In [1]:
import functools

import jax
import numpy as np
import scanpy as sc
import cfp.preprocessing as cfpp
from cfp.metrics import compute_mean_metrics, compute_metrics, compute_metrics_fast

In [2]:
ood_split=2

In [3]:
adata_train_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_train_{ood_split}.h5ad"
adata_test_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_test_{ood_split}.h5ad"
adata_ood_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_ood_{ood_split}.h5ad"

In [4]:
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)
adata_pred_test = sc.read(f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/biolord_output_test_{ood_split}.h5ad")
adata_pred_ood = sc.read(f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/biolord_output_ood_{ood_split}.h5ad")

In [5]:
adata_ref_ood = adata_ood[adata_ood.obs["condition"].str.contains('Vehicle')].copy()
cfpp.centered_pca(adata_ref_ood, n_comps=10)

adata_ref_test = adata_test[adata_test.obs["condition"].str.contains('Vehicle')].copy()
cfpp.centered_pca(adata_ref_test, n_comps=10)

In [6]:

cfpp.project_pca(query_adata=adata_pred_test, ref_adata=adata_ref_test)
cfpp.project_pca(query_adata=adata_test, ref_adata=adata_ref_test)
test_data_target_encoded = {}
test_data_target_decoded = {}
test_data_target_encoded_predicted = {}
test_data_target_decoded_predicted = {}
for cond in adata_test.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    test_data_target_encoded[cond] = adata_test[adata_test.obs["condition"] == cond].obsm["X_pca"]
    test_data_target_decoded[cond] = adata_test[adata_test.obs["condition"] == cond].X.toarray()
    test_data_target_decoded_predicted[cond] = adata_pred_test[adata_pred_test.obs["condition"] == cond].X
    test_data_target_encoded_predicted[cond] = adata_pred_test[adata_pred_test.obs["condition"] == cond].obsm["X_pca"]

In [7]:

cfpp.project_pca(query_adata=adata_pred_ood, ref_adata=adata_ref_ood)
cfpp.project_pca(query_adata=adata_ood, ref_adata=adata_ref_ood)
ood_data_target_encoded = {}
ood_data_target_decoded = {}
ood_data_target_encoded_predicted = {}
ood_data_target_decoded_predicted = {}
for cond in adata_ood.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    ood_data_target_encoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].obsm["X_pca"]
    ood_data_target_decoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].X.toarray()
    ood_data_target_decoded_predicted[cond] = adata_pred_ood[adata_pred_ood.obs["condition"] == cond].X
    ood_data_target_encoded_predicted[cond] = adata_pred_ood[adata_pred_ood.obs["condition"] == cond].obsm["X_pca"]

In [8]:

test_deg_dict = {
    k: v
    for k, v in adata_train.uns["rank_genes_groups_cov_all"].items()
    if k in test_data_target_decoded_predicted.keys()
}

ood_deg_dict = {
    k: v
    for k, v in adata_train.uns["rank_genes_groups_cov_all"].items()
    if k in ood_data_target_decoded_predicted.keys()
}

In [9]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]


ood_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, ood_data_target_decoded_predicted, ood_deg_dict)
ood_deg_target_decoded = jax.tree_util.tree_map(get_mask, ood_data_target_decoded, ood_deg_dict)

test_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, test_data_target_decoded_predicted, test_deg_dict)
test_deg_target_decoded = jax.tree_util.tree_map(get_mask, test_data_target_decoded, test_deg_dict)


In [10]:
deg_ood_metrics = jax.tree_util.tree_map(compute_metrics, ood_deg_target_decoded, ood_deg_target_decoded_predicted)
deg_mean_ood_metrics = compute_mean_metrics(deg_ood_metrics, prefix="deg_ood_")

deg_test_metrics = jax.tree_util.tree_map(compute_metrics, test_deg_target_decoded, test_deg_target_decoded_predicted)
deg_mean_test_metrics = compute_mean_metrics(deg_test_metrics, prefix="deg_test_")


2024-09-24 12:11:21.989336: E external/xla/xla/service/hlo_lexer.cc:438] Failed to parse int literal: 28675072374034695991782
2024-09-24 12:11:24.086306: E external/xla/xla/service/hlo_lexer.cc:438] Failed to parse int literal: 28675072374034695991782


In [11]:
ood_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, ood_data_target_encoded, ood_data_target_encoded_predicted
)
mean_ood_metrics_encoded = compute_mean_metrics(ood_metrics_encoded, prefix="encoded_ood_")

ood_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics, ood_data_target_decoded, ood_data_target_decoded_predicted
)
mean_ood_metrics_decoded = compute_mean_metrics(ood_metrics_decoded, prefix="decoded_ood_")

2024-09-24 14:55:41.956524: E external/xla/xla/service/hlo_lexer.cc:438] Failed to parse int literal: 27610039482389330461
2024-09-24 15:02:49.077833: E external/xla/xla/service/hlo_lexer.cc:438] Failed to parse int literal: 28675072374034695991782
2024-09-24 15:02:51.800500: E external/xla/xla/service/hlo_lexer.cc:438] Failed to parse int literal: 28675072374034695991782
2024-09-24 15:10:31.326758: E external/xla/xla/service/hlo_lexer.cc:438] Failed to parse int literal: 8086530955859784442449
2024-09-24 16:04:52.471675: E external/xla/xla/service/hlo_lexer.cc:438] Failed to parse int literal: 9828457231919883051687
2024-09-24 16:13:42.415840: E external/xla/xla/service/hlo_lexer.cc:438] Failed to parse int literal: 28675072374034695991782
2024-09-24 16:13:45.863812: E external/xla/xla/service/hlo_lexer.cc:438] Failed to parse int literal: 28675072374034695991782


In [12]:
test_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, test_data_target_encoded, test_data_target_encoded_predicted
)
mean_test_metrics_encoded = compute_mean_metrics(test_metrics_encoded, prefix="encoded_test_")

test_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics_fast, test_data_target_decoded, test_data_target_decoded_predicted
)
mean_test_metrics_decoded = compute_mean_metrics(test_metrics_decoded, prefix="decoded_test_")

In [13]:
ood_metrics_encoded

{'A549_A-366_10.0': {'r_squared': 0.9699674248695374,
  'sinkhorn_div_1': 1.6901137828826904,
  'sinkhorn_div_10': 0.767064094543457,
  'sinkhorn_div_100': 0.5508427619934082,
  'e_distance': 1.0531073075517878,
  'mmd': 0.07655787},
 'A549_A-366_100.0': {'r_squared': 0.9848592281341553,
  'sinkhorn_div_1': 1.6192491054534912,
  'sinkhorn_div_10': 0.5155563354492188,
  'sinkhorn_div_100': 0.26420068740844727,
  'e_distance': 0.4720503815107109,
  'mmd': 0.071314014},
 'A549_A-366_1000.0': {'r_squared': 0.9900884628295898,
  'sinkhorn_div_1': 1.487593412399292,
  'sinkhorn_div_10': 0.41553211212158203,
  'sinkhorn_div_100': 0.18543672561645508,
  'e_distance': 0.31912549307749405,
  'mmd': 0.0701856},
 'A549_A-366_10000.0': {'r_squared': 0.9534733295440674,
  'sinkhorn_div_1': 2.0659985542297363,
  'sinkhorn_div_10': 0.9924569129943848,
  'sinkhorn_div_100': 0.7533035278320312,
  'e_distance': 1.4531223909216395,
  'mmd': 0.082502075},
 'A549_AR-42_10.0': {'r_squared': 0.984667837619781

In [14]:
mean_ood_metrics_encoded

{'encoded_ood_r_squared': 0.9466261406739552,
 'encoded_ood_sinkhorn_div_1': 2.300351776458599,
 'encoded_ood_sinkhorn_div_10': 0.9819601526966801,
 'encoded_ood_sinkhorn_div_100': 0.5561536846337495,
 'encoded_ood_e_distance': 1.0180754927178237,
 'encoded_ood_mmd': 0.07880950635644021}

In [15]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.9614022057365488,
 'decoded_ood_sinkhorn_div_1': 47.79938374625312,
 'decoded_ood_sinkhorn_div_10': 23.220301373799643,
 'decoded_ood_sinkhorn_div_100': 1.6887848324245878,
 'decoded_ood_e_distance': 2.651926480439029,
 'decoded_ood_mmd': 0.03470122341273559}

In [16]:
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/results/biolord"

In [17]:
import os
import pandas as pd
split = ood_split
pd.DataFrame.from_dict(ood_metrics_encoded).to_csv(os.path.join(output_dir, f"ood_metrics_encoded_{split}.csv"))
pd.DataFrame.from_dict(ood_metrics_decoded).to_csv(os.path.join(output_dir, f"ood_metrics_decoded_{split}.csv"))
pd.DataFrame.from_dict(test_metrics_encoded).to_csv(os.path.join(output_dir, f"test_metrics_encoded_{split}.csv"))
pd.DataFrame.from_dict(test_metrics_decoded).to_csv(os.path.join(output_dir, f"test_metrics_decoded_{split}.csv"))
pd.DataFrame.from_dict(deg_test_metrics).to_csv(os.path.join(output_dir, f"test_metrics_deg_{split}.csv"))
pd.DataFrame.from_dict(deg_ood_metrics).to_csv(os.path.join(output_dir, f"ood_metrics_ood_{split}.csv"))


In [None]:
1