# Train sensitivity model with GMM VAE as DVAE

## Setup

In [1]:
import numpy as np
import pandas as pd
%run ./utils/imports.py

import utils.utils as utils
from models import GMMVAE, SensitivityModelGMMVAE, modules

import sys
sys.path.append('/home/adam/Projects/vadeers/code/gmm-vae-compounds/models/hgraph2graph/')

import rdkit.Chem as Chem
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


# Load the data

In [2]:
# General path
dataset_dir = "/home/adam/Projects/vadeers/data/Ready Datasets/Baseline Dataset/"

# Sensitivity table
sensitivity_table = pd.read_csv(os.path.join(dataset_dir, "sensitivity_table.csv"))

# Cell lines biological data
cell_lines_biological_data = pd.read_csv(os.path.join(dataset_dir, "cell_lines_biological_data_from_deers.csv"))

# Drugs SMILES vector representations
drugs_mol2vec_reprs = pd.read_csv(os.path.join(dataset_dir, "drugs_Mol2Vec_reprs.csv"))

In [3]:
### Load appropriate data

# Load drugs inhibition profiles
NO_TRUE_CLUSTER_LABELS = 3
drugs_inhib_profiles= pd.read_csv("/home/adam/Projects/vadeers/data/Ready Datasets/Baseline Dataset/drugs_inhib_profiles_with_3_guiding_cluster_labels.csv")

# Create mappers from IDs to indexes
cell_line_ID_to_index_mapper = utils.get_ID_to_idx_mapper(cell_lines_biological_data, id_col="cell_line_id")
drugs_ID_to_smiles_rep_index_mapper = utils.get_ID_to_idx_mapper(drugs_mol2vec_reprs, id_col="PubChem CID")
drugs_ID_to_inhib_profiles_index_mapper = utils.get_ID_to_idx_mapper(drugs_inhib_profiles, id_col="PubChem CID")

# Create main dataset
full_dataset = utils.DatasetThreeTables(sensitivity_table, 
                                        cell_lines_biological_data.values[:, 1:], 
                                        drugs_mol2vec_reprs.values[:, 1:], 
                                        drugs_inhib_profiles.values[:, 1:],
                                        cell_line_ID_to_index_mapper, 
                                        drugs_ID_to_smiles_rep_index_mapper, 
                                        drugs_ID_to_inhib_profiles_index_mapper,
                                        drug_ID_name="PubChem CID", 
                                        cell_line_ID_name="COSMIC_ID", 
                                        guiding_data_class_name="guiding_data_class",
                                        sensitivity_metric="LN_IC50", 
                                        drug_ID_index=1, 
                                        cell_line_ID_index=3, 
                                        sensitivity_metric_index=4)

# Create VAE dataloader
VAE_BATCH_SIZE = 8
vae_dataset = utils.get_vae_dataset(drugs_mol2vec_reprs, drugs_inhib_profiles)
vae_dataloader = DataLoader(vae_dataset, batch_size=VAE_BATCH_SIZE, shuffle=True)

In [4]:
full_dataset[0]

(array([ 1.0462272e+00, -3.2771523e+00, -6.3824110e+00,  1.8789635e+00,
        -7.9489590e-01, -5.8907150e+00, -8.2104870e+00,  3.5316072e+00,
         1.5746067e+00, -1.0245272e+00,  1.1576798e+00,  1.2552512e+00,
        -1.4733623e+00, -3.3957162e-01, -2.3549979e+00, -4.5432060e+00,
         2.7046566e+00, -1.7262821e+00, -6.2296480e+00,  9.3133620e+00,
         6.7901673e+00,  8.7795410e+00,  1.2105804e+01,  7.7198935e+00,
        -8.0074170e+00,  3.9709947e+00, -1.4577461e+00, -6.6870420e+00,
         1.7495483e-02, -4.4180405e-01,  8.2790650e+00, -4.8904943e+00,
        -4.6256530e+00, -7.1233892e+00,  2.5176482e+00, -4.1075134e-01,
        -3.1874676e+00, -3.3411288e-01,  1.1363386e+01,  3.2968955e+00,
        -5.5781290e-01,  1.4544309e+00, -1.7068732e+00, -9.8873830e-01,
        -9.4704820e+00,  1.0955430e+01,  8.6190460e-01,  9.6832680e+00,
        -4.9702187e+00,  2.5784788e+00,  3.8157666e+00, -1.1041448e+01,
         3.4110174e-01, -2.4691017e+00, -1.5018735e+01, -5.61865

## Setup the model

In [5]:
# Sensitivity model with GMM VAE
# Input dimensionalities
DRUG_INPUT_DIM = 300
DRUG_GUIDING_DIM = 294
CELL_LINE_INPUT_DIM = 241

# Latent spaces dimensionalities
DRUG_LATENT_DIM = 10
CELL_LINE_LATENT_DIM = 10

# NN layers
DRUG_ENCODER_LAYERS = (DRUG_INPUT_DIM, 128, 64, DRUG_LATENT_DIM)
DRUG_INPUT_DECODER_LAYERS = (DRUG_LATENT_DIM, 64, 128, DRUG_INPUT_DIM)
DRUG_GUIDING_DECODER_LAYERS = (DRUG_LATENT_DIM, 64, 128, DRUG_GUIDING_DIM)
CELL_LINE_ENCODER_LAYERS = (CELL_LINE_INPUT_DIM, 128, 64, CELL_LINE_LATENT_DIM)
CELL_LINE_DECODER_LAYERS = (CELL_LINE_LATENT_DIM, 64, 128, CELL_LINE_INPUT_DIM)

# Set number of components in latent GMM
NO_GMM_COMPONENTS = NO_TRUE_CLUSTER_LABELS

# Transformation to apply before encoders output
var_transformation = lambda x: torch.exp(x) ** 0.5

# Establish config dict
whole_model_config = {"drug_latent_dim": DRUG_LATENT_DIM,
                        "cell_line_latent_dim": CELL_LINE_LATENT_DIM,
                        "no_gmm_components": NO_GMM_COMPONENTS,
                        "components_std": 1.,
                        "drug_encoder_layers": (DRUG_INPUT_DIM, 128, 64, DRUG_LATENT_DIM),
                        "drug_input_decoder_layers": (DRUG_LATENT_DIM, 64, 128, DRUG_INPUT_DIM),
                        "drug_guiding_decoder_layers": (DRUG_LATENT_DIM, 64, 128, DRUG_GUIDING_DIM),
                        "cell_line_encoder_layers": (CELL_LINE_INPUT_DIM, 128, 64, CELL_LINE_LATENT_DIM),
                        "cell_line_decoder_layers": (CELL_LINE_LATENT_DIM, 64, 128, CELL_LINE_INPUT_DIM),
                        "vae_loss_function_weights": (1., 1., 1., 1., 0.),
                        "vae_var_transformation": "standard",
                        "optimizer": "adam",
                        "learning_rate": 0.0005,
                        "aen_reconstruction_weight": 1.,
                        "sensitivity_prediction_weight": 1.,
                        "l2_term": 0.,
                        "pretraining_vae": False,
                        "batch_size": 128,
                        "mixed_training": True,
                        "vae_training_num_epochs": 100,
                        "vae_training_step_rate": 1000,
                        "drug_model_learning_rate": 0.0005,
                        "vae_loader_batch_size": VAE_BATCH_SIZE, 
                        "clip_guiding_rec": False,
                        "guiding_clip_min": 0,
                        "guiding_clip_max": 100}

# Establish sensitivity prediction network config
sensitivity_prediction_network_config = {"layers": (DRUG_LATENT_DIM + CELL_LINE_LATENT_DIM, 512, 256, 128, 1),
                                        "learning_rate": 0.0005,
                                        "l2_term": 0,
                                        "dropout_rate1": 0.5,
                                        "dropout_rate2": 0.5}

## Run the model multiple times with different data splits

In [6]:
# Data split seeds
SPLIT_SEEDS = [11, 13, 26, 76, 92]

# Data split and loaders hyperparameters
NUM_TEST_CELL_LINES = 100
BATCH_SIZE_TRAIN = 128
BATCH_SIZE_TEST = 512

# Training hyperparameters
NUM_EPOCHS = 200
SAVE_CHECKPOINT_EVERY_N_EPOCHS = 10
FREEZE_EPOCH = 150
AFTER_FREEZE_LR = 0.001
STEP_SIZE = 10   # Step for learning rate shrinkage
GAMMA = 0.1   # Shrinkage factor for learning rate

for exp_run, split_seed in enumerate(SPLIT_SEEDS):
    dataset_train, dataset_test, train_cell_lines, test_cell_lines = full_dataset.train_test_split(NUM_TEST_CELL_LINES, seed=split_seed,
                                                                                              return_cell_lines=True)
    # Create corresponding DataLoaders
    dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE_TRAIN, shuffle=True)
    dataloader_test = DataLoader(dataset_test, batch_size=BATCH_SIZE_TEST)
    
    pl.utilities.seed.seed_everything(split_seed)
    
    # Establish drug model
    drug_gmm_vae = GMMVAE(whole_model_config["drug_encoder_layers"], whole_model_config["drug_input_decoder_layers"], 
                          whole_model_config["drug_guiding_decoder_layers"], 
                          whole_model_config["no_gmm_components"],
                          components_std=whole_model_config["components_std"],
                          var_transformation=var_transformation, 
                          learning_rate=whole_model_config["drug_model_learning_rate"],
                          loss_function_weights=whole_model_config["vae_loss_function_weights"], 
                          batch_norm=False, optimizer="adam",
                          encoder_dropout_rate=0, decoders_dropout_rate=0,
                          clip_guiding_rec=whole_model_config["clip_guiding_rec"],
                          guiding_clip_min=whole_model_config["guiding_clip_min"],
                          guiding_clip_max=whole_model_config["guiding_clip_max"])
    
    # Set up trainable componenst stds - comment below line if you want to have fixed isotropic covariance
    # matrices in GMM
    drug_gmm_vae.stds = nn.Parameter(data=torch.ones(whole_model_config["no_gmm_components"], drug_gmm_vae.latent_dim), requires_grad=True)

    # Establish cell line model
    cell_line_aen = modules.AutoencoderConfigurable(whole_model_config["cell_line_encoder_layers"], whole_model_config["cell_line_decoder_layers"])

    # Three-layer variant
    sensitivity_prediction_network = modules.FeedForwardThreeLayersConfigurableDropout(sensitivity_prediction_network_config)
    
    # Assemble the model
    model = SensitivityModelGMMVAE(drug_gmm_vae, cell_line_aen, sensitivity_prediction_network,
                                  learning_rate=whole_model_config["learning_rate"],
                                  aen_reconstruction_loss_weight=whole_model_config["aen_reconstruction_weight"],
                                  sensitivity_loss_weight=whole_model_config["sensitivity_prediction_weight"],
                                  vae_dataloader=vae_dataloader) # to na None, wtedy nie ma traning dodatkowego i OK
   
    # Train the model
    # Establish logger
    model_name = f"""GMM_VAE__IP__no_comps={NO_GMM_COMPONENTS}__trained_comp_std"""
    tb_logger = pl_loggers.TensorBoardLogger(rf"final_runs\{model_name}", name=f"run_{exp_run}_split_seed_{split_seed}")
    
    # Establish callbacks
    freezing_callback = utils.FreezingCallback(freeze_epoch=FREEZE_EPOCH, new_learning_rate=AFTER_FREEZE_LR, step_size=STEP_SIZE, gamma=GAMMA)
    
    # Overwrite default checkpoint callback if needed
    checkpoint_callback = ModelCheckpoint(monitor="val_sensitivity_pred_rmse", every_n_epochs=SAVE_CHECKPOINT_EVERY_N_EPOCHS, every_n_train_steps=None, train_time_interval=None,
                                         save_top_k=NUM_EPOCHS // SAVE_CHECKPOINT_EVERY_N_EPOCHS)

    # Establish trainer
    trainer = pl.Trainer(max_epochs=NUM_EPOCHS, logger=tb_logger, gpus=0, 
                         callbacks=[freezing_callback, checkpoint_callback])

    trainer.fit(model, dataloader_train, dataloader_test)

    # Save hyperparams
    whole_model_config["vae_var_transformation"] = str(var_transformation)
    whole_model_config["num_epochs"] = NUM_EPOCHS
    whole_model_config["freeze_epoch"] = FREEZE_EPOCH
    whole_model_config["after_freeze_lr"] = AFTER_FREEZE_LR

    with open(os.path.join(trainer.log_dir, "whole_model_config.json"), "w") as f:
        json.dump(whole_model_config, f)

    with open(os.path.join(trainer.log_dir, "sensitivity_prediction_network_config.json"), "w") as f:
        json.dump(sensitivity_prediction_network_config, f)

Global seed set to 11
  rank_zero_deprecation(
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
Missing logger folder: final_runs\GMM_VAE__IP__no_comps=3__trained_comp_std/run_0_split_seed_11

  | Name                           | Type                                      | Params
---------------------------------------------------------------------------------------------
0 | drug_model                     | GMMVAE                                    | 142 K 
1 | cell_line_model                | AutoencoderConfigurable                   | 80.0 K
2 | sensitivity_prediction_network | FeedForwardThreeLayersConfigurableDropout | 175 K 
---------------------------------------------------------------------------------------------
397 K     Trainable params
0         Non-trainable params
397 K     Total params
1.592     Total estimated model params size (MB)


                                                                            

  rank_zero_warn(
  rank_zero_warn(


Epoch 0:   0%|          | 0/1552 [00:00<?, ?it/s] 

MisconfigurationException: You are trying to `self.log()` but it is not managed by the `Trainer` control flow