In [33]:
import os
import cfp
import scanpy as sc
import numpy as np
import optax
import hydra 
from omegaconf import DictConfig, OmegaConf
from functools import partial
from cfp.training import Metrics, PCADecodedMetrics
from cfp.metrics import compute_metrics, compute_metrics_fast
from sklearn.metrics import r2_score
import datetime
import yaml
import jax.numpy as jnp
import torch
from utils import Config, get_highest_checkpoint_file, reconstruct_data_fn
from cfp.data._dataloader import ValidationSampler
import pandas as pd
import jax.tree_util as jtu
import cfp.preprocessing as cfpp

import seaborn as sns
import anndata as ad
import matplotlib.pyplot as plt
from pathlib import Path 

## Load the data of split 0 

In [10]:
path_to_splits = Path("/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/adata_ood_final_genes/adata_ood_final_genesIFNG_IFNB_TNFA_TGFB_INS_hvg-500_pca-100_counts_ms_0.5")

In [11]:
adata_train_path = path_to_splits / "adata_train_split_0.h5ad"
adata_test_path = path_to_splits / "adata_test_split_0.h5ad"
adata_ood_path = path_to_splits / "adata_ood_split_0.h5ad"
adata_train = sc.read_h5ad(adata_train_path)
adata_test = sc.read_h5ad(adata_test_path)
adata_ood = sc.read_h5ad(adata_ood_path)

Collect controls and perturbed

In [12]:
adata_ood_ctrl = adata_ood[adata_ood.obs["control"]]
adata_test_ctrl = adata_test[adata_test.obs["control"]]

covariate_data_ood = adata_ood[~adata_ood.obs["control"]].obs.drop_duplicates(subset=["perturbation_condition"])
covariate_data_test = adata_test[~adata_test.obs["control"]].obs.drop_duplicates(subset=["perturbation_condition"])

Read cellFlow from checkpoint

In [13]:
cf = cfp.model.CellFlow.load(f"/lustre/groups/ml01/workspace/alessandro.palma/ot_pert/out/ckpt_split_0/satija_gene_ood_0_CellFlow.pkl")

In [14]:
preds_ood = cf.predict(adata=adata_ood_ctrl, sample_rep="X_pca", condition_id_key="perturbation_condition", covariate_data=covariate_data_ood)

  adata.obs[self._control_key] = adata.obs[self._control_key].astype(
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _covariate_data["cell_index"] = _covariate_data.index






























2024-12-20 10:50:21.607355: W external/xla/xla/service/gpu/ir_emitter_unnested.cc:1171] Unable to parse backend config for custom call: Could not convert JSON string to proto: : Root element must be a message.
Fall back to parse the raw backend config str.
2024-12-20 10:50:21.616085: W external/xla/xla/service/gpu/ir_emitter_unnested.cc:1171] Unable to parse backend config for custom call: Could not convert JSON string to proto: : Root element must be a message.
Fall back to parse the raw backend config str.
2024-12-20 10:50:30.153043: W external/xla/xla/service/gpu/ir_emitt

In [15]:
# del adata_train

## Assamble and save

In [16]:
all_data = []
conditions = []

In [17]:
for condition, array in preds_ood.items():
    all_data.append(array)
    conditions.extend([condition] * array.shape[0])

In [18]:
# Stack all data vertically to create a single array
all_data_array = np.vstack(all_data)

# Create a DataFrame for the .obs attribute
obs_data = pd.DataFrame({
    'perturbation_condition': conditions
})

**Predict on OOD**

In [19]:
# Create the Anndata object
adata_ood_result = ad.AnnData(X=np.empty((len(all_data_array), 8265)), obs=obs_data)
adata_ood_result.obsm["X_pca_pred"] = all_data_array



In [20]:
adata_ood_result

AnnData object with n_obs × n_vars = 19300 × 8265
    obs: 'perturbation_condition'
    obsm: 'X_pca_pred'

In [21]:
cfpp.reconstruct_pca(query_adata=adata_ood_result, use_rep="X_pca_pred", ref_adata=adata_train, layers_key_added="X_recon_pred")

In [22]:
adata_ood_result.write("/lustre/groups/ml01/workspace/alessandro.palma/ot_pert/out/results_metrics/generated_data/adata_ood_split_0.h5ad")

**Predict on test**

In [None]:
preds_test = cf.predict(adata=adata_test_ctrl, sample_rep="X_pca", condition_id_key="perturbation_condition", covariate_data=covariate_data_test)

  adata.obs[self._control_key] = adata.obs[self._control_key].astype(
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _covariate_data["cell_index"] = _covariate_data.index






























2024-12-20 11:13:08.203331: W external/xla/xla/service/gpu/ir_emitter_unnested.cc:1171] Unable to parse backend config for custom call: Could not convert JSON string to proto: : Root element must be a message.
Fall back to parse the raw backend config str.
2024-12-20 11:13:08.203448: W external/xla/xla/service/gpu/ir_emitter_unnested.cc:1171] Unable to parse backend config for custom call: Could not convert JSON string to proto: : Root element must be a message.
Fall back to parse the raw backend config str.
2024-12-20 11:13:14.902107: W external/xla/xla/service/gpu/ir_emitt

In [30]:
all_data = []
conditions = []

for condition, array in preds_test.items():
    all_data.append(array)
    conditions.extend([condition] * array.shape[0])

In [31]:
# Stack all data vertically to create a single array
all_data_array = np.vstack(all_data)

# Create a DataFrame for the .obs attribute
obs_data = pd.DataFrame({
    'condition': conditions
})

In [None]:
# Create the Anndata object
adata_test_result = ad.AnnData(X=np.empty((len(all_data_array), 8265)), obs=obs_data)
adata_test_result.obsm["X_pca_pred"] = all_data_array

In [None]:
cfpp.reconstruct_pca(query_adata=adata_test_result, use_rep="X_pca_pred", ref_adata=adata_train, layers_key_added="X_recon_pred")

In [None]:
adata_test_result.write("/lustre/groups/ml01/workspace/alessandro.palma/ot_pert/out/results_metrics/generated_data/adata_test_split_0.h5ad")