In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import scanpy as sc
import jax
import os
from cfp.metrics import compute_metrics, compute_mean_metrics, compute_metrics_fast
import cfp.preprocessing as cfpp
import anndata as ad
import pandas as pd
from tqdm.auto import tqdm



In [3]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]

In [4]:
split = 3

#### Data loading

In [5]:
DATA_DIR = "/home/haicu/soeren.becker/repos/ot_pert_reproducibility/norman2019/norman_preprocessed_adata"

adata_train_path = os.path.join(DATA_DIR, f"adata_train_pca_3_split_{split}.h5ad")
adata_test_path = os.path.join(DATA_DIR, f"adata_val_pca_3_split_{split}.h5ad")
adata_ood_path = os.path.join(DATA_DIR, f"adata_test_pca_3_split_{split}.h5ad")

# load data splits
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)

In [6]:
# compute pca on full dataset
adata_all = ad.concat((adata_train, adata_test, adata_ood))
cfpp.centered_pca(adata_all, n_comps=10)

  utils.warn_names_duplicates("obs")


#### Predict on ood set

In [7]:
# compute predictions of identity model --> the identity model always predicts that the perturbation has not effect, 
# i.e., predictions correspond to cells of the control condition
adata_pred_ood = adata_ood[adata_ood.obs["condition"]=="ctrl"]

cfpp.project_pca(query_adata=adata_pred_ood, ref_adata=adata_all)
cfpp.project_pca(query_adata=adata_ood, ref_adata=adata_all)

ood_data_target_encoded, ood_data_target_decoded = {}, {}
ood_data_target_encoded_predicted, ood_data_target_decoded_predicted = {}, {}

ood_data_target_encoded = {}
ood_data_target_decoded = {}
ood_data_target_encoded_predicted = {}
ood_data_target_decoded_predicted = {}

for cond in adata_ood.obs["condition"].cat.categories:
    if cond == "ctrl":
        continue
    
    select = adata_ood.obs["condition"] == cond

    if not any(select):
        # the condition is not part of this subgroup
        continue
    
    # pca space
    ood_data_target_encoded[cond] = adata_ood[select].obsm["X_pca"]
    ood_data_target_encoded_predicted[cond] = adata_pred_ood.obsm["X_pca"]

    # gene space
    ood_data_target_decoded[cond] = adata_ood[select].X
    ood_data_target_decoded_predicted[cond] = adata_pred_ood.X

  query_adata.obsm[obsm_key_added] = np.array(


#### Evaluation on ood set

In [8]:
ood_metrics_encoded, mean_ood_metrics_encoded = {}, {}
ood_metrics_decoded, mean_ood_metrics_decoded = {}, {}
deg_ood_metrics, deg_mean_ood_metrics = {}, {}
ood_deg_dict = {}
ood_deg_target_decoded_predicted, ood_deg_target_decoded = {}, {}

print("Computing ood_metrics_encoded")
# ood set: evaluation in encoded (=pca) space
ood_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, 
    # compute_metrics_fast, 
    ood_data_target_encoded, 
    ood_data_target_encoded_predicted
)
mean_ood_metrics_encoded = compute_mean_metrics(
    ood_metrics_encoded, 
    prefix="encoded_ood_",
)

print("Computing ood_metrics_decoded")
# ood set: evaluation in decoded (=gene) space
ood_metrics_decoded = jax.tree_util.tree_map(
    # compute_metrics, 
    compute_metrics_fast, 
    ood_data_target_decoded, 
    ood_data_target_decoded_predicted
)
mean_ood_metrics_decoded = compute_mean_metrics(
    ood_metrics_decoded, 
    prefix="decoded_ood_",
)

# ood set
ood_deg_dict = {
    k: v
    for k, v in adata_train.uns['rank_genes_groups_cov_all'].items() 
    if k in ood_data_target_decoded_predicted.keys()
}

print("Apply DEG mask")
# ood set
ood_deg_target_decoded_predicted = jax.tree_util.tree_map(
    get_mask, 
    ood_data_target_decoded_predicted, 
    ood_deg_dict
)

ood_deg_target_decoded = jax.tree_util.tree_map(
    get_mask, 
    ood_data_target_decoded, 
    ood_deg_dict
)

print("Compute metrics on DEG subsetted decoded")
deg_ood_metrics = jax.tree_util.tree_map(
    compute_metrics, 
    # compute_metrics_fast, 
    ood_deg_target_decoded, 
    ood_deg_target_decoded_predicted
)
deg_mean_ood_metrics = compute_mean_metrics(
    deg_ood_metrics, 
    prefix="deg_ood_"
)

Computing ood_metrics_encoded
Computing ood_metrics_decoded
Apply DEG mask
Compute metrics on DEG subsetted decoded


In [9]:
collected_results = {
    # ood
    "ood_metrics_encoded": ood_metrics_encoded,
    "mean_ood_metrics_encoded": mean_ood_metrics_encoded,
    "ood_metrics_decoded": ood_metrics_decoded,
    "mean_ood_metrics_decoded": mean_ood_metrics_decoded,
    "deg_ood_metrics": deg_ood_metrics,
    "deg_mean_ood_metrics": deg_mean_ood_metrics,
    "ood_deg_dict": ood_deg_dict,
    "ood_deg_target_decoded_predicted": ood_deg_target_decoded_predicted,
    "ood_deg_target_decoded": ood_deg_target_decoded,
}

In [11]:
OUT_DIR = "/lustre/groups/ml01/workspace/ot_perturbation/data/norman_soren/identity_wo_subgroups"
os.makedirs(OUT_DIR, exist_ok=True)
pd.to_pickle(collected_results, os.path.join(OUT_DIR, f"norman_identity_split_{split}_collected_results.pkl"))