In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

In [2]:
import cfp
import scanpy as sc
import anndata as ad
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
split = 5
model_name = "wild-snowball-463"

In [4]:
adata_train_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_train_{split}.h5ad"
adata_test_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_test_{split}.h5ad"
adata_ood_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_ood_{split}.h5ad"
adata_train = sc.read_h5ad(adata_train_path)
adata_test = sc.read_h5ad(adata_test_path)
adata_ood = sc.read_h5ad(adata_ood_path)
    

In [5]:
adata_ood_ctrl = adata_ood[adata_ood.obs["condition"].str.contains("Vehicle")]
adata_test_ctrl = adata_test[adata_test.obs["condition"].str.contains("Vehicle")]
adata_ood_ctrl.obs["control"] = True
adata_test_ctrl.obs["control"] = True
covariate_data_ood = adata_ood[~adata_ood.obs["condition"].str.contains("Vehicle")].obs.drop_duplicates(subset=["condition"])
covariate_data_test = adata_test[~adata_test.obs["condition"].str.contains("Vehicle")].obs.drop_duplicates(subset=["condition"])

cf = cfp.model.CellFlow.load(f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/out/{model_name}_CellFlow.pkl")

preds_ood = cf.predict(adata=adata_ood_ctrl, sample_rep="X_pca", condition_id_key="condition", covariate_data=covariate_data_ood)


  adata_ood_ctrl.obs["control"] = True
  adata_test_ctrl.obs["control"] = True
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _covariate_data["cell_index"] = _covariate_data.index
100%|██████████| 27/27 [00:00<00:00, 122.49it/s]
100%|██████████| 32/32 [00:00<00:00, 691.80it/s]
100%|██████████| 24/24 [00:00<00:00, 755.02it/s]


In [9]:
import pandas as pd
import cfp.preprocessing as cfpp
all_data = []
conditions = []

for condition, array in preds_ood.items():
    all_data.append(array)
    conditions.extend([condition] * array.shape[0])

# Stack all data vertically to create a single array
all_data_array = np.vstack(all_data)

# Create a DataFrame for the .obs attribute
obs_data = pd.DataFrame({
    'condition': conditions
})

# Create the Anndata object
adata_ood_result = ad.AnnData(X=np.empty((len(all_data_array),2001)), obs=obs_data)
adata_ood_result.obsm["X_pca_pred"] = all_data_array



In [10]:
cfpp.reconstruct_pca(query_adata=adata_ood_result, use_rep="X_pca_pred", ref_adata=adata_train, layers_key_added="X_recon_pred")


In [11]:
adata_ref_ood = adata_ood[~adata_ood.obs["condition"].str.contains('Vehicle')].copy()
cfpp.centered_pca(adata_ref_ood, n_comps=10)


In [12]:
adata_pred_ood = adata_ood_result

In [13]:
cfpp.project_pca(query_adata=adata_pred_ood, ref_adata=adata_ref_ood)
cfpp.project_pca(query_adata=adata_ood, ref_adata=adata_ref_ood)
ood_data_target_encoded = {}
ood_data_target_decoded = {}
ood_data_target_encoded_predicted = {}
ood_data_target_decoded_predicted = {}
for cond in adata_ood.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    ood_data_target_encoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].obsm["X_pca"]
    ood_data_target_decoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].X.toarray()
    ood_data_target_decoded_predicted[cond] = adata_pred_ood[adata_pred_ood.obs["condition"] == cond].layers["X_recon_pred"]
    ood_data_target_encoded_predicted[cond] = adata_pred_ood[adata_pred_ood.obs["condition"] == cond].obsm["X_pca"]


In [14]:
import functools
import jax
import numpy as np
import scanpy as sc
import cfp.preprocessing as cfpp
from cfp.metrics import compute_mean_metrics, compute_metrics, compute_metrics_fast
import os
import pandas as pd

In [15]:
ood_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics_fast, ood_data_target_encoded, ood_data_target_encoded_predicted
)
mean_ood_metrics_encoded = compute_mean_metrics(ood_metrics_encoded, prefix="encoded_ood_")

ood_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics_fast, ood_data_target_decoded, ood_data_target_decoded_predicted
)
mean_ood_metrics_decoded = compute_mean_metrics(ood_metrics_decoded, prefix="decoded_ood_")


In [16]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.5348661258608282,
 'decoded_ood_e_distance': 25.234276212413874,
 'decoded_ood_mmd_distance': 0.04525858489235481}

In [17]:
mean_ood_metrics_encoded

{'encoded_ood_r_squared': -1.360208534227988,
 'encoded_ood_e_distance': 60.33726974321534,
 'encoded_ood_mmd_distance': 0.8656123043542885}

In [18]:
ood_metrics_encoded

{'A549_Alvespimycin_(17-DMAG)_HCl_10.0': {'r_squared': -0.8475600246691433,
  'e_distance': 74.82401513408907,
  'mmd_distance': 0.92467564},
 'A549_Alvespimycin_(17-DMAG)_HCl_100.0': {'r_squared': -1.1215174490755282,
  'e_distance': 89.91276421259919,
  'mmd_distance': 0.95831186},
 'A549_Belinostat_(PXD101)_10.0': {'r_squared': -1.0220381143122825,
  'e_distance': 80.89894198496127,
  'mmd_distance': 0.9173282},
 'A549_Belinostat_(PXD101)_100.0': {'r_squared': -1.1169971670162018,
  'e_distance': 81.91237946926333,
  'mmd_distance': 0.9222993},
 'A549_Belinostat_(PXD101)_1000.0': {'r_squared': -1.4577155690418548,
  'e_distance': 95.84725769201447,
  'mmd_distance': 0.95729846},
 'A549_Dacinostat_(LAQ824)_10.0': {'r_squared': -1.4612742828219591,
  'e_distance': 89.94282587157302,
  'mmd_distance': 0.93888617},
 'A549_Dacinostat_(LAQ824)_100.0': {'r_squared': -1.769159696517034,
  'e_distance': 72.33405853259143,
  'mmd_distance': 0.8933489},
 'A549_Dacinostat_(LAQ824)_1000.0': {'r_