In [1]:
import scanpy as sc
import numpy as np
import pandas as pd
from data import DataSplitter

adata = sc.read('/dfs/project/perturb-gnn/datasets/Norman2019_hvg+perts.h5ad')

In [3]:
len(adata.obs.condition.unique())

284

In [11]:
adata.obs.condition.unique()

['TSC22D1+ctrl', 'KLF1+MAP2K6', 'ctrl', 'CEBPE+RUNX1T1', 'MAML2+ctrl', ..., 'STIL+ctrl', 'CDKN1C+ctrl', 'ctrl+CDKN1B', 'CDKN1B+CDKN1A', 'C3orf72+FOXL2']
Length: 284
Categories (284, object): ['AHR+FEV', 'AHR+KLF1', 'AHR+ctrl', 'ARID1A+ctrl', ..., 'ZC3HAV1+HOXC13', 'ZC3HAV1+ctrl', 'ZNF318+FOXL2', 'ZNF318+ctrl']

In [2]:
def print_split_stats(adata, split_column):
    x = dict(adata.obs.groupby(split_column).agg({'condition': lambda x: x}).condition)
    for name in ['train', 'val', 'test']:
        single_count = 0
        double_count = 0
        for i in x[name].unique().tolist():
            if i != 'ctrl' and 'ctrl' in i:
                single_count += 1
            else:
                if i!= 'ctrl':
                    double_count +=1
        print(name + ' set has ' + str(len(x[name].unique().tolist())) + \
              ' unique perts (# single pert ' + str(single_count) + ', # double pert ' + str(double_count) \
              +') with '+ str(len(x[name])) + ' data point')

In [25]:
D = DataSplitter(adata, 'single')
new_adata = D.split_data(test_pert_genes=None,
             test_perts=None, split_name='split_new', test_size=0.1, seed = 1)
print_split_stats(new_adata, 'split_new')

train set has 178 unique perts (# single pert 111, # double pert 66) with 63964 data point
val set has 18 unique perts (# single pert 18, # double pert 0) with 6286 data point
test set has 23 unique perts (# single pert 23, # double pert 0) with 5778 data point


In [3]:
D = DataSplitter(adata, 'single_only')
new_adata = D.split_data(test_pert_genes=None,
             test_perts=None, split_name='split_new', test_size=0.1, seed = 1)
print_split_stats(new_adata, 'split_new')

train set has 113 unique perts (# single pert 112, # double pert 0) with 44168 data point
val set has 18 unique perts (# single pert 18, # double pert 0) with 6040 data point
test set has 22 unique perts (# single pert 22, # double pert 0) with 5552 data point


In [10]:
D = DataSplitter(adata, 'combo', seen=0)
new_adata = D.split_data(test_pert_genes=None,
             test_perts=None, split_name='split_new', test_size=0.25, seed = 1)
print_split_stats(new_adata, 'split_new')
# drop all the singles A/B, and single A/B + C

train set has 92 unique perts (# single pert 70, # double pert 21) with 34784 data point
val set has 45 unique perts (# single pert 35, # double pert 10) with 15783 data point
test set has 59 unique perts (# single pert 47, # double pert 12) with 16264 data point


In [28]:
D = DataSplitter(adata, 'combo', seen=1)
new_adata = D.split_data(test_pert_genes=None,
             test_perts=None, split_name='split_new', test_size=0.1, seed = 1)
print_split_stats(new_adata, 'split_new')
# select A, drop all the singles B, and single B + C

train set has 178 unique perts (# single pert 111, # double pert 66) with 63964 data point
val set has 42 unique perts (# single pert 18, # double pert 24) with 13215 data point
test set has 62 unique perts (# single pert 23, # double pert 39) with 13780 data point


In [8]:
D = DataSplitter(adata, 'combo', seen=2)
new_adata = D.split_data(test_pert_genes=None,
             test_perts=None, split_name='split_new', test_size=0.3, seed = 1)
print_split_stats(new_adata, 'split_new')
# just include everything

train set has 224 unique perts (# single pert 152, # double pert 71) with 75081 data point
val set has 25 unique perts (# single pert 0, # double pert 25) with 6600 data point
test set has 35 unique perts (# single pert 0, # double pert 35) with 9524 data point


In [4]:
def print_split_stats_dict(d):
    for name in ['train', 'val', 'test']:
        single_count = 0
        double_count = 0
        for i in d[name]:
            if i != 'ctrl' and 'ctrl' in i:
                single_count += 1
            else:
                if i!= 'ctrl':
                    double_count +=1
        print(name + ' set has ' + str(len(d[name])) + \
              ' unique perts: # single pert ' + str(single_count) + ', # double pert ' + str(double_count))

In [24]:
import pickle
split_path = './splits/Norman2019_combo_seen0_1_0.1.pkl'
print_split_stats_dict(pickle.load(open(split_path, "rb")))

train set has 178 unique perts: # single pert 111, # double pert 66
val set has 22 unique perts: # single pert 18, # double pert 3
test set has 26 unique perts: # single pert 23, # double pert 3


In [30]:
split_path = './splits/Norman2019_combo_seen1_1_0.1.pkl'
print_split_stats_dict(pickle.load(open(split_path, "rb")))

train set has 178 unique perts: # single pert 111, # double pert 66
val set has 43 unique perts: # single pert 18, # double pert 24
test set has 62 unique perts: # single pert 23, # double pert 39


In [31]:
split_path = './splits/Norman2019_combo_seen2_1_0.1.pkl'
print_split_stats_dict(pickle.load(open(split_path, "rb")))

train set has 260 unique perts: # single pert 152, # double pert 107
val set has 12 unique perts: # single pert 0, # double pert 11
test set has 13 unique perts: # single pert 0, # double pert 13


In [14]:
## helper function
def parse_single_pert(i):
    a = i.split('+')[0]
    b = i.split('+')[1]
    if a == 'ctrl':
        pert = b
    else:
        pert = a
    return pert

def parse_combo_pert(i):
    return i.split('+')[0], i.split('+')[1]


def parse_any_pert(p):
    if ('ctrl' in p) and (p != 'ctrl'):
        return [parse_single_pert(p)]
    elif 'ctrl' not in p:
        out = parse_combo_pert(p)
        return [out[0], out[1]]
        

In [5]:
pert_list = [p for p in D.adata.obs['condition'].unique() if
                        p != 'ctrl']

In [6]:
unique_pert_genes = D.get_genes_from_perts(pert_list)

In [7]:
len(unique_pert_genes)

105

In [8]:
unique_pert_genes

array(['AHR', 'ARID1A', 'ARRDC3', 'ATL1', 'BAK1', 'BCL2L11', 'BCORL1',
       'BPGM', 'C19orf26', 'C3orf72', 'CBFA2T3', 'CBL', 'CDKN1A',
       'CDKN1B', 'CDKN1C', 'CEBPA', 'CEBPB', 'CEBPE', 'CELF2', 'CITED1',
       'CKS1B', 'CLDN6', 'CNN1', 'CNNM4', 'COL1A1', 'COL2A1', 'CSRNP1',
       'DLX2', 'DUSP9', 'EGR1', 'ELMSAN1', 'ETS2', 'FEV', 'FOSB', 'FOXA1',
       'FOXA3', 'FOXF1', 'FOXL2', 'FOXO4', 'GLB1L2', 'HES7', 'HK2',
       'HNF4A', 'HOXA13', 'HOXB9', 'HOXC13', 'IER5L', 'IGDCC3', 'IKZF3',
       'IRF1', 'ISL2', 'JUN', 'KIAA1804', 'KIF18B', 'KIF2C', 'KLF1',
       'KMT2A', 'LHX1', 'LYL1', 'MAML2', 'MAP2K3', 'MAP2K6', 'MAP4K3',
       'MAP4K5', 'MAP7D1', 'MAPK1', 'MEIS1', 'MIDN', 'NCL', 'NIT1',
       'OSR2', 'PLK4', 'POU3F2', 'PRDM1', 'PRTG', 'PTPN1', 'PTPN12',
       'PTPN13', 'PTPN9', 'RHOXF2BB', 'RREB1', 'RUNX1T1', 'S1PR2',
       'SAMD1', 'SET', 'SGK1', 'SLC38A2', 'SLC4A1', 'SLC6A9', 'SNAI1',
       'SPI1', 'STIL', 'TBX2', 'TBX3', 'TGFBR2', 'TMSB4X', 'TP73',
       'TSC22D1', 'U

In [98]:
def parse_combo_pert(i):
    return i.split('+')[0], i.split('+')[1]

In [161]:
train_gene_set_size = 0.7
combo_seen2_train_frac = 0.7
seed = 1

pert_train = []
unseen_single = []
combo_seen0 = []
combo_seen1 = []
combo_seen1 = []
combo_seen2 = []
np.random.seed(seed=seed)
train_gene_candidates = np.random.choice(unique_pert_genes,
                                        int(len(unique_pert_genes) * train_gene_set_size), replace = False)

In [162]:
ood_genes = np.setdiff1d(unique_pert_genes, train_gene_candidates)

In [163]:
len(train_gene_candidates)

73

In [164]:
len(np.unique(train_gene_candidates))

73

In [165]:
len(pert_list)

283

In [166]:
len(np.unique(pert_list))

283

In [167]:
pert_single_train = D.get_perts_from_genes(train_gene_candidates, pert_list,'single')
pert_combo = D.get_perts_from_genes(train_gene_candidates, pert_list,'combo')
pert_train.extend(pert_single_train)

In [168]:
len(pert_single_train)

108

In [169]:
len(pert_combo)

120

In [170]:
combo_seen1 = [x for x in pert_combo if len([t for t in x.split('+') if
                                     t in train_gene_candidates]) == 1]

In [171]:
len(combo_seen1)

61

In [172]:
pert_combo = np.setdiff1d(pert_combo, combo_seen1)

In [173]:
len(pert_combo)

59

In [174]:
np.random.seed(seed=seed)
pert_combo_train = np.random.choice(pert_combo, int(len(pert_combo) * combo_seen2_train_frac), replace = False)

In [175]:
combo_seen2 = np.setdiff1d(pert_combo, pert_combo_train)

In [176]:
pert_train.extend(pert_combo_train)

In [177]:
len(pert_combo)

59

In [178]:
len(pert_train)

149

In [179]:
len(combo_seen2)

18

In [188]:
len(pert_combo_train)

41

In [180]:
len(D.get_perts_from_genes(train_gene_candidates, pert_list,'both'))

228

In [181]:
unseen_single = D.get_perts_from_genes(ood_genes, pert_list, 'single')

In [182]:
len(unseen_single)

44

In [183]:
len(np.unique(unseen_single))

44

In [184]:
combo_ood = D.get_perts_from_genes(ood_genes, pert_list, 'combo')

In [185]:
combo_seen0 = [x for x in combo_ood if len([t for t in x.split('+') if
                                     t in train_gene_candidates]) == 0]

In [186]:
len(combo_seen1) + len(combo_seen0) + len(unseen_single) + len(pert_train) + len(combo_seen2)

283

In [187]:
len(combo_ood)

72

In [189]:
len(pert_train)

149

In [2]:
DS = DataSplitter(adata, split_type='simulation')
                
adata, subgroup = DS.split_data(train_gene_set_size = 0.75, 
                                combo_seen2_train_frac = 0.75,
                                seed=1)

In [10]:
adata.obs.split.value_counts()

train    51283
test     31898
val       8024
Name: split, dtype: int64

In [13]:
test_perts = adata.obs[adata.obs.split == 'test'].condition.unique()

In [14]:
test_pert_res = {}
for i in test_perts:
    test_pert_res[i] = {'r2': 0.35, 'r2_de': 0.58}

In [19]:
list(test_pert_res.values())[0]

{'r2': 0.35, 'r2_de': 0.58}

In [20]:
subgroup_analysis = {}
for name in subgroup['test_subgroup'].keys():
    subgroup_analysis[name] = {}
    for m in list(test_pert_res.values())[0].keys():
        subgroup_analysis[name][m] = []
        
for name, pert_list in subgroup['test_subgroup'].items():
    for pert in pert_list:
        for m, res in test_pert_res[pert].items():
            subgroup_analysis[name][m].append(res)
            
for name, result in subgroup_analysis.items():
    for m in result.keys():
        subgroup_analysis[name][m] = np.mean(subgroup_analysis[name][m])
        #if args['wandb']:
        #    wandb.log({'test_' + name + '_' + m: subgroup_analysis[name][m]})
            
        print('test_' + name + '_' + m + ': ' + str(subgroup_analysis[name][m]))

test_combo_seen0_r2: 0.35
test_combo_seen0_r2_de: 0.58
test_combo_seen1_r2: 0.35000000000000014
test_combo_seen1_r2_de: 0.5799999999999998
test_combo_seen2_r2: 0.3499999999999999
test_combo_seen2_r2_de: 0.58
test_unseen_single_r2: 0.3499999999999999
test_unseen_single_r2_de: 0.5799999999999997


In [None]:

{'AHR+FEV': {'r2': 0.9321523789238141,
  'mse': 0.011103348,
  'r2_de': 0.6442925181646209,
  'mse_de': 0.92472154},
 'AHR+KLF1': {'r2': 0.966444615617381,
  'mse': 0.0054913764,
  'r2_de': -0.8421060510555765,
  'mse_de': 0.6459228}
}