In [1]:
import pandas as pd
import numpy as np

import scanpy as sc
import biolord

import seaborn as sns
import matplotlib.pyplot as plt
import warnings
from scipy.stats import ttest_rel
import anndata

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
adata_train_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_train_biolord_split_30.h5ad"
adata_test_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_test_biolord_split_30.h5ad"
adata_ood_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_ood_biolord_split_30.h5ad"

In [3]:
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)



In [4]:
adata = anndata.concat((adata_train, adata_test, adata_ood), label="split", keys=["train", "test", "ood"])

  utils.warn_names_duplicates("obs")


In [5]:
frac_valid = 0.1

def create_split(x):
    if x["split"] != "train":
        return "other"
    is_train = np.random.choice(2, p=[frac_valid, 1 - frac_valid])
    if is_train:
        return "train_train"
    return "train_valid"


adata.obs["new_split"] = adata.obs.apply(create_split, axis=1)

In [6]:
dose = adata.obs["dose"].astype("float") / np.max(adata.obs["dose"].astype("float")) # following 

In [7]:
adata.obsm["ecfp_dose"] = np.concatenate((adata.obsm["ecfp"], dose.values[:,None]), axis=1)

In [8]:
biolord.Biolord.setup_anndata(
    adata,
    ordered_attributes_keys=["ecfp_dose"],
    categorical_attributes_keys=["cell_type"],
    retrieval_attribute_key=None,
)

[34mINFO    [0m Generating sequential column names                                                                        


In [9]:
module_params= dict(
    decoder_width = 4096,
    decoder_depth = 4,
    
    attribute_dropout_rate = 0.1,
    attribute_nn_width = 2048,
    attribute_nn_depth = 2,
    
    unknown_attribute_noise_param = 2e+1,
    unknown_attribute_penalty = 1e-1,
    gene_likelihood = "normal",
    n_latent_attribute_ordered = 256,
    n_latent_attribute_categorical = 3,
    reconstruction_penalty = 1e+4,
    use_batch_norm = False,
    use_layer_norm = False,)

trainer_params=dict(
    latent_lr = 1e-4,
    latent_wd = 1e-4,
    decoder_lr = 1e-4,
    decoder_wd = 1e-4,
    attribute_nn_lr = 1e-2,
    attribute_nn_wd = 4e-8,
    cosine_scheduler = True,
    scheduler_final_lr = 1e-5,
    step_size_lr = 45,
)

In [10]:
model = biolord.Biolord(
    adata=adata,
    n_latent=256,
    model_name="sciplex3",
    module_params=module_params,
    train_classifiers=False,
    split_key="new_split",
    train_split="train_train",
    valid_split="train_valid",
    test_split="other",
)

[rank: 0] Seed set to 0


In [None]:
model.train(
    max_epochs=200,
    batch_size=512,
    plan_kwargs=trainer_params,
    early_stopping=True,
    early_stopping_patience=20,
    check_val_every_n_epoch=10,
    num_workers=10,
    enable_checkpointing=False
)

/home/icb/dominik.klein/mambaforge/envs/ot_pert_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/dominik.klein/mambaforge/envs/ot_pert_biol ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/icb/dominik.klein/mambaforge/envs/ot_pert_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/dominik.klein/mambaforge/envs/ot_pert_biol ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  self.pid = os.fork()


Epoch 2/200:   0%|          | 1/200 [00:42<2:18:20, 41.71s/it, v_num=1, val_generative_mean_accuracy=0.0917, val_generative_var_accuracy=-86.6, val_biolord_metric=-43.3, val_LOSS_KEYS.RECONSTRUCTION=261, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=235]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 3/200:   1%|          | 2/200 [01:13<1:57:56, 35.74s/it, v_num=1, val_generative_mean_accuracy=0.348, val_generative_var_accuracy=-9.56, val_biolord_metric=-4.6, val_LOSS_KEYS.RECONSTRUCTION=227, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=216, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=213, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 4/200:   2%|▏         | 3/200 [01:43<1:49:29, 33.35s/it, v_num=1, val_generative_mean_accuracy=0.461, val_generative_var_accuracy=-4.4, val_biolord_metric=-1.97, val_LOSS_KEYS.RECONSTRUCTION=216, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=198, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=179, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 5/200:   2%|▏         | 4/200 [02:14<1:45:09, 32.19s/it, v_num=1, val_generative_mean_accuracy=0.545, val_generative_var_accuracy=-1.81, val_biolord_metric=-0.632, val_LOSS_KEYS.RECONSTRUCTION=207, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=181, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=158, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 6/200:   2%|▎         | 5/200 [02:44<1:42:45, 31.62s/it, v_num=1, val_generative_mean_accuracy=0.626, val_generative_var_accuracy=-0.492, val_biolord_metric=0.0668, val_LOSS_KEYS.RECONSTRUCTION=198, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=166, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=157, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 7/200:   3%|▎         | 6/200 [03:15<1:40:55, 31.22s/it, v_num=1, val_generative_mean_accuracy=0.687, val_generative_var_accuracy=0.281, val_biolord_metric=0.484, val_LOSS_KEYS.RECONSTRUCTION=191, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=152, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=157, unknown_attribute_penalty_loss=1.03e+5]  

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 8/200:   4%|▎         | 7/200 [03:45<1:39:44, 31.01s/it, v_num=1, val_generative_mean_accuracy=0.74, val_generative_var_accuracy=0.427, val_biolord_metric=0.583, val_LOSS_KEYS.RECONSTRUCTION=185, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=139, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=156, unknown_attribute_penalty_loss=1.03e+5] 

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 9/200:   4%|▍         | 8/200 [04:16<1:39:08, 30.98s/it, v_num=1, val_generative_mean_accuracy=0.819, val_generative_var_accuracy=0.743, val_biolord_metric=0.781, val_LOSS_KEYS.RECONSTRUCTION=177, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=127, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=156, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 10/200:   4%|▍         | 9/200 [04:47<1:37:57, 30.77s/it, v_num=1, val_generative_mean_accuracy=0.862, val_generative_var_accuracy=0.771, val_biolord_metric=0.817, val_LOSS_KEYS.RECONSTRUCTION=172, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=116, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=156, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 11/200:   5%|▌         | 10/200 [05:17<1:37:16, 30.72s/it, v_num=1, val_generative_mean_accuracy=0.854, val_generative_var_accuracy=0.799, val_biolord_metric=0.827, val_LOSS_KEYS.RECONSTRUCTION=172, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=106, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=156, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 12/200:   6%|▌         | 11/200 [05:50<1:38:55, 31.40s/it, v_num=1, val_generative_mean_accuracy=0.848, val_generative_var_accuracy=0.831, val_biolord_metric=0.839, val_LOSS_KEYS.RECONSTRUCTION=172, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=96, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=156, unknown_attribute_penalty_loss=1.03e+5] 

  self.pid = os.fork()
  self.pid = os.fork()


In [None]:
def bool2idx(x):
    """
    Returns the indices of the True-valued entries in a boolean array `x`
    """
    return np.where(x)[0]

def repeat_n(x, n):
    """
    Returns an n-times repeated version of the Tensor x,
    repetition dimension is axis 0
    """
    # copy tensor to device BEFORE replicating it n times
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return x.to(device).view(1, -1).repeat(n, 1)


In [None]:
idx_test_control = np.where(
    (adata.obs["split"] == "test") & (adata.obs["control"] == 1)
)[0]

adata_test_control = adata[idx_test_control].copy()

idx_ood = np.where(((adata.obs["split"] == "ood") & (adata.obs["control"] == 0)))[0]

adata_ood = adata[idx_ood].copy()
dataset_ood = model.get_dataset(adata_ood)

In [None]:
dataset_ood = model.get_dataset(adata_ood)

In [None]:
import torch
from tqdm import tqdm

In [None]:
dataset = dataset_ood

In [None]:
pert_categories_index = pd.Index(adata_ood.obs["condition"].values, dtype="category")
allowed_cell_lines = []
n_obs = 500

cl_dict = {
    torch.Tensor([0.]): "A549",
    torch.Tensor([1.]): "K562",
    torch.Tensor([2.]): "MCF7",
}

layer = "X" if "X" in dataset else "layers"
cell_lines = ["A549", "K562", "MCF7"]

predictions_dict = {}
true_dict = {}
for cell_drug_dose_comb in tqdm(
    np.unique(pert_categories_index.values)):

    if "Vehicle" in cell_drug_dose_comb:
        continue

    cell_type_idx = np.unique(dataset["cell_type"][idx_all])
    assert len(cell_type_idx) == 1
    dataset_control = model.get_dataset(adata_test_control[adata_test_control.obs["_scvi_cell_type"]==cell_type_idx[0]])
    
    bool_category = pert_categories_index.get_loc(cell_drug_dose_comb)
    idx_all = bool2idx(bool_category)
    idx = idx_all[0]  
    dataset_comb = {}

    
    dataset_comb[layer] = dataset_control[layer].to(model.device)
    dataset_comb["ind_x"] = dataset_control["ind_x"].to(model.device)
    for key in dataset_control:
        if key not in [layer, "ind_x"]:
            dataset_comb[key] = repeat_n(dataset[key][idx, :], n_obs)

    stop = False
    for tensor, cl in cl_dict.items():
        if (tensor == dataset["cell_type"][idx]).all():
            if cl not in cell_lines:
                stop = True
    if stop:
        continue
        
    pred_mean, pred_std = model.module.get_expression(dataset_comb)
    samples = np.array(torch.normal(pred_mean, pred_std).cpu())

    predictions_dict[cell_drug_dose_comb] = samples
    true_dict[cell_drug_dose_comb] = adata_ood[adata_ood.obs["condition"]==cell_drug_dose_comb].X.A



In [None]:
np.savez("/lustre/groups/ml01/workspace/ot_perturbation/models/biolord/sciplex/biolord_ood_preds.npz", **predictions_dict)
np.savez("/lustre/groups/ml01/workspace/ot_perturbation/models/biolord/sciplex/biolord_ood_true.npz", **true_dict)

In [None]:
#with np.load("/lustre/groups/ml01/workspace/ot_perturbation/models/biolord/sciplex/biolord_ood_preds.npz") as data:
#    loaded_data = {key: data[key] for key in data.keys()}