In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import scanpy as sc
import jax
import os
from cfp.metrics import compute_metrics, compute_mean_metrics, compute_metrics_fast
import cfp.preprocessing as cfpp
import anndata as ad
import pandas as pd
from tqdm.auto import tqdm



In [3]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]

In [4]:
split = 0

#### Data loading

In [5]:
DATA_DIR = "/home/haicu/soeren.becker/repos/ot_pert_reproducibility/norman2019/norman_preprocessed_adata"

adata_train_path = os.path.join(DATA_DIR, f"adata_train_pca_3_split_{split}.h5ad")
adata_test_path = os.path.join(DATA_DIR, f"adata_val_pca_3_split_{split}.h5ad")
adata_ood_path = os.path.join(DATA_DIR, f"adata_test_pca_3_split_{split}.h5ad")

# load data splits
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)

#### Add subgroup annotations

- single: Only one gene has been perturbed. This perturbation is not part of the training data.
- double_seen_0: Two genes have been perturbed. Neither of these gene perturbations have been seen as single perturbations in the training data.
- double_seen_1: Two genes have been perturbed. One of the two perturbed genes has been seen as single perturbation in the training data.
- double_seen_2: Two genes have been perturbed. Both gene perturbations have been seen as single perturbation in the training data.

In [7]:
train_conditions = adata_train.obs.condition.str.replace("+ctrl", "").str.replace("ctrl+", "").unique()

assert not adata_ood[adata_ood.obs.condition != "ctrl"].obs.condition.isin(train_conditions).any()

mask_single_perturbation = adata_ood.obs.condition.str.contains("ctrl")
mask_double_perturbation_seen_0 = (
    ~adata_ood.obs.condition.str.contains("ctrl") & 
    ~adata_ood.obs.gene_1.isin(train_conditions) & 
    ~adata_ood.obs.gene_2.isin(train_conditions)
)
mask_double_perturbation_seen_1 = (
    ~adata_ood.obs.condition.str.contains("ctrl") & 
    (
        (adata_ood.obs.gene_1.isin(train_conditions) & ~adata_ood.obs.gene_2.isin(train_conditions)) | 
        (~adata_ood.obs.gene_1.isin(train_conditions) & adata_ood.obs.gene_2.isin(train_conditions))
    )
)
mask_double_perturbation_seen_2 = (
    ~adata_ood.obs.condition.str.contains("ctrl") & 
    adata_ood.obs.gene_1.isin(train_conditions) & 
    adata_ood.obs.gene_2.isin(train_conditions)
)
adata_ood.obs.loc[mask_single_perturbation, "subgroup"] = "single"
adata_ood.obs.loc[mask_double_perturbation_seen_0, "subgroup"] = "double_seen_0"
adata_ood.obs.loc[mask_double_perturbation_seen_1, "subgroup"] = "double_seen_1"
adata_ood.obs.loc[mask_double_perturbation_seen_2, "subgroup"] = "double_seen_2"

display(adata_ood.obs.subgroup.value_counts())

subgroup
double_seen_1    11417
single           11317
double_seen_2     4593
double_seen_0     1927
Name: count, dtype: int64

In [8]:
# compute pca on full dataset
adata_all = ad.concat((adata_train, adata_test, adata_ood))
cfpp.centered_pca(adata_all, n_comps=10)

  utils.warn_names_duplicates("obs")


#### Predict on test set

In [9]:
# compute predictions of identity model --> the identity model always predicts that the perturbation has not effect, 
# i.e., predictions correspond to cells of the control condition
adata_pred_test = adata_test[adata_test.obs["condition"]=="ctrl"]

# project predictions and ground truth data onto pca space
cfpp.project_pca(query_adata=adata_pred_test, ref_adata=adata_all)
cfpp.project_pca(query_adata=adata_test, ref_adata=adata_all)

# dict for ground truths per condition, encoded (=pca space) and decoded (= gene space)
test_data_target_encoded, test_data_target_decoded = {}, {}

# dict for predictions per condition, encoded (=pca space) and decoded (= gene space)
test_data_target_encoded_predicted, test_data_target_decoded_predicted = {}, {}

for cond in adata_test.obs["condition"].cat.categories:
    if cond == "ctrl":
        continue

    # pca space
    test_data_target_encoded[cond] = adata_test[adata_test.obs["condition"] == cond].obsm["X_pca"]
    test_data_target_encoded_predicted[cond] = adata_pred_test.obsm["X_pca"]
    
    # gene space
    test_data_target_decoded[cond] = adata_test[adata_test.obs["condition"] == cond].X
    test_data_target_decoded_predicted[cond] = adata_pred_test.X

  query_adata.obsm[obsm_key_added] = np.array(


#### Predict on ood set (full ood set + subgroups)

In [10]:
# compute predictions of identity model --> the identity model always predicts that the perturbation has not effect, 
# i.e., predictions correspond to cells of the control condition
adata_pred_ood = adata_ood[adata_ood.obs["condition"]=="ctrl"]

cfpp.project_pca(query_adata=adata_pred_ood, ref_adata=adata_all)
cfpp.project_pca(query_adata=adata_ood, ref_adata=adata_all)

ood_data_target_encoded, ood_data_target_decoded = {}, {}
ood_data_target_encoded_predicted, ood_data_target_decoded_predicted = {}, {}

subgroups = ["all", "single", "double_seen_0", "double_seen_1", "double_seen_2"]

for subgroup in subgroups:

    ood_data_target_encoded[subgroup] = {}
    ood_data_target_decoded[subgroup] = {}
    ood_data_target_encoded_predicted[subgroup] = {}
    ood_data_target_decoded_predicted[subgroup] = {}
    
    for cond in adata_ood.obs["condition"].cat.categories:
        if cond == "ctrl":
            continue
        
        select = adata_ood.obs["condition"] == cond
        if subgroup != "all":
            select = select & (adata_ood.obs.subgroup == subgroup)

        if not any(select):
            # the condition is not part of this subgroup
            continue
        
        # pca space
        ood_data_target_encoded[subgroup][cond] = adata_ood[select].obsm["X_pca"]
        ood_data_target_encoded_predicted[subgroup][cond] = adata_pred_ood.obsm["X_pca"]
    
        # gene space
        ood_data_target_decoded[subgroup][cond] = adata_ood[select].X
        ood_data_target_decoded_predicted[subgroup][cond] = adata_pred_ood.X

#### Evaluation on test set

In [11]:
# test set: evaluation in encoded (=pca) space
test_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, 
    # compute_metrics_fast, 
    test_data_target_encoded, 
    test_data_target_encoded_predicted
)
mean_test_metrics_encoded = compute_mean_metrics(
    test_metrics_encoded, 
    prefix="encoded_test_"
)

# test set: evaluation in decoded (=gene) space
test_metrics_decoded = jax.tree_util.tree_map(
    # compute_metrics, 
    compute_metrics_fast, 
    test_data_target_decoded, 
    test_data_target_decoded_predicted
)
mean_test_metrics_decoded = compute_mean_metrics(
    test_metrics_decoded, 
    prefix="decoded_test_"
)

# test set
test_deg_dict = {
    k: v 
    for k, v in adata_train.uns['rank_genes_groups_cov_all'].items() 
    if k in test_data_target_decoded_predicted.keys()
}

# test set
test_deg_target_decoded_predicted = jax.tree_util.tree_map(
    get_mask, 
    test_data_target_decoded_predicted, 
    test_deg_dict
)
test_deg_target_decoded = jax.tree_util.tree_map(
    get_mask, 
    test_data_target_decoded, 
    test_deg_dict
)

deg_test_metrics = jax.tree_util.tree_map(
    compute_metrics, 
    # compute_metrics_fast, 
    test_deg_target_decoded, 
    test_deg_target_decoded_predicted
)
deg_mean_test_metrics = compute_mean_metrics(
    deg_test_metrics, 
    prefix="deg_test_"
)

In [12]:
mean_test_metrics_encoded

{'encoded_test_r_squared': -0.9467601449258866,
 'encoded_test_sinkhorn_div_1': 17.070092062796316,
 'encoded_test_sinkhorn_div_10': 13.196097466253466,
 'encoded_test_sinkhorn_div_100': 11.004474824474704,
 'encoded_test_e_distance': np.float64(21.113644804908095),
 'encoded_test_mmd': np.float32(0.055912003)}

In [13]:
mean_test_metrics_decoded

{'decoded_test_r_squared': 0.9612018273722741,
 'decoded_test_e_distance': np.float64(29.58447520908832),
 'decoded_test_mmd_distance': np.float32(0.025115578)}

In [14]:
deg_mean_test_metrics

{'deg_test_r_squared': 0.807751226809717,
 'deg_test_sinkhorn_div_1': 23.67160471024052,
 'deg_test_sinkhorn_div_10': 12.807433466757498,
 'deg_test_sinkhorn_div_100': 11.447026160455518,
 'deg_test_e_distance': np.float64(22.63092522060512),
 'deg_test_mmd': np.float32(0.061670344)}

#### Evaluation on ood set

In [15]:
ood_metrics_encoded, mean_ood_metrics_encoded = {}, {}
ood_metrics_decoded, mean_ood_metrics_decoded = {}, {}
deg_ood_metrics, deg_mean_ood_metrics = {}, {}
ood_deg_dict = {}
ood_deg_target_decoded_predicted, ood_deg_target_decoded = {}, {}

for subgroup in tqdm(subgroups[::-1]):

    print(f"subgroup: {subgroup}")

    print("Computing ood_metrics_encoded")
    # ood set: evaluation in encoded (=pca) space
    ood_metrics_encoded[subgroup] = jax.tree_util.tree_map(
        compute_metrics, 
        # compute_metrics_fast, 
        ood_data_target_encoded[subgroup], 
        ood_data_target_encoded_predicted[subgroup]
    )
    mean_ood_metrics_encoded[subgroup] = compute_mean_metrics(
        ood_metrics_encoded[subgroup], 
        prefix="encoded_ood_",
    )

    print("Computing ood_metrics_decoded")
    # ood set: evaluation in decoded (=gene) space
    ood_metrics_decoded[subgroup] = jax.tree_util.tree_map(
        # compute_metrics, 
        compute_metrics_fast, 
        ood_data_target_decoded[subgroup], 
        ood_data_target_decoded_predicted[subgroup]
    )
    mean_ood_metrics_decoded[subgroup] = compute_mean_metrics(
        ood_metrics_decoded[subgroup], 
        prefix="decoded_ood_",
    )

    # ood set
    ood_deg_dict[subgroup] = {
        k: v
        for k, v in adata_train.uns['rank_genes_groups_cov_all'].items() 
        if k in ood_data_target_decoded_predicted[subgroup].keys()
    }

    print("Apply DEG mask")
    # ood set
    ood_deg_target_decoded_predicted[subgroup] = jax.tree_util.tree_map(
        get_mask, 
        ood_data_target_decoded_predicted[subgroup], 
        ood_deg_dict[subgroup]
    )
    
    ood_deg_target_decoded[subgroup] = jax.tree_util.tree_map(
        get_mask, 
        ood_data_target_decoded[subgroup], 
        ood_deg_dict[subgroup]
    )

    print("Compute metrics on DEG subsetted decoded")
    deg_ood_metrics[subgroup] = jax.tree_util.tree_map(
        compute_metrics, 
        # compute_metrics_fast, 
        ood_deg_target_decoded[subgroup], 
        ood_deg_target_decoded_predicted[subgroup]
    )
    deg_mean_ood_metrics[subgroup] = compute_mean_metrics(
        deg_ood_metrics[subgroup], 
        prefix="deg_ood_"
    )

  0%|          | 0/5 [00:00<?, ?it/s]

subgroup: double_seen_2
Computing ood_metrics_encoded
Computing ood_metrics_decoded
Apply DEG mask
Compute metrics on DEG subsetted decoded
subgroup: double_seen_1
Computing ood_metrics_encoded
Computing ood_metrics_decoded
Apply DEG mask
Compute metrics on DEG subsetted decoded
subgroup: double_seen_0
Computing ood_metrics_encoded
Computing ood_metrics_decoded
Apply DEG mask
Compute metrics on DEG subsetted decoded
subgroup: single
Computing ood_metrics_encoded
Computing ood_metrics_decoded
Apply DEG mask
Compute metrics on DEG subsetted decoded
subgroup: all
Computing ood_metrics_encoded
Computing ood_metrics_decoded
Apply DEG mask
Compute metrics on DEG subsetted decoded


In [16]:
mean_ood_metrics_encoded

{'double_seen_2': {'encoded_ood_r_squared': -1.3647322495778402,
  'encoded_ood_sinkhorn_div_1': 28.734171422322593,
  'encoded_ood_sinkhorn_div_10': 23.849644088745116,
  'encoded_ood_sinkhorn_div_100': 19.890652465820313,
  'encoded_ood_e_distance': np.float64(38.124879154580334),
  'encoded_ood_mmd': np.float32(0.08488757)},
 'double_seen_1': {'encoded_ood_r_squared': -0.9417630569501356,
  'encoded_ood_sinkhorn_div_1': 33.04746570370414,
  'encoded_ood_sinkhorn_div_10': 28.161083178086713,
  'encoded_ood_sinkhorn_div_100': 23.962311701341108,
  'encoded_ood_e_distance': np.float64(45.932549951830666),
  'encoded_ood_mmd': np.float32(0.09454148)},
 'double_seen_0': {'encoded_ood_r_squared': -1.0625868042310078,
  'encoded_ood_sinkhorn_div_1': 31.665705521901447,
  'encoded_ood_sinkhorn_div_10': 26.255151748657227,
  'encoded_ood_sinkhorn_div_100': 21.613754590352375,
  'encoded_ood_e_distance': np.float64(41.34482861795814),
  'encoded_ood_mmd': np.float32(0.09608994)},
 'single': {

In [17]:
mean_ood_metrics_decoded

{'double_seen_2': {'decoded_ood_r_squared': 0.9349655826886495,
  'decoded_ood_e_distance': np.float64(47.73458487247292),
  'decoded_ood_mmd_distance': np.float32(0.036045324)},
 'double_seen_1': {'decoded_ood_r_squared': 0.9226068596948277,
  'decoded_ood_e_distance': np.float64(56.05413620752554),
  'decoded_ood_mmd_distance': np.float32(0.039416768)},
 'double_seen_0': {'decoded_ood_r_squared': 0.9290193617343903,
  'decoded_ood_e_distance': np.float64(51.601263640622),
  'decoded_ood_mmd_distance': np.float32(0.04251587)},
 'single': {'decoded_ood_r_squared': 0.9670263926188151,
  'decoded_ood_e_distance': np.float64(24.169279849000336),
  'decoded_ood_mmd_distance': np.float32(0.021176755)},
 'all': {'decoded_ood_r_squared': 0.9400034396447868,
  'decoded_ood_e_distance': np.float64(43.66084116326818),
  'decoded_ood_mmd_distance': np.float32(0.03315487)}}

In [18]:
deg_mean_ood_metrics

{'double_seen_2': {'deg_ood_r_squared': 0.6987430771191915,
  'deg_ood_sinkhorn_div_1': 33.79127197265625,
  'deg_ood_sinkhorn_div_10': 20.059823989868164,
  'deg_ood_sinkhorn_div_100': 17.492684427897135,
  'deg_ood_e_distance': np.float64(34.43584635916893),
  'deg_ood_mmd': np.float32(0.0786094)},
 'double_seen_1': {'deg_ood_r_squared': 0.7215764969587326,
  'deg_ood_sinkhorn_div_1': 30.022345504977487,
  'deg_ood_sinkhorn_div_10': 19.109158900651064,
  'deg_ood_sinkhorn_div_100': 17.06007580323653,
  'deg_ood_e_distance': np.float64(33.654820371900875),
  'deg_ood_mmd': np.float32(0.0894188)},
 'double_seen_0': {'deg_ood_r_squared': 0.6453550110260645,
  'deg_ood_sinkhorn_div_1': 28.701750442385674,
  'deg_ood_sinkhorn_div_10': 18.24525158603986,
  'deg_ood_sinkhorn_div_100': 16.163390015562374,
  'deg_ood_e_distance': np.float64(31.879554067707645),
  'deg_ood_mmd': np.float32(0.09031654)},
 'single': {'deg_ood_r_squared': 0.859203490946028,
  'deg_ood_sinkhorn_div_1': 20.65730370

In [19]:
collected_results = {
    # test
    "test_metrics_encoded": test_metrics_encoded,
    "mean_test_metrics_encoded": mean_test_metrics_encoded,
    "test_metrics_decoded": test_metrics_decoded, 
    "mean_test_metrics_decoded": mean_test_metrics_decoded,
    "test_deg_dict": test_deg_dict,
    "test_deg_target_decoded_predicted": test_deg_target_decoded_predicted,
    "test_deg_target_decoded": test_deg_target_decoded,
    "deg_test_metrics": deg_test_metrics,
    "deg_mean_test_metrics": deg_mean_test_metrics,
    # ood
    "ood_metrics_encoded": ood_metrics_encoded,
    "mean_ood_metrics_encoded": mean_ood_metrics_encoded,
    "ood_metrics_decoded": ood_metrics_decoded,
    "mean_ood_metrics_decoded": mean_ood_metrics_decoded,
    "deg_ood_metrics": deg_ood_metrics,
    "deg_mean_ood_metrics": deg_mean_ood_metrics,
    "ood_deg_dict": ood_deg_dict,
    "ood_deg_target_decoded_predicted": ood_deg_target_decoded_predicted,
    "ood_deg_target_decoded": ood_deg_target_decoded,
}

In [20]:
collected_results["ood_deg_dict"]

{'double_seen_2': {'CEBPE+CNN1': array(['ENSG00000092067', 'ENSG00000130176', 'ENSG00000154479',
         'ENSG00000135047', 'ENSG00000167996', 'ENSG00000128683',
         'ENSG00000108106', 'ENSG00000149516', 'ENSG00000144061',
         'ENSG00000185650', 'ENSG00000225077', 'ENSG00000065978',
         'ENSG00000164611', 'ENSG00000143013', 'ENSG00000162772',
         'ENSG00000128228', 'ENSG00000134333', 'ENSG00000130656',
         'ENSG00000144591', 'ENSG00000239672', 'ENSG00000130513',
         'ENSG00000146278', 'ENSG00000204482', 'ENSG00000117632',
         'ENSG00000279669', 'ENSG00000116161', 'ENSG00000064886',
         'ENSG00000111057', 'ENSG00000197766', 'ENSG00000026025',
         'ENSG00000236824', 'ENSG00000078795', 'ENSG00000134107',
         'ENSG00000133636', 'ENSG00000206172', 'ENSG00000117450',
         'ENSG00000074800', 'ENSG00000115053', 'ENSG00000121769',
         'ENSG00000126067', 'ENSG00000245532', 'ENSG00000085514',
         'ENSG00000198736', 'ENSG00000214595'

In [21]:
OUT_DIR = "/lustre/groups/ml01/workspace/ot_perturbation/data/norman_soren/identity"
pd.to_pickle(collected_results, os.path.join(OUT_DIR, f"norman_identity_split_{split}_collected_results.pkl"))