In [1]:
import json

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import umap.plot
from anndata import AnnData
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
from tqdm.auto import tqdm
from utils import (
    compute_drug_embeddings,
    compute_pred,
    compute_pred_ctrl,
    # load_config,
    load_dataset,
    load_model,
    load_smiles,
)

from chemCPA.data import load_dataset_splits
from chemCPA.paths import FIGURE_DIR, PROJECT_DIR, ROOT, DATA_DIR

In [2]:
%load_ext lab_black
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
BLACK = False
SAVEFIG = False

In [4]:
if BLACK:
    plt.style.use("dark_background")
else:
    matplotlib.style.use("fivethirtyeight")
    matplotlib.style.use("seaborn-talk")
    matplotlib.pyplot.rcParams["savefig.facecolor"] = "white"
    sns.set_style("whitegrid")

matplotlib.rcParams["font.family"] = "monospace"
matplotlib.rcParams["figure.dpi"] = 300
sns.set_context("poster")

In [5]:
seml_collection = "baseline_comparison"

cpa_A549 = "044c4dba0c8719985c3622834f2cbd58"
cpa_K562 = "5ea85d5dd7abd5962d1d3eeff1b8c1ff"
cpa_MCF7 = "4d98e7d857f497d870e19c6d12175aaa"

chemCPA_A549_pretrained = "3326f900c45faaf99ca4400f78c58847"
chemCPA_A549 = "8779ff45830000c6bc8e22023bb1cb2c"

chemCPA_K562_pretrained = "6388fa373386c11e40dceb5e2e8a113d"
chemCPA_K562 = "34fd06018d6e2662ccd5da7a16b57334"

chemCPA_MCF7_pretrained = "2075a457bafdca5948ab671b77757974"
chemCPA_MCF7 = "6ad52ba3939397521c5050ca1dd89a4c"

In [6]:
def load_config(seml_collection, model_hash):
    file_path = PROJECT_DIR / f"{seml_collection}.json"  # Provide path to json

    with open(file_path) as f:
        file_data = json.load(f)

    for _config in tqdm(file_data):
        if _config["config_hash"] == model_hash:
            # print(config)
            config = _config["config"]
            config["config_hash"] = _config["config_hash"]
    return config

In [7]:
config = load_config(seml_collection, chemCPA_MCF7_pretrained)

config["dataset"]["data_params"]["dataset_path"] = (
    ROOT / config["dataset"]["data_params"]["dataset_path"]
)
config["model"]["embedding"]["directory"] = (
    ROOT / config["model"]["embedding"]["directory"]
)

dataset, key_dict = load_dataset(config)

config["dataset"]["n_vars"] = dataset.n_vars

canon_smiles_unique_sorted, smiles_to_pathway_map, smiles_to_drug_map = load_smiles(
    config, dataset, key_dict, True
)
model_pretrained, embedding_pretrained = load_model(config, canon_smiles_unique_sorted)

  0%|          | 0/9 [00:00<?, ?it/s]

In [8]:
data_params = config["dataset"]["data_params"]
datasets = load_dataset_splits(**data_params, return_dataset=False)

In [9]:
dosages = [1e1, 1e2, 1e3, 1e4]
cell_lines = ["A549", "K562", "MCF7"]
use_DEGs = True

In [78]:
def get_baseline_predictions(
    hash,
    seml_collection="baseline_comparison",
    smiles=None,
    dosages=[1e1, 1e2, 1e3, 1e4],
    cell_lines=["A549", "K562", "MCF7"],
    use_DEGs=False,
    verbose=False,
    name_tag=None,
):
    if smiles is None:
        smiles = canon_smiles_unique_sorted

    config = load_config(seml_collection, hash)
    config["dataset"]["n_vars"] = dataset.n_vars
    config["dataset"]["data_params"]["dataset_path"] = (
        ROOT / config["dataset"]["data_params"]["dataset_path"]
    )
    config["model"]["embedding"]["directory"] = (
        ROOT / config["model"]["embedding"]["directory"]
    )
    data_params = config["dataset"]["data_params"]
    datasets = load_dataset_splits(**data_params, return_dataset=False)

    predictions, _ = compute_pred_ctrl(
        datasets["ood"],
        dataset_ctrl=datasets["test_control"],
        # dataset_ctrl=datasets["training_control"],
        dosages=dosages,
        cell_lines=cell_lines,
        use_DEGs=use_DEGs,
        verbose=verbose,
    )

    predictions = pd.DataFrame.from_dict(predictions, orient="index", columns=["R2"])
    if name_tag:
        predictions["model"] = name_tag
    predictions["genes"] = "degs" if use_DEGs else "all"
    return predictions

In [79]:
seml_collection = "baseline_comparison"

base_hash_dict = dict(
    baseline_A549="044c4dba0c8719985c3622834f2cbd58",
    baseline_K562="5ea85d5dd7abd5962d1d3eeff1b8c1ff",
    baseline_MCF7="4d98e7d857f497d870e19c6d12175aaa",
)

In [80]:
# drug_r2_baseline_degs, _ = compute_pred_ctrl(
#     dataset=datasets["ood"],
#     dataset_ctrl=datasets["test_control"],
#     dosages=dosages,
#     cell_lines=cell_lines,
#     use_DEGs=True,
#     verbose=False,
# )

# drug_r2_baseline_all, _ = compute_pred_ctrl(
#     dataset=datasets["ood"],
#     dataset_ctrl=datasets["test_control"],
#     dosages=dosages,
#     cell_lines=cell_lines,
#     use_DEGs=False,
#     verbose=False,
# )
predictions = []

predictions.extend(
    [
        get_baseline_predictions(_hash, name_tag=name_tag, use_DEGs=True)
        for name_tag, _hash in base_hash_dict.items()
    ]
)

predictions.extend(
    [
        get_baseline_predictions(_hash, name_tag=name_tag, use_DEGs=False)
        for name_tag, _hash in base_hash_dict.items()
    ]
)

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

In [81]:
predictions

[                             R2          model genes
 A549_Alvespimycin_0.1  0.729878  baseline_A549  degs
 A549_Belinostat_0.1    0.621798  baseline_A549  degs
 A549_Dacinostat_0.1    0.213912  baseline_A549  degs
 A549_Flavopiridol_0.1  0.839678  baseline_A549  degs
 A549_Givinostat_0.1    0.785217  baseline_A549  degs
 A549_Hesperadin_0.1    0.846727  baseline_A549  degs
 A549_Quisinostat_0.1   0.064278  baseline_A549  degs
 A549_TAK-901_0.1       0.761692  baseline_A549  degs
 A549_Tanespimycin_0.1  0.800248  baseline_A549  degs,
                              R2          model genes
 K562_Alvespimycin_0.1  0.528042  baseline_K562  degs
 K562_Belinostat_0.1    0.778544  baseline_K562  degs
 K562_Dacinostat_0.1    0.000000  baseline_K562  degs
 K562_Flavopiridol_0.1  0.000000  baseline_K562  degs
 K562_Givinostat_0.1    0.702251  baseline_K562  degs
 K562_Hesperadin_0.1    0.743025  baseline_K562  degs
 K562_Quisinostat_0.1   0.000000  baseline_K562  degs
 K562_TAK-901_0.1       0.6

In [82]:
datasets["test_control"].genes.size()

torch.Size([1132, 977])

In [83]:
adata = sc.read(DATA_DIR / "adata_baseline.h5ad")
pd.crosstab(adata.obs["control"], adata.obs["split_baseline_A549"])

split_baseline_A549,ood,test,train
control,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,814,10368,88094
1,0,1132,11872


In [84]:
pd.crosstab(adata.obs["control"], adata.obs["split_baseline_K562"])

split_baseline_K562,ood,test,train
control,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,716,10417,88143
1,0,1132,11872


In [85]:
pd.crosstab(adata.obs["control"], adata.obs["split_baseline_MCF7"])

split_baseline_MCF7,ood,test,train
control,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,1662,9944,87670
1,0,1132,11872


In [86]:
def get_model_predictions(
    hash,
    seml_collection="baseline_comparison",
    smiles=None,
    dosages=[1e1, 1e2, 1e3, 1e4],
    cell_lines=["A549", "K562", "MCF7"],
    use_DEGs=False,
    verbose=False,
    name_tag=None,
):
    if smiles is None:
        smiles = canon_smiles_unique_sorted

    config = load_config(seml_collection, hash)
    config["dataset"]["n_vars"] = dataset.n_vars
    config["dataset"]["data_params"]["dataset_path"] = (
        ROOT / config["dataset"]["data_params"]["dataset_path"]
    )
    config["model"]["embedding"]["directory"] = (
        ROOT / config["model"]["embedding"]["directory"]
    )
    data_params = config["dataset"]["data_params"]
    datasets = load_dataset_splits(**data_params, return_dataset=False)

    model, embedding = load_model(config, smiles)
    predictions, _ = compute_pred(
        model,
        datasets["ood"],
        genes_control=datasets["test_control"].genes,
        # genes_control=datasets["training_control"].genes,
        dosages=dosages,
        cell_lines=cell_lines,
        use_DEGs=use_DEGs,
        verbose=verbose,
    )

    predictions = pd.DataFrame.from_dict(predictions, orient="index", columns=["R2"])
    if name_tag:
        predictions["model"] = name_tag
    predictions["genes"] = "degs" if use_DEGs else "all"
    return predictions

In [87]:
seml_collection = "baseline_comparison"

hash_dict = dict(
    cpa_A549="044c4dba0c8719985c3622834f2cbd58",
    cpa_K562="5ea85d5dd7abd5962d1d3eeff1b8c1ff",
    cpa_MCF7="4d98e7d857f497d870e19c6d12175aaa",
    chemCPA_A549_pretrained="3326f900c45faaf99ca4400f78c58847",
    chemCPA_A549="8779ff45830000c6bc8e22023bb1cb2c",
    chemCPA_K562_pretrained="6388fa373386c11e40dceb5e2e8a113d",
    chemCPA_K562="34fd06018d6e2662ccd5da7a16b57334",
    chemCPA_MCF7_pretrained="2075a457bafdca5948ab671b77757974",
    chemCPA_MCF7="6ad52ba3939397521c5050ca1dd89a4c",
)

In [88]:
# df_degs = pd.DataFrame.from_dict(drug_r2_baseline_degs, orient="index", columns=["R2"])
# df_degs["model"] = "baseline"
# df_degs["genes"] = "degs"

# df_all = pd.DataFrame.from_dict(drug_r2_baseline_all, orient="index", columns=["R2"])
# df_all["model"] = "baseline"
# df_all["genes"] = "all"

# predictions = [df_degs, df_all]

In [89]:
predictions.extend(
    [
        get_model_predictions(_hash, name_tag=name_tag, use_DEGs=True)
        for name_tag, _hash in hash_dict.items()
    ]
)
predictions.extend(
    [
        get_model_predictions(_hash, name_tag=name_tag, use_DEGs=False)
        for name_tag, _hash in hash_dict.items()
    ]
)

  0%|          | 0/9 [00:00<?, ?it/s]

CPA model without chemical prior.
['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

CPA model without chemical prior.
['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

CPA model without chemical prior.
['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

CPA model without chemical prior.
['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

CPA model without chemical prior.
['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

CPA model without chemical prior.
['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

In [90]:
def rename_model(str):
    str_list = str.split("_")
    if len(str_list) == 2:
        return str_list[0]
    else:
        assert len(str_list) == 3
        return "_".join([str_list[0], str_list[2]])

In [91]:
predictions = pd.concat(predictions)
predictions.reset_index(inplace=True)
predictions["cell_type"] = predictions["index"].apply(lambda s: s.split("_")[0])
predictions["condition"] = predictions["index"].apply(lambda s: s.split("_")[1])
predictions["dose"] = predictions["index"].apply(lambda s: s.split("_")[2])
predictions["model_ct"] = predictions["model"]
predictions["model"] = predictions["model"].apply(rename_model)

In [92]:
predictions

Unnamed: 0,index,R2,model,genes,cell_type,condition,dose,model_ct
0,A549_Alvespimycin_0.1,0.729878,baseline,degs,A549,Alvespimycin,0.1,baseline_A549
1,A549_Belinostat_0.1,0.621798,baseline,degs,A549,Belinostat,0.1,baseline_A549
2,A549_Dacinostat_0.1,0.213912,baseline,degs,A549,Dacinostat,0.1,baseline_A549
3,A549_Flavopiridol_0.1,0.839678,baseline,degs,A549,Flavopiridol,0.1,baseline_A549
4,A549_Givinostat_0.1,0.785217,baseline,degs,A549,Givinostat,0.1,baseline_A549
...,...,...,...,...,...,...,...,...
211,MCF7_Givinostat_0.1,0.925738,chemCPA,all,MCF7,Givinostat,0.1,chemCPA_MCF7
212,MCF7_Hesperadin_0.1,0.945307,chemCPA,all,MCF7,Hesperadin,0.1,chemCPA_MCF7
213,MCF7_Quisinostat_0.1,0.286649,chemCPA,all,MCF7,Quisinostat,0.1,chemCPA_MCF7
214,MCF7_TAK-901_0.1,0.953843,chemCPA,all,MCF7,TAK-901,0.1,chemCPA_MCF7


In [93]:
predictions.groupby(["model", "genes"]).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,R2
model,genes,Unnamed: 2_level_1
baseline,all,0.692736
baseline,degs,0.509725
chemCPA,all,0.738044
chemCPA,degs,0.604628
chemCPA_pretrained,all,0.768633
chemCPA_pretrained,degs,0.682686
cpa,all,0.717369
cpa,degs,0.541184


In [94]:
predictions.to_parquet("cpa_predictions.parquet")

In [95]:
predictions.groupby(["model", "genes"]).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,R2
model,genes,Unnamed: 2_level_1
baseline,all,0.692736
baseline,degs,0.509725
chemCPA,all,0.738044
chemCPA,degs,0.604628
chemCPA_pretrained,all,0.768633
chemCPA_pretrained,degs,0.682686
cpa,all,0.717369
cpa,degs,0.541184


In [96]:
predictions.groupby(["model", "genes"]).std()

Unnamed: 0_level_0,Unnamed: 1_level_0,R2
model,genes,Unnamed: 2_level_1
baseline,all,0.280294
baseline,degs,0.337252
chemCPA,all,0.264198
chemCPA,degs,0.311684
chemCPA_pretrained,all,0.247264
chemCPA_pretrained,degs,0.259599
cpa,all,0.27941
cpa,degs,0.348539


In [47]:
# predictions = pd.read_parquet("baseline_predictions.parquet")