In [9]:
import functools

import jax
import numpy as np
import scanpy as sc
import cfp.preprocessing as cfpp
from cfp.metrics import compute_mean_metrics, compute_metrics, compute_metrics_fast

In [2]:
split=1

In [3]:
adata_train_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_train_{split}.h5ad"
adata_test_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_test_{split}.h5ad"
adata_ood_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_ood_{split}.h5ad"

In [4]:
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)

In [5]:
adata_train.obs["split"].value_counts()

split
train    355056
Name: count, dtype: int64

In [6]:
adata_test.obs["split"].value_counts()

split
test    81650
Name: count, dtype: int64

In [7]:
adata_ood.obs["split"].value_counts()

split
ood     134723
test      1500
Name: count, dtype: int64

In [28]:
adata_ref_ood = adata_ood[~adata_ood.obs["condition"].str.contains("Vehicle")].copy()
cfpp.centered_pca(adata_ref_ood, n_comps=10)



In [29]:
adata_ref_ood

AnnData object with n_obs × n_vars = 134723 × 2001
    obs: 'cell_type', 'dose', 'dose_character', 'dose_pattern', 'g1s_score', 'g2m_score', 'pathway', 'pathway_level_1', 'pathway_level_2', 'product_dose', 'product_name', 'proliferation_index', 'replicate', 'size_factor', 'target', 'vehicle', 'perturbation', 'drug', 'cell_line', 'logdose', 'condition', 'n_genes', 'pubchem_name', 'pubchem_ID', 'smiles', 'control', 'ood_1', 'ood_2', 'ood_3', 'ood_4', 'ood_5', 'split'
    uns: 'cell_line_dict', 'ecfp_dict', 'pca'
    obsm: 'X_pca', 'ecfp'
    varm: 'X_mean', 'PCs'
    layers: 'X_centered'

In [30]:
adata_pred_ood = adata_ood["Vehicle" in adata_ood.obs["condition"]]
cfpp.project_pca(query_adata=adata_pred_ood, ref_adata=adata_ref_ood)
cfpp.project_pca(query_adata=adata_ood, ref_adata=adata_ref_ood)
ood_data_target_encoded = {}
ood_data_target_decoded = {}
ood_data_target_encoded_predicted = {}
ood_data_target_decoded_predicted = {}
for cond in adata_ood.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    src_str = list(adata_ood[adata_ood.obs["condition"] == cond].obs["cell_type"].unique())
    assert len(src_str) == 1
    adata_pred_ood = adata_ood[adata_ood.obs["condition"] == src_str[0] + "_Vehicle_0.0"]
    
    ood_data_target_encoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].obsm["X_pca"]
    ood_data_target_decoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].X.toarray()
    ood_data_target_decoded_predicted[cond] = adata_pred_ood.X.toarray()
    ood_data_target_encoded_predicted[cond] = adata_pred_ood.obsm["X_pca"]

  query_adata.obsm[obsm_key_added] = np.array(


In [31]:
ood_metrics_encoded = jax.tree_util.tree_map(compute_metrics, ood_data_target_encoded, ood_data_target_encoded_predicted)
mean_ood_metrics_encoded = compute_mean_metrics(ood_metrics_encoded, prefix="encoded_ood_")

ood_metrics_decoded = jax.tree_util.tree_map(compute_metrics_fast, ood_data_target_decoded, ood_data_target_decoded_predicted)
mean_ood_metrics_decoded = compute_mean_metrics(ood_metrics_decoded, prefix="decoded_ood_")




In [33]:
test_deg_dict = {
    k: v
    for k, v in adata_train.uns["rank_genes_groups_cov_all"].items()
    if k in test_data_target_decoded_predicted.keys()
}
ood_deg_dict = {
    k: v
    for k, v in adata_train.uns["rank_genes_groups_cov_all"].items()
    if k in ood_data_target_decoded_predicted.keys()
}

In [34]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]

ood_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, ood_data_target_decoded_predicted, ood_deg_dict)
ood_deg_target_decoded = jax.tree_util.tree_map(get_mask, ood_data_target_decoded, ood_deg_dict)


In [35]:
deg_ood_metrics = jax.tree_util.tree_map(compute_metrics, ood_deg_target_decoded, ood_deg_target_decoded_predicted)
deg_mean_ood_metrics = compute_mean_metrics(deg_ood_metrics, prefix="deg_ood_")


In [36]:
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/results/identity"

In [37]:
import os
import pandas as pd

pd.DataFrame.from_dict(ood_metrics_encoded).to_csv(os.path.join(output_dir, f"ood_metrics_encoded_{split}.csv"))
pd.DataFrame.from_dict(ood_metrics_decoded).to_csv(os.path.join(output_dir, f"ood_metrics_decoded_{split}.csv"))
pd.DataFrame.from_dict(deg_ood_metrics).to_csv(os.path.join(output_dir, f"ood_metrics_ood_{split}.csv"))


In [38]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.9579607951686757,
 'decoded_ood_e_distance': 2.8310221939633498,
 'decoded_ood_mmd_distance': 0.008437863255804583}

In [39]:
mean_ood_metrics_encoded

{'encoded_ood_r_squared': 0.9436248149223254,
 'encoded_ood_sinkhorn_div_1': 1.2033273474923496,
 'encoded_ood_sinkhorn_div_10': 0.6251639278455713,
 'encoded_ood_sinkhorn_div_100': 0.5681027352124796,
 'encoded_ood_e_distance': 1.1270504940033497,
 'encoded_ood_mmd': 0.01547972324894834}