In [1]:
import scanpy as sc
import numpy as np
import functools
import jax
from cfp.metrics import compute_metrics, compute_mean_metrics, compute_metrics_fast
import cfp.preprocessing as cfpp

In [2]:
split = 4

In [3]:
adata_train_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_train_{split}.h5ad"
adata_test_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_test_{split}.h5ad"
adata_ood_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_ood_{split}.h5ad"

In [4]:
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)


In [5]:
adata_ref_test = adata_test[adata_test.obs["condition"]!="control"].copy()
cfpp.centered_pca(adata_ref_test, n_comps=10)

adata_ref_ood = adata_ood[adata_ood.obs["condition"]!="control"].copy()
cfpp.centered_pca(adata_ref_ood, n_comps=10)

In [6]:
adata_pred_test = adata_test[adata_test.obs["condition"]=="control"]
cfpp.project_pca(query_adata=adata_pred_test, ref_adata=adata_ref_test)
cfpp.project_pca(query_adata=adata_test, ref_adata=adata_ref_test)
test_data_target_encoded = {}
test_data_target_decoded = {}
test_data_target_encoded_predicted = {}
test_data_target_decoded_predicted = {}
for cond in adata_test.obs["condition"].cat.categories:
    if cond == "control":
        continue
    test_data_target_encoded[cond] = adata_test[adata_test.obs["condition"] == cond].obsm["X_pca"]
    test_data_target_decoded[cond] = adata_test[adata_test.obs["condition"] == cond].X.toarray()
    test_data_target_decoded_predicted[cond] = adata_pred_test.X.toarray()
    test_data_target_encoded_predicted[cond] = adata_pred_test.obsm["X_pca"]

adata_pred_ood = adata_ood[adata_ood.obs["condition"]=="control"]
cfpp.project_pca(query_adata=adata_pred_ood, ref_adata=adata_ref_ood)
cfpp.project_pca(query_adata=adata_ood, ref_adata=adata_ref_ood)
ood_data_target_encoded = {}
ood_data_target_decoded = {}
ood_data_target_encoded_predicted = {}
ood_data_target_decoded_predicted = {}
for cond in adata_ood.obs["condition"].cat.categories:
    if cond == "control":
        continue
    ood_data_target_encoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].obsm["X_pca"]
    ood_data_target_decoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].X.toarray()
    ood_data_target_decoded_predicted[cond] = adata_pred_ood.X.toarray()
    ood_data_target_encoded_predicted[cond] = adata_pred_ood.obsm["X_pca"]

  query_adata.obsm[obsm_key_added] = np.array(


In [7]:
test_metrics_encoded = jax.tree_util.tree_map(compute_metrics, test_data_target_encoded, test_data_target_encoded_predicted)
mean_test_metrics_encoded = compute_mean_metrics(test_metrics_encoded, prefix="encoded_test_")

test_metrics_decoded = jax.tree_util.tree_map(compute_metrics_fast, test_data_target_decoded, test_data_target_decoded_predicted)
mean_test_metrics_decoded = compute_mean_metrics(test_metrics_decoded, prefix="decoded_test_")

In [8]:
ood_metrics_encoded = jax.tree_util.tree_map(compute_metrics, ood_data_target_encoded, ood_data_target_encoded_predicted)
mean_ood_metrics_encoded = compute_mean_metrics(ood_metrics_encoded, prefix="encoded_ood_")

ood_metrics_decoded = jax.tree_util.tree_map(compute_metrics_fast, ood_data_target_decoded, ood_data_target_decoded_predicted)
mean_ood_metrics_decoded = compute_mean_metrics(ood_metrics_decoded, prefix="decoded_ood_")




In [9]:
test_deg_dict = {k: v for k,v in adata_train.uns['rank_genes_groups_cov_all'].items() if k in test_data_target_decoded_predicted.keys()}
ood_deg_dict = {k: v for k,v in adata_train.uns['rank_genes_groups_cov_all'].items() if k in ood_data_target_decoded_predicted.keys()}

In [10]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]

ood_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, ood_data_target_decoded_predicted, ood_deg_dict)
ood_deg_target_decoded = jax.tree_util.tree_map(get_mask, ood_data_target_decoded, ood_deg_dict)

test_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, test_data_target_decoded_predicted, test_deg_dict)
test_deg_target_decoded = jax.tree_util.tree_map(get_mask, test_data_target_decoded, test_deg_dict)


In [11]:
deg_ood_metrics = jax.tree_util.tree_map(compute_metrics, ood_deg_target_decoded, ood_deg_target_decoded_predicted)
deg_mean_ood_metrics = compute_mean_metrics(deg_ood_metrics, prefix="deg_ood_")

deg_test_metrics = jax.tree_util.tree_map(compute_metrics, test_deg_target_decoded, test_deg_target_decoded_predicted)
deg_mean_test_metrics = compute_mean_metrics(deg_test_metrics, prefix="deg_test_")


In [12]:
mean_test_metrics_decoded 

{'decoded_test_r_squared': 0.6186772441864014,
 'decoded_test_e_distance': 74.62210162309972,
 'decoded_test_mmd_distance': 0.05043948840349913}

In [13]:
deg_mean_ood_metrics

{'deg_ood_r_squared': 0.8950363000233968,
 'deg_ood_sinkhorn_div_1': 11.893356243769327,
 'deg_ood_sinkhorn_div_10': 3.2520058949788413,
 'deg_ood_sinkhorn_div_100': 3.0406672159830728,
 'deg_ood_e_distance': 6.048266410641813,
 'deg_ood_mmd': 0.024619585640418034}

In [14]:
deg_mean_test_metrics

{'deg_test_r_squared': 0.30246726751327513,
 'deg_test_sinkhorn_div_1': 37.282035942077634,
 'deg_test_sinkhorn_div_10': 23.675760192871095,
 'deg_test_sinkhorn_div_100': 22.383598136901856,
 'deg_test_e_distance': 44.593742921132815,
 'deg_test_mmd': 0.10848859447985887}

In [15]:
mean_test_metrics_encoded

{'encoded_test_r_squared': -2.076555087215783,
 'encoded_test_sinkhorn_div_1': 39.54037618637085,
 'encoded_test_sinkhorn_div_10': 35.89650901794434,
 'encoded_test_sinkhorn_div_100': 34.65567150115967,
 'encoded_test_e_distance': 69.03055698791623,
 'encoded_test_mmd': 0.1936525147780776}

In [16]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.9424620866775513,
 'decoded_ood_e_distance': 10.26261064340016,
 'decoded_ood_mmd_distance': 0.00872026092838496}

In [17]:
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/results/identity"

In [18]:
import os
import pandas as pd

pd.DataFrame.from_dict(ood_metrics_encoded).to_csv(os.path.join(output_dir, f"ood_metrics_encoded_{split}.csv"))
pd.DataFrame.from_dict(ood_metrics_decoded).to_csv(os.path.join(output_dir, f"ood_metrics_decoded_{split}.csv"))
pd.DataFrame.from_dict(test_metrics_encoded).to_csv(os.path.join(output_dir, f"test_metrics_encoded_{split}.csv"))
pd.DataFrame.from_dict(test_metrics_decoded).to_csv(os.path.join(output_dir, f"test_metrics_decoded_{split}.csv"))
pd.DataFrame.from_dict(deg_test_metrics).to_csv(os.path.join(output_dir, f"test_metrics_deg_{split}.csv"))
pd.DataFrame.from_dict(deg_ood_metrics).to_csv(os.path.join(output_dir, f"ood_metrics_ood_{split}.csv"))
