In [4]:
pwd

'/home/share/huadjyin/home/zhoumin3/zhoumin/model_benchmark/B_results/Afigure/04Jaccard_similarity'

In [1]:
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

import argparse
import sys
import numpy as np
from gears import GEARS
from gears.inference import evaluate

from gears import PertData, GEARS
import scanpy as sc
import diffxpy.api as de

from scipy.stats import hypergeom
from sklearn.metrics import jaccard_score
from statsmodels.stats.multitest import multipletests

In [2]:
def make_gene_list_unique(adata):
    if adata.var['gene_name'].value_counts().max() == 1:
        return
    
    else:
        adata.var = adata.var.reset_index()
        adata.var['gene_name'] = adata.var['gene_name'].astype('string')
        adata.var = adata.var.set_index('gene_name')
        adata.var_names_make_unique()
        adata.var = adata.var.reset_index()
        adata.var = adata.var.set_index('gene')
        return

In [3]:
def run_DE_test(pert, high_expression_filter, filter_level=0.01):
    ctrl_adata = adata[adata.obs['condition'] == 'ctrl']
    DE_adata = adata[adata.obs['condition'].isin(['ctrl', pert])]

    if high_expression_filter:
        high_expression_idx = np.where(adata.X.toarray().mean(0) > filter_level)[0]
        high_expression_genes = adata.var.index[high_expression_idx]
        ctrl_adata = ctrl_adata[:, high_expression_genes]
        DE_adata = DE_adata[:, high_expression_genes]

    ## First compute for true DE
    DE_adata.X = DE_adata.X.todense()
    true_df = DE_adata.to_df()
    true_df['condition'] = DE_adata.obs['condition']

    true_df_mean = true_df.groupby('condition').mean()

    slim_adata = sc.AnnData(DE_adata.X.toarray())
    slim_adata.obs = DE_adata.obs

    ctrl_test = de.test.rank_test(
        data=slim_adata,
        grouping="condition"
    )

    ## For perts in test
    genes = [g for g in pert.split('+') if g != 'ctrl']
    
    if gears_model.config['uncertainty']:
        preds_dict, _ = gears_model.predict([genes])
    else:
        preds_dict = gears_model.predict([genes])

    preds = preds_dict['_'.join(genes)]
    

    preds = np.array(preds)
    if preds.ndim == 1:
        preds = preds.reshape(1, -1)

    if high_expression_filter:
        preds = preds[:, high_expression_idx]

    pred_adata = sc.AnnData(np.concatenate([ctrl_adata.X.toarray(), preds]))
    pred_adata.obs_names_make_unique()

    pred_adata.obs['condition'] = ['ctrl'] * len(ctrl_adata) + [pert] * len(preds)

    for c in slim_adata.obs.columns:
        if c == 'condition':
            continue
        pred_adata.obs[c] = slim_adata.obs[c][0]

    pred_test = de.test.rank_test(
        data=pred_adata,
        grouping="condition"
    )

    pred_df = pred_adata.to_df()
    pred_df['condition'] = pred_adata.obs['condition']

    pred_df_mean = pred_df.groupby('condition').mean()

    return ctrl_test, pred_test

def filter_summaries(ctrl_test, pred_test, abslogfc_thresh=1, qval_thresh=0.05):
    pred_summary = pred_test.summary()
    pred_summary['abs_log2fc'] = pred_summary['log2fc'].abs()

    ctrl_summary = ctrl_test.summary()
    ctrl_summary['abs_log2fc'] = ctrl_summary['log2fc'].abs()

    ctrl_subset = ctrl_summary[ctrl_summary['abs_log2fc']>1]
    ctrl_subset = ctrl_subset[ctrl_subset['qval']<0.05]

    pred_subset = pred_summary[pred_summary['abs_log2fc']>1]
    pred_subset = pred_subset[pred_subset['qval']<0.05]
    
    return ctrl_subset, pred_subset 

def get_p_val_jaccard(ctrl_subset, pred_subset):  
    y_true = set(ctrl_subset['gene'].values)
    y_pred = set(pred_subset['gene'].values)
    overlap = len(y_true.intersection(y_pred))
    jaccard = get_jaccard(y_true, y_pred)    

    M = len(ctrl_test.summary())
    N = len(pred_subset)
    n = len(ctrl_subset)
    k = overlap
    print(M, N, n, k)

    p_value = hypergeom.cdf(k-1, M, n, N)
    adjusted_p_value = multipletests(p_value, method='bonferroni')[1]
    
    #return (1-p_value, jaccard)
    return (adjusted_p_value, jaccard)
    
def get_jaccard(A, B):
    return (len(A.intersection(B))/len(A.union(B)))

In [4]:
adata = sc.read_h5ad('/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/normanweissman2019/perturb_processed.h5ad')
gene_name_dict = adata.var.loc[:,'gene_name'].to_dict()
gene_name_dict_inverse = {val:key for key, val in gene_name_dict.items()}

In [5]:
import torch
jaccards = {}
p_vals = {}
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')  
dataset = "NormanWeissman2019"

In [6]:
for seed in range(1,6):
    device =device
    dataset=dataset
    batch_size=32
    epoch = 15
    no_perturb = False
    naive=None
    model='gears'

    #data_path = '/dfs/project/perturb-gnn/datasets/data/'
    pert_data = PertData('/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/')  # specific saved folder
    #pert_data.load(data_path = data_path + dataset) # load the processed data, the path is saved folder + dataset_name
    pert_data.load(data_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/normanweissman2019')
    pert_data.prepare_split(split = 'simulation', seed = seed)
    pert_data.get_dataloader(batch_size = batch_size, test_batch_size = batch_size)

    gears_model = GEARS(pert_data, device = device, 
                        weight_bias_track = False, 
                        proj_name = dataset, 
                        exp_name = str(model) + '_seed' + str(seed))
    gears_model.load_pretrained('/home/share/huadjyin/home/zhoumin3/zhoumin/model_benchmark/01_A_results/NormanWeissman2019/gears/split' + str(seed))
  
    high_expression_filter = True
    filter_level = 0.01

    p_vals[seed] = {}
    jaccards[seed] = {}
    for key in pert_data.subgroup['test_subgroup'].keys():
        p_vals[seed][key] = []
        jaccards[seed][key] = []
        for pert in pert_data.subgroup['test_subgroup'][key]:
            print(pert)
            ctrl_test, pred_test = run_DE_test(pert, high_expression_filter, filter_level=filter_level)
            ctrl_subset, pred_subset = filter_summaries(ctrl_test, pred_test)
            p_val, jaccard = get_p_val_jaccard(ctrl_subset, pred_subset)
            p_vals[seed][key].append(p_val)
            jaccards[seed][key].append(jaccard)
            
    del(gears_model)

Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['LYL1+IER5L' 'IER5L+ctrl' 'KIAA1804+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:9
combo_seen1:52
combo_seen2:17
unseen_single:27
Done!
Creating dataloaders....
Done!


here1
CBL+PTPN12
1750 456 116 61
CDKN1C+CDKN1A
1750 444 145 110
C3orf72+FOXL2
1750 510 24 14
CDKN1C+CDKN1B
1750 435 125 95
CEBPB+PTPN12
1750 519 118 93
ZBTB10+PTPN12
1750 433 170 104
RHOXF2+SET
1750 455 281 167
CDKN1B+CDKN1A
1750 460 129 90
POU3F2+FOXL2
1750 535 170 103
SET+KLF1
1750 295 150 78
FOXA3+FOXL2
1750 540 145 82
IRF1+SET
1750 477 303 176
SAMD1+PTPN12
1750 503 129 82
CBL+TGFBR2
1750 368 61 31
PTPN12+SNAI1
1750 486 85 58
DUSP9+PRTG
1750 456 257 149
KLF1+COL2A1
1750 281 135 63
ZBTB10+SNAI1
1750 442 120 58
CBL+PTPN9
1750 425 106 58
ETS2+MAP7D1
1750 530 116 73
PTPN12+ZBTB25
1750 424 141 64
BPGM+SAMD1
1750 491 123 76
PTPN12+PTPN9
1750 455 124 65
CEBPE+PTPN12
1750 509 82 73
CEBPB+CEBPA
1750 555 152 88
CEBPE+SPI1
1750 516 262 136
LYL1+CEBPB
1750 499 131 98
CEBPB+MAPK1
1750 483 243 154
CBL+CNN1
1750 391 166 86
CBL+UBASH3B
1750 376 142 81
BCL2L11+BAK1
1750 61 3 0
FOSB+PTPN12
1750 516 52 37
RHOXF2+ZBTB25
1750 398 169 100
PTPN12+UBASH3A
1750 480 124 71
SET+CEBPE
1750 501 145 110
BPGM+ZBT

Found local copy...


1750 411 37 24


These perturbations are not in the GO graph and their perturbation can thus not be predicted
['LYL1+IER5L' 'IER5L+ctrl' 'KIAA1804+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:7
combo_seen1:51
combo_seen2:18
unseen_single:27
Done!
Creating dataloaders....
Done!


here1
ETS2+CNN1
1750 461 53 36
ETS2+IGDCC3
1750 458 83 59
SGK1+S1PR2
1750 364 207 122
CNN1+UBASH3A
1750 346 104 50
FOXA1+FOXL2
1750 519 108 73
POU3F2+FOXL2
1750 511 170 100
JUN+CEBPA
1750 473 174 84
KLF1+BAK1
1750 241 58 32
FOXA3+FOXL2
1750 514 145 83
IRF1+SET
1750 390 303 164
IGDCC3+MAPK1
1750 384 120 79
SGK1+TBX3
1750 349 83 50
C3orf72+FOXL2
1750 450 24 14
FOSB+UBASH3B
1750 487 54 38
ETS2+MAP7D1
1750 367 116 58
UBASH3B+UBASH3A
1750 360 39 28
IGDCC3+ZBTB25
1750 345 173 92
FOSB+OSR2
1750 475 293 170
CNN1+MAPK1
1750 379 92 53
BPGM+SAMD1
1750 446 123 75
FOXA3+FOXA1
1750 526 197 128
TGFBR2+ETS2
1750 326 60 21
ETS2+PRTG
1750 498 135 98
KLF1+CEBPA
1750 255 251 124
FOXA1+HOXB9
1750 551 172 119
CEBPB+CEBPA
1750 494 152 85
LYL1+CEBPB
1750 441 131 99
IGDCC3+PRTG
1750 440 128 76
TMSB4X+BAK1
1750 263 8 0
CBL+CNN1
1750 395 166 82
CEBPE+CNN1
1750 422 103 78
BCL2L11+BAK1
1750 138 3 0
SGK1+TBX2
1750 325 147 84
FOSB+PTPN12
1750 401 52 33
PTPN12+UBASH3A
1750 345 124 65
BPGM+ZBTB1
1750 425 123 85
MAP2K3

Found local copy...


1750 362 68 30


These perturbations are not in the GO graph and their perturbation can thus not be predicted
['LYL1+IER5L' 'IER5L+ctrl' 'KIAA1804+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:2
combo_seen1:39
combo_seen2:23
unseen_single:27
Done!
Creating dataloaders....
Done!


here1
UBASH3B+CNN1
1750 402 187 82
POU3F2+CBFA2T3
1750 386 129 64
CDKN1C+CDKN1A
1750 467 145 113
AHR+FEV
1750 464 253 135
TGFBR2+PRTG
1750 298 112 63
CDKN1C+CDKN1B
1750 448 125 98
MAP2K3+MAP2K6
1750 406 13 9
MAPK1+PRTG
1750 386 149 96
MAP2K3+IKZF3
1750 499 128 70
DUSP9+PRTG
1750 476 257 150
KLF1+CLDN6
1750 279 170 66
CEBPE+RUNX1T1
1750 347 144 106
FOSB+UBASH3B
1750 339 54 36
MAP2K3+ELMSAN1
1750 277 98 37
UBASH3B+UBASH3A
1750 394 39 28
CNN1+MAPK1
1750 381 92 58
ETS2+CNN1
1750 308 53 35
ETS2+PRTG
1750 296 135 69
FOXA1+HOXB9
1750 456 172 113
IGDCC3+PRTG
1750 420 128 74
CBL+CNN1
1750 371 166 88
CEBPE+CNN1
1750 402 103 74
CBL+UBASH3B
1750 389 142 83
MAP2K3+SLC38A2
1750 167 18 4
UBASH3B+OSR2
1750 404 54 42
AHR+KLF1
1750 363 105 50
UBASH3B+ZBTB25
1750 381 100 59
FOXA3+HOXB9
1750 472 202 148
FOXL2+HOXB9
1750 455 118 53
UBASH3B+PTPN9
1750 371 120 67
FOXF1+HOXB9
1750 432 136 76
UBASH3B+PTPN12
1750 386 123 66
CNN1+UBASH3A
1750 371 104 50
SNAI1+UBASH3B
1750 510 57 40
FOXL2+MEIS1
1750 475 135 78
SA

Found local copy...


1750 400 68 35


These perturbations are not in the GO graph and their perturbation can thus not be predicted
['LYL1+IER5L' 'IER5L+ctrl' 'KIAA1804+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:5
combo_seen1:44
combo_seen2:21
unseen_single:25
Done!
Creating dataloaders....
Done!


here1
KLF1+CLDN6
1750 332 170 69
KLF1+TGFBR2
1750 330 70 31
AHR+KLF1
1750 388 105 49
FOXF1+HOXB9
1750 375 136 61
SAMD1+TGFBR2
1750 237 43 17
SET+KLF1
1750 292 150 78
KLF1+BAK1
1750 305 58 26
SAMD1+PTPN12
1750 258 129 64
AHR+FEV
1750 474 253 137
SGK1+TBX3
1750 317 83 53
C3orf72+FOXL2
1750 417 24 13
TGFBR2+PRTG
1750 246 112 51
CBL+TGFBR2
1750 329 61 32
LHX1+ELMSAN1
1750 164 120 31
FOXA3+FOXF1
1750 403 116 65
KLF1+COL2A1
1750 395 135 73
KLF1+MAP2K6
1750 235 48 19
ZC3HAV1+HOXC13
1750 459 77 58
UBASH3B+UBASH3A
1750 264 39 29
ZC3HAV1+CEBPE
1750 431 68 49
TBX3+TBX2
1750 309 109 62
BPGM+SAMD1
1750 299 123 74
CEBPE+KLF1
1750 417 79 57
TGFBR2+ETS2
1750 317 60 30
KLF1+CEBPA
1750 495 251 169
FOXA1+HOXB9
1750 423 172 106
LYL1+CEBPB
1750 482 131 99
TMSB4X+BAK1
1750 211 8 0
PTPN12+UBASH3A
1750 249 124 57
TGFBR2+IGDCC3
1750 341 93 58
MAPK1+TGFBR2
1750 268 134 49
KLF1+FOXA1
1750 368 132 71
FOXA3+HOXB9
1750 428 202 130
BCL2L11+TGFBR2
1750 261 10 3
SAMD1+ZBTB1
1750 362 145 88
FOXL2+HOXB9
1750 466 118 55


Found local copy...


1750 352 56 25


These perturbations are not in the GO graph and their perturbation can thus not be predicted
['LYL1+IER5L' 'IER5L+ctrl' 'KIAA1804+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:10
combo_seen1:58
combo_seen2:16
unseen_single:27
Done!
Creating dataloaders....
Done!


here1
IGDCC3+MAPK1
1750 486 120 82
CEBPB+PTPN12
1750 475 118 84
UBASH3B+UBASH3A
1750 484 39 32
PTPN12+PTPN9
1750 502 124 73
CEBPB+CEBPA
1750 504 152 72
MAP2K6+ELMSAN1
1750 430 34 21
CEBPB+MAPK1
1750 486 243 129
PTPN12+UBASH3A
1750 481 124 71
UBASH3B+PTPN9
1750 487 120 75
UBASH3B+PTPN12
1750 487 123 68
CBL+PTPN12
1750 487 116 64
SAMD1+PTPN12
1750 488 129 81
C3orf72+FOXL2
1750 512 24 14
MAP2K3+MAP2K6
1750 361 13 9
MAPK1+IKZF3
1750 508 224 142
PTPN12+SNAI1
1750 504 85 58
LHX1+ELMSAN1
1750 500 120 75
MAPK1+PRTG
1750 484 149 102
KLF1+MAP2K6
1750 388 48 25
CBL+PTPN9
1750 483 106 66
KIF18B+KIF2C
1750 477 5 3
FOSB+UBASH3B
1750 469 54 38
MAP2K3+ELMSAN1
1750 395 98 51
IGDCC3+ZBTB25
1750 479 173 104
PTPN12+ZBTB25
1750 488 141 77
CNN1+MAPK1
1750 484 92 54
BPGM+SAMD1
1750 489 123 80
KLF1+CEBPA
1750 423 251 136
FOXA1+HOXB9
1750 524 172 116
ZBTB10+PTPN12
1750 533 170 115
CEBPE+PTPN12
1750 490 82 68
LYL1+CEBPB
1750 472 131 90
IGDCC3+PRTG
1750 489 128 79
TMSB4X+BAK1
1750 433 8 1
SNAI1+DLX2
1750 534 126

In [7]:
np.save('p_values_norman_filter_0.01_gears', p_vals)
np.save('jaccards_norman_filter_0.01_gears', jaccards)

In [2]:
import numpy as np
p_vals = np.load('p_values_norman_filter_0.01_gears.npy',allow_pickle=True).item()
jaccards = np.load('jaccards_norman_filter_0.01_gears.npy',allow_pickle=True).item()

In [8]:
tot_num_tests = 0
for setting in jaccards[1].keys():
    for seed in range(1,6):
        print(seed, setting, len(jaccards[seed][setting]))
        tot_num_tests += len(jaccards[seed][setting])

1 combo_seen0 9
2 combo_seen0 7
3 combo_seen0 2
4 combo_seen0 5
5 combo_seen0 10
1 combo_seen1 52
2 combo_seen1 51
3 combo_seen1 39
4 combo_seen1 44
5 combo_seen1 58
1 combo_seen2 17
2 combo_seen2 18
3 combo_seen2 23
4 combo_seen2 21
5 combo_seen2 16
1 unseen_single 27
2 unseen_single 27
3 unseen_single 27
4 unseen_single 25
5 unseen_single 27


In [9]:
jaccards_list = []

for k in jaccards.keys():
    jaccards_list.append(np.mean(np.hstack(list(jaccards[k].values()))))

In [10]:
split_key = 'combo_seen0'

def get_jaccards(split_key):
    jaccards_list = []

    for k in jaccards.keys():
        jaccards_list.append(np.mean(jaccards[k][split_key]))
        
    return jaccards_list

combo_seen2 = get_jaccards('combo_seen2')
print('combo_seen2')
print(combo_seen2)

combo_seen1 = get_jaccards('combo_seen1')
print('combo_seen1')
print(combo_seen1)

combo_seen0 = get_jaccards('combo_seen0')
print('combo_seen0')
print(combo_seen0)

combo_seen2
[0.16141461592110307, 0.1582908402046712, 0.17920395022354837, 0.15522005751020926, 0.19681981729421538]
combo_seen1
[0.15252131079104497, 0.16352731224959907, 0.15135640323052388, 0.14842975120026075, 0.1352018805845803]
combo_seen0
[0.17828557442443146, 0.14964601024565893, 0.15182128690571467, 0.11078377760481421, 0.13037876026871137]


In [11]:
p_vals_all = {}
for k in p_vals[1].keys():
    p_vals_all[k] = []
    
for seed in range(1,6):
    for col in p_vals_all.keys():
        p_vals_all[col].extend(p_vals[seed][col])

In [12]:
dfs = []

for seed in range(1,6):
    df = pd.DataFrame.from_dict(p_vals[seed], orient='index').T
    df = df.melt().dropna()
    df['-log(p)'] = df.value.apply(lambda x: -np.log10(x+1e-20))
    df['condition'] = df['variable'].map({'combo_seen0':'0/2 Seen',
                                          'combo_seen1':'1/2 Seen',
                                          'combo_seen2':'2/2 Seen',
                                          'unseen_single':'0/1 Seen'})
    df['seed']=seed
    dfs.append(df)

dfs = pd.concat(dfs)