In [3]:
import numpy as np
import pandas as pd
import pickle
import torch
import sys
import matplotlib.pyplot as plt
from matplotlib import cm,colors
from warnings import filterwarnings
from tqdm import tqdm
import seaborn as sns
from torch import nn

filterwarnings('ignore')
sys.path.append('/home/che/perturb-project/git/gene_ptb_prediction/src/')

import scanpy as sc
sc.settings.verbosity = 3
sc.settings.set_figure_params(dpi=80, facecolor='white', frameon=False)
from scipy.stats import pearsonr
from sklearn.metrics import r2_score
import pytorch_lightning as pl

In [4]:
from inference import *
from utils import SCDATA_sampler, MMD_loss
from dataset import SCDataset
from torch.utils.data import DataLoader

In [5]:
# set random seed
rand_seed = 12
pl.seed_everything(rand_seed)
device = 'cuda:7'

Seed set to 12


In [18]:
# first read in the csv file of dataset
dataset = 'replogle_rpe1'
use_hvg = 'True'
dataset_name = dataset + '_hvg'
scdata_file = pd.read_csv('/home/che/perturb-project/git/gene_ptb_prediction/scdata_file_path.csv')
adata_path = scdata_file[scdata_file['dataset'] == dataset][scdata_file['use_hvg'] == (use_hvg == 'True')]['file_path'].values[0]
print('Load data from: ', adata_path)
adata = sc.read_h5ad(adata_path)

Load data from:  /home/che/perturb-project/predict_model/dataset/replogle_rpe1/rpe1_normalized_hvg.h5ad
Load gene embeddings from:  /home/che/perturb-project/data/gene_ptb_emb/STRING/STRING_gene_embedding_ada_text.pkl


In [19]:
# count number of single values in cond_harm
single_perturbations = adata[~adata.obs['gene'].str.contains('\+')].copy()
single_perturbations = single_perturbations[single_perturbations.obs['gene'] != 'ctrl']
single_lst = single_perturbations.obs['gene'].value_counts().index.tolist()
single_lst.remove('non-targeting')
print(single_lst)
print(len(single_lst))

['TFAM', 'SLC1A5', 'GFM1', 'MRPL36', 'TARDBP', 'PPP6C', 'MRPL35', 'NBPF12', 'CCDC6', 'GTF3C4', 'BCAR1', 'MYBL2', 'FBXO42', 'HSD17B10', 'POGLUT3', 'GAB2', 'CLOCK', 'TMEM214', 'TWF1', 'EPS8L1', 'TMEM242', 'TACC3', 'TPT1', 'KRT17', 'ANAPC15', 'FAM136A', 'SPC25', 'CCDC78', 'SMG5', 'IK', 'SSBP3', 'ZNF718', 'SMN2', 'MVK', 'BCL2L1', 'ZDHHC7', 'DDX19A', 'EIF4B', 'SLC35G2', 'WTAP', 'TREX2', 'ADAM10', 'RGPD6', 'C7orf26', 'LAMTOR1', 'DDX19B', 'PPP1R37', 'INTS14', 'TRNT1', 'ESPN', 'CCNK', 'TIMM23B', 'RHOQ', 'CLASRP', 'PCBP2', 'PMF1', 'EIF3CL', 'HSD17B12', 'CSH2', 'DNAAF3', 'DNM1', 'SRSF11', 'CENPT', 'NAA35', 'TFRC', 'SNX15', 'TBX1', 'INTS13', 'RTEL1', 'UBTF', 'PRODH', 'PPP2R1A', 'GYG1', 'PARS2', 'C14orf178', 'PSTK', 'ADAT3', 'MRPL39', 'ZBTB17', 'ATP6AP2', 'ZFP69B', 'YRDC', 'GTF2E2', 'UBE2M', 'BTF3L4', 'TUBB', 'CSE1L', 'H2AFZ', 'VPS41', 'DSTYK', 'LMO2', 'DDN', 'PSMG3', 'MRPL38', 'ANKS6', 'MRPL16', 'HMGCS1', 'NAGLU', 'TKT', 'RBM14-RBM4', 'GLRX5', 'SSU72', 'SIRT7', 'SLC7A5', 'MRPS21', 'ATP6V0C', 'URO

### First, let's get the kernal matrix of ground-truth

In [20]:
ctrl_effect = adata[adata.obs['gene'] == 'non-targeting'].X.mean(axis=0)
mean_effect = adata[adata.obs['gene'] != 'non-targeting'].X.mean(axis=0)
print(mean_effect.shape)

(5000,)


In [21]:
pert2effect = {}
pert2effect_delta = {}
pert2effect_delta_mean_pert = {}

for pert in tqdm(adata.obs.gene.unique()):
    mean_pert = adata[adata.obs.gene == pert].X.mean(axis = 0)
    pert2effect[pert] = mean_pert
    pert2effect_delta[pert] = mean_pert - ctrl_effect
    pert2effect_delta_mean_pert[pert] = mean_pert - mean_effect

100%|██████████| 2265/2265 [00:12<00:00, 175.02it/s]


In [22]:
df = pd.DataFrame(np.stack(list(pert2effect_delta.values())), index=list(pert2effect_delta.keys()))
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,4990,4991,4992,4993,4994,4995,4996,4997,4998,4999
MRPS31,0.018916,0.069750,-0.078424,0.038682,0.048687,-0.088258,-0.001392,-0.029275,-0.023797,-0.020134,...,0.378234,0.480464,0.185599,0.221763,-0.128041,0.202955,0.263885,-0.175177,-0.033543,-0.474197
LRRC37A3,-0.014735,0.054704,-0.019607,-0.024678,0.038480,0.023722,0.003681,0.055208,0.002020,0.002808,...,-0.029753,-0.027705,-0.018238,-0.014040,-0.015686,0.003506,0.018576,-0.010476,0.000979,-0.030286
SRCAP,-0.002037,-0.025275,-0.061436,-0.033214,-0.058680,-0.029110,0.017428,-0.027701,-0.011107,-0.030638,...,0.169415,0.122641,0.018163,0.124532,-0.104019,0.113555,0.076830,0.014054,-0.068117,0.007625
WBP1,-0.020565,0.005715,-0.005894,-0.026890,0.023274,-0.047800,0.001036,0.002708,-0.007588,-0.036539,...,0.089768,0.058501,-0.115592,-0.000131,0.021866,0.035734,0.116105,0.125756,0.028886,0.033675
NOMO3,0.000566,-0.023716,0.040369,0.059700,-0.017800,-0.040362,0.027142,-0.024143,0.020768,0.029720,...,0.049844,-0.024692,-0.062294,-0.005582,0.004206,0.039863,-0.027244,0.021976,-0.090477,0.082441
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
RPS25,-0.029348,-0.045336,0.080781,0.140281,-0.145603,-0.107816,-0.068170,0.181169,0.053343,-0.026197,...,0.068118,-0.055280,-0.127365,-0.057015,0.098385,0.079964,0.072934,-0.006090,-0.016989,0.144037
RPS5,-0.037272,-0.069859,0.111033,0.100079,0.089286,-0.136246,0.034712,0.134877,-0.019901,-0.112315,...,-0.112403,-0.554280,0.082346,-0.540485,-0.348049,-0.214475,-0.127736,-0.383793,-0.445792,-0.137346
MAK16,-0.047020,-0.266297,-0.044655,0.012306,-0.024358,-0.199365,-0.145598,0.170681,0.004376,-0.153326,...,-0.062019,-0.376635,-0.104205,-0.414019,-0.166746,0.027689,-0.057538,-0.057108,-0.134620,0.024440
SARS,-0.071610,-0.285547,-0.197200,0.306740,0.136848,-0.118266,-0.285046,-0.112197,0.067805,-0.014244,...,-0.089093,-0.534158,0.105756,-0.433396,-0.219990,-0.195817,-0.331946,-0.369912,-0.354668,-0.074652


In [23]:
id2emb = dict(zip([i.split('+')[0] for i in df.index.values], df.values))

In [24]:
pert_list = [i for i in df.index.values]
pert_list_non_ctrl = [i for i in pert_list if i!= 'non-targeting']
truth_feat = np.stack([id2emb[i] for i in pert_list_non_ctrl])

G = np.dot(truth_feat, truth_feat.T)

In [25]:
truth_feat.shape

(2264, 5000)

In [26]:
G.shape # number of perturbations X number of perturbations

(2264, 2264)

In [27]:
dataset_name = dataset + '_hvg'
dataset_name

'replogle_rpe1_hvg'

In [28]:
# save the truth_feat and G
import os
folder_name = 'ground_truth_delta'
output_dir = '/home/che/perturb-project/git/gene_ptb_prediction/active_learning/kernels/' + dataset_name + '/' + folder_name + '/'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
with open(output_dir + 'pert_list.pkl', 'wb') as f:
    pickle.dump(pert_list_non_ctrl, f)
with open(output_dir + 'kernel.pkl', 'wb') as f:
    pickle.dump(G, f)
with open(output_dir + 'feat.pkl', 'wb') as f:
    pickle.dump(truth_feat, f)