**Requirements:**
* Trained models

**Outputs:** 
* none 
___
# Imports

In [3]:
import matplotlib
import umap.plot
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sn

from utils import load_config, load_dataset, load_smiles, load_model, compute_drug_embeddings, compute_pred
from compert.data import load_dataset_splits

matplotlib.style.use("fivethirtyeight")
matplotlib.style.use("seaborn-talk")
matplotlib.rcParams['font.family'] = "monospace"
matplotlib.rcParams['figure.dpi'] = 60
matplotlib.pyplot.rcParams['savefig.facecolor'] = 'white'
sn.set_context("poster")

Using backend: pytorch[22:00:19] /opt/dgl/src/runtime/tensordispatch.cc:43: TensorDispatcher: dlopen failed: /home/icb/leon.hetzel/miniconda3/envs/chemical_CPA/lib/python3.7/site-packages/dgl/tensoradapter/pytorch/libtensoradapter_pytorch_1.10.1.so: cannot open shared object file: No such file or directory



In [4]:
%load_ext autoreload
%autoreload 2

# Load and analyse model 
* Define `seml_collection` and `model_hash` to load data and model

In [5]:
seml_collection = "finetuning_num_genes"
# split_ho_pathway, append_ae_layer: true
model_hash_pretrained = "70290e4f42ac4cb19246fafa0b75ccb6" # "config.model.load_pretrained": true, 
model_hash_scratch = "00e7e9c7979f90d1325f25f9ff4e3fcb" # "config.model.load_pretrained": false, 

# split_ood_finetuning, append_ae_layer: true
model_hash_pretrained = "bd001c8d557edffe9df9e6bf09dc4120" # "config.model.load_pretrained": true, 
model_hash_scratch = "6e9d00880375aa450a8e5de60250659f" # "config.model.load_pretrained": false, 

seml_collection = "sciplex_hparam"
# rdkit 
# split_ood_finetuning, append_ae_layer: false
model_hash_pretrained = "d9ee464c93a0d2d947e9115f8d834f22" # "config.model.load_pretrained": true, 
model_hash_scratch = "0a929eab639127e304271036fe478e0b" # "config.model.load_pretrained": false, 

# grover
# split_ood_finetuning, append_ae_layer: false
model_hash_pretrained = "bacf2e0b3f9dee9078a97c5216bf7f1c" # "config.model.load_pretrained": true, 
model_hash_scratch = "d635df7c184dfff217e09ca93395604b" # "config.model.load_pretrained": false, 

## Load config

In [6]:
config = load_config(seml_collection, model_hash_pretrained)
dataset, key_dict = load_dataset(config)
config['dataset']['n_vars'] = dataset.n_vars

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

### Load smiles info

In [7]:
canon_smiles_unique_sorted, smiles_to_pathway_map, smiles_to_drug_map = load_smiles(config, dataset, key_dict)

#### Define which drugs should be annotaded with list `ood_drugs`

In [8]:
ood_drugs = dataset.obs.condition[dataset.obs[config["dataset"]["data_params"]["split_key"]].isin(['ood'])].unique().to_list()

#### Get pathway level 2 annotation for clustering of drug embeddings

In [9]:
smiles_to_pw_level2_map = {}
pw1_to_pw2 = {}

for (drug, pw1, pw2), df in dataset.obs.groupby(['SMILES', 'pathway_level_1', 'pathway_level_2']): 
    smiles_to_pw_level2_map[drug] = pw2
    if pw1 in pw1_to_pw2:
        pw1_to_pw2[pw1].add(pw2)
    else: 
        pw1_to_pw2[pw1] = {pw2}

In [10]:
groups = ["Epigenetic regulation"]

groups_pw2 = [pw2 for pw in groups for pw2 in pw1_to_pw2[pw]]
groups_pw2

['DNA methylation',
 'Histone acetylation',
 'Histone demethylase',
 'Bromodomain',
 'Histone deacetylation',
 'Histone methylation']

## Load dataset splits

In [11]:
config['dataset']['data_params']

{'covariate_keys': 'cell_type',
 'dataset_path': '/storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/datasets/sciplex_complete_lincs_genes.h5ad',
 'degs_key': 'lincs_DEGs',
 'dose_key': 'dose',
 'pert_category': 'cov_drug_dose_name',
 'perturbation_key': 'condition',
 'smiles_key': 'SMILES',
 'split_key': 'split_ood_finetuning',
 'use_drugs_idx': True}

In [12]:
data_params = config['dataset']['data_params']

# #Overwrite split_key 
# data_params['split_key'] = 'split_ho_epigenetic'

datasets = load_dataset_splits(**data_params, return_dataset=False)

___
## Pretrained model

In [13]:
dosages = [1e3, 1e4]
cell_lines = ["A549", "K562", 'MCF7']  # ["A549", "K562", "MCF7"]
# cell_lines = ['MCF7']

In [14]:
config = load_config(seml_collection, model_hash_pretrained)
config['dataset']['n_vars'] = dataset.n_vars
model_pretrained, embedding_pretrained = load_model(config, canon_smiles_unique_sorted)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [22]:
drug_r2_pretrained, _ = compute_pred(model_pretrained, 
                                     datasets['ood'], 
                                     genes_control=datasets['test_control'].genes, 
                                     cell_lines=cell_lines,
                                     dosages=dosages)

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

A549_CUDC-101_0.1: 0.73
A549_CUDC-101_1.0: 0.55
A549_CUDC-907_0.1: 0.20
A549_CUDC-907_1.0: 0.12
A549_Dacinostat_0.1: 0.61
A549_Dacinostat_1.0: 0.55
A549_Givinostat_0.1: 0.30
A549_Givinostat_1.0: 0.55
A549_Hesperadin_0.1: 0.69
A549_Hesperadin_1.0: 0.69
A549_Pirarubicin_0.1: 0.24
A549_Pirarubicin_1.0: 0.08
A549_Raltitrexed_0.1: -0.19
A549_Raltitrexed_1.0: -0.46
A549_Tanespimycin_0.1: 0.37
A549_Tanespimycin_1.0: 0.35
A549_Trametinib_0.1: 0.40
A549_Trametinib_1.0: 0.39
K562_CUDC-101_0.1: 0.45
K562_CUDC-101_1.0: -0.11
K562_CUDC-907_0.1: 0.37
K562_CUDC-907_1.0: 0.22
K562_Dacinostat_0.1: 0.45
K562_Dacinostat_1.0: 0.08
K562_Givinostat_0.1: 0.18
K562_Givinostat_1.0: 0.43
K562_Hesperadin_0.1: 0.27
K562_Hesperadin_1.0: 0.21
K562_Pirarubicin_0.1: 0.54
K562_Pirarubicin_1.0: 0.50
K562_Raltitrexed_0.1: 0.37
K562_Raltitrexed_1.0: 0.30
K562_Tanespimycin_0.1: 0.23
K562_Tanespimycin_1.0: -0.17
K562_Trametinib_0.1: 0.36
K562_Trametinib_1.0: 0.35
MCF7_CUDC-101_0.1: 0.69
MCF7_CUDC-101_1.0: 0.56
MCF7_CUDC-90

## Non-pretrained model

In [16]:
config = load_config(seml_collection, model_hash_scratch)
config['dataset']['n_vars'] = dataset.n_vars
model_scratch, embedding_scratch = load_model(config, canon_smiles_unique_sorted)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [20]:
drug_r2_scratch, _ = compute_pred(model_scratch,
                                  datasets['ood'],
                                  genes_control=datasets['test_control'].genes, 
                                  dosages=dosages, 
                                  cell_lines=cell_lines,
                                 ) # non-pretrained

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

A549_CUDC-101_0.1: 0.95
A549_CUDC-101_1.0: 0.70
A549_CUDC-907_0.1: 0.66
A549_CUDC-907_1.0: 0.34
A549_Dacinostat_0.1: 0.25
A549_Dacinostat_1.0: 0.12
A549_Givinostat_0.1: 0.82
A549_Givinostat_1.0: 0.18
A549_Hesperadin_0.1: 0.89
A549_Hesperadin_1.0: 0.90
A549_Pirarubicin_0.1: 0.82
A549_Pirarubicin_1.0: 0.58
A549_Raltitrexed_0.1: 0.29
A549_Raltitrexed_1.0: 0.02
A549_Tanespimycin_0.1: 0.78
A549_Tanespimycin_1.0: 0.74
A549_Trametinib_0.1: 0.66
A549_Trametinib_1.0: 0.66
K562_CUDC-101_0.1: 0.23
K562_CUDC-101_1.0: -0.62
K562_CUDC-907_0.1: 0.31
K562_CUDC-907_1.0: 0.16
K562_Dacinostat_0.1: -0.54
K562_Dacinostat_1.0: -1.41
K562_Givinostat_0.1: 0.42
K562_Givinostat_1.0: 0.41
K562_Hesperadin_0.1: 0.33
K562_Hesperadin_1.0: 0.24
K562_Pirarubicin_0.1: 0.51
K562_Pirarubicin_1.0: 0.60
K562_Raltitrexed_0.1: -0.06
K562_Raltitrexed_1.0: -0.20
K562_Tanespimycin_0.1: -0.37
K562_Tanespimycin_1.0: -0.87
K562_Trametinib_0.1: -0.20
K562_Trametinib_1.0: -0.26
MCF7_CUDC-101_0.1: 0.36
MCF7_CUDC-101_1.0: 0.31
MCF7_CU

In [21]:
np.mean([max(v, 0) for v in drug_r2_scratch.values()])

0.3166501146775705

In [23]:
np.mean([max(v, 0) for v in drug_r2_pretrained.values()])

0.39955469303660923

In [28]:
from utils import evaluate_r2

In [31]:
evaluate_r2(model_pretrained, datasets['ood'], datasets['test_control'].genes)

Number of different r2 computations: 108


[0.2497908925568616,
 0.40579023129410213,
 0.22323442002137503,
 -0.1854762131417239]

In [35]:
from compert.paths import CHECKPOINT_DIR
import torch 

model_checkp = CHECKPOINT_DIR / (model_hash_pretrained + ".pt")

state_dict, cov_state_dicts, init_args, history = torch.load(model_checkp)

In [None]:
model_pretrained

In [36]:
cov_state_dicts

[OrderedDict([('network.0.weight',
               tensor([[-0.0325,  0.1343,  0.0268,  ...,  0.1718, -0.0321,  0.2684],
                       [ 0.0702,  0.0787,  0.0593,  ..., -0.0144,  0.0668,  0.1410],
                       [-0.0208,  0.0151,  0.0995,  ...,  0.1422,  0.0176, -0.1239],
                       ...,
                       [-0.1085,  0.0709, -0.1513,  ...,  0.0088,  0.1379,  0.0973],
                       [-0.0739, -0.1386, -0.0178,  ...,  0.0274,  0.0142,  0.0949],
                       [-0.0413, -0.0946, -0.0621,  ..., -0.2739,  0.0046,  0.0533]],
                      device='cuda:0')),
              ('network.0.bias',
               tensor([ 2.4672e-06,  1.1429e-05,  3.8087e-07, -2.4739e-06,  8.0045e-06,
                        8.7927e-07,  2.7636e-06, -4.3956e-06,  8.6761e-06,  5.4314e-06,
                       -6.2532e-06, -3.2854e-06,  3.3048e-06,  3.5233e-06,  4.2617e-06,
                       -4.5976e-06, -3.7156e-06, -6.8080e-06, -9.9235e-07,  2.2717e-06,


In [41]:
model_pretrained.covariates_embeddings[0].weight

Parameter containing:
tensor([[ 0.4081,  0.3767, -1.1169, -0.4985,  0.2052,  1.9294,  0.0779,  0.3978,
          0.5137,  1.5418,  0.7749,  0.5571, -0.5488, -1.4260, -0.9397, -0.6951,
         -1.5591, -0.3072,  0.5017,  1.4675, -0.2127, -1.1622,  1.4627, -1.6516,
          0.0434,  0.1914, -0.7639,  1.1491,  0.2505, -0.7223, -0.0294, -0.3321],
        [ 0.5505,  0.0382, -1.3970, -0.5779, -0.0952, -0.2434, -0.1063,  0.7139,
         -1.4399,  1.2624, -1.3002,  0.2700, -0.0426, -1.1687, -0.3934, -1.2991,
         -0.5546,  1.0597, -0.5303, -0.5718,  1.8390, -1.9639,  1.6057,  0.3113,
          1.1219, -1.1024, -1.0911,  0.9484, -0.1188, -0.4019,  0.1167, -0.3188],
        [-0.7843, -1.9332, -1.1973,  0.3887, -0.3791,  2.4663, -0.0663,  0.0252,
         -0.2218,  2.2960, -1.2462,  0.0498,  0.8647,  0.0201,  1.5937,  2.0901,
          1.7925,  0.2062, -0.0372,  0.9941, -1.7321, -1.6416,  1.5243,  1.1358,
         -0.5476, -1.3976, -0.3468,  0.3913,  1.5127, -1.3529, -0.1891,  0.4813]],
  

In [45]:
[k for k in state_dict.keys() if 'emb' in k]

['drug_embeddings.weight',
 'drug_embedding_encoder.network.0.weight',
 'drug_embedding_encoder.network.0.bias',
 'drug_embedding_encoder.network.1.weight',
 'drug_embedding_encoder.network.1.bias',
 'drug_embedding_encoder.network.1.running_mean',
 'drug_embedding_encoder.network.1.running_var',
 'drug_embedding_encoder.network.1.num_batches_tracked',
 'drug_embedding_encoder.network.3.weight',
 'drug_embedding_encoder.network.3.bias',
 'drug_embedding_encoder.network.4.weight',
 'drug_embedding_encoder.network.4.bias',
 'drug_embedding_encoder.network.4.running_mean',
 'drug_embedding_encoder.network.4.running_var',
 'drug_embedding_encoder.network.4.num_batches_tracked',
 'drug_embedding_encoder.network.6.weight',
 'drug_embedding_encoder.network.6.bias',
 'drug_embedding_encoder.network.7.weight',
 'drug_embedding_encoder.network.7.bias',
 'drug_embedding_encoder.network.7.running_mean',
 'drug_embedding_encoder.network.7.running_var',
 'drug_embedding_encoder.network.7.num_batches