In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

sys.path.append("../")

import sys
import torch
from data import datasets
from training import trainer
from modules import bidirectional_ind_generator
from analysis import similarity

import plots, utilities
from IPython.display import display, Markdown
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np


In [None]:

# dataset = datasets.CharSMILESChEMBLIndications(
#     batch_size=128
# )

dataset = datasets.ChEMBLIndicationsExtended(
    batch_size=128
)

# Model training

In [None]:
config = bidirectional_ind_generator.SmilesBidirectionalIndGeneratorRNNConfig(
    vocab_size = len(dataset.vocab),
    num_indications = dataset.num_indications,
    num_hiddens = 256,
    num_layers = 4,
    learning_rate = 1e-3,
    weight_decay = 1e-5,
    output_dropout = 0.4,
    rnn_dropout = 0.4,
    state_dropout = 0.4
)

# Produces pretty bad models, but fine for debugging.
config_mini =  bidirectional_ind_generator.SmilesBidirectionalIndGeneratorRNNConfig(
    vocab_size = len(dataset.vocab),
    num_indications = dataset.num_indications,
    num_hiddens = 256,
    num_layers = 2,
    learning_rate = 1e-4,
    weight_decay = 1e-3,
    output_dropout = 0.4,
    rnn_dropout = 0.4,
    state_dropout = 0.4
)

model =  bidirectional_ind_generator.SmilesBidirectionalIndGeneratorRNN(config_mini)

In [None]:

load_model_path = "../models/Chembl-Ind-SmilesIndGeneratorRNN-CharSMILESChEMBLIndications-2025-07-31-08-01-18.pt"

train_new = True  # Set false to load a pre-trained model
save_model = True  # If trainign a new model, do we want to save it?

if train_new:
    model_trainer = trainer.Trainer(max_epochs=64, init_random=None, clip_grads_norm=10.0)
    model_trainer.fit(model, dataset)

    if save_model:
        utilities.save_model_weights("Chembl-Ind-", model, dataset)
else: 
    model.load_state_dict(
        state_dict= torch.load(load_model_path, weights_only=True),
    )
    device="cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

In [None]:
if train_new:
    losses = utilities.extract_training_losses(model_trainer.metadata)
    fig, ax = plots.plot_training_validation_loss(
        training_losses = losses['avg_train_losses'], 
        validation_losses = losses['avg_val_losses']
    )
    ax.set_yscale('log')

In [None]:
test_batch = dataset[20]

output, _ = model((test_batch[0].unsqueeze(0).to(device="cuda"), test_batch[1].unsqueeze(0).to(device="cuda")))

print(output)

print(f"(Seq Len, vocab size): {test_batch[0].shape}, (Indications): {test_batch[1].shape}, (Seq Len): {test_batch[2].shape}")
print(f"Input SMILES: {dataset.vocab.decode_tokens(test_batch[2].cpu().numpy())}")

print(f"Prediction: {dataset.vocab.decode_tokens(output.argmax(dim=-1).squeeze().cpu().numpy())}")
print(f"Prediction Encoded: {output.argmax(dim=-1).squeeze().cpu().numpy()}")

In [None]:
# Define a quick convenience function which will let us generate a sample of molecules for a particul indication

# indication = "mesh_heading_Infections"
indication = "mesh_heading_Bacterial Infections"
indication_name = indication.replace("mesh_heading_", "")
outputs = utilities.generate_samples(
    rows=5,
    cols=5,
    dataset=dataset,
    model=model,
    prompt=dataset.vocab.bos.char, # Prompt is just the beginning of sentence special character
    indication_name=indication,
    max_attempts=10,
    max_generate=100,
    temperature=0.7
)

display(
    Markdown(f"## {indication_name}")
)
display(
    utilities.draw_molecules_as_grid_from_smiles(
        canonical_smiles=outputs, 
        names=[indication_name] * len(outputs)  
    )
)