In [29]:
import scanpy as sc
import numpy as np
import anndata as ad
import pandas as pd
from scipy.sparse import csr_matrix
from tqdm import tqdm 

import sys 
sys.path.insert(0, "..")
from utils import get_DE_genes

import os
import pickle as pkl
from cfp import preprocessing as cfpp

# Replicate the preprocessing 

Higly variable genes

In [30]:
# Highly variable genes 
hvg = 500
pca_dim = 100
ms = 0.5

# The pathways and the ood_condition 
pathway = 'IFNG_IFNB_TNFA_TGFB_INS'
ood_condition = 'split_1'

In [7]:
# The final output dir
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/adata_ood_final_genes" + pathway + '_hvg-' + str(hvg) + '_pca-' + str(pca_dim) + '_counts' + '_ms_' + str(ms)
output_dir

'/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/adata_ood_final_genesIFNG_IFNB_TNFA_TGFB_INS_hvg-500_pca-100_counts_ms_0.5'

In [8]:
genes_from_paper = [
    "AHNAK", "RNF213", "APOL6", "ASTN2", "B2M", "CFH", "CXCL9", "DENND4A", 
    "DOCK9", "EFNA5", "ERAP2", "FAT1", "GBP1", "GBP4", "HAPLN3", "HSPG2", 
    "IDO1", "IFI6", "IRF1", "LAP3", "LI", "LINC02328", "MAGI1", "MUC4", 
    "NLRC5", "NUB1", "PARP14", "PARP9", "RARRES1", "RNF213", "ROR1", "SCN9A", 
    "SERPING1", "ST5", "STAT1", "TAP1", "TAP2", "THBS1", "THSD4", "TPM1", "VCL", 
    "WARS", "XRN1"
]

Now read datasets filtered with an ms score 0.5

In [9]:
datasets = []
for pw in pathway.split('_'):
    if ms == None:
        data_path = '/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/' + pw + '_Perturb_seq.h5ad' # '_Perturb_seq_ms_0.5.h5ad'
    else:
        data_path = '/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/' + pw + '_Perturb_seq_ms_' + str(ms) + '.h5ad'
    print('Loading dataset from ' + data_path)
    dataset = sc.read_h5ad(data_path)
    dataset.obs['pathway'] = pw
    datasets.append(dataset)

Loading dataset from /lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/IFNG_Perturb_seq_ms_0.5.h5ad
Loading dataset from /lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/IFNB_Perturb_seq_ms_0.5.h5ad
Loading dataset from /lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/TNFA_Perturb_seq_ms_0.5.h5ad
Loading dataset from /lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/TGFB_Perturb_seq_ms_0.5.h5ad
Loading dataset from /lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/INS_Perturb_seq_ms_0.5.h5ad


In [10]:
adata = ad.concat(datasets, join='outer')
print('Datasets concatenated')

Datasets concatenated


  utils.warn_names_duplicates("obs")


In [11]:
adata.obs_names_make_unique()

In [12]:
columns_to_drop = ['orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'percent.mito', 'sample_ID', 'Batch_info', 'bc1_well', 'bc2_well', 'bc3_well', 'guide', 'mixscale_score', 'RNA_snn_res.0.9', 'seurat_clusters']
adata.obs.drop(columns=columns_to_drop, inplace=True)
print('Datasets prepared, running hvg analysis')

Datasets prepared, running hvg analysis


## Create gene splits 

In [13]:
from collections import defaultdict

# Use a defaultdict for convenience
pathway_to_gene = defaultdict(set)

# Populate the dictionary
for pathway, gene in zip(adata.obs.pathway, adata.obs.gene):
    pathway_to_gene[pathway].add(gene)

# Convert sets to lists (optional, depending on downstream use)
pathway_to_gene = {key: list(value) for key, value in pathway_to_gene.items()}

In [14]:
for path, gene_list in pathway_to_gene.items():
    print(path, len(gene_list))

IFNG 58
IFNB 62
TNFA 55
TGFB 44
INS 45


In [15]:
pathway_to_gene

{'IFNG': ['IRF5',
  'FOXN3',
  'NFKB1',
  'ATF3',
  'ATF5',
  'SRC',
  'JUN',
  'ZC3H3',
  'ZNFX1',
  'GUK1',
  'TRAFD1',
  'STAT3',
  'PIK3CA',
  'MYC',
  'NT',
  'KLF4',
  'JAK2',
  'ETV7',
  'PTGES3',
  'ZNF267',
  'PRDM1',
  'PTPN11',
  'RARRES3',
  'IRF7',
  'PARP12',
  'SOX2',
  'TBX21',
  'KIN',
  'CEBPB',
  'STAT1',
  'IRF2',
  'IFNGR2',
  'TAPBPL',
  'EHF',
  'STAT2',
  'PPARG',
  'RFX5',
  'FMNL2',
  'HLX',
  'MAFF',
  'IFNGR1',
  'CLK1',
  'RNF14',
  'IRF9',
  'JAK1',
  'FBXO6',
  'PLEK',
  'ZFP36',
  'CUL1',
  'SP110',
  'RUNX1',
  'IFI16',
  'SP100',
  'BATF2',
  'MAFB',
  'MCRS1',
  'IRF1',
  'HLA-DQB1'],
 'IFNB': ['NFKB1',
  'MAPK14',
  'SOCS1',
  'HES4',
  'JUN',
  'STAT6',
  'STAT4',
  'STAT5A',
  'ZNFX1',
  'ZBP1',
  'TRAFD1',
  'H1F0',
  'UBE2L6',
  'STAT3',
  'MYC',
  'NT',
  'USP18',
  'ETS2',
  'IFNAR2',
  'HERC5',
  'IRF7',
  'PARP12',
  'TYK2',
  'AKT1',
  'HERC6',
  'IFNAR1',
  'TRIM21',
  'ADAR',
  'CEBPB',
  'STAT1',
  'RAP1GAP',
  'MAPK8',
  'VAV1',
  'ID2',

**Check if there are overlapping genes**

In [16]:
import itertools

for path_1, path_2 in itertools.combinations(list(pathway_to_gene.keys()), 2):
    print(f"Number of shared genes {path_1, path_2}: {len(np.intersect1d(pathway_to_gene[path_1], pathway_to_gene[path_2]))}")

Number of shared genes ('IFNG', 'IFNB'): 19
Number of shared genes ('IFNG', 'TNFA'): 8
Number of shared genes ('IFNG', 'TGFB'): 8
Number of shared genes ('IFNG', 'INS'): 2
Number of shared genes ('IFNB', 'TNFA'): 11
Number of shared genes ('IFNB', 'TGFB'): 8
Number of shared genes ('IFNB', 'INS'): 3
Number of shared genes ('TNFA', 'TGFB'): 10
Number of shared genes ('TNFA', 'INS'): 6
Number of shared genes ('TGFB', 'INS'): 8


**Leave out splits**

In [17]:
np.random.choice(42)

splits = {}
for pathway in pathway_to_gene:

    pathway_to_gene_no_controls = [gene for gene in pathway_to_gene[pathway] if gene != 'NT']
    
    # Split in 4 equal values 
    oods = np.array_split(np.array(pathway_to_gene_no_controls), 4)
    splits[pathway] = oods

In [18]:
splits.keys()

dict_keys(['IFNG', 'IFNB', 'TNFA', 'TGFB', 'INS'])

In [19]:
splits["IFNG"]

[array(['IRF5', 'FOXN3', 'NFKB1', 'ATF3', 'ATF5', 'SRC', 'JUN', 'ZC3H3',
        'ZNFX1', 'GUK1', 'TRAFD1', 'STAT3', 'PIK3CA', 'MYC', 'KLF4'],
       dtype='<U8'),
 array(['JAK2', 'ETV7', 'PTGES3', 'ZNF267', 'PRDM1', 'PTPN11', 'RARRES3',
        'IRF7', 'PARP12', 'SOX2', 'TBX21', 'KIN', 'CEBPB', 'STAT1'],
       dtype='<U8'),
 array(['IRF2', 'IFNGR2', 'TAPBPL', 'EHF', 'STAT2', 'PPARG', 'RFX5',
        'FMNL2', 'HLX', 'MAFF', 'IFNGR1', 'CLK1', 'RNF14', 'IRF9'],
       dtype='<U8'),
 array(['JAK1', 'FBXO6', 'PLEK', 'ZFP36', 'CUL1', 'SP110', 'RUNX1',
        'IFI16', 'SP100', 'BATF2', 'MAFB', 'MCRS1', 'IRF1', 'HLA-DQB1'],
       dtype='<U8')]

In [20]:
splits["IFNB"]

[array(['NFKB1', 'MAPK14', 'SOCS1', 'HES4', 'JUN', 'STAT6', 'STAT4',
        'STAT5A', 'ZNFX1', 'ZBP1', 'TRAFD1', 'H1F0', 'UBE2L6', 'STAT3',
        'MYC', 'USP18'], dtype='<U7'),
 array(['ETS2', 'IFNAR2', 'HERC5', 'IRF7', 'PARP12', 'TYK2', 'AKT1',
        'HERC6', 'IFNAR1', 'TRIM21', 'ADAR', 'CEBPB', 'STAT1', 'RAP1GAP',
        'MAPK8'], dtype='<U7'),
 array(['VAV1', 'ID2', 'ELK1', 'DRAP1', 'ETS1', 'STAT2', 'POU2F1',
        'NFE2L3', 'CRKL', 'MAP3K14', 'SMARCA5', 'RAPGEF1', 'IRF9', 'JAK1',
        'IRF3'], dtype='<U7'),
 array(['ID3', 'UBA7', 'MEF2A', 'FOS', 'ID1', 'TRIM22', 'SP110', 'IFI16',
        'SP100', 'BATF2', 'CEBPG', 'DTX3L', 'BRD9', 'RNF114', 'IRF1'],
       dtype='<U7')]

In [21]:
splits["TNFA"]

[array(['NFKB1', 'MAPK14', 'IKBKG', 'KLF6', 'PTGS2', 'BIRC2', 'MTF1',
        'TRAF1', 'CASP10', 'JUN', 'CSF2', 'PDCD5', 'MAPK9', 'JUNB'],
       dtype='<U8'),
 array(['SOX9', 'ZFP36L1', 'TRAF3', 'PIK3CA', 'MYC', 'NFKBIA', 'ZNF267',
        'IKBKB', 'RELB', 'CASP3', 'MTOR', 'TNFRSF1A', 'CSF1', 'CEBPB'],
       dtype='<U8'),
 array(['TRAF2', 'DNM1L', 'MAPK8', 'ID2', 'BIRC3', 'ITCH', 'MAP3K14',
        'NKX3-1', 'CASP8', 'REPIN1', 'CHUK', 'MMP9', 'MAP2K3'], dtype='<U8'),
 array(['MAP3K7', 'BATF', 'SOX4', 'FOS', 'CASP7', 'ARID5B', 'NFKBIE',
        'FADD', 'TNFRSF1B', 'CREB1', 'IKBKE', 'NFAT5', 'IRF1'], dtype='<U8')]

In [22]:
splits["TGFB"]

[array(['NFKB1', 'MAPK14', 'MAPK3', 'PPP2CA', 'RHOA', 'JUN', 'HRAS',
        'RELA', 'SMAD6', 'KRAS', 'PIK3CA'], dtype='<U6'),
 array(['MYC', 'IKBKB', 'IRF7', 'SMAD4', 'CREBBP', 'HDAC4', 'MAPK1',
        'ATF2', 'AKT1', 'SP1', 'SMAD7'], dtype='<U6'),
 array(['TGFBR2', 'FOXP2', 'RUNX3', 'SMAD3', 'MED15', 'SMURF1', 'EP300',
        'TGIF1', 'NRAS', 'CHUK', 'FOS'], dtype='<U6'),
 array(['MAP3K7', 'SKP1', 'RBL1', 'TGFBR1', 'CUL1', 'TGFBR3', 'SMAD2',
        'RUNX1', 'SMAD9', 'SMAD5'], dtype='<U6')]

In [23]:
splits["INS"]

[array(['IRS4', 'POLR2L', 'MAPK3', 'TSC1', 'TTF1', 'TAF3', 'PIK3CA', 'DR1',
        'FOXO3', 'IRS2', 'TTF2'], dtype='<U7'),
 array(['PTEN', 'IKBKB', 'SREBF2', 'TAF8', 'MTOR', 'MAPK1', 'TSC2', 'SP1',
        'SRF', 'RAD51', 'POLR2G'], dtype='<U7'),
 array(['ELK1', 'TAF7', 'FOXO1', 'HSF1', 'IRS1', 'GRB2', 'SGK1', 'FOXO4',
        'COPS4', 'EIF2B1', 'XBP1'], dtype='<U7'),
 array(['CHUK', 'FOS', 'ZNF593', 'GRB10', 'SREBF1', 'IGF2', 'SHC1',
        'RPS6KB1', 'FOXM1', 'INSR', 'SMARCE1'], dtype='<U7')]

Add splits to the AnnData 

In [24]:
# We will have controls assigned to all observations for and then convert to split name per gene
adata.obs["split_encodings"] = ["controls"] * len(adata)

In [25]:
for split_no in range(4):    
    for pathway in splits:
        for gene in splits[pathway][split_no]:
            idx = np.logical_and(adata.obs.pathway==pathway, 
                                 adata.obs.gene==gene)
            
            adata.obs.loc[idx, "split_encodings"] = f"split_{split_no}"

In [26]:
np.unique(adata.obs.split_encodings, return_counts=True)

(array(['controls', 'split_0', 'split_1', 'split_2', 'split_3'],
       dtype=object),
 array([ 84269, 124869, 136114, 143788, 136839]))

## Two aspects, condition and background 

Condition is now cell type pathway and the encoding 

In [27]:
adata.obs['split_condition'] = adata.obs.apply(lambda x: "_".join([x.pathway, x.split_encodings]), axis=1)
adata.obs['perturbation_condition'] = adata.obs.apply(lambda x: "_".join([x.cell_type, x.pathway, x.gene]), axis=1)
adata.obs['background'] = adata.obs.apply(lambda x: "_".join([x.cell_type, x.pathway]), axis=1)

In [27]:
# Condition is ct_pathway_gene
print(adata.obs.perturbation_condition)

# Condition is ct_pathway
print(adata.obs.background)

05_33_45_1_1_1_1_1_1_1_1_1     A549_IFNG_IRF2
05_17_93_1_1_1_1_1_1_1_1_1       A549_IFNG_NT
06_63_14_1_1_1_1_1_1_1_1_1       A549_IFNG_NT
06_89_90_1_1_1_1_1_1_1_1_1    A549_IFNG_STAT2
05_59_54_1_1_1_1_1_1_1_1_1     A549_IFNG_IRF2
                                   ...       
83_82_89_2_2                    MCF7_INS_TAF3
81_30_72_2_2                     MCF7_INS_DR1
84_70_17_2_2                   MCF7_INS_RAD51
82_83_73_2_2                    MCF7_INS_TAF7
84_92_02_2_2                  MCF7_INS_PIK3CA
Name: perturbation_condition, Length: 625879, dtype: object
05_33_45_1_1_1_1_1_1_1_1_1    A549_IFNG
05_17_93_1_1_1_1_1_1_1_1_1    A549_IFNG
06_63_14_1_1_1_1_1_1_1_1_1    A549_IFNG
06_89_90_1_1_1_1_1_1_1_1_1    A549_IFNG
05_59_54_1_1_1_1_1_1_1_1_1    A549_IFNG
                                ...    
83_82_89_2_2                   MCF7_INS
81_30_72_2_2                   MCF7_INS
84_70_17_2_2                   MCF7_INS
82_83_73_2_2                   MCF7_INS
84_92_02_2_2                   MCF

Only keep the conditions with more than 100 cells - still at the gene level 

In [28]:
condition_counts = adata.obs['perturbation_condition'].value_counts()
filtered_conditions = condition_counts[condition_counts >= 100]
adata = adata[adata.obs['perturbation_condition'].isin(filtered_conditions.index)]

# Preprocessing for the entire dataset

In [29]:
adata.layers["counts"] = adata.layers["counts"].astype(np.float32)
adata.X = csr_matrix(adata.layers["counts"])
del adata.layers['counts']
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

  adata.layers["counts"] = adata.layers["counts"].astype(np.float32)


Collect highly variable genes for each background 

In [30]:
highly_var_genes = {}
for bg in tqdm(adata.obs['background'].unique()):
    temp = adata[adata.obs['background'] == bg, :]
    sc.pp.highly_variable_genes(temp, inplace=True, n_top_genes=hvg)
    temp = adata[:,temp.var["highly_variable"]==True]
    highly_var_genes[bg] = set(temp.var.index)
    del temp 

  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}
  adata.uns["hvg"] = {"flavor": flavor}


In [32]:
# Compile the union list and add the genes from the paper 
combined_set = set()
for key in highly_var_genes:
    combined_set.update(highly_var_genes[key])
combined_set = combined_set.union(set(genes_from_paper))
adata = adata[:, adata.var.index.isin(combined_set)]

We are left with 600k observations and 8.2k genes. We compute differentually expressed genes per condition (maybe this has to change and condition should be split )

In [33]:
adata = get_DE_genes(adata, by='perturbation_condition', covariate='background')
print('DE genes calculated')

  self.obj[key] = value
  adata.uns[key_added] = {}
  adata.uns[key_added] = {}
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  sel

DE genes calculated


  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group

In [34]:
# controls = {}
# for bg in adata.obs["background"].unique():
#     controls[bg] = adata[adata.obs["perturbation_condition"]==bg+'_NT'].X.toarray()

for col in adata.obs.select_dtypes(include=["category"]):
    adata.obs[col].cat.remove_unused_categories()

## Condition processing 

In [35]:
adata.obs['split_condition'].unique()

['IFNG_split_1', 'IFNG_controls', 'IFNG_split_0', 'IFNG_split_3', 'IFNG_split_2', ..., 'INS_split_2', 'INS_controls', 'INS_split_0', 'INS_split_3', 'INS_split_1']
Length: 25
Categories (25, object): ['IFNB_controls', 'IFNB_split_0', 'IFNB_split_1', 'IFNB_split_2', ..., 'TNFA_split_0', 'TNFA_split_1', 'TNFA_split_2', 'TNFA_split_3']

In [36]:
# Filter the condition
filtered_conditions = adata.obs['split_condition'].unique() # unnecessary
perturbations = list(adata.obs[adata.obs['gene'] != 'NT']["split_condition"].unique())
# Here I have to pass the split 
ood_conditions = [c for c in perturbations if c.endswith(ood_condition) and c in filtered_conditions]


## Add a column saying if an observation is ood or not 

In [35]:
adata.obs["is_ood"] = adata.obs.apply(lambda x: x["split_condition"] in ood_conditions, axis=1)
adata_train = adata[~adata.obs["is_ood"]]
adata_ood = adata[adata.obs["is_ood"]]
print(adata_ood.obs['control'].value_counts())

control
0    109119
Name: count, dtype: int64


In [37]:
# Remove original anndata 
adata.write_h5ad("/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/full_adata_with_splits.h5ad")
# del adata

## Now we perform the splits 

Perform training and test split 

In [37]:
rng = np.random.default_rng(0)
split_dfs = []
adata_train.obs["split"] = "not_included"

  adata_train.obs["split"] = "not_included"


In [38]:
for c in adata_train.obs["perturbation_condition"].unique():
    n_cells = adata_train[(adata_train.obs["perturbation_condition"]==c)].n_obs
    # Subsample the controls, not treated 
    if c.endswith('_NT'):
        idx_test = rng.choice(np.arange(n_cells), 100, replace=False)
        remaining_idcs = set(np.arange(n_cells)) - set(idx_test)
        adata_train.obs.loc[adata_train.obs['perturbation_condition'] == c, 'split'] = ["test" if idx in idx_test else "train" for idx in range(n_cells)]
    elif n_cells>300:
        idx_test = rng.choice(np.arange(n_cells), 100, replace=False)
        adata_train.obs.loc[adata_train.obs['perturbation_condition'] == c, 'split'] = ["test" if idx in idx_test else "train" for idx in range(n_cells)]

adata_train_final = adata_train[adata_train.obs["split"]=="train"]
adata_test_final = adata_train[adata_train.obs["split"]=="test"]
# For evluation
adata_ood_final = ad.concat((adata_ood, adata_test_final[adata_test_final.obs["perturbation_condition"].str.endswith('_NT')]))
adata_ood_final.uns = adata_ood.uns
print(adata_ood_final.obs['control'].value_counts())

control
0    109119
1      3000
Name: count, dtype: int64


In [39]:
# adata_train_final = adata_train_final[~(adata_train_final.obs['split_condition'] == ood_condition), :]
# adata_test_final = adata_test_final[~(adata_test_final.obs['split_condition'] == ood_condition), :]
# adata_ood_final = adata_ood_final[adata_ood_final.obs['split_condition'] == ood_condition, :]

Here technically Lea saves a `wo` version of the training adata 

## PCA on real data 

In [40]:
cfpp.centered_pca(adata_train_final, n_comps=pca_dim)

  adata.varm["X_mean"] = np.array(X.mean(axis=0).T)


In [41]:
# Initialize a log-count layer
adata_train_final.layers["X_log1p"] = adata_train_final.X.copy()
# Training data mean 
adata_train_final_mean = adata_train_final.varm["X_mean"].flatten()

# Define the gene means for the anndata train and ood as the training one 
adata_ood_final.varm["X_mean"] = adata_train_final.varm["X_mean"]
adata_test_final.varm["X_mean"] = adata_train_final.varm["X_mean"]

# Center both test and ood data by the mean of the training set and compute PCA based on this
adata_test_final.layers["centered_X"] = csr_matrix(adata_test_final.X.toarray() - adata_train_final_mean)
adata_ood_final.layers["centered_X"] = csr_matrix(adata_ood_final.X.toarray() - adata_train_final_mean)
adata_test_final.obsm["X_pca"] = np.matmul(adata_test_final.layers["centered_X"].toarray(), adata_train_final.varm["PCs"])
adata_ood_final.obsm["X_pca"] = np.matmul(adata_ood_final.layers["centered_X"].toarray(), adata_train_final.varm["PCs"])

  adata_test_final.varm["X_mean"] = adata_train_final.varm["X_mean"]


Add if an observation is a control 

In [42]:
# Add the control key to the obs data frame
adata_train_final.obs['control'] = adata_train_final.obs.apply(lambda x: x['gene'] == 'NT', axis=1)
adata_test_final.obs['control'] = adata_test_final.obs.apply(lambda x: x['gene'] == 'NT', axis=1)
adata_ood_final.obs['control'] = adata_ood_final.obs.apply(lambda x: x['gene'] == 'NT', axis=1)

Collect ESM embeddings 

In [43]:
path_to_embeddings = os.path.join('/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/embeddings/perturb_emb/satijas_v2', 'gene_embeddings.pkl')
# Gene KO embeddings 
ko_embeddings = pkl.load(open(path_to_embeddings, 'rb'))
ko_embeddings = pd.DataFrame(ko_embeddings).T
ko_embeddings = ko_embeddings.astype(np.float32)
gene_embeddings_dict = dict(zip(ko_embeddings.index, ko_embeddings.values))

In [44]:
# Cell line embedding 
cell_embeddings = pd.read_csv('/lustre/groups/ml01/workspace/ot_perturbation/data/satija/embeddings/cell_line_embedding_full_ccle_300_normalized.csv', index_col=0)
cell_embeddings = cell_embeddings.astype(np.float32)
cell_embeddings_dict = dict(zip(cell_embeddings.index, cell_embeddings.values))
cell_embeddings_dict = {k: v for k, v in cell_embeddings_dict.items() if k in adata_train_final.obs['cell_type'].unique()}

In [45]:
# Control embedding as zero 
gene_embeddings_dict['NT'] = np.zeros(gene_embeddings_dict['IFNG'].shape)
pathway_embeddings = {k: v for k, v in gene_embeddings_dict.items() if k in adata_train_final.obs['pathway'].unique()}

In [46]:
# Add all the embeddings to the uns of the adata 
adata_train_final.uns['gene_emb'] = gene_embeddings_dict
adata_train_final.uns['cell_type_emb'] = cell_embeddings_dict
adata_train_final.uns['pathway_emb'] = pathway_embeddings

adata_test_final.uns['gene_emb'] = gene_embeddings_dict
adata_test_final.uns['cell_type_emb'] = cell_embeddings_dict
adata_test_final.uns['pathway_emb'] = pathway_embeddings

adata_ood_final.uns['gene_emb'] = gene_embeddings_dict
adata_ood_final.uns['cell_type_emb'] = cell_embeddings_dict
adata_ood_final.uns['pathway_emb'] = pathway_embeddings

In [47]:
# Subset for cells for which we have the embeddings 
adata_train_final = adata_train_final[adata_train_final.obs['cell_type'].isin(cell_embeddings_dict.keys()), :]
adata_train_final = adata_train_final[adata_train_final.obs['pathway'].isin(pathway_embeddings.keys()), :]
adata_train_final = adata_train_final[(adata_train_final.obs['gene'].isin(gene_embeddings_dict.keys()) | (adata_train_final.obs['gene'] == 'NT')), :]

adata_test_final = adata_test_final[adata_test_final.obs['cell_type'].isin(cell_embeddings_dict.keys()), :]
adata_test_final = adata_test_final[adata_test_final.obs['pathway'].isin(pathway_embeddings.keys()), :]
adata_test_final = adata_test_final[(adata_test_final.obs['gene'].isin(gene_embeddings_dict.keys()) | (adata_test_final.obs['gene'] == 'NT')), :]

adata_ood_final = adata_ood_final[adata_ood_final.obs['cell_type'].isin(cell_embeddings_dict.keys()), :]
adata_ood_final = adata_ood_final[adata_ood_final.obs['pathway'].isin(pathway_embeddings.keys()), :]
adata_ood_final = adata_ood_final[(adata_ood_final.obs['gene'].isin(gene_embeddings_dict.keys()) | (adata_ood_final.obs['gene'] == 'NT')), :]
print(adata_ood_final.obs['control'].value_counts())

control
False    109119
True       3000
Name: count, dtype: int64


In [48]:
adata_ood_final

View of AnnData object with n_obs × n_vars = 112119 × 8328
    obs: 'cell_type', 'gene', 'pathway', 'split_encodings', 'split_condition', 'perturbation_condition', 'background', 'control', 'is_ood'
    uns: 'log1p', 'rank_genes_groups_cov_all', 'gene_emb', 'cell_type_emb', 'pathway_emb'
    obsm: 'X_pca'
    varm: 'X_mean'
    layers: 'centered_X'

In [49]:
adata_test_final

View of AnnData object with n_obs × n_vars = 39000 × 8328
    obs: 'cell_type', 'gene', 'pathway', 'split_encodings', 'split_condition', 'perturbation_condition', 'background', 'control', 'is_ood', 'split'
    uns: 'log1p', 'rank_genes_groups_cov_all', 'gene_emb', 'cell_type_emb', 'pathway_emb'
    obsm: 'X_pca'
    varm: 'X_mean'
    layers: 'centered_X'

In [50]:
adata_train_final

View of AnnData object with n_obs × n_vars = 292913 × 8328
    obs: 'cell_type', 'gene', 'pathway', 'split_encodings', 'split_condition', 'perturbation_condition', 'background', 'control', 'is_ood', 'split'
    uns: 'log1p', 'rank_genes_groups_cov_all', 'pca', 'gene_emb', 'cell_type_emb', 'pathway_emb'
    obsm: 'X_pca'
    varm: 'X_mean', 'PCs'
    layers: 'X_centered', 'X_log1p'

In [38]:
adata

AnnData object with n_obs × n_vars = 618023 × 8265
    obs: 'cell_type', 'gene', 'pathway', 'split_encodings', 'split_condition', 'perturbation_condition', 'background', 'control'
    uns: 'log1p', 'rank_genes_groups_cov_all'

In [52]:
# sc.read_h5ad("/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/adata_ood_final_genes/adata_ood_final_genesIFNG_IFNB_TNFA_TGFB_INS_hvg-500_pca-100_counts_ms_0.5/adata_ood_split_2.h5ad")