In [35]:
import functools

import jax
import numpy as np
import scanpy as sc

from ot_pert.metrics import compute_mean_metrics, compute_metrics

In [36]:
adata_preds = sc.read("/lustre/groups/ml01/workspace/ot_perturbation/models/chemcpa/combosciplex/adata_with_predictions.h5ad")

  utils.warn_names_duplicates("obs")


In [37]:
adata_train_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_train_300.h5ad"
adata_test_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_test_300.h5ad"
adata_ood_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_ood_300.h5ad"

In [38]:
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)



In [39]:
adata_preds.obsm

AxisArrays with keys: CPA_pred, perts, perts_doses

In [40]:
def reconstruct_data(embedding: np.ndarray, projection_matrix: np.ndarray, mean_to_add: np.ndarray) -> np.ndarray:
    return np.matmul(embedding, projection_matrix.T) + mean_to_add


def project_data(data: np.ndarray, projection_matrix: np.ndarray, mean_to_subtract: np.ndarray) -> np.ndarray:
    return np.matmul(data - mean_to_subtract, projection_matrix)

In [41]:
OBSM_KEY_DATA_EMBEDDING = "X_pca"

In [42]:
project_data_fn = functools.partial(
    project_data,
    projection_matrix=adata_train.varm["PCs"],
    mean_to_subtract=adata_train.varm["X_train_mean"].T,
)

In [43]:
adata_preds

AnnData object with n_obs × n_vars = 67382 × 2000
    obs: 'condition', 'cell_type', 'condition_ID', 'log_dose', 'smiles_rdkit', 'split', 'split_1ct_MEC', 'CPA_cat', 'CPA_CHEMBL504', '_scvi_condition_ID', '_scvi_cell_type', '_scvi_CPA_cat'
    uns: '_scvi_manager_uuid', '_scvi_uuid'
    obsm: 'CPA_pred', 'perts', 'perts_doses'

In [50]:
test_data_target_encoded = {}
test_data_target_decoded = {}
test_data_target_encoded_predicted = {}
test_data_target_decoded_predicted = {}
for cond in adata_test.obs["condition"].cat.categories:
    if cond == "control":
        continue
    test_data_target_encoded[cond] = adata_test[adata_test.obs["condition"] == cond].obsm[OBSM_KEY_DATA_EMBEDDING]
    test_data_target_decoded[cond] = adata_test[adata_test.obs["condition"] == cond].X.A
    pred_cpa = adata_preds[adata_preds.obs["condition"] == cond].obsm["CPA_pred"]
    test_data_target_decoded_predicted[cond] = pred_cpa
    test_data_target_encoded_predicted[cond] = project_data_fn(pred_cpa)


ood_data_target_encoded = {}
ood_data_target_decoded = {}
ood_data_target_encoded_predicted = {}
ood_data_target_decoded_predicted = {}
for cond in adata_ood.obs["condition"].cat.categories:
    if cond == "control":
        continue
    ood_data_target_encoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].obsm[OBSM_KEY_DATA_EMBEDDING]
    ood_data_target_decoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].X.A
    pred_cpa = adata_preds[adata_preds.obs["condition"] == cond].obsm["CPA_pred"]
    ood_data_target_decoded_predicted[cond] = pred_cpa
    ood_data_target_encoded_predicted[cond] = project_data_fn(pred_cpa)

In [51]:
test_deg_dict = {
    k: v
    for k, v in adata_train.uns["rank_genes_groups_cov_all"].items()
    if k in test_data_target_decoded_predicted.keys()
}
ood_deg_dict = {
    k: v
    for k, v in adata_train.uns["rank_genes_groups_cov_all"].items()
    if k in ood_data_target_decoded_predicted.keys()
}

In [52]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]


ood_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, ood_data_target_decoded_predicted, ood_deg_dict)
ood_deg_target_decoded = jax.tree_util.tree_map(get_mask, ood_data_target_decoded, ood_deg_dict)

test_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, test_data_target_decoded_predicted, test_deg_dict)
test_deg_target_decoded = jax.tree_util.tree_map(get_mask, test_data_target_decoded, test_deg_dict)


In [47]:
deg_ood_metrics = jax.tree_util.tree_map(compute_metrics, ood_deg_target_decoded, ood_deg_target_decoded_predicted)
deg_mean_ood_metrics = compute_mean_metrics(deg_ood_metrics, prefix="deg_ood_")

deg_test_metrics = jax.tree_util.tree_map(compute_metrics, test_deg_target_decoded, test_deg_target_decoded_predicted)
deg_mean_test_metrics = compute_mean_metrics(deg_test_metrics, prefix="deg_test_")


In [53]:
test_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, test_data_target_encoded, test_data_target_encoded_predicted
)
mean_test_metrics_encoded = compute_mean_metrics(test_metrics_encoded, prefix="encoded_test_")

test_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics, test_data_target_decoded, test_data_target_decoded_predicted
)
mean_test_metrics_decoded = compute_mean_metrics(test_metrics_decoded, prefix="decoded_test_")

In [None]:
ood_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, ood_data_target_encoded, ood_data_target_encoded_predicted
)
mean_ood_metrics_encoded = compute_mean_metrics(ood_metrics_encoded, prefix="encoded_ood_")

ood_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics, ood_data_target_decoded, ood_data_target_decoded_predicted
)
mean_ood_metrics_decoded = compute_mean_metrics(ood_metrics_decoded, prefix="decoded_ood_")

In [32]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.7406127580322035,
 'decoded_ood_sinkhorn_div_1': 104.2943115234375,
 'decoded_ood_sinkhorn_div_10': 72.1079818725586,
 'decoded_ood_sinkhorn_div_100': 27.471092224121094,
 'decoded_ood_e_distance': 54.005889533955006,
 'decoded_ood_mmd': 0.37630357790693436}

1

In [None]:
mean_test_metrics_decoded

In [None]:
mean_ood_metrics_decoded

In [None]:
deg_mean_ood_metrics

In [None]:
mean_test_metrics_decoded

In [None]:
mean_test_metrics_encoded

In [None]:
ood_metrics_encoded

In [None]:
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/results/chemcpa"

In [None]:
import os

pd.DataFrame.from_dict(ood_metrics_encoded).to_csv(os.path.join(output_dir, "ood_metrics_encoded.csv"))
pd.DataFrame.from_dict(ood_metrics_decoded).to_csv(os.path.join(output_dir, "ood_metrics_decoded.csv"))
pd.DataFrame.from_dict(test_metrics_encoded).to_csv(os.path.join(output_dir, "test_metrics_encoded.csv"))
pd.DataFrame.from_dict(test_metrics_decoded).to_csv(os.path.join(output_dir, "test_metrics_decoded.csv"))
pd.DataFrame.from_dict(test_metrics_encoded).to_csv(os.path.join(output_dir, "test_metrics_encoded.csv"))
pd.DataFrame.from_dict(test_metrics_decoded).to_csv(os.path.join(output_dir, "test_metrics_decoded.csv"))


In [None]:
mean_ood_metrics_encoded

In [None]:
mean_ood_metrics_decoded

In [None]:
adata[adata.obs["split"] == "ood"].obs["condition"].value_counts()

In [None]:
adata_train.obs["split"].value_counts()

In [None]:
adata_test.obs["split"].value_counts()

In [None]:
adata_ood.obs["split"].value_counts()