In [48]:
from pathlib import Path

import anndata as ad
import numpy as np
import torch

from chemCPA.data import PerturbationDataModule, load_dataset_splits
from chemCPA.lightning_module import ChemCPA

In [49]:
ckpt = "last.ckpt"
# orig
# run_id = "hnjtnxyk"
# split 0
# run_id = "2kxbiqm8"
# split = 1
# run_id = "sb2w11gg"
# split = 2
# run_id = "km6ft6aj"
# split = 3
# run_id = "6w5pv7js"
split = 4
run_id = "hgzlvf0t"
cp_path = Path("/lustre/groups/ml01/workspace/artur.szalata/code/ot_pert_reproducibility/competing_methods/sciplex/chemcpa/chemCPA/project_folder/checkpoints_hydra") / run_id / ckpt

In [50]:
module = ChemCPA.load_from_checkpoint(cp_path)

/home/icb/artur.szalata/miniconda3/envs/perturbation_models_biolord/lib/python3.9/site-packages/lightning/pytorch/utilities/migration/utils.py:55: The loaded checkpoint was produced with Lightning v2.2.4, which is newer than your current Lightning version: v2.1.4


In [51]:
data_params = module.config["dataset"]

In [52]:
datasets, dataset = load_dataset_splits(**data_params, return_dataset=True)

In [53]:
dm = PerturbationDataModule(datasplits=datasets, train_bs=module.config["model"]["hparams"]["batch_size"])
dm.setup(stage="fit")  # fit, validate/test, predict

In [54]:
module.model.eval();

In [55]:
ood_control_dataset = dm.ood_control_dataset
test_control_dataset = dm.test_control_dataset
ood_treated_dataset = dm.ood_treated_dataset
test_treated_dataset = dm.test_treated_dataset

In [68]:
def compute_preds(control_dataset, treated_dataset):
    control_genes = {}

    # Iterate over the dataset
    _genes = control_dataset.genes
    _cov_names = control_dataset.covariate_names["cell_type"]
    
    for covariate, gene in zip(_cov_names, _genes):
        if covariate not in control_genes:
            control_genes[covariate] = gene.unsqueeze(0)
            continue
        control_genes[covariate] = torch.concat([control_genes[covariate], gene.unsqueeze(0)], dim=0)
    module.model.eval()
    module.model.to("cuda")
    
    preds = {}
    targs = {}
    
    for pert_cat, item in zip(dm.ood_treated_dataset.pert_categories, dm.ood_treated_dataset):
        if pert_cat not in preds:
            genes = item[0]
            drug_idx = item[1]
            dosages = item[2]
            covariates = item[4:]
            cl = pert_cat.split("_")[0]
            dose = pert_cat.split("_")[-1]
            drug = "_".join(pert_cat.split("_")[1:-1])
    
            genes = control_genes[cl]
            n_obs = len(control_genes[cl])
    
            # repeat torch tensor n_obs times
            drugs_idx = drug_idx.repeat(n_obs)
            dosages = dosages.repeat(n_obs)
            covariates = [cov.repeat(n_obs, 1) for cov in covariates]
            gene_reconstructions, cell_drug_embedding, latent_basal = module.model.predict(
                genes=genes,
                drugs=None,
                drugs_idx=drugs_idx,
                dosages=dosages,
                covariates=covariates,
                return_latent_basal=True,
            )
    
            dim = gene_reconstructions.size(1) // 2
            mean = gene_reconstructions[:, :dim]
            var = gene_reconstructions[:, dim:]
    
            preds[pert_cat] = mean.detach().cpu().numpy()
            targs[pert_cat] = (
                (treated_dataset.genes[treated_dataset.pert_categories == pert_cat]).clone().numpy()
            )
    predictions = []
    targets = []
    cl_p = []
    cl_t = []
    drug_p = []
    drug_t = []
    dose_p = []
    dose_t = []
    control = {}
    control_cl = {}
    for key, val in preds.items():
        cl = key.split("_")[0]
        drug = "_".join(key.split("_")[1:-1])
        dose = key.split("_")[-1]
    
        control[cl] = control_genes[cl].numpy()
        control_cl[cl] = control[cl].shape[0] * [cl]
    
        predictions.append(val)
        cl_p.extend(val.shape[0] * [cl])
        drug_p.extend(val.shape[0] * [drug])
        dose_p.extend(val.shape[0] * [float(dose)])
    
        targets.append(targs[key])
        cl_t.extend(targs[key].shape[0] * [cl])
        drug_t.extend(targs[key].shape[0] * [drug])
        dose_t.extend(targs[key].shape[0] * [float(dose)])
    
    adata_c = ad.AnnData(np.concatenate([control[cl] for cl in control], axis=0))
    adata_c.obs["cell_line"] = list(np.concatenate([control_cl[cl] for cl in control], axis=0))
    adata_c.obs["condition"] = "control"
    adata_c.obs["perturbation"] = "Vehicle"
    adata_c.obs["dose"] = 1.0
    
    adata_p = ad.AnnData(np.concatenate(predictions, axis=0))
    adata_p.obs["condition"] = "prediction"
    adata_p.obs["cell_line"] = cl_p
    adata_p.obs["perturbation"] = drug_p
    adata_p.obs["dose"] = dose_p
    
    
    adata_t = ad.AnnData(np.concatenate(targets, axis=0))
    adata_t.obs["condition"] = "target"
    adata_t.obs["cell_line"] = cl_t
    adata_t.obs["perturbation"] = drug_t
    adata_t.obs["dose"] = dose_t
    
    adata = ad.concat([adata_c, adata_p, adata_t])
    adata.obs_names_make_unique()
    adata.obs["pert_category"] = None
    
    for key in np.unique(ood_treated_dataset.pert_categories):
        cl = key.split("_")[0]
        drug = "_".join(key.split("_")[1:-1])
        dose = float(key.split("_")[-1])
    
        cond = adata.obs["cell_line"] == cl
        cond *= adata.obs["perturbation"] == drug
        cond *= adata.obs["dose"] == dose
        adata.obs.loc[cond, "pert_category"] = key
    return adata

In [69]:
ood_preds = compute_preds(ood_control_dataset, ood_treated_dataset)

ValueError: need at least one array to concatenate

In [None]:
ood_preds.write(f"/lustre/groups/ml01/workspace/artur.szalata/code/ot_pert_reproducibility/competing_methods/sciplex/chemcpa/chemCPA/project_folder/adata_biolord_split_{split}_300_pred_ood.h5ad")

In [70]:
test_preds = compute_preds(test_control_dataset, test_treated_dataset)

ValueError: need at least one array to concatenate

In [None]:
test_preds.write(f"/lustre/groups/ml01/workspace/artur.szalata/code/ot_pert_reproducibility/competing_methods/sciplex/chemcpa/chemCPA/project_folder/adata_biolord_split_{split}_300_pred_test.h5ad")

In [None]:
!ls /lustre/groups/ml01/workspace/artur.szalata/code/ot_pert_reproducibility/competing_methods/sciplex/chemcpa/chemCPA/project_folder