In [1]:
import functools

import jax
import numpy as np
import scanpy as sc

from cfp.metrics import compute_mean_metrics, compute_metrics, compute_metrics_fast

In [2]:
adata_preds = sc.read("/lustre/groups/ml01/workspace/ot_perturbation/models/cpa/combosciplex/adata_with_predictions_2.h5ad")

  utils.warn_names_duplicates("obs")


In [3]:
adata_train_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_train_2.h5ad"
adata_test_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_test_2.h5ad"
adata_ood_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_ood_2.h5ad"

In [4]:
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)

In [30]:
adata_ref_test = adata_test.copy()
cfpp.centered_pca(adata_ref_test, n_comps=10)

In [5]:
adata_ref_ood = sc.read("/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_ood_2_pca_for_validation.h5ad")

In [6]:
import cfp.preprocessing as cfpp

In [7]:
OBSM_KEY_DATA_EMBEDDING = "X_pca"

In [8]:
adata_preds.obs["split"].value_counts()


split
train    45967
test     12000
ood       3500
Name: count, dtype: int64

In [31]:
adata_pred_test = adata_preds[adata_preds.obs["split"]=="test"]
adata_pred_test.layers["CPA_PRED"] = adata_pred_test.obsm["CPA_pred"]
cfpp.project_pca(query_adata=adata_pred_test, ref_adata=adata_ref_test, layer="CPA_PRED")
cfpp.project_pca(query_adata=adata_test, ref_adata=adata_ref_test, layer="CPA_PRED")
test_data_target_encoded = {}
test_data_target_decoded = {}
test_data_target_encoded_predicted = {}
test_data_target_decoded_predicted = {}
for cond in adata_test.obs["condition"].cat.categories:
    if cond == "control":
        continue
    test_data_target_encoded[cond] = adata_test[adata_test.obs["condition"] == cond].obsm[OBSM_KEY_DATA_EMBEDDING]
    test_data_target_decoded[cond] = adata_test[adata_test.obs["condition"] == cond].X.toarray()
    test_data_target_decoded_predicted[cond] = adata_pred_test[adata_pred_test.obs["condition"] == cond].obsm["CPA_pred"]
    test_data_target_encoded_predicted[cond] = adata_pred_test[adata_pred_test.obs["condition"] == cond].obsm["X_pca"]

adata_pred_ood = adata_preds[adata_preds.obs["split"]=="ood"]
adata_pred_ood.layers["CPA_PRED"] = adata_pred_ood.obsm["CPA_pred"]
cfpp.project_pca(query_adata=adata_pred_ood, ref_adata=adata_ref_ood, layer="CPA_PRED")
ood_data_target_encoded = {}
ood_data_target_decoded = {}
ood_data_target_encoded_predicted = {}
ood_data_target_decoded_predicted = {}
for cond in adata_ood.obs["condition"].cat.categories:
    if cond == "control":
        continue
    ood_data_target_encoded[cond] = adata_ref_ood[adata_ref_ood.obs["condition"] == cond].obsm[OBSM_KEY_DATA_EMBEDDING]
    ood_data_target_decoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].X.toarray()
    ood_data_target_decoded_predicted[cond] = adata_pred_ood[adata_pred_ood.obs["condition"] == cond].obsm["CPA_pred"]
    ood_data_target_encoded_predicted[cond] = adata_pred_ood[adata_pred_ood.obs["condition"] == cond].obsm["X_pca"]

  adata_pred_test.layers["CPA_PRED"] = adata_pred_test.obsm["CPA_pred"]
  adata_pred_ood.layers["CPA_PRED"] = adata_pred_ood.obsm["CPA_pred"]


In [18]:
test_deg_dict = {
    k: v
    for k, v in adata_train.uns["rank_genes_groups_cov_all"].items()
    if k in test_data_target_decoded_predicted.keys()
}
ood_deg_dict = {
    k: v
    for k, v in adata_train.uns["rank_genes_groups_cov_all"].items()
    if k in ood_data_target_decoded_predicted.keys()
}

In [19]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]


ood_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, ood_data_target_decoded_predicted, ood_deg_dict)
ood_deg_target_decoded = jax.tree_util.tree_map(get_mask, ood_data_target_decoded, ood_deg_dict)

test_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, test_data_target_decoded_predicted, test_deg_dict)
test_deg_target_decoded = jax.tree_util.tree_map(get_mask, test_data_target_decoded, test_deg_dict)


In [20]:
deg_ood_metrics = jax.tree_util.tree_map(compute_metrics, ood_deg_target_decoded, ood_deg_target_decoded_predicted)
deg_mean_ood_metrics = compute_mean_metrics(deg_ood_metrics, prefix="deg_ood_")

deg_test_metrics = jax.tree_util.tree_map(compute_metrics, test_deg_target_decoded, test_deg_target_decoded_predicted)
deg_mean_test_metrics = compute_mean_metrics(deg_test_metrics, prefix="deg_test_")


In [21]:
test_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, test_data_target_encoded, test_data_target_encoded_predicted
)
mean_test_metrics_encoded = compute_mean_metrics(test_metrics_encoded, prefix="encoded_test_")

test_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics_fast, test_data_target_decoded, test_data_target_decoded_predicted
)
mean_test_metrics_decoded = compute_mean_metrics(test_metrics_decoded, prefix="decoded_test_")

In [22]:
ood_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, ood_data_target_encoded, ood_data_target_encoded_predicted
)
mean_ood_metrics_encoded = compute_mean_metrics(ood_metrics_encoded, prefix="encoded_ood_")

ood_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics, ood_data_target_decoded, ood_data_target_decoded_predicted
)
mean_ood_metrics_decoded = compute_mean_metrics(ood_metrics_decoded, prefix="decoded_ood_")

In [23]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.6862101980618068,
 'decoded_ood_sinkhorn_div_1': 110.23924255371094,
 'decoded_ood_sinkhorn_div_10': 77.27082334245954,
 'decoded_ood_sinkhorn_div_100': 30.684291294642858,
 'decoded_ood_e_distance': 60.5277461113519,
 'decoded_ood_mmd': 0.21484620443412236}

In [24]:
mean_test_metrics_decoded

{'decoded_test_r_squared': 0.8049486900369326,
 'decoded_test_e_distance': 38.56786602168032,
 'decoded_test_mmd_distance': 0.19895603445669016}

In [16]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.7391917749232417,
 'decoded_ood_sinkhorn_div_1': 103.79780883789063,
 'decoded_ood_sinkhorn_div_10': 72.30388488769532,
 'decoded_ood_sinkhorn_div_100': 27.833563232421874,
 'decoded_ood_e_distance': 54.769449029362136,
 'decoded_ood_mmd': 0.246560040331874}

In [25]:
mean_ood_metrics_encoded

{'encoded_ood_r_squared': -1.8743951235498701,
 'encoded_ood_sinkhorn_div_1': 33.54921545301165,
 'encoded_ood_sinkhorn_div_10': 30.328326089041575,
 'encoded_ood_sinkhorn_div_100': 29.202074595860072,
 'encoded_ood_e_distance': 58.14460522488314,
 'encoded_ood_mmd': 0.23800723467554366}

In [26]:
deg_mean_ood_metrics

{'deg_ood_r_squared': 0.4604918360710144,
 'deg_ood_sinkhorn_div_1': 27.17738880429949,
 'deg_ood_sinkhorn_div_10': 18.563781329563685,
 'deg_ood_sinkhorn_div_100': 17.91834463391985,
 'deg_ood_e_distance': 35.70447261741083,
 'deg_ood_mmd': 0.28966505399772097}

In [27]:
deg_mean_test_metrics

{'deg_test_r_squared': 0.6556011984745661,
 'deg_test_sinkhorn_div_1': 20.04078009724617,
 'deg_test_sinkhorn_div_10': 11.906160990397135,
 'deg_test_sinkhorn_div_100': 10.913932899634043,
 'deg_test_e_distance': 21.686399635285508,
 'deg_test_mmd': 0.26243185127774876}

In [17]:
mean_test_metrics_decoded

{'decoded_test_r_squared': 0.956132341509689,
 'decoded_test_sinkhorn_div_1': 82.31482153672438,
 'decoded_test_sinkhorn_div_10': 63.47894507188063,
 'decoded_test_sinkhorn_div_100': 6.181912348820613,
 'decoded_test_e_distance': 8.298093321851823,
 'decoded_test_mmd': 0.1823629892674941}

In [18]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.7391917749232417,
 'decoded_ood_sinkhorn_div_1': 103.79780883789063,
 'decoded_ood_sinkhorn_div_10': 72.30388488769532,
 'decoded_ood_sinkhorn_div_100': 27.833563232421874,
 'decoded_ood_e_distance': 54.769449029362136,
 'decoded_ood_mmd': 0.246560040331874}

In [19]:
deg_mean_ood_metrics

{'deg_ood_r_squared': 0.5159547481300902,
 'deg_ood_sinkhorn_div_1': 24.493817806243896,
 'deg_ood_sinkhorn_div_10': 16.49903392791748,
 'deg_ood_sinkhorn_div_100': 15.873676109313966,
 'deg_ood_e_distance': 31.61525838731057,
 'deg_ood_mmd': 0.34742090103299406}

In [20]:
mean_test_metrics_decoded

{'decoded_test_r_squared': 0.956132341509689,
 'decoded_test_sinkhorn_div_1': 82.31482153672438,
 'decoded_test_sinkhorn_div_10': 63.47894507188063,
 'decoded_test_sinkhorn_div_100': 6.181912348820613,
 'decoded_test_e_distance': 8.298093321851823,
 'decoded_test_mmd': 0.1823629892674941}

In [21]:
mean_test_metrics_encoded

{'encoded_test_r_squared': 0.7785426383460113,
 'encoded_test_sinkhorn_div_1': 53.68342032799354,
 'encoded_test_sinkhorn_div_10': 34.97506611163799,
 'encoded_test_sinkhorn_div_100': 4.564601017878606,
 'encoded_test_e_distance': 7.258525984396837,
 'encoded_test_mmd': 0.16077130219967567}

In [22]:
ood_metrics_encoded

{'Cediranib+PCI-34051': {'r_squared': 0.9074379652405279,
  'sinkhorn_div_1': 44.772613525390625,
  'sinkhorn_div_10': 14.927448272705078,
  'sinkhorn_div_100': 1.5101356506347656,
  'e_distance': 2.6157738667272463,
  'mmd': 0.17343959115221655},
 'Givinostat+SRT1720': {'r_squared': 0.8628172793421973,
  'sinkhorn_div_1': 46.553855895996094,
  'sinkhorn_div_10': 16.30636978149414,
  'sinkhorn_div_100': 1.3358192443847656,
  'e_distance': 2.271262392813491,
  'mmd': 0.16894391608478376},
 'Panobinostat+Crizotinib': {'r_squared': -0.9476498324506673,
  'sinkhorn_div_1': 115.91564178466797,
  'sinkhorn_div_10': 83.86618041992188,
  'sinkhorn_div_100': 68.12455749511719,
  'e_distance': 135.32155961493643,
  'mmd': 0.32815091943184865},
 'Panobinostat+PCI-34051': {'r_squared': -1.1604597817810318,
  'sinkhorn_div_1': 105.5364761352539,
  'sinkhorn_div_10': 73.5385513305664,
  'sinkhorn_div_100': 58.305912017822266,
  'e_distance': 115.77804275511745,
  'mmd': 0.3094106159503713},
 'SRT210

In [23]:
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/results/cpa"

In [25]:
import os
import pandas as pd

pd.DataFrame.from_dict(ood_metrics_encoded).to_csv(os.path.join(output_dir, "ood_metrics_encoded_1.csv"))
pd.DataFrame.from_dict(ood_metrics_decoded).to_csv(os.path.join(output_dir, "ood_metrics_decoded_1.csv"))
pd.DataFrame.from_dict(test_metrics_encoded).to_csv(os.path.join(output_dir, "test_metrics_encoded_1.csv"))
pd.DataFrame.from_dict(test_metrics_decoded).to_csv(os.path.join(output_dir, "test_metrics_decoded_1.csv"))
pd.DataFrame.from_dict(deg_test_metrics).to_csv(os.path.join(output_dir, "test_metrics_deg_1.csv"))
pd.DataFrame.from_dict(deg_ood_metrics).to_csv(os.path.join(output_dir, "ood_metrics_ood_1.csv"))
