In [1]:
import scanpy as sc
import numpy as np
import anndata as ad
import pandas as pd
from scipy.sparse import csr_matrix
from tqdm import tqdm 
from collections import defaultdict
import itertools
import argparse

import sys 
sys.path.insert(0, "..")
from utils import get_DE_genes

import os
import pickle as pkl
from cfp import preprocessing as cfpp

  from .autonotebook import tqdm as notebook_tqdm


# Replicate the preprocessing 

Higly variable genes

In [2]:
# Highly variable genes 
rng = np.random.default_rng(0)
hvg = 500
pca_dim = 100
ms = 0.5

# The pathways and the ood_condition 
ood_cell_type = "BXPC3"
pathway = 'IFNG_IFNB_TNFA_TGFB_INS'

In [3]:
# The final output dir
output_dir = output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/adata_ood_final_cell_type/" + ood_cell_type 

os.makedirs(output_dir, exist_ok=True)

In [4]:
genes_from_paper = [
    "AHNAK", "RNF213", "APOL6", "ASTN2", "B2M", "CFH", "CXCL9", "DENND4A", 
    "DOCK9", "EFNA5", "ERAP2", "FAT1", "GBP1", "GBP4", "HAPLN3", "HSPG2", 
    "IDO1", "IFI6", "IRF1", "LAP3", "LI", "LINC02328", "MAGI1", "MUC4", 
    "NLRC5", "NUB1", "PARP14", "PARP9", "RARRES1", "RNF213", "ROR1", "SCN9A", 
    "SERPING1", "ST5", "STAT1", "TAP1", "TAP2", "THBS1", "THSD4", "TPM1", "VCL", 
    "WARS", "XRN1"
]

Now read datasets filtered with an ms score 0.5

In [5]:
datasets = []
for pw in pathway.split('_'):
    if ms == None:
        data_path = '/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/' + pw + '_Perturb_seq.h5ad' # '_Perturb_seq_ms_0.5.h5ad'
    else:
        data_path = '/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/' + pw + '_Perturb_seq_ms_' + str(ms) + '.h5ad'
    print('Loading dataset from ' + data_path)
    dataset = sc.read_h5ad(data_path)
    dataset.obs['pathway'] = pw
    datasets.append(dataset)

Loading dataset from /lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/IFNG_Perturb_seq_ms_0.5.h5ad
Loading dataset from /lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/IFNB_Perturb_seq_ms_0.5.h5ad
Loading dataset from /lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/TNFA_Perturb_seq_ms_0.5.h5ad
Loading dataset from /lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/TGFB_Perturb_seq_ms_0.5.h5ad
Loading dataset from /lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/INS_Perturb_seq_ms_0.5.h5ad


In [6]:
adata = ad.concat(datasets, join='outer')
print('Datasets concatenated')

Datasets concatenated


  utils.warn_names_duplicates("obs")


In [7]:
adata.obs_names_make_unique()

In [8]:
columns_to_drop = ['orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'percent.mito', 'sample_ID', 'Batch_info', 'bc1_well', 'bc2_well', 'bc3_well', 'guide', 'mixscale_score', 'RNA_snn_res.0.9', 'seurat_clusters']
adata.obs.drop(columns=columns_to_drop, inplace=True)
print('Datasets prepared, running hvg analysis')

Datasets prepared, running hvg analysis


Condition is now cell type pathway and the encoding 

In [9]:
adata.obs['condition'] = adata.obs.apply(lambda x: "_".join([x.cell_type, x.pathway, x.gene]), axis=1)
adata.obs['background'] = adata.obs.apply(lambda x: "_".join([x.cell_type, x.pathway]), axis=1)

Only keep the conditions with more than 100 cells - still at the gene level 

In [10]:
condition_counts = adata.obs['condition'].value_counts()
filtered_conditions = condition_counts[condition_counts >= 100]
adata = adata[adata.obs['condition'].isin(filtered_conditions.index)]

# Preprocessing for the entire dataset

In [11]:
adata.layers["counts"] = adata.layers["counts"].astype(np.float32)
adata.X = csr_matrix(adata.layers["counts"])
del adata.layers['counts']
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

  adata.layers["counts"] = adata.layers["counts"].astype(np.float32)


Collect highly variable genes for each background 

In [12]:
highly_var_genes = {}
for bg in tqdm(adata.obs['background'].unique()):
    temp = adata[adata.obs['background'] == bg, :]
    sc.pp.highly_variable_genes(temp, inplace=True, n_top_genes=hvg)
    temp = adata[:,temp.var["highly_variable"]==True]
    highly_var_genes[bg] = set(temp.var.index)
    del temp 

  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}


In [13]:
# Compile the union list and add the genes from the paper 
combined_set = set()
for key in highly_var_genes:
    combined_set.update(highly_var_genes[key])
combined_set = combined_set.union(set(genes_from_paper))
adata = adata[:, adata.var.index.isin(combined_set)]
print("Highly variable genes selected")

Highly variable genes selected


We are left with 600k observations and 8.2k genes. We compute differentually expressed genes per condition (maybe this has to change and condition should be split )

In [14]:
adata = get_DE_genes(adata, by='condition', covariate='background')
print('DE genes calculated')

  self.obj[key] = value
  adata.uns[key_added] = {}
  adata.uns[key_added] = {}
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  sel

DE genes calculated


  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group

In [15]:
for col in adata.obs.select_dtypes(include=["category"]):
    adata.obs[col].cat.remove_unused_categories()

## Condition processing 

In [16]:
ood_condition = ood_cell_type 

In [17]:
# Filter the condition
filtered_conditions = adata.obs['condition'].unique() # unnecessary
filtered_conditions

['A549_IFNG_IRF2', 'A549_IFNG_NT', 'A549_IFNG_STAT2', 'A549_IFNG_CEBPB', 'A549_IFNG_IRF1', ..., 'MCF7_INS_FOXO4', 'MCF7_INS_XBP1', 'MCF7_INS_IRS4', 'MCF7_INS_FOXO3', 'MCF7_INS_RAD51']
Length: 855
Categories (855, object): ['A549_IFNB_ADAR', 'A549_IFNB_CEBPB', 'A549_IFNB_CEBPG', 'A549_IFNB_CRKL', ..., 'MCF7_TNFA_TNFRSF1A', 'MCF7_TNFA_TRAF1', 'MCF7_TNFA_TRAF3', 'MCF7_TNFA_ZFP36L1']

In [18]:
perturbations = list(adata.obs[adata.obs['gene'] != 'NT']["condition"].unique())
ood_conditions = [c for c in perturbations if c.startswith(ood_condition) and c in filtered_conditions]

In [19]:
ood_conditions

['BXPC3_IFNG_PRDM1',
 'BXPC3_IFNG_PTPN11',
 'BXPC3_IFNG_IFNGR2',
 'BXPC3_IFNG_BATF2',
 'BXPC3_IFNG_FOXN3',
 'BXPC3_IFNG_IFNGR1',
 'BXPC3_IFNG_ZNFX1',
 'BXPC3_IFNG_ATF3',
 'BXPC3_IFNG_EHF',
 'BXPC3_IFNG_MAFF',
 'BXPC3_IFNG_JAK1',
 'BXPC3_IFNG_SP100',
 'BXPC3_IFNG_ATF5',
 'BXPC3_IFNG_JUN',
 'BXPC3_IFNG_MAFB',
 'BXPC3_IFNG_MYC',
 'BXPC3_IFNG_JAK2',
 'BXPC3_IFNG_TRAFD1',
 'BXPC3_IFNG_PARP12',
 'BXPC3_IFNG_TBX21',
 'BXPC3_IFNG_CEBPB',
 'BXPC3_IFNG_IRF1',
 'BXPC3_IFNG_RARRES3',
 'BXPC3_IFNG_PLEK',
 'BXPC3_IFNG_HLA-DQB1',
 'BXPC3_IFNG_RUNX1',
 'BXPC3_IFNG_SRC',
 'BXPC3_IFNG_IRF9',
 'BXPC3_IFNG_IFI16',
 'BXPC3_IFNG_PIK3CA',
 'BXPC3_IFNG_IRF2',
 'BXPC3_IFNG_KLF4',
 'BXPC3_IFNG_GUK1',
 'BXPC3_IFNG_ZNF267',
 'BXPC3_IFNG_RFX5',
 'BXPC3_IFNG_STAT3',
 'BXPC3_IFNG_IRF7',
 'BXPC3_IFNG_ETV7',
 'BXPC3_IFNG_PPARG',
 'BXPC3_IFNG_STAT1',
 'BXPC3_IFNG_IRF5',
 'BXPC3_IFNG_FMNL2',
 'BXPC3_IFNG_RNF14',
 'BXPC3_IFNG_PTGES3',
 'BXPC3_IFNG_CLK1',
 'BXPC3_IFNG_ZFP36',
 'BXPC3_IFNG_FBXO6',
 'BXPC3_IFNB_IFI16',
 'BX

## Add a column saying if an observation is ood or not 

In [20]:
# adata.obs["ood"] = adata.obs.apply(lambda x: x["condition"] if x["condition"] in ood_conditions else False, axis=1)

In [22]:
adata.obs["is_ood"] = adata.obs.apply(lambda x: x["condition"] in ood_conditions, axis=1)
# adata.obs.drop(columns='ood', inplace=True)
adata_train = adata[~adata.obs["is_ood"]]
adata_ood = adata[adata.obs["is_ood"]]
print(adata_ood.obs['control'].value_counts())

control
0    164306
Name: count, dtype: int64


## Now we perform the splits 

Perform training and test split 

In [None]:
rng = np.random.default_rng(0)
split_dfs = []
adata_train.obs["split"] = "not_included"

  adata_train.obs["split"] = "not_included"


In [None]:
for c in adata_train.obs["condition"].unique():
    n_cells = adata_train[(adata_train.obs["condition"]==c)].n_obs
    # Subsample the controls, not treated 
    if c.endswith('_NT'):
        idx_test = rng.choice(np.arange(n_cells), 500, replace=False)
        remaining_idcs = set(np.arange(n_cells)) - set(idx_test)
        adata_train.obs.loc[adata_train.obs['condition'] == c, 'split'] = ["test" if idx in idx_test else "train" for idx in range(n_cells)]
    elif n_cells>300:
        idx_test = rng.choice(np.arange(n_cells), 100, replace=False)
        adata_train.obs.loc[adata_train.obs['condition'] == c, 'split'] = ["test" if idx in idx_test else "train" for idx in range(n_cells)]

adata_train_final = adata_train[adata_train.obs["split"]=="train"]
adata_test_final = adata_train[adata_train.obs["split"]=="test"]
# For evluation
adata_ood_final = ad.concat((adata_ood, adata_test_final[adata_test_final.obs["condition"].str.endswith('_NT')]))
adata_ood_final.uns = adata_ood.uns
print(adata_ood_final.obs['control'].value_counts())

In [None]:
adata_ood_final

In [None]:
adata_train_final = adata_train_final[~(adata_train_final.obs['cell_type'] == ood_cell_type), :]
adata_test_final = adata_test_final[~(adata_test_final.obs['cell_type'] == ood_cell_type), :]
adata_ood_final = adata_ood_final[adata_ood_final.obs['cell_type'] == ood_cell_type, :]

In [None]:
adata_ood_final

Here technically Lea saves a `wo` version of the training adata 

## PCA on real data 

In [None]:
cfpp.centered_pca(adata_train_final, n_comps=pca_dim)

In [None]:
# Initialize a log-count layer
adata_train_final.layers["X_log1p"] = adata_train_final.X.copy()
# Training data mean 
adata_train_final_mean = adata_train_final.varm["X_mean"].flatten()

# Define the gene means for the anndata train and ood as the training one 
adata_ood_final.varm["X_mean"] = adata_train_final.varm["X_mean"]
adata_test_final.varm["X_mean"] = adata_train_final.varm["X_mean"]

# Center both test and ood data by the mean of the training set and compute PCA based on this
adata_test_final.layers["centered_X"] = csr_matrix(adata_test_final.X.toarray() - adata_train_final_mean)
adata_ood_final.layers["centered_X"] = csr_matrix(adata_ood_final.X.toarray() - adata_train_final_mean)
adata_test_final.obsm["X_pca"] = np.matmul(adata_test_final.layers["centered_X"].toarray(), adata_train_final.varm["PCs"])
adata_ood_final.obsm["X_pca"] = np.matmul(adata_ood_final.layers["centered_X"].toarray(), adata_train_final.varm["PCs"])

Add if an observation is a control 

In [None]:
# Add the control key to the obs data frame
adata_train_final.obs['control'] = adata_train_final.obs.apply(lambda x: x['gene'] == 'NT', axis=1)
adata_test_final.obs['control'] = adata_test_final.obs.apply(lambda x: x['gene'] == 'NT', axis=1)
adata_ood_final.obs['control'] = adata_ood_final.obs.apply(lambda x: x['gene'] == 'NT', axis=1)

Collect ESM embeddings 

In [None]:
path_to_embeddings = os.path.join('/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/embeddings/perturb_emb/satijas_v2', 'gene_embeddings.pkl')
# Gene KO embeddings 
ko_embeddings = pkl.load(open(path_to_embeddings, 'rb'))
ko_embeddings = pd.DataFrame(ko_embeddings).T
ko_embeddings = ko_embeddings.astype(np.float32)
gene_embeddings_dict = dict(zip(ko_embeddings.index, ko_embeddings.values))

In [None]:
# Cell line embedding 
cell_embeddings = pd.read_csv('/lustre/groups/ml01/workspace/ot_perturbation/data/satija/embeddings/cell_line_embedding_full_ccle_300_normalized.csv', index_col=0)
cell_embeddings = cell_embeddings.astype(np.float32)
cell_embeddings_dict = dict(zip(cell_embeddings.index, cell_embeddings.values))
cell_embeddings_dict = {k: v for k, v in cell_embeddings_dict.items() if k in adata_train_final.obs['cell_type'].unique()}

In [None]:
# Control embedding as zero 
gene_embeddings_dict['NT'] = np.zeros(gene_embeddings_dict['IFNG'].shape)
pathway_embeddings = {k: v for k, v in gene_embeddings_dict.items() if k in adata_train_final.obs['pathway'].unique()}

In [None]:
# Add all the embeddings to the uns of the adata 
adata_train_final.uns['gene_emb'] = gene_embeddings_dict
adata_train_final.uns['cell_type_emb'] = cell_embeddings_dict
adata_train_final.uns['pathway_emb'] = pathway_embeddings

adata_test_final.uns['gene_emb'] = gene_embeddings_dict
adata_test_final.uns['cell_type_emb'] = cell_embeddings_dict
adata_test_final.uns['pathway_emb'] = pathway_embeddings

adata_ood_final.uns['gene_emb'] = gene_embeddings_dict
adata_ood_final.uns['cell_type_emb'] = cell_embeddings_dict
adata_ood_final.uns['pathway_emb'] = pathway_embeddings

In [None]:
# Subset for cells for which we have the embeddings 
adata_train_final = adata_train_final[adata_train_final.obs['cell_type'].isin(cell_embeddings_dict.keys()), :]
adata_train_final = adata_train_final[adata_train_final.obs['pathway'].isin(pathway_embeddings.keys()), :]
adata_train_final = adata_train_final[(adata_train_final.obs['gene'].isin(gene_embeddings_dict.keys()) | (adata_train_final.obs['gene'] == 'NT')), :]

adata_test_final = adata_test_final[adata_test_final.obs['cell_type'].isin(cell_embeddings_dict.keys()), :]
adata_test_final = adata_test_final[adata_test_final.obs['pathway'].isin(pathway_embeddings.keys()), :]
adata_test_final = adata_test_final[(adata_test_final.obs['gene'].isin(gene_embeddings_dict.keys()) | (adata_test_final.obs['gene'] == 'NT')), :]

adata_ood_final = adata_ood_final[adata_ood_final.obs['cell_type'].isin(cell_embeddings_dict.keys()), :]
adata_ood_final = adata_ood_final[adata_ood_final.obs['pathway'].isin(pathway_embeddings.keys()), :]
adata_ood_final = adata_ood_final[(adata_ood_final.obs['gene'].isin(gene_embeddings_dict.keys()) | (adata_ood_final.obs['gene'] == 'NT')), :]

In [None]:
adata_train_final = adata_train_final[~(adata_train_final.obs['cell_type'] == ood_cell_type), :]
adata_test_final = adata_test_final[~(adata_test_final.obs['cell_type'] == ood_cell_type), :]

In [None]:
# adata_train_final.write(os.path.join(output_dir, "adata_train_" + ood_pathway + "_" + ood_cell_type + ".h5ad"))
# adata_ood_final.write(os.path.join(output_dir, "adata_ood_" + ood_pathway + "_" + ood_cell_type + ".h5ad"))
# adata_test_final.write(os.path.join(output_dir, "adata_test_" + ood_pathway + "_" + ood_cell_type + ".h5ad"))