In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import scanpy as sc
import jax
import os
from cfp.metrics import compute_metrics, compute_mean_metrics, compute_metrics_fast
import cfp.preprocessing as cfpp
import anndata as ad
import pandas as pd
from tqdm.auto import tqdm
import numpy as np



In [3]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]

In [4]:
split = 0

### Load data

In [5]:
DATA_DIR = "/home/haicu/soeren.becker/repos/ot_pert_reproducibility/norman2019/norman_preprocessed_adata"

adata_train_path = os.path.join(DATA_DIR, f"adata_train_pca_50_split_{split}.h5ad")
adata_test_path = os.path.join(DATA_DIR, f"adata_val_pca_50_split_{split}.h5ad")
adata_ood_path = os.path.join(DATA_DIR, f"adata_test_pca_50_split_{split}.h5ad")

# load data splits
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)

In [9]:
"AHR+KLF1" in adata_ood.obs.condition.values

True

In [27]:
adata_train.obs.condition.values.unique().tolist()

['TSC22D1+ctrl',
 'ctrl',
 'CEBPE+RUNX1T1',
 'MAML2+ctrl',
 'ctrl+CEBPE',
 'SGK1+TBX3',
 'ctrl+FOXA1',
 'FOXA3+FOXA1',
 'ETS2+IGDCC3',
 'GLB1L2+ctrl',
 'MAP2K6+IKZF3',
 'BAK1+ctrl',
 'FEV+ctrl',
 'MAP2K3+SLC38A2',
 'ctrl+ETS2',
 'ctrl+FEV',
 'ctrl+SET',
 'TBX3+ctrl',
 'LHX1+ctrl',
 'RREB1+ctrl',
 'ZNF318+ctrl',
 'ctrl+ZBTB25',
 'MAP4K5+ctrl',
 'UBASH3B+ctrl',
 'SLC6A9+ctrl',
 'MIDN+ctrl',
 'DLX2+ctrl',
 'CBFA2T3+ctrl',
 'HES7+ctrl',
 'SET+CEBPE',
 'IGDCC3+ZBTB25',
 'AHR+ctrl',
 'FOXO4+ctrl',
 'ctrl+CBFA2T3',
 'ctrl+RUNX1T1',
 'POU3F2+ctrl',
 'ctrl+CNN1',
 'IGDCC3+MAPK1',
 'MAP2K3+ctrl',
 'MAP4K3+ctrl',
 'ZBTB25+ctrl',
 'ZC3HAV1+CEBPE',
 'UBASH3B+UBASH3A',
 'MAP2K3+MAP2K6',
 'PTPN1+ctrl',
 'RUNX1T1+ctrl',
 'PTPN12+ctrl',
 'TP73+ctrl',
 'ctrl+MAP7D1',
 'FOSB+ctrl',
 'MAPK1+ctrl',
 'IRF1+ctrl',
 'TMSB4X+BAK1',
 'BPGM+SAMD1',
 'IKZF3+ctrl',
 'HOXB9+ctrl',
 'ctrl+HOXC13',
 'MAPK1+IKZF3',
 'ctrl+UBASH3B',
 'ctrl+HOXB9',
 'ETS2+ctrl',
 'CLDN6+ctrl',
 'FOXA3+ctrl',
 'CEBPE+ctrl',
 'KIF18B+KIF2

In [22]:
"AHR+KLF1" in adata_train.obs.condition.values.unique()

False

In [24]:
"AHR+ctrl" in adata_train.obs.condition.values.unique()

True

In [26]:
"KLF1+ctrl" in adata_train.obs.condition.values.unique(), "ctrl+KLF1" in adata_train.obs.condition.values.unique()

(False, False)

### Categorize perturbations into subgroups: single, double_seen_0, double_seen_1, double_seen_2

In [10]:
train_conditions = adata_train.obs.condition.str.replace("+ctrl", "").str.replace("ctrl+", "").unique()

# make sure that non of the test conditions is in the training data
assert not adata_ood[adata_ood.obs.condition != "ctrl"].obs.condition.isin(train_conditions).any()

# single perturbations
mask_single_perturbation = adata_ood.obs.condition.str.contains("ctrl")

# double_seen_0 perturbations
mask_double_perturbation_seen_0 = (
    ~adata_ood.obs.condition.str.contains("ctrl") & 
    ~adata_ood.obs.gene_1.isin(train_conditions) & 
    ~adata_ood.obs.gene_2.isin(train_conditions)
)

# double_seen_1 perturbations
mask_double_perturbation_seen_1 = (
    ~adata_ood.obs.condition.str.contains("ctrl") & 
    (
        (adata_ood.obs.gene_1.isin(train_conditions) & ~adata_ood.obs.gene_2.isin(train_conditions)) | 
        (~adata_ood.obs.gene_1.isin(train_conditions) & adata_ood.obs.gene_2.isin(train_conditions))
    )
)

# double_seen_2 perturbations
mask_double_perturbation_seen_2 = (
    ~adata_ood.obs.condition.str.contains("ctrl") & 
    adata_ood.obs.gene_1.isin(train_conditions) & 
    adata_ood.obs.gene_2.isin(train_conditions)
)

# add perturbation subgroup to anndata
adata_ood.obs.loc[mask_single_perturbation, "subgroup"] = "single"
adata_ood.obs.loc[mask_double_perturbation_seen_0, "subgroup"] = "double_seen_0"
adata_ood.obs.loc[mask_double_perturbation_seen_1, "subgroup"] = "double_seen_1"
adata_ood.obs.loc[mask_double_perturbation_seen_2, "subgroup"] = "double_seen_2"

display(adata_ood.obs.subgroup.value_counts())

subgroup
double_seen_1    11417
single           11317
double_seen_2     4593
double_seen_0     1927
Name: count, dtype: int64

In [16]:
"AHR+KLF1" in adata_ood.obs.loc[adata_ood.obs.subgroup == "double_seen_1", "condition"].values

True

In [19]:
"AHR" in train_conditions, "KLF1" in train_conditions

(True, False)

In [75]:
# compute pca on full dataset
adata_all = ad.concat((adata_train, adata_test, adata_ood))
cfpp.centered_pca(adata_all, n_comps=10)

  utils.warn_names_duplicates("obs")


In [76]:
adata_train.obs

Unnamed: 0_level_0,condition,cell_type,dose_val,control,condition_name,cell_line,gene_1,gene_2,num_control,kategory
cell_barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
AAACCTGAGGCATGTG-1,TSC22D1+ctrl,A549,1+1,False,A549_TSC22D1+ctrl_1+1,A549,TSC22D1,ctrl,1,single
AAACCTGCACGAAGCA-1,ctrl,A549,1,True,A549_ctrl_1,A549,ctrl,ctrl,2,ctrl
AAACCTGCAGACGTAG-1,CEBPE+RUNX1T1,A549,1+1,False,A549_CEBPE+RUNX1T1_1+1,A549,CEBPE,RUNX1T1,0,double
AAACCTGCAGCCTTGG-1,MAML2+ctrl,A549,1+1,False,A549_MAML2+ctrl_1+1,A549,MAML2,ctrl,1,single
AAACCTGCATCTCCCA-1,ctrl+CEBPE,A549,1+1,False,A549_ctrl+CEBPE_1+1,A549,ctrl,CEBPE,1,single
...,...,...,...,...,...,...,...,...,...,...
TTTGTCAGTAGCGTGA-8,ctrl,A549,1,True,A549_ctrl_1,A549,ctrl,ctrl,2,ctrl
TTTGTCAGTATAAACG-8,SAMD1+UBASH3B,A549,1+1,False,A549_SAMD1+UBASH3B_1+1,A549,SAMD1,UBASH3B,0,double
TTTGTCAGTCAGAATA-8,ctrl,A549,1,True,A549_ctrl_1,A549,ctrl,ctrl,2,ctrl
TTTGTCATCAGTACGT-8,FOXA3+ctrl,A549,1+1,False,A549_FOXA3+ctrl_1+1,A549,FOXA3,ctrl,1,single


In [77]:
# control cells used to make predictions
adata_train_ctrl = adata_train[adata_train.obs.condition == "ctrl"].copy()

# perturbed cells used to create additive model predictions
adata_train_single = adata_train[adata_train.obs.kategory == "single"].copy()

# conditions that the additive model can predict
adata_ood_double_seen_2 = adata_ood[adata_ood.obs.subgroup == "double_seen_2"].copy()

NUM_SAMPLED_CELLS = 500

### make sure that the perturbed gene is always gene 1

In [78]:
genes = adata_train_single.obs.condition.str.split("+", expand=True).values
genes_1 = genes[:, 0]
genes_2 = genes[:, 1]
mask = genes_1 == "ctrl"

genes_1[mask], genes_2[mask] = genes_2[mask], genes_1[mask]

adata_train_single.obs.loc[:, "condition_ordered"] = genes_1 + "+" + genes_2
adata_train_single.obs.loc[:, "gene_1_ordered"] = genes_1 # this is the perturbed gene
adata_train_single.obs.loc[:, "gene_2_ordered"] = genes_2 # this is always ctrl

assert np.all(adata_train_single.obs.loc[:, "gene_1_ordered"] != "ctrl")
assert np.all(adata_train_single.obs.loc[:, "gene_2_ordered"] == "ctrl")

In [79]:
display(adata_train_single.obs)

Unnamed: 0_level_0,condition,cell_type,dose_val,control,condition_name,cell_line,gene_1,gene_2,num_control,kategory,condition_ordered,gene_1_ordered,gene_2_ordered
cell_barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
AAACCTGAGGCATGTG-1,TSC22D1+ctrl,A549,1+1,False,A549_TSC22D1+ctrl_1+1,A549,TSC22D1,ctrl,1,single,TSC22D1+ctrl,TSC22D1,ctrl
AAACCTGCAGCCTTGG-1,MAML2+ctrl,A549,1+1,False,A549_MAML2+ctrl_1+1,A549,MAML2,ctrl,1,single,MAML2+ctrl,MAML2,ctrl
AAACCTGCATCTCCCA-1,ctrl+CEBPE,A549,1+1,False,A549_ctrl+CEBPE_1+1,A549,ctrl,CEBPE,1,single,CEBPE+ctrl,CEBPE,ctrl
AAACCTGTCAGGCGAA-1,ctrl+FOXA1,A549,1+1,False,A549_ctrl+FOXA1_1+1,A549,ctrl,FOXA1,1,single,FOXA1+ctrl,FOXA1,ctrl
AAACCTGTCGTCCAGG-1,GLB1L2+ctrl,A549,1+1,False,A549_GLB1L2+ctrl_1+1,A549,GLB1L2,ctrl,1,single,GLB1L2+ctrl,GLB1L2,ctrl
...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTCACATCGTCGG-8,ctrl+MAP2K6,A549,1+1,False,A549_ctrl+MAP2K6_1+1,A549,ctrl,MAP2K6,1,single,MAP2K6+ctrl,MAP2K6,ctrl
TTTGTCAGTACCTACA-8,TSC22D1+ctrl,A549,1+1,False,A549_TSC22D1+ctrl_1+1,A549,TSC22D1,ctrl,1,single,TSC22D1+ctrl,TSC22D1,ctrl
TTTGTCAGTACGCTGC-8,ctrl+MAP7D1,A549,1+1,False,A549_ctrl+MAP7D1_1+1,A549,ctrl,MAP7D1,1,single,MAP7D1+ctrl,MAP7D1,ctrl
TTTGTCATCAGTACGT-8,FOXA3+ctrl,A549,1+1,False,A549_FOXA3+ctrl_1+1,A549,FOXA3,ctrl,1,single,FOXA3+ctrl,FOXA3,ctrl


### Get predictions - recall that this additive model can only make predictions for subgroup double_seen_2

In [80]:
all_predictions, all_conditions = [], []

for condition in tqdm(adata_ood_double_seen_2.obs.condition.unique()):

    # get perturbed genes
    gene_1, gene_2 = condition.split("+")

    # get perturbed gene expression profiles
    cells_1 = np.asarray(adata_train_single[adata_train_single.obs.gene_1_ordered == gene_1].X.todense()) # gene_1_ordered is always the perturbed gene
    cells_2 = np.asarray(adata_train_single[adata_train_single.obs.gene_1_ordered == gene_2].X.todense())

    # get control cells
    random_idcs_ctrl = np.random.choice(adata_train_ctrl.shape[0], size=NUM_SAMPLED_CELLS, replace=True)
    ctrl_cells = np.asarray(adata_train_ctrl.X[random_idcs_ctrl].todense())

    # compute displacements
    displacement_1 = cells_1.mean(axis=0) - ctrl_cells.mean(axis=0)
    displacement_2 = cells_2.mean(axis=0) - ctrl_cells.mean(axis=0)

    # get predictions
    predictions = ctrl_cells + displacement_1 + displacement_2
    all_predictions.append(predictions)
    all_conditions.extend([condition] * NUM_SAMPLED_CELLS)

  0%|          | 0/15 [00:00<?, ?it/s]

In [81]:
adata_pred_ood  = ad.AnnData(X=np.vstack(all_predictions), obs=pd.DataFrame(all_conditions, columns=["condition"]))
adata_pred_ood



AnnData object with n_obs × n_vars = 7500 × 5045
    obs: 'condition'

### create dicts with predictions and ground truths per condition

In [82]:
cfpp.project_pca(query_adata=adata_pred_ood, ref_adata=adata_all)
cfpp.project_pca(query_adata=adata_ood_double_seen_2, ref_adata=adata_all)

ood_data_target_encoded, ood_data_target_decoded = {}, {}
ood_data_target_encoded_predicted, ood_data_target_decoded_predicted = {}, {}

for cond in tqdm(adata_ood_double_seen_2.obs["condition"].cat.categories):
    if cond == "ctrl":
        continue
    
    # get masks for truths and preds for this condition
    select = adata_ood_double_seen_2.obs["condition"] == cond
    select2 = adata_pred_ood.obs["condition"] == cond
    
    # truth and preds in pca space
    ood_data_target_encoded[cond] = adata_ood_double_seen_2[select].obsm["X_pca"]
    ood_data_target_encoded_predicted[cond] = adata_pred_ood[select2].obsm["X_pca"]

    # truths and preds in gene space
    ood_data_target_decoded[cond] = np.asarray(adata_ood_double_seen_2[select].X.todense())
    ood_data_target_decoded_predicted[cond] = adata_pred_ood[select2].X

  0%|          | 0/15 [00:00<?, ?it/s]

In [83]:
ood_metrics_encoded, mean_ood_metrics_encoded = {}, {}
ood_metrics_decoded, mean_ood_metrics_decoded = {}, {}
deg_ood_metrics, deg_mean_ood_metrics = {}, {}
ood_deg_dict = {}
ood_deg_target_decoded_predicted, ood_deg_target_decoded = {}, {}

print("Computing ood_metrics_encoded")
# ood set: evaluation in encoded (=pca) space
ood_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, 
    # compute_metrics_fast, 
    ood_data_target_encoded, 
    ood_data_target_encoded_predicted
)
mean_ood_metrics_encoded = compute_mean_metrics(
    ood_metrics_encoded, 
    prefix="encoded_ood_",
)

print("Computing ood_metrics_decoded")
# ood set: evaluation in decoded (=gene) space
ood_metrics_decoded = jax.tree_util.tree_map(
    # compute_metrics, 
    compute_metrics_fast, 
    ood_data_target_decoded, 
    ood_data_target_decoded_predicted
)
mean_ood_metrics_decoded = compute_mean_metrics(
    ood_metrics_decoded, 
    prefix="decoded_ood_",
)

# ood set
ood_deg_dict = {
    k: v
    for k, v in adata_train.uns['rank_genes_groups_cov_all'].items() 
    if k in ood_data_target_decoded_predicted.keys()
}

print("Apply DEG mask")
# ood set
ood_deg_target_decoded_predicted = jax.tree_util.tree_map(
    get_mask, 
    ood_data_target_decoded_predicted, 
    ood_deg_dict
)

ood_deg_target_decoded = jax.tree_util.tree_map(
    get_mask, 
    ood_data_target_decoded, 
    ood_deg_dict
)

print("Compute metrics on DEG subsetted decoded")
deg_ood_metrics = jax.tree_util.tree_map(
    compute_metrics, 
    # compute_metrics_fast, 
    ood_deg_target_decoded, 
    ood_deg_target_decoded_predicted
)
deg_mean_ood_metrics = compute_mean_metrics(
    deg_ood_metrics, 
    prefix="deg_ood_"
)

Computing ood_metrics_encoded
Computing ood_metrics_decoded
Apply DEG mask
Compute metrics on DEG subsetted decoded


In [84]:
deg_mean_ood_metrics

{'deg_ood_r_squared': 0.9736408591270447,
 'deg_ood_sinkhorn_div_1': 17.404671065012614,
 'deg_ood_sinkhorn_div_10': 4.6770671844482425,
 'deg_ood_sinkhorn_div_100': 1.8972557067871094,
 'deg_ood_e_distance': np.float64(3.1158367545999774),
 'deg_ood_mmd': np.float32(0.02011226)}

In [85]:
collected_results = {
    "ood_metrics_encoded": ood_metrics_encoded,
    "mean_ood_metrics_encoded": mean_ood_metrics_encoded,
    "ood_metrics_decoded": ood_metrics_decoded,
    "mean_ood_metrics_decoded": mean_ood_metrics_decoded,
    "deg_ood_metrics": deg_ood_metrics,
    "deg_mean_ood_metrics": deg_mean_ood_metrics,
    "ood_deg_dict": ood_deg_dict,
    "ood_deg_target_decoded_predicted": ood_deg_target_decoded_predicted,
    "ood_deg_target_decoded": ood_deg_target_decoded,
}

In [86]:
OUT_DIR = "/lustre/groups/ml01/workspace/ot_perturbation/data/norman_soren/additive_mean_displacement/num_samples_500"
os.makedirs(OUT_DIR, exist_ok=True)
pd.to_pickle(collected_results, os.path.join(OUT_DIR, f"norman_additive_split_{split}_collected_results.pkl"))