In [None]:
import pandas as pd
import numpy as np

import scanpy as sc
import biolord

import seaborn as sns
import matplotlib.pyplot as plt
import warnings
from scipy.stats import ttest_rel
import anndata

In [None]:
ood_split = 4

In [None]:
adata_train_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_train_{ood_split}.h5ad"
adata_test_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_test_{ood_split}.h5ad"
adata_ood_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_ood_{ood_split}.h5ad"

In [None]:
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)

In [None]:
adata = anndata.concat((adata_train, adata_test, adata_ood), label="split", keys=["train", "test", "ood"])

In [None]:
frac_valid = 0.1

def create_split(x):
    if x["split"] != "train":
        return "other"
    is_train = np.random.choice(2, p=[frac_valid, 1 - frac_valid])
    if is_train:
        return "train_train"
    return "train_valid"


adata.obs["new_split"] = adata.obs.apply(create_split, axis=1)

In [None]:
dose = adata.obs["dose"].astype("float") / np.max(adata.obs["dose"].astype("float")) # following biolord repr.

In [None]:
adata.obsm["ecfp_dose"] = np.concatenate((adata.obsm["ecfp"], dose.values[:,None]), axis=1)

In [None]:
biolord.Biolord.setup_anndata(
    adata,
    ordered_attributes_keys=["ecfp_dose"],
    categorical_attributes_keys=["cell_type"],
    retrieval_attribute_key=None,
)

In [None]:
module_params= dict(
    decoder_width = 4096,
    decoder_depth = 4,
    
    attribute_dropout_rate = 0.1,
    attribute_nn_width = 2048,
    attribute_nn_depth = 2,
    
    unknown_attribute_noise_param = 2e+1,
    unknown_attribute_penalty = 1e-1,
    gene_likelihood = "normal",
    n_latent_attribute_ordered = 256,
    n_latent_attribute_categorical = 3,
    reconstruction_penalty = 1e+4,
    use_batch_norm = False,
    use_layer_norm = False,)

trainer_params=dict(
    latent_lr = 1e-4,
    latent_wd = 1e-4,
    decoder_lr = 1e-4,
    decoder_wd = 1e-4,
    attribute_nn_lr = 1e-2,
    attribute_nn_wd = 4e-8,
    cosine_scheduler = True,
    scheduler_final_lr = 1e-5,
    step_size_lr = 45,
)

In [None]:
model = biolord.Biolord(
    adata=adata,
    n_latent=256,
    model_name="sciplex3",
    module_params=module_params,
    train_classifiers=False,
    split_key="new_split",
    train_split="train_train",
    valid_split="train_valid",
    test_split="other",
)

In [None]:
model.train(
    max_epochs=200,
    batch_size=512,
    plan_kwargs=trainer_params,
    early_stopping=True,
    early_stopping_patience=20,
    check_val_every_n_epoch=10,
    num_workers=10,
    enable_checkpointing=False
)

In [None]:
def bool2idx(x):
    """
    Returns the indices of the True-valued entries in a boolean array `x`
    """
    return np.where(x)[0]

def repeat_n(x, n):
    """
    Returns an n-times repeated version of the Tensor x,
    repetition dimension is axis 0
    """
    # copy tensor to device BEFORE replicating it n times
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return x.to(device).view(1, -1).repeat(n, 1)


In [None]:
idx_test_control = np.where(
    (adata.obs["split"] == "test") & (adata.obs["control"] == 1)
)[0]

adata_test_control = adata[idx_test_control].copy()

idx_ood = np.where(((adata.obs["split"] == "ood") & (adata.obs["control"] == 0)))[0]

adata_ood = adata[idx_ood].copy()
dataset_ood = model.get_dataset(adata_ood)

In [None]:
dataset_ood = model.get_dataset(adata_ood)

In [None]:
import pandas as pd
from tqdm import tqdm

def compute_prediction(
    model,
    adata,
    dataset,
    adata_control,
    n_obs=500
):
    pert_categories_index = pd.Index(adata.obs["condition"].values, dtype="category")

    cl_dict = {
        torch.Tensor([0.]): "A549",
        torch.Tensor([1.]): "K562",
        torch.Tensor([2.]): "MCF7",
    }

    cell_lines = ["A549", "K562", "MCF7"]

    layer = "X" if "X" in dataset else "layers"
    predictions_dict = {}
    for cell_drug_dose_comb in tqdm(np.unique(pert_categories_index.values)
    ):
        cur_cell_line = cell_drug_dose_comb.split("_")[0]
        dataset_control = model.get_dataset(adata_test_control[adata_test_control.obs.cell_type == cur_cell_line])

        bool_category = pert_categories_index.get_loc(cell_drug_dose_comb)
        idx_all = bool2idx(bool_category)
        idx = idx_all[0]
                    
        dataset_comb = {}

        dataset_comb[layer] = dataset_control[layer].to(model.device)
        dataset_comb["ind_x"] = dataset_control["ind_x"].to(model.device)
        for key in dataset_control:
            if key not in [layer, "ind_x"]:
                dataset_comb[key] = repeat_n(dataset[key][idx, :], n_obs)

        stop = False
        for tensor, cl in cl_dict.items():
            if (tensor == dataset["cell_type"][idx]).all():
                if cl not in cell_lines:
                    stop = True
        if stop:
            continue
            
        pred_mean, pred_std = model.module.get_expression(dataset_comb)
        samples = torch.normal(pred_mean, pred_std)

        predictions_dict[cell_drug_dose_comb] = samples.detach().cpu().numpy()
    return predictions_dict

In [None]:
import torch
biolord_prediction = compute_prediction(
    model=model,
    adata=adata_ood,
    dataset=dataset_ood,
    adata_control=adata_test_control)

In [None]:
import anndata as ad
all_data = []
conditions = []

for condition, array in biolord_prediction.items():
    all_data.append(array)
    conditions.extend([condition] * array.shape[0])

# Stack all data vertically to create a single array
all_data_array = np.vstack(all_data)

# Create a DataFrame for the .obs attribute
obs_data = pd.DataFrame({
    'condition': conditions
})

# Create the Anndata object
adata_ood_result = ad.AnnData(X=all_data_array, obs=obs_data)

In [None]:
adata_ood_result.write_h5ad(f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/biolord_output_ood_{ood_split}.h5ad")

In [None]:
1

In [None]:
idx_test_control = np.where(
    (adata.obs["split"] == "test") & (adata.obs["control"] == 1)
)[0]

adata_test_control = adata[idx_test_control].copy()

idx_test = np.where(((adata.obs["split"] == "test") & (adata.obs["control"] == 0)))[0]

adata_test = adata[idx_test].copy()
dataset_test = model.get_dataset(adata_test)

In [None]:
dataset_test = model.get_dataset(adata_test)

In [None]:
biolord_prediction = compute_prediction(
    model=model,
    adata=adata_test,
    dataset=dataset_test,
    adata_control=adata_test_control)

In [None]:
all_data = []
conditions = []

for condition, array in biolord_prediction.items():
    all_data.append(array)
    conditions.extend([condition] * array.shape[0])

# Stack all data vertically to create a single array
all_data_array = np.vstack(all_data)

# Create a DataFrame for the .obs attribute
obs_data = pd.DataFrame({
    'condition': conditions
})

# Create the Anndata object
adata_test_result = ad.AnnData(X=all_data_array, obs=obs_data)

In [None]:
adata_test_result.write_h5ad(f"/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/biolord_output_test_{ood_split}.h5ad")