#### This notebook evaluates the identity model on the Norman dataset.

The identity model (= "no effect model") assumes that perturbations have no effect and predictions thus correspond to control cells. These 'predictions' are computed on-the-fly in this notebook (rather than being loaded as in the other baseline models).

In [1]:
%load_ext autoreload
%autoreload 2

In [12]:
import scanpy as sc
import numpy as np
import functools
import jax
import os
from cfp.metrics import compute_metrics, compute_mean_metrics, compute_metrics_fast
import cfp.preprocessing as cfpp
import anndata as ad

In [3]:
split = 0

In [8]:
DATA_DIR = "/home/haicu/soeren.becker/repos/ot_pert_reproducibility/norman2019/norman_preprocessed_adata"

In [35]:
adata_train_path = os.path.join(DATA_DIR, f"adata_train_pca_3_split_{split}.h5ad")
adata_test_path = os.path.join(DATA_DIR, f"adata_val_pca_3_split_{split}.h5ad")
adata_ood_path = os.path.join(DATA_DIR, f"adata_test_pca_3_split_{split}.h5ad")

In [36]:
# load data splits
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)

In [37]:
# compute pca on full dataset
adata_all = ad.concat((adata_train, adata_test, adata_ood))
cfpp.centered_pca(adata_all, n_comps=10)

  utils.warn_names_duplicates("obs")


In [38]:
# compute predictions of identity model --> the identity model always predicts that the perturbation has not effect, 
# i.e., predictions correspond to cells of the control condition
adata_pred_test = adata_test[adata_test.obs["condition"]=="ctrl"]

# project predictions and ground truth data onto pca space
cfpp.project_pca(query_adata=adata_pred_test, ref_adata=adata_all)
cfpp.project_pca(query_adata=adata_test, ref_adata=adata_all)

# dict for ground truths per condition, encoded (=pca space) and decoded (= gene space)
test_data_target_encoded, test_data_target_decoded = {}, {}

# dict for predictions per condition, encoded (=pca space) and decoded (= gene space)
test_data_target_encoded_predicted, test_data_target_decoded_predicted = {}, {}

for cond in adata_test.obs["condition"].cat.categories:
    if cond == "ctrl":
        continue

    # pca space
    test_data_target_encoded[cond] = adata_test[adata_test.obs["condition"] == cond].obsm["X_pca"]
    test_data_target_encoded_predicted[cond] = adata_pred_test.obsm["X_pca"]
    
    # gene space
    test_data_target_decoded[cond] = adata_test[adata_test.obs["condition"] == cond].X
    test_data_target_decoded_predicted[cond] = adata_pred_test.X

  query_adata.obsm[obsm_key_added] = np.array(


In [42]:
# compute predictions of identity model --> the identity model always predicts that the perturbation has not effect, 
# i.e., predictions correspond to cells of the control condition
adata_pred_ood = adata_ood[adata_ood.obs["condition"]=="ctrl"]

cfpp.project_pca(query_adata=adata_pred_ood, ref_adata=adata_all)
cfpp.project_pca(query_adata=adata_ood, ref_adata=adata_all)

ood_data_target_encoded, ood_data_target_decoded = {}, {}
ood_data_target_encoded_predicted, ood_data_target_decoded_predicted = {}, {}

for cond in adata_ood.obs["condition"].cat.categories:
    if cond == "ctrl":
        continue

    # pca space
    ood_data_target_encoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].obsm["X_pca"]
    ood_data_target_encoded_predicted[cond] = adata_pred_ood.obsm["X_pca"]

    # gene space
    ood_data_target_decoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].X
    ood_data_target_decoded_predicted[cond] = adata_pred_ood.X

In [43]:
# test set: evaluation in encoded (=pca) space
test_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, 
    test_data_target_encoded, 
    test_data_target_encoded_predicted
)
mean_test_metrics_encoded = compute_mean_metrics(
    test_metrics_encoded, 
    prefix="encoded_test_"
)

AttributeError: 'tuple' object has no attribute 'divergence'

In [None]:
# test set: evaluation in decoded (=gene) space
test_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics_fast, 
    test_data_target_decoded, 
    test_data_target_decoded_predicted
)
mean_test_metrics_decoded = compute_mean_metrics(
    test_metrics_decoded, 
    prefix="decoded_test_"
)

In [None]:
# ood set: evaluation in encoded (=pca) space
ood_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, 
    ood_data_target_encoded, 
    ood_data_target_encoded_predicted
)
mean_ood_metrics_encoded = compute_mean_metrics(
    ood_metrics_encoded, 
    prefix="encoded_ood_",
)

# ood set: evaluation in decoded (=gene) space
ood_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics_fast, 
    ood_data_target_decoded, 
    ood_data_target_decoded_predicted
)
mean_ood_metrics_decoded = compute_mean_metrics(
    ood_metrics_decoded, 
    prefix="decoded_ood_",
)

In [None]:
# test set
test_deg_dict = {
    k: v 
    for k, v in adata_train.uns['rank_genes_groups_cov_all'].items() 
    if k in test_data_target_decoded_predicted.keys()
}

# ood set
ood_deg_dict = {
    k: v
    for k, v in adata_train.uns['rank_genes_groups_cov_all'].items() 
    if k in ood_data_target_decoded_predicted.keys()
}

In [None]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]

# test set
test_deg_target_decoded_predicted = jax.tree_util.tree_map(
    get_mask, 
    test_data_target_decoded_predicted, 
    test_deg_dict
)
test_deg_target_decoded = jax.tree_util.tree_map(
    get_mask, 
    test_data_target_decoded, 
    test_deg_dict
)

# ood set
ood_deg_target_decoded_predicted = jax.tree_util.tree_map(
    get_mask, 
    ood_data_target_decoded_predicted, 
    ood_deg_dict
)

ood_deg_target_decoded = jax.tree_util.tree_map(
    get_mask, 
    ood_data_target_decoded, 
    ood_deg_dict
)

In [None]:
deg_ood_metrics = jax.tree_util.tree_map(
    compute_metrics, 
    ood_deg_target_decoded, 
    ood_deg_target_decoded_predicted
)
deg_mean_ood_metrics = compute_mean_metrics(
    deg_ood_metrics, 
    prefix="deg_ood_"
)

deg_test_metrics = jax.tree_util.tree_map(
    compute_metrics, 
    test_deg_target_decoded, 
    test_deg_target_decoded_predicted
)
deg_mean_test_metrics = compute_mean_metrics(
    deg_test_metrics, 
    prefix="deg_test_"
)