In [None]:
import cfp
import scanpy as sc
import anndata as ad
import numpy as np

In [None]:
split = 5
model_name = "vital-dragon-422"

In [None]:
adata_train_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_train_{split}.h5ad"
adata_test_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_test_{split}.h5ad"
adata_ood_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_ood_{split}.h5ad"
adata_train = sc.read_h5ad(adata_train_path)
adata_test = sc.read_h5ad(adata_test_path)
adata_ood = sc.read_h5ad(adata_ood_path)
    

In [None]:
vae = CFJaxSCVI.load("/lustre/groups/ml01/workspace/ot_perturbation/models/vaes/sciplex/32_1024", adata=adata_train)
adata_train.obsm["X_scVI"] = vae.get_latent_representation(adata_train)
adata_test.obsm["X_scVI"] = vae.get_latent_representation(adata_test)
adata_ood.obsm["X_scVI"] = vae.get_latent_representation(adata_ood)


In [None]:
adata_ood_ctrl = adata_ood[adata_ood.obs["condition"]=="control"]
adata_test_ctrl = adata_test[adata_test.obs["condition"]=="control"]
adata_ood_ctrl.obs["control"] = True
adata_test_ctrl.obs["control"] = True
covariate_data_ood = adata_ood[adata_ood.obs["condition"]!="control"].obs.drop_duplicates(subset=["condition"])
covariate_data_test = adata_test[adata_test.obs["condition"]!="control"].obs.drop_duplicates(subset=["condition"])

In [None]:
cf = cfp.model.CellFlow.load(f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/out/{model_name}_CellFlow.pkl")

In [None]:
adata_ood_ctrl

In [None]:
preds_ood = cf.predict(adata=adata_ood_ctrl, sample_rep="X_scVI", condition_id_key="condition", covariate_data=covariate_data_ood)

In [None]:
preds_test = cf.predict(adata=adata_test_ctrl, sample_rep="X_scVI", condition_id_key="condition", covariate_data=covariate_data_test)

In [None]:
import pandas as pd
all_data = []
conditions = []

for condition, array in preds_ood.items():
    all_data.append(array)
    conditions.extend([condition] * array.shape[0])

# Stack all data vertically to create a single array
all_data_array = np.vstack(all_data)

# Create a DataFrame for the .obs attribute
obs_data = pd.DataFrame({
    'condition': conditions
})

# Create the Anndata object
adata_ood_result = ad.AnnData(X=np.empty((len(all_data_array),2000)), obs=obs_data)
adata_ood_result.obsm["X_scVI_pred"] = all_data_array

In [None]:
adata_ood_result.layers["pred_reconstruction"] = vae.get_reconstructed_expression(adata_ood_result, key="X_scVI_pred")

In [13]:
adata_ood_result.write(f"/lustre/groups/ml01/workspace/ot_perturbation/models/otfm/sciplex/adata_ood_with_predictions_{split}.h5ad")


In [13]:
all_data = []
conditions = []

for condition, array in preds_test.items():
    all_data.append(array)
    conditions.extend([condition] * array.shape[0])

# Stack all data vertically to create a single array
all_data_array = np.vstack(all_data)

# Create a DataFrame for the .obs attribute
obs_data = pd.DataFrame({
    'condition': conditions
})

# Create the Anndata object
adata_test_result = ad.AnnData(X=np.empty((len(all_data_array),2000)), obs=obs_data)
adata_test_result.obsm["X_pca_pred"] = all_data_array                       



In [14]:
cfpp.reconstruct_pca(query_adata=adata_test_result, use_rep="X_pca_pred", ref_adata=adata_train, layers_key_added="X_recon_pred")

In [15]:
adata_ood_result.write(f"/lustre/groups/ml01/workspace/ot_perturbation/models/otfm/combosciplex/adata_ood_with_predictions_{split}.h5ad")
adata_test_result.write(f"/lustre/groups/ml01/workspace/ot_perturbation/models/otfm/combosciplex/adata_test_with_predictions_{split}.h5ad")