# Evaluate models' predictive performance

## Setup

In [None]:
%run ./utils/imports.py

import utils.utils as utils
from models import VanillaVAE, GMMVAE, SensitivityModelVanillaVAE, SensitivityModelGMMVAE, modules

## Load the data

In [None]:
# Data path
dataset_dir = "path/to/files"

# Sensitivity table
sensitivity_table = pd.read_csv(os.path.join(dataset_dir, "sensitivity_table.csv"))

# Cell lines biological data
cell_lines_biological_data = pd.read_csv(os.path.join(dataset_dir, "cell_lines_biological_data.csv"))

# Drugs SMILES vector representations
drugs_mol2vec_reprs = pd.read_csv(os.path.join(dataset_dir, "drugs_Mol2Vec_reprs.csv"))

# Drugs inhibition profiles
NO_TRUE_CLUSTER_LABELS = 3
drugs_inhib_profiles= pd.read_csv("path/to/file")

# Create mappers from IDs to indexes
cell_line_ID_to_index_mapper = utils.get_ID_to_idx_mapper(cell_lines_biological_data, id_col="cell_line_id")
drugs_ID_to_smiles_rep_index_mapper = utils.get_ID_to_idx_mapper(drugs_mol2vec_reprs, id_col="PubChem CID")
drugs_ID_to_inhib_profiles_index_mapper = utils.get_ID_to_idx_mapper(drugs_inhib_profiles, id_col="PubChem CID")

# Create main dataset and dataloaders
full_dataset = utils.DatasetThreeTables(sensitivity_table, 
                                  cell_lines_biological_data.values[:, 1:], 
                                  drugs_mol2vec_reprs.values[:, 1:], 
                                  drugs_inhib_profiles.values[:, 1:],
                                  cell_line_ID_to_index_mapper, drugs_ID_to_smiles_rep_index_mapper, drugs_ID_to_inhib_profiles_index_mapper,
                                  drug_ID_name="PubChem CID", cell_line_ID_name="COSMIC_ID", guiding_data_class_name="guiding_data_class",
                                  sensitivity_metric="LN_IC50", drug_ID_index=1, cell_line_ID_index=3, sensitivity_metric_index=4)

## Establish the model to assess

In [None]:
# Establish model directory and type
model_dir = "path/to/model"
MODEL_TYPE = "Vanilla"

## Evaluate predictive performance

In [None]:
# Initialize dict for storing performance metrics
metrics_dict = {"run": [],
           "split_seed":[],
          "rmse_test": [],
          "pearson_test": [],
           "guiding_data_rec_rmse": []}
NUM_TEST_CELL_LINES = 100

for exp in sorted(os.listdir(model_dir)):   # Iterate over experimental runs
    # Load datasets
    run = int(exp.split("_")[1])
    split_seed = int(exp.split("_")[-1])
    dataset_train, dataset_test, train_cell_lines, test_cell_lines = full_dataset.train_test_split(NUM_TEST_CELL_LINES, seed=split_seed,
                                                                                          return_cell_lines=True)
    dataloader_train = DataLoader(dataset_train, batch_size=len(dataset_train),
                             shuffle=False)
    dataloader_test = DataLoader(dataset_test, batch_size=len(dataset_test),
                             shuffle=False)

    sorted_checkpoints = sorted(os.listdir(os.path.join(model_dir, exp, "version_0", "checkpoints")), reverse=True, key=lambda x: int(x.split("=")[-1][:-5]))
    checkpoint_path = os.path.join(model_dir, exp, "version_0", "checkpoints", sorted_checkpoints[0])
    
    # Load the model
    if MODEL_TYPE == "GMM":
        with open(os.path.join(model_dir, exp, "version_0", "whole_model_config.json"), "r") as f:
            whole_model_config = json.load(f)
        with open(os.path.join(model_dir, exp, "version_0", "sensitivity_prediction_network_config.json"), "r") as f:
            sensitivity_prediction_network_config = json.load(f)

        # Load GMMVAE sensitivity model from config dict
        var_transformation = lambda x: torch.exp(x) ** 0.5

        # Establish drug model
        drug_gmm_vae = GMMVAE(whole_model_config["drug_encoder_layers"], whole_model_config["drug_input_decoder_layers"], 
                              whole_model_config["drug_guiding_decoder_layers"], 
                              whole_model_config["no_gmm_components"],
                         var_transformation=var_transformation, 
                              learning_rate=whole_model_config["learning_rate"],
                         loss_function_weights=whole_model_config["vae_loss_function_weights"], 
                              batch_norm=False, optimizer="adam",
                         encoder_dropout_rate=0, decoders_dropout_rate=0,
                             clip_guiding_rec=True)

        # Establish cell line model
        cell_line_aen = modules.AutoencoderConfigurable(whole_model_config["cell_line_encoder_layers"], whole_model_config["cell_line_decoder_layers"])

        # Establish sensitivity prediction network
        sensitivity_prediction_network = modules.FeedForwardThreeLayersConfigurable(sensitivity_prediction_network_config)

        model = SensitivityModelGMMVAE.load_from_checkpoint(checkpoint_path,
                                                            drug_model=drug_gmm_vae, 
                                                            cell_line_model=cell_line_aen, 
                                                            sensitivity_prediction_network=sensitivity_prediction_network,
                                                            learning_rate=whole_model_config["learning_rate"],
                                                            aen_reconstruction_loss_weight=whole_model_config["aen_reconstruction_weight"],
                                                            sensitivity_prediction_weight=whole_model_config["sensitivity_prediction_weight"])
    if MODEL_TYPE == "Vanilla":
        # Setup the model's hyperparams
        DRUG_INPUT_DIM = 300
        DRUG_GUIDING_DIM = 294
        CELL_LINE_INPUT_DIM = 241

        DRUG_LATENT_DIM = 10
        CELL_LINE_LATENT_DIM = 10

        DRUG_ENCODER_LAYERS = (DRUG_INPUT_DIM, 128, 64, DRUG_LATENT_DIM)
        DRUG_INPUT_DECODER_LAYERS = (DRUG_LATENT_DIM, 64, 128, DRUG_INPUT_DIM)
        DRUG_GUIDING_DECODER_LAYERS = (DRUG_LATENT_DIM, 64, 128, DRUG_GUIDING_DIM)


        CELL_LINE_ENCODER_LAYERS = (CELL_LINE_INPUT_DIM, 128, 64, CELL_LINE_LATENT_DIM)
        CELL_LINE_DECODER_LAYERS = (CELL_LINE_LATENT_DIM, 64, 128, CELL_LINE_INPUT_DIM)

        # Establish sensitivity prediction network
        sensitivity_prediction_network_config = {"layers": (DRUG_LATENT_DIM + CELL_LINE_LATENT_DIM, 512, 256, 128, 1),
                                                "learning_rate": 0.001,
                                                "l2_term": 0,
                                                "dropout_rate1": 0.5,
                                                "dropout_rate2": 0.5}
        # Establish drug model
        drug_vanilla_vae = VanillaVAE(DRUG_ENCODER_LAYERS, DRUG_INPUT_DECODER_LAYERS, DRUG_GUIDING_DECODER_LAYERS,
                         var_transformation=lambda x: torch.exp(x) ** 0.5, learning_rate=0.0005,
                         loss_function_weights=(1., 1., 1., 1., 0.0), batch_norm=False, optimizer="adam",
                         encoder_dropout_rate=0, decoders_dropout_rate=0, clip_guiding_rec=True
                          )

        # Establish cell line model
        cell_line_aen = modules.AutoencoderConfigurable(CELL_LINE_ENCODER_LAYERS, CELL_LINE_DECODER_LAYERS)

        # Forward network
        sensitivity_prediction_network = modules.FeedForwardThreeLayersConfigurableDropout(sensitivity_prediction_network_config)

        model = SensitivityModelVanillaVAE.load_from_checkpoint(checkpoint_path,
                                            drug_model=drug_vanilla_vae, 
                                            cell_line_model=cell_line_aen, 
                                           sensitivity_prediction_network=sensitivity_prediction_network, 
                                           vae_training_num_epochs=100,
                                           vae_training_step_rate=1000,
                                           vae_dataloader=None,
                                           learning_rate=0.0005)

    # Test data
    smiles_data_batch, inhib_profiles_batch, guiding_data_class, cell_line_data_batch, responses, pubchem_id, cosmic_id = next(iter(dataloader_test))

    # Perform inference
    model.eval()
    vae_out, aen_out, sensitivity_pred = model(smiles_data_batch.float(), cell_line_data_batch.float())

    # Evaluate sensitivity predictive performance
    rmse_test = metrics.mean_squared_error(responses.detach().numpy(), sensitivity_pred.detach().numpy()) ** 0.5
    pearson_test = pearsonr(responses.detach().numpy().reshape(-1), sensitivity_pred.detach().numpy().reshape(-1))

    # Evaluate inhibition profiles reconstruction
    guiding_data_rec = vae_out[1]
    guiding_data_rec_rmse = model.drug_model.mse_loss_with_nans(guiding_data_rec, inhib_profiles_batch).item() ** 0.5

    metrics_dict["run"].append(run)
    metrics_dict["split_seed"].append(split_seed)
    metrics_dict["rmse_test"].append(rmse_test)
    metrics_dict["pearson_test"].append(pearson_test)
    metrics_dict["guiding_data_rec_rmse"].append(guiding_data_rec_rmse)

# Create a DataFrame from metrics
results_df = pd.DataFrame(data=metrics_dict)

In [None]:
# Save results df
name = model_dir.split("\\")[1]
results_df.to_csv(os.path.join("final_runs", "Results", f"{name}_results_df.csv"), index=False)