# Evaluation Notebook of Trained DRP Models

### Imports and Paths

In [None]:
import numpy as np
import pandas as pd
import joblib
from joblib import Parallel, delayed
import torch
from torch.utils.data import DataLoader
from torchmetrics.functional.regression import (
    pearson_corrcoef,
    spearman_corrcoef,
    mean_squared_error,
    mean_absolute_error,
)
from tqdm import tqdm
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns

import os
from collections import defaultdict


from src import DrugResponseDataset
from src import (
    DrugResponseModelTokens,
    DrugResponseModelLegacy,
    DrugResponseLightningModule,
)

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
smiles_embeddings_path = os.path.join("data", "smiles_embeddings_dict.joblib")
selfies_embeddings_path = os.path.join("data", "selfies_embeddings_dict.joblib")
smiles_tokens_embeddings_path = os.path.join(
    "data", "smiles_tokens_embeddings_dict.joblib"
)
smiles_fingerprints_embeddings_path = os.path.join(
    "data", "smiles_fingerprints_embeddings_dict.joblib"
)

ccl_ge_path = os.path.join("data", "ge_filtered_scaled.csv")
drp_test_path = os.path.join("data", "drp_test.csv")
drp_train_path = os.path.join("data", "drp_train.csv")

In [None]:
drp_test_df = pd.read_csv(drp_test_path)
len(drp_test_df)

In [None]:
selfies_test_unique = drp_test_df["selfies"].unique()
smiles_test_unique = drp_test_df["smiles"].unique()
print(f"Number of unique selfies in test set: {len(selfies_test_unique)}")
print(f"Number of unique smiles in test set: {len(smiles_test_unique)}")

In [None]:
print(
    f"Average number of entries per unique selfies: {len(drp_test_df) / len(selfies_test_unique):.0f}"
)

### Evaluation Functions

In [None]:
def compute_eval_metrics(model, test_loader, metrics, device, return_predictions=False):
    """
    Compute evaluation metrics for a given model on a test dataset.
    Optionally return predicted values and ground truth values.

    Args:
    - model (torch.nn.Module): The model to be evaluated.
    - test_loader (DataLoader): DataLoader containing the test dataset.
    - metrics (list of functions): List of metric functions to evaluate.
    - device (torch.device): Device on which to perform the computations.
    - return_predictions (bool): Whether to return predictions and ground truths.

    Returns:
    - List or Tuple: A list containing the computed metric values, and optionally, predictions and ground truths.
    """
    model.eval()

    all_outputs = []
    all_labels = []
    metric_results = [0] * len(metrics)
    mse_losses = []

    for batch in tqdm(test_loader):
        with torch.no_grad():
            cpd_emb = batch["cpd_embeddings"].to(device)
            ccl_emb = batch["ccl_ge_embeddings"].to(device)
            labels = batch["label"].cpu()
            outputs = model(cpd_emb, ccl_emb).cpu()

            all_outputs.append(outputs)
            all_labels.append(labels)

            for i, metric in enumerate(metrics):
                if metric.__name__ == "mean_squared_error":
                    mse_loss = metric(outputs, labels).item()
                    mse_losses.append(mse_loss)

    all_outputs = torch.cat(all_outputs, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    for i, metric in enumerate(metrics):
        if metric.__name__ == "mean_squared_error":
            metric_results[i] = (np.mean(mse_losses), np.var(mse_losses))
        else:
            metric_results[i] = metric(all_outputs, all_labels).item()

    if return_predictions:
        return metric_results, all_outputs.numpy(), all_labels.numpy()
    else:
        return metric_results


def compute_per_compound_metrics(
    model, test_loader, metrics, device, return_predictions=False
):
    """
    Compute evaluation metrics for each compound in the test dataset.
    Optionally return average predicted value, standard deviation of predicted values, 
    average ground truth value, and standard deviation of ground truth values per compound.

    Args:
    - model (torch.nn.Module): The model to be evaluated.
    - test_loader (DataLoader): DataLoader containing the test dataset.
    - metrics (list of functions): List of metric functions to evaluate.
    - device (torch.device): Device on which to perform the computations.
    - return_predictions (bool): Whether to return average predictions, prediction std,
      average ground truths, and ground truth std per compound.

    Returns:
    - dict: A dictionary where keys are compound names and values are dictionaries of metric results,
             and optionally, average predictions, prediction std, average ground truths,
             and ground truth std per compound.
    """
    model.eval()

    compound_data = defaultdict(lambda: {"outputs": [], "labels": []})
    compound_metrics = {}
    predictions_dict = {} if return_predictions else None

    for batch in tqdm(test_loader):
        with torch.no_grad():
            cpd_emb = batch["cpd_embeddings"].to(device)
            ccl_emb = batch["ccl_ge_embeddings"].to(device)
            labels = batch["label"].cpu()
            outputs = model(cpd_emb, ccl_emb).cpu()
            cpd_names = batch["cpd_name"]

            for name, output, label in zip(cpd_names, outputs, labels):
                compound_data[name]["outputs"].append(output.numpy())
                compound_data[name]["labels"].append(label.numpy())

    for cpd_name, data in compound_data.items():
        cpd_outputs = np.array(data["outputs"])
        cpd_labels = np.array(data["labels"])

        avg_output = cpd_outputs.mean()
        avg_label = cpd_labels.mean()
        std_output = cpd_outputs.std()
        std_label = cpd_labels.std()

        if return_predictions:
            predictions_dict[cpd_name] = {
                "outputs": avg_output,
                "outputs_std": std_output,
                "labels": avg_label,
                "labels_std": std_label
            }

        # Compute squared errors if MSE is one of the metrics
        if any(metric.__name__ == "mean_squared_error" for metric in metrics):
            squared_errors = (cpd_outputs - cpd_labels) ** 2
            mse_mean = squared_errors.mean()
            mse_std = squared_errors.std()
            rmse = np.sqrt(mse_mean)
        else:
            mse_mean, mse_std, rmse = 0, 0, 0

        compound_metrics[cpd_name] = {
            "mean_squared_error": mse_mean,
            "root_mean_squared_error": rmse,
            "std_of_squared_errors": mse_std,
        }

        for metric in metrics:
            if metric.__name__ != "mean_squared_error":
                metric_value = metric(torch.from_numpy(cpd_outputs), torch.from_numpy(cpd_labels)).item()
                compound_metrics[cpd_name][metric.__name__] = metric_value

    if return_predictions:
        return compound_metrics, predictions_dict
    else:
        return compound_metrics


def plot_preds_vs_labels(
    ground_truth, predictions, residuals=False, figsize=(8, 6), fontsize=14, title=None
):
    """
    Create a plot with predicted values or residuals vs ground truth.

    Args:
    - ground_truth (array-like): Ground truth values.
    - predictions (array-like): Predicted values.
    - residuals (bool): Whether to plot residuals instead of predictions.
    - figsize (tuple): Size of the figure.
    - fontsize (int): Font size for labels and title.
    - title (str): Title of the plot.
    """
    with plt.style.context("seaborn-v0_8-whitegrid"):
        plt.figure(figsize=figsize)

        y_axis_data = predictions if not residuals else ground_truth - predictions

        plt.scatter(ground_truth, y_axis_data, alpha=0.5)

        if residuals:
            plt.axhline(y=0, color="r", linestyle="--", lw=2)
            plt.ylabel("Residuals (Ground Truth - Predictions)", fontsize=fontsize)
        else:
            plt.plot(
                [ground_truth.min(), ground_truth.max()],
                [ground_truth.min(), ground_truth.max()],
                "r--",
                lw=2,
            )
            plt.ylabel("Predictions", fontsize=fontsize)

        plt.xlabel("Ground Truth", fontsize=fontsize)

        if title is not None:
            plt.title(title, fontsize=fontsize + 2)

        plt.grid(True)
        plt.show()


def plot_predictions_vs_ground_truth(
    preds_dict,
    figsize=(10, 6),
    errorbar_color='lightgray',
    errorbar_alpha=0.5,
    point_color='blue',
    line_color='k',
    xlabel='Ground Truth Mean',
    ylabel='Predicted Mean',
    title='Predicted vs Ground Truth Values with Error Bars',
    legend_loc='upper left',
    fontsize=14
):
    """
    Plot predicted values against ground truth values, including mean and standard deviation.

    Args:
    - preds_dict (dict): A dictionary where keys are compound names and values are dictionaries
                         containing 'outputs', 'outputs_std', 'labels', and 'labels_std'.
    - figsize (tuple): Size of the figure.
    - errorbar_color (str): Color for the error bars.
    - errorbar_alpha (float): Alpha transparency for the error bars.
    - point_color (str): Color of the points.
    - line_color (str): Color of the perfect predictions line.
    - xlabel (str): Label for the x-axis.
    - ylabel (str): Label for the y-axis.
    - title (str): Title of the plot.
    - legend_loc (str): Location of the legend.
    - fontsize (int): Font size for labels, title, and legend.
    """
    # Extracting data for plotting
    compound_names = list(preds_dict.keys())
    means = [preds_dict[cpd]['outputs'] for cpd in compound_names]
    stds = [preds_dict[cpd]['outputs_std'] for cpd in compound_names]
    ground_truths = [preds_dict[cpd]['labels'] for cpd in compound_names]
    ground_truths_stds = [preds_dict[cpd]['labels_std'] for cpd in compound_names]

    # Plotting
    plt.figure(figsize=figsize)
    plt.errorbar(
        ground_truths, means, xerr=ground_truths_stds, yerr=stds, fmt='o', 
        ecolor=errorbar_color, elinewidth=2, alpha=errorbar_alpha, 
        color=point_color, label='Compounds'
    )

    # Line representing perfect predictions
    min_val = min(min(ground_truths), min(means))
    max_val = max(max(ground_truths), max(means))
    plt.plot(
        [min_val, max_val], [min_val, max_val], 
        linestyle='--', color=line_color, label='Perfect predictions'
    )

    plt.xlabel(xlabel, fontsize=fontsize)
    plt.ylabel(ylabel, fontsize=fontsize)
    plt.title(title, fontsize=fontsize + 2)
    plt.legend(loc=legend_loc, fontsize=fontsize)
    plt.grid(True)
    plt.show()

def plot_ground_truth_distribution(
    ground_truth_values,
    ground_truth_values2=None,
    bins=30,
    show_kde=True,
    figsize=(10, 6),
    fontsize=14,
    labels=["Set 1", "Set 2"],
    title="Distribution of Ground Truth Values",
    legend_loc="upper left",
    fill_alpha=0.5,
):
    """
    Plot the distribution of ground truth values for all compounds. Optionally plot a second set of values on the same plot.

    Args:
    - ground_truth_values (array): Array of ground truth values.
    - ground_truth_values2 (array, optional): Second array of ground truth values.
    - figsize (tuple): Size of the figure.
    - bins (int): Number of bins in the histogram.
    - title (str): Title of the plot.
    - fill_alpha (float): Opacity for the fill color.
    - show_kde (bool): Whether to show the KDE line.
    """

    plt.figure(figsize=figsize)

    # Plot the first set of ground truth values
    sns.histplot(
        ground_truth_values,
        bins=bins,
        kde=show_kde,
        stat="percent",
        label=labels[0],
        element="step",
        fill=True,
        alpha=fill_alpha,
    )

    # Plot the second set of ground truth values, if provided
    if ground_truth_values2 is not None:
        sns.histplot(
            ground_truth_values2,
            bins=bins,
            kde=show_kde,
            stat="percent",
            color="orange",
            label=labels[1],
            element="step",
            fill=True,
            alpha=fill_alpha,
        )

    plt.xlabel("Ground Truth Value", fontsize=fontsize)
    plt.ylabel("Percentage", fontsize=fontsize)
    plt.title(title, fontsize=fontsize + 2)
    plt.legend(fontsize=fontsize, loc=legend_loc)
    plt.grid(True)
    plt.show()


def extract_labels_parallel(dataset, n_jobs=-1):
    """
    Extract all label values and average label values per compound from a dataset using parallel processing.

    Args:
    - dataset (Dataset): The dataset to extract labels from.
    - n_jobs (int): The number of jobs to run in parallel. -1 means using all processors.

    Returns:
    - tuple: A tuple containing two NumPy arrays:
        1. An array of all label values.
        2. An array of average label values per compound.
    """

    def compute_average(labels):
        """Helper function to compute the average of a list of labels."""
        return np.mean(labels)

    all_labels = []
    compound_labels = defaultdict(list)

    for entry in dataset:
        label = entry["label"].item()
        compound_name = entry["cpd_name"]

        all_labels.append(label)
        compound_labels[compound_name].append(label)

    # Parallel computation of averages
    avg_labels_per_compound = Parallel(n_jobs=n_jobs)(
        delayed(compute_average)(labels) for labels in compound_labels.values()
    )

    return np.array(all_labels), np.array(avg_labels_per_compound)

def percentage_values_in_range(values, low, high):
    """
    Calculate the percentage of values within a specified range.

    Args:
    - values (array-like): The array of values to be filtered.
    - low (float): The lower bound of the range.
    - high (float): The upper bound of the range.

    Returns:
    - float: The percentage of values within the specified range.
    """
    values_in_range = [value for value in values if low <= value < high]
    percentage = (len(values_in_range) / len(values)) * 100
    return percentage

## Comparison of the 2D Embeddings Spaces

### SELFIES Embeddings

In [None]:
selfies_embeddings_dict = joblib.load(selfies_embeddings_path)

selfies_embeddings = [value for value in selfies_embeddings_dict.values()]
selfies_embeddings = np.array(selfies_embeddings)

X = np.mean(selfies_embeddings, axis=1)

tsne = TSNE()
X_2d = tsne.fit_transform(X)

test_indices = [
    list(selfies_embeddings_dict.keys()).index(selfies)
    for selfies in selfies_test_unique
]

with plt.style.context("seaborn-v0_8-whitegrid"):
    plt.figure(figsize=(10, 8))
    sns.scatterplot(x=X_2d[:, 0], y=X_2d[:, 1], label="Train Embeddings")
    sns.scatterplot(
        x=X_2d[test_indices, 0],
        y=X_2d[test_indices, 1],
        color="orange",
        label="Test Embeddings",
    )

    plt.legend()
    plt.show()

### SMILES Embeddings

In [None]:
smiles_embeddings_dict = joblib.load(smiles_embeddings_path)

smiles_embeddings = [value for value in smiles_embeddings_dict.values()]
smiles_embeddings = np.array(smiles_embeddings)

X = np.mean(smiles_embeddings, axis=1)

tsne = TSNE()
X_2d = tsne.fit_transform(X)

test_indices = [
    list(smiles_embeddings_dict.keys()).index(smiles) for smiles in smiles_test_unique
]

with plt.style.context("seaborn-v0_8-whitegrid"):
    plt.figure(figsize=(10, 8))
    sns.scatterplot(x=X_2d[:, 0], y=X_2d[:, 1], label="Train Embeddings")
    sns.scatterplot(
        x=X_2d[test_indices, 0],
        y=X_2d[test_indices, 1],
        color="orange",
        label="Test Embeddings",
    )

    plt.legend()
    plt.show()

## Labels Distributions

In [None]:
test_dataset = DrugResponseDataset(
    cpd_embeddings_path=selfies_embeddings_path,
    ccl_ge_path=ccl_ge_path,
    drp_path=drp_test_path,
    cpd_type="selfies",
)

train_dataset = DrugResponseDataset(
    cpd_embeddings_path=selfies_embeddings_path,
    ccl_ge_path=ccl_ge_path,
    drp_path=drp_train_path,
    cpd_type="selfies",
)

test_all_labels, test_avg_label_values_per_compound = extract_labels_parallel(
    test_dataset
)
train_all_labels, train_avg_label_values_per_compound = extract_labels_parallel(
    train_dataset
)

### Per-compound distribution of label values

In [None]:
plot_ground_truth_distribution(
    test_avg_label_values_per_compound,
    train_avg_label_values_per_compound,
    show_kde=False,
    bins=30,
    figsize=(10, 6),
    title=None,
    labels=["Test Set", "Train Set"],
    fill_alpha=0.3,
    legend_loc="upper left",
)

In [None]:
percentage_values_in_range(train_avg_label_values_per_compound, 0, 0.4)

### Cross-compound distribution of label values

In [None]:
plot_ground_truth_distribution(
    test_all_labels,
    train_all_labels,
    show_kde=False,
    bins=30,
    figsize=(10, 6),
    title=None,
    labels=["Test Set", "Train Set"],
    fill_alpha=0.3,
    legend_loc="upper left",
)

## I. Cross-Compounds Evaluation

### SELFIES-based Trained Model Evaluation

In [None]:
dataset = DrugResponseDataset(
    cpd_embeddings_path=selfies_embeddings_path,
    ccl_ge_path=ccl_ge_path,
    drp_path=drp_test_path,
    cpd_type="selfies",
)

test_loader = DataLoader(
    dataset,
    batch_size=256,
    shuffle=False,
    num_workers=8,
)

len(dataset)

In [None]:
selfies_checkpoint = os.path.join(
    "checkpoints/selfies_model_20231127-053956/best_selfies_model_20231127-053956_trained.pt"
)
selfies_state_dict = torch.load(selfies_checkpoint)

selfies_model = DrugResponseModelLegacy(
    cpd_sequence_length=256,
    cpd_embedding_dim=768,
    ccl_embedding_dim=6136,
    hidden_dim=1020,
    transformer_heads=6,
    transformer_layers=6,
).to(device)

selfies_model.load_state_dict(selfies_state_dict)

In [None]:
metrics = [mean_squared_error, pearson_corrcoef, spearman_corrcoef, mean_absolute_error]
selfies_results, preds, labels = compute_eval_metrics(
    selfies_model, test_loader, metrics, device, return_predictions=True
)

In [None]:
print(
    f"Mean Squared Error: {selfies_results[0][0]:.4f} +/- {selfies_results[0][1]:.4f}"
)
print(f"Mean Absolute Error: {selfies_results[3]:.4f}")
print(f"Pearson Correlation: {selfies_results[1]:.4f}")
print(f"Spearman Correlation: {selfies_results[2]:.4f}")

### SMILES-based Trained Model Evaluation

In [None]:
dataset = DrugResponseDataset(
    cpd_embeddings_path=smiles_embeddings_path,
    ccl_ge_path=ccl_ge_path,
    drp_path=drp_test_path,
    cpd_type="smiles",
)

test_loader = DataLoader(
    dataset,
    batch_size=256,
    shuffle=False,
    num_workers=8,
)

len(dataset)

In [None]:
smiles_checkpoint = os.path.join(
    "checkpoints/smiles_model_20231123-190031/best_smiles_model_20231123-190031_trained.pt"
)
smiles_state_dict = torch.load(smiles_checkpoint)

In [None]:
smiles_model = DrugResponseModelLegacy(
    cpd_sequence_length=256,
    cpd_embedding_dim=768,
    ccl_embedding_dim=6136,
    hidden_dim=1020,
    transformer_heads=6,
    transformer_layers=6,
).to(device)

smiles_model.load_state_dict(smiles_state_dict)

In [None]:
metrics = [mean_squared_error, pearson_corrcoef, spearman_corrcoef, mean_absolute_error]
smiles_results, preds, labels = compute_eval_metrics(
    smiles_model, test_loader, metrics, device, return_predictions=True
)

In [None]:
print(f"Mean Squared Error: {smiles_results[0][0]:.4f} +/- {smiles_results[0][1]:.4f}")
print(f"Mean Absolute Error: {smiles_results[3]:.4f}")
print(f"Pearson Correlation: {smiles_results[1]:.4f}")
print(f"Spearman Correlation: {smiles_results[2]:.4f}")

### Tokens-based

In [None]:
dataset = DrugResponseDataset(
    cpd_embeddings_path=smiles_tokens_embeddings_path,
    ccl_ge_path=ccl_ge_path,
    drp_path=drp_test_path,
    cpd_type="smiles",
    embed_tokens=True,
)

test_loader = DataLoader(
    dataset,
    batch_size=256,
    shuffle=False,
    num_workers=8,
)

len(dataset)

In [None]:
smiles_tokens_checkpoint = os.path.join(
    "checkpoints/smiles_tokens_model_20231204-183804/best_smiles_tokens_model_20231204-183804_trained.pt"
)
smiles_tokens_state_dict = torch.load(smiles_tokens_checkpoint)

smiles_tokens_model = DrugResponseModelTokens(
    cpd_sequence_length=256,
    cpd_embedding_dim=768,
    ccl_embedding_dim=6136,
    hidden_dim=1020,
    transformer_heads=4,
    transformer_layers=4,
).to(device)

smiles_tokens_model.load_state_dict(smiles_tokens_state_dict)

In [None]:
metrics = [mean_squared_error, pearson_corrcoef, spearman_corrcoef, mean_absolute_error]
smiles_tokens_results, smiles_tokens_preds, smiles_tokens_labels = compute_eval_metrics(
    smiles_tokens_model, test_loader, metrics, device, return_predictions=True
)

In [None]:
print(
    f"Mean Squared Error: {smiles_tokens_results[0][0]:.4f} +/- {smiles_tokens_results[0][1]:.4f}"
)
print(f"Mean Absolute Error: {smiles_tokens_results[3]:.4f}")
print(f"Pearson Correlation: {smiles_tokens_results[1]:.4f}")
print(f"Spearman Correlation: {smiles_tokens_results[2]:.4f}")

### Fingerprints-based

In [None]:
dataset = DrugResponseDataset(
    cpd_embeddings_path=smiles_fingerprints_embeddings_path,
    ccl_ge_path=ccl_ge_path,
    drp_path=drp_test_path,
    cpd_type="smiles",
)

test_loader = DataLoader(
    dataset,
    batch_size=256,
    shuffle=False,
    num_workers=8,
)

len(dataset)

In [None]:
model_pl = DrugResponseLightningModule.load_from_checkpoint(
    "checkpoints/smiles_fingerprints_model_20231212-073658/epoch=119-val_loss=0.0003.ckpt"
)
fingerprints_model = model_pl.model

In [None]:
metrics = [mean_squared_error, pearson_corrcoef, spearman_corrcoef, mean_absolute_error]
(
    smiles_fingerprints_results,
    smiles_fingerprints_preds,
    smiles_fingerprints_labels,
) = compute_eval_metrics(
    fingerprints_model, test_loader, metrics, device, return_predictions=True
)

In [None]:
print(
    f"Mean Squared Error: {smiles_fingerprints_results[0][0]:.4f} +/- {smiles_fingerprints_results[0][1]:.4f}"
)
print(f"Mean Absolute Error: {smiles_fingerprints_results[3]:.4f}")
print(f"Pearson Correlation: {smiles_fingerprints_results[1]:.4f}")
print(f"Spearman Correlation: {smiles_fingerprints_results[2]:.4f}")

## II. Per-compound evaluation

### SELFIES-based Trained Model Evaluation

In [None]:
test_dataset = DrugResponseDataset(
    cpd_embeddings_path=selfies_embeddings_path,
    ccl_ge_path=ccl_ge_path,
    drp_path=drp_test_path,
    cpd_type="selfies",
)

train_dataset = DrugResponseDataset(
    cpd_embeddings_path=selfies_embeddings_path,
    ccl_ge_path=ccl_ge_path,
    drp_path=drp_train_path,
    cpd_type="selfies",
)

test_loader = DataLoader(
    test_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=8,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=8,
)

In [None]:
selfies_checkpoint = os.path.join(
    "checkpoints/selfies_model_20231127-053956/best_selfies_model_20231127-053956_trained.pt"
)
selfies_state_dict = torch.load(selfies_checkpoint)

selfies_model = DrugResponseModelLegacy(
    cpd_sequence_length=256,
    cpd_embedding_dim=768,
    ccl_embedding_dim=6136,
    hidden_dim=1020,
    transformer_heads=6,
    transformer_layers=6,
).to(device)

selfies_model.load_state_dict(selfies_state_dict)

In [None]:
metrics = [mean_squared_error, pearson_corrcoef, spearman_corrcoef]
(
    test_selfies_results,
    test_per_compound_preds_selfies,
) = compute_per_compound_metrics(
    selfies_model, test_loader, metrics, device, return_predictions=True
)

(
    train_seflies_results,
    train_per_compound_preds_selfies,
) = compute_per_compound_metrics(
    selfies_model, train_loader, metrics, device, return_predictions=True
)

In [None]:
plot_predictions_vs_ground_truth(test_per_compound_preds_selfies, title=None, xlabel='Mean Ground Truth AUC Value', ylabel='Mean Predicted AUC Value', fontsize=14)

In [None]:
plot_predictions_vs_ground_truth(train_per_compound_preds_selfies, title=None, xlabel='Mean Ground Truth AUC Value', ylabel='Mean Predicted AUC Value', fontsize=14)

### SMILES-based Trained Model Evaluation

In [None]:
test_dataset = DrugResponseDataset(
    cpd_embeddings_path=smiles_embeddings_path,
    ccl_ge_path=ccl_ge_path,
    drp_path=drp_test_path,
    cpd_type="smiles",
)

train_dataset = DrugResponseDataset(
    cpd_embeddings_path=smiles_embeddings_path,
    ccl_ge_path=ccl_ge_path,
    drp_path=drp_train_path,
    cpd_type="smiles",
)

test_loader = DataLoader(
    test_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=8,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=8,
)

In [None]:
smiles_checkpoint = os.path.join(
    "checkpoints/smiles_model_20231123-190031/best_smiles_model_20231123-190031_trained.pt"
)
smiles_state_dict = torch.load(smiles_checkpoint)

smiles_model = DrugResponseModelLegacy(
    cpd_sequence_length=256,
    cpd_embedding_dim=768,
    ccl_embedding_dim=6136,
    hidden_dim=1020,
    transformer_heads=6,
    transformer_layers=6,
).to(device)

smiles_model.load_state_dict(smiles_state_dict)

In [None]:
metrics = [mean_squared_error, pearson_corrcoef, spearman_corrcoef]
(
    test_smiles_results,
    test_per_compound_preds_smiles,
) = compute_per_compound_metrics(
    smiles_model, test_loader, metrics, device, return_predictions=True
)

(
    train_smiles_results,
    train_per_compound_preds_smiles,
) = compute_per_compound_metrics(
    smiles_model, train_loader, metrics, device, return_predictions=True
)

In [None]:
plot_predictions_vs_ground_truth(test_per_compound_preds_smiles, title=None, xlabel='Mean Ground Truth AUC Value', ylabel='Mean Predicted AUC Value', fontsize=14)

In [None]:
plot_predictions_vs_ground_truth(train_per_compound_preds_smiles, title=None, xlabel='Mean Ground Truth AUC Value', ylabel='Mean Predicted AUC Value', fontsize=14)


### Token-based model

In [None]:
test_dataset = DrugResponseDataset(
    cpd_embeddings_path=smiles_tokens_embeddings_path,
    ccl_ge_path=ccl_ge_path,
    drp_path=drp_test_path,
    cpd_type="smiles",
    embed_tokens=True,
)

train_dataset = DrugResponseDataset(
    cpd_embeddings_path=smiles_tokens_embeddings_path,
    ccl_ge_path=ccl_ge_path,
    drp_path=drp_train_path,
    cpd_type="smiles",
    embed_tokens=True,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=8,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=8,
)

In [None]:
smiles_tokens_checkpoint = os.path.join(
    "checkpoints/smiles_tokens_model_20231204-183804/best_smiles_tokens_model_20231204-183804_trained.pt"
)
smiles_tokens_state_dict = torch.load(smiles_tokens_checkpoint)

smiles_tokens_model = DrugResponseModelTokens(
    cpd_sequence_length=256,
    cpd_embedding_dim=768,
    ccl_embedding_dim=6136,
    hidden_dim=1020,
    transformer_heads=4,
    transformer_layers=4,
).to(device)

smiles_tokens_model.load_state_dict(smiles_tokens_state_dict)

In [None]:
metrics = [mean_squared_error, pearson_corrcoef, spearman_corrcoef]
(
    test_smiles_tokens_results,
    test_per_compound_preds_tokens,
) = compute_per_compound_metrics(
    smiles_tokens_model, test_loader, metrics, device, return_predictions=True
)

(
    train_smiles_tokens_results,
    train_per_compound_preds_tokens,
) = compute_per_compound_metrics(
    smiles_tokens_model, train_loader, metrics, device, return_predictions=True
)

In [None]:
plot_predictions_vs_ground_truth(test_per_compound_preds_tokens, title=None, xlabel='Mean Ground Truth AUC Value', ylabel='Mean Predicted AUC Value', fontsize=14)

In [None]:
plot_predictions_vs_ground_truth(train_per_compound_preds_tokens, title=None, xlabel='Mean Ground Truth AUC Value', ylabel='Mean Predicted AUC Value', fontsize=14)


### Morgan Fingerprints based model

In [None]:
test_dataset = DrugResponseDataset(
    cpd_embeddings_path=smiles_fingerprints_embeddings_path,
    ccl_ge_path=ccl_ge_path,
    drp_path=drp_test_path,
    cpd_type="smiles",
)

train_dataset = DrugResponseDataset(
    cpd_embeddings_path=smiles_fingerprints_embeddings_path,
    ccl_ge_path=ccl_ge_path,
    drp_path=drp_train_path,
    cpd_type="smiles",
)

test_loader = DataLoader(
    test_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=8,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=8,
)

In [None]:
model_pl = DrugResponseLightningModule.load_from_checkpoint(
    "checkpoints/smiles_fingerprints_model_20231212-073658/epoch=119-val_loss=0.0003.ckpt"
)
fingerprints_model = model_pl.model

In [None]:
metrics = [mean_squared_error, pearson_corrcoef, spearman_corrcoef]
(
    test_smiles_fingerprints_results,
    test_per_compound_preds_fingerprints,
) = compute_per_compound_metrics(
    fingerprints_model, test_loader, metrics, device, return_predictions=True
)

(
    train_smiles_fingerprints_results,
    train_per_compound_preds_fingerprints,
) = compute_per_compound_metrics(
    fingerprints_model, train_loader, metrics, device, return_predictions=True
)

In [None]:
plot_predictions_vs_ground_truth(test_per_compound_preds_fingerprints, title=None, xlabel='Mean Ground Truth AUC Value', ylabel='Mean Predicted AUC Value', fontsize=14)

In [None]:
plot_predictions_vs_ground_truth(train_per_compound_preds_fingerprints, title=None, xlabel='Mean Ground Truth AUC Value', ylabel='Mean Predicted AUC Value', fontsize=14)

## III. Training and Validation Curves

Let's plot all the training and validation curves on the same plots.

In [None]:
selfies_training_loss_1 = os.path.join(
    "data/training_results/selfies_selfies_model_20231124-092533_selfies_model_version_0.csv"
)
selfies_training_loss_2 = os.path.join(
    "data/training_results/selfies_selfies_model_20231127-053956_selfies_model_version_0.csv"
)
selfies_training_loss_1_df = pd.read_csv(selfies_training_loss_1)
selfies_training_loss_2_df = pd.read_csv(selfies_training_loss_2)
selfies_training_loss_df = pd.concat(
    [selfies_training_loss_1_df, selfies_training_loss_2_df]
)

selfies_val_loss_1 = os.path.join(
    "data/training_results/selfies_selfies_model_20231124-092533_selfies_model_version_0_val.csv"
)
selfies_val_loss_2 = os.path.join(
    "data/training_results/selfies_selfies_model_20231127-053956_selfies_model_version_0_val.csv"
)
selfies_val_loss_1_df = pd.read_csv(selfies_val_loss_1)
selfies_val_loss_2_df = pd.read_csv(selfies_val_loss_2)
selfies_val_loss_df = pd.concat([selfies_val_loss_1_df, selfies_val_loss_2_df])

smiles_training_loss = os.path.join(
    "data/training_results/smiles_smiles_model_20231123-190031_smiles_model_version_0.csv"
)
smiles_val_loss = os.path.join(
    "data/training_results/smiles_smiles_model_20231123-190031_smiles_model_version_0_val.csv"
)
smiles_training_loss_df = pd.read_csv(smiles_training_loss)
smiles_val_loss_df = pd.read_csv(smiles_val_loss)

smiles_tokens_training_loss = os.path.join(
    "data/training_results/smiles_tokens_smiles_tokens_model_20231204-183804_smiles_tokens_model_version_0.csv"
)
smiles_tokens_val_loss = os.path.join(
    "data/training_results/smiles_tokens_smiles_tokens_model_20231204-183804_smiles_tokens_model_version_0_val.csv"
)
smiles_tokens_training_loss_df = pd.read_csv(smiles_tokens_training_loss)
smiles_tokens_val_loss_df = pd.read_csv(smiles_tokens_val_loss)

fingerprints_training_loss = os.path.join(
    "data/training_results/fingerprints_smiles_fingerprints_model_20231212-073658_smiles_fingerprints_model_version_0.csv"
)
fingerprints_val_loss = os.path.join(
    "data/training_results/fingerprints_smiles_fingerprints_model_20231212-073658_smiles_fingerprints_model_version_0_val.csv"
)
fingerprints_training_loss_df = pd.read_csv(fingerprints_training_loss)
fingerprints_val_loss_df = pd.read_csv(fingerprints_val_loss)

In [None]:
with plt.style.context("seaborn-v0_8-whitegrid"):
    plt.figure(figsize=(10, 6))
    plt.plot(
        selfies_training_loss_df["Step"],
        selfies_training_loss_df["Value"],
        label="SELFIES Embeddings",
    )
    plt.plot(
        smiles_training_loss_df["Step"],
        smiles_training_loss_df["Value"],
        label="SMILES Embeddings",
    )
    plt.plot(
        smiles_tokens_training_loss_df["Step"],
        smiles_tokens_training_loss_df["Value"],
        label="SMILES Tokens",
    )
    plt.plot(
        fingerprints_training_loss_df["Step"],
        fingerprints_training_loss_df["Value"],
        label="Morgan Fingerprints",
    )
    plt.xlabel("Steps", fontsize=14)
    plt.ylabel("MSE Loss", fontsize=14)
    # plt.title("Training Losses", fontsize=16)
    plt.ylim(0, 0.06)
    plt.legend(fontsize=14)
    plt.show()

In [None]:
with plt.style.context("seaborn-v0_8-whitegrid"):
    plt.figure(figsize=(10, 6))
    plt.plot(
        selfies_val_loss_df["Step"],
        selfies_val_loss_df["Value"],
        label="SELFIES Embeddings",
    )
    plt.plot(
        smiles_val_loss_df["Step"],
        smiles_val_loss_df["Value"],
        label="SMILES Embeddings",
    )
    plt.plot(
        smiles_tokens_val_loss_df["Step"],
        smiles_tokens_val_loss_df["Value"],
        label="SMILES Tokens",
    )
    plt.plot(
        fingerprints_val_loss_df["Step"],
        fingerprints_val_loss_df["Value"],
        label="Morgan Fingerprints",
    )
    plt.xlabel("Steps", fontsize=14)
    plt.ylabel("MSE Loss", fontsize=14)
    # plt.title("Validation Losses", fontsize=16)
    plt.ylim(0, 0.06)
    plt.legend(fontsize=14)
    plt.show()