In [1]:
import functools
import jax
import numpy as np
import scanpy as sc
import cfp.preprocessing as cfpp
from cfp.metrics import compute_mean_metrics, compute_metrics, compute_metrics_fast
import os
import pandas as pd
import sys
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
split = 0 

In [3]:
path_to_splits = Path("/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/adata_ood_final_genes/adata_ood_final_genesIFNG_IFNB_TNFA_TGFB_INS_hvg-500_pca-100_counts_ms_0.5")
path_to_generated = Path("/lustre/groups/ml01/workspace/alessandro.palma/ot_pert/out/results_metrics")

Load the real data 

In [4]:
adata_train_path = path_to_splits / "adata_train_split_0.h5ad"
adata_test_path = path_to_splits / "adata_test_split_0.h5ad"
adata_ood_path = path_to_splits / "adata_ood_split_0.h5ad"
adata_train = sc.read_h5ad(adata_train_path)
adata_test = sc.read_h5ad(adata_test_path)
adata_ood = sc.read_h5ad(adata_ood_path)

In [5]:
len(np.intersect1d(adata_train.var.index, adata_test.var.index))

8265

Load the generated data 

Read the whole datastet (use for PCA)

In [61]:
adata_ref = sc.read_h5ad("/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/full_adata_with_splits.h5ad")

In [71]:
np.intersect1d(adata_ref.var.index.unique(), adata_train.var.index.unique())

array(['A2M', 'A2M-AS1', 'A4GNT', ..., 'ZPLD1', 'ZSCAN12P1', 'ZXDB'],
      dtype=object)

In [66]:
adata_train.var.index.unique()

Index(['A2M', 'A2M-AS1', 'A4GNT', 'AADACL2', 'AADAT', 'AAK1', 'AANAT', 'AARD',
       'AATBC', 'AB015752.1',
       ...
       'ZNRD1', 'ZNRF3-AS1', 'ZP2', 'ZPBP', 'ZPLD1', 'ZSCAN10', 'ZSCAN12P1',
       'ZSCAN5A', 'ZSCAN5DP', 'ZXDB'],
      dtype='object', length=8265)

Centered PCA `adata_ref`

In [None]:
cfpp.centered_pca(adata_ref, n_comps=20)

Projected PCA 

In [None]:
cfpp.project_pca(query_adata=adata_pred_test, ref_adata=adata_ref)
cfpp.project_pca(query_adata=adata_test, ref_adata=adata_ref)

In [None]:
test_data_target_encoded = {}
test_data_target_decoded = {}
test_data_target_encoded_predicted = {}
test_data_target_decoded_predicted = {}

for cond in adata_test.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    test_data_target_encoded[cond] = adata_test[adata_test.obs["condition"] == cond].obsm["X_pca"]
    test_data_target_decoded[cond] = adata_test[adata_test.obs["condition"] == cond].X.toarray()
    test_data_target_decoded_predicted[cond] = adata_pred_test[adata_pred_test.obs["condition"] == cond].layers["X_recon_pred"]
    test_data_target_encoded_predicted[cond] = adata_pred_test[adata_pred_test.obs["condition"] == cond].obsm["X_pca"]

In [None]:
cfpp.project_pca(query_adata=adata_pred_ood, ref_adata=adata_ref)
cfpp.project_pca(query_adata=adata_ood, ref_adata=adata_ref)
ood_data_target_encoded = {}
ood_data_target_decoded = {}
ood_data_target_encoded_predicted = {}
ood_data_target_decoded_predicted = {}
for cond in adata_ood.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    ood_data_target_encoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].obsm["X_pca"]
    ood_data_target_decoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].X.toarray()
    ood_data_target_decoded_predicted[cond] = adata_pred_ood[adata_pred_ood.obs["condition"] == cond].layers["X_recon_pred"]
    ood_data_target_encoded_predicted[cond] = adata_pred_ood[adata_pred_ood.obs["condition"] == cond].obsm["X_pca"]

## Collect differentially expressed genes to test for 

In [None]:
test_deg_dict = {
    k: v
    for k, v in adata_train.uns["rank_genes_groups_cov_all"].items()
    if k in test_data_target_decoded_predicted.keys()
}

ood_deg_dict = {
    k: v
    for k, v in adata_train.uns["rank_genes_groups_cov_all"].items()
    if k in ood_data_target_decoded_predicted.keys()

In [None]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]

In [None]:
ood_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, ood_data_target_decoded_predicted, ood_deg_dict)
ood_deg_target_decoded = jax.tree_util.tree_map(get_mask, ood_data_target_decoded, ood_deg_dict)

ood_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, ood_data_target_encoded, ood_data_target_encoded_predicted
)
ood_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics_fast, ood_data_target_decoded, ood_data_target_decoded_predicted
)
deg_ood_metrics = jax.tree_util.tree_map(compute_metrics, ood_deg_target_decoded, ood_deg_target_decoded_predicted)

output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/results/otfm/pca_mean_pooling"

pd.DataFrame.from_dict(ood_metrics_encoded).to_csv(os.path.join(output_dir, f"ood_metrics_encoded_{split}.csv"))
pd.DataFrame.from_dict(ood_metrics_decoded).to_csv(os.path.join(output_dir, f"ood_metrics_decoded_{split}.csv"))
pd.DataFrame.from_dict(deg_ood_metrics).to_csv(os.path.join(output_dir, f"ood_metrics_ood_{split}.csv"))