In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import pickle
import numpy as np
import scanpy as sc
import pandas as pd
import anndata
import torch

import biolord

from tqdm.auto import tqdm
from scipy.sparse import issparse

In [3]:
# from biolord_reproducibility/utils/utils_perturbations.py
def bool2idx(x):
    """
    Returns the indices of the True-valued entries in a boolean array `x`
    """
    return np.where(x)[0]

def repeat_n(x, n):
    """combo_seen2
    Returns an n-times repeated version of the Tensor x,
    repetition dimension is axis 0
    """
    # copy tensor to device BEFORE replicating it n times
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return x.to(device).view(1, -1).repeat(n, 1)

In [4]:
# from biolord_reproducibility/scripts/biolord/norman/norman_optimal_config.py
varying_arg = {
    "seed": 42,
    "unknown_attribute_noise_param": 0.2, 
    "use_batch_norm": False,
    "use_layer_norm": False, 
    "step_size_lr": 45, 
    "attribute_dropout_rate": 0.0, 
    "cosine_scheduler":True,
    "scheduler_final_lr":1e-5,
    "n_latent":32, 
    "n_latent_attribute_ordered": 32,
    "reconstruction_penalty": 10000.0,
    "attribute_nn_width": 64,
    "attribute_nn_depth" :2, 
    "attribute_nn_lr": 0.001, 
    "attribute_nn_wd": 4e-8,
    "latent_lr": 0.01,
    "latent_wd": 0.00001,
    "decoder_width": 32,
    "decoder_depth": 2,  
    "decoder_activation": True,
    "attribute_nn_activation": True,
    "unknown_attributes": False,
    "decoder_lr": 0.01,
    "decoder_wd": 0.01,
    "max_epochs":200,
    "early_stopping_patience": 200,
    "ordered_attributes_key": "perturbation_neighbors1",
    "n_latent_attribute_categorical": 16,
    "unknown_attribute_penalty": 10000.0,
}

module_params = {
    "attribute_nn_width":  varying_arg["attribute_nn_width"],
    "attribute_nn_depth": varying_arg["attribute_nn_depth"],
    "use_batch_norm": varying_arg["use_batch_norm"],
    "use_layer_norm": varying_arg["use_layer_norm"],
    "attribute_dropout_rate":  varying_arg["attribute_dropout_rate"],
    "unknown_attribute_noise_param": varying_arg["unknown_attribute_noise_param"],
    "seed": varying_arg["seed"],
    "n_latent_attribute_ordered": varying_arg["n_latent_attribute_ordered"],
    "n_latent_attribute_categorical": varying_arg["n_latent_attribute_categorical"],
    "reconstruction_penalty": varying_arg["reconstruction_penalty"],
    "unknown_attribute_penalty": varying_arg["unknown_attribute_penalty"],
    "decoder_width": varying_arg["decoder_width"],
    "decoder_depth": varying_arg["decoder_depth"],
    "decoder_activation": varying_arg["decoder_activation"],
    "attribute_nn_activation": varying_arg["attribute_nn_activation"],
    "unknown_attributes": varying_arg["unknown_attributes"],
}

trainer_params = {
    "n_epochs_warmup": 0,
    "latent_lr": varying_arg["latent_lr"],
    "latent_wd": varying_arg["latent_wd"],
    "attribute_nn_lr": varying_arg["attribute_nn_lr"],
    "attribute_nn_wd": varying_arg["attribute_nn_wd"],
    "step_size_lr": varying_arg["step_size_lr"],
    "cosine_scheduler": varying_arg["cosine_scheduler"],
    "scheduler_final_lr": varying_arg["scheduler_final_lr"],
    "decoder_lr": varying_arg["decoder_lr"],
    "decoder_wd": varying_arg["decoder_wd"]
}

In [21]:
path_in = "/home/haicu/soeren.becker/repos/ot_pert_reproducibility/norman2019"
adata = sc.read_h5ad(os.path.join(path_in, "norman_preprocessed_adata", "adata_all.h5ad"))

for split_idx in tqdm(range(5)):
    adata_train = sc.read_h5ad(os.path.join(path_in, "norman_preprocessed_adata", f"adata_train_pca_3_split_{split_idx}.h5ad"))
    adata_val = sc.read_h5ad(os.path.join(path_in, "norman_preprocessed_adata", f"adata_val_pca_3_split_{split_idx}.h5ad"))
    adata_test = sc.read_h5ad(os.path.join(path_in, "norman_preprocessed_adata", f"adata_test_pca_3_split_{split_idx}.h5ad"))

    train_conditions = adata_train.obs.condition.cat.categories
    test_conditions = adata_val.obs.condition.cat.categories
    ood_conditions = adata_test.obs.condition.cat.categories

    train_idcs = adata.obs.condition.isin(train_conditions)
    test_idcs = adata.obs.condition.isin(test_conditions)
    ood_idcs = adata.obs.condition.isin(ood_conditions)

    adata.obs.loc[train_idcs, f"split{split_idx}"] = "train"
    adata.obs.loc[test_idcs, f"split{split_idx}"] = "test"
    adata.obs.loc[ood_idcs, f"split{split_idx}"] = "ood"

    _seed = split_idx + 1
    path_to_splits = os.path.join(path_in, f"splits/norman2019_simulation_{_seed}_0.75.pkl")
    split_dict = pd.read_pickle(path_to_splits)

    print(
        "train",
        len(np.intersect1d(
            adata.obs.loc[adata.obs.loc[:, f"split{split_idx}"] == "train", "condition"].values.unique(),
            split_dict["train"]
        )),
        len(split_dict["train"]), 
        len(adata.obs.loc[adata.obs.loc[:, f"split{split_idx}"] == "train", "condition"].values.unique())
    )
    
    print(
        "val",
        len(np.intersect1d(
            adata.obs.loc[adata.obs.loc[:, f"split{split_idx}"] == "test", "condition"].values.unique(),
            split_dict["val"]
        )),
        len(split_dict["val"]), 
        len(adata.obs.loc[adata.obs.loc[:, f"split{split_idx}"] == "test", "condition"].values.unique())
    )
    
    print(
        "test",
        len(np.intersect1d(
            adata.obs.loc[adata.obs.loc[:, f"split{split_idx}"] == "ood", "condition"].values.unique(),
            split_dict["test"]
        )),
        len(split_dict["test"]), 
        len(adata.obs.loc[adata.obs.loc[:, f"split{split_idx}"] == "ood", "condition"].values.unique())
    )

  0%|          | 0/5 [00:00<?, ?it/s]

train 138 139 138
val 31 31 31
test 107 107 108
train 139 140 139
val 20 20 20
test 117 117 118
train 132 133 132
val 35 35 35
test 109 109 110
train 139 140 139
val 26 26 26
test 111 111 112
train 130 131 130
val 29 29 29
test 117 117 118


In [22]:
adata.obs.loc[:, ["perturbation", "perturbation_rep"]] = adata.obs.condition.str.split("+", expand=True).rename(
    {0:"perturbation", 1: "perturbation_rep"}, axis=1
).fillna("ctrl")

adata.obs.perturbation = adata.obs.perturbation.astype(object)
adata.obs.perturbation_rep = adata.obs.perturbation_rep.astype(object)
switch_mask = (adata.obs.perturbation == "ctrl") & (adata.obs.perturbation_rep != "ctrl")    
_perturbation = adata.obs.loc[switch_mask, "perturbation"].values
_perturbation_rep = adata.obs.loc[switch_mask, "perturbation_rep"].values
adata.obs.loc[switch_mask, "perturbation"] = _perturbation_rep
adata.obs.loc[switch_mask, "perturbation_rep"] = _perturbation
adata.obs.perturbation = adata.obs.perturbation.astype("category")
adata.obs.perturbation_rep = adata.obs.perturbation_rep.astype("category")

In [23]:
# following:
# https://github.com/nitzanlab/biolord_reproducibility/blob/main/notebooks/perturbations/norman/1_perturbations_norman_preprocessing.ipynb

go_path = "/home/haicu/soeren.becker/repos/ot_pert_reproducibility/norman2019/go_essential_norman.csv"
gene_path = "/home/haicu/soeren.becker/repos/ot_pert_reproducibility/norman2019/essential_norman.pkl"
df = pd.read_csv(go_path)
df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1,['importance'])).reset_index(drop = True)
with open(gene_path, 'rb') as f:
    gene_list = pickle.load(f)

df = df[df["source"].isin(gene_list)]
def get_map(pert):
    tmp = pd.DataFrame(np.zeros(len(gene_list)), index=gene_list)
    tmp.loc[df[df.target == pert].source.values, :] = df[df.target == pert].importance.values[:, np.newaxis]
    return tmp.values.flatten()

pert2neighbor =  {i: get_map(i) for i in list(adata.obs["perturbation"].cat.categories)}    
adata.uns["pert2neighbor"] = pert2neighbor

pert2neighbor = np.asarray([val for val in adata.uns["pert2neighbor"].values()])
keep_idx = pert2neighbor.sum(0) > 0
keep_idx1 = pert2neighbor.sum(0) > 1
keep_idx2 = pert2neighbor.sum(0) > 2
keep_idx3 = pert2neighbor.sum(0) > 3

name_map = dict(adata.obs[["condition", "condition_name"]].drop_duplicates().values)
ctrl = np.asarray(adata[adata.obs["condition"].isin(["ctrl"])].X.mean(0)).flatten() 
df_perts_expression = pd.DataFrame(adata.X, index=adata.obs_names, columns=adata.var_names)
df_perts_expression["condition"] = adata.obs["condition"]
df_perts_expression = df_perts_expression.groupby(["condition"]).mean()
df_perts_expression = df_perts_expression.reset_index()

single_perts_condition = []
single_pert_val = []
double_perts = []
for pert in adata.obs["condition"].cat.categories:
    if len(pert.split("+")) == 1:
        continue
    elif "ctrl" in pert:
        single_perts_condition.append(pert)
        p1, p2 = pert.split("+")
        if p2 == "ctrl":
            single_pert_val.append(p1)
        else:
            single_pert_val.append(p2)
    else:
        double_perts.append(pert)
single_perts_condition.append("ctrl")
single_pert_val.append("ctrl")

df_singleperts_expression = pd.DataFrame(
    df_perts_expression.set_index("condition").loc[single_perts_condition].values, 
    index=single_pert_val
)
df_singleperts_emb = np.asarray([adata.uns["pert2neighbor"][p1][keep_idx] for p1 in df_singleperts_expression.index])
df_singleperts_emb1 = np.asarray([adata.uns["pert2neighbor"][p1][keep_idx1] for p1 in df_singleperts_expression.index])
df_singleperts_emb2 = np.asarray([adata.uns["pert2neighbor"][p1][keep_idx2] for p1 in df_singleperts_expression.index])
df_singleperts_emb3 = np.asarray([adata.uns["pert2neighbor"][p1][keep_idx3] for p1 in df_singleperts_expression.index])

df_singleperts_condition = pd.Index(single_perts_condition)
df_single_pert_val = pd.Index(single_pert_val)

df_doubleperts_expression = df_perts_expression.set_index("condition").loc[double_perts].values
df_doubleperts_condition = pd.Index(double_perts)
adata_single = anndata.AnnData(X=df_singleperts_expression.values, var=adata.var.copy(), dtype=df_singleperts_expression.values.dtype)
adata_single.obs_names = df_singleperts_condition
adata_single.obs["condition"] = df_singleperts_condition
adata_single.obs["perts_name"] = df_single_pert_val
adata_single.obsm["perturbation_neighbors"] = df_singleperts_emb
adata_single.obsm["perturbation_neighbors1"] = df_singleperts_emb1
adata_single.obsm["perturbation_neighbors2"] = df_singleperts_emb2
adata_single.obsm["perturbation_neighbors3"] = df_singleperts_emb3

for split_seed in range(5):
    adata_single.obs[f"split{split_seed}"] = None
    for cat in ["train","test","ood"]:
        cat_idx = adata_single.obs["condition"].isin(
            adata[adata.obs[f"split{split_seed}"] == cat].obs["condition"].cat.categories
        )
        adata_single.obs.loc[cat_idx ,f"split{split_seed}"] = cat

  df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1,['importance'])).reset_index(drop = True)
  df_perts_expression = df_perts_expression.groupby(["condition"]).mean()


In [24]:
np.random.seed(42)

ordered_attributes_key = varying_arg["ordered_attributes_key"]

biolord.Biolord.setup_anndata(
    adata_single,
    ordered_attributes_keys=[ordered_attributes_key],
    categorical_attributes_keys=None,
    retrieval_attribute_key=None,
)

[34mINFO    [0m Generating sequential column names                                                                        


In [25]:
def get_predictions(
    obs_train: pd.DataFrame, 
    obs_test: pd.DataFrame, 
    adata_single: anndata.AnnData,  
    ordered_attributes_key: str,
    df_singleperts_expression: pd.DataFrame, 
    df_single_pert_val: pd.core.indexes.base.Index,
) -> anndata.AnnData:

    train_perturbations = obs_train.condition.cat.categories
    test_perturbations = obs_test.condition.cat.categories

    obs_list, X_pred_list = [], []
    
    for perturbation in tqdm(test_perturbations):
        if perturbation == "ctrl": 
            # skip unperturbed condition
            continue
        assert perturbation not in train_perturbations
        obs_this_perturbation = obs_test.loc[obs_test.condition == perturbation]
        obs_list.append(obs_this_perturbation)
        num_samples_this_perturbation = obs_this_perturbation.shape[0]

        print(perturbation, num_samples_this_perturbation)

        adata_control = adata_single[adata_single.obs["condition"] == "ctrl"].copy()
        n_obs = adata_control.shape[0] # will be 1
        # dict with keys: 'X', 'ind_x' and 'perturbation_neighbors1'
        # X has shape [1, 5045]
        dataset_control: dict = model.get_dataset(adata_control)
        
        # dict with keys: 'X', 'ind_x' and 'perturbation_neighbors1'
        # X has shape [149, 5045]
        dataset_reference: dict = model.get_dataset(adata_single)
        
        if "ctrl" in perturbation:
    
            # single perturbation
            
            # copy ctrl cells
            dataset_pred = dataset_control.copy()
            # add 'perturbation_neighbors1' information
            idx_ref =  bool2idx(adata_single.obs["condition"] == perturbation)[0]
            dataset_pred[ordered_attributes_key] = repeat_n(dataset_reference[ordered_attributes_key][idx_ref, :], n_obs)
            pred_mean, pred_std = model.module.get_expression(dataset_pred)
            samples = torch.normal(
                pred_mean.repeat(num_samples_this_perturbation, 1),
                pred_std.repeat(num_samples_this_perturbation, 1)
            ).detach().cpu().numpy()
            X_pred_list.append(samples)
    
        else:
            # double perturbation
            test_preds_add = []
            for p in perturbation.split("+"):
                if p in train_perturbations:
                    # the perturbation corresponds to a perturbation that was seen as single perturbation during training
                    # instead of predicting the response, we hence use the training data of the single perturbation
                    test_predsp = df_singleperts_expression.values[df_single_pert_val.isin([p]), :]
                    test_preds_add.append(test_predsp.repeat(num_samples_this_perturbation, 1))
                    
                else:
                    # the perturbation does not correspond to any single perturbation seen during training
                    idx_ref =  bool2idx(adata_single.obs["perts_name"].isin([p]))[0]
                    dataset_pred = dataset_control.copy()
                    dataset_pred[ordered_attributes_key] = repeat_n(dataset_reference[ordered_attributes_key][idx_ref, :], n_obs)
                    pred_mean, pred_std = model.module.get_expression(dataset_pred)
                    _samples = torch.normal(
                        pred_mean.repeat(num_samples_this_perturbation, 1),
                        pred_std.repeat(num_samples_this_perturbation, 1)
                    ).detach().cpu().numpy()
                    test_preds_add.append(_samples)
                    
            samples = test_preds_add[0] + test_preds_add[1] - ctrl
            X_pred_list.append(samples)

    adata_pred = anndata.AnnData(
        X=np.vstack(X_pred_list),
        obs=pd.concat(obs_list),
    )
    return adata_pred

In [26]:
results = {}

In [33]:
for split_idx in range(5):

    model = biolord.Biolord(
            adata=adata_single,
            n_latent=varying_arg["n_latent"],
            model_name="norman",
            module_params=module_params,
            train_classifiers=False,
            split_key=f"split{split_idx}"
    )
    
    model.train(
        max_epochs=int(varying_arg["max_epochs"]),
        batch_size=32,
        plan_kwargs=trainer_params,
        early_stopping=True,
        early_stopping_patience=int(varying_arg["early_stopping_patience"]),
        check_val_every_n_epoch=5,
        num_workers=1,
        enable_checkpointing=False
    )

    adata_test_pred = get_predictions(
        adata[adata.obs.loc[:, f"split{split_idx}"] == "train"].obs,
        adata[adata.obs.loc[:, f"split{split_idx}"] == "test"].obs,
        adata_single, 
        ordered_attributes_key,
        df_singleperts_expression, 
        df_single_pert_val,
    )

    adata_ood_pred = get_predictions(
        adata[adata.obs.loc[:, f"split{split_idx}"] == "train"].obs,
        adata[adata.obs.loc[:, f"split{split_idx}"] == "ood"].obs,
        adata_single,  
        ordered_attributes_key,
        df_singleperts_expression, 
        df_single_pert_val,
    )

    ### add subgroup annotations to adata_ood_pred ###

    train_mask = adata_single.obs.loc[:, f"split{split_idx}"] == "train"
    train_genes = adata_single[train_mask].obs.condition.str.replace("+ctrl", "").str.replace("ctrl+", "").values

    mask_single_perturbation = adata_ood_pred.obs.condition.str.contains("ctrl")
    mask_double_perturbation_seen_0 = (
        ~adata_ood_pred.obs.condition.str.contains("ctrl") & 
        ~adata_ood_pred.obs.gene_1.isin(train_genes) & 
        ~adata_ood_pred.obs.gene_2.isin(train_genes)
    )
    mask_double_perturbation_seen_1 = (
        ~adata_ood_pred.obs.condition.str.contains("ctrl") & 
        (
            (adata_ood_pred.obs.gene_1.isin(train_genes) & ~adata_ood_pred.obs.gene_2.isin(train_genes)) | 
            (~adata_ood_pred.obs.gene_1.isin(train_genes) & adata_ood_pred.obs.gene_2.isin(train_genes))
        )
    )
    mask_double_perturbation_seen_2 = (
        ~adata_ood_pred.obs.condition.str.contains("ctrl") & 
        adata_ood_pred.obs.gene_1.isin(train_genes) & 
        adata_ood_pred.obs.gene_2.isin(train_genes)
    )
    adata_ood_pred.obs.loc[mask_single_perturbation, "subgroup"] = "single"
    adata_ood_pred.obs.loc[mask_double_perturbation_seen_0, "subgroup"] = "double_seen_0"
    adata_ood_pred.obs.loc[mask_double_perturbation_seen_1, "subgroup"] = "double_seen_1"
    adata_ood_pred.obs.loc[mask_double_perturbation_seen_2, "subgroup"] = "double_seen_2"

    display(adata_ood_pred.obs.subgroup.value_counts())

    results[split_idx] = {
        "test": adata_test_pred, 
        "ood": adata_ood_pred,
    }

[rank: 0] Seed set to 42
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/haicu/soeren.becker/miniconda3/envs/env_biolor ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/haicu/soeren.becker/miniconda3/envs/env_biolor ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-p

Training:   0%|          | 0/200 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=200` reached.


  0%|          | 0/31 [00:00<?, ?it/s]

ARRDC3+ctrl 405
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
BPGM+ZBTB1 283
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
C19orf26+ctrl 480
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CITED1+ctrl 169
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CNNM4+ctrl 376
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
ELMSAN1+ctrl 353
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
ETS2+MAP7D1 265
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
FEV+CBFA2T3 

  0%|          | 0/108 [00:00<?, ?it/s]

AHR+KLF1 412
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
ARID1A+ctrl 182
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
BCL2L11+BAK1 153
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
BCL2L11+ctrl 463
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
BCL2L11+TGFBR2 382
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
BCORL1+ctrl 456
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
C3orf72+FOXL2 49
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
C3orf72+ct

subgroup
double_seen_1    11417
single           10817
double_seen_2     4593
double_seen_0     1927
Name: count, dtype: int64

[rank: 0] Seed set to 42
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/haicu/soeren.becker/miniconda3/envs/env_biolor ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/haicu/soeren.becker/miniconda3/envs/env_biolor ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-p

Training:   0%|          | 0/200 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=200` reached.


  0%|          | 0/20 [00:00<?, ?it/s]

C3orf72+ctrl 217
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CEBPB+OSR2 188
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CELF2+ctrl 388
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
DLX2+ctrl 316
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
FEV+CBFA2T3 153
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
IGDCC3+ZBTB25 99
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
JUN+CEBPB 52
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
JUN+ctrl 235
[34m

  0%|          | 0/118 [00:00<?, ?it/s]

BAK1+ctrl 534
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
BCL2L11+BAK1 153
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
BPGM+ctrl 393
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
BPGM+SAMD1 240
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
BPGM+ZBTB1 283
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
C3orf72+FOXL2 49
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CBL+CNN1 288
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CBL+UBASH3A 50
[34m

subgroup
double_seen_1    13066
single           12151
double_seen_2     4964
double_seen_0     4091
Name: count, dtype: int64

[rank: 0] Seed set to 42
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/haicu/soeren.becker/miniconda3/envs/env_biolor ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/haicu/soeren.becker/miniconda3/envs/env_biolor ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-p

Training:   0%|          | 0/200 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=200` reached.


  0%|          | 0/35 [00:00<?, ?it/s]

BAK1+ctrl 534
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
BCL2L11+TGFBR2 382
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
BPGM+ctrl 393
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
BPGM+SAMD1 240
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CBL+TGFBR2 156
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CEBPE+CEBPB 111
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
DUSP9+ETS2 698
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
ETS2+IKZF3 388
[

  0%|          | 0/110 [00:00<?, ?it/s]

AHR+FEV 264
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
AHR+KLF1 412
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
AHR+ctrl 479
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
ARID1A+ctrl 182
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
ARRDC3+ctrl 405
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
ATL1+ctrl 305
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
BCL2L11+BAK1 153
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
BPGM+ZBTB1 283
[34mINFO

subgroup
double_seen_1    13048
single           10181
double_seen_0     4449
double_seen_2     2592
Name: count, dtype: int64

[rank: 0] Seed set to 42
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/haicu/soeren.becker/miniconda3/envs/env_biolor ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/haicu/soeren.becker/miniconda3/envs/env_biolor ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-p

Training:   0%|          | 0/200 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=200` reached.


  0%|          | 0/26 [00:00<?, ?it/s]

BAK1+ctrl 534
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
BCL2L11+BAK1 153
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CBL+PTPN9 234
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CBL+PTPN12 257
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CDKN1B+CDKN1A 98
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CDKN1B+ctrl 268
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CEBPB+MAPK1 337
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CEBPB+PTPN12 266

  0%|          | 0/112 [00:00<?, ?it/s]

AHR+FEV 264
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
AHR+KLF1 412
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
AHR+ctrl 479
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
ARID1A+ctrl 182
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
ATL1+ctrl 305
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
BCL2L11+TGFBR2 382
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
BPGM+SAMD1 240
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
C3orf72+FOXL2 49
[34mI

subgroup
double_seen_1    13587
single           11185
double_seen_2     2815
double_seen_0     1873
Name: count, dtype: int64

[rank: 0] Seed set to 42
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/haicu/soeren.becker/miniconda3/envs/env_biolor ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/haicu/soeren.becker/miniconda3/envs/env_biolor ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-p

Training:   0%|          | 0/200 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=200` reached.


  0%|          | 0/29 [00:00<?, ?it/s]

CBFA2T3+ctrl 288
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CBL+CNN1 288
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CBL+ctrl 538
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CBL+PTPN12 257
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CBL+TGFBR2 156
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CBL+UBASH3B 326
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CEBPE+CNN1 194
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CNN1+ctrl 236
[34mIN

  0%|          | 0/118 [00:00<?, ?it/s]

AHR+KLF1 412
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
BPGM+ctrl 393
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
BPGM+SAMD1 240
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
BPGM+ZBTB1 283
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
C3orf72+FOXL2 49
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
C3orf72+ctrl 217
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
C19orf26+ctrl 480
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
CBL+PTPN9 234
[

subgroup
double_seen_1    14554
single           13020
double_seen_2     3934
double_seen_0     1587
Name: count, dtype: int64

In [None]:
OUT_DIR =  "/home/haicu/soeren.becker/repos/ot_pert_reproducibility/results"

In [39]:
# save predictions

for split_seed in results.keys():

    adata[adata.obs.loc[:, f"split{split_seed}"] == "train"].write_h5ad(os.path.join(OUT_DIR, f"adata_train_{split_seed}.h5ad"))
    results[split_seed]["test"].write_h5ad(os.path.join(OUT_DIR, f"biolord_output_test_{split_seed}.h5ad"))

    # save entire ood adata
    results[split_seed]["ood"].write_h5ad(os.path.join(OUT_DIR, f"biolord_output_ood_{split_seed}.h5ad"))
    
    # save per subgroup ood adatas
    for subgroup in ["single", "double_seen_0", "double_seen_1", "double_seen_2"]:
        results[split_seed]["ood"][results[split_seed]["ood"].obs.loc[:, "subgroup"] == subgroup].write_h5ad(
            os.path.join(OUT_DIR, f"biolord_output_ood_{subgroup}_split_{split_seed}.h5ad")
        )

  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c


In [40]:
# save ground truth data

for split_seed in results.keys():
    adata[adata.obs.loc[:, f"split{split_seed}"] == "train"].write_h5ad(os.path.join(OUT_DIR, f"adata_train_{split_seed}.h5ad"))
    adata[adata.obs.loc[:, f"split{split_seed}"] == "test"].write_h5ad(os.path.join(OUT_DIR, f"adata_test_{split_seed}.h5ad"))

    for subgroup in ["single", "double_seen_0", "double_seen_1", "double_seen_2"]:
        
        _adata_ood = results[split_seed]["ood"]
        subgroup_conditions = _adata_ood[_adata_ood.obs.loc[:, "subgroup"] == subgroup].obs.condition.unique()
        select = (
            (adata.obs.loc[:, f"split{split_seed}"] == "ood") & 
            adata.obs.condition.isin(subgroup_conditions)
        )
        adata[select].write_h5ad(os.path.join(OUT_DIR, f"adata_ood_{subgroup}_split_{split_seed}.h5ad"))

  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[k