In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import jax
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
import cfp.preprocessing as cfpp
from cfp.metrics import compute_metrics, compute_metrics_fast



In [3]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]

In [4]:
split = "3"

In [5]:
adata_train = sc.read_h5ad(f"/home/haicu/soeren.becker/repos/ot_pert_reproducibility/results/adata_train_{split}.h5ad")
adata_test = sc.read_h5ad(f"/home/haicu/soeren.becker/repos/ot_pert_reproducibility/results/adata_test_{split}.h5ad")
adata_ood = sc.read_h5ad(f"/home/haicu/soeren.becker/repos/ot_pert_reproducibility/results/adata_ood_{split}.h5ad")

In [6]:
OUT_DIR =  "/home/haicu/soeren.becker/repos/ot_pert_reproducibility/results_debug_biolord"
adata_pred_ood = sc.read_h5ad(f"/home/haicu/soeren.becker/repos/ot_pert_reproducibility/results_debug_biolord/biolord2_adata_pred_ood_{split}.h5ad")

  utils.warn_names_duplicates("obs")


In [7]:
# compute pca on full dataset
adata_all = ad.concat((adata_train, adata_test, adata_ood))
cfpp.centered_pca(adata_all, n_comps=10)

#### Predict on ood set (full ood set + subgroups)

In [8]:
cfpp.project_pca(query_adata=adata_pred_ood, ref_adata=adata_all)
cfpp.project_pca(query_adata=adata_ood, ref_adata=adata_all)

ood_data_target_encoded, ood_data_target_decoded = {}, {}
ood_data_target_encoded_predicted, ood_data_target_decoded_predicted = {}, {}
    
for cond in adata_ood.obs["condition"].cat.categories:
    if cond == "ctrl":
        continue
    
    # pca space
    ood_data_target_encoded[cond] = adata_ood[adata_ood.obs.condition == cond].obsm["X_pca"]
    ood_data_target_encoded_predicted[cond] = adata_pred_ood[adata_pred_ood.obs.condition == cond].obsm["X_pca"]

    # gene space
    ood_data_target_decoded[cond] = np.asarray(adata_ood[adata_ood.obs.condition == cond].X.todense())
    ood_data_target_decoded_predicted[cond] = adata_pred_ood[adata_pred_ood.obs.condition == cond].X

#### Evaluation on ood set

In [9]:
# ood_metrics_encoded = {}
ood_metrics_decoded = {}
deg_ood_metrics = {}
ood_deg_dict = {}
ood_deg_target_decoded_predicted, ood_deg_target_decoded = {}, {}

print("Computing ood_metrics_encoded")
# ood set: evaluation in encoded (=pca) space
ood_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, 
    # compute_metrics_fast, 
    ood_data_target_encoded, 
    ood_data_target_encoded_predicted
)

print("Computing ood_metrics_decoded")
# ood set: evaluation in decoded (=gene) space
ood_metrics_decoded = jax.tree_util.tree_map(
    # compute_metrics, 
    compute_metrics_fast, 
    ood_data_target_decoded, 
    ood_data_target_decoded_predicted
)

# ood set
ood_deg_dict = {
    k: v
    for k, v in adata_train.uns['rank_genes_groups_cov_all'].items() 
    if k in ood_data_target_decoded_predicted.keys()
}

print("Apply DEG mask")
# ood set
ood_deg_target_decoded_predicted = jax.tree_util.tree_map(
    get_mask, 
    ood_data_target_decoded_predicted,
    ood_deg_dict
)

ood_deg_target_decoded = jax.tree_util.tree_map(
    get_mask, 
    ood_data_target_decoded,
    ood_deg_dict
)

print("Compute metrics on DEG subsetted decoded")
deg_ood_metrics= jax.tree_util.tree_map(
    compute_metrics, 
    # compute_metrics_fast, 
    ood_deg_target_decoded, 
    ood_deg_target_decoded_predicted
)

Computing ood_metrics_encoded
Computing ood_metrics_decoded
Apply DEG mask
Compute metrics on DEG subsetted decoded


In [10]:
collected_results = {
    "ood_metrics_encoded": ood_metrics_encoded,
    "ood_metrics_decoded": ood_metrics_decoded,
    "deg_ood_metrics": deg_ood_metrics,
    "ood_deg_target_decoded_predicted": ood_deg_target_decoded_predicted,
    "ood_deg_target_decoded": ood_deg_target_decoded,
    "ood_deg_dict": ood_deg_dict,
}

In [11]:
OUT_DIR = "/lustre/groups/ml01/workspace/ot_perturbation/data/norman_soren/biolord_debug"
os.makedirs(OUT_DIR, exist_ok=True)
out_file = os.path.join(OUT_DIR, f"biolord_split_{split}_collected_results_new.pkl")
pd.to_pickle(collected_results, out_file)
print(f"Saving results at: {out_file}")

Saving results at: /lustre/groups/ml01/workspace/ot_perturbation/data/norman_soren/biolord_debug/biolord_split_3_collected_results_new.pkl


In [12]:
collected_results

{'ood_metrics_encoded': {'AHR+FEV': {'r_squared': 0.6878296136856079,
   'sinkhorn_div_1': 42.86714172363281,
   'sinkhorn_div_10': 34.884613037109375,
   'sinkhorn_div_100': 15.697677612304688,
   'e_distance': np.float64(21.99960783375727),
   'mmd': np.float32(0.08559219)},
  'AHR+KLF1': {'r_squared': 0.1788308024406433,
   'sinkhorn_div_1': 22.348989486694336,
   'sinkhorn_div_10': 16.529306411743164,
   'sinkhorn_div_100': 9.04678726196289,
   'e_distance': np.float64(15.663890570870915),
   'mmd': np.float32(0.070343405)},
  'AHR+ctrl': {'r_squared': -0.4028204679489136,
   'sinkhorn_div_1': 20.3182430267334,
   'sinkhorn_div_10': 13.898359298706055,
   'sinkhorn_div_100': 4.0970611572265625,
   'e_distance': np.float64(4.757158108540384),
   'mmd': np.float32(0.0422702)},
  'ARID1A+ctrl': {'r_squared': -0.03427541255950928,
   'sinkhorn_div_1': 26.886159896850586,
   'sinkhorn_div_10': 20.206266403198242,
   'sinkhorn_div_100': 9.95840835571289,
   'e_distance': np.float64(14.53

In [14]:
r2s = []
for key in collected_results["deg_ood_metrics"].keys():
    r2s.append(collected_results["deg_ood_metrics"][key]["r_squared"])

In [16]:
np.mean(r2s), np.median(r2s)

(np.float64(-31.360833547136806), np.float64(0.9541035294532776))