**Requirements:**
* Trained models
* GROVER: 
     * fine-tuned:      `'a50dc68191a3776694ce8f34ad55e7e0'` 
     * non-pretrained: `'51b81b77079c1060aedb0ee2259008ca'`
* RDKit: 
     * fine-tuned:      `'27b401db1845eea26c102fb614df9c33'` 
     * non-pretrained: `'cbf9e956049fce00dbcebdfc1aeb67fe'`
* JT-VAE: 
     * fine-tuned:      `'f9e328d21bff64c5541f81ae1303c279'` 
     * non-pretrained: `'d273bf129f3a866a4d02f8925cf5cc8d'`

Here everything is in setting 1 (shared gene sets)  

**Outputs:**
* **Table 2** 
* Supplement Table 9
___
# Imports

In [1]:
import matplotlib
import umap.plot
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scipy
import seaborn as sns

from utils import load_config, load_dataset, load_smiles, load_model, compute_drug_embeddings, compute_pred, compute_pred_ctrl
from chemCPA.data import load_dataset_splits
from chemCPA.paths import FIGURE_DIR

matplotlib.style.use("fivethirtyeight")
matplotlib.style.use("seaborn-talk")
matplotlib.rcParams['font.family'] = "monospace"
matplotlib.rcParams['figure.dpi'] = 300
matplotlib.pyplot.rcParams['savefig.facecolor'] = 'white'
sns.set_style("whitegrid")
sns.set_context("poster")

[15:41:02] /opt/dgl/src/runtime/tensordispatch.cc:43: TensorDispatcher: dlopen failed: /home/icb/leon.hetzel/miniconda3/envs/chemical_CPA/lib/python3.7/site-packages/dgl/tensoradapter/pytorch/libtensoradapter_pytorch_1.10.1.so: cannot open shared object file: No such file or directory
Using backend: pytorch


In [2]:
%load_ext autoreload
%autoreload 2

# Load model configs and dataset
* Define `seml_collection` and `model_hash` to load data and model


**Info**
* split:            `split_ood_finetuning`  
* append_ae_layer:  `False`

In [3]:
seml_collection = "finetuning_num_genes"

model_hash_pretrained_rdkit = "27b401db1845eea26c102fb614df9c33" # Fine-tuned 
model_hash_scratch_rdkit = "51b81b77079c1060aedb0ee2259008ca" # Non-pretrained

model_hash_pretrained_grover = "a50dc68191a3776694ce8f34ad55e7e0" # Fine-tuned
model_hash_scratch_grover = "0807497c5407f4e0c8a52207f36a185f" # Non-pretrained

model_hash_pretrained_jtvae = "f9e328d21bff64c5541f81ae1303c279" # Fine-tuned
model_hash_scratch_jtvae = "d273bf129f3a866a4d02f8925cf5cc8d" # Non-pretrained

## Load config and SMILES

In [4]:
config = load_config(seml_collection, model_hash_pretrained_rdkit)
dataset, key_dict = load_dataset(config)
config['dataset']['n_vars'] = dataset.n_vars

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [5]:
canon_smiles_unique_sorted, smiles_to_pathway_map, smiles_to_drug_map = load_smiles(config, dataset, key_dict, True)

Get list of drugs that are ood in `ood_drugs`

In [6]:
ood_drugs = dataset.obs.condition[dataset.obs[config["dataset"]["data_params"]["split_key"]].isin(['ood'])].unique().to_list()

## Load dataset splits

In [7]:
config['dataset']['data_params']

{'covariate_keys': 'cell_type',
 'dataset_path': '/storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/datasets/sciplex_complete_middle_subset_lincs_genes.h5ad',
 'degs_key': 'lincs_DEGs',
 'dose_key': 'dose',
 'pert_category': 'cov_drug_dose_name',
 'perturbation_key': 'condition',
 'smiles_key': 'SMILES',
 'split_key': 'split_ood_finetuning',
 'use_drugs_idx': True}

In [8]:
data_params = config['dataset']['data_params']
datasets = load_dataset_splits(**data_params, return_dataset=False)

____
# Run models
## Baseline model

In [9]:
dosages = [1e1,1e2,1e3,1e4]
cell_lines = ["A549", "K562", "MCF7"]
use_DEGs = True

In [10]:
drug_r2_baseline_degs, _ = compute_pred_ctrl(dataset=datasets['ood'],
                                        dataset_ctrl=datasets['test_control'],
                                        dosages=dosages,
                                        cell_lines=cell_lines,
                                        use_DEGs=True,
                                        verbose=False,
                                       )

drug_r2_baseline_all, _ = compute_pred_ctrl(dataset=datasets['ood'],
                                        dataset_ctrl=datasets['test_control'],
                                        dosages=dosages,
                                        cell_lines=cell_lines,
                                        use_DEGs=False,
                                        verbose=False,
                                       )

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

## RDKit

In [11]:
ood_drugs

['Hesperadin',
 'CUDC-101',
 'Raltitrexed',
 'Trametinib',
 'Dacinostat',
 'CUDC-907',
 'Pirarubicin',
 'Tanespimycin',
 'Givinostat']

### Pretrained

In [12]:
config = load_config(seml_collection, model_hash_pretrained_rdkit)
config['dataset']['n_vars'] = dataset.n_vars
model_pretrained_rdkit, embedding_pretrained_rdkit = load_model(config, canon_smiles_unique_sorted)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [13]:
drug_r2_pretrained_degs_rdkit, _ = compute_pred(model_pretrained_rdkit, 
                                     datasets['ood'], 
                                     genes_control=datasets['test_control'].genes, 
                                     dosages=dosages,
                                     cell_lines=cell_lines,
                                     use_DEGs=True,
                                     verbose=False,
                                    )

drug_r2_pretrained_all_rdkit, _ = compute_pred(model_pretrained_rdkit, 
                                     datasets['ood'], 
                                     genes_control=datasets['test_control'].genes, 
                                     dosages=dosages,
                                     cell_lines=cell_lines,
                                     use_DEGs=False,
                                     verbose=False,
                                    )

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

### Non-pretrained model

In [14]:
config = load_config(seml_collection, model_hash_scratch_rdkit)
config['dataset']['n_vars'] = dataset.n_vars
model_scratch_rdkit, embedding_scratch_rdkit = load_model(config, canon_smiles_unique_sorted)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [15]:
drug_r2_scratch_degs_rdkit, _ = compute_pred(model_scratch_rdkit,
                                  datasets['ood'],
                                  genes_control=datasets['test_control'].genes, 
                                  dosages=dosages,
                                  cell_lines=cell_lines,
                                  use_DEGs=True,
                                  verbose=False,
                                 ) # non-pretrained

drug_r2_scratch_all_rdkit, _ = compute_pred(model_scratch_rdkit,
                                  datasets['ood'],
                                  genes_control=datasets['test_control'].genes, 
                                  dosages=dosages,
                                  cell_lines=cell_lines,
                                  use_DEGs=False,
                                  verbose=False,
                                 ) # non-pretrained

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

## GROVER

### Pretrained

In [16]:
config = load_config(seml_collection, model_hash_pretrained_grover)
config['dataset']['n_vars'] = dataset.n_vars
model_pretrained_grover, embedding_pretrained_grover = load_model(config, canon_smiles_unique_sorted)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [17]:
drug_r2_pretrained_degs_grover, _ = compute_pred(model_pretrained_grover, 
                                     datasets['ood'], 
                                     genes_control=datasets['test_control'].genes, 
                                     dosages=dosages,
                                     cell_lines=cell_lines,
                                     use_DEGs=True,
                                     verbose=False,
                                    )

drug_r2_pretrained_all_grover, _ = compute_pred(model_pretrained_grover, 
                                     datasets['ood'], 
                                     genes_control=datasets['test_control'].genes, 
                                     dosages=dosages,
                                     cell_lines=cell_lines,
                                     use_DEGs=False,
                                     verbose=False,
                                    )

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

### Non-pretrained model

In [18]:
config = load_config(seml_collection, model_hash_scratch_grover)
config['dataset']['n_vars'] = dataset.n_vars
model_scratch_grover, embedding_scratch_grover = load_model(config, canon_smiles_unique_sorted)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [19]:
drug_r2_scratch_degs_grover, _ = compute_pred(model_scratch_grover,
                                  datasets['ood'],
                                  genes_control=datasets['test_control'].genes, 
                                  dosages=dosages,
                                  cell_lines=cell_lines,
                                  use_DEGs=True,
                                  verbose=False,
                                 ) # non-pretrained

drug_r2_scratch_all_grover, _ = compute_pred(model_scratch_grover,
                                  datasets['ood'],
                                  genes_control=datasets['test_control'].genes, 
                                  dosages=dosages,
                                  cell_lines=cell_lines,
                                  use_DEGs=False,
                                  verbose=False,
                                 ) # non-pretrained

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

## JT-VAE

### Pretrained

In [20]:
config = load_config(seml_collection, model_hash_pretrained_jtvae)
config['dataset']['n_vars'] = dataset.n_vars
model_pretrained_jtvae, embedding_pretrained_jtvae = load_model(config, canon_smiles_unique_sorted)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [21]:
drug_r2_pretrained_degs_jtvae, _ = compute_pred(model_pretrained_jtvae, 
                                     datasets['ood'], 
                                     genes_control=datasets['test_control'].genes, 
                                     dosages=dosages,
                                     cell_lines=cell_lines,
                                     use_DEGs=True,
                                     verbose=False,
                                    )

drug_r2_pretrained_all_jtvae, _ = compute_pred(model_pretrained_jtvae, 
                                     datasets['ood'], 
                                     genes_control=datasets['test_control'].genes, 
                                     dosages=dosages,
                                     cell_lines=cell_lines,
                                     use_DEGs=False,
                                     verbose=False,
                                    )

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

### Non-pretrained model

In [22]:
config = load_config(seml_collection, model_hash_scratch_jtvae)
config['dataset']['n_vars'] = dataset.n_vars
model_scratch_jtvae, embedding_scratch_jtvae = load_model(config, canon_smiles_unique_sorted)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [23]:
drug_r2_scratch_degs_jtvae, _ = compute_pred(model_scratch_jtvae,
                                  datasets['ood'],
                                  genes_control=datasets['test_control'].genes, 
                                  dosages=dosages,
                                  cell_lines=cell_lines,
                                  use_DEGs=True,
                                  verbose=False,
                                 ) # non-pretrained

drug_r2_scratch_all_jtvae, _ = compute_pred(model_scratch_jtvae,
                                  datasets['ood'],
                                  genes_control=datasets['test_control'].genes, 
                                  dosages=dosages,
                                  cell_lines=cell_lines,
                                  use_DEGs=False,
                                  verbose=False,
                                 ) # non-pretrained

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

# Combine results and create dataframe

In [24]:
def create_df(drug_r2_baseline, 
              drug_r2_pretrained_rdkit, 
              drug_r2_scratch_rdkit,
              drug_r2_pretrained_grover,
              drug_r2_scratch_grover,
              drug_r2_pretrained_jtvae,
              drug_r2_scratch_jtvae,
             ):
    df_baseline = pd.DataFrame.from_dict(drug_r2_baseline, orient='index', columns=['r2_de'])
    df_baseline['type'] = 'baseline'
    df_baseline['model'] = 'baseline'
    
    df_pretrained_rdkit = pd.DataFrame.from_dict(drug_r2_pretrained_rdkit, orient='index', columns=['r2_de'])
    df_pretrained_rdkit['type'] = 'pretrained'
    df_pretrained_rdkit['model'] = 'rdkit'
    df_scratch_rdkit = pd.DataFrame.from_dict(drug_r2_scratch_rdkit, orient='index', columns=['r2_de'])
    df_scratch_rdkit['type'] = 'non-pretrained'
    df_scratch_rdkit['model'] = 'rdkit'

    df_pretrained_grover = pd.DataFrame.from_dict(drug_r2_pretrained_grover, orient='index', columns=['r2_de'])
    df_pretrained_grover['type'] = 'pretrained'
    df_pretrained_grover['model'] = 'grover'
    df_scratch_grover = pd.DataFrame.from_dict(drug_r2_scratch_grover, orient='index', columns=['r2_de'])
    df_scratch_grover['type'] = 'non-pretrained'
    df_scratch_grover['model'] = 'grover'

    df_pretrained_jtvae = pd.DataFrame.from_dict(drug_r2_pretrained_jtvae, orient='index', columns=['r2_de'])
    df_pretrained_jtvae['type'] = 'pretrained'
    df_pretrained_jtvae['model'] = 'jtvae'
    df_scratch_jtvae = pd.DataFrame.from_dict(drug_r2_scratch_jtvae, orient='index', columns=['r2_de'])
    df_scratch_jtvae['type'] = 'non-pretrained'
    df_scratch_jtvae['model'] = 'jtvae'
    
    df = pd.concat([df_baseline, 
                    df_pretrained_rdkit, 
                    df_scratch_rdkit,
                    df_pretrained_grover, 
                    df_scratch_grover,
                    df_pretrained_jtvae, 
                    df_scratch_jtvae,
                   ])

    df['r2_de'] = df['r2_de'].apply(lambda x: max(x,0))
    # df['delta'] = df['pretrained'] - df['scratch']
    df['cell_line'] = pd.Series(df.index.values).apply(lambda x: x.split('_')[0]).values
    df['drug'] = pd.Series(df.index.values).apply(lambda x: x.split('_')[1]).values
    df['dose'] = pd.Series(df.index.values).apply(lambda x: x.split('_')[2]).values
    df['dose'] = df['dose'].astype(float)

    df = df.reset_index()
    return df

In [25]:
df_degs = create_df(drug_r2_baseline_degs, 
                    drug_r2_pretrained_degs_rdkit, 
                    drug_r2_scratch_degs_rdkit,
                    drug_r2_pretrained_degs_grover,
                    drug_r2_scratch_degs_grover,
                    drug_r2_pretrained_degs_jtvae,
                    drug_r2_scratch_degs_jtvae,
                   )
df_all = create_df(drug_r2_baseline_all, 
                   drug_r2_pretrained_all_rdkit, 
                   drug_r2_scratch_all_rdkit,
                   drug_r2_pretrained_all_grover,
                   drug_r2_scratch_all_grover,
                   drug_r2_pretrained_all_jtvae,
                   drug_r2_scratch_all_jtvae,)

## Compute mean and median across DEGs and all genes 

In [26]:
r2_degs_mean = []
for model, _df in df_degs.groupby(['model', 'type','dose']): 
    dose = model[2]
    if dose == 1.0:
        print(f"Model: {model}, R2 mean: {_df.r2_de.mean()}")
        r2_degs_mean.append(_df.r2_de.mean())

Model: ('baseline', 'baseline', 1.0), R2 mean: 0.42459163400861955
Model: ('grover', 'non-pretrained', 1.0), R2 mean: 0.4746559880397938
Model: ('grover', 'pretrained', 1.0), R2 mean: 0.556008603837755
Model: ('jtvae', 'non-pretrained', 1.0), R2 mean: 0.4393050957609106
Model: ('jtvae', 'pretrained', 1.0), R2 mean: 0.5283902839378074
Model: ('rdkit', 'non-pretrained', 1.0), R2 mean: 0.47309815221362644
Model: ('rdkit', 'pretrained', 1.0), R2 mean: 0.6430796980857849


In [27]:
r2_all_mean = []
for model, _df in df_all.groupby(['model', 'type','dose']): 
    dose = model[2]
    if dose == 1.0:
        print(f"Model: {model}, R2 mean: {_df.r2_de.mean()}")
        r2_all_mean.append(_df.r2_de.mean())

Model: ('baseline', 'baseline', 1.0), R2 mean: 0.6262997168081778
Model: ('grover', 'non-pretrained', 1.0), R2 mean: 0.653501175068043
Model: ('grover', 'pretrained', 1.0), R2 mean: 0.731699303344444
Model: ('jtvae', 'non-pretrained', 1.0), R2 mean: 0.5941466004760177
Model: ('jtvae', 'pretrained', 1.0), R2 mean: 0.7148480856860125
Model: ('rdkit', 'non-pretrained', 1.0), R2 mean: 0.6593914716332047
Model: ('rdkit', 'pretrained', 1.0), R2 mean: 0.7794026335080465


In [28]:
r2_degs_median = []
for model, _df in df_degs.groupby(['model', 'type','dose']): 
    dose = model[2]
    if dose == 1.0:
        print(f"Model: {model}, R2 median: {_df.r2_de.median()}")
        r2_degs_median.append(_df.r2_de.median())

Model: ('baseline', 'baseline', 1.0), R2 median: 0.4295828938484192
Model: ('grover', 'non-pretrained', 1.0), R2 median: 0.5277886390686035
Model: ('grover', 'pretrained', 1.0), R2 median: 0.5857417583465576
Model: ('jtvae', 'non-pretrained', 1.0), R2 median: 0.5022945404052734
Model: ('jtvae', 'pretrained', 1.0), R2 median: 0.5144805908203125
Model: ('rdkit', 'non-pretrained', 1.0), R2 median: 0.5223346948623657
Model: ('rdkit', 'pretrained', 1.0), R2 median: 0.7218122482299805


In [29]:
r2_all_median = []
model = []
model_type = []
for _model, _df in df_all.groupby(['model', 'type','dose']): 
    dose = _model[2]
    if dose == 1.0:
        print(f"Model: {_model}, R2 median: {_df.r2_de.median()}")
        r2_all_median.append(_df.r2_de.median())
        model.append(_model[0])
        model_type.append(_model[1])

Model: ('baseline', 'baseline', 1.0), R2 median: 0.7536529302597046
Model: ('grover', 'non-pretrained', 1.0), R2 median: 0.7543074488639832
Model: ('grover', 'pretrained', 1.0), R2 median: 0.8012929558753967
Model: ('jtvae', 'non-pretrained', 1.0), R2 median: 0.7209659814834595
Model: ('jtvae', 'pretrained', 1.0), R2 median: 0.7891685962677002
Model: ('rdkit', 'non-pretrained', 1.0), R2 median: 0.7765571475028992
Model: ('rdkit', 'pretrained', 1.0), R2 median: 0.8393895626068115


# Compute Table 2

In [30]:
df_dict = {
    "Model": model, 
    "Type": model_type,
    "Mean $r^2$ all": r2_all_mean,
    "Mean $r^2$ DEGs": r2_degs_mean,
    "Median $r^2$ all": r2_all_median,
    "Median $r^2$ DEGs": r2_degs_median
}

df = pd.DataFrame.from_dict(df_dict)
df = df.set_index('Model')

In [31]:
print(df.to_latex(float_format="%.2f"))

\begin{tabular}{llrrrr}
\toprule
{} &            Type &  Mean \$r\textasciicircum 2\$ all &  Mean \$r\textasciicircum 2\$ DEGs &  Median \$r\textasciicircum 2\$ all &  Median \$r\textasciicircum 2\$ DEGs \\
Model    &                 &                 &                  &                   &                    \\
\midrule
baseline &        baseline &            0.63 &             0.42 &              0.75 &               0.43 \\
grover   &  non-pretrained &            0.65 &             0.47 &              0.75 &               0.53 \\
grover   &      pretrained &            0.73 &             0.56 &              0.80 &               0.59 \\
jtvae    &  non-pretrained &            0.59 &             0.44 &              0.72 &               0.50 \\
jtvae    &      pretrained &            0.71 &             0.53 &              0.79 &               0.51 \\
rdkit    &  non-pretrained &            0.66 &             0.47 &              0.78 &               0.52 \\
rdkit    &      pretrained &

____
# Compute Supplement Table 9

Calculations

In [32]:
dose = 1.0
vs_model = 'baseline'

models = []
gene_set = []
p_values = [] 
vs_models = []


for model in ['rdkit', 'grover', 'jtvae']:
    for vs_model in ['baseline', "non-pretrained"]:
        _df = df_all[df_all.model.isin([vs_model, model])]
        _df = _df[_df.type.isin(['pretrained', vs_model]) & (_df.dose == dose)]
    #     display(_df)
        stat, pvalue = scipy.stats.ttest_rel(
            _df[(_df.type == 'pretrained') & (_df.dose == dose)].r2_de,
            _df[(_df.type == vs_model) & (_df.dose == dose)].r2_de,
        )
    #     print(f"Model: {model}, p-value: {pvalue}")
        models.append(model)
        gene_set.append('all genes')
        p_values.append(pvalue)
        vs_models.append(vs_model)

        _df = df_degs[df_degs.model.isin(['baseline', model])]
        _df = _df[_df.type.isin(['pretrained', vs_model]) & (_df.dose == dose)]
    #     display(_df)
        stat, pvalue = scipy.stats.ttest_rel(
            _df[(_df.type == 'pretrained') & (_df.dose == dose)].r2_de,
            _df[(_df.type == vs_model) & (_df.dose == dose)].r2_de,
        )
    #     print(f"Model: {model}, p-value: {pvalue}")
        models.append(model)
        gene_set.append('DEGs')
        p_values.append(pvalue)
        vs_models.append(vs_model)

In [33]:
df_dict = {
    "Model $G$": models,
    "Against": vs_models,
    "Gene set": gene_set,
    "p-value": p_values
}

df = pd.DataFrame.from_dict(df_dict)
df = df.set_index('Model $G$')

Print table

In [34]:
print(df.to_latex(float_format="%.4f"))

\begin{tabular}{lllr}
\toprule
{} &         Against &   Gene set &  p-value \\
Model \$G\$ &                 &            &          \\
\midrule
rdkit     &        baseline &  all genes &   0.0014 \\
rdkit     &        baseline &       DEGs &   0.0007 \\
rdkit     &  non-pretrained &  all genes &   0.0058 \\
rdkit     &  non-pretrained &       DEGs &   0.0051 \\
grover    &        baseline &  all genes &   0.0009 \\
grover    &        baseline &       DEGs &   0.0003 \\
grover    &  non-pretrained &  all genes &   0.0025 \\
grover    &  non-pretrained &       DEGs &   0.0067 \\
jtvae     &        baseline &  all genes &   0.0047 \\
jtvae     &        baseline &       DEGs &   0.0021 \\
jtvae     &  non-pretrained &  all genes &   0.0001 \\
jtvae     &  non-pretrained &       DEGs &   0.0086 \\
\bottomrule
\end{tabular}



____