In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import scanpy as sc
import jax
import os
from cfp.metrics import compute_metrics, compute_mean_metrics, compute_metrics_fast
import cfp.preprocessing as cfpp
import anndata as ad
import pandas as pd
from tqdm.auto import tqdm
import numpy as np



In [3]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]

In [4]:
split = 2

In [5]:
DATA_DIR = "/home/haicu/soeren.becker/repos/ot_pert_reproducibility/norman2019/norman_preprocessed_adata"

adata_train_path = os.path.join(DATA_DIR, f"adata_train_pca_50_split_{split}.h5ad")
adata_test_path = os.path.join(DATA_DIR, f"adata_val_pca_50_split_{split}.h5ad")
adata_ood_path = os.path.join(DATA_DIR, f"adata_test_pca_50_split_{split}.h5ad")

# load data splits
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)

In [6]:
train_conditions = adata_train.obs.condition.str.replace("+ctrl", "").str.replace("ctrl+", "").unique()

assert not adata_ood[adata_ood.obs.condition != "ctrl"].obs.condition.isin(train_conditions).any()

mask_single_perturbation = adata_ood.obs.condition.str.contains("ctrl")
mask_double_perturbation_seen_0 = (
    ~adata_ood.obs.condition.str.contains("ctrl") & 
    ~adata_ood.obs.gene_1.isin(train_conditions) & 
    ~adata_ood.obs.gene_2.isin(train_conditions)
)
mask_double_perturbation_seen_1 = (
    ~adata_ood.obs.condition.str.contains("ctrl") & 
    (
        (adata_ood.obs.gene_1.isin(train_conditions) & ~adata_ood.obs.gene_2.isin(train_conditions)) | 
        (~adata_ood.obs.gene_1.isin(train_conditions) & adata_ood.obs.gene_2.isin(train_conditions))
    )
)
mask_double_perturbation_seen_2 = (
    ~adata_ood.obs.condition.str.contains("ctrl") & 
    adata_ood.obs.gene_1.isin(train_conditions) & 
    adata_ood.obs.gene_2.isin(train_conditions)
)
adata_ood.obs.loc[mask_single_perturbation, "subgroup"] = "single"
adata_ood.obs.loc[mask_double_perturbation_seen_0, "subgroup"] = "double_seen_0"
adata_ood.obs.loc[mask_double_perturbation_seen_1, "subgroup"] = "double_seen_1"
adata_ood.obs.loc[mask_double_perturbation_seen_2, "subgroup"] = "double_seen_2"

display(adata_ood.obs.subgroup.value_counts())

subgroup
double_seen_1    13048
single           10681
double_seen_0     4449
double_seen_2     2592
Name: count, dtype: int64

In [7]:
# compute pca on full dataset
adata_all = ad.concat((adata_train, adata_test, adata_ood))
cfpp.centered_pca(adata_all, n_comps=10)

  utils.warn_names_duplicates("obs")


In [8]:
adata_train.obs

Unnamed: 0_level_0,condition,cell_type,dose_val,control,condition_name,cell_line,gene_1,gene_2,num_control,kategory
cell_barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
AAACCTGAGGCATGTG-1,TSC22D1+ctrl,A549,1+1,False,A549_TSC22D1+ctrl_1+1,A549,TSC22D1,ctrl,1,single
AAACCTGCACGAAGCA-1,ctrl,A549,1,True,A549_ctrl_1,A549,ctrl,ctrl,2,ctrl
AAACCTGCAGCCTTGG-1,MAML2+ctrl,A549,1+1,False,A549_MAML2+ctrl_1+1,A549,MAML2,ctrl,1,single
AAACCTGCATCTCCCA-1,ctrl+CEBPE,A549,1+1,False,A549_ctrl+CEBPE_1+1,A549,ctrl,CEBPE,1,single
AAACCTGGTATCGCAT-1,CBL+PTPN9,A549,1+1,False,A549_CBL+PTPN9_1+1,A549,CBL,PTPN9,0,double
...,...,...,...,...,...,...,...,...,...,...
TTTGTCAGTAGGCATG-8,COL2A1+ctrl,A549,1+1,False,A549_COL2A1+ctrl_1+1,A549,COL2A1,ctrl,1,single
TTTGTCAGTCAGAATA-8,ctrl,A549,1,True,A549_ctrl_1,A549,ctrl,ctrl,2,ctrl
TTTGTCATCAGTACGT-8,FOXA3+ctrl,A549,1+1,False,A549_FOXA3+ctrl_1+1,A549,FOXA3,ctrl,1,single
TTTGTCATCCACTCCA-8,CELF2+ctrl,A549,1+1,False,A549_CELF2+ctrl_1+1,A549,CELF2,ctrl,1,single


In [9]:
adata_train_ctrl = adata_train[adata_train.obs.condition == "ctrl"]
adata_train_ctrl.shape

(6853, 5045)

In [10]:
adata_ood_double_seen_2 = adata_ood[adata_ood.obs.subgroup == "double_seen_2"]
adata_train_single = adata_train[adata_train.obs.kategory == "single"]

adata_train_ctrl = adata_train[adata_train.obs.condition == "ctrl"]

all_predictions, all_conditions = [], []

num_sampled_cells = 500

for condition in tqdm(adata_ood_double_seen_2.obs.condition.unique()):
    gene_1, gene_2 = condition.split("+")
    
    cells_1 = adata_train_single[adata_train_single.obs.gene_1 == gene_1].X
    random_idcs_1 = np.random.choice(cells_1.shape[0], size=num_sampled_cells, replace=True)
    cells_1 = cells_1[random_idcs_1].todense()
    
    cells_2 = adata_train_single[adata_train_single.obs.gene_2 == gene_2].X
    random_idcs_2 = np.random.choice(cells_2.shape[0], size=num_sampled_cells, replace=True)
    cells_2 = cells_2[random_idcs_2].todense()

    random_idcs_ctrl = np.random.choice(adata_train_ctrl.shape[0], size=num_sampled_cells, replace=True)
    ctrl_cells = adata_train_ctrl.X[random_idcs_ctrl].todense()

    displacement_1 = cells_1 - ctrl_cells
    displacement_2 = cells_2 - ctrl_cells

    predictions = np.asarray(ctrl_cells + displacement_1 + displacement_2)
    all_predictions.append(predictions)
    all_conditions.extend([condition] * num_sampled_cells)

  0%|          | 0/9 [00:00<?, ?it/s]

In [11]:
adata_pred_ood  = ad.AnnData(X=np.vstack(all_predictions), obs=pd.DataFrame(all_conditions, columns=["condition"]))



In [12]:
adata_pred_ood

AnnData object with n_obs × n_vars = 4500 × 5045
    obs: 'condition'

In [13]:
cfpp.project_pca(query_adata=adata_pred_ood, ref_adata=adata_all)
cfpp.project_pca(query_adata=adata_ood, ref_adata=adata_all)

ood_data_target_encoded, ood_data_target_decoded = {}, {}
ood_data_target_encoded_predicted, ood_data_target_decoded_predicted = {}, {}

subgroups = ["double_seen_2"]

for subgroup in subgroups:

    ood_data_target_encoded[subgroup] = {}
    ood_data_target_decoded[subgroup] = {}
    ood_data_target_encoded_predicted[subgroup] = {}
    ood_data_target_decoded_predicted[subgroup] = {}
    
    for cond in adata_ood.obs["condition"].cat.categories:
        if cond == "ctrl":
            continue
        
        select = adata_ood.obs["condition"] == cond
        select2 = adata_pred_ood.obs["condition"] == cond
        if subgroup != "all":
            select = select & (adata_ood.obs.subgroup == subgroup)

        if not any(select):
            # the condition is not part of this subgroup
            continue
        
        # pca space
        ood_data_target_encoded[subgroup][cond] = adata_ood[select].obsm["X_pca"]
        ood_data_target_encoded_predicted[subgroup][cond] = adata_pred_ood[select2].obsm["X_pca"]
        # print(ood_data_target_encoded[subgroup][cond].shape, ood_data_target_encoded_predicted[subgroup][cond].shape)
    
        # gene space
        ood_data_target_decoded[subgroup][cond] = np.asarray(adata_ood[select].X.todense())
        ood_data_target_decoded_predicted[subgroup][cond] = adata_pred_ood[select2].X
        # print(ood_data_target_decoded[subgroup][cond].shape, ood_data_target_decoded_predicted[subgroup][cond].shape)

In [14]:
ood_metrics_encoded, mean_ood_metrics_encoded = {}, {}
ood_metrics_decoded, mean_ood_metrics_decoded = {}, {}
deg_ood_metrics, deg_mean_ood_metrics = {}, {}
ood_deg_dict = {}
ood_deg_target_decoded_predicted, ood_deg_target_decoded = {}, {}

for subgroup in tqdm(subgroups[::-1]):

    print(f"subgroup: {subgroup}")

    print("Computing ood_metrics_encoded")
    # ood set: evaluation in encoded (=pca) space
    ood_metrics_encoded[subgroup] = jax.tree_util.tree_map(
        compute_metrics, 
        # compute_metrics_fast, 
        ood_data_target_encoded[subgroup], 
        ood_data_target_encoded_predicted[subgroup]
    )
    mean_ood_metrics_encoded[subgroup] = compute_mean_metrics(
        ood_metrics_encoded[subgroup], 
        prefix="encoded_ood_",
    )

    print("Computing ood_metrics_decoded")
    # ood set: evaluation in decoded (=gene) space
    ood_metrics_decoded[subgroup] = jax.tree_util.tree_map(
        # compute_metrics, 
        compute_metrics_fast, 
        ood_data_target_decoded[subgroup], 
        ood_data_target_decoded_predicted[subgroup]
    )
    mean_ood_metrics_decoded[subgroup] = compute_mean_metrics(
        ood_metrics_decoded[subgroup], 
        prefix="decoded_ood_",
    )

    # ood set
    ood_deg_dict[subgroup] = {
        k: v
        for k, v in adata_train.uns['rank_genes_groups_cov_all'].items() 
        if k in ood_data_target_decoded_predicted[subgroup].keys()
    }

    print("Apply DEG mask")
    # ood set
    ood_deg_target_decoded_predicted[subgroup] = jax.tree_util.tree_map(
        get_mask, 
        ood_data_target_decoded_predicted[subgroup], 
        ood_deg_dict[subgroup]
    )
    
    ood_deg_target_decoded[subgroup] = jax.tree_util.tree_map(
        get_mask, 
        ood_data_target_decoded[subgroup], 
        ood_deg_dict[subgroup]
    )

    print("Compute metrics on DEG subsetted decoded")
    deg_ood_metrics[subgroup] = jax.tree_util.tree_map(
        compute_metrics, 
        # compute_metrics_fast, 
        ood_deg_target_decoded[subgroup], 
        ood_deg_target_decoded_predicted[subgroup]
    )
    deg_mean_ood_metrics[subgroup] = compute_mean_metrics(
        deg_ood_metrics[subgroup], 
        prefix="deg_ood_"
    )

  0%|          | 0/1 [00:00<?, ?it/s]

subgroup: double_seen_2
Computing ood_metrics_encoded
Computing ood_metrics_decoded
Apply DEG mask
Compute metrics on DEG subsetted decoded


In [15]:
deg_mean_ood_metrics

{'double_seen_2': {'deg_ood_r_squared': 0.8535235921541849,
  'deg_ood_sinkhorn_div_1': 29.743914100858902,
  'deg_ood_sinkhorn_div_10': 10.864479806688097,
  'deg_ood_sinkhorn_div_100': 2.2280073960622153,
  'deg_ood_e_distance': np.float64(3.7429370868569456),
  'deg_ood_mmd': np.float32(0.06567267)}}

In [16]:
collected_results = {
    "ood_metrics_encoded": ood_metrics_encoded,
    "mean_ood_metrics_encoded": mean_ood_metrics_encoded,
    "ood_metrics_decoded": ood_metrics_decoded,
    "mean_ood_metrics_decoded": mean_ood_metrics_decoded,
    "deg_ood_metrics": deg_ood_metrics,
    "deg_mean_ood_metrics": deg_mean_ood_metrics,
    "ood_deg_dict": ood_deg_dict,
    "ood_deg_target_decoded_predicted": ood_deg_target_decoded_predicted,
    "ood_deg_target_decoded": ood_deg_target_decoded,
}

In [17]:
OUT_DIR = "/lustre/groups/ml01/workspace/ot_perturbation/data/norman_soren/additive"
os.makedirs(OUT_DIR, exist_ok=True)
pd.to_pickle(collected_results, os.path.join(OUT_DIR, f"norman_additive_split_{split}_collected_results.pkl"))