In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scgen

from chemCPA.paths import CHECKPOINT_DIR, DATA_DIR

# sc.set_figure_params(dpi=300, frameon=False)
# sc.logging.print_header()
%load_ext lab_black
%load_ext autoreload
%autoreload 2

Global seed set to 0
  PyTreeDef = type(jax.tree_structure(None))


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
dose = 1.0

if dose == 0.1:
    suffix = ""
elif dose == 1.0:
    suffix = "_high_dose"
adata = sc.read(DATA_DIR / f"adata_baseline{suffix}.h5ad")

In [3]:
split = "split_baseline_A549"
df_ood = adata.obs.loc[
    adata.obs[split] == "ood", ["cell_type", "condition"]
].drop_duplicates()

df_ood

Unnamed: 0_level_0,cell_type,condition
index,Unnamed: 1_level_1,Unnamed: 2_level_1
A01_E09_RT_BC_122_Lig_BC_104-0-0-0,A549,Hesperadin
A01_E09_RT_BC_129_Lig_BC_15-0-0-0,A549,Flavopiridol
A01_E09_RT_BC_31_Lig_BC_257-0-0-0,A549,Belinostat
A01_E09_RT_BC_328_Lig_BC_51-0-0-0,A549,TAK-901
A01_E09_RT_BC_337_Lig_BC_71-0-0-0,A549,Quisinostat
A01_F10_RT_BC_149_Lig_BC_165-0-0-0,A549,Alvespimycin
A03_E09_RT_BC_209_Lig_BC_281-0-0-0,A549,Givinostat
A03_E09_RT_BC_367_Lig_BC_80-0-0-0,A549,Tanespimycin
A04_E09_RT_BC_76_Lig_BC_322-0-0-0,A549,Dacinostat


In [4]:
splits = [c for c in adata.obs.columns if "baseline" in c]

splits

['split_baseline_A549', 'split_baseline_MCF7', 'split_baseline_K562']

In [5]:
split_model_dict = dict(
    split_baseline_A549=CHECKPOINT_DIR / f"scgen_sciplex_A549{suffix}.pt",
    split_baseline_K562=CHECKPOINT_DIR / f"scgen_sciplex_K562{suffix}.pt",
    split_baseline_MCF7=CHECKPOINT_DIR / f"scgen_sciplex_MCF7{suffix}.pt",
)

### Train scGen

In [6]:
def train_scgen(split, path=None):
    adata_train = adata[adata.obs[split] == "train"].copy()

    scgen.SCGEN.setup_anndata(
        adata_train, batch_key="condition", labels_key="cell_type"
    )

    model = scgen.SCGEN(adata_train)

    model.train(
        max_epochs=50,
        batch_size=128,
        early_stopping=True,
        early_stopping_patience=25,
        plan_kwargs=dict(n_epochs_kl_warmup=45),
    )
    if path:
        # fname = CHECKPOINT_DIR/f"scgen_sciplex_{split.split('_')[-1]}.pt"
        fname = path
        model.save(fname, overwrite=True)
    print(f"Model saved at \n\t f{fname}")
    del model

In [7]:
retrain = False

for split, model_path in split_model_dict.items():
    if not model_path.exists() or retrain:
        train_scgen(split, model_path)
    else:
        print(f"Model for {split} already exists at:\n\t {model_path}")

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]


Epoch 29/50:  58%|█████▊    | 29/50 [04:31<03:16,  9.37s/it, loss=15.5, v_num=1]
Monitored metric elbo_validation did not improve in the last 25 records. Best score: 398.693. Signaling Trainer to stop.
Model saved at 
	 f/nfs/staff-ssd/hetzell/code/chemCPA_v2/project_folder/checkpoints/scgen_sciplex_A549_high_dose.pt


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]


Epoch 26/50:  52%|█████▏    | 26/50 [04:01<03:42,  9.29s/it, loss=15.6, v_num=1]
Monitored metric elbo_validation did not improve in the last 25 records. Best score: 291.633. Signaling Trainer to stop.
Model saved at 
	 f/nfs/staff-ssd/hetzell/code/chemCPA_v2/project_folder/checkpoints/scgen_sciplex_K562_high_dose.pt


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]


Epoch 27/50:  54%|█████▍    | 27/50 [04:02<03:26,  9.00s/it, loss=15.7, v_num=1]
Monitored metric elbo_validation did not improve in the last 25 records. Best score: 356.003. Signaling Trainer to stop.
Model saved at 
	 f/nfs/staff-ssd/hetzell/code/chemCPA_v2/project_folder/checkpoints/scgen_sciplex_MCF7_high_dose.pt


### Compute predictions

In [9]:
import torch

from chemCPA.train import compute_r2


def compute_prediction(
    split, model, adata, use_DEGs=False, degs_key="lincs_DEGs", dose=0.1
):
    drug_r2 = {}
    ood_idx = adata.obs[split] == "ood"
    df_ood = adata.obs.loc[ood_idx, ["cell_type", "condition"]].drop_duplicates()
    for _, (ct, condition) in df_ood.iterrows():
        cell_drug_dose_comb = f"{ct}_{condition}_{dose}"
        ctrl_idx = (
            adata.obs[[split, "condition", "cell_type"]]
            .isin(["test", "control", ct])
            .prod(1)
            .astype(bool)
        )
        y_idx = (
            adata.obs[[split, "cell_type", "condition"]]
            .isin(["ood", ct, condition])
            .prod(1)
            .astype(bool)
        )
        y_true = adata[y_idx].X.A

        # adata_pred, _ = model.predict(
        #     ctrl_key='control',
        #     stim_key=condition,
        #     celltype_to_predict=ct,
        #     )
        adata_pred, _ = model.predict(
            ctrl_key="control",
            stim_key=condition,
            adata_to_predict=adata[ctrl_idx].copy(),
        )

        y_pred = adata_pred.X

        y_pred = torch.Tensor(y_pred).mean(0)
        y_true = torch.Tensor(y_true).mean(0)

        if use_DEGs:
            degs = adata.uns[degs_key][f"{ct}_{condition}_{dose}"]
            idx_de = adata.var_names.isin(degs)
            r2_m_de = compute_r2(y_true[idx_de].cuda(), y_pred[idx_de].cuda())
            drug_r2[cell_drug_dose_comb] = max(r2_m_de, 0.0)
        else:
            r2_m = compute_r2(y_true.cuda(), y_pred.cuda())
            drug_r2[cell_drug_dose_comb] = max(r2_m, 0.0)

    return drug_r2

In [10]:
scgen.SCGEN.setup_anndata(adata)
predictions = []
for split, model_path in split_model_dict.items():
    _adata = adata[adata.obs[split] == "train"].copy()
    model = scgen.SCGEN.load(model_path, _adata)
    for use_DEGs in [False, True]:
        preds = compute_prediction(
            split=split,
            model=model,
            adata=adata,
            use_DEGs=use_DEGs,
        )
        preds = pd.DataFrame.from_dict(preds, orient="index", columns=["R2"])

        preds["model"] = f"scGen_{split.split('_')[-1]}_{dose}"
        preds["genes"] = "degs" if use_DEGs else "all"
        predictions.append(preds)

[34mINFO    [0m File [35m/nfs/staff-ssd/hetzell/code/chemCPA_v2/project_folder/checkpoints/scgen_sciplex[0m
         [35m_A549_high_dose.pt/[0m[95mmodel.pt[0m already downloaded                                      


  utils.warn_names_duplicates("obs")


[34mINFO    [0m Received view of anndata, making copy.                                              
[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup       
[34mINFO    [0m Received view of anndata, making copy.                                              
[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup       
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                  
[34mINFO    [0m Received view of anndata, making copy.                                              
[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup       
[34mINFO    [0m Received view of anndata, making copy.                                              
[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup       
[34mINFO    [0m AnnData object appears to be a copy. Attempting to tran

In [11]:
predictions = pd.concat(predictions)
predictions.reset_index(inplace=True)
predictions["cell_type"] = predictions["index"].apply(lambda s: s.split("_")[0])
predictions["condition"] = predictions["index"].apply(lambda s: s.split("_")[1])
predictions["dose"] = f"{dose}"
predictions["model_ct"] = predictions["model"]
predictions["model"] = predictions["model"].apply(lambda s: s.split("_")[0])

In [12]:
predictions

Unnamed: 0,index,R2,model,genes,cell_type,condition,dose,model_ct
0,A549_Hesperadin_0.1,0.793252,scGen,all,A549,Hesperadin,1.0,scGen_A549_1.0
1,A549_Flavopiridol_0.1,0.615121,scGen,all,A549,Flavopiridol,1.0,scGen_A549_1.0
2,A549_Belinostat_0.1,0.67256,scGen,all,A549,Belinostat,1.0,scGen_A549_1.0
3,A549_TAK-901_0.1,0.839423,scGen,all,A549,TAK-901,1.0,scGen_A549_1.0
4,A549_Quisinostat_0.1,0.623886,scGen,all,A549,Quisinostat,1.0,scGen_A549_1.0
5,A549_Alvespimycin_0.1,0.730094,scGen,all,A549,Alvespimycin,1.0,scGen_A549_1.0
6,A549_Givinostat_0.1,0.661892,scGen,all,A549,Givinostat,1.0,scGen_A549_1.0
7,A549_Tanespimycin_0.1,0.764184,scGen,all,A549,Tanespimycin,1.0,scGen_A549_1.0
8,A549_Dacinostat_0.1,0.651836,scGen,all,A549,Dacinostat,1.0,scGen_A549_1.0
9,A549_Hesperadin_0.1,0.830219,scGen,degs,A549,Hesperadin,1.0,scGen_A549_1.0


In [15]:
predictions.groupby(["model", "genes"]).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,R2
model,genes,Unnamed: 2_level_1
scGen,all,0.621697
scGen,degs,0.471889


In [14]:
predictions.to_parquet(f"scgen_predictions{suffix}.parquet")