In [1]:
import pandas as pd
import numpy as np

import scanpy as sc
import biolord

import seaborn as sns
import matplotlib.pyplot as plt
import warnings
from scipy.stats import ttest_rel
import anndata

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ood_split = 0

In [3]:
adata_train_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_train_300_{ood_split}.h5ad"
adata_test_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_test_300_{ood_split}.h5ad"
adata_ood_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_ood_300_{ood_split}.h5ad"

In [4]:
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)



In [5]:
adata = anndata.concat((adata_train, adata_test, adata_ood), label="split", keys=["train", "test", "ood"])

  utils.warn_names_duplicates("obs")


In [6]:
frac_valid = 0.1

def create_split(x):
    if x["split"] != "train":
        return "other"
    is_train = np.random.choice(2, p=[frac_valid, 1 - frac_valid])
    if is_train:
        return "train_train"
    return "train_valid"


adata.obs["new_split"] = adata.obs.apply(create_split, axis=1)

In [7]:
dose = adata.obs["dose"].astype("float") / np.max(adata.obs["dose"].astype("float")) # following biolord repr.

In [8]:
adata.obsm["ecfp_dose"] = np.concatenate((adata.obsm["ecfp"], dose.values[:,None]), axis=1)

In [9]:
biolord.Biolord.setup_anndata(
    adata,
    ordered_attributes_keys=["ecfp_dose"],
    categorical_attributes_keys=["cell_type"],
    retrieval_attribute_key=None,
)

[34mINFO    [0m Generating sequential column names                                                                        


In [10]:
module_params= dict(
    decoder_width = 4096,
    decoder_depth = 4,
    
    attribute_dropout_rate = 0.1,
    attribute_nn_width = 2048,
    attribute_nn_depth = 2,
    
    unknown_attribute_noise_param = 2e+1,
    unknown_attribute_penalty = 1e-1,
    gene_likelihood = "normal",
    n_latent_attribute_ordered = 256,
    n_latent_attribute_categorical = 3,
    reconstruction_penalty = 1e+4,
    use_batch_norm = False,
    use_layer_norm = False,)

trainer_params=dict(
    latent_lr = 1e-4,
    latent_wd = 1e-4,
    decoder_lr = 1e-4,
    decoder_wd = 1e-4,
    attribute_nn_lr = 1e-2,
    attribute_nn_wd = 4e-8,
    cosine_scheduler = True,
    scheduler_final_lr = 1e-5,
    step_size_lr = 45,
)

In [11]:
model = biolord.Biolord(
    adata=adata,
    n_latent=256,
    model_name="sciplex3",
    module_params=module_params,
    train_classifiers=False,
    split_key="new_split",
    train_split="train_train",
    valid_split="train_valid",
    test_split="other",
)

[rank: 0] Seed set to 0


In [None]:
model.train(
    max_epochs=200,
    batch_size=512,
    plan_kwargs=trainer_params,
    early_stopping=True,
    early_stopping_patience=20,
    check_val_every_n_epoch=10,
    num_workers=10,
    enable_checkpointing=False
)

/home/icb/dominik.klein/mambaforge/envs/ot_pert_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/dominik.klein/mambaforge/envs/ot_pert_biol ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/icb/dominik.klein/mambaforge/envs/ot_pert_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/dominik.klein/mambaforge/envs/ot_pert_biol ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  self.pid = os.fork()


Epoch 2/200:   0%|          | 1/200 [00:37<2:03:25, 37.21s/it, v_num=1, val_generative_mean_accuracy=0.156, val_generative_var_accuracy=-60.6, val_biolord_metric=-30.2, val_LOSS_KEYS.RECONSTRUCTION=266, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=238]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 3/200:   1%|          | 2/200 [01:03<1:41:58, 30.90s/it, v_num=1, val_generative_mean_accuracy=0.185, val_generative_var_accuracy=-40, val_biolord_metric=-19.9, val_LOSS_KEYS.RECONSTRUCTION=253, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=222, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=217, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 4/200:   2%|▏         | 3/200 [01:29<1:33:41, 28.54s/it, v_num=1, val_generative_mean_accuracy=0.28, val_generative_var_accuracy=-15.1, val_biolord_metric=-7.41, val_LOSS_KEYS.RECONSTRUCTION=242, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=207, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=203, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 5/200:   2%|▏         | 4/200 [01:55<1:30:21, 27.66s/it, v_num=1, val_generative_mean_accuracy=0.523, val_generative_var_accuracy=-2.47, val_biolord_metric=-0.972, val_LOSS_KEYS.RECONSTRUCTION=215, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=192, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=183, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 6/200:   2%|▎         | 5/200 [02:21<1:28:11, 27.14s/it, v_num=1, val_generative_mean_accuracy=0.589, val_generative_var_accuracy=-0.642, val_biolord_metric=-0.0269, val_LOSS_KEYS.RECONSTRUCTION=209, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=179, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=169, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 7/200:   3%|▎         | 6/200 [02:48<1:26:47, 26.84s/it, v_num=1, val_generative_mean_accuracy=0.613, val_generative_var_accuracy=-0.149, val_biolord_metric=0.232, val_LOSS_KEYS.RECONSTRUCTION=206, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=167, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=160, unknown_attribute_penalty_loss=1.03e+5]  

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 8/200:   4%|▎         | 7/200 [03:13<1:24:49, 26.37s/it, v_num=1, val_generative_mean_accuracy=0.672, val_generative_var_accuracy=0.338, val_biolord_metric=0.505, val_LOSS_KEYS.RECONSTRUCTION=199, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=155, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=158, unknown_attribute_penalty_loss=1.03e+5] 

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 9/200:   4%|▍         | 8/200 [03:38<1:22:50, 25.89s/it, v_num=1, val_generative_mean_accuracy=0.747, val_generative_var_accuracy=0.581, val_biolord_metric=0.664, val_LOSS_KEYS.RECONSTRUCTION=190, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=144, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=158, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 10/200:   4%|▍         | 9/200 [04:03<1:21:26, 25.58s/it, v_num=1, val_generative_mean_accuracy=0.748, val_generative_var_accuracy=0.592, val_biolord_metric=0.67, val_LOSS_KEYS.RECONSTRUCTION=190, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=133, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=157, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 11/200:   5%|▌         | 10/200 [04:28<1:20:11, 25.32s/it, v_num=1, val_generative_mean_accuracy=0.726, val_generative_var_accuracy=0.521, val_biolord_metric=0.624, val_LOSS_KEYS.RECONSTRUCTION=194, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=124, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=157, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 12/200:   6%|▌         | 11/200 [04:52<1:19:07, 25.12s/it, v_num=1, val_generative_mean_accuracy=0.814, val_generative_var_accuracy=0.758, val_biolord_metric=0.786, val_LOSS_KEYS.RECONSTRUCTION=182, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=115, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=157, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 13/200:   6%|▌         | 12/200 [05:17<1:18:26, 25.04s/it, v_num=1, val_generative_mean_accuracy=0.81, val_generative_var_accuracy=0.764, val_biolord_metric=0.787, val_LOSS_KEYS.RECONSTRUCTION=181, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=106, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=157, unknown_attribute_penalty_loss=1.03e+5] 

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 14/200:   6%|▋         | 13/200 [05:42<1:18:01, 25.03s/it, v_num=1, val_generative_mean_accuracy=0.846, val_generative_var_accuracy=0.814, val_biolord_metric=0.83, val_LOSS_KEYS.RECONSTRUCTION=178, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=98, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=157, unknown_attribute_penalty_loss=1.03e+5] 

  self.pid = os.fork()


In [None]:
def bool2idx(x):
    """
    Returns the indices of the True-valued entries in a boolean array `x`
    """
    return np.where(x)[0]

def repeat_n(x, n):
    """
    Returns an n-times repeated version of the Tensor x,
    repetition dimension is axis 0
    """
    # copy tensor to device BEFORE replicating it n times
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return x.to(device).view(1, -1).repeat(n, 1)


In [None]:
idx_test_control = np.where(
    (adata.obs["split"] == "test") & (adata.obs["control"] == 1)
)[0]

adata_test_control = adata[idx_test_control].copy()

idx_ood = np.where(((adata.obs["split"] == "ood") & (adata.obs["control"] == 0)))[0]

adata_ood = adata[idx_ood].copy()
dataset_ood = model.get_dataset(adata_ood)

In [None]:
dataset_ood = model.get_dataset(adata_ood)

In [None]:
import pandas as pd
from tqdm import tqdm

def compute_prediction(
    model,
    adata,
    dataset,
    adata_control,
    n_obs=500
):
    pert_categories_index = pd.Index(adata.obs["condition"].values, dtype="category")

    cl_dict = {
        torch.Tensor([0.]): "A549",
        torch.Tensor([1.]): "K562",
        torch.Tensor([2.]): "MCF7",
    }

    cell_lines = ["A549", "K562", "MCF7"]

    layer = "X" if "X" in dataset else "layers"
    predictions_dict = {}
    for cell_drug_dose_comb in tqdm(np.unique(pert_categories_index.values)
    ):
        cur_cell_line = cell_drug_dose_comb.split("_")[0]
        dataset_control = model.get_dataset(adata_test_control[adata_test_control.obs.cell_type == cur_cell_line])

        bool_category = pert_categories_index.get_loc(cell_drug_dose_comb)
        idx_all = bool2idx(bool_category)
        idx = idx_all[0]
                    
        dataset_comb = {}

        dataset_comb[layer] = dataset_control[layer].to(model.device)
        dataset_comb["ind_x"] = dataset_control["ind_x"].to(model.device)
        for key in dataset_control:
            if key not in [layer, "ind_x"]:
                dataset_comb[key] = repeat_n(dataset[key][idx, :], n_obs)

        stop = False
        for tensor, cl in cl_dict.items():
            if (tensor == dataset["cell_type"][idx]).all():
                if cl not in cell_lines:
                    stop = True
        if stop:
            continue
            
        pred_mean, pred_std = model.module.get_expression(dataset_comb)
        samples = torch.normal(pred_mean, pred_std)

        predictions_dict[cell_drug_dose_comb] = samples.detach().cpu().numpy()
    return predictions_dict

In [None]:
biolord_prediction = compute_prediction(
    model=model,
    adata=adata_ood,
    dataset=dataset_ood,
    adata_control=adata_test_control)

In [None]:
import anndata as ad
all_data = []
conditions = []

for condition, array in biolord_prediction.items():
    all_data.append(array)
    conditions.extend([condition] * array.shape[0])

# Stack all data vertically to create a single array
all_data_array = np.vstack(all_data)

# Create a DataFrame for the .obs attribute
obs_data = pd.DataFrame({
    'condition': conditions
})

# Create the Anndata object
adata_ood_result = ad.AnnData(X=all_data_array, obs=obs_data)

In [None]:
adata_ood_result.write_h5ad(f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/biolord_output_ood_300_{ood_split}.h5ad")

In [None]:
idx_test_control = np.where(
    (adata.obs["split"] == "test") & (adata.obs["control"] == 1)
)[0]

adata_test_control = adata[idx_test_control].copy()

idx_test = np.where(((adata.obs["split"] == "test") & (adata.obs["control"] == 0)))[0]

adata_test = adata[idx_test].copy()
dataset_test = model.get_dataset(adata_test)

In [None]:
dataset_test = model.get_dataset(adata_test)

In [None]:
biolord_prediction = compute_prediction(
    model=model,
    adata=adata_test,
    dataset=dataset_test,
    adata_control=adata_test_control)

In [None]:
all_data = []
conditions = []

for condition, array in biolord_prediction.items():
    all_data.append(array)
    conditions.extend([condition] * array.shape[0])

# Stack all data vertically to create a single array
all_data_array = np.vstack(all_data)

# Create a DataFrame for the .obs attribute
obs_data = pd.DataFrame({
    'condition': conditions
})

# Create the Anndata object
adata_test_result = ad.AnnData(X=all_data_array, obs=obs_data)

In [None]:
adata_test_result.write_h5ad(f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/biolord_output_test_300_{ood_split}.h5ad")