# Run GSEA on DE output

In [None]:
import os
import pandas as pd
import numpy as np
import glob
from pybiomart import Dataset
import anndata as ad
import matplotlib.pyplot as plt

import gseapy as gp
from gseapy import Msigdb


In [None]:
NUM_CPUS = 8

# DATA_PATH = os.getcwd()
# DATA_PATH = '/data/expression_atlas/v1/GSE162828/'
# DATA_PATH = '/data/expression_atlas/v1/GSE122459/'
# DATA_PATH = '/data/expression_atlas/v1/GSE110914/'
DATA_PATH = '/data/expression_atlas/v1/GSE80183/'

RESULTS_PATH = '' + 'results/%s' % DATA_PATH.rstrip('/').split('/')[-1]

DDS_GENE_FH = '' + 'results/%s_dds_gene.h5_ad' % DATA_PATH.rstrip('/').split('/')[-1]

DDS_GENE_FH_OUT = '' + 'results/%s_dds_gene_gsea.h5_ad' % DATA_PATH.rstrip('/').split('/')[-1]

In [None]:
# Recover fh of all pydeseq2 output files.

de_gene_files = glob.glob('%s_gene*.csv' % RESULTS_PATH)
de_gene_files

In [None]:
# Load DE dataframes.

de_gene_dfs = {k:pd.read_csv(k, index_col=0) for k in de_gene_files}

In [None]:
# Filter dataframes by gene_id, only keep ensembl gene ids.

gene_prefix = 'ENSG'

for k in de_gene_dfs.keys():
    de_gene_dfs[k] = de_gene_dfs[k][de_gene_dfs[k].index.str.startswith(gene_prefix)]


In [None]:
# Fetch ensembl gene id - external gene name mappings from biomart.

dataset = Dataset(
                name='hsapiens_gene_ensembl',
                host='http://www.ensembl.org',
    )

external_gene_mapping = dataset.query(
                                attributes=['ensembl_gene_id', 'external_gene_name']
    )

external_gene_mapping.rename({'Gene stable ID': 'gene_id', 'Gene name': 'gene_name'}, axis=1, inplace=True)

external_gene_mapping

In [None]:
# Convert ensembl gene ids to stable ids by removing version. Merge external gene name. 

for k in de_gene_dfs.keys():
    if 'gene_name' not in de_gene_dfs[k].columns:
        de_gene_dfs[k].index = de_gene_dfs[k].index.str.split('.').str[0]

        de_gene_dfs[k] = de_gene_dfs[k].merge(
                                            external_gene_mapping, 
                                            left_index=True, 
                                            right_on='gene_id',
                                        )
        de_gene_dfs[k].set_index('gene_id', inplace=True)

In [None]:
# Drop rows where there isn't a defined gene_name

for k in de_gene_dfs.keys():
    de_gene_dfs[k] = de_gene_dfs[k][~de_gene_dfs[k]['gene_name'].isnull()]

In [None]:
# Create rank tables from dataframes using wald test statistic.

de_gene_ranks = {k: [
                v.reset_index()[['gene_name', 'stat']].sort_values('stat', axis=0),
                None,
                ] for k,v in de_gene_dfs.items()}

In [None]:
# Pull MSigDB gene sets down.

msig = Msigdb()
msig.list_category(dbver="2023.1.Hs")
gmt = msig.get_gmt(category='h.all', dbver='2023.1.Hs')
# gmt = msig.get_gmt(category='c2.all', dbver='2023.1.Hs')
len(gmt), gmt


In [None]:
# Run GSEA prerank on rank dataframes using gene sets of choice.

for k in de_gene_dfs.keys():
    gs_res = gp.prerank(
                                rnk=de_gene_ranks[k][0],
                                # gene_sets='KEGG_2016',
                                # gene_sets='GO_Biological_Process_2013', 
                                # gene_sets='ENCODE_and_ChEA_Consensus_TFs_from_ChIP-X',
                                # gene_sets='Reactome_2016',
                                # gene_sets='MSigDB_Computational',
                                gene_sets=gmt,
                                threads=NUM_CPUS,
                                min_size=5,
                                max_size=1000,
                                permuation_num=1000,
                                outdir=None,
                                seed=42,
                                verbose=True
                                )

    de_gene_ranks[k][1] = gs_res

    print(k)
    print(de_gene_ranks[k][1].res2d.head(10).to_string())

    terms = gs_res.res2d.Term
    axs = gs_res.plot(terms=terms[0:20])
    axs.suptitle(k, y=0.0)


In [None]:
# Load adata objects containing normed_counts for running traditional gsea.

dds_gene = ad.read_h5ad(DDS_GENE_FH)

dds_gene.X.shape, dds_gene.uns['contrasts']

In [None]:
# Create dataframes and class vectors for running GSEA on normed_counts. 

gene_df = pd.DataFrame(dds_gene.layers['normed_counts'].T, columns=dds_gene.obs.index, index=dds_gene.var.index)

# Filter dataframes by gene_id, only keep ensembl gene ids.
gene_prefix = 'ENSG'
gene_df = gene_df[gene_df.index.str.startswith(gene_prefix)]

if any(True for i in gene_df.index if '.' in i):
    
    gene_df.index = gene_df.index.str.split('.').str[0]

    gene_df = gene_df.merge(external_gene_mapping, left_index=True, right_on='gene_id')

    gene_df = gene_df[~gene_df['gene_name'].isnull()]

    gene_df.set_index('gene_name', inplace=True)
    
gene_df.shape

In [None]:
# Run GSEA on normed_counts for each contrast specied in dds object.

for k, v in dds_gene.uns['contrasts'].items():

    smallest_condition_size = dds_gene.obs[v[0]].value_counts()[-1]

    if smallest_condition_size < 3:
        continue

    conditions = dds_gene.obs[v[0]].values

    samples_in_contrast = dds_gene.obs[dds_gene.obs[v[0]].isin([v[1], v[2]])].index

    conditions = dds_gene.obs.loc[samples_in_contrast, v[0]]

    gs_res = gp.gsea(
                    data=gene_df[samples_in_contrast],
                    gene_sets=gmt,
                    cls=conditions,
                    permutation_type='phenotype' if len(conditions) > 15 else 'gene_set',
                    permutation_num=1000,
                    outdir=None,  # do not write output to disk
                    method='signal_to_noise',
                    min_size=5,
                    max_size=1000,
                    threads=NUM_CPUS, 
                    seed= 42,
                    verbose=True,
                )
    
    dds_gene.uns['gsea_%s' % k] = gs_res.res2d.astype(str)

    print(k)
    print(gs_res.res2d.head(10).to_string())

    terms = gs_res.res2d.Term
    axs = gs_res.plot(terms=terms[0:20])
    axs.suptitle(k,y=0.0)

In [None]:
# Run all samples through ssgsea, create matrix of ssgsea output in dds.obsm storing ssgsea NES.

dds_gene.obsm['ssgsea_es'] = pd.DataFrame(
                                np.zeros((dds_gene.obs.shape[0], len(gmt))), 
                                columns=[c for c in gmt.keys()],
                                index=dds_gene.obs.index
                            )

dds_gene.obsm['ssgsea_nes'] = pd.DataFrame(
                                np.zeros((dds_gene.obs.shape[0], len(gmt))), 
                                columns=[c for c in gmt.keys()],
                                index=dds_gene.obs.index
                            )
            
for i in dds_gene.obs.index:
    
    ss = gp.ssgsea(data=gene_df.loc[:,i].rename(None, inplace=True),
               gene_sets=gmt,
               outdir=None,
               sample_norm_method='rank', 
               no_plot=True,
               verbose=True,
               min_size=5, 
               max_size=1000,
               )

    ss.res2d.set_index('Term', inplace=True)

    for c in dds_gene.obsm['ssgsea_es'].columns:
        dds_gene.obsm['ssgsea_es'].loc[i,c] = ss.res2d.loc[c,'ES']
        dds_gene.obsm['ssgsea_nes'].loc[i,c] = ss.res2d.loc[c,'NES']

In [None]:
# Save output of gsea analyses.

dds_gene.write(DDS_GENE_FH_OUT)

In [None]:
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

In [None]:
scaler = StandardScaler()
scaled_data = scaler.fit_transform(dds_gene.obsm['ssgsea_es'])

In [None]:
scaled_data

In [None]:
pca = PCA(n_components=scaled_data.shape[0])

pca_out = pca.fit_transform(scaled_data)

In [None]:
pca_out.shape

In [None]:
plt.plot(pca.explained_variance_ratio_)

In [None]:
plt.scatter(pca_out[:,0], pca_out[:,1],c=['b' if c == 'CONTROL' else 'k' for c in dds_gene.obs['condition-1']])
plt.xlabel('PC1')
plt.ylabel('PC2')