# Run GSEA on DE output
* Run prerank on all contrasts using gene sets defined in second cell. Results saved to uns['gsea_results_prerank'].
* Run traditional GSEA using normalized counts with defined genesets. Results saved to uns['gsea_results_standard'].
* Run single sample GSEA using normalized counts with defined genesets. Results saved to obsm as an array with columns keys in uns['ssgsea_colnames'].

In [None]:
import os
import pandas as pd
import numpy as np
from pybiomart import Dataset
import anndata as ad
import matplotlib.pyplot as plt
import warnings

import gseapy as gp
from gseapy import Msigdb
import gseapy.parser as gmtparser

warnings.filterwarnings('ignore')

### Define gene sets for gsea.
* Available gene sets defined:
    * At bottom of page here: https://gseapy.readthedocs.io/en/latest/gseapy_tutorial.html via enrichr API.
    * Here: https://gseapy.readthedocs.io/en/latest/gseapy_example.html#Msigdb-API via Msigdb API.


In [None]:
NUM_CPUS = 8

# DATA_PATH = os.getcwd()
# DATA_PATH = '/data/expression_atlas/v1/GSE162828/'
# DATA_PATH = '/data/expression_atlas/v1/GSE122459/'
# DATA_PATH = '/data/expression_atlas/v1/GSE110914/'
DATA_PATH = '/data/expression_atlas/v1/GSE80183/'

RESULTS_PATH = '' + 'de_results/%s' % DATA_PATH.rstrip('/').split('/')[-1]

DDS_GENE_FH = RESULTS_PATH + '_dds_gene.h5_ad'

GENE_SET_ANNOTATIONS = [
                    'h.all', 
                    'c2.all', 
                    'c3.all',
                    'c6.all',
                    'c7.all',
                    'c8.all',
                    'Reactome_2016',
                    'KEGG_2016',
                ]

In [None]:
# Download all the GENE_SET_ANNOTATIONS.

for ga in GENE_SET_ANNOTATIONS:
    if '.all' not in ga:
        gmtparser.download_library(ga)

### Contrasts defined in experiment.

In [None]:
# Load adata objects containing normed_counts for running traditional gsea.

dds_gene = ad.read_h5ad(DDS_GENE_FH)

dds_gene.X.shape, dds_gene.uns['contrasts']

In [None]:
# Copy DE dataframes to new dataframe that can be manipulated for gsea.

dds_gene.uns['stat_results_gsea'] = {k:v.copy() for k,v in dds_gene.uns['stat_results'].items()}


In [None]:
# Filter dataframes by gene_id, only keep ensembl gene ids.

gene_prefix = 'ENSG'

for k in dds_gene.uns['stat_results_gsea'].keys():
    dds_gene.uns['stat_results_gsea'][k] = dds_gene.uns['stat_results_gsea'][k][
                                                dds_gene.uns['stat_results_gsea'][k].index.str.startswith(gene_prefix)
                                            ]

### PA1.1 Fetch gene_name/ensembl gene id mappings.

In [None]:
# Fetch ensembl gene id - external gene name mappings from biomart.

dataset = Dataset(
                name='hsapiens_gene_ensembl',
                host='http://www.ensembl.org',
            )

external_gene_mapping = dataset.query(
                                attributes=[
                                    'ensembl_gene_id', 
                                    'external_gene_name',
                                ]
                            )

external_gene_mapping.rename({'Gene stable ID': 'gene_id', 'Gene name': 'gene_name'}, axis=1, inplace=True)

external_gene_mapping

In [None]:
dataset.attributes.keys()

In [None]:
# Convert ensembl gene ids to stable ids by removing version. Merge external gene name.

for k in dds_gene.uns['stat_results_gsea'].keys():
    if 'gene_name' not in dds_gene.uns['stat_results_gsea'][k].columns:
        dds_gene.uns['stat_results_gsea'][k].index = dds_gene.uns['stat_results_gsea'][k].index.str.split('.').str[0]

        dds_gene.uns['stat_results_gsea'][k] = dds_gene.uns['stat_results_gsea'][k].merge(
                                                                                        external_gene_mapping,
                                                                                        left_index=True,
                                                                                        right_on='gene_id',
                                                                                    )

In [None]:
# Drop rows where there isn't a defined gene_name.

for k in dds_gene.uns['stat_results_gsea'].keys():
    dds_gene.uns['stat_results_gsea'][k] = dds_gene.uns['stat_results_gsea'][k][
                                                                            ~dds_gene.uns['stat_results_gsea'][k]['gene_name'].isnull()
                                                                        ]

In [None]:
# Pull MSigDB gene sets down.

msigdb_gene_sets = {}

msig = Msigdb()
msig.list_category(dbver="2023.1.Hs")

for gs in [i for i in GENE_SET_ANNOTATIONS if '.all' in i]:

    gmt = msig.get_gmt(category=gs, dbver='2023.1.Hs')
    msigdb_gene_sets[gs] = gmt


### PA1.2 Run GSEA prerank using Wald statistic.

In [None]:
# Run GSEA prerank on Wald statistic ranks in gsea DE dataframes using gene sets of choice.

dds_gene.uns['gsea_results_prerank'] = {}

for ga in GENE_SET_ANNOTATIONS:

    for k in dds_gene.uns['stat_results_gsea'].keys():

        gs_res = gp.prerank(
                        rnk=dds_gene.uns['stat_results_gsea'][k][['gene_name', 'stat']],
                        gene_sets=ga if '.all' not in ga else msigdb_gene_sets[ga],
                        threads=NUM_CPUS,
                        min_size=5, 
                        max_size=1000,
                        permutation_num=1000,
                        outdir=None,
                        seed=42,
                        verbose=True,
                    )
        
        dds_gene.uns['gsea_results_prerank']['%s_%s' % (k,ga)] = gs_res.res2d.astype(str, copy=True).copy()

        print('%s_%s' % (k, ga))
        display(gs_res.res2d.head(10))

        terms = gs_res.res2d.Term
        axs = gs_res.plot(terms=terms[0:20])
        axs.suptitle('%s_%s' % (k,ga), y=0.0)

In [None]:
# Create dataframes and class vectors for running GSEA on normed_counts. 

gene_df = pd.DataFrame(dds_gene.layers['normed_counts'].T, columns=dds_gene.obs.index, index=dds_gene.var.index)

# Filter dataframes by gene_id, only keep ensembl gene ids.
gene_prefix = 'ENSG'
gene_df = gene_df[gene_df.index.str.startswith(gene_prefix)]

if any(True for i in gene_df.index if '.' in i):
    
    gene_df.index = gene_df.index.str.split('.').str[0]

    gene_df = gene_df.merge(external_gene_mapping, left_index=True, right_on='gene_id')

    gene_df = gene_df[~gene_df['gene_name'].isnull()]

    gene_df.set_index('gene_name', inplace=True)
    
dds_gene.uns['gsea_gene_df'] = gene_df.copy()


### PA1.2 Run traditional GSEA on normalized counts.

In [None]:
# Run GSEA on normed_counts for each contrast specied in dds object.

dds_gene.uns['gsea_results_standard'] = {}

for ga in GENE_SET_ANNOTATIONS:

    for k, v in dds_gene.uns['contrasts'].items():

        smallest_condition_size = dds_gene.obs[v[0]].value_counts()[-1]

        if smallest_condition_size < 3:
            continue

        conditions = dds_gene.obs[v[0]].values

        samples_in_contrast = dds_gene.obs[dds_gene.obs[v[0]].isin([v[1], v[2]])].index

        conditions = dds_gene.obs.loc[samples_in_contrast, v[0]]

        gs_res = gp.gsea(
                        data=dds_gene.uns['gsea_gene_df'][samples_in_contrast],
                        gene_sets=ga if '.all' not in ga else msigdb_gene_sets[ga],
                        cls=conditions,
                        permutation_type='phenotype' if len(conditions) > 15 else 'gene_set',
                        permutation_num=1000,
                        outdir=None,
                        method='signal_to_noise',
                        min_size=5,
                        max_size=1000,
                        threads=NUM_CPUS, 
                        seed= 42,
                        verbose=True,
                    )
        
        dds_gene.uns['gsea_results_standard']['%s_%s' % (k, ga)] = gs_res.res2d.astype(str).copy()

        print('%s_%s' % (k, ga))
        print(gs_res.res2d.head(10).to_string())

        terms = gs_res.res2d.Term
        axs = gs_res.plot(terms=terms[0:20])
        axs.suptitle(k,y=0.0)

### PA1.3 Run single-sample GSEA. Save pathway enrichment- and normalized enrichment- scores to anndata.

In [None]:
# Run all samples through ssgsea, create matrix of ssgsea output in dds.obsm storing ssgsea NES and ES.
# There's an issue saving pandas DataFrames with large headers via hdf5, so saving ssgsea results in two
# arrays in obsm, and then saving header to separate arrays in uns.

dds_gene.uns['ssgsea_colnames'] = {}

for ga in GENE_SET_ANNOTATIONS:

    if '.all' in ga:
        gmt = msigdb_gene_sets[ga]
    else:
        gmt = gmtparser.get_library(ga)

    dds_gene.obsm['%s_ssgsea_es' % ga] = np.zeros((dds_gene.obs.shape[0], len(gmt)))

    dds_gene.obsm['%s_ssgsea_nes' % ga] = np.zeros((dds_gene.obs.shape[0], len(gmt))) 

    dds_gene.uns['ssgsea_colnames'][ga] = np.array([c for c in gmt.keys()])
            
    for i,s in enumerate(dds_gene.obs.index):
        
        ss = gp.ssgsea(
                data=dds_gene.uns['gsea_gene_df'].loc[:,s].rename(None, inplace=True),
                gene_sets=ga if '.all' not in ga else msigdb_gene_sets[ga],
                outdir=None,
                sample_norm_method='rank', 
                no_plot=True,
                verbose=True,
                min_size=5, 
                max_size=1000,
                )

        ss.res2d.set_index('Term', inplace=True)

        for c in ss.res2d.index:

            c_index = np.where(dds_gene.uns['ssgsea_colnames'][ga] == c)[0][0]

            dds_gene.obsm['%s_ssgsea_es' % ga][i,c_index] = ss.res2d.loc[c,'ES']
            dds_gene.obsm['%s_ssgsea_nes' % ga][i,c_index] = ss.res2d.loc[c,'NES']
            

In [None]:
# Save output of gsea analyses.

dds_gene.write(DDS_GENE_FH)