In [1]:
### Script for computing GI genes
## Designed for the LOO training/testing strategy

import glob
import numpy as np
import torch
import scanpy as sc
import pandas as pd 
import copy
import sys
import os

sys.path.append('../../gears_misc/')
from gears import PertData, GEARS
from gears.inference import GIs

#from data import PertDataloader
#from inference import evaluate, compute_metrics

from scipy.stats import pearsonr
import pdb

# Linear model fitting functions
import statsmodels.api as sm
from sklearn.linear_model import LinearRegression, TheilSenRegressor
from dcor import distance_correlation, partial_distance_correlation
from sklearn.metrics import r2_score


In [2]:
def get_ctrl_pert(x):
    # Sometime ctrl is first, sometimes its second, so this finds the right match
    return [c for c in all_conditions if x in c and 'ctrl' in c][0]

def create_output_df(pert, model, split='truth'):
    if split == 'truth':
        out = model.adata[model.adata.obs.condition == pert].X.toarray().mean(0)
    elif split == 'pred':
        p_ = [x for x in pert.split('+') if x != 'ctrl']
        out = model.predict([p_])['_'.join(p_)]
    out_df = pd.DataFrame(out).T
    out_df.columns = model.adata.var.gene_name
    out_df['condition'] = pert
    return out_df

def get_combo_adata_pred(combo, out_df, DE_genes, control_adata):
    """
    Given a combo perturbation, returns expression profiles for component 
    single perturbations and combo perturbation. Restriced to union of DE genes
    """
    control = control_adata[:,pert_adata.var.index]
    
    x,y = combo.split('+')
    x_pert = get_ctrl_pert(x)
    y_pert = get_ctrl_pert(y)
    x_y_pert = combo
    
    pert_df = out_df.loc[[x_pert,y_pert,x_y_pert]]
    pert_df = pert_df - control.to_df().mean(0) #- out_df.loc['ctrl']
    return pert_df.loc[:,DE_genes]

# Get DE means
def get_de_exp(pert, mean_ctrl_exp):
    de_genes = get_covar_genes(pert, gene_name_dict=gene_name_dict)
    return mean_ctrl_exp[de_genes]

def get_single_name(g, adata):
    name = g+'+ctrl'
    if name in adata.obs.condition.values:
        return name
    else:
        return 'ctrl+'+g

def combine_train_val(train_res, val_res):
    for key in train_res:
        train_res[key] = np.concatenate([train_res[key], val_res[key]])   
    return train_res


## GI calculation functions
def get_high_umi_idx(adata):
    # Genes used for linear model fitting
    high_umi = np.load('./genes_with_hi_mean.npy', allow_pickle=True)
    high_umi_idx = np.where([g in high_umi for g in adata.var.gene_name.values])[0]
    
    return high_umi_idx

# Unexpected genes

## Identify most unexpect expressed genes

def get_combo_single_df(combo, model, split='truth'):
    """
    Returns a df with combination and single perturbations
    """

    x,y = combo.split('+')
    x_pert = get_ctrl_pert(x)
    y_pert = get_ctrl_pert(y)
    perts = [combo, x_pert, y_pert]
    
    ## Create df for combo+singles
    dfs = []
    for p in perts:
        dfs.append(create_output_df(p, model, split=split))
    
    all_df = pd.concat(dfs)
    all_df = all_df.set_index('condition')
    
    return all_df

def mean_subtract(data_df, mean_control):
    return data_df - mean_control.loc[:,data_df.columns].values

def add_naive_combo(data_df, combo):
    """
    Returns set of genes that show unexpected expression
    """
    x,y = combo.split('+')
    x_pert = get_ctrl_pert(x)
    y_pert = get_ctrl_pert(y)
    
    naive_sum = data_df.loc[x_pert,:] + data_df.loc[y_pert,:]
    naive_sum.name = 'Naive'

    data_df = data_df.append(naive_sum) 
    return data_df

def get_unexpect_genes(data_df, combo, k=25):
    
    data_df = add_naive_combo(data_df, combo)
    abs_diff = (data_df.loc[combo]-data_df.loc['Naive']).apply(np.abs)
    unexp_genes = abs_diff.sort_values(ascending=False).index
    
    return data_df, unexp_genes[:k]

def get_unexpect_mse(combo, model, mean_control, return_plot_df = False, k=20):
    results = {}
    
    truth_df = get_combo_single_df(combo, model, split='truth')
    truth_df = mean_subtract(truth_df, mean_control)
    
    truth_df, unexp_genes = get_unexpect_genes(truth_df, combo, k=k)
    truth_df = truth_df.loc[:,unexp_genes]

    pred_df = get_combo_single_df(combo, model, split='pred')
    pred_df = mean_subtract(pred_df, mean_control)
    pred_df = pred_df.loc[:,unexp_genes]
    pred_df.index = pred_df.index+'_p'
    
    results['pred_mse'] = np.mean((pred_df.loc[combo+'_p',:] - truth_df.loc[combo,:])**2)
    results['pred_pearson'] = pearsonr(pred_df.loc[combo+'_p',:],truth_df.loc[combo,:])
    results['naive_mse'] = np.mean((truth_df.loc['Naive',:] - truth_df.loc[combo,:])**2)
    results['naive_pearson'] = pearsonr(truth_df.loc['Naive',:],truth_df.loc[combo,:])
    
    if return_plot_df:
        plot_df = truth_df.append(pred_df.loc[combo+'_p'])
        return results, plot_df
    
    else:
        return results
    
def get_unexpect_precision(combo, model, mean_control, k=10, k_real=None):
    results = {}
    
    truth_df = get_combo_single_df(combo, model, split='truth')
    truth_df = mean_subtract(truth_df, mean_control)

    pred_df_ = get_combo_single_df(combo, model, split='pred')
    pred_df_ = mean_subtract(pred_df_, mean_control)
    
    # Replace single perturbation in pred_df with true values   
    pred_df = truth_df.copy()
    pred_df.loc[combo,:] = pred_df_.loc[combo, :]
    
    if k_real is None:
        k_real = k
    truth_naive_df, unexp_genes_truth = get_unexpect_genes(truth_df, combo, k=k_real)
    
    pred_naive_df, unexp_genes_pred = get_unexpect_genes(pred_df, combo, k=k)
    
    naive_topk = ((truth_naive_df.loc['Naive',:])**2).sort_values(ascending=False)[:k].keys()

    results['naive_precision_at_'+str(k)+'_'+str(k_real)] = len(set(unexp_genes_truth).intersection(set(naive_topk)))/k
    results['precision_at_'+str(k)+'_'+str(k_real)] = len(set(unexp_genes_truth).intersection(set(unexp_genes_pred)))/k
    
    return results

def load_data(seed, test_pert):
    pert_data = PertData(data_path[:-1], gene_path='/dfs/user/yhr/gears2/Evaluation/data/essential_norman.pkl') # specific saved folder
    pert_data.load(data_path= data_path+dataset) # load the processed data, the path is saved folder + dataset_name
    pert_data.prepare_split(split = 'combo_seen2', seed = seed, test_perts = test_pert)
    pert_data.get_dataloader(batch_size = 32, test_batch_size = 32)
    
    adata_df = pert_data.adata.to_df()
    adata_df['condition'] = pert_data.adata.obs['condition']
    mean_df = adata_df.groupby('condition').mean()
    ctrl_mean = mean_df.loc['ctrl']
    
    return pert_data, ctrl_mean, mean_df

def load_model(pert_data, test_pert, device = 'cuda:8'):
    gears_model = GEARS(pert_data, device = device, 
                weight_bias_track = False, 
                proj_name = 'norman_go', 
                exp_name = 'gears')
    gears_model.load_pretrained('./model_ckpt/GI/norman_umi_go_' + test_pert)
    
    return gears_model
    

In [3]:
dataset = 'norman_umi_go'
data_path = '/dfs/project/perturb-gnn/datasets/data/'
model = 'gears'
device = 'cuda:3'
save = True
seed_name = 'seed1'
seen = 2

In [4]:
for idx, GI_loo_trained_model in enumerate(glob.glob('/dfs/user/yhr/gears2/Evaluation/model_ckpt/GI/norman_umi_go_*')):
    test_pert = GI_loo_trained_model.split('_')[-1]
    print('------- Running ' + test_pert +' ---------')
    
    print(idx)
    if idx==0:
        pert_data, ctrl_mean, mean_df = load_data(seed=1, test_pert=test_pert)
        ctrl_df = pd.DataFrame(ctrl_mean).T
        ctrl_df.columns = pert_data.adata.var.gene_name.values
        all_conditions = pert_data.adata.obs['condition'].unique()
    gears_model = load_model(pert_data, test_pert, device)

    results, plot_df = get_unexpect_mse(test_pert,
                                        gears_model,
                                        ctrl_df, 
                                        return_plot_df=True,
                                        k=20)
    results.update(get_unexpect_precision(test_pert, 
                                        gears_model,
                                        ctrl_df, 
                                        k=10))
    results.update(get_unexpect_precision(test_pert, 
                                        gears_model,
                                        ctrl_df, 
                                        k=5))
    results.update(get_unexpect_precision(test_pert, 
                                          gears_model,
                                        ctrl_df, 
                                        k=5, k_real=25))
    
    if save:
        np.save('GI_gene_mse/'+test_pert+seed_name, results)
        plot_df.to_csv('GI_gene_mse/dfs/'+seed_name+'/'+test_pert)

    del(gears_model)

------- Running MAP2K3+MAP2K6 ---------
0


Found local copy...
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Done!
Creating dataloaders....
Done!


------- Running BPGM+SAMD1 ---------
1
------- Running CBL+UBASH3B ---------
2
------- Running PRDM1+CBFA2T3 ---------
3
------- Running ZBTB10+ELMSAN1 ---------
4
------- Running FOXA1+FOXL2 ---------
5
------- Running TMSB4X+BAK1 ---------
6
------- Running CEBPE+RUNX1T1 ---------
7
------- Running TGFBR2+C19orf26 ---------
8
------- Running MAP2K6+ELMSAN1 ---------
9
------- Running FEV+MAP7D1 ---------
10
------- Running CEBPB+PTPN12 ---------
11
------- Running CEBPE+SPI1 ---------
12
------- Running PTPN12+PTPN9 ---------
13
------- Running MAPK1+PRTG ---------
14
------- Running JUN+CEBPB ---------
15
------- Running RHOXF2BB+SET ---------
16
------- Running AHR+KLF1 ---------
17
------- Running MAP2K3+IKZF3 ---------
18
------- Running KLF1+MAP2K6 ---------
19
------- Running CDKN1C+CDKN1B ---------
20
------- Running ETS2+CEBPE ---------
21
------- Running SAMD1+ZBTB1 ---------
22
------- Running CBL+TGFBR2 ---------
23
------- Running SET+KLF1 ---------
24
------- Running ZBT