In [6]:
import numpy as np
import multiprocessing
import string
import random
import pandas as pd
from tqdm.notebook import tqdm
from statsmodels.stats.multitest import multipletests

In [7]:
def id_generator(size=6, chars=string.ascii_uppercase + string.digits):
    return ''.join(random.choice(chars) for _ in range(size))

profile = np.array([id_generator(7) for _ in range(20000)])
geneset = np.random.choice(profile, 500, False)

In [8]:
def GSEA(profile, geneset, return_info=False):
    if len(geneset) == 0:
        return np.nan
        
    ids = np.where(np.isin(profile, geneset))[0]
    ids_sort = np.sort(ids)

    n = geneset.shape[0]
    N = profile.shape[0]

    ids_sort += 1
    ids_stack = np.hstack([ids_sort, ids_sort-1, ids_sort+1])

    tile = np.tile(ids_sort, (3 * ids_sort.shape[0], 1))

    comp = np.sum(tile <= ids_stack[:, None], 1)
    phit =  comp / n
    pmiss = (ids_stack - comp) / (N - n)

    es = phit - pmiss
    id_max = np.argmax(np.abs(es))
    similarity_score = es[id_max]

    if return_info == True:
        ids_sort -= 1
        gene_names = profile[ids_sort]
        
        if similarity_score > 0:
            direction = "pos"
            significant = np.where(ids_sort <= id_max, "Y", "N")
            percent_top = np.sum(significant=='Y') / n * 100
            #percent_top <- sum(gene_table$Significant=="Y")/n*100
        else:
            direction = "neg"
            significant = np.where(ids_sort >= id_max, "Y", "N")
            percent_top = np.sum(significant=='Y') / n * 100

        info = {'direction':direction, 
                'top':percent_top,
                'genes':gene_names,
                'ids':ids_sort,
                'significant':significant,
                'es':es}
        return similarity_score, info
    else:
        return similarity_score

def premutation_biGSEA(profile_genes, termup, termdw):
    sample_up = np.random.choice(profile_genes, size=termup.shape[0], replace=False)
    sample_dw = np.random.choice(profile_genes, size=termdw.shape[0], replace=False)
    score_up = GSEA(profile_genes, sample_up, False)
    score_dw = GSEA(profile_genes, sample_dw, False)
    return score_up, score_dw

In [9]:
def association_test(profile, genesets_list, N_permutations=100, n_jobs=24):
    profile_name = profile.columns[0]
    profile_genes = profile.index.to_numpy()
    profile_sorted = profile.sort_values(profile_name, ascending=False)
    profile_sorted_genes = profile_sorted.index.to_numpy()

    out = pd.DataFrame(index=genesets_list.keys(), columns = ['Term', 'pval', 'padj', 'NES', 
                                                            'Top_up', 'Top_down', 
                                                            'Direction_up', 'Direction_down'])
    pool = multiprocessing.Pool(n_jobs)
    for term, double_set in tqdm(genesets_list.items()):
        termup = np.intersect1d(double_set['Up'], profile_genes)
        termdw = np.intersect1d(double_set['Down'], profile_genes)

        if (len(termup) + len(termdw)) == 0:
            continue
        
        resampling_scores_up, resampling_scores_dw = zip(*pool.starmap(
                                                    premutation_biGSEA, 
                                                    [(profile_genes, termup, termdw)] * N_permutations))
        resampling_scores_up = np.asarray(resampling_scores_up)
        resampling_scores_dw = np.asarray(resampling_scores_dw)

        #normalize resampling scores by std
        sd_up = resampling_scores_up.std(ddof=1)
        sd_dw = resampling_scores_dw.std(ddof=1)
        resampling_scores_up = resampling_scores_up / sd_up
        resampling_scores_dw = resampling_scores_dw / sd_dw
        resampling_scores = (resampling_scores_up - resampling_scores_dw) / 2

        #run GSEA for profile
        termup = np.intersect1d(double_set['Up'], profile_sorted_genes)
        termdw = np.intersect1d(double_set['Down'], profile_sorted_genes)

        score_up, info_up = GSEA(profile_sorted_genes, termup, return_info=True)
        score_dw, info_dw = GSEA(profile_sorted_genes, termdw, return_info=True)
        info_up['es'] = info_up['es'] / sd_up
        info_dw['es'] = info_dw['es'] / sd_dw
        NES_up = score_up / sd_up
        NES_dw = score_dw / sd_dw

        final_score = (NES_up - NES_dw) / 2
        
        out.loc[term, "Term"] = term
        out.loc[term, "NES"] = final_score
        out.loc[term, "Direction_up"] = info_up['direction']
        out.loc[term, "Direction_down"] = info_dw['direction']
        out.loc[term, "Top_up"] = info_up['top']
        out.loc[term, "Top_down"] = info_dw['top']
        
    out['pval'] = out['NES'].apply(lambda x: np.sum(np.abs(resampling_scores) > np.abs(x)) / N_permutations)
    out['pval'] = np.where(out['pval']==0, 1/N_permutations, out['pval'])
    out['padj'] = multipletests(out['pval'], method='fdr_bh')[1]
    return out

In [10]:
def open_gmt_cmap(path_up, path_dw):
    with open(path_up) as f:
        gmtu = f.readlines()
    with open(path_dw) as f:    
        gmtd = f.readlines()
    terms = {}
    for s in gmtu:
        title, _, *genes  = s.strip().split('\t')
        terms[title] = {'Up':genes, 'Down':None}
    for s in gmtd:
        title, _, *genes  = s.strip().split('\t')
        terms[title]['Down'] = genes
    return terms  

In [24]:
a, b = zip(*[(1,2), (3,4)])
np.asarray(a)

array([1, 3])

In [11]:
#dummy data
N = 20000
profile = pd.DataFrame({"name":np.random.randn(N)}, index = np.array([id_generator(7) for _ in range(N)]))
genesets_list = {term:{"Up": np.random.choice(profile.index, 100, False),
                       "Down":np.random.choice(profile.index, 100, False)} 
                       for term in ['a', 'b', 'c', 'd']}

# out = association_test(profile, genesets_list, 1000, 32)

In [25]:
#real data
reprog = pd.read_csv('reprogramming_full.csv', index_col=0)
profile = reprog[['symbol', 'logFC']]
profile = profile.dropna(0)
profile['symbol'] = profile['symbol'].apply(str.upper)
profile = profile.set_index('symbol')

In [26]:
path_up = 'drugs/cmap_up.gmt'
path_dw = 'drugs/cmap_dw.gmt'
genesets_list = open_gmt_cmap(path_up, path_dw)

In [27]:
dict((k,adict[k]) for k in ('key1','key2','key99') if k in adict)

{'1,4-chrysenequinone-1773': {'Up': ['ABHD3',
   'ADAM8',
   'ADCYAP1',
   'ALDOB',
   'APBA2',
   'ARHGAP33',
   'ARL4C',
   'ATXN1',
   'BAG3',
   'BIN1',
   'CA12',
   'CCT6B',
   'CLSPN',
   'CTSL',
   'CYP1B1',
   'DDR1',
   'DLG3',
   'DNAJA1',
   'DNAJB1',
   'DNAJB4',
   'EPB41L3',
   'FAM115A',
   'FAM206A',
   'FASTK',
   'FLNC',
   'FSCN3',
   'FTH1',
   'FTH1P5',
   'GATA2',
   'GCLM',
   'GPR182',
   'GPR27',
   'HFE',
   'HIST1H2AM',
   'HIST1H2BG',
   'HIST1H2BI',
   'HMOX1',
   'HSP90AA1',
   'HSPA4L',
   'HSPA6',
   'HSPA8',
   'HSPB1',
   'HSPH1',
   'ID2',
   'IER5',
   'IL24',
   'KIF3C',
   'LMF2',
   'LY6G6C',
   'MAGED2',
   'MAPK8IP3',
   'MARCKS',
   'ME1',
   'MPZL1',
   'NCOA3',
   'NDRG2',
   'NEAT1',
   'NF1',
   'NFKBIA',
   'NQO1',
   'NR2F2',
   'OPRL1',
   'PBX1',
   'PHLDA3',
   'PIAS2',
   'PKD1',
   'PLCB2',
   'PSG3',
   'RAB11B',
   'RELB',
   'RHPN1-AS1',
   'RIT1',
   'RORC',
   'SLC33A1',
   'SLC7A11',
   'SPDEF',
   'SPR',
   'SQSTM1',
   'SRD5

In [53]:
out = association_test(profile, genesets_list, 1000, 32)

  0%|          | 0/6100 [00:00<?, ?it/s]

Process ForkPoolWorker-90:
Process ForkPoolWorker-85:
Process ForkPoolWorker-100:
Process ForkPoolWorker-92:
Process ForkPoolWorker-82:
Process ForkPoolWorker-89:
Process ForkPoolWorker-95:
Process ForkPoolWorker-87:
Process ForkPoolWorker-103:
Process ForkPoolWorker-84:
Process ForkPoolWorker-83:
Process ForkPoolWorker-73:
Process ForkPoolWorker-88:
Process ForkPoolWorker-91:
Process ForkPoolWorker-77:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Process ForkPoolWorker-97:
  File "/opt/Anaconda/2020/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/opt/Anaconda/2020/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last)

KeyboardInterrupt: 

Traceback (most recent call last):
  File "/opt/Anaconda/2020/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
  File "/opt/Anaconda/2020/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/opt/Anaconda/2020/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/Anaconda/2020/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
Process ForkPoolWorker-94:
  File "/opt/Anaconda/2020/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
Process ForkPoolWorker-93:
  File "/opt/Anaconda/2020/lib/python3.7/multiprocessing/pool.py", line 121, in worker
    result = (True, func(*args, **kwds))
Traceback (most recent call last):
  File "/opt/Anaconda/2020/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
 

In [45]:
out

NameError: name 'out' is not defined