In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

In [2]:
import functools

import jax
import numpy as np
import scanpy as sc

from cfp.metrics import compute_mean_metrics, compute_metrics, compute_metrics_fast
import cfp.preprocessing as cfpp

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
split = 3

In [4]:
adata_preds = sc.read(f"/lustre/groups/ml01/workspace/ot_perturbation/models/chemcpa/combosciplex/adata_with_predictions_{split}.h5ad")

  utils.warn_names_duplicates("obs")


In [5]:
adata_train_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_train_{split}.h5ad"
adata_test_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_test_{split}.h5ad"
adata_ood_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_ood_{split}.h5ad"

In [6]:
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)

In [7]:
adata_ref_test = adata_test[adata_test.obs["condition"]!="control"].copy()
cfpp.centered_pca(adata_ref_test, n_comps=10)

adata_ref_ood = adata_ood[adata_ood.obs["condition"]!="control"].copy()
cfpp.centered_pca(adata_ref_ood, n_comps=10)

In [8]:
adata_pred_test = adata_preds[adata_preds.obs["split"]=="test"]
adata_pred_test.layers["CPA_PRED"] = adata_pred_test.obsm["CPA_pred"]
cfpp.project_pca(query_adata=adata_pred_test, ref_adata=adata_ref_test, layer="CPA_PRED")
cfpp.project_pca(query_adata=adata_test, ref_adata=adata_ref_test)
test_data_target_encoded = {}
test_data_target_decoded = {}
test_data_target_encoded_predicted = {}
test_data_target_decoded_predicted = {}
for cond in adata_test.obs["condition"].cat.categories:
    if cond == "control":
        continue
    test_data_target_encoded[cond] = adata_test[adata_test.obs["condition"] == cond].obsm["X_pca"]
    test_data_target_decoded[cond] = adata_test[adata_test.obs["condition"] == cond].X.toarray()
    test_data_target_decoded_predicted[cond] = adata_pred_test[adata_pred_test.obs["condition"] == cond].obsm["CPA_pred"]
    test_data_target_encoded_predicted[cond] = adata_pred_test[adata_pred_test.obs["condition"] == cond].obsm["X_pca"]

adata_pred_ood = adata_preds[adata_preds.obs["split"]=="ood"]
adata_pred_ood.layers["CPA_PRED"] = adata_pred_ood.obsm["CPA_pred"]
cfpp.project_pca(query_adata=adata_pred_ood, ref_adata=adata_ref_ood, layer="CPA_PRED")
cfpp.project_pca(query_adata=adata_ood, ref_adata=adata_ref_ood)
ood_data_target_encoded = {}
ood_data_target_decoded = {}
ood_data_target_encoded_predicted = {}
ood_data_target_decoded_predicted = {}
for cond in adata_ood.obs["condition"].cat.categories:
    if cond == "control":
        continue
    ood_data_target_encoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].obsm["X_pca"]
    ood_data_target_decoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].X.toarray()
    ood_data_target_decoded_predicted[cond] = adata_pred_ood[adata_pred_ood.obs["condition"] == cond].obsm["CPA_pred"]
    ood_data_target_encoded_predicted[cond] = adata_pred_ood[adata_pred_ood.obs["condition"] == cond].obsm["X_pca"]

  adata_pred_test.layers["CPA_PRED"] = adata_pred_test.obsm["CPA_pred"]
  adata_pred_ood.layers["CPA_PRED"] = adata_pred_ood.obsm["CPA_pred"]


In [9]:
test_deg_dict = {
    k: v
    for k, v in adata_train.uns["rank_genes_groups_cov_all"].items()
    if k in test_data_target_decoded_predicted.keys()
}
ood_deg_dict = {
    k: v
    for k, v in adata_train.uns["rank_genes_groups_cov_all"].items()
    if k in ood_data_target_decoded_predicted.keys()
}

In [10]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]


ood_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, ood_data_target_decoded_predicted, ood_deg_dict)
ood_deg_target_decoded = jax.tree_util.tree_map(get_mask, ood_data_target_decoded, ood_deg_dict)

test_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, test_data_target_decoded_predicted, test_deg_dict)
test_deg_target_decoded = jax.tree_util.tree_map(get_mask, test_data_target_decoded, test_deg_dict)


In [11]:
deg_ood_metrics = jax.tree_util.tree_map(compute_metrics, ood_deg_target_decoded, ood_deg_target_decoded_predicted)
deg_mean_ood_metrics = compute_mean_metrics(deg_ood_metrics, prefix="deg_ood_")

deg_test_metrics = jax.tree_util.tree_map(compute_metrics, test_deg_target_decoded, test_deg_target_decoded_predicted)
deg_mean_test_metrics = compute_mean_metrics(deg_test_metrics, prefix="deg_test_")


In [12]:
test_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, test_data_target_encoded, test_data_target_encoded_predicted
)
mean_test_metrics_encoded = compute_mean_metrics(test_metrics_encoded, prefix="encoded_test_")

test_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics_fast, test_data_target_decoded, test_data_target_decoded_predicted
)
mean_test_metrics_decoded = compute_mean_metrics(test_metrics_decoded, prefix="decoded_test_")

In [13]:
ood_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, ood_data_target_encoded, ood_data_target_encoded_predicted
)
mean_ood_metrics_encoded = compute_mean_metrics(ood_metrics_encoded, prefix="encoded_ood_")

ood_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics, ood_data_target_decoded, ood_data_target_decoded_predicted
)
mean_ood_metrics_decoded = compute_mean_metrics(ood_metrics_decoded, prefix="decoded_ood_")

In [14]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.5770284788949149,
 'decoded_ood_sinkhorn_div_1': 120.45567212785993,
 'decoded_ood_sinkhorn_div_10': 87.89769799368722,
 'decoded_ood_sinkhorn_div_100': 42.39277866908482,
 'decoded_ood_e_distance': 83.78755889324448,
 'decoded_ood_mmd': 0.4680623199258532}

In [15]:
mean_test_metrics_decoded

{'decoded_test_r_squared': 0.7594517519076666,
 'decoded_test_e_distance': 46.338353357117484,
 'decoded_test_mmd_distance': 0.4408578996857007}

In [16]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.5770284788949149,
 'decoded_ood_sinkhorn_div_1': 120.45567212785993,
 'decoded_ood_sinkhorn_div_10': 87.89769799368722,
 'decoded_ood_sinkhorn_div_100': 42.39277866908482,
 'decoded_ood_e_distance': 83.78755889324448,
 'decoded_ood_mmd': 0.4680623199258532}

In [17]:
mean_ood_metrics_encoded

{'encoded_ood_r_squared': -13.597259368215289,
 'encoded_ood_sinkhorn_div_1': 39.86796079363142,
 'encoded_ood_sinkhorn_div_10': 35.57119410378592,
 'encoded_ood_sinkhorn_div_100': 34.06618915285383,
 'encoded_ood_e_distance': 67.76772728496633,
 'encoded_ood_mmd': 0.49574320231165203}

In [18]:
deg_mean_ood_metrics

{'deg_ood_r_squared': 0.25832273278917584,
 'deg_ood_sinkhorn_div_1': 33.83009392874582,
 'deg_ood_sinkhorn_div_10': 24.708332061767578,
 'deg_ood_sinkhorn_div_100': 23.940852846418107,
 'deg_ood_e_distance': 47.72245416540198,
 'deg_ood_mmd': 0.5313356646469661}

In [19]:
deg_mean_test_metrics

{'deg_test_r_squared': 0.604872907201449,
 'deg_test_sinkhorn_div_1': 22.147478818893433,
 'deg_test_sinkhorn_div_10': 13.676102101802826,
 'deg_test_sinkhorn_div_100': 12.632001141707102,
 'deg_test_e_distance': 25.116608559157182,
 'deg_test_mmd': 0.46812040234605473}

In [20]:
mean_test_metrics_decoded

{'decoded_test_r_squared': 0.7594517519076666,
 'decoded_test_e_distance': 46.338353357117484,
 'decoded_test_mmd_distance': 0.4408578996857007}

In [21]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.5770284788949149,
 'decoded_ood_sinkhorn_div_1': 120.45567212785993,
 'decoded_ood_sinkhorn_div_10': 87.89769799368722,
 'decoded_ood_sinkhorn_div_100': 42.39277866908482,
 'decoded_ood_e_distance': 83.78755889324448,
 'decoded_ood_mmd': 0.4680623199258532}

In [22]:
deg_mean_ood_metrics

{'deg_ood_r_squared': 0.25832273278917584,
 'deg_ood_sinkhorn_div_1': 33.83009392874582,
 'deg_ood_sinkhorn_div_10': 24.708332061767578,
 'deg_ood_sinkhorn_div_100': 23.940852846418107,
 'deg_ood_e_distance': 47.72245416540198,
 'deg_ood_mmd': 0.5313356646469661}

In [23]:
mean_test_metrics_decoded

{'decoded_test_r_squared': 0.7594517519076666,
 'decoded_test_e_distance': 46.338353357117484,
 'decoded_test_mmd_distance': 0.4408578996857007}

In [24]:
mean_test_metrics_encoded

{'encoded_test_r_squared': -0.3432772200306198,
 'encoded_test_sinkhorn_div_1': 27.504093031088512,
 'encoded_test_sinkhorn_div_10': 22.98530109723409,
 'encoded_test_sinkhorn_div_100': 21.618591626485188,
 'encoded_test_e_distance': 42.94924890336099,
 'encoded_test_mmd': 0.3953709602355957}

In [25]:
ood_metrics_encoded

{'Panobinostat+Crizotinib': {'r_squared': -7.597105026245117,
  'sinkhorn_div_1': 74.4083251953125,
  'sinkhorn_div_10': 68.1750717163086,
  'sinkhorn_div_100': 65.7416763305664,
  'e_distance': 130.85884892569663,
  'mmd': 0.61414474},
 'Panobinostat+Curcumin': {'r_squared': -20.320356369018555,
  'sinkhorn_div_1': 56.02846145629883,
  'sinkhorn_div_10': 50.24510955810547,
  'sinkhorn_div_100': 48.27759552001953,
  'e_distance': 96.08875109637636,
  'mmd': 0.5602491},
 'Panobinostat+SRT1720': {'r_squared': -54.51448440551758,
  'sinkhorn_div_1': 48.59428787231445,
  'sinkhorn_div_10': 42.99931335449219,
  'sinkhorn_div_100': 41.0257568359375,
  'e_distance': 81.60193937575166,
  'mmd': 0.5518963},
 'Panobinostat+Sorafenib': {'r_squared': -14.233306884765625,
  'sinkhorn_div_1': 64.84175872802734,
  'sinkhorn_div_10': 58.6452751159668,
  'sinkhorn_div_100': 56.38493728637695,
  'e_distance': 112.22745213330732,
  'mmd': 0.5870464},
 'SRT2104+Alvespimycin': {'r_squared': 0.8866101503372

In [26]:
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/results/chemcpa"

In [27]:
import os
import pandas as pd

pd.DataFrame.from_dict(ood_metrics_encoded).to_csv(os.path.join(output_dir, f"ood_metrics_encoded_{split}.csv"))
pd.DataFrame.from_dict(ood_metrics_decoded).to_csv(os.path.join(output_dir, f"ood_metrics_decoded_{split}.csv"))
pd.DataFrame.from_dict(test_metrics_encoded).to_csv(os.path.join(output_dir, f"test_metrics_encoded_{split}.csv"))
pd.DataFrame.from_dict(test_metrics_decoded).to_csv(os.path.join(output_dir, f"test_metrics_decoded_{split}.csv"))
pd.DataFrame.from_dict(deg_test_metrics).to_csv(os.path.join(output_dir, f"test_metrics_deg_{split}.csv"))
pd.DataFrame.from_dict(deg_ood_metrics).to_csv(os.path.join(output_dir, f"ood_metrics_ood_{split}.csv"))
