In [2]:
import numpy as np
import torch
import scanpy as sc
from data import PertDataloader, Network


name2path = {
    'GNN_Disentangle-L2': 'GNN_Disentangle_GAT_string_20.0_64_2_l2_Norman2019_gene_emb_pert_emb_constant_sim_gnn',
    'GNN_Disentangle_Sim': 'GNN_Disentangle_GAT_string_20.0_64_2_l3_Norman2019_gamma2_gene_emb_pert_emb_constant_sim_gnn',
    'No-Perturb': 'No_Perturb_GAT_string_20.0_64_2_l3_Norman2019_gamma2_gene_emb_pert_emb_constant_sim_gnn'
}

name = 'No-Perturb'

model_name = name2path[name]
args = np.load('./saved_args/'+model_name+'.npy', allow_pickle = True).item()
args['device'] = 'cuda:3'

In [3]:
#import wandb
#wandb.init(project='pert_gnn_simulation', entity='kexinhuang', name=name)

In [4]:
if args['network_name'] == 'string':
    args['network_path'] = '/dfs/project/perturb-gnn/graphs/STRING_full_9606.csv'

if args['dataset'] == 'Norman2019':
    data_path = '/dfs/project/perturb-gnn/datasets/Norman2019_hvg+perts_more_de.h5ad'

adata = sc.read_h5ad(data_path)
if 'gene_symbols' not in adata.var.columns.values:
    adata.var['gene_symbols'] = adata.var['gene_name']
gene_list = [f for f in adata.var.gene_symbols.values]
# Set up message passing network
network = Network(fname=args['network_path'], gene_list=args['gene_list'],
                  percentile=args['top_edge_percent'])

# Pertrubation dataloader
pertdl = PertDataloader(adata, network.G, network.weights, args)

There are 101013 edges in the PPI.
Creating pyg object for each cell in the data...
Local copy of pyg dataset is detected. Loading...
Loading splits...
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:9
combo_seen1:52
combo_seen2:18
unseen_single:37
Creating dataloaders....
Dataloaders created...


In [5]:
model = torch.load('./saved_models/' + model_name)
#model.pert_emb_agg = 'constant'
#model.lambda_emission = False
#model.sim_gnn = False
#model.args = args
from inference import evaluate, compute_metrics, deeper_analysis, GI_subgroup

test_res = evaluate(pertdl.loaders['test_loader'],
                        pertdl.loaders['edge_index'],
                        pertdl.loaders['edge_attr'], model, args)

test_metrics, test_pert_res = compute_metrics(test_res)



In [5]:
import pickle
metrics = ['mse', 'mae', 'spearman', 'pearson', 'r2']
for m in metrics:
    wandb.log({'test_' + m: test_metrics[m],
               'test_de_'+m: test_metrics[m + '_de']                     
              })
subgroup_path = './splits/Norman2019_simulation_1_0.1_subgroup.pkl'
subgroup = pickle.load(open(subgroup_path, "rb"))
        
subgroup_analysis = {}
for name in subgroup['test_subgroup'].keys():
    subgroup_analysis[name] = {}
    for m in list(list(test_pert_res.values())[0].keys()):
        subgroup_analysis[name][m] = []

for name, pert_list in subgroup['test_subgroup'].items():
    for pert in pert_list:
        for m, res in test_pert_res[pert].items():
            subgroup_analysis[name][m].append(res)

for name, result in subgroup_analysis.items():
    for m in result.keys():
        subgroup_analysis[name][m] = np.mean(subgroup_analysis[name][m])
        wandb.log({'test_' + name + '_' + m: subgroup_analysis[name][m]})

        print('test_' + name + '_' + m + ': ' + str(subgroup_analysis[name][m]))

test_combo_seen0_mse: 0.006539415
test_combo_seen0_mae: 0.024473794
test_combo_seen0_spearman: 0.8554495607638996
test_combo_seen0_pearson: 0.9798589015861244
test_combo_seen0_r2: 0.9598484850494818
test_combo_seen0_mse_de: 0.38452822
test_combo_seen0_mae_de: 0.4249717
test_combo_seen0_spearman_de: 0.6664525191936057
test_combo_seen0_pearson_de: 0.7978656813909297
test_combo_seen0_r2_de: 0.21557548380438085
test_combo_seen1_mse: 0.008615982
test_combo_seen1_mae: 0.026082836
test_combo_seen1_spearman: 0.8544749063877007
test_combo_seen1_pearson: 0.9739889917017668
test_combo_seen1_r2: 0.9473155717003212
test_combo_seen1_mse_de: 0.5381686
test_combo_seen1_mae_de: 0.5313102
test_combo_seen1_spearman_de: 0.670112338770854
test_combo_seen1_pearson_de: 0.7268227974597503
test_combo_seen1_r2_de: -0.1511186308570229
test_combo_seen2_mse: 0.005411647
test_combo_seen2_mae: 0.019519575
test_combo_seen2_spearman: 0.8663311193796182
test_combo_seen2_pearson: 0.9836433716631533
test_combo_seen2_r2: 

In [6]:
out = deeper_analysis(adata, test_res)
GI_out = GI_subgroup(out)

  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl


In [7]:
metrics = ['frac_in_range', 'mean_sigma', 'std_sigma', 'frac_sigma_below_1', 'frac_sigma_below_2', 
          'spearman_delta', 'spearman_delta_de', 'pearson_delta', 'pearson_delta_de', 'fold_change_gap_all', 
          'spearman_delta_top200_hvg', 'pearson_delta_top200_hvg', 'fold_change_gap_upreg_3', 'fold_change_gap_downreg_0.33',
          'fold_change_gap_downreg_0.1', 'fold_change_gap_upreg_10', 'spearman_top200_hvg', 'pearson_top200_hvg', 
          'pearson_top200_de', 'spearman_top200_de', 'pearson_delta_top200_de', 'spearman_delta_top200_de',
          'pearson_top100_de', 'spearman_top100_de', 'pearson_delta_top100_de', 'spearman_delta_top100_de']

for m in metrics:
    wandb.log({'test_' + m: np.mean([j[m] for i,j in out.items() if m in j])})

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


In [8]:
subgroup_path = './splits/Norman2019_simulation_1_0.1_subgroup.pkl'
subgroup = pickle.load(open(subgroup_path, "rb"))
        
subgroup_analysis = {}
for name in subgroup['test_subgroup'].keys():
    subgroup_analysis[name] = {}
    for m in metrics:
        subgroup_analysis[name][m] = []

for name, pert_list in subgroup['test_subgroup'].items():
    for pert in pert_list:
        for m, res in out[pert].items():
            subgroup_analysis[name][m].append(res)

for name, result in subgroup_analysis.items():
    for m in result.keys():
        subgroup_analysis[name][m] = np.mean(subgroup_analysis[name][m])
        wandb.log({'test_' + name + '_' + m: subgroup_analysis[name][m]})

        print('test_' + name + '_' + m + ': ' + str(subgroup_analysis[name][m]))

test_combo_seen0_frac_in_range: 1.0
test_combo_seen0_mean_sigma: 1.1593934
test_combo_seen0_std_sigma: 0.183148
test_combo_seen0_frac_sigma_below_1: 0.2971749226006192
test_combo_seen0_frac_sigma_below_2: 0.9875
test_combo_seen0_spearman_delta: 0.12472215133661285
test_combo_seen0_spearman_delta_de: 0.8058479532163743
test_combo_seen0_pearson_delta: 0.004367435041405094
test_combo_seen0_pearson_delta_de: 0.8607025363597248
test_combo_seen0_fold_change_gap_all: 0.972206
test_combo_seen0_spearman_delta_top200_hvg: 0.0019227147345350336
test_combo_seen0_pearson_delta_top200_hvg: -0.0227856575276294
test_combo_seen0_fold_change_gap_upreg_3: 71.17499
test_combo_seen0_fold_change_gap_downreg_0.33: nan
test_combo_seen0_fold_change_gap_downreg_0.1: nan
test_combo_seen0_fold_change_gap_upreg_10: 265.04575
test_combo_seen0_spearman_top200_hvg: 0.9271347596043279
test_combo_seen0_pearson_top200_hvg: 0.9379116568866616
test_combo_seen0_pearson_top200_de: 0.8966080945853451
test_combo_seen0_spearma

In [9]:
for i,j in GI_out.items():
    for m in ['mean_sigma', 'std_sigma', 'fold_change_gap_all', 'pearson_delta_top200_de', 
              'spearman_delta_top200_hvg', 'pearson_delta_top200_hvg', 
              'spearman_delta_top200_de', 'pearson_delta_top100_de', 'spearman_delta_top100_de']:
        wandb.log({'test_' + i + '_' + m: j[m]})

In [10]:
for i,j in  GI_subgroup(test_pert_res).items():
    for m in ['mse_de']:
        wandb.log({'test_' + i + '_' + m: j[m]})

In [27]:
np.quantile(test_res['truth_de'][pert_idx], 0.25, axis = 0)

array([1.98486948, 0.        , 0.        , 0.62915419, 1.09004056,
       1.97734478, 0.        , 0.67672339, 0.5912282 , 0.52250315,
       0.        , 0.        , 5.06873918, 0.        , 0.        ,
       0.99365979, 1.06806734, 0.        , 0.        , 2.5388093 ])

In [35]:
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import mean_squared_error as mse

metric2fct = {
       #'spearman': spearmanr, # not meaningful
       'pearson': pearsonr,
       'mse': mse
}

pert_metric = {}

## in silico modeling and upperbounding
pert2pert_full_id = dict(adata.obs[['condition', 'cov_drug_dose_name']].values)
geneid2name = dict(zip(adata.var.index.values, adata.var['gene_name']))
geneid2idx = dict(zip(adata.var.index.values, range(len(adata.var.index.values))))

# calculate mean expression for each condition
unique_conditions = adata.obs.condition.unique()
conditions2index = {}
for i in unique_conditions:
    conditions2index[i] = np.where(adata.obs.condition == i)[0]

condition2mean_expression = {}
for i, j in conditions2index.items():
    condition2mean_expression[i] = np.mean(adata.X[j], axis = 0)
pert_list = np.array(list(condition2mean_expression.keys()))
mean_expression = np.array(list(condition2mean_expression.values())).reshape(len(adata.obs.condition.unique()), adata.X.toarray().shape[1])
ctrl = mean_expression[np.where(pert_list == 'ctrl')[0]]
most_variable_genes = np.argsort(np.std(mean_expression, axis = 0))[-200:]
gene_list = adata.var['gene_name'].values

for pert in np.unique(test_res['pert_cat']):
    pert_metric[pert] = {}
    #de_names = [geneid2name[i] for i in adata.uns['rank_genes_groups_cov'][pert2pert_full_id[pert]]]
    de_idx = [geneid2idx[i] for i in adata.uns['rank_genes_groups_cov'][pert2pert_full_id[pert]]]
    de_idx_200 = [geneid2idx[i] for i in adata.uns['rank_genes_groups_cov_top200'][pert2pert_full_id[pert]]]
    de_idx_100 = [geneid2idx[i] for i in adata.uns['rank_genes_groups_cov_top100'][pert2pert_full_id[pert]]]
    de_idx_50 = [geneid2idx[i] for i in adata.uns['rank_genes_groups_cov_top50'][pert2pert_full_id[pert]]]

    pert_idx = np.where(test_res['pert_cat'] == pert)[0]    
    pred_mean = np.mean(test_res['pred_de'][pert_idx], axis = 0).reshape(-1,)
    true_mean = np.mean(test_res['truth_de'][pert_idx], axis = 0).reshape(-1,)

    mean = np.mean(test_res['truth_de'][pert_idx], axis = 0)
    std = np.std(test_res['truth_de'][pert_idx], axis = 0)
    min_ = np.min(test_res['truth_de'][pert_idx], axis = 0)
    max_ = np.max(test_res['truth_de'][pert_idx], axis = 0)
    q25 = np.quantile(test_res['truth_de'][pert_idx], 0.25, axis = 0)
    q75 = np.quantile(test_res['truth_de'][pert_idx], 0.75, axis = 0)
    q55 = np.quantile(test_res['truth_de'][pert_idx], 0.55, axis = 0)
    q45 = np.quantile(test_res['truth_de'][pert_idx], 0.45, axis = 0)
    q40 = np.quantile(test_res['truth_de'][pert_idx], 0.4, axis = 0)
    q60 = np.quantile(test_res['truth_de'][pert_idx], 0.6, axis = 0)
    
    zero_des = np.intersect1d(np.where(min_ == 0)[0], np.where(max_ == 0)[0])
    nonzero_des = np.setdiff1d(list(range(20)), zero_des)
    if len(nonzero_des) == 0:
        pass
        # pert that all de genes are 0...
    else:
        in_range = (pred_mean[nonzero_des] >= min_[nonzero_des]) & (pred_mean[nonzero_des] <= max_[nonzero_des])
        frac_in_range = sum(in_range)/len(nonzero_des)
        pert_metric[pert]['frac_in_range'] = frac_in_range
        
        in_range_5 = (pred_mean[nonzero_des] >= q45[nonzero_des]) & (pred_mean[nonzero_des] <= q55[nonzero_des])
        frac_in_range_45_55 = sum(in_range_5)/len(nonzero_des)
        pert_metric[pert]['frac_in_range_45_55'] = frac_in_range_45_55
        
        in_range_10 = (pred_mean[nonzero_des] >= q40[nonzero_des]) & (pred_mean[nonzero_des] <= q60[nonzero_des])
        frac_in_range_40_60 = sum(in_range_10)/len(nonzero_des)
        pert_metric[pert]['frac_in_range_40_60'] = frac_in_range_40_60
        
        in_range_25 = (pred_mean[nonzero_des] >= q25[nonzero_des]) & (pred_mean[nonzero_des] <= q75[nonzero_des])
        frac_in_range_25_75 = sum(in_range_25)/len(nonzero_des)
        pert_metric[pert]['frac_in_range_25_75'] = frac_in_range_25_75
        
        zero_idx = np.where(std > 0)[0]
        sigma = (np.abs(pred_mean[zero_idx] - mean[zero_idx]))/(std[zero_idx])
        pert_metric[pert]['mean_sigma'] = np.mean(sigma)
        pert_metric[pert]['std_sigma'] = np.std(sigma)
        pert_metric[pert]['frac_sigma_below_1'] = 1 - len(np.where(sigma > 1)[0])/len(zero_idx)
        pert_metric[pert]['frac_sigma_below_2'] = 1 - len(np.where(sigma > 2)[0])/len(zero_idx)

    ## correlation on delta
    p_idx = np.where(test_res['pert_cat'] == pert)[0]

    for m, fct in metric2fct.items():
        if m != 'mse':
            val = fct(test_res['pred'][p_idx].mean(0)- ctrl[0], test_res['truth'][p_idx].mean(0)-ctrl[0])[0]
            if np.isnan(val):
                val = 0

            pert_metric[pert][m + '_delta'] = val

            val = fct(test_res['pred_de'][p_idx].mean(0)- ctrl[0][de_idx], test_res['truth_de'][p_idx].mean(0)-ctrl[0][de_idx])[0]
            if np.isnan(val):
                val = 0

            pert_metric[pert][m + '_delta_de'] = val

    ## up fold changes > 10?
    pert_mean = np.mean(test_res['truth'][p_idx], axis = 0).reshape(-1,)

    fold_change = pert_mean/ctrl
    fold_change[np.isnan(fold_change)] = 0
    fold_change[np.isinf(fold_change)] = 0
    ## this is to remove the ones that are super low and the fold change becomes unmeaningful
    fold_change[0][np.where(pert_mean < 0.5)[0]] = 0

    o =  np.where(fold_change[0] > 0)[0]

    pred_fc = test_res['pred'][p_idx].mean(0)[o]
    true_fc = test_res['truth'][p_idx].mean(0)[o]
    ctrl_fc = ctrl[0][o]

    if len(o) > 0:
        pert_metric[pert]['fold_change_gap_all'] = np.mean(np.abs(pred_fc/ctrl_fc - true_fc/ctrl_fc))


    o = np.intersect1d(np.where(fold_change[0] <0.333)[0], np.where(fold_change[0] > 0)[0])

    pred_fc = test_res['pred'][p_idx].mean(0)[o]
    true_fc = test_res['truth'][p_idx].mean(0)[o]
    ctrl_fc = ctrl[0][o]

    if len(o) > 0:
        pert_metric[pert]['fold_change_gap_downreg_0.33'] = np.mean(np.abs(pred_fc/ctrl_fc - true_fc/ctrl_fc))


    o = np.intersect1d(np.where(fold_change[0] <0.1)[0], np.where(fold_change[0] > 0)[0])

    pred_fc = test_res['pred'][p_idx].mean(0)[o]
    true_fc = test_res['truth'][p_idx].mean(0)[o]
    ctrl_fc = ctrl[0][o]

    if len(o) > 0:
        pert_metric[pert]['fold_change_gap_downreg_0.1'] = np.mean(np.abs(pred_fc/ctrl_fc - true_fc/ctrl_fc))

    o = np.where(fold_change[0] > 3)[0]

    pred_fc = test_res['pred'][p_idx].mean(0)[o]
    true_fc = test_res['truth'][p_idx].mean(0)[o]
    ctrl_fc = ctrl[0][o]

    if len(o) > 0:
        pert_metric[pert]['fold_change_gap_upreg_3'] = np.mean(np.abs(pred_fc/ctrl_fc - true_fc/ctrl_fc))

    o = np.where(fold_change[0] > 10)[0]

    pred_fc = test_res['pred'][p_idx].mean(0)[o]
    true_fc = test_res['truth'][p_idx].mean(0)[o]
    ctrl_fc = ctrl[0][o]

    if len(o) > 0:
        pert_metric[pert]['fold_change_gap_upreg_10'] = np.mean(np.abs(pred_fc/ctrl_fc - true_fc/ctrl_fc))

    ## most variable genes
    for m, fct in metric2fct.items():
        if m != 'mse':
            val = fct(test_res['pred'][p_idx].mean(0)[most_variable_genes] - ctrl[0][most_variable_genes], test_res['truth'][p_idx].mean(0)[most_variable_genes]-ctrl[0][most_variable_genes])[0]
            if np.isnan(val):
                val = 0
            pert_metric[pert][m + '_delta_top200_hvg'] = val


            val = fct(test_res['pred'][p_idx].mean(0)[most_variable_genes], test_res['truth'][p_idx].mean(0)[most_variable_genes])[0]
            if np.isnan(val):
                val = 0
            pert_metric[pert][m + '_top200_hvg'] = val
        else:
            val = fct(test_res['pred'][p_idx].mean(0)[most_variable_genes], test_res['truth'][p_idx].mean(0)[most_variable_genes])
            pert_metric[pert][m + '_top200_hvg'] = val


    ## top 50/100/200 DEs
    for m, fct in metric2fct.items():
        if m != 'mse':
            val = fct(test_res['pred'][p_idx].mean(0)[de_idx_200] - ctrl[0][de_idx_200], test_res['truth'][p_idx].mean(0)[de_idx_200]-ctrl[0][de_idx_200])[0]
            if np.isnan(val):
                val = 0
            pert_metric[pert][m + '_delta_top200_de'] = val


            val = fct(test_res['pred'][p_idx].mean(0)[de_idx_200], test_res['truth'][p_idx].mean(0)[de_idx_200])[0]
            if np.isnan(val):
                val = 0
            pert_metric[pert][m + '_top200_de'] = val
        else:
            val = fct(test_res['pred'][p_idx].mean(0)[de_idx_200] - ctrl[0][de_idx_200], test_res['truth'][p_idx].mean(0)[de_idx_200]-ctrl[0][de_idx_200])
            pert_metric[pert][m + '_top200_de'] = val

    for m, fct in metric2fct.items():
        if m != 'mse':
        
            val = fct(test_res['pred'][p_idx].mean(0)[de_idx_100] - ctrl[0][de_idx_100], test_res['truth'][p_idx].mean(0)[de_idx_100]-ctrl[0][de_idx_100])[0]
            if np.isnan(val):
                val = 0
            pert_metric[pert][m + '_delta_top100_de'] = val


            val = fct(test_res['pred'][p_idx].mean(0)[de_idx_100], test_res['truth'][p_idx].mean(0)[de_idx_100])[0]
            if np.isnan(val):
                val = 0
            pert_metric[pert][m + '_top100_de'] = val
        else:
            val = fct(test_res['pred'][p_idx].mean(0)[de_idx_100] - ctrl[0][de_idx_100], test_res['truth'][p_idx].mean(0)[de_idx_100]-ctrl[0][de_idx_100])
            pert_metric[pert][m + '_top100_de'] = val
            
    for m, fct in metric2fct.items():
        if m != 'mse':
        
            val = fct(test_res['pred'][p_idx].mean(0)[de_idx_50] - ctrl[0][de_idx_50], test_res['truth'][p_idx].mean(0)[de_idx_50]-ctrl[0][de_idx_50])[0]
            if np.isnan(val):
                val = 0
            pert_metric[pert][m + '_delta_top50_de'] = val


            val = fct(test_res['pred'][p_idx].mean(0)[de_idx_50], test_res['truth'][p_idx].mean(0)[de_idx_50])[0]
            if np.isnan(val):
                val = 0
            pert_metric[pert][m + '_top50_de'] = val
        else:
            val = fct(test_res['pred'][p_idx].mean(0)[de_idx_50] - ctrl[0][de_idx_50], test_res['truth'][p_idx].mean(0)[de_idx_50]-ctrl[0][de_idx_50])
            pert_metric[pert][m + '_top50_de'] = val



  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_c

  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_c

  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_c

  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl


In [37]:
pert_metric

{'AHR+KLF1': {'frac_in_range': 1.0,
  'frac_in_range_45_55': 0.0,
  'frac_in_range_40_60': 0.0,
  'frac_in_range_25_75': 0.15,
  'mean_sigma': 1.1150635,
  'std_sigma': 0.40569043,
  'frac_sigma_below_1': 0.4,
  'frac_sigma_below_2': 0.95,
  'pearson_delta': 0.022298707646802935,
  'pearson_delta_de': 0.960052268454196,
  'fold_change_gap_all': 0.1878515,
  'fold_change_gap_upreg_3': 2.952547,
  'pearson_delta_top200_hvg': 0.10723303505101547,
  'pearson_top200_hvg': 0.9595004643472538,
  'mse_top200_hvg': 0.08491335,
  'pearson_delta_top200_de': 0.09362745184777875,
  'pearson_top200_de': 0.9622425492148956,
  'mse_top200_de': 0.09853176,
  'pearson_delta_top100_de': 0.1683593718888371,
  'pearson_top100_de': 0.9381953503523631,
  'mse_top100_de': 0.17942572,
  'pearson_delta_top50_de': 0.23618640501101235,
  'pearson_top50_de': 0.9026560662482132,
  'mse_top50_de': 0.3203701},
 'ARID1A+ctrl': {'frac_in_range': 1.0,
  'frac_in_range_45_55': 0.0,
  'frac_in_range_40_60': 0.0,
  'frac_i

In [36]:
fold_change

array([[0.       , 0.       , 0.       , ..., 0.9791699, 0.       ,
        0.       ]], dtype=float32)

In [19]:
fold_change_gap_downreg_num = 0
fold_change_gap_upreg_num = 0

for i,j in pert_metric.items():
    if 'fold_change_gap_downreg_0.33' in j:
        fold_change_gap_downreg_num += 1
    elif 'fold_change_gap_upreg_3' in j:
        fold_change_gap_upreg_num += 1

In [20]:
fold_change_gap_upreg_num

102

In [21]:
fold_change_gap_downreg_num

3