In [1]:
import functools

import jax
import numpy as np
import scanpy as sc

from ot_pert.metrics import compute_mean_metrics, compute_metrics

In [8]:
adata_train_path= "/lustre/groups/ml01/workspace/ot_perturbation/data/norman/adata_train_0_seen_genes.h5ad"
adata_test_path= "/lustre/groups/ml01/workspace/ot_perturbation/data/norman/adata_test_0_seen_genes.h5ad"
adata_ood_path= "/lustre/groups/ml01/workspace/ot_perturbation/data/norman/adata_ood_0_seen_genes.h5ad"

In [9]:
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)
adata_pred_test = sc.read("/lustre/groups/ml01/workspace/ot_perturbation/models/biolord/norman/biolord_output_test.h5ad")
adata_pred_ood = sc.read("/lustre/groups/ml01/workspace/ot_perturbation/models/biolord/norman/biolord_output_ood.h5ad")



In [10]:
OBSM_KEY_DATA_EMBEDDING = "X_pca"

In [11]:
def reconstruct_data(embedding: np.ndarray, projection_matrix: np.ndarray, mean_to_add: np.ndarray) -> np.ndarray:
    return np.matmul(embedding, projection_matrix.T) + mean_to_add


def project_data(data: np.ndarray, projection_matrix: np.ndarray, mean_to_subtract: np.ndarray) -> np.ndarray:
    return np.matmul(data - mean_to_subtract, projection_matrix)

In [12]:
project_data_fn = functools.partial(
    project_data,
    projection_matrix=adata_train.varm["PCs"],
    mean_to_subtract=adata_train.varm["X_train_mean"].T,
)

In [14]:
test_data_target_encoded = {}
test_data_target_decoded = {}
test_data_target_encoded_predicted = {}
test_data_target_decoded_predicted = {}
for cond in adata_test.obs["condition"].cat.categories:
    if cond != "control":
        pred = adata_pred_test[adata_pred_test.obs["condition"] == cond].X
        test_data_target_encoded[cond] = adata_test[adata_test.obs["condition"] == cond].obsm[OBSM_KEY_DATA_EMBEDDING]
        test_data_target_decoded[cond] = adata_test[adata_test.obs["condition"] == cond].X.A
        test_data_target_decoded_predicted[cond] = pred
        test_data_target_encoded_predicted[cond] = project_data_fn(pred)

ood_data_target_encoded = {}
ood_data_target_decoded = {}
ood_data_target_encoded_predicted = {}
ood_data_target_decoded_predicted = {}
for cond in adata_ood.obs["condition"].cat.categories:
    if cond != "control":
        pred = adata_pred_ood[adata_pred_ood.obs["condition"] == cond].X
        ood_data_target_encoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].obsm[OBSM_KEY_DATA_EMBEDDING]
        ood_data_target_decoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].X.A
        ood_data_target_decoded_predicted[cond] = pred
        ood_data_target_encoded_predicted[cond] = project_data_fn(pred)

In [15]:
#test_deg_dict = {
#    k: v
#    for k, v in adata_train.uns["rank_genes_groups_cov_all"].items()
#    if k in test_data_target_decoded_predicted.keys()
#}
ood_deg_dict = {
    k: v
    for k, v in adata_train.uns["rank_genes_groups_cov_all"].items()
    if k in ood_data_target_decoded_predicted.keys()
}

In [16]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]


ood_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, ood_data_target_decoded_predicted, ood_deg_dict)
ood_deg_target_decoded = jax.tree_util.tree_map(get_mask, ood_data_target_decoded, ood_deg_dict)

#test_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, test_data_target_decoded_predicted, test_deg_dict)
#test_deg_target_decoded = jax.tree_util.tree_map(get_mask, test_data_target_decoded, test_deg_dict)


ValueError: Dict key mismatch; expected keys: ['BAK1+KLF1', 'BAK1+TMSB4X', 'CBFA2T3+FEV', 'CBL+TGFBR2', 'CEBPB+CEBPE', 'CEBPE+PTPN12', 'DUSP9+KLF1', 'DUSP9+PRTG', 'ELMSAN1+MAP2K6', 'FOXA1+KLF1', 'MAP2K6+SPI1', 'MAPK1+PRTG', 'PRTG+TGFBR2', 'S1PR2+SGK1']; dict: {}.

In [None]:
test_metrics_encoded = jax.tree_util.tree_map(compute_metrics, test_data_target_encoded, test_data_target_encoded_predicted)
mean_test_metrics_encoded = compute_mean_metrics(test_metrics_encoded, prefix="encoded_test_")

test_metrics_decoded = jax.tree_util.tree_map(compute_metrics, test_data_target_decoded, test_data_target_decoded_predicted)
mean_test_metrics_decoded = compute_mean_metrics(test_metrics_decoded, prefix="decoded_test_")

In [17]:
ood_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, ood_data_target_encoded, ood_data_target_encoded_predicted
)
mean_ood_metrics_encoded = compute_mean_metrics(ood_metrics_encoded, prefix="encoded_ood_")

ood_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics, ood_data_target_decoded, ood_data_target_decoded_predicted
)
mean_ood_metrics_decoded = compute_mean_metrics(ood_metrics_decoded, prefix="decoded_ood_")

2024-06-26 13:57:44.398734: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.3 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [18]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.9661981356589792,
 'decoded_ood_sinkhorn_div_1': 2007.3167550223213,
 'decoded_ood_sinkhorn_div_10': 1985.9652099609375,
 'decoded_ood_sinkhorn_div_100': 1622.461678641183,
 'decoded_ood_e_distance': 44.09208879204103,
 'decoded_ood_mmd': 0.06431907509633154}

In [None]:
mean_test_metrics_decoded

In [None]:
deg_ood_metrics = jax.tree_util.tree_map(compute_metrics, ood_deg_target_decoded, ood_deg_target_decoded_predicted)
deg_mean_ood_metrics = compute_mean_metrics(deg_ood_metrics, prefix="deg_ood_")

deg_test_metrics = jax.tree_util.tree_map(compute_metrics, test_deg_target_decoded, test_deg_target_decoded_predicted)
deg_mean_test_metrics = compute_mean_metrics(deg_test_metrics, prefix="deg_test_")


In [None]:
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/results/biolord"

In [None]:
import os
import pandas as pd

pd.DataFrame.from_dict(ood_metrics_encoded).to_csv(os.path.join(output_dir, "biolord_split_ood_metrics_encoded.csv"))
pd.DataFrame.from_dict(ood_metrics_decoded).to_csv(os.path.join(output_dir, "biolord_split_ood_metrics_decoded.csv"))
pd.DataFrame.from_dict(test_metrics_encoded).to_csv(os.path.join(output_dir, "biolord_split_test_metrics_encoded.csv"))
pd.DataFrame.from_dict(test_metrics_decoded).to_csv(os.path.join(output_dir, "biolord_split_test_metrics_decoded.csv"))
pd.DataFrame.from_dict(train_metrics_encoded).to_csv(os.path.join(output_dir, "biolord_split_train_metrics_encoded.csv"))
pd.DataFrame.from_dict(train_metrics_decoded).to_csv(os.path.join(output_dir, "biolord_split_train_metrics_decoded.csv"))