In [1]:
import scanpy as sc
import numpy as np
import anndata as ad
import pandas as pd
from scipy.sparse import csr_matrix
from tqdm import tqdm 

import sys 
sys.path.insert(0, "..")
from utils import get_DE_genes

import os
import pickle as pkl
from cfp import preprocessing as cfpp

  from .autonotebook import tqdm as notebook_tqdm


# Replicate the preprocessing 

Higly variable genes

In [2]:
# Highly variable genes 
rng = np.random.default_rng(seed=42)  
hvg = 500
pca_dim = 100
ms = 0.5

# The pathways and the ood_condition 
pathway = 'IFNG_IFNB_TNFA_TGFB_INS'

In [3]:
# The final output dir
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/adata_ood_final_genes" + pathway + '_hvg-' + str(hvg) + '_pca-' + str(pca_dim) + '_counts' + '_ms_' + str(ms)
output_dir

'/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/adata_ood_final_genesIFNG_IFNB_TNFA_TGFB_INS_hvg-500_pca-100_counts_ms_0.5'

In [4]:
genes_from_paper = [
    "AHNAK", "RNF213", "APOL6", "ASTN2", "B2M", "CFH", "CXCL9", "DENND4A", 
    "DOCK9", "EFNA5", "ERAP2", "FAT1", "GBP1", "GBP4", "HAPLN3", "HSPG2", 
    "IDO1", "IFI6", "IRF1", "LAP3", "LI", "LINC02328", "MAGI1", "MUC4", 
    "NLRC5", "NUB1", "PARP14", "PARP9", "RARRES1", "RNF213", "ROR1", "SCN9A", 
    "SERPING1", "ST5", "STAT1", "TAP1", "TAP2", "THBS1", "THSD4", "TPM1", "VCL", 
    "WARS", "XRN1"
]

Now read datasets filtered with an ms score 0.5

In [5]:
datasets = []
for pw in pathway.split('_'):
    if ms == None:
        data_path = '/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/' + pw + '_Perturb_seq.h5ad' # '_Perturb_seq_ms_0.5.h5ad'
    else:
        data_path = '/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/' + pw + '_Perturb_seq_ms_' + str(ms) + '.h5ad'
    print('Loading dataset from ' + data_path)
    dataset = sc.read_h5ad(data_path)
    dataset.obs['pathway'] = pw
    datasets.append(dataset)

Loading dataset from /lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/IFNG_Perturb_seq_ms_0.5.h5ad
Loading dataset from /lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/IFNB_Perturb_seq_ms_0.5.h5ad
Loading dataset from /lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/TNFA_Perturb_seq_ms_0.5.h5ad
Loading dataset from /lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/TGFB_Perturb_seq_ms_0.5.h5ad
Loading dataset from /lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/INS_Perturb_seq_ms_0.5.h5ad


In [6]:
adata = ad.concat(datasets, join='outer')
adata = adata[~np.logical_and(adata.obs.cell_type=="K562",
                                 (adata.obs.pathway=="TGFB"))]
print('Datasets concatenated')
# sc.pp.subsample(adata, 0.8)

  utils.warn_names_duplicates("obs")


Datasets concatenated


In [7]:
adata.obs_names_make_unique()

  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


In [8]:
columns_to_drop = ['orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'percent.mito', 'sample_ID', 'Batch_info', 'bc1_well', 'bc2_well', 'bc3_well', 'guide', 'mixscale_score', 'RNA_snn_res.0.9', 'seurat_clusters']
adata.obs.drop(columns=columns_to_drop, inplace=True)
print('Datasets prepared, running hvg analysis')

Datasets prepared, running hvg analysis


In [9]:
# Use a defaultdict for convenience
pathway_to_gene = {}

# Populate the dictionary
for pathway, gene in zip(adata.obs.pathway, adata.obs.gene):
    if gene != "NT":
        if pathway not in pathway_to_gene:
            pathway_to_gene[pathway]=[]
        if gene not in pathway_to_gene[pathway]:
            pathway_to_gene[pathway].append(gene)

In [10]:
# Convert sets to lists (optional, depending on downstream use)
for key in pathway_to_gene:
    rng.shuffle(pathway_to_gene[key])  # Shuffle genes per pathway 

In [11]:
# Collect the different splits per pathway 
splits = {}
for pathway in pathway_to_gene:
    # Split in 4 equal values 
    oods = np.array_split(np.array(pathway_to_gene[pathway]), 4)
    splits[pathway] = oods

In [12]:
# We will have controls assigned to all observations for and then convert to split name per gene
adata.obs["split_encodings"] = ["controls"] * len(adata)

In [13]:
for split_no in range(4):    
    for pathway in splits:
        for gene in splits[pathway][split_no]:
            idx = np.logical_and(adata.obs.pathway==pathway, 
                                 adata.obs.gene==gene)
            
            adata.obs.loc[idx, "split_encodings"] = f"split_{split_no}"

Condition is now cell type pathway and the encoding 

In [14]:
adata.obs['perturbation_condition'] = adata.obs.apply(lambda x: "_".join([x.cell_type, x.pathway, x.gene]), axis=1)
adata.obs['background'] = adata.obs.apply(lambda x: "_".join([x.cell_type, x.pathway]), axis=1)
print("Added pathway keys to the .obs data frame")

Added pathway keys to the .obs data frame


In [15]:
# adata0_obs = adata[adata.obs.split_encodings=="split_0"].obs
# adata1_obs = adata[adata.obs.split_encodings=="split_1"].obs
# adata2_obs = adata[adata.obs.split_encodings=="split_2"].obs
# adata3_obs = adata[adata.obs.split_encodings=="split_3"].obs

In [15]:
condition_counts = adata.obs['perturbation_condition'].value_counts()
filtered_conditions = condition_counts[condition_counts >= 100]
adata = adata[adata.obs['perturbation_condition'].isin(filtered_conditions.index)]

# Preprocessing for the entire dataset

In [16]:
adata.layers["counts"] = adata.layers["counts"].astype(np.float32)
adata.X = csr_matrix(adata.layers["counts"])
del adata.layers['counts']
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

  adata.layers["counts"] = adata.layers["counts"].astype(np.float32)


Collect highly variable genes for each background 

In [17]:
highly_var_genes = {}
for bg in tqdm(adata.obs['background'].unique()):
    temp = adata[adata.obs['background'] == bg, :]
    sc.pp.highly_variable_genes(temp, inplace=True, n_top_genes=hvg)
    temp = adata[:,temp.var["highly_variable"]==True]
    highly_var_genes[bg] = set(temp.var.index)
    del temp 

  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}


In [18]:
# Compile the union list and add the genes from the paper 
combined_set = set()
for key in highly_var_genes:
    combined_set.update(highly_var_genes[key])
combined_set = combined_set.union(set(genes_from_paper))
adata = adata[:, adata.var.index.isin(combined_set)]

We are left with 600k observations and 8.2k genes. We compute differentually expressed genes per condition (maybe this has to change and condition should be split )

In [19]:
adata = get_DE_genes(adata, by='perturbation_condition', covariate='background')
print('DE genes calculated')

  self.obj[key] = value
  adata.uns[key_added] = {}
  adata.uns[key_added] = {}
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  sel

DE genes calculated


  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group

In [20]:
for col in adata.obs.select_dtypes(include=["category"]):
    adata.obs[col].cat.remove_unused_categories()

In [21]:
ood_condition = 'split_1'
adata.obs["is_ood"] = adata.obs.apply(lambda x: x["split_encodings"] == ood_condition, axis=1)

In [23]:
adata.write_h5ad("/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/reference/full_adata_with_splits_BXCP3_ood.h5ad")

In [24]:
adata

AnnData object with n_obs × n_vars = 615504 × 8056
    obs: 'cell_type', 'gene', 'pathway', 'split_encodings', 'perturbation_condition', 'background', 'control', 'is_ood'
    uns: 'log1p', 'rank_genes_groups_cov_all'

In [25]:
adata_train = adata[~adata.obs["is_ood"]]
adata_ood = adata[adata.obs["is_ood"]]
print(adata_ood.obs['control'].value_counts())

control
0    101045
Name: count, dtype: int64


In [26]:
adata_train[(adata_train.obs.cell_type=="K562")*(adata_train.obs.pathway=="TGFB")].obs.gene

79_45_17_1_1_1                    NT
81_17_32_1_1_1_1_1            CREBBP
81_14_92_1_1_1_1_1             MAPK1
79_10_83_2_1_1_1_1_1            RBL1
79_91_37_2_1_1                    NT
                               ...  
79_15_13_1_1_1_1_1                NT
81_67_03_1_1_1_1              CREBBP
75_65_81_1_1_1_1_1_1_1_1_2        NT
73_17_34_2_1_1_1_2                NT
76_69_81_1_1_1_2              CREBBP
Name: gene, Length: 1675, dtype: category
Categories (7, object): ['CREBBP', 'JUN', 'MAPK1', 'MAPK14', 'NT', 'RBL1', 'SMAD4']

In [27]:
adata_ood[(adata_ood.obs.cell_type=="K562")*(adata_ood.obs.pathway=="TGFB")].obs.gene

80_38_15_1_1_1_1_1_1_1        FOXP2
74_22_11_1_1_1_2              FOXP2
79_14_34_2_1_1_1_1_1_1        FOXP2
73_76_69_2_1_1_1_1_1_1_2      FOXP2
74_88_65_1_1_1_1_1_2          FOXP2
                              ...  
80_70_06_1_1_1_1_1_1_1_1_1    FOXP2
73_84_40_1_1_1_1_1_2          FOXP2
75_15_60_2_1_1_1_1_1_1_1_2    FOXP2
80_21_01_2_1_1_1_1            FOXP2
80_09_90_1_1_1_1              FOXP2
Name: gene, Length: 151, dtype: category
Categories (1, object): ['FOXP2']

In [28]:
# Remove original anndata 
# adata.write_h5ad("/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/full_adata_with_splits.h5ad")
# del adata

## Now we perform the splits 

Perform training and test split 

In [29]:
adata_train.obs["split"] = "not_included"

  adata_train.obs["split"] = "not_included"


In [30]:
for c in adata_train.obs["perturbation_condition"].unique():
    n_cells = adata_train[(adata_train.obs["perturbation_condition"]==c)].n_obs
    # Subsample the controls, not treated 
    if c.endswith('_NT'):
        idx_test = rng.choice(np.arange(n_cells), 100, replace=False) 
        adata_train.obs.loc[adata_train.obs['perturbation_condition'] == c, 'split'] = ["test" if idx in idx_test else "train" for idx in range(n_cells)]
    elif n_cells>300:
        idx_test = rng.choice(np.arange(n_cells), 100, replace=False)
        adata_train.obs.loc[adata_train.obs['perturbation_condition'] == c, 'split'] = ["test" if idx in idx_test else "train" for idx in range(n_cells)]

# Final training set 
adata_train_final = adata_train[adata_train.obs["split"]=="train"]
adata_test_final = adata_train[adata_train.obs["split"]=="test"]
# Add test controls to the ood data loader 
adata_ood_final = ad.concat((adata_ood, adata_test_final[adata_test_final.obs["perturbation_condition"].str.endswith('_NT')]))
adata_ood_final.uns = adata_ood.uns

In [31]:
adata_train_final[(adata_train_final.obs.cell_type=="K562")*(adata_train_final.obs.pathway=="TGFB")].obs.gene

79_91_37_2_1_1                NT
75_65_24_1_1_2                NT
79_23_63_1_1_1_1_1_1_1        NT
74_57_64_1_1_1_1_2            NT
76_93_89_1_2                  NT
                              ..
73_76_94_2_2                  NT
79_07_42_1_1                  NT
79_15_13_1_1_1_1_1            NT
75_65_81_1_1_1_1_1_1_1_1_2    NT
73_17_34_2_1_1_1_2            NT
Name: gene, Length: 622, dtype: category
Categories (1, object): ['NT']

In [32]:
adata_ood[(adata_ood.obs.cell_type=="K562")*(adata_ood.obs.pathway=="TGFB")].obs.gene

80_38_15_1_1_1_1_1_1_1        FOXP2
74_22_11_1_1_1_2              FOXP2
79_14_34_2_1_1_1_1_1_1        FOXP2
73_76_69_2_1_1_1_1_1_1_2      FOXP2
74_88_65_1_1_1_1_1_2          FOXP2
                              ...  
80_70_06_1_1_1_1_1_1_1_1_1    FOXP2
73_84_40_1_1_1_1_1_2          FOXP2
75_15_60_2_1_1_1_1_1_1_1_2    FOXP2
80_21_01_2_1_1_1_1            FOXP2
80_09_90_1_1_1_1              FOXP2
Name: gene, Length: 151, dtype: category
Categories (1, object): ['FOXP2']

In [33]:
# adata_train_final = adata_train_final[~(adata_train_final.obs['split_condition'] == ood_condition), :]
# adata_test_final = adata_test_final[~(adata_test_final.obs['split_condition'] == ood_condition), :]
# adata_ood_final = adata_ood_final[adata_ood_final.obs['split_condition'] == ood_condition, :]

Here technically Lea saves a `wo` version of the training adata 

## PCA on real data 

In [34]:
cfpp.centered_pca(adata_train_final, n_comps=pca_dim)

  adata.varm["X_mean"] = np.array(X.mean(axis=0).T)


In [35]:
# Initialize a log-count layer
adata_train_final.layers["X_log1p"] = adata_train_final.X.copy()
# Training data mean 
adata_train_final_mean = adata_train_final.varm["X_mean"].flatten()

# Define the gene means for the anndata train and ood as the training one 
adata_ood_final.varm["X_mean"] = adata_train_final.varm["X_mean"]
adata_test_final.varm["X_mean"] = adata_train_final.varm["X_mean"]

# Center both test and ood data by the mean of the training set and compute PCA based on this
adata_test_final.layers["centered_X"] = csr_matrix(adata_test_final.X.toarray() - adata_train_final_mean)
adata_ood_final.layers["centered_X"] = csr_matrix(adata_ood_final.X.toarray() - adata_train_final_mean)
adata_test_final.obsm["X_pca"] = np.matmul(adata_test_final.layers["centered_X"].toarray(), adata_train_final.varm["PCs"])
adata_ood_final.obsm["X_pca"] = np.matmul(adata_ood_final.layers["centered_X"].toarray(), adata_train_final.varm["PCs"])

  adata_test_final.varm["X_mean"] = adata_train_final.varm["X_mean"]


Add if an observation is a control 

In [36]:
# Add the control key to the obs data frame
adata_train_final.obs['control'] = adata_train_final.obs.apply(lambda x: x['gene'] == 'NT', axis=1)
adata_test_final.obs['control'] = adata_test_final.obs.apply(lambda x: x['gene'] == 'NT', axis=1)
adata_ood_final.obs['control'] = adata_ood_final.obs.apply(lambda x: x['gene'] == 'NT', axis=1)

Collect ESM embeddings 

In [37]:
path_to_embeddings = os.path.join('/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/embeddings/perturb_emb/satijas_v2', 'gene_embeddings.pkl')
# Gene KO embeddings 
ko_embeddings = pkl.load(open(path_to_embeddings, 'rb'))
ko_embeddings = pd.DataFrame(ko_embeddings).T
ko_embeddings = ko_embeddings.astype(np.float32)
gene_embeddings_dict = dict(zip(ko_embeddings.index, ko_embeddings.values))

In [38]:
# Cell line embedding 
cell_embeddings = pd.read_csv('/lustre/groups/ml01/workspace/ot_perturbation/data/satija/embeddings/cell_line_embedding_full_ccle_300_normalized.csv', index_col=0)
cell_embeddings = cell_embeddings.astype(np.float32)
cell_embeddings_dict = dict(zip(cell_embeddings.index, cell_embeddings.values))
cell_embeddings_dict = {k: v for k, v in cell_embeddings_dict.items() if k in adata_train_final.obs['cell_type'].unique()}

In [39]:
# Control embedding as zero 
gene_embeddings_dict['NT'] = np.zeros(gene_embeddings_dict['IFNG'].shape)
pathway_embeddings = {k: v for k, v in gene_embeddings_dict.items() if k in adata_train_final.obs['pathway'].unique()}

In [40]:
# Add all the embeddings to the uns of the adata 
adata_train_final.uns['gene_emb'] = gene_embeddings_dict
adata_train_final.uns['cell_type_emb'] = cell_embeddings_dict
adata_train_final.uns['pathway_emb'] = pathway_embeddings

adata_test_final.uns['gene_emb'] = gene_embeddings_dict
adata_test_final.uns['cell_type_emb'] = cell_embeddings_dict
adata_test_final.uns['pathway_emb'] = pathway_embeddings

adata_ood_final.uns['gene_emb'] = gene_embeddings_dict
adata_ood_final.uns['cell_type_emb'] = cell_embeddings_dict
adata_ood_final.uns['pathway_emb'] = pathway_embeddings

In [41]:
# Subset for cells for which we have the embeddings 
adata_train_final = adata_train_final[adata_train_final.obs['cell_type'].isin(cell_embeddings_dict.keys()), :]
adata_train_final = adata_train_final[adata_train_final.obs['pathway'].isin(pathway_embeddings.keys()), :]
adata_train_final = adata_train_final[(adata_train_final.obs['gene'].isin(gene_embeddings_dict.keys()) | (adata_train_final.obs['gene'] == 'NT')), :]

adata_test_final = adata_test_final[adata_test_final.obs['cell_type'].isin(cell_embeddings_dict.keys()), :]
adata_test_final = adata_test_final[adata_test_final.obs['pathway'].isin(pathway_embeddings.keys()), :]
adata_test_final = adata_test_final[(adata_test_final.obs['gene'].isin(gene_embeddings_dict.keys()) | (adata_test_final.obs['gene'] == 'NT')), :]

adata_ood_final = adata_ood_final[adata_ood_final.obs['cell_type'].isin(cell_embeddings_dict.keys()), :]
adata_ood_final = adata_ood_final[adata_ood_final.obs['pathway'].isin(pathway_embeddings.keys()), :]
adata_ood_final = adata_ood_final[(adata_ood_final.obs['gene'].isin(gene_embeddings_dict.keys()) | (adata_ood_final.obs['gene'] == 'NT')), :]
print(adata_ood_final.obs['control'].value_counts())

control
False    98543
True      3000
Name: count, dtype: int64


In [42]:
adata_train_final[(adata_train_final.obs.cell_type=="K562")*(adata_train_final.obs.pathway=="TGFB")].obs.gene

79_91_37_2_1_1                NT
75_65_24_1_1_2                NT
79_23_63_1_1_1_1_1_1_1        NT
74_57_64_1_1_1_1_2            NT
76_93_89_1_2                  NT
                              ..
73_76_94_2_2                  NT
79_07_42_1_1                  NT
79_15_13_1_1_1_1_1            NT
75_65_81_1_1_1_1_1_1_1_1_2    NT
73_17_34_2_1_1_1_2            NT
Name: gene, Length: 622, dtype: category
Categories (1, object): ['NT']

In [43]:
adata_ood_final[(adata_ood_final.obs.cell_type=="K562")*(adata_ood_final.obs.pathway=="TGFB")].obs.gene

80_38_15_1_1_1_1_1_1_1      FOXP2
74_22_11_1_1_1_2            FOXP2
79_14_34_2_1_1_1_1_1_1      FOXP2
73_76_69_2_1_1_1_1_1_1_2    FOXP2
74_88_65_1_1_1_1_1_2        FOXP2
                            ...  
79_17_87_1_1_1                 NT
73_70_78_2_1_1_1_1_1_2         NT
80_43_89_2_1_1_1_1_1           NT
80_91_54_2_1_1_1_1_1           NT
81_27_95_1_1_1_1               NT
Name: gene, Length: 251, dtype: category
Categories (2, object): ['FOXP2', 'NT']