In [1]:
import functools

import jax
import numpy as np
import scanpy as sc

from ot_pert.metrics import compute_mean_metrics, compute_metrics

In [2]:
adata = sc.read("/lustre/groups/ml01/workspace/ot_perturbation/models/chemcpa/combosciplex/adata_with_predictions_30.h5ad")

  utils.warn_names_duplicates("obs")


In [3]:
adata_train_orig = sc.read("/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_train_30.h5ad")



In [4]:
adata_train_orig.uns.keys()

dict_keys(['Drug1_colors', 'Drug2_colors', 'Well_colors', 'condition_colors', 'dendrogram_leiden', 'hvg', 'leiden', 'leiden_colors', 'log1p', 'neighbors', 'pathway1_colors', 'pathway2_colors', 'pathway_colors', 'pca', 'rank_genes_groups', 'rank_genes_groups_cov_all', 'split_colors', 'umap'])

In [5]:
preds = adata.obsm["CPA_pred"]

In [6]:
def reconstruct_data(embedding: np.ndarray, projection_matrix: np.ndarray, mean_to_add: np.ndarray) -> np.ndarray:
    return np.matmul(embedding, projection_matrix.T) + mean_to_add


def project_data(data: np.ndarray, projection_matrix: np.ndarray, mean_to_subtract: np.ndarray) -> np.ndarray:
    return np.matmul(data - mean_to_subtract, projection_matrix)

In [7]:
OBSM_KEY_DATA_EMBEDDING = "X_pca"

In [8]:
adata_train = adata[adata.obs["split"] == "train"]
adata_test = adata[adata.obs["split"] == "test"]
adata_ood = adata[adata.obs["split"] == "ood"]

In [9]:
adata_ood.obsm["CPA_pred"].shape

(8896, 2000)

In [10]:
project_data_fn = functools.partial(
    project_data,
    projection_matrix=adata_train_orig.varm["PCs"],
    mean_to_subtract=adata_train_orig.varm["X_train_mean"].T,
)

In [11]:
test_data_target_encoded = {}
test_data_target_decoded = {}
test_data_target_encoded_predicted = {}
test_data_target_decoded_predicted = {}
for cond in adata_test.obs["condition"].cat.categories:
    if cond == "control":
        continue
    test_data_target_encoded[cond] = adata_test[adata_test.obs["condition"] == cond].obsm[OBSM_KEY_DATA_EMBEDDING]
    test_data_target_decoded[cond] = adata_test[adata_test.obs["condition"] == cond].X.A
    #pred_cpa = np.log1p(adata_test[adata_test.obs["condition"] == cond].obsm["CPA_pred"])
    pred_cpa = adata_test[adata_test.obs["condition"] == cond].obsm["CPA_pred"]
    test_data_target_decoded_predicted[cond] = pred_cpa
    test_data_target_encoded_predicted[cond] = project_data_fn(pred_cpa)


ood_data_target_encoded = {}
ood_data_target_decoded = {}
ood_data_target_encoded_predicted = {}
ood_data_target_decoded_predicted = {}
for cond in adata_ood.obs["condition"].cat.categories:
    if cond == "control":
        continue
    ood_data_target_encoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].obsm[OBSM_KEY_DATA_EMBEDDING]
    ood_data_target_decoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].X.A
    #pred_cpa = np.log1p(adata_ood[adata_ood.obs["condition"] == cond].obsm["CPA_pred"])
    pred_cpa = adata_ood[adata_ood.obs["condition"] == cond].obsm["CPA_pred"]
    ood_data_target_decoded_predicted[cond] = pred_cpa
    ood_data_target_encoded_predicted[cond] = project_data_fn(pred_cpa)

In [12]:
#train_deg_dict = {
#    k: v
#    for k, v in adata_train_orig.uns["rank_genes_groups_cov_all"].items()
#    if k in train_data_target_decoded_predicted.keys()
#}
test_deg_dict = {
    k: v
    for k, v in adata_train_orig.uns["rank_genes_groups_cov_all"].items()
    if k in test_data_target_decoded_predicted.keys()
}
ood_deg_dict = {
    k: v
    for k, v in adata_train_orig.uns["rank_genes_groups_cov_all"].items()
    if k in ood_data_target_decoded_predicted.keys()
}

In [13]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]


ood_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, ood_data_target_decoded_predicted, ood_deg_dict)
ood_deg_target_decoded = jax.tree_util.tree_map(get_mask, ood_data_target_decoded, ood_deg_dict)

test_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, test_data_target_decoded_predicted, test_deg_dict)
test_deg_target_decoded = jax.tree_util.tree_map(get_mask, test_data_target_decoded, test_deg_dict)

#train_deg_target_decoded_predicted = jax.tree_util.tree_map(
#    get_mask, train_data_target_decoded_predicted, train_deg_dict
#)
#train_deg_target_decoded = jax.tree_util.tree_map(get_mask, train_data_target_decoded, test_deg_dict)

In [14]:
deg_ood_metrics = jax.tree_util.tree_map(compute_metrics, ood_deg_target_decoded, ood_deg_target_decoded_predicted)
deg_mean_ood_metrics = compute_mean_metrics(deg_ood_metrics, prefix="deg_ood_")

deg_test_metrics = jax.tree_util.tree_map(compute_metrics, test_deg_target_decoded, test_deg_target_decoded_predicted)
deg_mean_test_metrics = compute_mean_metrics(deg_test_metrics, prefix="deg_test_")

#deg_train_metrics = jax.tree_util.tree_map(
#    compute_metrics, train_deg_target_decoded, train_deg_target_decoded_predicted
#)
#deg_mean_train_metrics = compute_mean_metrics(deg_train_metrics, prefix="deg_train_")

2024-06-13 15:12:47.319419: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.3 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [15]:
#train_metrics_encoded = jax.tree_util.tree_map(
#    compute_metrics, train_data_target_encoded, train_data_target_encoded_predicted
#)
#mean_train_metrics_encoded = compute_mean_metrics(train_metrics_encoded, prefix="encoded_train_")

#train_metrics_decoded = jax.tree_util.tree_map(
#    compute_metrics, train_data_target_decoded, train_data_target_decoded_predicted
#)
#mean_train_metrics_decoded = compute_mean_metrics(train_metrics_decoded, prefix="decoded_train_")

In [16]:
test_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, test_data_target_encoded, test_data_target_encoded_predicted
)
mean_test_metrics_encoded = compute_mean_metrics(test_metrics_encoded, prefix="encoded_test_")

test_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics, test_data_target_decoded, test_data_target_decoded_predicted
)
mean_test_metrics_decoded = compute_mean_metrics(test_metrics_decoded, prefix="decoded_test_")

In [17]:
ood_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, ood_data_target_encoded, ood_data_target_encoded_predicted
)
mean_ood_metrics_encoded = compute_mean_metrics(ood_metrics_encoded, prefix="encoded_ood_")

ood_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics, ood_data_target_decoded, ood_data_target_decoded_predicted
)
mean_ood_metrics_decoded = compute_mean_metrics(ood_metrics_decoded, prefix="decoded_ood_")

In [18]:
mean_test_metrics_decoded

{'decoded_test_r_squared': 0.7633810649874615,
 'decoded_test_sinkhorn_div_1': 101.37817060030423,
 'decoded_test_sinkhorn_div_10': 81.09938944303073,
 'decoded_test_sinkhorn_div_100': 24.61082282433143,
 'decoded_test_e_distance': 45.52553616143315,
 'decoded_test_mmd': 0.42768340171662145}

In [19]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.7406613479771971,
 'decoded_ood_sinkhorn_div_1': 104.32796325683594,
 'decoded_ood_sinkhorn_div_10': 72.12513732910156,
 'decoded_ood_sinkhorn_div_100': 27.474310302734374,
 'decoded_ood_e_distance': 54.009219277932914,
 'decoded_ood_mmd': 0.4245061780581649}

In [20]:
deg_mean_ood_metrics

{'deg_ood_r_squared': 0.5194864385439965,
 'deg_ood_sinkhorn_div_1': 24.515447902679444,
 'deg_ood_sinkhorn_div_10': 16.35088233947754,
 'deg_ood_sinkhorn_div_100': 15.705111694335937,
 'deg_ood_e_distance': 31.27432490899064,
 'deg_ood_mmd': 0.47959513276259713}

In [21]:
mean_test_metrics_decoded

{'decoded_test_r_squared': 0.7633810649874615,
 'decoded_test_sinkhorn_div_1': 101.37817060030423,
 'decoded_test_sinkhorn_div_10': 81.09938944303073,
 'decoded_test_sinkhorn_div_100': 24.61082282433143,
 'decoded_test_e_distance': 45.52553616143315,
 'decoded_test_mmd': 0.42768340171662145}

In [22]:
mean_test_metrics_encoded

{'encoded_test_r_squared': -0.2775952443570995,
 'encoded_test_sinkhorn_div_1': 33.99979521678044,
 'encoded_test_sinkhorn_div_10': 23.868965295644905,
 'encoded_test_sinkhorn_div_100': 21.590697655310997,
 'encoded_test_e_distance': 42.79317362639634,
 'encoded_test_mmd': 0.38909356945343226}

In [19]:
ood_metrics_encoded

{'Cediranib+PCI-34051': {'r_squared': 0.9670571868604838,
  'sinkhorn_div_1': 5.284296035766602,
  'sinkhorn_div_10': 0.6397514343261719,
  'sinkhorn_div_100': 0.46512794494628906,
  'e_distance': 0.900585264605688,
  'mmd': 0.009130221615835847},
 'Givinostat+SRT1720': {'r_squared': 0.9742046491386951,
  'sinkhorn_div_1': 5.229059219360352,
  'sinkhorn_div_10': 0.3977165222167969,
  'sinkhorn_div_100': 0.2260417938232422,
  'e_distance': 0.42198174452832404,
  'mmd': 0.007237379008675678},
 'Panobinostat+Crizotinib': {'r_squared': 0.9887869542746437,
  'sinkhorn_div_1': 4.363508224487305,
  'sinkhorn_div_10': 0.5683212280273438,
  'sinkhorn_div_100': 0.3817615509033203,
  'e_distance': 0.7275057020834055,
  'mmd': 0.011857479357851378},
 'Panobinostat+PCI-34051': {'r_squared': 0.9895649534603667,
  'sinkhorn_div_1': 4.524935722351074,
  'sinkhorn_div_10': 0.4970741271972656,
  'sinkhorn_div_100': 0.28163909912109375,
  'e_distance': 0.5249919406934112,
  'mmd': 0.011476454900411095},


In [26]:
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/results/chemcpa"

In [29]:
import os
import pandas as pd
pd.DataFrame.from_dict(ood_metrics_encoded).to_csv(os.path.join(output_dir, "ood_metrics_encoded.csv"))
pd.DataFrame.from_dict(ood_metrics_decoded).to_csv(os.path.join(output_dir, "ood_metrics_decoded.csv"))
pd.DataFrame.from_dict(test_metrics_encoded).to_csv(os.path.join(output_dir, "test_metrics_encoded.csv"))
pd.DataFrame.from_dict(test_metrics_decoded).to_csv(os.path.join(output_dir, "test_metrics_decoded.csv"))
pd.DataFrame.from_dict(deg_test_metrics).to_csv(os.path.join(output_dir, "test_metrics_deg.csv"))
pd.DataFrame.from_dict(deg_ood_metrics).to_csv(os.path.join(output_dir, "ood_metrics_deg.csv"))