In [4]:
import pandas as pd
import numpy as np

import scanpy as sc
import biolord

import seaborn as sns
import matplotlib.pyplot as plt
import warnings
from scipy.stats import ttest_rel
import anndata

In [2]:
adata_train_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_train_biolord_split_30.h5ad"
adata_test_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_test_biolord_split_30.h5ad"
adata_ood_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_ood_biolord_split_30.h5ad"

In [3]:
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)



In [5]:
adata = anndata.concat((adata_train, adata_test, adata_ood), label="split", keys=["train", "test", "ood"])

  utils.warn_names_duplicates("obs")


In [6]:
adata.obsm

AxisArrays with keys: X_pca, cell_line_emb, ecfp, ecfp_cell_line, ecfp_cell_line_dose, ecfp_cell_line_dose_more_dose, ecfp_cell_line_logdose, ecfp_cell_line_logdose_more_dose

In [113]:
dose = adata.obs["dose"].astype("float") / np.max(adata.obs["dose"].astype("float")) # following 

In [117]:
adata.obsm["ecfp_dose"] = np.concatenate((adata.obsm["ecfp"], dose.values[:,None]), axis=1)

In [11]:
biolord.Biolord.setup_anndata(
    adata,
    ordered_attributes_keys=["ecfp_logdose"],
    categorical_attributes_keys=["cell_type"],
    retrieval_attribute_key=None,
)

[34mINFO    [0m Generating sequential column names                                                                        


In [27]:
module_params= dict(
decoder_width = 4096,
decoder_depth = 4,

attribute_dropout_rate = 0.1,
attribute_nn_width = 2048,
attribute_nn_depth = 2,

unknown_attribute_noise_param = 2e+1,
unknown_attribute_penalty = 1e-1,
gene_likelihood = "normal",
n_latent_attribute_ordered = 256,
n_latent_attribute_categorical = 3,
reconstruction_penalty = 1e+4,
use_batch_norm = False,
use_layer_norm = False,)

trainer_params=dict(
latent_lr = 1e-4,
latent_wd = 1e-4,
decoder_lr = 1e-4,
decoder_wd = 1e-4,
attribute_nn_lr = 1e-2,
attribute_nn_wd = 4e-8,
cosine_scheduler = True,
scheduler_final_lr = 1e-5,
step_size_lr = 45,
)

In [28]:
model = biolord.Biolord(
    adata=adata,
    n_latent=256,
    model_name="sciplex3",
    module_params=module_params,
    train_classifiers=False,
    split_key="split",
)

[rank: 0] Seed set to 0


In [29]:
model.train(
    max_epochs=200,
    batch_size=512,
    plan_kwargs=trainer_params,
    early_stopping=True,
    early_stopping_patience=20,
    check_val_every_n_epoch=10,
    num_workers=10,
    enable_checkpointing=False
)

/home/icb/dominik.klein/mambaforge/envs/ot_pert_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/dominik.klein/mambaforge/envs/ot_pert_biol ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/icb/dominik.klein/mambaforge/envs/ot_pert_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/dominik.klein/mambaforge/envs/ot_pert_biol ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  self.pid = os.fork()


Epoch 2/200:   0%|          | 1/200 [00:51<2:48:30, 50.81s/it, v_num=1, val_generative_mean_accuracy=0.102, val_generative_var_accuracy=-60.8, val_biolord_metric=-30.4, val_LOSS_KEYS.RECONSTRUCTION=326, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=232]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 3/200:   1%|          | 2/200 [01:30<2:25:50, 44.19s/it, v_num=1, val_generative_mean_accuracy=0.336, val_generative_var_accuracy=-10.7, val_biolord_metric=-5.17, val_LOSS_KEYS.RECONSTRUCTION=290, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=211, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=212, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 4/200:   2%|▏         | 3/200 [02:09<2:17:33, 41.90s/it, v_num=1, val_generative_mean_accuracy=0.465, val_generative_var_accuracy=-4.76, val_biolord_metric=-2.15, val_LOSS_KEYS.RECONSTRUCTION=275, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=192, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=172, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 5/200:   2%|▏         | 4/200 [02:49<2:13:59, 41.02s/it, v_num=1, val_generative_mean_accuracy=0.574, val_generative_var_accuracy=-1.96, val_biolord_metric=-0.695, val_LOSS_KEYS.RECONSTRUCTION=262, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=174, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=157, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 6/200:   2%|▎         | 5/200 [03:28<2:11:13, 40.38s/it, v_num=1, val_generative_mean_accuracy=0.662, val_generative_var_accuracy=-0.562, val_biolord_metric=0.0497, val_LOSS_KEYS.RECONSTRUCTION=250, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=158, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=157, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 7/200:   3%|▎         | 6/200 [04:07<2:09:29, 40.05s/it, v_num=1, val_generative_mean_accuracy=0.729, val_generative_var_accuracy=-0.194, val_biolord_metric=0.268, val_LOSS_KEYS.RECONSTRUCTION=241, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=143, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=157, unknown_attribute_penalty_loss=1.03e+5] 

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 8/200:   4%|▎         | 7/200 [04:47<2:08:30, 39.95s/it, v_num=1, val_generative_mean_accuracy=0.781, val_generative_var_accuracy=0.245, val_biolord_metric=0.513, val_LOSS_KEYS.RECONSTRUCTION=235, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=130, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=156, unknown_attribute_penalty_loss=1.03e+5] 

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 9/200:   4%|▍         | 8/200 [05:27<2:07:29, 39.84s/it, v_num=1, val_generative_mean_accuracy=0.832, val_generative_var_accuracy=0.443, val_biolord_metric=0.638, val_LOSS_KEYS.RECONSTRUCTION=228, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=117, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=156, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 10/200:   4%|▍         | 9/200 [06:06<2:06:38, 39.78s/it, v_num=1, val_generative_mean_accuracy=0.865, val_generative_var_accuracy=0.65, val_biolord_metric=0.757, val_LOSS_KEYS.RECONSTRUCTION=224, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=105, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=156, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 11/200:   5%|▌         | 10/200 [06:46<2:05:59, 39.78s/it, v_num=1, val_generative_mean_accuracy=0.867, val_generative_var_accuracy=0.506, val_biolord_metric=0.687, val_LOSS_KEYS.RECONSTRUCTION=223, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=94.8, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=156, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 12/200:   6%|▌         | 11/200 [07:26<2:05:14, 39.76s/it, v_num=1, val_generative_mean_accuracy=0.895, val_generative_var_accuracy=0.722, val_biolord_metric=0.809, val_LOSS_KEYS.RECONSTRUCTION=219, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=85.1, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=156, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 13/200:   6%|▌         | 12/200 [08:05<2:04:17, 39.67s/it, v_num=1, val_generative_mean_accuracy=0.924, val_generative_var_accuracy=0.778, val_biolord_metric=0.851, val_LOSS_KEYS.RECONSTRUCTION=215, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=76.2, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=156, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 14/200:   6%|▋         | 13/200 [08:45<2:03:37, 39.67s/it, v_num=1, val_generative_mean_accuracy=0.927, val_generative_var_accuracy=0.669, val_biolord_metric=0.798, val_LOSS_KEYS.RECONSTRUCTION=215, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=68.1, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=156, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 15/200:   7%|▋         | 14/200 [09:24<2:02:46, 39.61s/it, v_num=1, val_generative_mean_accuracy=0.927, val_generative_var_accuracy=0.731, val_biolord_metric=0.829, val_LOSS_KEYS.RECONSTRUCTION=215, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=60.8, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=156, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 16/200:   8%|▊         | 15/200 [10:04<2:02:06, 39.60s/it, v_num=1, val_generative_mean_accuracy=0.941, val_generative_var_accuracy=0.694, val_biolord_metric=0.818, val_LOSS_KEYS.RECONSTRUCTION=213, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=54.1, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=156, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 17/200:   8%|▊         | 16/200 [10:44<2:01:29, 39.62s/it, v_num=1, val_generative_mean_accuracy=0.953, val_generative_var_accuracy=0.74, val_biolord_metric=0.846, val_LOSS_KEYS.RECONSTRUCTION=211, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=48, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.03e+5]   

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 18/200:   8%|▊         | 17/200 [11:23<2:00:57, 39.66s/it, v_num=1, val_generative_mean_accuracy=0.949, val_generative_var_accuracy=0.745, val_biolord_metric=0.847, val_LOSS_KEYS.RECONSTRUCTION=212, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=42.5, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 19/200:   9%|▉         | 18/200 [12:03<2:00:09, 39.61s/it, v_num=1, val_generative_mean_accuracy=0.953, val_generative_var_accuracy=0.773, val_biolord_metric=0.863, val_LOSS_KEYS.RECONSTRUCTION=211, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=37.6, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=156, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 20/200:  10%|▉         | 19/200 [12:43<1:59:31, 39.62s/it, v_num=1, val_generative_mean_accuracy=0.959, val_generative_var_accuracy=0.749, val_biolord_metric=0.854, val_LOSS_KEYS.RECONSTRUCTION=211, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=33.2, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 21/200:  10%|█         | 20/200 [13:22<1:58:52, 39.63s/it, v_num=1, val_generative_mean_accuracy=0.956, val_generative_var_accuracy=0.722, val_biolord_metric=0.839, val_LOSS_KEYS.RECONSTRUCTION=211, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=29.2, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 22/200:  10%|█         | 21/200 [14:02<1:58:12, 39.62s/it, v_num=1, val_generative_mean_accuracy=0.963, val_generative_var_accuracy=0.692, val_biolord_metric=0.827, val_LOSS_KEYS.RECONSTRUCTION=210, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=25.6, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.02e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 23/200:  11%|█         | 22/200 [14:42<1:57:42, 39.68s/it, v_num=1, val_generative_mean_accuracy=0.963, val_generative_var_accuracy=0.719, val_biolord_metric=0.841, val_LOSS_KEYS.RECONSTRUCTION=210, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=22.4, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.02e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 24/200:  12%|█▏        | 23/200 [15:22<1:57:12, 39.73s/it, v_num=1, val_generative_mean_accuracy=0.953, val_generative_var_accuracy=0.889, val_biolord_metric=0.921, val_LOSS_KEYS.RECONSTRUCTION=211, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=19.6, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 25/200:  12%|█▏        | 24/200 [16:01<1:56:26, 39.70s/it, v_num=1, val_generative_mean_accuracy=0.97, val_generative_var_accuracy=0.841, val_biolord_metric=0.906, val_LOSS_KEYS.RECONSTRUCTION=209, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=17.1, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.03e+5] 

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 26/200:  12%|█▎        | 25/200 [16:41<1:55:52, 39.73s/it, v_num=1, val_generative_mean_accuracy=0.965, val_generative_var_accuracy=0.795, val_biolord_metric=0.88, val_LOSS_KEYS.RECONSTRUCTION=209, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=14.9, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.02e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 27/200:  13%|█▎        | 26/200 [17:21<1:55:16, 39.75s/it, v_num=1, val_generative_mean_accuracy=0.972, val_generative_var_accuracy=0.887, val_biolord_metric=0.929, val_LOSS_KEYS.RECONSTRUCTION=208, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=12.9, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 28/200:  14%|█▎        | 27/200 [18:00<1:54:31, 39.72s/it, v_num=1, val_generative_mean_accuracy=0.971, val_generative_var_accuracy=0.843, val_biolord_metric=0.907, val_LOSS_KEYS.RECONSTRUCTION=209, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=11.2, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 29/200:  14%|█▍        | 28/200 [18:40<1:54:04, 39.79s/it, v_num=1, val_generative_mean_accuracy=0.974, val_generative_var_accuracy=0.917, val_biolord_metric=0.945, val_LOSS_KEYS.RECONSTRUCTION=208, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=9.63, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.02e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 30/200:  14%|█▍        | 29/200 [19:20<1:53:29, 39.82s/it, v_num=1, val_generative_mean_accuracy=0.97, val_generative_var_accuracy=0.883, val_biolord_metric=0.927, val_LOSS_KEYS.RECONSTRUCTION=209, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=8.29, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.02e+5] 

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 31/200:  15%|█▌        | 30/200 [20:00<1:52:54, 39.85s/it, v_num=1, val_generative_mean_accuracy=0.975, val_generative_var_accuracy=0.921, val_biolord_metric=0.948, val_LOSS_KEYS.RECONSTRUCTION=208, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=7.12, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.03e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 32/200:  16%|█▌        | 31/200 [20:40<1:52:17, 39.87s/it, v_num=1, val_generative_mean_accuracy=0.974, val_generative_var_accuracy=0.888, val_biolord_metric=0.931, val_LOSS_KEYS.RECONSTRUCTION=208, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=6.1, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.02e+5] 

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 33/200:  16%|█▌        | 32/200 [21:20<1:51:45, 39.92s/it, v_num=1, val_generative_mean_accuracy=0.977, val_generative_var_accuracy=0.915, val_biolord_metric=0.946, val_LOSS_KEYS.RECONSTRUCTION=208, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=5.21, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.02e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 34/200:  16%|█▋        | 33/200 [22:00<1:51:17, 39.99s/it, v_num=1, val_generative_mean_accuracy=0.977, val_generative_var_accuracy=0.949, val_biolord_metric=0.963, val_LOSS_KEYS.RECONSTRUCTION=208, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=4.44, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.02e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 35/200:  17%|█▋        | 34/200 [22:40<1:50:33, 39.96s/it, v_num=1, val_generative_mean_accuracy=0.98, val_generative_var_accuracy=0.95, val_biolord_metric=0.965, val_LOSS_KEYS.RECONSTRUCTION=207, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=3.77, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.02e+5]  

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 36/200:  18%|█▊        | 35/200 [23:20<1:49:45, 39.91s/it, v_num=1, val_generative_mean_accuracy=0.979, val_generative_var_accuracy=0.961, val_biolord_metric=0.97, val_LOSS_KEYS.RECONSTRUCTION=207, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=3.2, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.02e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 37/200:  18%|█▊        | 36/200 [24:00<1:49:14, 39.97s/it, v_num=1, val_generative_mean_accuracy=0.975, val_generative_var_accuracy=0.951, val_biolord_metric=0.963, val_LOSS_KEYS.RECONSTRUCTION=208, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=2.7, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.02e+5]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 38/200:  18%|█▊        | 37/200 [24:40<1:48:36, 39.98s/it, v_num=1, val_generative_mean_accuracy=0.979, val_generative_var_accuracy=0.955, val_biolord_metric=0.967, val_LOSS_KEYS.RECONSTRUCTION=207, val_LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY=2.28, generative_mean_accuracy=0, generative_var_accuracy=0, biolord_metric=0, reconstruction_loss=155, unknown_attribute_penalty_loss=1.02e+5]

  self.pid = os.fork()
  self.pid = os.fork()
/home/icb/dominik.klein/mambaforge/envs/ot_pert_biolord/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
  self.pid = os.fork()


In [31]:
def bool2idx(x):
    """
    Returns the indices of the True-valued entries in a boolean array `x`
    """
    return np.where(x)[0]

def repeat_n(x, n):
    """
    Returns an n-times repeated version of the Tensor x,
    repetition dimension is axis 0
    """
    # copy tensor to device BEFORE replicating it n times
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return x.to(device).view(1, -1).repeat(n, 1)


In [37]:
idx_test_control = np.where(
    (adata.obs["split"] == "test") & (adata.obs["control"] == 1)
)[0]

adata_test_control = adata[idx_test_control].copy()

idx_ood = np.where((adata.obs["split"] == "ood"))[0]

adata_ood = adata[idx_ood].copy()
dataset_ood = model.get_dataset(adata_ood)

[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


In [38]:
dataset_ood = model.get_dataset(adata_ood)

In [43]:
import torch
from tqdm import tqdm

In [45]:
dataset = dataset_ood

In [109]:
pert_categories_index = pd.Index(adata_ood.obs["condition"].values, dtype="category")
allowed_cell_lines = []
n_obs = 500

cl_dict = {
    torch.Tensor([0.]): "A549",
    torch.Tensor([1.]): "K562",
    torch.Tensor([2.]): "MCF7",
}

layer = "X" if "X" in dataset else "layers"
cell_lines = ["A549", "K562", "MCF7"]

predictions_dict = {}
for cell_drug_dose_comb in tqdm(
    np.unique(pert_categories_index.values)):

    cell_type_idx = np.unique(dataset["cell_type"][idx_all])
    assert len(cell_type_idx) == 1
    dataset_control = model.get_dataset(adata_test_control[adata_test_control.obs["_scvi_cell_type"]==cell_type_idx[0]])
    
    bool_category = pert_categories_index.get_loc(cell_drug_dose_comb)
    idx_all = bool2idx(bool_category)
    idx = idx_all[0]  
    dataset_comb = {}

    
    dataset_comb[layer] = dataset_control[layer].to(model.device)
    dataset_comb["ind_x"] = dataset_control["ind_x"].to(model.device)
    for key in dataset_control:
        if key not in [layer, "ind_x"]:
            dataset_comb[key] = repeat_n(dataset[key][idx, :], n_obs)

    stop = False
    for tensor, cl in cl_dict.items():
        if (tensor == dataset["cell_type"][idx]).all():
            if cl not in cell_lines:
                stop = True
    if stop:
        continue
        
    pred_mean, pred_std = model.module.get_expression(dataset_comb)
    samples = torch.normal(pred_mean, pred_std)

    predictions_dict[cell_drug_dose_comb] = samples
    break



  0%|          | 0/40 [00:00<?, ?it/s][A

[34mINFO    [0m Received view of anndata, making copy.                                                                    
[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             


  0%|          | 0/40 [00:05<?, ?it/s]


In [None]:
# either convert dict to adata layers, or just save it as a dictionary. the latter is fine, too.

In [110]:
predictions_dict

{'A549_Alvespimycin_(17-DMAG)_HCl_10.0': tensor([[ 2.5184e-01,  5.2117e-01,  1.3084e-01,  ..., -1.9506e-05,
           1.9830e-04, -5.7395e-04],
         [ 2.7082e-01,  7.6124e-01,  4.0564e-01,  ..., -1.7216e-05,
           2.0655e-04, -7.6847e-04],
         [ 1.7310e-01,  2.2424e-01,  3.2515e-01,  ..., -1.8057e-05,
           2.0517e-04, -1.0013e-03],
         ...,
         [ 1.5208e-01, -1.0546e-01, -5.5587e-02,  ..., -1.6713e-05,
           2.0690e-04, -9.8138e-04],
         [-2.1371e-01,  2.3126e-01,  6.0327e-02,  ..., -1.7787e-05,
           2.0536e-04, -1.3881e-03],
         [ 1.8769e-01,  3.1853e-02,  3.0973e-01,  ..., -1.7503e-05,
           2.0499e-04, -1.1228e-03]], device='cuda:0')}