**Requirements:**
* Trained models

**Outputs:** 
* none 
___
# Imports

In [8]:
import matplotlib
import umap.plot
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sn

from utils import load_config, load_dataset, load_smiles, load_model, compute_drug_embeddings, compute_pred
from compert.data import load_dataset_splits

matplotlib.style.use("fivethirtyeight")
matplotlib.style.use("seaborn-talk")
matplotlib.rcParams['font.family'] = "monospace"
matplotlib.rcParams['figure.dpi'] = 60
matplotlib.pyplot.rcParams['savefig.facecolor'] = 'white'
sn.set_context("poster")

In [2]:
%load_ext autoreload
%autoreload 2

# Load and analyse model 
* Define `seml_collection` and `model_hash` to load data and model

In [17]:
seml_collection = "finetuning_num_genes"

# split_ho_pathway, append_ae_layer: true
model_hash_pretrained = "70290e4f42ac4cb19246fafa0b75ccb6" # "config.model.load_pretrained": true, 
model_hash_scratch = "ed3bc586a5fcfe3c4dbb0157cd67d0d9" # "config.model.load_pretrained": false, 

# split_ood_finetuning, append_ae_layer: true
model_hash_pretrained = "bd001c8d557edffe9df9e6bf09dc4120" # "config.model.load_pretrained": true, 
model_hash_scratch = "6e9d00880375aa450a8e5de60250659f" # "config.model.load_pretrained": false, 

## Load config

In [18]:
config = load_config(seml_collection, model_hash_pretrained)
dataset, key_dict = load_dataset(config)
config['dataset']['n_vars'] = dataset.n_vars

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

### Load smiles info

In [19]:
canon_smiles_unique_sorted, smiles_to_pathway_map, smiles_to_drug_map = load_smiles(config, dataset, key_dict)

#### Define which drugs should be annotaded with list `ood_drugs`

In [23]:
ood_drugs = dataset.obs.condition[dataset.obs[config["dataset"]["data_params"]["split_key"]].isin(['ood'])].unique().to_list()

#### Get pathway level 2 annotation for clustering of drug embeddings

In [24]:
smiles_to_pw_level2_map = {}
pw1_to_pw2 = {}

for (drug, pw1, pw2), df in dataset.obs.groupby(['SMILES', 'pathway_level_1', 'pathway_level_2']): 
    smiles_to_pw_level2_map[drug] = pw2
    if pw1 in pw1_to_pw2:
        pw1_to_pw2[pw1].add(pw2)
    else: 
        pw1_to_pw2[pw1] = {pw2}

In [25]:
groups = ["Epigenetic regulation"]

groups_pw2 = [pw2 for pw in groups for pw2 in pw1_to_pw2[pw]]
groups_pw2

['Histone methylation',
 'DNA methylation',
 'Bromodomain',
 'Histone deacetylation',
 'Histone demethylase',
 'Histone acetylation']

## Load dataset splits

In [26]:
config['dataset']['data_params']

{'covariate_keys': 'cell_type',
 'dataset_path': '/storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/datasets/sciplex_complete.h5ad',
 'degs_key': 'all_DEGs',
 'dose_key': 'dose',
 'pert_category': 'cov_drug_dose_name',
 'perturbation_key': 'condition',
 'smiles_key': 'SMILES',
 'split_key': 'split_ood_finetuning',
 'use_drugs_idx': True}

In [97]:
data_params = config['dataset']['data_params']

# #Overwrite split_key 
# data_params['split_key'] = 'split_ho_epigenetic'

datasets = load_dataset_splits(**data_params, return_dataset=False)

___
## Pretrained model

In [98]:
dosages = [1e1, 1e2, 1e3, 1e4]

In [109]:
config = load_config(seml_collection, model_hash_pretrained)
config['dataset']['n_vars'] = dataset.n_vars
model_pretrained, embedding_pretrained = load_model(config, canon_smiles_unique_sorted)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [110]:
drug_r2_pretrained, _ = compute_pred(model_pretrained, 
                                     datasets['ood'], 
                                     genes_control=datasets['test_control'].genes, 
                                     dosages=dosages)

0it [00:00, ?it/s]

MCF7_CUDC-101_0.001: 0.77
MCF7_CUDC-101_0.01: 0.78
MCF7_CUDC-101_0.1: 0.77
MCF7_CUDC-101_1.0: 0.60
MCF7_CUDC-907_0.001: 0.66
MCF7_CUDC-907_0.01: 0.67
MCF7_CUDC-907_0.1: 0.58
MCF7_CUDC-907_1.0: -0.39
MCF7_Dacinostat_0.001: 0.73
MCF7_Dacinostat_0.01: 0.72
MCF7_Dacinostat_0.1: 0.12
MCF7_Dacinostat_1.0: -0.10
MCF7_Givinostat_0.001: 0.77
MCF7_Givinostat_0.01: 0.78
MCF7_Givinostat_0.1: 0.84
MCF7_Givinostat_1.0: -0.24
MCF7_Hesperadin_0.001: 0.72
MCF7_Hesperadin_0.01: 0.74
MCF7_Hesperadin_0.1: 0.83
MCF7_Hesperadin_1.0: 0.80
MCF7_Pirarubicin_0.001: 0.61
MCF7_Pirarubicin_0.01: 0.62
MCF7_Pirarubicin_0.1: 0.62
MCF7_Pirarubicin_1.0: 0.59
MCF7_Raltitrexed_0.001: 0.62
MCF7_Raltitrexed_0.01: 0.60
MCF7_Raltitrexed_0.1: 0.57
MCF7_Raltitrexed_1.0: 0.57
MCF7_Tanespimycin_0.001: 0.65
MCF7_Tanespimycin_0.01: 0.60
MCF7_Tanespimycin_0.1: 0.59
MCF7_Tanespimycin_1.0: 0.55
MCF7_Trametinib_0.001: 0.68
MCF7_Trametinib_0.01: 0.67
MCF7_Trametinib_0.1: 0.72
MCF7_Trametinib_1.0: 0.75


## Non-pretrained model

In [111]:
config = load_config(seml_collection, model_hash_scratch)
config['dataset']['n_vars'] = dataset.n_vars
model_scratch, embedding_scratch = load_model(config, canon_smiles_unique_sorted)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [112]:
drug_r2_scratch, _ = compute_pred(model_scratch,
                                  datasets['ood'],
                                  genes_control=datasets['test_control'].genes, 
                                  dosages=dosages) # non-pretrained

0it [00:00, ?it/s]

MCF7_CUDC-101_0.001: 0.55
MCF7_CUDC-101_0.01: 0.54
MCF7_CUDC-101_0.1: 0.50
MCF7_CUDC-101_1.0: 0.25
MCF7_CUDC-907_0.001: 0.34
MCF7_CUDC-907_0.01: 0.32
MCF7_CUDC-907_0.1: 0.16
MCF7_CUDC-907_1.0: -0.39
MCF7_Dacinostat_0.001: 0.39
MCF7_Dacinostat_0.01: 0.27
MCF7_Dacinostat_0.1: -0.28
MCF7_Dacinostat_1.0: -0.41
MCF7_Givinostat_0.001: 0.44
MCF7_Givinostat_0.01: 0.45
MCF7_Givinostat_0.1: 0.15
MCF7_Givinostat_1.0: -0.49
MCF7_Hesperadin_0.001: 0.46
MCF7_Hesperadin_0.01: 0.47
MCF7_Hesperadin_0.1: 0.18
MCF7_Hesperadin_1.0: 0.40
MCF7_Pirarubicin_0.001: 0.39
MCF7_Pirarubicin_0.01: 0.40
MCF7_Pirarubicin_0.1: 0.43
MCF7_Pirarubicin_1.0: 0.32
MCF7_Raltitrexed_0.001: 0.33
MCF7_Raltitrexed_0.01: 0.31
MCF7_Raltitrexed_0.1: 0.22
MCF7_Raltitrexed_1.0: 0.23
MCF7_Tanespimycin_0.001: 0.20
MCF7_Tanespimycin_0.01: 0.21
MCF7_Tanespimycin_0.1: 0.32
MCF7_Tanespimycin_1.0: 0.22
MCF7_Trametinib_0.001: 0.05
MCF7_Trametinib_0.01: -0.01
MCF7_Trametinib_0.1: 0.24
MCF7_Trametinib_1.0: 0.28


In [103]:
dataset.obs.loc[dataset.obs.split_ood_finetuning=='ood', 'condition'].unique().to_list()

['Raltitrexed',
 'Trametinib',
 'Hesperadin',
 'CUDC-101',
 'Dacinostat',
 'Pirarubicin',
 'CUDC-907',
 'Tanespimycin',
 'Givinostat']

In [113]:
np.mean([max(v, 0) for v in drug_r2_scratch.values()])

0.18929382165273032

In [114]:
np.mean([max(v, 0) for v in drug_r2_pretrained.values()])

0.4283696214358012

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.6065,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])