In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

In [2]:
import functools

import jax
import numpy as np
import scanpy as sc

from cfp.metrics import compute_mean_metrics, compute_metrics, compute_metrics_fast
import cfp.preprocessing as cfpp

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
split = 3

In [4]:
adata_preds = sc.read(f"/lustre/groups/ml01/workspace/ot_perturbation/models/cpa/combosciplex/adata_with_predictions_{split}.h5ad")

  utils.warn_names_duplicates("obs")


In [5]:
adata_train_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_train_{split}.h5ad"
adata_test_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_test_{split}.h5ad"
adata_ood_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_ood_{split}.h5ad"

In [6]:
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)

In [7]:
adata_ref_test = adata_test[adata_test.obs["condition"]!="control"].copy()
cfpp.centered_pca(adata_ref_test, n_comps=10)

adata_ref_ood = adata_ood[adata_ood.obs["condition"]!="control"].copy()
cfpp.centered_pca(adata_ref_ood, n_comps=10)

In [8]:
adata_pred_test = adata_preds[adata_preds.obs["split"]=="test"]
adata_pred_test.layers["CPA_PRED"] = adata_pred_test.obsm["CPA_pred"]
cfpp.project_pca(query_adata=adata_pred_test, ref_adata=adata_ref_test, layer="CPA_PRED")
cfpp.project_pca(query_adata=adata_test, ref_adata=adata_ref_test)
test_data_target_encoded = {}
test_data_target_decoded = {}
test_data_target_encoded_predicted = {}
test_data_target_decoded_predicted = {}
for cond in adata_test.obs["condition"].cat.categories:
    if cond == "control":
        continue
    test_data_target_encoded[cond] = adata_test[adata_test.obs["condition"] == cond].obsm["X_pca"]
    test_data_target_decoded[cond] = adata_test[adata_test.obs["condition"] == cond].X.toarray()
    test_data_target_decoded_predicted[cond] = adata_pred_test[adata_pred_test.obs["condition"] == cond].obsm["CPA_pred"]
    test_data_target_encoded_predicted[cond] = adata_pred_test[adata_pred_test.obs["condition"] == cond].obsm["X_pca"]

adata_pred_ood = adata_preds[adata_preds.obs["split"]=="ood"]
adata_pred_ood.layers["CPA_PRED"] = adata_pred_ood.obsm["CPA_pred"]
cfpp.project_pca(query_adata=adata_pred_ood, ref_adata=adata_ref_ood, layer="CPA_PRED")
cfpp.project_pca(query_adata=adata_ood, ref_adata=adata_ref_ood)
ood_data_target_encoded = {}
ood_data_target_decoded = {}
ood_data_target_encoded_predicted = {}
ood_data_target_decoded_predicted = {}
for cond in adata_ood.obs["condition"].cat.categories:
    if cond == "control":
        continue
    ood_data_target_encoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].obsm["X_pca"]
    ood_data_target_decoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].X.toarray()
    ood_data_target_decoded_predicted[cond] = adata_pred_ood[adata_pred_ood.obs["condition"] == cond].obsm["CPA_pred"]
    ood_data_target_encoded_predicted[cond] = adata_pred_ood[adata_pred_ood.obs["condition"] == cond].obsm["X_pca"]

  adata_pred_test.layers["CPA_PRED"] = adata_pred_test.obsm["CPA_pred"]
  adata_pred_ood.layers["CPA_PRED"] = adata_pred_ood.obsm["CPA_pred"]


In [9]:
set(adata_ood.obs["condition"].cat.categories).intersection(set(adata_pred_ood.obs["condition"].cat.categories))

{'Panobinostat+Crizotinib',
 'Panobinostat+Curcumin',
 'Panobinostat+SRT1720',
 'Panobinostat+Sorafenib',
 'SRT2104+Alvespimycin',
 'control+Alvespimycin',
 'control+Dacinostat'}

In [10]:
adata_preds[adata_preds.obs["split"]=="train"].obs["condition"].value_counts()

condition
Dacinostat+PCI-34051         3198
SRT3025+Cediranib            2916
Givinostat+Cediranib         2683
control+SRT2104              2656
Givinostat+Curcumin          2636
Givinostat+Sorafenib         2634
Givinostat+Carmofur          2592
Givinostat+Crizotinib        2562
Givinostat+Dasatinib         2321
Givinostat+SRT2104           2253
control+Dasatinib            2243
Givinostat+SRT1720           2160
Cediranib+PCI-34051          2061
Panobinostat+SRT2104         1871
Panobinostat+Dasatinib       1855
Dacinostat+Danusertib        1839
Panobinostat+SRT3025         1789
Panobinostat+PCI-34051       1714
control+Givinostat           1582
control+Panobinostat         1478
Givinostat+Tanespimycin      1210
Dacinostat+Dasatinib         1131
control                       951
Panobinostat+Alvespimycin     896
Alvespimycin+Pirarubicin      376
Name: count, dtype: int64

In [11]:
adata_preds[adata_preds.obs["split"]=="ood"].n_obs

3500

In [12]:
adata_preds[adata_preds.obs["split"]=="test"].n_obs

12000

In [13]:
adata_preds[adata_preds.obs["split"]=="train"]


View of AnnData object with n_obs × n_vars = 49607 × 2000
    obs: 'condition', 'cell_type', 'condition_ID', 'log_dose', 'smiles_rdkit', 'split', 'split_1ct_MEC', 'CPA_cat', 'CPA_CHEMBL504', '_scvi_condition_ID', '_scvi_cell_type', '_scvi_CPA_cat'
    uns: '_scvi_manager_uuid', '_scvi_uuid'
    obsm: 'CPA_pred', 'perts', 'perts_doses'

In [14]:
adata_ood.obs["condition"].cat.categories

Index(['Panobinostat+Crizotinib', 'Panobinostat+Curcumin',
       'Panobinostat+SRT1720', 'Panobinostat+Sorafenib',
       'SRT2104+Alvespimycin', 'control', 'control+Alvespimycin',
       'control+Dacinostat'],
      dtype='object')

In [15]:
adata_pred_ood.obs["condition"].cat.categories

Index(['Panobinostat+Crizotinib', 'Panobinostat+Curcumin',
       'Panobinostat+SRT1720', 'Panobinostat+Sorafenib',
       'SRT2104+Alvespimycin', 'control+Alvespimycin', 'control+Dacinostat'],
      dtype='object')

In [16]:
test_deg_dict = {
    k: v
    for k, v in adata_train.uns["rank_genes_groups_cov_all"].items()
    if k in test_data_target_decoded_predicted.keys()
}
ood_deg_dict = {
    k: v
    for k, v in adata_train.uns["rank_genes_groups_cov_all"].items()
    if k in ood_data_target_decoded_predicted.keys()
}

In [17]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]


ood_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, ood_data_target_decoded_predicted, ood_deg_dict)
ood_deg_target_decoded = jax.tree_util.tree_map(get_mask, ood_data_target_decoded, ood_deg_dict)

test_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, test_data_target_decoded_predicted, test_deg_dict)
test_deg_target_decoded = jax.tree_util.tree_map(get_mask, test_data_target_decoded, test_deg_dict)


In [20]:
deg_ood_metrics = jax.tree_util.tree_map(compute_metrics, ood_deg_target_decoded, ood_deg_target_decoded_predicted)
deg_mean_ood_metrics = compute_mean_metrics(deg_ood_metrics, prefix="deg_ood_")

deg_test_metrics = jax.tree_util.tree_map(compute_metrics, test_deg_target_decoded, test_deg_target_decoded_predicted)
deg_mean_test_metrics = compute_mean_metrics(deg_test_metrics, prefix="deg_test_")


In [21]:
test_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, test_data_target_encoded, test_data_target_encoded_predicted
)
mean_test_metrics_encoded = compute_mean_metrics(test_metrics_encoded, prefix="encoded_test_")

test_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics_fast, test_data_target_decoded, test_data_target_decoded_predicted
)
mean_test_metrics_decoded = compute_mean_metrics(test_metrics_decoded, prefix="decoded_test_")

In [22]:
ood_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, ood_data_target_encoded, ood_data_target_encoded_predicted
)
mean_ood_metrics_encoded = compute_mean_metrics(ood_metrics_encoded, prefix="encoded_ood_")

ood_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics, ood_data_target_decoded, ood_data_target_decoded_predicted
)
mean_ood_metrics_decoded = compute_mean_metrics(ood_metrics_decoded, prefix="decoded_ood_")

In [23]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.6416063053267342,
 'decoded_ood_sinkhorn_div_1': 113.5741446358817,
 'decoded_ood_sinkhorn_div_10': 81.50048991612026,
 'decoded_ood_sinkhorn_div_100': 36.08098384312221,
 'decoded_ood_e_distance': 71.17611389136802,
 'decoded_ood_mmd': 0.29216617132936207}

In [24]:
mean_test_metrics_decoded

{'decoded_test_r_squared': 0.7978128120303154,
 'decoded_test_e_distance': 39.116195608926574,
 'decoded_test_mmd_distance': 0.2726070123414199}

In [25]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.6416063053267342,
 'decoded_ood_sinkhorn_div_1': 113.5741446358817,
 'decoded_ood_sinkhorn_div_10': 81.50048991612026,
 'decoded_ood_sinkhorn_div_100': 36.08098384312221,
 'decoded_ood_e_distance': 71.17611389136802,
 'decoded_ood_mmd': 0.29216617132936207}

In [26]:
mean_ood_metrics_encoded

{'encoded_ood_r_squared': -11.87179936681475,
 'encoded_ood_sinkhorn_div_1': 34.78901651927403,
 'encoded_ood_sinkhorn_div_10': 30.948229721614293,
 'encoded_ood_sinkhorn_div_100': 29.521682330540248,
 'encoded_ood_e_distance': 58.69375515452333,
 'encoded_ood_mmd': 0.32655507964747293}

In [27]:
deg_mean_ood_metrics

{'deg_ood_r_squared': 0.3396240217345102,
 'deg_ood_sinkhorn_div_1': 31.017115456717356,
 'deg_ood_sinkhorn_div_10': 22.082514354160853,
 'deg_ood_sinkhorn_div_100': 21.333051409040177,
 'deg_ood_e_distance': 42.51000233426526,
 'deg_ood_mmd': 0.3858695243086134}

In [28]:
deg_mean_test_metrics

{'deg_test_r_squared': 0.6550388683875402,
 'deg_test_sinkhorn_div_1': 20.331246356169384,
 'deg_test_sinkhorn_div_10': 12.083612958590189,
 'deg_test_sinkhorn_div_100': 11.0665452281634,
 'deg_test_e_distance': 21.99070169244564,
 'deg_test_mmd': 0.32991578864554566}

In [29]:
mean_test_metrics_decoded

{'decoded_test_r_squared': 0.7978128120303154,
 'decoded_test_e_distance': 39.116195608926574,
 'decoded_test_mmd_distance': 0.2726070123414199}

In [30]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.6416063053267342,
 'decoded_ood_sinkhorn_div_1': 113.5741446358817,
 'decoded_ood_sinkhorn_div_10': 81.50048991612026,
 'decoded_ood_sinkhorn_div_100': 36.08098384312221,
 'decoded_ood_e_distance': 71.17611389136802,
 'decoded_ood_mmd': 0.29216617132936207}

In [31]:
deg_mean_ood_metrics

{'deg_ood_r_squared': 0.3396240217345102,
 'deg_ood_sinkhorn_div_1': 31.017115456717356,
 'deg_ood_sinkhorn_div_10': 22.082514354160853,
 'deg_ood_sinkhorn_div_100': 21.333051409040177,
 'deg_ood_e_distance': 42.51000233426526,
 'deg_ood_mmd': 0.3858695243086134}

In [32]:
mean_test_metrics_decoded

{'decoded_test_r_squared': 0.7978128120303154,
 'decoded_test_e_distance': 39.116195608926574,
 'decoded_test_mmd_distance': 0.2726070123414199}

In [33]:
mean_test_metrics_encoded

{'encoded_test_r_squared': -0.06065262778467836,
 'encoded_test_sinkhorn_div_1': 23.0813467502594,
 'encoded_test_sinkhorn_div_10': 19.157341996828716,
 'encoded_test_sinkhorn_div_100': 17.929346044858296,
 'encoded_test_e_distance': 35.597068159694494,
 'encoded_test_mmd': 0.2378191373621424}

In [34]:
ood_metrics_encoded

{'Panobinostat+Crizotinib': {'r_squared': -6.651731014251709,
  'sinkhorn_div_1': 66.77953338623047,
  'sinkhorn_div_10': 60.942237854003906,
  'sinkhorn_div_100': 58.54538345336914,
  'e_distance': 116.46903808980781,
  'mmd': 0.44453254},
 'Panobinostat+Curcumin': {'r_squared': -18.521371841430664,
  'sinkhorn_div_1': 51.51376724243164,
  'sinkhorn_div_10': 46.152496337890625,
  'sinkhorn_div_100': 44.22247314453125,
  'e_distance': 87.98091181681215,
  'mmd': 0.407159},
 'Panobinostat+SRT1720': {'r_squared': -47.89146423339844,
  'sinkhorn_div_1': 43.186866760253906,
  'sinkhorn_div_10': 38.07693862915039,
  'sinkhorn_div_100': 36.15467834472656,
  'e_distance': 71.86661872701488,
  'mmd': 0.37971333},
 'Panobinostat+Sorafenib': {'r_squared': -11.814579963684082,
  'sinkhorn_div_1': 55.43647766113281,
  'sinkhorn_div_10': 49.70075988769531,
  'sinkhorn_div_100': 47.47370910644531,
  'e_distance': 94.408157092309,
  'mmd': 0.40299368},
 'SRT2104+Alvespimycin': {'r_squared': 0.9526504

In [35]:
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/results/cpa"

In [36]:
import os
import pandas as pd

pd.DataFrame.from_dict(ood_metrics_encoded).to_csv(os.path.join(output_dir, f"ood_metrics_encoded_{split}.csv"))
pd.DataFrame.from_dict(ood_metrics_decoded).to_csv(os.path.join(output_dir, f"ood_metrics_decoded_{split}.csv"))
pd.DataFrame.from_dict(test_metrics_encoded).to_csv(os.path.join(output_dir, f"test_metrics_encoded_{split}.csv"))
pd.DataFrame.from_dict(test_metrics_decoded).to_csv(os.path.join(output_dir, f"test_metrics_decoded_{split}.csv"))
pd.DataFrame.from_dict(deg_test_metrics).to_csv(os.path.join(output_dir, f"test_metrics_deg_{split}.csv"))
pd.DataFrame.from_dict(deg_ood_metrics).to_csv(os.path.join(output_dir, f"ood_metrics_ood_{split}.csv"))
