In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import pickle
import numpy as np
import scanpy as sc
import pandas as pd
import anndata
import torch

import biolord

from tqdm.auto import tqdm
from scipy.sparse import issparse

##### Setup model hyperparams & helper functions

In [3]:
# from biolord_reproducibility/utils/utils_perturbations.py
def bool2idx(x):
    """
    Returns the indices of the True-valued entries in a boolean array `x`
    """
    return np.where(x)[0]

def repeat_n(x, n):
    """
    Returns an n-times repeated version of the Tensor x,
    repetition dimension is axis 0
    """
    # copy tensor to device BEFORE replicating it n times
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return x.to(device).view(1, -1).repeat(n, 1)

In [4]:
# from biolord_reproducibility/scripts/biolord/norman/norman_optimal_config.py
varying_arg = {
    "seed": 42,
    "unknown_attribute_noise_param": 0.2, 
    "use_batch_norm": False,
    "use_layer_norm": False, 
    "step_size_lr": 45, 
    "attribute_dropout_rate": 0.0, 
    "cosine_scheduler":True,
    "scheduler_final_lr":1e-5,
    "n_latent":32, 
    "n_latent_attribute_ordered": 32,
    "reconstruction_penalty": 10000.0,
    "attribute_nn_width": 64,
    "attribute_nn_depth" :2, 
    "attribute_nn_lr": 0.001, 
    "attribute_nn_wd": 4e-8,
    "latent_lr": 0.01,
    "latent_wd": 0.00001,
    "decoder_width": 32,
    "decoder_depth": 2,  
    "decoder_activation": True,
    "attribute_nn_activation": True,
    "unknown_attributes": False,
    "decoder_lr": 0.01,
    "decoder_wd": 0.01,
    "max_epochs":200,
    "early_stopping_patience": 200,
    "ordered_attributes_key": "perturbation_neighbors1",
    "n_latent_attribute_categorical": 16,
    "unknown_attribute_penalty": 10000.0,
}

module_params = {
    "attribute_nn_width":  varying_arg["attribute_nn_width"],
    "attribute_nn_depth": varying_arg["attribute_nn_depth"],
    "use_batch_norm": varying_arg["use_batch_norm"],
    "use_layer_norm": varying_arg["use_layer_norm"],
    "attribute_dropout_rate":  varying_arg["attribute_dropout_rate"],
    "unknown_attribute_noise_param": varying_arg["unknown_attribute_noise_param"],
    "seed": varying_arg["seed"],
    "n_latent_attribute_ordered": varying_arg["n_latent_attribute_ordered"],
    "n_latent_attribute_categorical": varying_arg["n_latent_attribute_categorical"],
    "reconstruction_penalty": varying_arg["reconstruction_penalty"],
    "unknown_attribute_penalty": varying_arg["unknown_attribute_penalty"],
    "decoder_width": varying_arg["decoder_width"],
    "decoder_depth": varying_arg["decoder_depth"],
    "decoder_activation": varying_arg["decoder_activation"],
    "attribute_nn_activation": varying_arg["attribute_nn_activation"],
    "unknown_attributes": varying_arg["unknown_attributes"],
}

trainer_params = {
    "n_epochs_warmup": 0,
    "latent_lr": varying_arg["latent_lr"],
    "latent_wd": varying_arg["latent_wd"],
    "attribute_nn_lr": varying_arg["attribute_nn_lr"],
    "attribute_nn_wd": varying_arg["attribute_nn_wd"],
    "step_size_lr": varying_arg["step_size_lr"],
    "cosine_scheduler": varying_arg["cosine_scheduler"],
    "scheduler_final_lr": varying_arg["scheduler_final_lr"],
    "decoder_lr": varying_arg["decoder_lr"],
    "decoder_wd": varying_arg["decoder_wd"]
}

##### biolord wants a single adata object that contains information about split_idx and train/val/test partitions

In [5]:
path_in = "/home/haicu/soeren.becker/repos/ot_pert_reproducibility/norman2019"
adata = sc.read_h5ad(os.path.join(path_in, "norman_preprocessed_adata", "adata_all.h5ad"))

ctrl_samples_per_split = {}

for split_idx in tqdm(range(5)):
    adata_train = sc.read_h5ad(os.path.join(path_in, "norman_preprocessed_adata", f"adata_train_pca_50_split_{split_idx}.h5ad"))
    adata_val = sc.read_h5ad(os.path.join(path_in, "norman_preprocessed_adata", f"adata_val_pca_50_split_{split_idx}.h5ad"))
    adata_test = sc.read_h5ad(os.path.join(path_in, "norman_preprocessed_adata", f"adata_test_pca_50_split_{split_idx}.h5ad"))

    # will be used to determine number of samples to take from biolord
    ctrl_samples_per_split[split_idx] = adata_test.obs.control.sum()

    train_conditions = adata_train.obs.condition.cat.categories
    test_conditions = adata_val.obs.condition.cat.categories
    ood_conditions = adata_test.obs.condition.cat.categories

    train_idcs = adata.obs.condition.isin(train_conditions)
    test_idcs = adata.obs.condition.isin(test_conditions)
    ood_idcs = adata.obs.condition.isin(ood_conditions)

    adata.obs.loc[train_idcs, f"split{split_idx}"] = "train"
    adata.obs.loc[test_idcs, f"split{split_idx}"] = "test"
    adata.obs.loc[ood_idcs, f"split{split_idx}"] = "ood"

  0%|          | 0/5 [00:00<?, ?it/s]

In [7]:
ctrl_samples_per_split

{0: 500, 1: 500, 2: 500, 3: 500, 4: 500}

In [8]:
adata.obs.loc[:, ["perturbation", "perturbation_rep"]] = adata.obs.condition.str.split("+", expand=True).rename(
    {0:"perturbation", 1: "perturbation_rep"}, axis=1
).fillna("ctrl")

adata.obs.perturbation = adata.obs.perturbation.astype(object)
adata.obs.perturbation_rep = adata.obs.perturbation_rep.astype(object)
switch_mask = (adata.obs.perturbation == "ctrl") & (adata.obs.perturbation_rep != "ctrl")    
_perturbation = adata.obs.loc[switch_mask, "perturbation"].values
_perturbation_rep = adata.obs.loc[switch_mask, "perturbation_rep"].values
adata.obs.loc[switch_mask, "perturbation"] = _perturbation_rep
adata.obs.loc[switch_mask, "perturbation_rep"] = _perturbation
adata.obs.perturbation = adata.obs.perturbation.astype("category")
adata.obs.perturbation_rep = adata.obs.perturbation_rep.astype("category")

In [9]:
# following:
# https://github.com/nitzanlab/biolord_reproducibility/blob/main/notebooks/perturbations/norman/1_perturbations_norman_preprocessing.ipynb

go_path = "/home/haicu/soeren.becker/repos/ot_pert_reproducibility/norman2019/go_essential_norman.csv"
gene_path = "/home/haicu/soeren.becker/repos/ot_pert_reproducibility/norman2019/essential_norman.pkl"
df = pd.read_csv(go_path)
df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1,['importance'])).reset_index(drop = True)
with open(gene_path, 'rb') as f:
    gene_list = pickle.load(f)

df = df[df["source"].isin(gene_list)]
def get_map(pert):
    tmp = pd.DataFrame(np.zeros(len(gene_list)), index=gene_list)
    tmp.loc[df[df.target == pert].source.values, :] = df[df.target == pert].importance.values[:, np.newaxis]
    return tmp.values.flatten()

pert2neighbor =  {i: get_map(i) for i in list(adata.obs["perturbation"].cat.categories)}    
adata.uns["pert2neighbor"] = pert2neighbor

pert2neighbor = np.asarray([val for val in adata.uns["pert2neighbor"].values()])
keep_idx = pert2neighbor.sum(0) > 0
keep_idx1 = pert2neighbor.sum(0) > 1
keep_idx2 = pert2neighbor.sum(0) > 2
keep_idx3 = pert2neighbor.sum(0) > 3

name_map = dict(adata.obs[["condition", "condition_name"]].drop_duplicates().values)
ctrl = np.asarray(adata[adata.obs["condition"].isin(["ctrl"])].X.mean(0)).flatten() 
df_perts_expression = pd.DataFrame(adata.X.todense(), index=adata.obs_names, columns=adata.var_names)
df_perts_expression["condition"] = adata.obs["condition"]
df_perts_expression = df_perts_expression.groupby(["condition"]).mean()
df_perts_expression = df_perts_expression.reset_index()

single_perts_condition = []
single_pert_val = []
double_perts = []

soeren_perts = []

for pert in adata.obs["condition"].cat.categories:

    soeren_perts.append(pert)

    if len(pert.split("+")) == 1:
        continue
    elif "ctrl" in pert:
        single_perts_condition.append(pert)
        p1, p2 = pert.split("+")
        if p2 == "ctrl":
            single_pert_val.append(p1)
        else:
            single_pert_val.append(p2)
    else:
        double_perts.append(pert)
single_perts_condition.append("ctrl")
single_pert_val.append("ctrl")

df_singleperts_expression = pd.DataFrame(
    df_perts_expression.set_index("condition").loc[single_perts_condition].values, 
    index=single_pert_val
)
df_singleperts_emb = np.asarray([adata.uns["pert2neighbor"][p1][keep_idx] for p1 in df_singleperts_expression.index])
df_singleperts_emb1 = np.asarray([adata.uns["pert2neighbor"][p1][keep_idx1] for p1 in df_singleperts_expression.index])
df_singleperts_emb2 = np.asarray([adata.uns["pert2neighbor"][p1][keep_idx2] for p1 in df_singleperts_expression.index])
df_singleperts_emb3 = np.asarray([adata.uns["pert2neighbor"][p1][keep_idx3] for p1 in df_singleperts_expression.index])

df_singleperts_condition = pd.Index(single_perts_condition)
df_single_pert_val = pd.Index(single_pert_val)

df_doubleperts_expression = df_perts_expression.set_index("condition").loc[double_perts].values
df_doubleperts_condition = pd.Index(double_perts)
adata_single = anndata.AnnData(X=df_singleperts_expression.values, var=adata.var.copy(), dtype=df_singleperts_expression.values.dtype)
adata_single.obs_names = df_singleperts_condition
adata_single.obs["condition"] = df_singleperts_condition
adata_single.obs["perts_name"] = df_single_pert_val
adata_single.obsm["perturbation_neighbors"] = df_singleperts_emb
adata_single.obsm["perturbation_neighbors1"] = df_singleperts_emb1
adata_single.obsm["perturbation_neighbors2"] = df_singleperts_emb2
adata_single.obsm["perturbation_neighbors3"] = df_singleperts_emb3

for split_seed in range(5):
    adata_single.obs[f"split{split_seed}"] = None
    for cat in ["train","test","ood"]:
        cat_idx = adata_single.obs["condition"].isin(
            adata[adata.obs[f"split{split_seed}"] == cat].obs["condition"].cat.categories
        )
        adata_single.obs.loc[cat_idx ,f"split{split_seed}"] = cat

  df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1,['importance'])).reset_index(drop = True)
  df_perts_expression = df_perts_expression.groupby(["condition"]).mean()


In [10]:
df_singleperts_expression

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,5035,5036,5037,5038,5039,5040,5041,5042,5043,5044
AHR,0.0,0.002820,0.013193,0.000000,0.200627,0.315682,0.0,0.000000,0.0,0.003837,...,3.681569,4.803782,3.206625,0.904401,3.898038,1.696639,1.788903,3.210524,0.000000,0.0
ARID1A,0.0,0.000000,0.007757,0.000000,0.115991,0.223723,0.0,0.000000,0.0,0.000000,...,3.839409,4.973530,3.437334,0.911261,4.240818,1.855815,1.409777,3.526829,0.000000,0.0
ARRDC3,0.0,0.002222,0.006851,0.000000,0.209191,0.324964,0.0,0.000000,0.0,0.000000,...,4.088739,5.197714,3.632900,1.197712,4.499700,2.153821,1.671886,3.774733,0.000000,0.0
ATL1,0.0,0.002599,0.007228,0.000000,0.171573,0.240020,0.0,0.000000,0.0,0.000000,...,3.962917,5.155320,3.464915,0.919592,4.401278,2.019156,1.205559,3.596434,0.002044,0.0
BAK1,0.0,0.001722,0.010174,0.000000,0.220381,0.297742,0.0,0.000876,0.0,0.001351,...,4.079533,5.116052,3.559192,1.230604,4.496326,2.134926,1.588939,3.720367,0.002062,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ZBTB10,0.0,0.026599,0.045591,0.000000,0.143753,0.301114,0.0,0.000000,0.0,0.000000,...,4.069921,5.216878,3.545341,0.876884,4.454348,2.109554,1.094485,3.710337,0.005313,0.0
ZBTB25,0.0,0.001546,0.003796,0.000000,0.165104,0.310133,0.0,0.000000,0.0,0.000000,...,4.055984,5.174850,3.546826,1.108888,4.501656,2.157745,1.484247,3.775636,0.005404,0.0
ZC3HAV1,0.0,0.001125,0.004168,0.000000,0.189943,0.286803,0.0,0.000000,0.0,0.000000,...,4.168334,5.188219,3.576361,1.242347,4.580150,2.190308,1.560414,3.759454,0.000000,0.0
ZNF318,0.0,0.004154,0.004655,0.001078,0.249694,0.405123,0.0,0.002013,0.0,0.001179,...,4.107763,5.203619,3.574353,1.192111,4.552012,2.183263,1.386212,3.788765,0.001898,0.0


In [11]:
np.random.seed(42)

ordered_attributes_key = varying_arg["ordered_attributes_key"]

biolord.Biolord.setup_anndata(
    adata_single,
    ordered_attributes_keys=[ordered_attributes_key],
    categorical_attributes_keys=None,
    retrieval_attribute_key=None,
)

[34mINFO    [0m Generating sequential column names                                                                        


In [12]:
def get_predictions(
    model: biolord.Biolord,
    ctrl: np.ndarray,
    obs_train: pd.DataFrame, 
    obs_test: pd.DataFrame, 
    adata_single: anndata.AnnData,  
    ordered_attributes_key: str,
    df_singleperts_expression: pd.DataFrame, 
    df_single_pert_val: pd.core.indexes.base.Index,
    num_samples: int,
) -> anndata.AnnData:

    train_perturbations = obs_train.condition.cat.categories

    train_perturbations_single = np.unique(obs_train.condition.str.split("+", expand=True).loc[:, [0, 1]].values.flatten())

    test_perturbations = obs_test.condition.cat.categories

    obs_list, X_pred_list = [], []
    
    for perturbation in tqdm(test_perturbations):
        if perturbation == "ctrl": 
            # skip unperturbed condition
            continue
        assert perturbation not in train_perturbations
        obs_this_perturbation = obs_test.loc[obs_test.condition == perturbation].copy()
        # num_samples_this_perturbation = obs_this_perturbation.shape[0]
        # num_samples = obs_this_perturbation.shape[0]
        unique_obs_this_perturbation = obs_this_perturbation.drop_duplicates()
        assert unique_obs_this_perturbation.shape[0] == 1
        obs_this_perturbation = pd.concat([unique_obs_this_perturbation] * num_samples, ignore_index=True)
        obs_list.append(obs_this_perturbation)

        adata_control = adata_single[adata_single.obs["condition"] == "ctrl"].copy()
        n_obs = adata_control.shape[0] # will be 1
        # dict with keys: 'X', 'ind_x' and 'perturbation_neighbors1'
        # X has shape [1, 5045]
        dataset_control: dict = model.get_dataset(adata_control.copy())
        
        # dict with keys: 'X', 'ind_x' and 'perturbation_neighbors1'
        # X has shape [149, 5045]
        dataset_reference: dict = model.get_dataset(adata_single)

        if "ctrl" in perturbation:
    
            # single perturbation
            
            # copy ctrl cells
            dataset_pred = dataset_control.copy()
            # add 'perturbation_neighbors1' information
            idx_ref =  bool2idx(adata_single.obs["condition"] == perturbation)[0]
            dataset_pred[ordered_attributes_key] = repeat_n(dataset_reference[ordered_attributes_key][idx_ref, :], n_obs)
            pred_mean, pred_std = model.module.get_expression(dataset_pred)
            samples = torch.normal(
                pred_mean.repeat(num_samples, 1),
                pred_std.repeat(num_samples, 1)
            ).detach().cpu().numpy()

            X_pred_list.append(samples)

            obs_list[-1].loc[:, "assumed_subgroup"] = "single"
    
        else:
            # double perturbation

            assumed_subgroup = []

            test_preds_add = []
            for p in perturbation.split("+"):
                
                # if p in train_perturbations:
                if p in train_perturbations_single:
                    # the perturbation corresponds to a perturbation that was seen as single perturbation during training
                    # instead of predicting the response, we hence use the training data of the single perturbation
                    test_predsp = df_singleperts_expression.values[df_single_pert_val.isin([p]), :]
                    test_predsp = test_predsp[0, :].reshape(1, -1)
                    test_preds_add.append(test_predsp.repeat(num_samples, axis=0)) # test_predsp is a numpy array
                    assumed_subgroup.append("seen")
                    
                else:
                    # the perturbation does not correspond to any single perturbation seen during training
                    idx_ref =  bool2idx(adata_single.obs["perts_name"].isin([p]))[0]
                    dataset_pred = dataset_control.copy()
                    dataset_pred[ordered_attributes_key] = repeat_n(dataset_reference[ordered_attributes_key][idx_ref, :], n_obs)
                    pred_mean, pred_std = model.module.get_expression(dataset_pred)

                    _samples = torch.normal(
                        pred_mean.repeat(num_samples, 1), # pred_mean is a torch tensor
                        pred_std.repeat(num_samples, 1) # pred_mean is a torch tensor
                    ).detach().cpu().numpy()
                    test_preds_add.append(_samples)

                    assumed_subgroup.append("unseen")

            if ("seen" in assumed_subgroup) and ("unseen" in assumed_subgroup):
                assumed_subgroup = "double_seen_1"
            elif "seen" in assumed_subgroup:
                assumed_subgroup = "double_seen_2"
            elif "unseen" in assumed_subgroup:
                assumed_subgroup = "double_seen_0"

            samples = test_preds_add[0] + test_preds_add[1] - ctrl
            X_pred_list.append(samples)
            obs_list[-1].loc[:, "assumed_subgroup"] = assumed_subgroup

    adata_pred = anndata.AnnData(
        X=np.vstack(X_pred_list),
        obs=pd.concat(obs_list),
    )
    return adata_pred

In [13]:
results = {}

In [14]:
for split_idx in range(5):

    model = biolord.Biolord(
            adata=adata_single,
            n_latent=varying_arg["n_latent"],
            model_name="norman",
            module_params=module_params,
            train_classifiers=False,
            split_key=f"split{split_idx}"
    )
    
    model.train(
        max_epochs=int(varying_arg["max_epochs"]),
        batch_size=32,
        plan_kwargs=trainer_params,
        early_stopping=True,
        early_stopping_patience=int(varying_arg["early_stopping_patience"]),
        check_val_every_n_epoch=5,
        num_workers=1,
        enable_checkpointing=False
    )

    adata_test_pred = get_predictions(
        model=model,
        ctrl=ctrl,
        obs_train=adata[adata.obs.loc[:, f"split{split_idx}"] == "train"].obs,
        obs_test=adata[adata.obs.loc[:, f"split{split_idx}"] == "test"].obs,
        adata_single=adata_single, 
        ordered_attributes_key=ordered_attributes_key,
        df_singleperts_expression=df_singleperts_expression, 
        df_single_pert_val=df_single_pert_val,
        num_samples=ctrl_samples_per_split[split_idx],
    )

    adata_ood_pred = get_predictions(
        model=model,
        ctrl=ctrl,
        obs_train=adata[adata.obs.loc[:, f"split{split_idx}"] == "train"].obs,
        obs_test=adata[adata.obs.loc[:, f"split{split_idx}"] == "ood"].obs,
        adata_single=adata_single,  
        ordered_attributes_key=ordered_attributes_key,
        df_singleperts_expression=df_singleperts_expression, 
        df_single_pert_val=df_single_pert_val,
        num_samples=ctrl_samples_per_split[split_idx],
    )

    ### add subgroup annotations to adata_ood_pred ###

    train_mask = adata_single.obs.loc[:, f"split{split_idx}"] == "train"
    train_genes = adata_single[train_mask].obs.condition.str.replace("+ctrl", "").str.replace("ctrl+", "").values

    mask_single_perturbation = adata_ood_pred.obs.condition.str.contains("ctrl")
    mask_double_perturbation_seen_0 = (
        ~adata_ood_pred.obs.condition.str.contains("ctrl") & 
        ~adata_ood_pred.obs.gene_1.isin(train_genes) & 
        ~adata_ood_pred.obs.gene_2.isin(train_genes)
    )
    mask_double_perturbation_seen_1 = (
        ~adata_ood_pred.obs.condition.str.contains("ctrl") & 
        (
            (adata_ood_pred.obs.gene_1.isin(train_genes) & ~adata_ood_pred.obs.gene_2.isin(train_genes)) | 
            (~adata_ood_pred.obs.gene_1.isin(train_genes) & adata_ood_pred.obs.gene_2.isin(train_genes))
        )
    )
    mask_double_perturbation_seen_2 = (
        ~adata_ood_pred.obs.condition.str.contains("ctrl") & 
        adata_ood_pred.obs.gene_1.isin(train_genes) & 
        adata_ood_pred.obs.gene_2.isin(train_genes)
    )
    adata_ood_pred.obs.loc[mask_single_perturbation, "subgroup"] = "single"
    adata_ood_pred.obs.loc[mask_double_perturbation_seen_0, "subgroup"] = "double_seen_0"
    adata_ood_pred.obs.loc[mask_double_perturbation_seen_1, "subgroup"] = "double_seen_1"
    adata_ood_pred.obs.loc[mask_double_perturbation_seen_2, "subgroup"] = "double_seen_2"

    display(adata_ood_pred.obs.subgroup.value_counts())

    results[split_idx] = {
        "test": adata_test_pred, 
        "ood": adata_ood_pred,
    }

Seed set to 42
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/haicu/soeren.becker/miniconda3/envs/env_biolor ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/haicu/soeren.becker/miniconda3/envs/env_biolor ...
You are using a CUDA device ('NVIDIA A100-PCIE-40GB MIG 3g.20gb') that has Tensor Cores. To properly utilize them, you should s

Training:   0%|          | 0/200 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=200` reached.


  0%|          | 0/31 [00:00<?, ?it/s]

ARRDC3+ctrl 500 (500, 17)
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
Index(['AHR+FEV', 'AHR+ctrl', 'ATL1+ctrl', 'BAK1+ctrl', 'BPGM+ctrl',
       'BPGM+SAMD1', 'CBFA2T3+ctrl', 'CEBPA+ctrl', 'CEBPE+CEBPA', 'CEBPE+ctrl',
       ...
       'TSC22D1+ctrl', 'UBASH3A+ctrl', 'UBASH3B+CNN1', 'UBASH3B+ctrl',
       'UBASH3B+UBASH3A', 'ZBTB25+ctrl', 'ZC3HAV1+CEBPA', 'ZC3HAV1+CEBPE',
       'ZC3HAV1+ctrl', 'ZNF318+ctrl'],
      dtype='object', length=138)
samples.shape (500, 5045)
BPGM+ZBTB1 500 (500, 17)
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
Index(['AHR+FEV', 'AHR+ctrl', 'ATL1+ctrl', 'BAK1+ctrl', 'BPGM+ctrl',
       'BPGM+SAMD1', 'CBFA2T3+ctrl', 'CEBPA+ctrl', 'CEBPE+CEBPA', 'CEBPE+ctrl',
       ...
       'TSC22D1+ctrl', 'UBASH3A+ctrl', 'UBASH3B+CNN1', 'UBASH3B+ctrl',
       'UBASH3B+UBASH3A', 'ZBTB25+ctrl', 'ZC3HAV1+CEBPA', 'ZC3H

  utils.warn_names_duplicates("obs")


  0%|          | 0/108 [00:00<?, ?it/s]

AHR+KLF1 500 (500, 17)
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
Index(['AHR+FEV', 'AHR+ctrl', 'ATL1+ctrl', 'BAK1+ctrl', 'BPGM+ctrl',
       'BPGM+SAMD1', 'CBFA2T3+ctrl', 'CEBPA+ctrl', 'CEBPE+CEBPA', 'CEBPE+ctrl',
       ...
       'TSC22D1+ctrl', 'UBASH3A+ctrl', 'UBASH3B+CNN1', 'UBASH3B+ctrl',
       'UBASH3B+UBASH3A', 'ZBTB25+ctrl', 'ZC3HAV1+CEBPA', 'ZC3HAV1+CEBPE',
       'ZC3HAV1+ctrl', 'ZNF318+ctrl'],
      dtype='object', length=138)
test_predsp.shape (1, 5045) (500, 5045)
pred_mean.shape torch.Size([5045])
pred_std.shape torch.Size([5045])
_samples.shape (500, 5045)
['seen', 'unseen']
test_preds_add[0].shape (500, 5045)
test_preds_add[1].shape (500, 5045)
ctrl.shape (5045,)
ARID1A+ctrl 500 (500, 17)
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
Index(['AHR+FEV', 'AHR+ctrl', 'ATL1+ctrl', 'BAK1+ctrl', 'BPGM+ctrl',
       

  utils.warn_names_duplicates("obs")


subgroup
double_seen_1    22000
single           18000
double_seen_2     7500
double_seen_0     6000
Name: count, dtype: int64

Seed set to 42
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/haicu/soeren.becker/miniconda3/envs/env_biolor ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/haicu/soeren.becker/miniconda3/envs/env_biolor ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/li

Training:   0%|          | 0/200 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=200` reached.


  0%|          | 0/20 [00:00<?, ?it/s]

C3orf72+ctrl 500 (500, 17)
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
Index(['AHR+FEV', 'AHR+KLF1', 'AHR+ctrl', 'ARID1A+ctrl', 'ARRDC3+ctrl',
       'ATL1+ctrl', 'BCL2L11+ctrl', 'BCL2L11+TGFBR2', 'BCORL1+ctrl',
       'C19orf26+ctrl',
       ...
       'TP73+ctrl', 'TSC22D1+ctrl', 'UBASH3B+ctrl', 'UBASH3B+OSR2',
       'UBASH3B+PTPN12', 'ZBTB1+ctrl', 'ZBTB25+ctrl', 'ZC3HAV1+CEBPE',
       'ZC3HAV1+ctrl', 'ZNF318+ctrl'],
      dtype='object', length=139)
samples.shape (500, 5045)
CEBPB+OSR2 500 (500, 17)
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
Index(['AHR+FEV', 'AHR+KLF1', 'AHR+ctrl', 'ARID1A+ctrl', 'ARRDC3+ctrl',
       'ATL1+ctrl', 'BCL2L11+ctrl', 'BCL2L11+TGFBR2', 'BCORL1+ctrl',
       'C19orf26+ctrl',
       ...
       'TP73+ctrl', 'TSC22D1+ctrl', 'UBASH3B+ctrl', 'UBASH3B+OSR2',
       'UBASH3B+PTPN12', 'ZBTB1+ctrl', '

  utils.warn_names_duplicates("obs")


  0%|          | 0/118 [00:00<?, ?it/s]

BAK1+ctrl 500 (500, 17)
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
Index(['AHR+FEV', 'AHR+KLF1', 'AHR+ctrl', 'ARID1A+ctrl', 'ARRDC3+ctrl',
       'ATL1+ctrl', 'BCL2L11+ctrl', 'BCL2L11+TGFBR2', 'BCORL1+ctrl',
       'C19orf26+ctrl',
       ...
       'TP73+ctrl', 'TSC22D1+ctrl', 'UBASH3B+ctrl', 'UBASH3B+OSR2',
       'UBASH3B+PTPN12', 'ZBTB1+ctrl', 'ZBTB25+ctrl', 'ZC3HAV1+CEBPE',
       'ZC3HAV1+ctrl', 'ZNF318+ctrl'],
      dtype='object', length=139)
samples.shape (500, 5045)
BCL2L11+BAK1 500 (500, 17)
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
Index(['AHR+FEV', 'AHR+KLF1', 'AHR+ctrl', 'ARID1A+ctrl', 'ARRDC3+ctrl',
       'ATL1+ctrl', 'BCL2L11+ctrl', 'BCL2L11+TGFBR2', 'BCORL1+ctrl',
       'C19orf26+ctrl',
       ...
       'TP73+ctrl', 'TSC22D1+ctrl', 'UBASH3B+ctrl', 'UBASH3B+OSR2',
       'UBASH3B+PTPN12', 'ZBTB1+ctrl', 'Z

  utils.warn_names_duplicates("obs")


subgroup
double_seen_1    25500
single           18500
double_seen_2     7500
double_seen_0     7000
Name: count, dtype: int64

Seed set to 42
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/haicu/soeren.becker/miniconda3/envs/env_biolor ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/haicu/soeren.becker/miniconda3/envs/env_biolor ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/li

Training:   0%|          | 0/200 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=200` reached.


  0%|          | 0/35 [00:00<?, ?it/s]

BAK1+ctrl 500 (500, 17)
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
Index(['BCL2L11+ctrl', 'BCORL1+ctrl', 'C3orf72+FOXL2', 'C3orf72+ctrl',
       'C19orf26+ctrl', 'CBL+ctrl', 'CBL+PTPN9', 'CBL+UBASH3A', 'CDKN1A+ctrl',
       'CDKN1B+CDKN1A',
       ...
       'TSC22D1+ctrl', 'UBASH3A+ctrl', 'ZBTB1+ctrl', 'ZBTB25+ctrl',
       'ZC3HAV1+CEBPA', 'ZC3HAV1+CEBPE', 'ZC3HAV1+HOXC13', 'ZC3HAV1+ctrl',
       'ZNF318+FOXL2', 'ZNF318+ctrl'],
      dtype='object', length=132)
samples.shape (500, 5045)
BCL2L11+TGFBR2 500 (500, 17)
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
Index(['BCL2L11+ctrl', 'BCORL1+ctrl', 'C3orf72+FOXL2', 'C3orf72+ctrl',
       'C19orf26+ctrl', 'CBL+ctrl', 'CBL+PTPN9', 'CBL+UBASH3A', 'CDKN1A+ctrl',
       'CDKN1B+CDKN1A',
       ...
       'TSC22D1+ctrl', 'UBASH3A+ctrl', 'ZBTB1+ctrl', 'ZBTB25+ctrl',
       'ZC3HAV1+C

  utils.warn_names_duplicates("obs")


  0%|          | 0/110 [00:00<?, ?it/s]

AHR+FEV 500 (500, 17)
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
Index(['BCL2L11+ctrl', 'BCORL1+ctrl', 'C3orf72+FOXL2', 'C3orf72+ctrl',
       'C19orf26+ctrl', 'CBL+ctrl', 'CBL+PTPN9', 'CBL+UBASH3A', 'CDKN1A+ctrl',
       'CDKN1B+CDKN1A',
       ...
       'TSC22D1+ctrl', 'UBASH3A+ctrl', 'ZBTB1+ctrl', 'ZBTB25+ctrl',
       'ZC3HAV1+CEBPA', 'ZC3HAV1+CEBPE', 'ZC3HAV1+HOXC13', 'ZC3HAV1+ctrl',
       'ZNF318+FOXL2', 'ZNF318+ctrl'],
      dtype='object', length=132)
pred_mean.shape torch.Size([5045])
pred_std.shape torch.Size([5045])
_samples.shape (500, 5045)
test_predsp.shape (1, 5045) (500, 5045)
['unseen', 'seen']
test_preds_add[0].shape (500, 5045)
test_preds_add[1].shape (500, 5045)
ctrl.shape (5045,)
AHR+KLF1 500 (500, 17)
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
Index(['BCL2L11+ctrl', 'BCORL1+ctrl', 'C3orf72+FOXL2', 'C3

  utils.warn_names_duplicates("obs")


subgroup
double_seen_1    26000
single           17500
double_seen_0     6500
double_seen_2     4500
Name: count, dtype: int64

Seed set to 42
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/haicu/soeren.becker/miniconda3/envs/env_biolor ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/haicu/soeren.becker/miniconda3/envs/env_biolor ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/li

Training:   0%|          | 0/200 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=200` reached.


  0%|          | 0/26 [00:00<?, ?it/s]

BAK1+ctrl 500 (500, 17)
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
Index(['ARRDC3+ctrl', 'BCL2L11+ctrl', 'BCORL1+ctrl', 'BPGM+ctrl', 'BPGM+ZBTB1',
       'C19orf26+ctrl', 'CBFA2T3+ctrl', 'CBL+ctrl', 'CBL+UBASH3B',
       'CDKN1A+ctrl',
       ...
       'TBX3+ctrl', 'TSC22D1+ctrl', 'UBASH3B+CNN1', 'UBASH3B+ctrl',
       'UBASH3B+PTPN9', 'ZBTB1+ctrl', 'ZBTB10+DLX2', 'ZBTB10+ctrl',
       'ZNF318+FOXL2', 'ZNF318+ctrl'],
      dtype='object', length=139)
samples.shape (500, 5045)
BCL2L11+BAK1 500 (500, 17)
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
Index(['ARRDC3+ctrl', 'BCL2L11+ctrl', 'BCORL1+ctrl', 'BPGM+ctrl', 'BPGM+ZBTB1',
       'C19orf26+ctrl', 'CBFA2T3+ctrl', 'CBL+ctrl', 'CBL+UBASH3B',
       'CDKN1A+ctrl',
       ...
       'TBX3+ctrl', 'TSC22D1+ctrl', 'UBASH3B+CNN1', 'UBASH3B+ctrl',
       'UBASH3B+PTPN9', 'ZBTB1+ctrl'

  utils.warn_names_duplicates("obs")


  0%|          | 0/112 [00:00<?, ?it/s]

AHR+FEV 500 (500, 17)
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
Index(['ARRDC3+ctrl', 'BCL2L11+ctrl', 'BCORL1+ctrl', 'BPGM+ctrl', 'BPGM+ZBTB1',
       'C19orf26+ctrl', 'CBFA2T3+ctrl', 'CBL+ctrl', 'CBL+UBASH3B',
       'CDKN1A+ctrl',
       ...
       'TBX3+ctrl', 'TSC22D1+ctrl', 'UBASH3B+CNN1', 'UBASH3B+ctrl',
       'UBASH3B+PTPN9', 'ZBTB1+ctrl', 'ZBTB10+DLX2', 'ZBTB10+ctrl',
       'ZNF318+FOXL2', 'ZNF318+ctrl'],
      dtype='object', length=139)
pred_mean.shape torch.Size([5045])
pred_std.shape torch.Size([5045])
_samples.shape (500, 5045)
test_predsp.shape (1, 5045) (500, 5045)
['unseen', 'seen']
test_preds_add[0].shape (500, 5045)
test_preds_add[1].shape (500, 5045)
ctrl.shape (5045,)
AHR+KLF1 500 (500, 17)
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
Index(['ARRDC3+ctrl', 'BCL2L11+ctrl', 'BCORL1+ctrl', 'BPGM+ctrl', 'BPG

  utils.warn_names_duplicates("obs")


subgroup
double_seen_1    26500
single           18500
double_seen_2     6000
double_seen_0     4500
Name: count, dtype: int64

Seed set to 42
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/haicu/soeren.becker/miniconda3/envs/env_biolor ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/haicu/soeren.becker/miniconda3/envs/env_biolor ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/haicu/soeren.becker/miniconda3/envs/env_biolord/lib/python3.12/site-packages/li

Training:   0%|          | 0/200 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=200` reached.


  0%|          | 0/29 [00:00<?, ?it/s]

CBFA2T3+ctrl 500 (500, 17)
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
Index(['AHR+FEV', 'AHR+ctrl', 'ARID1A+ctrl', 'ARRDC3+ctrl', 'ATL1+ctrl',
       'BAK1+ctrl', 'BCL2L11+BAK1', 'BCL2L11+ctrl', 'BCL2L11+TGFBR2',
       'BCORL1+ctrl',
       ...
       'TBX3+TBX2', 'TGFBR2+ETS2', 'TGFBR2+IGDCC3', 'TGFBR2+ctrl',
       'TGFBR2+PRTG', 'UBASH3B+ctrl', 'UBASH3B+PTPN12', 'ZBTB1+ctrl',
       'ZC3HAV1+ctrl', 'ZNF318+ctrl'],
      dtype='object', length=130)
samples.shape (500, 5045)
CBL+CNN1 500 (500, 17)
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
Index(['AHR+FEV', 'AHR+ctrl', 'ARID1A+ctrl', 'ARRDC3+ctrl', 'ATL1+ctrl',
       'BAK1+ctrl', 'BCL2L11+BAK1', 'BCL2L11+ctrl', 'BCL2L11+TGFBR2',
       'BCORL1+ctrl',
       ...
       'TBX3+TBX2', 'TGFBR2+ETS2', 'TGFBR2+IGDCC3', 'TGFBR2+ctrl',
       'TGFBR2+PRTG', 'UBASH3B+ctrl', 'UBASH3

  utils.warn_names_duplicates("obs")


  0%|          | 0/118 [00:00<?, ?it/s]

AHR+KLF1 500 (500, 17)
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
Index(['AHR+FEV', 'AHR+ctrl', 'ARID1A+ctrl', 'ARRDC3+ctrl', 'ATL1+ctrl',
       'BAK1+ctrl', 'BCL2L11+BAK1', 'BCL2L11+ctrl', 'BCL2L11+TGFBR2',
       'BCORL1+ctrl',
       ...
       'TBX3+TBX2', 'TGFBR2+ETS2', 'TGFBR2+IGDCC3', 'TGFBR2+ctrl',
       'TGFBR2+PRTG', 'UBASH3B+ctrl', 'UBASH3B+PTPN12', 'ZBTB1+ctrl',
       'ZC3HAV1+ctrl', 'ZNF318+ctrl'],
      dtype='object', length=130)
test_predsp.shape (1, 5045) (500, 5045)
pred_mean.shape torch.Size([5045])
pred_std.shape torch.Size([5045])
_samples.shape (500, 5045)
['seen', 'unseen']
test_preds_add[0].shape (500, 5045)
test_preds_add[1].shape (500, 5045)
ctrl.shape (5045,)
BPGM+ctrl 500 (500, 17)
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        
Index(['AHR+FEV', 'AHR+ctrl', 'ARID1A+ctrl', 'ARRDC3+ctrl', 'ATL1+ctrl'

  utils.warn_names_duplicates("obs")


subgroup
double_seen_1    27500
single           20000
double_seen_2     6500
double_seen_0     4500
Name: count, dtype: int64

In [15]:
OUT_DIR =  "/home/haicu/soeren.becker/repos/ot_pert_reproducibility/results_debug_biolord"
os.makedirs(OUT_DIR, exist_ok=True)

In [16]:
# save predictions

for split_seed in results.keys():

    adata[adata.obs.loc[:, f"split{split_seed}"] == "train"].write_h5ad(os.path.join(OUT_DIR, f"biolord_adata_train_{split_seed}.h5ad"))
    results[split_seed]["test"].write_h5ad(os.path.join(OUT_DIR, f"biolord2_adata_pred_test_{split_seed}.h5ad"))

    # save entire ood adata
    results[split_seed]["ood"].write_h5ad(os.path.join(OUT_DIR, f"biolord2_adata_pred_ood_{split_seed}.h5ad"))
    
    # save per subgroup ood adatas
    for subgroup in ["single", "double_seen_0", "double_seen_1", "double_seen_2"]:
        results[split_seed]["ood"][results[split_seed]["ood"].obs.loc[:, "subgroup"] == subgroup].write_h5ad(
            os.path.join(OUT_DIR, f"biolord2_adata_pred_ood_{subgroup}_split_{split_seed}.h5ad")
        )

  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c


In [17]:
# save ground truth data

for split_seed in results.keys():
    adata[adata.obs.loc[:, f"split{split_seed}"] == "train"].write_h5ad(os.path.join(OUT_DIR, f"adata_train_{split_seed}.h5ad"))
    adata[adata.obs.loc[:, f"split{split_seed}"] == "test"].write_h5ad(os.path.join(OUT_DIR, f"adata_test_{split_seed}.h5ad"))
    adata[adata.obs.loc[:, f"split{split_seed}"] == "ood"].write_h5ad(os.path.join(OUT_DIR, f"adata_ood_{split_seed}.h5ad"))

    # for subgroup in ["single", "double_seen_0", "double_seen_1", "double_seen_2"]:
        
    #     _adata_ood = results[split_seed]["ood"]
    #     subgroup_conditions = _adata_ood[_adata_ood.obs.loc[:, "subgroup"] == subgroup].obs.condition.unique()
    #     select = (
    #         (adata.obs.loc[:, f"split{split_seed}"] == "ood") & 
    #         adata.obs.condition.isin(subgroup_conditions)
    #     )
    #     adata[select].write_h5ad(os.path.join(OUT_DIR, f"adata_ood_{subgroup}_split_{split_seed}.h5ad"))

  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[k

In [None]:
results[0]["ood"].obs.condition.unique()

In [None]:
base_path_gears = "/lustre/groups/ml01/workspace/leander.dony/projects/cellflow/241106_gears/evaluate_gears/norman/output/adatas_gearsoptimal"
paths_gears = [
    "norman_gearsoptimal_seed1_test_unseensingle_pred.h5ad",
    "norman_gearsoptimal_seed1_test_comboseen0_pred.h5ad",
    "norman_gearsoptimal_seed1_test_comboseen1_pred.h5ad",
    "norman_gearsoptimal_seed1_test_comboseen2_pred.h5ad"
]
adatas_gears = []
for subgroup, path in zip(["single", "double_seen_0", "double_seen_1", "double_seen_2"], paths_gears):
    _tmp = sc.read_h5ad(os.path.join(base_path_gears, path))
    _tmp.obs.loc[:, "subgroup"] = subgroup
    adatas_gears.append(_tmp)

In [131]:
adata_gears = sc.concat(adatas_gears)

In [None]:
np.array(sorted(adata_gears.obs.condition.unique()))

In [None]:
np.array(sorted(results[0]["ood"].obs.condition.unique()))

In [150]:
results[0]["ood"].obs.condition = results[0]["ood"].obs.condition.astype(str)

In [None]:
results[0]["ood"].obs.loc[:, ["condition", "subgroup"]].sort_values(by="condition")

In [None]:
adata_gears.obs.sort_values(by="condition")

In [175]:
df_bio = results[0]["ood"].obs.loc[:, ["condition", "subgroup"]].sort_values(by="condition").rename({"subgroup": "subgroup_bio"}, axis=1).drop_duplicates()
df_gears = adata_gears.obs.sort_values(by="condition").drop_duplicates()

In [181]:
seen_genes = np.unique(adata_train.obs.gene_1.tolist()+adata_train.obs.gene_2.tolist())

In [None]:
"AHR" in seen_genes, "KLF1" in seen_genes

In [None]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    display(df_bio.merge(df_gears, on="condition"))

In [185]:
train_conds_gears = ['TSC22D1+ctrl',
 'ctrl',
 'CEBPE+RUNX1T1',
 'MAML2+ctrl',
 'ctrl+CEBPE',
 'SGK1+TBX3',
 'ctrl+FOXA1',
 'FOXA3+FOXA1',
 'ETS2+IGDCC3',
 'GLB1L2+ctrl',
 'MAP2K6+IKZF3',
 'BAK1+ctrl',
 'FEV+ctrl',
 'MAP2K3+SLC38A2',
 'ctrl+ETS2',
 'ctrl+FEV',
 'ctrl+SET',
 'TBX3+ctrl',
 'LHX1+ctrl',
 'RREB1+ctrl',
 'ZNF318+ctrl',
 'ctrl+ZBTB25',
 'MAP4K5+ctrl',
 'UBASH3B+ctrl',
 'SLC6A9+ctrl',
 'MIDN+ctrl',
 'DLX2+ctrl',
 'CBFA2T3+ctrl',
 'HES7+ctrl',
 'SET+CEBPE',
 'IGDCC3+ZBTB25',
 'AHR+ctrl',
 'FOXO4+ctrl',
 'ctrl+CBFA2T3',
 'ctrl+RUNX1T1',
 'POU3F2+ctrl',
 'ctrl+CNN1',
 'IGDCC3+MAPK1',
 'MAP2K3+ctrl',
 'MAP4K3+ctrl',
 'ZBTB25+ctrl',
 'ZC3HAV1+CEBPE',
 'UBASH3B+UBASH3A',
 'MAP2K3+MAP2K6',
 'PTPN1+ctrl',
 'RUNX1T1+ctrl',
 'PTPN12+ctrl',
 'TP73+ctrl',
 'ctrl+MAP7D1',
 'FOSB+ctrl',
 'MAPK1+ctrl',
 'IRF1+ctrl',
 'TMSB4X+BAK1',
 'BPGM+SAMD1',
 'IKZF3+ctrl',
 'HOXB9+ctrl',
 'ctrl+HOXC13',
 'MAPK1+IKZF3',
 'ctrl+UBASH3B',
 'ctrl+HOXB9',
 'ETS2+ctrl',
 'CLDN6+ctrl',
 'FOXA3+ctrl',
 'CEBPE+ctrl',
 'KIF18B+KIF2C',
 'ctrl+SAMD1',
 'COL1A1+ctrl',
 'PTPN12+UBASH3A',
 'FOXF1+ctrl',
 'FEV+MAP7D1',
 'PLK4+ctrl',
 'BPGM+ctrl',
 'LYL1+ctrl',
 'ctrl+MAP2K6',
 'SGK1+ctrl',
 'MAPK1+TGFBR2',
 'ctrl+DLX2',
 'MAP2K6+ctrl',
 'ctrl+TBX3',
 'CNN1+ctrl',
 'ctrl+CEBPA',
 'HNF4A+ctrl',
 'MAP7D1+ctrl',
 'PTPN12+SNAI1',
 'KMT2A+ctrl',
 'CNN1+UBASH3A',
 'IGDCC3+ctrl',
 'ISL2+ctrl',
 'TGFBR2+IGDCC3',
 'TMSB4X+ctrl',
 'KIF2C+ctrl',
 'ctrl+CLDN6',
 'ctrl+KIF2C',
 'IRF1+SET',
 'CSRNP1+ctrl',
 'CEBPE+CEBPA',
 'ctrl+UBASH3A',
 'NCL+ctrl',
 'ctrl+BAK1',
 'ctrl+IKZF3',
 'FOXF1+HOXB9',
 'UBASH3B+CNN1',
 'ZC3HAV1+ctrl',
 'SET+ctrl',
 'FOSB+UBASH3B',
 'SNAI1+UBASH3B',
 'ctrl+STIL',
 'HOXC13+ctrl',
 'ATL1+ctrl',
 'CEBPE+PTPN12',
 'CEBPA+ctrl',
 'NIT1+ctrl',
 'SAMD1+UBASH3B',
 'TGFBR2+ctrl',
 'SAMD1+TGFBR2',
 'FOXA1+ctrl',
 'SAMD1+ctrl',
 'ctrl+MAPK1',
 'UBASH3A+ctrl',
 'AHR+FEV',
 'ETS2+IKZF3',
 'ctrl+ISL2',
 'ctrl+SLC38A2',
 'PTPN12+ZBTB25',
 'ctrl+SNAI1',
 'HOXA13+ctrl',
 'ctrl+FOXF1',
 'ctrl+PTPN12',
 'SAMD1+PTPN12',
 'HK2+ctrl',
 'ctrl+IGDCC3',
 'ctrl+TGFBR2',
 'FOXA3+FOXF1',
 'ZC3HAV1+CEBPA',
 'KIF18B+ctrl',
 'SNAI1+ctrl',
 'FOXA1+FOXF1',
 'PLK4+STIL',
 'STIL+ctrl']

In [186]:
ood_conds_gears = ['CBL+PTPN9',
 'DUSP9+ctrl',
 'MAP2K6+SPI1',
 'UBASH3B+PTPN12',
 'BCORL1+ctrl',
 'MEIS1+ctrl',
 'CBL+ctrl',
 'KLF1+FOXA1',
 'TBX3+TBX2',
 'SLC4A1+ctrl',
 'DUSP9+MAPK1',
 'COL2A1+ctrl',
 'CEBPE+KLF1',
 'UBASH3B+OSR2',
 'UBASH3B+ZBTB25',
 'DUSP9+ETS2',
 'ZNF318+FOXL2',
 'UBASH3B+PTPN9',
 'S1PR2+ctrl',
 'CELF2+ctrl',
 'JUN+CEBPA',
 'CDKN1A+ctrl',
 'ctrl+MEIS1',
 'MAPK1+PRTG',
 'MAP2K3+IKZF3',
 'KLF1+COL2A1',
 'PTPN12+OSR2',
 'ETS2+CEBPE',
 'POU3F2+FOXL2',
 'DUSP9+PRTG',
 'CKS1B+ctrl',
 'BCL2L11+TGFBR2',
 'AHR+KLF1',
 'CEBPB+CEBPA',
 'PRTG+ctrl',
 'ETS2+CNN1',
 'C3orf72+ctrl',
 'CNN1+MAPK1',
 'FOXL2+MEIS1',
 'FOXL2+ctrl',
 'FOSB+CEBPE',
 'PTPN12+PTPN9',
 'FOSB+CEBPB',
 'ctrl+CDKN1A',
 'BCL2L11+BAK1',
 'FOXA3+HOXB9',
 'ARID1A+ctrl',
 'ctrl+COL2A1',
 'CEBPE+CNN1',
 'ZC3HAV1+HOXC13',
 'CBL+CNN1',
 'ZBTB10+PTPN12',
 'CBL+UBASH3B',
 'CEBPB+PTPN12',
 'BCL2L11+ctrl',
 'OSR2+ctrl',
 'ctrl+SPI1',
 'CEBPB+MAPK1',
 'ETS2+MAPK1',
 'DUSP9+IGDCC3',
 'CEBPB+ctrl',
 'CBL+PTPN12',
 'CEBPB+OSR2',
 'ctrl+PRTG',
 'SGK1+S1PR2',
 'DUSP9+KLF1',
 'CDKN1B+ctrl',
 'FEV+ISL2',
 'JUN+ctrl',
 'POU3F2+CBFA2T3',
 'FOXA1+HOXB9',
 'ZBTB10+ctrl',
 'CEBPE+SPI1',
 'PTPN13+ctrl',
 'CBL+TGFBR2',
 'FOXA1+FOXL2',
 'FOXF1+FOXL2',
 'ETS2+PRTG',
 'PTPN9+ctrl',
 'LYL1+CEBPB',
 'DUSP9+SNAI1',
 'ctrl+CEBPB',
 'TGFBR2+PRTG',
 'PRDM1+ctrl',
 'FOSB+OSR2',
 'FOXL2+HOXB9',
 'ctrl+PTPN9',
 'ctrl+OSR2',
 'ZBTB10+ELMSAN1',
 'JUN+CEBPB',
 'ZBTB10+SNAI1',
 'ctrl+FOXL2',
 'CEBPE+CEBPB',
 'PRDM1+CBFA2T3',
 'FOXA3+FOXL2',
 'CDKN1C+CDKN1B',
 'CDKN1C+CDKN1A',
 'SPI1+ctrl',
 'EGR1+ctrl',
 'ZBTB10+DLX2',
 'CBL+UBASH3A',
 'SNAI1+DLX2',
 'IGDCC3+PRTG',
 'CDKN1C+ctrl',
 'ctrl+CDKN1B',
 'CDKN1B+CDKN1A',
 'C3orf72+FOXL2']

In [None]:
len(ood_conds_gears)

In [None]:
len(np.intersect1d(df_gears.condition, ood_conds_gears))

In [None]:
df_gears

In [None]:
df_train_conds_gears = pd.DataFrame(train_conds_gears, columns=["condition"])
df_train_conds_gears.loc[:, ["gene_1", "gene_2"]] = df_train_conds_gears.condition.str.split("+", expand=True).rename({0: "gene_1", 1: "gene_2"}, axis=1)
display(df_train_conds_gears)
train_conds = df_train_conds_gears.gene_1.tolist() + df_train_conds_gears.gene_2.tolist()
train_conds_all = np.unique(np.array(train_conds, dtype=str))
train_conds_all

In [None]:
"KLF1" in train_conds_all

In [None]:
df_gears

In [128]:
compare_conds = np.vstack(
    [
        sorted(results[0]["ood"].obs.condition.unique()),
        sorted(adata_gears.obs.condition.unique()),
    ]
).T

In [None]:
compare_conds[:, 0] == compare_conds[:, 1]