# ORA and GSEA

## 1. Import

In [None]:
import decoupler as dc
import os
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import pickle


np.random.seed(14)

## 2. Load databases

### 2.1 MSigDB 

In [None]:
msigdb = dc.op.resource('MSigDB', organism='human')

In [None]:
# Filter by GO terms
msigdb = msigdb[msigdb['collection'].str.startswith('go')]
# Remove duplicated entries
msigdb = msigdb[~msigdb.duplicated(['geneset', 'genesymbol'])]
msigdb

In [None]:
msigdb.to_csv('dataset/msigdb_GOterms.csv', index=False)

In [None]:
msigdb = pd.read_csv('dataset/msigdb_GOterms.csv')

### 2.2 Hallmark

In [None]:
hallmark = dc.op.hallmark(organism='human')
hallmark

## 3. Load consensus_top_genes_df

In [None]:
output_directory = 'cNMF_w_filtered_genes'
corrected_count_adat_fn = os.path.join(output_directory, 'corrected_counts.h5ad')
adata = sc.read_h5ad(corrected_count_adat_fn)
cohorts = adata.obs['dataset'].unique()

In [None]:
directory = os.path.join(output_directory, "2000hvg")

In [None]:
consensus_top_genes_df = pd.read_csv(f"{directory}/cos_similarity_consensus_top_genes_df.csv")

In [None]:
consensus_top_genes_df = consensus_top_genes_df.set_index('Unnamed: 0')
consensus_top_genes_df.index.name = "gene_id"

In [None]:
consensus_top_genes_df_og = consensus_top_genes_df.copy()
consensus_top_genes_df = consensus_top_genes_df.head(20)
consensus_top_genes_df

In [None]:
# dataframe with boolean values
all_genes = pd.unique(consensus_top_genes_df.values.ravel())
bool_df = pd.DataFrame(0, index=all_genes, columns=consensus_top_genes_df.columns)

# assign a 1 if gene is in top 100 of that consensus
for consensus in consensus_top_genes_df.columns:
    genes_in_consensus = consensus_top_genes_df[consensus].values
    bool_df.loc[genes_in_consensus, consensus] = 1

## 4. ORA

### 4.1 MSigDB GO terms

In [None]:
net = msigdb.rename(columns={
    'geneset': 'source',
    'genesymbol': 'target'
})[['source', 'target']]

In [None]:
genes_in_net = set(net['target'])
filtered_bool_df = bool_df.loc[bool_df.index.intersection(genes_in_net)]

In [None]:
pd.DataFrame(filtered_bool_df).sum()

In [None]:
filtered_bool_df = filtered_bool_df.drop(columns=['consensus_11'])

In [None]:
# order the df
filtered_bool_df = filtered_bool_df[sorted(filtered_bool_df.columns, key=lambda x: int(x.split("_")[1]))]

In [None]:
output_dir = "ora_MSigDB_barplots"
os.makedirs(output_dir, exist_ok=True)

dict_ora_results = {} # tuples dictionary
for consensus in filtered_bool_df.columns:
    ora_results = dc.mt.ora(
                        data = pd.DataFrame(filtered_bool_df[consensus]).T,
                        net = net,
                        tmin = 1)
    dict_ora_results[consensus] = ora_results

    enrichment_scores = ora_results[0]

    fig, ax = plt.subplots(figsize=(14, 5))  
    dc.pl.barplot(data=enrichment_scores, name=consensus, ax=ax)
    plt.title(consensus)
    plt.tight_layout()
    output_path = os.path.join(output_dir, f"{consensus}.png")
    fig.savefig(output_path, dpi=300)  

    plt.show()

### 4.2 MSigDB GO terms separated

In [None]:
net = msigdb.rename(columns={
    'geneset': 'source',
    'genesymbol': 'target'
})[['source', 'target']]

In [None]:
net_CC = net[net['source'].str.startswith('GOCC')].copy()
net_BP = net[net['source'].str.startswith('GOBP')].copy()
net_MF = net[net['source'].str.startswith('GOMF')].copy()

#### 4.2.1 GO cellular component

In [None]:
genes_in_net_CC = set(net_CC['target'])
filtered_bool_df_CC = bool_df.loc[bool_df.index.intersection(genes_in_net_CC)]

In [None]:
pd.DataFrame(filtered_bool_df_CC).sum()

In [None]:
filtered_bool_df_CC = filtered_bool_df_CC.drop(columns=['consensus_11'])

In [None]:
# order the df
filtered_bool_df_CC = filtered_bool_df_CC[sorted(filtered_bool_df_CC.columns, key=lambda x: int(x.split("_")[1]))]

In [None]:
output_dir = "ora_MSigDB_CC_barplots"
os.makedirs(output_dir, exist_ok=True)

dict_ora_results_CC = {} # tuples dictionary
for consensus in filtered_bool_df_CC.columns:
    ora_results_CC = dc.mt.ora(
                        data = pd.DataFrame(filtered_bool_df_CC[consensus]).T,
                        net = net_CC,
                        tmin = 1)
    dict_ora_results_CC[consensus] = ora_results_CC

    enrichment_scores_CC = ora_results_CC[0]

    fig, ax = plt.subplots(figsize=(14, 5))  
    dc.pl.barplot(data=enrichment_scores_CC, name=consensus, ax=ax)
    plt.title(consensus)
    plt.tight_layout()
    output_path = os.path.join(output_dir, f"{consensus}.png")
    fig.savefig(output_path, dpi=300)  

    plt.show()

#### 4.2.2 GO biological process

In [None]:
genes_in_net_BP = set(net_BP['target'])
filtered_bool_df_BP = bool_df.loc[bool_df.index.intersection(genes_in_net_BP)]

In [None]:
pd.DataFrame(filtered_bool_df_BP).sum()

In [None]:
filtered_bool_df_BP = filtered_bool_df_BP.drop(columns=['consensus_11'])

In [None]:
# order the df
filtered_bool_df_BP = filtered_bool_df_BP[sorted(filtered_bool_df_BP.columns, key=lambda x: int(x.split("_")[1]))]

In [None]:
output_dir = "ora_MSigDB_BP_barplots"
os.makedirs(output_dir, exist_ok=True)

dict_ora_results_BP = {} # tuples dictionary
for consensus in filtered_bool_df_BP.columns:
    ora_results_BP = dc.mt.ora(
                        data = pd.DataFrame(filtered_bool_df_BP[consensus]).T,
                        net = net_BP,
                        tmin = 1)
    dict_ora_results_BP[consensus] = ora_results_BP

    enrichment_scores_BP = ora_results_BP[0]

    fig, ax = plt.subplots(figsize=(14, 5))  
    dc.pl.barplot(data=enrichment_scores_BP, name=consensus, ax=ax)
    plt.title(consensus)
    plt.tight_layout()
    output_path = os.path.join(output_dir, f"{consensus}.png")
    fig.savefig(output_path, dpi=300)  

    plt.show()

#### 4.2.3 GO molecular function

In [None]:
genes_in_net_MF = set(net_MF['target'])
filtered_bool_df_MF = bool_df.loc[bool_df.index.intersection(genes_in_net_MF)]

In [None]:
pd.DataFrame(filtered_bool_df_MF).sum()

In [None]:
filtered_bool_df_MF = filtered_bool_df_MF.drop(columns=['consensus_11'])

In [None]:
# order the df
filtered_bool_df_MF = filtered_bool_df_MF[sorted(filtered_bool_df_MF.columns, key=lambda x: int(x.split("_")[1]))]

In [None]:
output_dir = "ora_MSigDB_MF_barplots"
os.makedirs(output_dir, exist_ok=True)

dict_ora_results_MF = {} # tuples dictionary
for consensus in filtered_bool_df_MF.columns:
    ora_results_MF = dc.mt.ora(
                        data = pd.DataFrame(filtered_bool_df_MF[consensus]).T,
                        net = net_MF,
                        tmin = 1)
    dict_ora_results_MF[consensus] = ora_results_MF

    enrichment_scores_MF = ora_results_MF[0]

    fig, ax = plt.subplots(figsize=(14, 5))  
    dc.pl.barplot(data=enrichment_scores_MF, name=consensus, ax=ax)
    plt.title(consensus)
    plt.tight_layout()
    output_path = os.path.join(output_dir, f"{consensus}.png")
    fig.savefig(output_path, dpi=300)  

    plt.show()

## 5. Load consensus programs dictionary with spectra scores

In [None]:
with open(f"{directory}/consensus_programs_symbols.pkl", "rb") as f:
    consensus_programs_symbols = pickle.load(f)

In [None]:
consensus_programs_symbols.pop('consensus_11', None)

In [None]:
sorted_keys = sorted(consensus_programs_symbols.keys(), key=lambda x: int(x.split('_')[1]))

## 6. GSEA

### 6.1 Hallmark gene set

In [None]:
output_dir = "gsea_Hallmark_barplots"
os.makedirs(output_dir, exist_ok=True)

dict_gsea_results_hallmark = {} # tuples dictionary
for consensus in sorted_keys:
    df = pd.DataFrame(consensus_programs_symbols[consensus]).T
    df.index = [consensus]
    gsea_results_hallmark = dc.mt.gsea(
                        data = df,
                        net = hallmark,
                        tmin = 1)
    dict_gsea_results_hallmark[consensus] = gsea_results_hallmark

    enrichment_scores_hallmark = gsea_results_hallmark[0]

    fig, ax = plt.subplots(figsize=(10, 4))  
    dc.pl.barplot(data=enrichment_scores_hallmark, name=consensus, ax=ax)
    plt.title(consensus)
    plt.tight_layout()
    output_path = os.path.join(output_dir, f"{consensus}.png")
    fig.savefig(output_path, dpi=300)  # dpi=300 per alta qualità

    plt.show()