**Requirements:**
* Trained models

**Outputs:** 
* none 
___
# Imports

In [1]:
import matplotlib
import umap.plot
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sn

from utils import load_config, load_dataset, load_smiles, load_model, compute_drug_embeddings, compute_pred
from compert.data import load_dataset_splits

matplotlib.style.use("fivethirtyeight")
matplotlib.style.use("seaborn-talk")
matplotlib.rcParams['font.family'] = "monospace"
matplotlib.rcParams['figure.dpi'] = 60
matplotlib.pyplot.rcParams['savefig.facecolor'] = 'white'
sn.set_context("poster")

Using backend: pytorch[10:41:50] /opt/dgl/src/runtime/tensordispatch.cc:43: TensorDispatcher: dlopen failed: /home/icb/leon.hetzel/miniconda3/envs/chemical_CPA/lib/python3.7/site-packages/dgl/tensoradapter/pytorch/libtensoradapter_pytorch_1.10.1.so: cannot open shared object file: No such file or directory



In [2]:
%load_ext autoreload
%autoreload 2

# Load and analyse model 
* Define `seml_collection` and `model_hash` to load data and model

In [3]:
seml_collection = "finetuning_num_genes"

# split_ho_pathway, append_ae_layer: true
model_hash_pretrained = "70290e4f42ac4cb19246fafa0b75ccb6" # "config.model.load_pretrained": true, 
model_hash_scratch = "ed3bc586a5fcfe3c4dbb0157cd67d0d9" # "config.model.load_pretrained": false, 

# split_ood_finetuning, append_ae_layer: true
model_hash_pretrained = "bd001c8d557edffe9df9e6bf09dc4120" # "config.model.load_pretrained": true, 
model_hash_scratch = "6e9d00880375aa450a8e5de60250659f" # "config.model.load_pretrained": false, 

## Load config

In [4]:
config = load_config(seml_collection, model_hash_pretrained)
dataset, key_dict = load_dataset(config)
config['dataset']['n_vars'] = dataset.n_vars

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

### Load smiles info

In [5]:
canon_smiles_unique_sorted, smiles_to_pathway_map, smiles_to_drug_map = load_smiles(config, dataset, key_dict)

#### Define which drugs should be annotaded with list `ood_drugs`

In [6]:
ood_drugs = dataset.obs.condition[dataset.obs[config["dataset"]["data_params"]["split_key"]].isin(['ood'])].unique().to_list()

#### Get pathway level 2 annotation for clustering of drug embeddings

In [7]:
smiles_to_pw_level2_map = {}
pw1_to_pw2 = {}

for (drug, pw1, pw2), df in dataset.obs.groupby(['SMILES', 'pathway_level_1', 'pathway_level_2']): 
    smiles_to_pw_level2_map[drug] = pw2
    if pw1 in pw1_to_pw2:
        pw1_to_pw2[pw1].add(pw2)
    else: 
        pw1_to_pw2[pw1] = {pw2}

In [8]:
groups = ["Epigenetic regulation"]

groups_pw2 = [pw2 for pw in groups for pw2 in pw1_to_pw2[pw]]
groups_pw2

['DNA methylation',
 'Histone methylation',
 'Bromodomain',
 'Histone demethylase',
 'Histone deacetylation',
 'Histone acetylation']

## Load dataset splits

In [9]:
config['dataset']['data_params']

{'covariate_keys': 'cell_type',
 'dataset_path': '/storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/datasets/sciplex_complete.h5ad',
 'degs_key': 'all_DEGs',
 'dose_key': 'dose',
 'pert_category': 'cov_drug_dose_name',
 'perturbation_key': 'condition',
 'smiles_key': 'SMILES',
 'split_key': 'split_ood_finetuning',
 'use_drugs_idx': True}

In [10]:
data_params = config['dataset']['data_params']

# #Overwrite split_key 
# data_params['split_key'] = 'split_ho_epigenetic'

datasets = load_dataset_splits(**data_params, return_dataset=False)

___
## Pretrained model

In [11]:
dosages = [1e1, 1e2, 1e3, 1e4]

In [12]:
config = load_config(seml_collection, model_hash_pretrained)
config['dataset']['n_vars'] = dataset.n_vars
model_pretrained, embedding_pretrained = load_model(config, canon_smiles_unique_sorted)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [13]:
drug_r2_pretrained, _ = compute_pred(model_pretrained, 
                                     datasets['ood'], 
                                     genes_control=datasets['test_control'].genes, 
                                     dosages=dosages)

0it [00:00, ?it/s]

MCF7_CUDC-101_0.001: 0.82
MCF7_CUDC-101_0.01: 0.83
MCF7_CUDC-101_0.1: 0.82
MCF7_CUDC-101_1.0: 0.63
MCF7_CUDC-907_0.001: 0.70
MCF7_CUDC-907_0.01: 0.71
MCF7_CUDC-907_0.1: 0.61
MCF7_CUDC-907_1.0: -0.39
MCF7_Dacinostat_0.001: 0.72
MCF7_Dacinostat_0.01: 0.72
MCF7_Dacinostat_0.1: 0.13
MCF7_Dacinostat_1.0: -0.09
MCF7_Givinostat_0.001: 0.76
MCF7_Givinostat_0.01: 0.77
MCF7_Givinostat_0.1: 0.84
MCF7_Givinostat_1.0: -0.22
MCF7_Hesperadin_0.001: 0.78
MCF7_Hesperadin_0.01: 0.80
MCF7_Hesperadin_0.1: 0.87
MCF7_Hesperadin_1.0: 0.80
MCF7_Pirarubicin_0.001: 0.68
MCF7_Pirarubicin_0.01: 0.69
MCF7_Pirarubicin_0.1: 0.69
MCF7_Pirarubicin_1.0: 0.64
MCF7_Raltitrexed_0.001: 0.70
MCF7_Raltitrexed_0.01: 0.68
MCF7_Raltitrexed_0.1: 0.66
MCF7_Raltitrexed_1.0: 0.66
MCF7_Tanespimycin_0.001: 0.73
MCF7_Tanespimycin_0.01: 0.67
MCF7_Tanespimycin_0.1: 0.64
MCF7_Tanespimycin_1.0: 0.59
MCF7_Trametinib_0.001: 0.72
MCF7_Trametinib_0.01: 0.72
MCF7_Trametinib_0.1: 0.74
MCF7_Trametinib_1.0: 0.75


## Non-pretrained model

In [14]:
config = load_config(seml_collection, model_hash_scratch)
config['dataset']['n_vars'] = dataset.n_vars
model_scratch, embedding_scratch = load_model(config, canon_smiles_unique_sorted)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [15]:
drug_r2_scratch, _ = compute_pred(model_scratch,
                                  datasets['ood'],
                                  genes_control=datasets['test_control'].genes, 
                                  dosages=dosages) # non-pretrained

0it [00:00, ?it/s]

MCF7_CUDC-101_0.001: 0.56
MCF7_CUDC-101_0.01: 0.55
MCF7_CUDC-101_0.1: 0.50
MCF7_CUDC-101_1.0: 0.24
MCF7_CUDC-907_0.001: 0.35
MCF7_CUDC-907_0.01: 0.33
MCF7_CUDC-907_0.1: 0.16
MCF7_CUDC-907_1.0: -0.39
MCF7_Dacinostat_0.001: 0.38
MCF7_Dacinostat_0.01: 0.27
MCF7_Dacinostat_0.1: -0.29
MCF7_Dacinostat_1.0: -0.41
MCF7_Givinostat_0.001: 0.45
MCF7_Givinostat_0.01: 0.46
MCF7_Givinostat_0.1: 0.14
MCF7_Givinostat_1.0: -0.49
MCF7_Hesperadin_0.001: 0.47
MCF7_Hesperadin_0.01: 0.48
MCF7_Hesperadin_0.1: 0.16
MCF7_Hesperadin_1.0: 0.39
MCF7_Pirarubicin_0.001: 0.40
MCF7_Pirarubicin_0.01: 0.40
MCF7_Pirarubicin_0.1: 0.44
MCF7_Pirarubicin_1.0: 0.32
MCF7_Raltitrexed_0.001: 0.35
MCF7_Raltitrexed_0.01: 0.34
MCF7_Raltitrexed_0.1: 0.23
MCF7_Raltitrexed_1.0: 0.25
MCF7_Tanespimycin_0.001: 0.21
MCF7_Tanespimycin_0.01: 0.22
MCF7_Tanespimycin_0.1: 0.33
MCF7_Tanespimycin_1.0: 0.22
MCF7_Trametinib_0.001: 0.05
MCF7_Trametinib_0.01: -0.02
MCF7_Trametinib_0.1: 0.26
MCF7_Trametinib_1.0: 0.29


In [16]:
dataset.obs.loc[dataset.obs.split_ood_finetuning=='ood', 'condition'].unique().to_list()

['Raltitrexed',
 'Trametinib',
 'Hesperadin',
 'CUDC-101',
 'Dacinostat',
 'Pirarubicin',
 'CUDC-907',
 'Tanespimycin',
 'Givinostat']

In [17]:
np.mean([max(v, 0) for v in drug_r2_scratch.values()])

0.19084201256434122

In [18]:
np.mean([max(v, 0) for v in drug_r2_pretrained.values()])

0.45148704449335736

____

In [19]:
import torch 
import json

file_name = "test.pt"
torch.save(
            (
                model_scratch.state_dict(),
                # adversary covariates are saved as a list attr on the autoencoder
                # which PyTorch doesn't include in the autoencoder's state dict
                [
                    adversary_covariates.state_dict()
                    for adversary_covariates in model_scratch.adversary_covariates
                ],
#                 TODO I haven't checked that this actually works
                [
                  covariate_embedding.state_dict()
                  for covariate_embedding in model_scratch.covariates_embeddings
                ],
                model_scratch.init_args,
                model_scratch.history,
            ),
            file_name,
        )
pjson = lambda s: print(json.dumps(s), flush=True)
pjson({"model_saved": file_name})

{"model_saved": "test.pt"}


In [20]:
def load_torch_model(file_path, append_ae_layer=False): 
    dumped_model = torch.load(file_path)
    if len(dumped_model) == 3:
        # old version
        state_dict, model_config, history = dumped_model
    else:
        # new version
        assert len(dumped_model) == 5
        (
            state_dict,
            adversary_cov_state_dicts,
            cov_embeddings_state_dicts,
            model_config,
            history,
        ) = dumped_model
        assert len(cov_embeddings_state_dicts) == 1
        print("hi")
#     # sanity check
#     if append_ae_layer:
#         assert model_config["num_genes"] < self.datasets["training"].num_genes
#     else:
#         assert model_config["num_genes"] == self.datasets["training"].num_genes
#     assert model_config["use_drugs_idx"]
#     keys = list(state_dict.keys())
#     for key in keys:
#         # remove all components which we will train from scratch
#         # the drug embedding is saved in the state_dict for some reason, but we don't need it
#         if key.startswith("adversary_drugs") or key == "drug_embeddings.weight":
#             state_dict.pop(key)
#     if self.embedding_model_type == "vanilla":
#         # for Vanilla CPA, we also train the amortized doser & drug_embedding_encoder anew
#         keys = list(state_dict.keys())
#         for key in keys:
#             if key.startswith("dosers") or key.startswith("drug_embedding_encoder"):
#                 state_dict.pop(key)
    return state_dict, cov_embeddings_state_dicts, model_config

In [32]:
model_scratch.drug_embeddings.weight

Parameter containing:
tensor([[ 0.0595, -1.1332, -1.0637,  ..., -0.1452, -0.4641,  1.3959],
        [-0.5587, -1.0031, -0.9086,  ...,  1.2048, -0.4641,  1.3008],
        [-1.4123,  3.1108,  2.3815,  ..., -0.1452, -0.4641, -2.7933],
        ...,
        [-1.0110,  0.8632,  0.4577,  ...,  1.2048, -0.4641, -1.5177],
        [ 1.4706, -1.4112, -1.6018,  ..., -0.1452,  1.9109,  0.2541],
        [ 2.4432, -2.6194, -2.7603,  ..., -0.1452, -0.4641, -0.9532]])

In [22]:
append_ae_layer = False
append_layer_width = (
    datasets["training"].num_genes if append_ae_layer else None
)
in_out_size = (
    model_config["num_genes"]
    if append_ae_layer
    else datasets["training"].num_genes
)
# idea: Reconstruct the ComPert model as pretrained (hence the "old" in_out_size)
# then add the append_layer (the "new" in_out_size)

from compert.embedding import get_chemical_representation
embedding = get_chemical_representation(
    smiles=canon_smiles_unique_sorted,
    embedding_model=config["model"]["embedding"]["model"],
    data_dir=config["model"]["embedding"]["directory"],
    device="cuda",
)
state_dict, cov_embeddings_state_dicts, model_config = load_torch_model("test.pt")
append_layer_width = (
    config["dataset"]["n_vars"]
    if (config["model"]["append_ae_layer"] and config["model"]["load_pretrained"])
    else None
)

if config["model"]["embedding"]["model"] != "vanilla":
    state_dict.pop("drug_embeddings.weight")

from compert.model import ComPert
autoencoder = ComPert(
    **model_config, drug_embeddings=embedding, append_layer_width=append_layer_width
)
incomp_keys = autoencoder.load_state_dict(state_dict, strict=False)
for embedding, state_dict in zip(
    autoencoder.covariates_embeddings, cov_embeddings_state_dicts
):
    embedding.load_state_dict(state_dict)
autoencoder.eval()
print(
    f"INCOMP_KEYS (make sure these contain what you expected):\n{incomp_keys}"
)

hi
INCOMP_KEYS (make sure these contain what you expected):
_IncompatibleKeys(missing_keys=['drug_embeddings.weight'], unexpected_keys=[])


In [33]:
x = torch.randn((3,2000))

(model_scratch.encoder.network[0](x) == autoencoder.encoder.network[0](x)).all()

tensor(True)

In [34]:
drug_r2_scratch, _ = compute_pred(autoencoder,
                                  datasets['ood'],
                                  genes_control=datasets['test_control'].genes, 
                                  dosages=dosages) # non-pretrained

0it [00:00, ?it/s]

MCF7_CUDC-101_0.001: 0.56
MCF7_CUDC-101_0.01: 0.55
MCF7_CUDC-101_0.1: 0.50
MCF7_CUDC-101_1.0: 0.24
MCF7_CUDC-907_0.001: 0.35
MCF7_CUDC-907_0.01: 0.33
MCF7_CUDC-907_0.1: 0.16
MCF7_CUDC-907_1.0: -0.39
MCF7_Dacinostat_0.001: 0.38
MCF7_Dacinostat_0.01: 0.27
MCF7_Dacinostat_0.1: -0.29
MCF7_Dacinostat_1.0: -0.41
MCF7_Givinostat_0.001: 0.45
MCF7_Givinostat_0.01: 0.46
MCF7_Givinostat_0.1: 0.14
MCF7_Givinostat_1.0: -0.49
MCF7_Hesperadin_0.001: 0.47
MCF7_Hesperadin_0.01: 0.48
MCF7_Hesperadin_0.1: 0.16
MCF7_Hesperadin_1.0: 0.39
MCF7_Pirarubicin_0.001: 0.40
MCF7_Pirarubicin_0.01: 0.40
MCF7_Pirarubicin_0.1: 0.44
MCF7_Pirarubicin_1.0: 0.32
MCF7_Raltitrexed_0.001: 0.35
MCF7_Raltitrexed_0.01: 0.34
MCF7_Raltitrexed_0.1: 0.23
MCF7_Raltitrexed_1.0: 0.25
MCF7_Tanespimycin_0.001: 0.21
MCF7_Tanespimycin_0.01: 0.22
MCF7_Tanespimycin_0.1: 0.33
MCF7_Tanespimycin_1.0: 0.22
MCF7_Trametinib_0.001: 0.05
MCF7_Trametinib_0.01: -0.02
MCF7_Trametinib_0.1: 0.26
MCF7_Trametinib_1.0: 0.29
