# Train sensitivity model with Vanilla VAE as DVAE

## Setup

In [None]:
%run ./utils/imports.py

import utils.utils as utils
from models import VanillaVAE, SensitivityModelVanillaVAE, modules

# Load the data

In [None]:
# General path
dataset_dir = "path/to/files"

# Sensitivity table
sensitivity_table = pd.read_csv(os.path.join(dataset_dir, "sensitivity_table.csv"))

# Cell lines biological data
cell_lines_biological_data = pd.read_csv(os.path.join(dataset_dir, "cell_lines_biological_data_from_deers.csv"))

# Drugs SMILES vector representations
drugs_mol2vec_reprs = pd.read_csv(os.path.join(dataset_dir, "drugs_Mol2Vec_reprs.csv"))

In [None]:
### Load appropriate data

# Load drugs inhibition profiles
NO_TRUE_CLUSTER_LABELS = 3
drugs_inhib_profiles= pd.read_csv("path/to/file")

# Create mappers from IDs to indexes
cell_line_ID_to_index_mapper = utils.get_ID_to_idx_mapper(cell_lines_biological_data, id_col="cell_line_id")
drugs_ID_to_smiles_rep_index_mapper = utils.get_ID_to_idx_mapper(drugs_mol2vec_reprs, id_col="PubChem CID")
drugs_ID_to_inhib_profiles_index_mapper = utils.get_ID_to_idx_mapper(drugs_inhib_profiles, id_col="PubChem CID")

# Create main dataset
full_dataset = utils.DatasetThreeTables(sensitivity_table, 
                                        cell_lines_biological_data.values[:, 1:], 
                                        drugs_mol2vec_reprs.values[:, 1:], 
                                        drugs_inhib_profiles.values[:, 1:],
                                        cell_line_ID_to_index_mapper, 
                                        drugs_ID_to_smiles_rep_index_mapper, 
                                        drugs_ID_to_inhib_profiles_index_mapper,
                                        drug_ID_name="PubChem CID", 
                                        cell_line_ID_name="COSMIC_ID", 
                                        guiding_data_class_name="guiding_data_class",
                                        sensitivity_metric="LN_IC50", 
                                        drug_ID_index=1, 
                                        cell_line_ID_index=3, 
                                        sensitivity_metric_index=4)

# Create VAE dataloader
VAE_BATCH_SIZE = 8
vae_dataset = utils.get_vae_dataset(drugs_mol2vec_reprs, drugs_inhib_profiles)
vae_dataloader = DataLoader(vae_dataset, batch_size=VAE_BATCH_SIZE, shuffle=True)

## Run the model multiple times with different data splits

In [None]:
# Data split seeds
SPLIT_SEEDS = [11, 13, 26, 76, 92]

# Data split and loaders hyperparameters
NUM_TEST_CELL_LINES = 100
BATCH_SIZE_TRAIN = 128
BATCH_SIZE_TEST = 512

# Input dimensionalities
DRUG_INPUT_DIM = 300
DRUG_GUIDING_DIM = 294
CELL_LINE_INPUT_DIM = 241

# Latent dimensionalities
DRUG_LATENT_DIM = 10
CELL_LINE_LATENT_DIM = 10

## NN layers
DRUG_ENCODER_LAYERS = (DRUG_INPUT_DIM, 128, 64, DRUG_LATENT_DIM)
DRUG_INPUT_DECODER_LAYERS = (DRUG_LATENT_DIM, 64, 128, DRUG_INPUT_DIM)
DRUG_GUIDING_DECODER_LAYERS = (DRUG_LATENT_DIM, 64, 128, DRUG_GUIDING_DIM)
CELL_LINE_ENCODER_LAYERS = (CELL_LINE_INPUT_DIM, 128, 64, CELL_LINE_LATENT_DIM)
CELL_LINE_DECODER_LAYERS = (CELL_LINE_LATENT_DIM, 64, 128, CELL_LINE_INPUT_DIM)

# Establish sensitivity prediction network config
sensitivity_prediction_network_config = {"layers": (DRUG_LATENT_DIM + CELL_LINE_LATENT_DIM, 512, 256, 128, 1),
                                        "learning_rate": 0.0005,
                                        "l2_term": 0,
                                        "dropout_rate1": 0.5,
                                        "dropout_rate2": 0.5}

# Training hyperparameters
NUM_EPOCHS = 200
SAVE_CHECKPOINT_EVERY_N_EPOCHS = 10
FREEZE_EPOCH = 150
AFTER_FREEZE_LR = 0.001
STEP_SIZE = 10  # Step for learning rate shrinkage
GAMMA = 0.1   # Shrinkage factor for learning rate

for exp_run, split_seed in enumerate(SPLIT_SEEDS):
    dataset_train, dataset_test, train_cell_lines, test_cell_lines = full_dataset.train_test_split(NUM_TEST_CELL_LINES, seed=split_seed,
                                                                                              return_cell_lines=True)
    # Create corresponding DataLoaders
    dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE_TRAIN, shuffle=True)
    dataloader_test = DataLoader(dataset_test, batch_size=BATCH_SIZE_TEST)
    
    pl.utilities.seed.seed_everything(split_seed)
    
    # Establish drug model
    drug_vanilla_vae = VanillaVAE(DRUG_ENCODER_LAYERS, DRUG_INPUT_DECODER_LAYERS, DRUG_GUIDING_DECODER_LAYERS,
                     var_transformation=lambda x: torch.exp(x) ** 0.5, learning_rate=0.0005,
                     loss_function_weights=(1., 1., 1., 1., 0.0), batch_norm=False, optimizer="adam",
                     encoder_dropout_rate=0, decoders_dropout_rate=0, clip_guiding_rec=False
                      )

    # Establish cell line model
    cell_line_aen = modules.AutoencoderConfigurable(CELL_LINE_ENCODER_LAYERS, CELL_LINE_DECODER_LAYERS)
    
    # Forward network
    sensitivity_prediction_network = modules.FeedForwardThreeLayersConfigurableDropout(sensitivity_prediction_network_config)
    
    # Assemble the model
    model = SensitivityModelVanillaVAE(drug_vanilla_vae, cell_line_aen, sensitivity_prediction_network, 
                                       vae_training_num_epochs=100,
                                       vae_training_step_rate=1000,
                                       vae_dataloader=vae_dataloader,
                                       learning_rate=0.0005)
    # Train the model
    # Establish logger
    model_name = f"""Vanilla_VAE__IP__dr={sensitivity_prediction_network_config["dropout_rate1"]}_{sensitivity_prediction_network_config["dropout_rate2"]}__gam={GAMMA}"""
    tb_logger = pl_loggers.TensorBoardLogger(rf"final_runs\{model_name}", name=f"run_{exp_run}_split_seed_{split_seed}")
    
    # Establish callbacks
    freezing_callback = utils.FreezingCallback(freeze_epoch=FREEZE_EPOCH, new_learning_rate=AFTER_FREEZE_LR, step_size=STEP_SIZE, gamma=GAMMA)
    
    # Overwrite default checkpoint callback if needed
    checkpoint_callback = ModelCheckpoint(monitor="val_sensitivity_pred_rmse", every_n_epochs=SAVE_CHECKPOINT_EVERY_N_EPOCHS, every_n_train_steps=None, train_time_interval=None,
                                         save_top_k=NUM_EPOCHS // SAVE_CHECKPOINT_EVERY_N_EPOCHS)

    # Establish trainer
    trainer = pl.Trainer(max_epochs=NUM_EPOCHS, logger=tb_logger, gpus=0, 
                         callbacks=[freezing_callback, checkpoint_callback])

    trainer.fit(model, dataloader_train, dataloader_test)
    
    # Save hyperparams
    with open(os.path.join(trainer.log_dir, "sensitivity_prediction_network_config.json"), "w") as f:
        json.dump(sensitivity_prediction_network_config, f)