In [1]:
import pandas as pd
import numpy as np

import scanpy as sc
import biolord

import seaborn as sns
import matplotlib.pyplot as plt
import warnings
from scipy.stats import ttest_rel
import anndata

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
adata_train_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_train_biolord_split_300.h5ad"
adata_test_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_test_biolord_split_300.h5ad"
adata_ood_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_ood_biolord_split_300.h5ad"

In [3]:
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)



In [4]:
adata = anndata.concat((adata_train, adata_test, adata_ood), label="split", keys=["train", "test", "ood"])

  utils.warn_names_duplicates("obs")


In [5]:
frac_valid = 0.1

def create_split(x):
    if x["split"] != "train":
        return "other"
    is_train = np.random.choice(2, p=[frac_valid, 1 - frac_valid])
    if is_train:
        return "train_train"
    return "train_valid"


adata.obs["new_split"] = adata.obs.apply(create_split, axis=1)

In [6]:
dose = adata.obs["dose"].astype("float") / np.max(adata.obs["dose"].astype("float")) # following biolord repr.

In [7]:
adata.obsm["ecfp_dose"] = np.concatenate((adata.obsm["ecfp"], dose.values[:,None]), axis=1)

In [8]:
biolord.Biolord.setup_anndata(
    adata,
    ordered_attributes_keys=["ecfp_dose"],
    categorical_attributes_keys=["cell_type"],
    retrieval_attribute_key=None,
)

[34mINFO    [0m Generating sequential column names                                                                        


In [9]:
module_params= dict(
    decoder_width = 4096,
    decoder_depth = 4,
    
    attribute_dropout_rate = 0.1,
    attribute_nn_width = 2048,
    attribute_nn_depth = 2,
    
    unknown_attribute_noise_param = 2e+1,
    unknown_attribute_penalty = 1e-1,
    gene_likelihood = "normal",
    n_latent_attribute_ordered = 256,
    n_latent_attribute_categorical = 3,
    reconstruction_penalty = 1e+4,
    use_batch_norm = False,
    use_layer_norm = False,)

trainer_params=dict(
    latent_lr = 1e-4,
    latent_wd = 1e-4,
    decoder_lr = 1e-4,
    decoder_wd = 1e-4,
    attribute_nn_lr = 1e-2,
    attribute_nn_wd = 4e-8,
    cosine_scheduler = True,
    scheduler_final_lr = 1e-5,
    step_size_lr = 45,
)

In [10]:
model = biolord.Biolord(
    adata=adata,
    n_latent=256,
    model_name="sciplex3",
    module_params=module_params,
    train_classifiers=False,
    split_key="new_split",
    train_split="train_train",
    valid_split="train_valid",
    test_split="other",
)

[rank: 0] Seed set to 0


In [None]:
model.train(
    max_epochs=200,
    batch_size=512,
    plan_kwargs=trainer_params,
    early_stopping=True,
    early_stopping_patience=20,
    check_val_every_n_epoch=10,
    num_workers=10,
    enable_checkpointing=False
)

/home/icb/dominik.klein/mambaforge/envs/ot_pert_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/dominik.klein/mambaforge/envs/ot_pert_biol ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/icb/dominik.klein/mambaforge/envs/ot_pert_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/dominik.klein/mambaforge/envs/ot_pert_biol ...
You are using a CUDA device ('NVIDIA A100-PCIE-40GB MIG 3g.20gb') that has Tensor Cores. To properly u

Epoch 2/200:   0%| | 1/200 [00:48<2:41:15, 48.62s/it, v_num=1, val_generative_mean_accuracy=0.18, val_generative_var_accuracy=-67

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 3/200:   1%| | 2/200 [01:30<2:27:18, 44.64s/it, v_num=1, val_generative_mean_accuracy=0.357, val_generative_var_accuracy=-9

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 4/200:   2%| | 3/200 [02:12<2:21:58, 43.24s/it, v_num=1, val_generative_mean_accuracy=0.457, val_generative_var_accuracy=-5

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 5/200:   2%| | 4/200 [02:53<2:19:16, 42.64s/it, v_num=1, val_generative_mean_accuracy=0.57, val_generative_var_accuracy=-2.

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 6/200:   2%| | 5/200 [03:35<2:17:45, 42.39s/it, v_num=1, val_generative_mean_accuracy=0.645, val_generative_var_accuracy=-0

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 7/200:   3%| | 6/200 [04:17<2:16:28, 42.21s/it, v_num=1, val_generative_mean_accuracy=0.707, val_generative_var_accuracy=-0

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 8/200:   4%| | 7/200 [04:59<2:15:32, 42.14s/it, v_num=1, val_generative_mean_accuracy=0.769, val_generative_var_accuracy=0.

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 9/200:   4%| | 8/200 [05:41<2:15:03, 42.21s/it, v_num=1, val_generative_mean_accuracy=0.797, val_generative_var_accuracy=0.

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 10/200:   4%| | 9/200 [06:23<2:14:06, 42.13s/it, v_num=1, val_generative_mean_accuracy=0.799, val_generative_var_accuracy=0

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 11/200:   5%| | 10/200 [07:05<2:13:10, 42.06s/it, v_num=1, val_generative_mean_accuracy=0.862, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 12/200:   6%| | 11/200 [07:47<2:12:11, 41.97s/it, v_num=1, val_generative_mean_accuracy=0.871, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 13/200:   6%| | 12/200 [08:29<2:11:27, 41.96s/it, v_num=1, val_generative_mean_accuracy=0.918, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 14/200:   6%| | 13/200 [09:11<2:10:32, 41.89s/it, v_num=1, val_generative_mean_accuracy=0.932, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 15/200:   7%| | 14/200 [09:52<2:09:45, 41.86s/it, v_num=1, val_generative_mean_accuracy=0.924, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 16/200:   8%| | 15/200 [10:35<2:09:14, 41.92s/it, v_num=1, val_generative_mean_accuracy=0.909, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 17/200:   8%| | 16/200 [11:17<2:08:36, 41.93s/it, v_num=1, val_generative_mean_accuracy=0.943, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 18/200:   8%| | 17/200 [11:59<2:08:02, 41.98s/it, v_num=1, val_generative_mean_accuracy=0.959, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 19/200:   9%| | 18/200 [12:41<2:07:18, 41.97s/it, v_num=1, val_generative_mean_accuracy=0.951, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 20/200:  10%| | 19/200 [13:22<2:06:04, 41.79s/it, v_num=1, val_generative_mean_accuracy=0.956, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 21/200:  10%| | 20/200 [14:03<2:05:09, 41.72s/it, v_num=1, val_generative_mean_accuracy=0.968, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 22/200:  10%| | 21/200 [14:46<2:05:02, 41.92s/it, v_num=1, val_generative_mean_accuracy=0.966, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 23/200:  11%| | 22/200 [15:28<2:04:13, 41.87s/it, v_num=1, val_generative_mean_accuracy=0.957, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 24/200:  12%| | 23/200 [16:10<2:03:39, 41.92s/it, v_num=1, val_generative_mean_accuracy=0.974, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 25/200:  12%| | 24/200 [16:51<2:02:48, 41.87s/it, v_num=1, val_generative_mean_accuracy=0.979, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 26/200:  12%|▏| 25/200 [17:33<2:01:59, 41.83s/it, v_num=1, val_generative_mean_accuracy=0.975, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 27/200:  13%|▏| 26/200 [18:15<2:01:02, 41.74s/it, v_num=1, val_generative_mean_accuracy=0.972, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 28/200:  14%|▏| 27/200 [18:56<2:00:21, 41.74s/it, v_num=1, val_generative_mean_accuracy=0.973, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 29/200:  14%|▏| 28/200 [19:39<1:59:57, 41.85s/it, v_num=1, val_generative_mean_accuracy=0.977, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 30/200:  14%|▏| 29/200 [20:20<1:59:14, 41.84s/it, v_num=1, val_generative_mean_accuracy=0.973, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 31/200:  15%|▏| 30/200 [21:03<1:59:37, 42.22s/it, v_num=1, val_generative_mean_accuracy=0.975, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 32/200:  16%|▏| 31/200 [21:46<1:59:01, 42.26s/it, v_num=1, val_generative_mean_accuracy=0.979, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 33/200:  16%|▏| 32/200 [22:28<1:58:15, 42.23s/it, v_num=1, val_generative_mean_accuracy=0.975, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 34/200:  16%|▏| 33/200 [23:10<1:57:22, 42.17s/it, v_num=1, val_generative_mean_accuracy=0.979, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 35/200:  17%|▏| 34/200 [23:52<1:56:46, 42.21s/it, v_num=1, val_generative_mean_accuracy=0.979, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 36/200:  18%|▏| 35/200 [24:34<1:55:52, 42.14s/it, v_num=1, val_generative_mean_accuracy=0.976, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 37/200:  18%|▏| 36/200 [25:16<1:55:00, 42.07s/it, v_num=1, val_generative_mean_accuracy=0.98, val_generative_var_accuracy=0

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 38/200:  18%|▏| 37/200 [25:59<1:54:46, 42.25s/it, v_num=1, val_generative_mean_accuracy=0.985, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 39/200:  19%|▏| 38/200 [26:41<1:53:51, 42.17s/it, v_num=1, val_generative_mean_accuracy=0.986, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 40/200:  20%|▏| 39/200 [27:23<1:53:16, 42.22s/it, v_num=1, val_generative_mean_accuracy=0.979, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 41/200:  20%|▏| 40/200 [28:05<1:52:22, 42.14s/it, v_num=1, val_generative_mean_accuracy=0.978, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 42/200:  20%|▏| 41/200 [28:48<1:52:12, 42.34s/it, v_num=1, val_generative_mean_accuracy=0.981, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 43/200:  21%|▏| 42/200 [29:30<1:51:33, 42.36s/it, v_num=1, val_generative_mean_accuracy=0.982, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 44/200:  22%|▏| 43/200 [30:13<1:50:52, 42.37s/it, v_num=1, val_generative_mean_accuracy=0.98, val_generative_var_accuracy=0

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 45/200:  22%|▏| 44/200 [30:55<1:50:12, 42.38s/it, v_num=1, val_generative_mean_accuracy=0.982, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 46/200:  22%|▏| 45/200 [31:37<1:49:25, 42.36s/it, v_num=1, val_generative_mean_accuracy=0.977, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 47/200:  23%|▏| 46/200 [32:20<1:48:32, 42.29s/it, v_num=1, val_generative_mean_accuracy=0.98, val_generative_var_accuracy=0

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 48/200:  24%|▏| 47/200 [33:02<1:47:46, 42.27s/it, v_num=1, val_generative_mean_accuracy=0.983, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 49/200:  24%|▏| 48/200 [33:45<1:47:30, 42.44s/it, v_num=1, val_generative_mean_accuracy=0.985, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 50/200:  24%|▏| 49/200 [34:27<1:46:36, 42.36s/it, v_num=1, val_generative_mean_accuracy=0.983, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 51/200:  25%|▎| 50/200 [35:09<1:45:45, 42.30s/it, v_num=1, val_generative_mean_accuracy=0.979, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 52/200:  26%|▎| 51/200 [35:52<1:45:24, 42.44s/it, v_num=1, val_generative_mean_accuracy=0.984, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 53/200:  26%|▎| 52/200 [36:34<1:44:28, 42.35s/it, v_num=1, val_generative_mean_accuracy=0.984, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 54/200:  26%|▎| 53/200 [37:16<1:43:28, 42.23s/it, v_num=1, val_generative_mean_accuracy=0.981, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 55/200:  27%|▎| 54/200 [37:58<1:42:55, 42.30s/it, v_num=1, val_generative_mean_accuracy=0.982, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 56/200:  28%|▎| 55/200 [38:41<1:42:20, 42.35s/it, v_num=1, val_generative_mean_accuracy=0.986, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 57/200:  28%|▎| 56/200 [39:23<1:41:33, 42.32s/it, v_num=1, val_generative_mean_accuracy=0.979, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 58/200:  28%|▎| 57/200 [40:05<1:40:42, 42.26s/it, v_num=1, val_generative_mean_accuracy=0.985, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 59/200:  29%|▎| 58/200 [40:52<1:42:58, 43.51s/it, v_num=1, val_generative_mean_accuracy=0.983, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 60/200:  30%|▎| 59/200 [41:34<1:41:28, 43.18s/it, v_num=1, val_generative_mean_accuracy=0.987, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 61/200:  30%|▎| 60/200 [42:17<1:40:19, 43.00s/it, v_num=1, val_generative_mean_accuracy=0.982, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 62/200:  30%|▎| 61/200 [42:59<1:39:22, 42.89s/it, v_num=1, val_generative_mean_accuracy=0.987, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 63/200:  31%|▎| 62/200 [43:41<1:38:04, 42.64s/it, v_num=1, val_generative_mean_accuracy=0.986, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 64/200:  32%|▎| 63/200 [44:23<1:36:48, 42.39s/it, v_num=1, val_generative_mean_accuracy=0.984, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 65/200:  32%|▎| 64/200 [45:05<1:35:51, 42.29s/it, v_num=1, val_generative_mean_accuracy=0.984, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 66/200:  32%|▎| 65/200 [45:47<1:34:57, 42.21s/it, v_num=1, val_generative_mean_accuracy=0.987, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 67/200:  33%|▎| 66/200 [46:29<1:34:22, 42.26s/it, v_num=1, val_generative_mean_accuracy=0.987, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 68/200:  34%|▎| 67/200 [47:12<1:33:32, 42.20s/it, v_num=1, val_generative_mean_accuracy=0.986, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 69/200:  34%|▎| 68/200 [47:54<1:33:12, 42.37s/it, v_num=1, val_generative_mean_accuracy=0.988, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 70/200:  34%|▎| 69/200 [48:37<1:32:43, 42.47s/it, v_num=1, val_generative_mean_accuracy=0.987, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 71/200:  35%|▎| 70/200 [49:19<1:31:46, 42.36s/it, v_num=1, val_generative_mean_accuracy=0.988, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 72/200:  36%|▎| 71/200 [50:01<1:30:56, 42.30s/it, v_num=1, val_generative_mean_accuracy=0.986, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 73/200:  36%|▎| 72/200 [50:43<1:30:07, 42.25s/it, v_num=1, val_generative_mean_accuracy=0.989, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 74/200:  36%|▎| 73/200 [51:26<1:29:32, 42.30s/it, v_num=1, val_generative_mean_accuracy=0.986, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 75/200:  37%|▎| 74/200 [52:08<1:28:48, 42.29s/it, v_num=1, val_generative_mean_accuracy=0.99, val_generative_var_accuracy=0

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 76/200:  38%|▍| 75/200 [52:50<1:27:49, 42.16s/it, v_num=1, val_generative_mean_accuracy=0.985, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 77/200:  38%|▍| 76/200 [53:32<1:27:12, 42.20s/it, v_num=1, val_generative_mean_accuracy=0.988, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 78/200:  38%|▍| 77/200 [54:14<1:26:22, 42.13s/it, v_num=1, val_generative_mean_accuracy=0.985, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 79/200:  39%|▍| 78/200 [54:56<1:25:35, 42.09s/it, v_num=1, val_generative_mean_accuracy=0.99, val_generative_var_accuracy=0

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 80/200:  40%|▍| 79/200 [55:39<1:24:59, 42.15s/it, v_num=1, val_generative_mean_accuracy=0.987, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 81/200:  40%|▍| 80/200 [56:21<1:24:29, 42.25s/it, v_num=1, val_generative_mean_accuracy=0.987, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 82/200:  40%|▍| 81/200 [57:03<1:23:39, 42.18s/it, v_num=1, val_generative_mean_accuracy=0.987, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 83/200:  41%|▍| 82/200 [57:45<1:22:46, 42.08s/it, v_num=1, val_generative_mean_accuracy=0.991, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 84/200:  42%|▍| 83/200 [58:27<1:21:54, 42.00s/it, v_num=1, val_generative_mean_accuracy=0.989, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 85/200:  42%|▍| 84/200 [59:09<1:21:38, 42.23s/it, v_num=1, val_generative_mean_accuracy=0.988, val_generative_var_accuracy=

  self.pid = os.fork()
  self.pid = os.fork()


In [None]:
def bool2idx(x):
    """
    Returns the indices of the True-valued entries in a boolean array `x`
    """
    return np.where(x)[0]

def repeat_n(x, n):
    """
    Returns an n-times repeated version of the Tensor x,
    repetition dimension is axis 0
    """
    # copy tensor to device BEFORE replicating it n times
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return x.to(device).view(1, -1).repeat(n, 1)


In [None]:
idx_test_control = np.where(
    (adata.obs["split"] == "test") & (adata.obs["control"] == 1)
)[0]

adata_test_control = adata[idx_test_control].copy()

idx_ood = np.where(((adata.obs["split"] == "ood") & (adata.obs["control"] == 0)))[0]

adata_ood = adata[idx_ood].copy()
dataset_ood = model.get_dataset(adata_ood)

In [None]:
dataset_ood = model.get_dataset(adata_ood)

In [None]:
import pandas as pd
from tqdm import tqdm

def compute_prediction(
    model,
    adata,
    dataset,
    adata_control,
    n_obs=500
):
    pert_categories_index = pd.Index(adata.obs["condition"].values, dtype="category")

    cl_dict = {
        torch.Tensor([0.]): "A549",
        torch.Tensor([1.]): "K562",
        torch.Tensor([2.]): "MCF7",
    }

    cell_lines = ["A549", "K562", "MCF7"]

    layer = "X" if "X" in dataset else "layers"
    predictions_dict = {}
    for cell_drug_dose_comb in tqdm(np.unique(pert_categories_index.values)
    ):
        cur_cell_line = cell_drug_dose_comb.split("_")[0]
        dataset_control = model.get_dataset(adata_test_control[adata_test_control.obs.cell_type == cur_cell_line])

        bool_category = pert_categories_index.get_loc(cell_drug_dose_comb)
        idx_all = bool2idx(bool_category)
        idx = idx_all[0]
                    
        dataset_comb = {}

        dataset_comb[layer] = dataset_control[layer].to(model.device)
        dataset_comb["ind_x"] = dataset_control["ind_x"].to(model.device)
        for key in dataset_control:
            if key not in [layer, "ind_x"]:
                dataset_comb[key] = repeat_n(dataset[key][idx, :], n_obs)

        stop = False
        for tensor, cl in cl_dict.items():
            if (tensor == dataset["cell_type"][idx]).all():
                if cl not in cell_lines:
                    stop = True
        if stop:
            continue
            
        pred_mean, pred_std = model.module.get_expression(dataset_comb)
        samples = torch.normal(pred_mean, pred_std)

        predictions_dict[cell_drug_dose_comb] = samples.detach().cpu().numpy()
    return predictions_dict

In [None]:
biolord_prediction = compute_prediction(
    model=model,
    adata=adata_ood,
    dataset=dataset_ood,
    adata_control=adata_test_control)

In [None]:
import anndata as ad
all_data = []
conditions = []

for condition, array in biolord_prediction.items():
    all_data.append(array)
    conditions.extend([condition] * array.shape[0])

# Stack all data vertically to create a single array
all_data_array = np.vstack(all_data)

# Create a DataFrame for the .obs attribute
obs_data = pd.DataFrame({
    'condition': conditions
})

# Create the Anndata object
adata_ood_result = ad.AnnData(X=all_data_array, obs=obs_data)

In [None]:
adata_ood_result.write_h5ad("/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/biolord_output_ood_300.h5ad")

In [None]:
idx_test_control = np.where(
    (adata.obs["split"] == "test") & (adata.obs["control"] == 1)
)[0]

adata_test_control = adata[idx_test_control].copy()

idx_test = np.where(((adata.obs["split"] == "test") & (adata.obs["control"] == 0)))[0]

adata_test = adata[idx_test].copy()
dataset_test = model.get_dataset(adata_test)

In [None]:
dataset_test = model.get_dataset(adata_test)

In [None]:
biolord_prediction = compute_prediction(
    model=model,
    adata=adata_test,
    dataset=dataset_test,
    adata_control=adata_test_control)

In [None]:
all_data = []
conditions = []

for condition, array in biolord_prediction.items():
    all_data.append(array)
    conditions.extend([condition] * array.shape[0])

# Stack all data vertically to create a single array
all_data_array = np.vstack(all_data)

# Create a DataFrame for the .obs attribute
obs_data = pd.DataFrame({
    'condition': conditions
})

# Create the Anndata object
adata_test_result = ad.AnnData(X=all_data_array, obs=obs_data)

In [None]:
adata_test_result.write_h5ad("/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/biolord_output_test_300.h5ad")