In [1]:
import functools

import jax
import numpy as np
import scanpy as sc

from cfp.metrics import compute_mean_metrics, compute_metrics, compute_metrics_fast
import cfp.preprocessing as cfpp

In [2]:
split = 2

In [3]:
adata_pred_ood = sc.read(f"/lustre/groups/ml01/workspace/ot_perturbation/models/otfm/combosciplex/adata_ood_with_predictions_{split}_attention_seed.h5ad")
adata_pred_test = sc.read(f"/lustre/groups/ml01/workspace/ot_perturbation/models/otfm/combosciplex/adata_test_with_predictions_{split}_attention_seed.h5ad")

In [4]:
adata_train_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_train_{split}.h5ad"
adata_test_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_test_{split}.h5ad"
adata_ood_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_ood_{split}.h5ad"

In [5]:
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)

In [6]:
adata_ref_test = adata_test[adata_test.obs["condition"]!="control"].copy()
cfpp.centered_pca(adata_ref_test, n_comps=10)

adata_ref_ood = adata_ood[adata_ood.obs["condition"]!="control"].copy()
cfpp.centered_pca(adata_ref_ood, n_comps=10)

In [7]:
cfpp.project_pca(query_adata=adata_pred_test, ref_adata=adata_ref_test, layer="X_recon_pred")
cfpp.project_pca(query_adata=adata_test, ref_adata=adata_ref_test)
test_data_target_encoded = {}
test_data_target_decoded = {}
test_data_target_encoded_predicted = {}
test_data_target_decoded_predicted = {}
for cond in adata_test.obs["condition"].cat.categories:
    if cond == "control":
        continue
    test_data_target_encoded[cond] = adata_test[adata_test.obs["condition"] == cond].obsm["X_pca"]
    test_data_target_decoded[cond] = adata_test[adata_test.obs["condition"] == cond].X.toarray()
    test_data_target_decoded_predicted[cond] = adata_pred_test[adata_pred_test.obs["condition"] == cond].layers["X_recon_pred"]
    test_data_target_encoded_predicted[cond] = adata_pred_test[adata_pred_test.obs["condition"] == cond].obsm["X_pca"]

cfpp.project_pca(query_adata=adata_pred_ood, ref_adata=adata_ref_ood, layer="X_recon_pred")
cfpp.project_pca(query_adata=adata_ood, ref_adata=adata_ref_ood)
ood_data_target_encoded = {}
ood_data_target_decoded = {}
ood_data_target_encoded_predicted = {}
ood_data_target_decoded_predicted = {}
for cond in adata_ood.obs["condition"].cat.categories:
    if cond == "control":
        continue
    ood_data_target_encoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].obsm["X_pca"]
    ood_data_target_decoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].X.toarray()
    ood_data_target_decoded_predicted[cond] = adata_pred_ood[adata_pred_ood.obs["condition"] == cond].layers["X_recon_pred"]
    ood_data_target_encoded_predicted[cond] = adata_pred_ood[adata_pred_ood.obs["condition"] == cond].obsm["X_pca"]

In [8]:
test_deg_dict = {
    k: v
    for k, v in adata_train.uns["rank_genes_groups_cov_all"].items()
    if k in test_data_target_decoded_predicted.keys()
}
ood_deg_dict = {
    k: v
    for k, v in adata_train.uns["rank_genes_groups_cov_all"].items()
    if k in ood_data_target_decoded_predicted.keys()
}

In [9]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]


ood_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, ood_data_target_decoded_predicted, ood_deg_dict)
ood_deg_target_decoded = jax.tree_util.tree_map(get_mask, ood_data_target_decoded, ood_deg_dict)

test_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, test_data_target_decoded_predicted, test_deg_dict)
test_deg_target_decoded = jax.tree_util.tree_map(get_mask, test_data_target_decoded, test_deg_dict)


In [10]:
deg_ood_metrics = jax.tree_util.tree_map(compute_metrics, ood_deg_target_decoded, ood_deg_target_decoded_predicted)
deg_mean_ood_metrics = compute_mean_metrics(deg_ood_metrics, prefix="deg_ood_")

deg_test_metrics = jax.tree_util.tree_map(compute_metrics, test_deg_target_decoded, test_deg_target_decoded_predicted)
deg_mean_test_metrics = compute_mean_metrics(deg_test_metrics, prefix="deg_test_")


In [11]:
test_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, test_data_target_encoded, test_data_target_encoded_predicted
)
mean_test_metrics_encoded = compute_mean_metrics(test_metrics_encoded, prefix="encoded_test_")

test_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics_fast, test_data_target_decoded, test_data_target_decoded_predicted
)
mean_test_metrics_decoded = compute_mean_metrics(test_metrics_decoded, prefix="decoded_test_")

In [12]:
ood_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, ood_data_target_encoded, ood_data_target_encoded_predicted
)
mean_ood_metrics_encoded = compute_mean_metrics(ood_metrics_encoded, prefix="encoded_ood_")

ood_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics, ood_data_target_decoded, ood_data_target_decoded_predicted
)
mean_ood_metrics_decoded = compute_mean_metrics(ood_metrics_decoded, prefix="decoded_ood_")

In [13]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.9468989274993805,
 'decoded_ood_sinkhorn_div_1': 123.31844983782086,
 'decoded_ood_sinkhorn_div_10': 75.3612529209682,
 'decoded_ood_sinkhorn_div_100': 5.561011178152902,
 'decoded_ood_e_distance': 10.224704456031729,
 'decoded_ood_mmd': 0.014579998728420054}

In [14]:
mean_ood_metrics_encoded

{'encoded_ood_r_squared': 0.5543742467113407,
 'encoded_ood_sinkhorn_div_1': 7.4788361958095,
 'encoded_ood_sinkhorn_div_10': 5.199157442365374,
 'encoded_ood_sinkhorn_div_100': 4.443621907915388,
 'encoded_ood_e_distance': 8.714489129419137,
 'encoded_ood_mmd': 0.048912342105593}

In [15]:
deg_mean_ood_metrics

{'deg_ood_r_squared': 0.9326301661203235,
 'deg_ood_sinkhorn_div_1': 15.514337403433663,
 'deg_ood_sinkhorn_div_10': 2.914127894810268,
 'deg_ood_sinkhorn_div_100': 2.2765846252441406,
 'deg_ood_e_distance': 4.435311405777582,
 'deg_ood_mmd': 0.018948799997035946}

In [16]:
deg_mean_test_metrics

{'deg_test_r_squared': 0.9843531548096424,
 'deg_test_sinkhorn_div_1': 12.556443870067596,
 'deg_test_sinkhorn_div_10': 1.4122504393259685,
 'deg_test_sinkhorn_div_100': 0.5204794009526571,
 'deg_test_e_distance': 0.9291850731203296,
 'deg_test_mmd': 0.012753711896948516}

In [17]:
mean_test_metrics_decoded

{'decoded_test_r_squared': 0.9828076732206649,
 'decoded_test_e_distance': 3.2462426788455443,
 'decoded_test_mmd_distance': 0.013929410371929407}

In [18]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.9468989274993805,
 'decoded_ood_sinkhorn_div_1': 123.31844983782086,
 'decoded_ood_sinkhorn_div_10': 75.3612529209682,
 'decoded_ood_sinkhorn_div_100': 5.561011178152902,
 'decoded_ood_e_distance': 10.224704456031729,
 'decoded_ood_mmd': 0.014579998728420054}

In [19]:
deg_mean_ood_metrics

{'deg_ood_r_squared': 0.9326301661203235,
 'deg_ood_sinkhorn_div_1': 15.514337403433663,
 'deg_ood_sinkhorn_div_10': 2.914127894810268,
 'deg_ood_sinkhorn_div_100': 2.2765846252441406,
 'deg_ood_e_distance': 4.435311405777582,
 'deg_ood_mmd': 0.018948799997035946}

In [20]:
mean_test_metrics_decoded

{'decoded_test_r_squared': 0.9828076732206649,
 'decoded_test_e_distance': 3.2462426788455443,
 'decoded_test_mmd_distance': 0.013929410371929407}

In [21]:
mean_test_metrics_encoded

{'encoded_test_r_squared': 0.9585716433517278,
 'encoded_test_sinkhorn_div_1': 4.39786159992218,
 'encoded_test_sinkhorn_div_10': 1.3866101503372192,
 'encoded_test_sinkhorn_div_100': 0.6818737188975016,
 'encoded_test_e_distance': 1.218643865758949,
 'encoded_test_mmd': 0.018176127050537616}

In [22]:
ood_metrics_encoded

{'Dacinostat+Dasatinib': {'r_squared': -0.02766840895330369,
  'sinkhorn_div_1': 7.21816349029541,
  'sinkhorn_div_10': 4.909151077270508,
  'sinkhorn_div_100': 4.3495683670043945,
  'e_distance': 8.591592556683414,
  'mmd': 0.054786656},
 'Dacinostat+PCI-34051': {'r_squared': 0.05521346272546557,
  'sinkhorn_div_1': 6.974331855773926,
  'sinkhorn_div_10': 4.813686370849609,
  'sinkhorn_div_100': 4.272054672241211,
  'e_distance': 8.432283673574675,
  'mmd': 0.054813523},
 'Givinostat+Cediranib': {'r_squared': 0.940542117498528,
  'sinkhorn_div_1': 2.198640823364258,
  'sinkhorn_div_10': 1.0289373397827148,
  'sinkhorn_div_100': 0.8444366455078125,
  'e_distance': 1.6717557388746336,
  'mmd': 0.013037602},
 'Givinostat+Curcumin': {'r_squared': 0.942776413317766,
  'sinkhorn_div_1': 2.1052322387695312,
  'sinkhorn_div_10': 1.040964126586914,
  'sinkhorn_div_100': 0.8444023132324219,
  'e_distance': 1.6434922518084303,
  'mmd': 0.009760681},
 'Panobinostat+Alvespimycin': {'r_squared': 0.

In [23]:
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/results/otfm"

In [24]:
import os
import pandas as pd

pd.DataFrame.from_dict(ood_metrics_encoded).to_csv(os.path.join(output_dir, f"ood_metrics_encoded_{split}_attention_seed.csv"))
pd.DataFrame.from_dict(ood_metrics_decoded).to_csv(os.path.join(output_dir, f"ood_metrics_decoded_{split}_attention_seed.csv"))
pd.DataFrame.from_dict(test_metrics_encoded).to_csv(os.path.join(output_dir, f"test_metrics_encoded_{split}_attention_seed.csv"))
pd.DataFrame.from_dict(test_metrics_decoded).to_csv(os.path.join(output_dir, f"test_metrics_decoded_{split}_attention_seed.csv"))
pd.DataFrame.from_dict(deg_test_metrics).to_csv(os.path.join(output_dir, f"test_metrics_deg_{split}_attention_seed.csv"))
pd.DataFrame.from_dict(deg_ood_metrics).to_csv(os.path.join(output_dir, f"ood_metrics_ood_{split}_attention_seed.csv"))
