In [93]:
adata_train_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_train_biolord_split_30.h5ad"
adata_test_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_test_biolord_split_30.h5ad"
adata_ood_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_ood_biolord_split_30.h5ad"

In [94]:
import anndata as ad
import numpy as np
import biolord

from parameters_sciplex3 import module_params, trainer_params
from utils_perturbation_sciplex3 import compute_prediction

In [101]:
ad_train = ad.read_h5ad(adata_train_path)

In [102]:
ad_test = ad.read_h5ad(adata_test_path)
ad_ood = ad.read_h5ad(adata_ood_path)

In [None]:
# join the three anndatas
adata = ad_train.concatenate(ad_test, ad_ood)

  adata = ad_train.concatenate(ad_test, ad_ood)


In [None]:
biolord.Biolord.setup_anndata(
    adata,
    ordered_attributes_keys=["ecfp"],
    categorical_attributes_keys=["cell_type"],
    retrieval_attribute_key=None,
)

In [None]:
model = biolord.Biolord(
    adata=adata,
    n_latent=256,
    model_name="sciplex3",
    module_params=module_params,
    train_classifiers=False,
    split_key="split",
)

In [None]:
model.train(
    max_epochs=200,
    batch_size=512,
    plan_kwargs=trainer_params,
    early_stopping=True,
    early_stopping_patience=20,
    check_val_every_n_epoch=10,
    num_workers=10,
    enable_checkpointing=False
)

In [None]:
model.save("/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/biolord_model_biolordsplit")

In [98]:
import copy

In [100]:
model_30 = copy.deepcopy(model)

In [61]:
idx_test_control = np.where(
    (adata.obs["split"] == "test") & (adata.obs["control"] == 1)
)[0]

adata_test_control = adata[idx_test_control].copy()

idx_ood = np.where((adata.obs["split"] == "ood"))[0]

adata_ood = adata[idx_ood].copy()

In [30]:
# narrow to a single drug dose cell type for testing
# adata_ood = adata_ood[(adata_ood.obs.cov_drug_dose_name == "A549_Quisinostat_0.001")]

In [62]:
dataset_control = model.get_dataset(adata_test_control)
dataset_ood = model.get_dataset(adata_ood)

[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


In [32]:
biolord_prediction = compute_prediction(
    model=model,
    adata=adata_ood,
    dataset=dataset_ood,
    dataset_control=dataset_control)

199it [00:20,  9.93it/s]


In [70]:
# Step 1: Prepare the data matrix and conditions list
all_data = []
conditions = []

for condition, tensor in biolord_prediction.items():
    # Move tensor to CPU and convert to numpy
    numpy_array = tensor.cpu().numpy()
    # save a single sample from each condition due to their similarity
    all_data.append(numpy_array[0])
    conditions.extend([condition])

# Stack all data vertically
all_data = np.vstack(all_data)

# Step 2: Create an AnnData object
adata_output = ad.AnnData(all_data)
adata_output.obs['condition'] = conditions

In [78]:
adata_output.write_h5ad("/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/biolord_output_adata_ood_30.h5ad")

In [None]:
# pred_orig = compute_prediction_orig(model,
#     adata_ood,
#     dataset_ood,
#     dataset_control=dataset_control)

In [84]:
# # code for computing pred
# import numpy as np
# import pandas as pd
# import torch
# from torchmetrics import R2Score
# from tqdm import tqdm
# 
# 
# def bool2idx(x):
#     """
#     Returns the indices of the True-valued entries in a boolean array `x`
#     """
#     return np.where(x)[0]
# 
# def repeat_n(x, n):
#     """
#     Returns an n-times repeated version of the Tensor x,
#     repetition dimension is axis 0
#     """
#     # copy tensor to device BEFORE replicating it n times
#     device = "cuda" if torch.cuda.is_available() else "cpu"
#     return x.to(device).view(1, -1).repeat(n, 1)
# 
# def compute_r2(y_true, y_pred):
#     """
#     Computes the r2 score for `y_true` and `y_pred`,
#     returns `-1` when `y_pred` contains nan values
#     """
#     y_pred = torch.clamp(y_pred, -3e12, 3e12)
#     metric = R2Score().to(y_true.device)
#     metric.update(y_pred, y_true)  # same as sklearn.r2_score(y_true, y_pred)
#     return metric.compute().item()
# 
# def compute_prediction_orig(
#     model,
#     adata,
#     dataset,
#     cell_lines=None,
#     dataset_control=None,
#     use_DEGs=True,
#     verbose=True
# ):
#     pert_categories_index = pd.Index(adata.obs["condition"].values, dtype="category")
#     allowed_cell_lines = []
# 
#     cl_dict = {
#         torch.Tensor([0.]): "A549",
#         torch.Tensor([1.]): "K562",
#         torch.Tensor([2.]): "MCF7",
#     }
# 
#     if cell_lines is None:
#         cell_lines = ["A549", "K562", "MCF7"]
# 
#     print(cell_lines)
#     layer = "X" if "X" in dataset else "layers"
#     predictions_dict = {}
#     drug_r2 = {}
#     for cell_drug_dose_comb, category_count in tqdm(
#         zip(*np.unique(pert_categories_index.values, return_counts=True))
#     ):
#         # estimate metrics only for reasonably-sized drug/cell-type combos
#         if category_count <= 5:
#             continue
#         # doesn"t make sense to evaluate DMSO (=control) as a perturbation
#         if (
#             "dmso" in cell_drug_dose_comb.lower()
#             or "control" in cell_drug_dose_comb.lower()
#         ):
#             continue
# 
#         # adata.var_names is the list of gene names
#         # adata.uns["all_DEGs"] is a dict, containing a list of all differentiably-expressed
#         # genes for every cell_drug_dose combination.
# 
# 
#         bool_category = pert_categories_index.get_loc(cell_drug_dose_comb)
#         idx_all = bool2idx(bool_category)
#         idx = idx_all[0]
#         y_true = dataset[layer][idx_all, :].to(model.device)
#         
#                     
#         dataset_comb = {}
#         if dataset_control is None:
#             n_obs = y_true.size(0).to(model.device)
#             for key, val in dataset.items():
#                 dataset_comb[key] = val[idx_all].to(model.device)
#         else:
#             n_obs = dataset_control[layer].size(0)
#             dataset_comb[layer] = dataset_control[layer].to(model.device)
#             dataset_comb["ind_x"] = dataset_control["ind_x"].to(model.device)
#             for key in dataset_control:
#                 if key not in [layer, "ind_x"]:
#                     dataset_comb[key] = repeat_n(dataset[key][idx, :], n_obs)
# 
#         stop = False
#         for tensor, cl in cl_dict.items():
#             if (tensor == dataset["cell_type"][idx]).all():
#                 if cl not in cell_lines:
#                     stop = True
#         if stop:
#             continue
#             
#         pred, _ = model.module.get_expression(dataset_comb)
# 
#         y_pred = pred.mean(0)
#         y_true = y_true.mean(0)
# 
#         r2_m = compute_r2(y_true.cuda(), y_pred.cuda())
#         print(f"{cell_drug_dose_comb}: {r2_m:.2f}") if verbose else None
#         drug_r2[cell_drug_dose_comb] = r2_m
# 
#         predictions_dict[cell_drug_dose_comb] = [y_true, y_pred]
#     return drug_r2, predictions_dict
# 
# 
# def compute_baseline(
#     model,
#     adata,
#     dataset,
#     cell_lines=None,
#     dataset_control=None,
#     use_DEGs=True,
#     verbose=True,
# ):
#     pert_categories_index = pd.Index(adata.obs["cov_drug_dose_name"].values, dtype="category")
#     allowed_cell_lines = []
# 
#     cl_dict = {
#         torch.Tensor([0.]): "A549",
#         torch.Tensor([1.]): "K562", 
#         torch.Tensor([2.]): "MCF7",
#     }
#     
#     cl_dict_op = {
#         "A549":torch.Tensor([0.]),
#         "K562": torch.Tensor([1.]),
#         "MCF7": torch.Tensor([2.]),
#     }
# 
#     if cell_lines is None:
#         cell_lines = ["A549", "K562", "MCF7"]
# 
#     print(cell_lines)
# 
#     layer = "X" if "X" in dataset else "layers"
#     predictions_dict = {}
#     drug_r2 = {}
#     for cell_drug_dose_comb, category_count in tqdm(
#         zip(*np.unique(pert_categories_index.values, return_counts=True))
#     ):
#         # estimate metrics only for reasonably-sized drug/cell-type combos
#         if category_count <= 5:
#             continue
# 
#         # doesn"t make sense to evaluate DMSO (=control) as a perturbation
#         if (
#             "dmso" in cell_drug_dose_comb.lower()
#             or "control" in cell_drug_dose_comb.lower()
#         ):
#             continue
# 
#         # adata.var_names is the list of gene names
#         # adata.uns["all_DEGs"] is a dict, containing a list of all differentiably-expressed
#         # genes for every cell_drug_dose combination.
#         bool_de = adata.var_names.isin(
#             np.array(adata.uns["all_DEGs"][cell_drug_dose_comb])
#         )
#         idx_de = bool2idx(bool_de)
# 
#         # need at least two genes to be able to calc r2 score
#         if len(idx_de) < 2:
#             continue
# 
#         bool_category = pert_categories_index.get_loc(cell_drug_dose_comb)
#         idx_all = bool2idx(bool_category)
#         idx = idx_all[0]
#         y_true = dataset[layer][idx_all, :].to(model.device)
#         
#         cov_name = cell_drug_dose_comb.split("_")[0]
#         cond = bool2idx(dataset_control["cell_type"] == cl_dict_op[cov_name])
#         y_pred = dataset_control[layer][cond, :].to(model.device)
# 
#         stop = False
#         for tensor, cl in cl_dict.items():
#             if (tensor == dataset["cell_type"][idx]).all():
#                 if cl not in cell_lines:
#                     stop = True
#         if stop:
#             continue
#             
#         y_pred = y_pred.mean(0)
#         y_true = y_true.mean(0)
#         if use_DEGs:
#             r2_m_de = compute_r2(y_true[idx_de].cuda(), y_pred[idx_de].cuda())
#             print(f"{cell_drug_dose_comb}: {r2_m_de:.2f}") if verbose else None
#             drug_r2[cell_drug_dose_comb] = r2_m_de
#         else:
#             r2_m = compute_r2(y_true.cuda(), y_pred.cuda())
#             print(f"{cell_drug_dose_comb}: {r2_m:.2f}") if verbose else None
#             drug_r2[cell_drug_dose_comb] = r2_m
# 
#         predictions_dict[cell_drug_dose_comb] = [y_true, y_pred, idx_de]
#     
#     return drug_r2, predictions_dict
# 
# 
# def create_df(res):
#     dfs_ = []
#     for key_, res_ in res.items():
#         df_ = pd.DataFrame.from_dict(res_, orient="index", columns=["r2_de"])
#         df_["type"] = key_
#         dfs_.append(df_)
#         
#     df = pd.concat(dfs_)
# 
#     df["r2_de"] = df["r2_de"].apply(lambda x: max(x,0))
#     df["cell_line"] = pd.Series(df.index.values).apply(lambda x: x.split("_")[0]).values
#     df["drug"] = pd.Series(df.index.values).apply(lambda x: x.split("_")[1]).values
#     df["dose"] = pd.Series(df.index.values).apply(lambda x: x.split("_")[2]).values
#     df["dose"] = df["dose"].astype(float)
# 
#     df["combination"] = df.index.values
#     df = df.reset_index()
#     return df