In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

sys.path.append("../")

import sys
import torch
from data import datasets
from training import trainer
from modules import ind_generator

import datetime
import plots, utilities
from IPython.display import Image, display, Markdown
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

from rdkit import Chem
from rdkit.Chem import Draw

In [None]:
dataset = datasets.CharSMILESChEMBLIndications(
    batch_size=128
)

## What are some of the most common indications in our dataset

In [None]:
most_frequent_indications = dataset.all_data.drop(columns=["canonical_smiles"]).sum(axis=0).sort_values(ascending=False)
most_frequent_indications_names = most_frequent_indications.index.to_list()
fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(20, 5))
ax.bar(np.arange(len(most_frequent_indications)), most_frequent_indications.to_numpy())
ax.set_xticks(np.arange(len(most_frequent_indications)))
labels = ax.set_xticklabels([heading.replace("mesh_heading_", "") for heading in most_frequent_indications.index], rotation=90)
ax.set_yscale("log")

## What do the molecules look like for our most common indication?

In [None]:
for indications_filter_index in range(10):
    # indications_filter_index = 1
    indications_filter_name = most_frequent_indications.index[indications_filter_index]
    filtered_molecules = dataset.all_data.filter(items=["canonical_smiles", indications_filter_name])
    filtered_molecules = filtered_molecules[filtered_molecules[indications_filter_name]].drop(columns=[indications_filter_name])
    filtered_molecules = filtered_molecules.rename(columns={"canonical_smiles": indications_filter_name})

    rows, cols = 3, 3

    mols = []
    names = []
    for i, smiles in enumerate(filtered_molecules[indications_filter_name][:rows * cols]):
        mols.append(Chem.MolFromSmiles(smiles))
        names.append(str(i + 1))

    mols = [[mols[i + j] for i in range(rows)] for j in range(0, cols * rows, rows)]
    names = [[names[i + j] for i in range(rows)] for j in range(0, cols * rows, rows)]

    display(Markdown(f"## {rows * cols} example chemical structures for {indications_filter_name}"))
    display(Draw.MolsMatrixToGridImage(mols, legendsMatrix=names))

# Model training

In [None]:
model = ind_generator.SmilesIndGeneratorRNN(
    vocab_size = dataset.vocab_size,
    num_indications = dataset.num_indications,
    num_hiddens = 256,
    num_layers = 5,
    learning_rate = 1e-3,
    weight_decay = 1e-5,
    output_dropout = 0.4,
    rnn_dropout = 0.4,
    state_dropout = 0.4
)

In [None]:

load_model_path = "../models/Chembl-Ind-SmilesIndGeneratorRNN-CharSMILESChEMBLIndications-2025-07-30-17-50-04.pt"

train_new = False  # Set false to load a pre-trained model
save_model = True  # If trainign a new model, do we want to save it?

if train_new:
    model_trainer = trainer.Trainer(max_epochs=128, init_random=None, clip_grads_norm=10.0)
    model_trainer.fit(model, dataset)

    if save_model:
        utilities.save_model_weights("Chembl-Ind-", model, dataset)
else: 
    model.load_weights(
        path = load_model_path,
    )
    device="cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

In [None]:
if train_new:
    losses = utilities.extract_training_losses(model_trainer.metadata)
    fig, ax = plots.plot_training_validation_loss(
        training_losses = losses['avg_train_losses'], 
        validation_losses = losses['avg_val_losses']
    )
    ax.set_yscale('log')

In [None]:

test_batch = dataset[20]

output, _ = model(test_batch[0].unsqueeze(0).to(device="cuda"), test_batch[1].unsqueeze(0).to(device="cuda"))

print(f"(Seq Len, vocab size): {test_batch[0].shape}, (Indications): {test_batch[1].shape}, (Seq Len): {test_batch[2].shape}")
print(f"Input SMILES: {''.join([dataset.idx_to_char[c] for c in test_batch[2].cpu().numpy()])}")
print(f"Prediction: {''.join([dataset.idx_to_char[c] for c in output.argmax(dim=-1).squeeze().cpu().numpy()])}")

# Lets generate some  Medications

In [None]:
rows, cols = 3, 3

n_chars = 100

n_valid = 0
n_invalid = 0
images = []
outputs = []

# Get the indices for the (rows * cols) most common drug indications
mesh_indices = [dataset.indications_names.index(indication_name) for indication_name in most_frequent_indications_names[:rows * cols]]

# We don't always get valid output, so we use a robust generation procedure to allow us to make a few
# Attempts at getting a valid output
max_attempts = 5
for idx in mesh_indices:

    output = ind_generator.robust_generate(
        ind_generator.simple_generate,
        max_attempts=max_attempts,
        prefix="",
        indications_tensor = dataset.get_indications_tensor(dataset.indications_names[idx]).to("cuda"),
        num_chars=500,
        model=model,
        char_to_idx_mapping=dataset.char_to_idx,
        idx_to_char_mapping=dataset.idx_to_char,
        temperature=0.7,
        device="cuda"

    )

    # Throw a warning if this exact output is in the training set
    if output in dataset.all_data["canonical_smiles"].tolist():
        print("\n WARNING: Exact output found in training, overfitting? \n")

    if output:
        outputs.append(output)
        n_valid += 1
    if not output:
        n_invalid += 1
        print("Generated SMILES is not valid.")
        outputs.append(None)

In [None]:
mols = []
names = []
for i, smiles in enumerate(outputs):
    mols.append(Chem.MolFromSmiles(smiles))
    names.append(dataset.indications_names[mesh_indices[i]].replace("mesh_heading_", ""))

mols = [[mols[i + j] for i in range(rows)] for j in range(0, cols * rows, rows)]
names = [[names[i + j] for i in range(rows)] for j in range(0, cols * rows, rows)]

display(Draw.MolsMatrixToGridImage(mols, legendsMatrix=names))

## Sanity check... "mesh heading other" should generate essentially random molecules

In [None]:
rows, cols = 3, 3

n_chars = 100

n_valid = 0
n_invalid = 0
images = []
outputs = []

# Get the indices for the (rows * cols) "other" category
mesh_indices = [dataset.indications_names.index(most_frequent_indications_names[4])] * 9

# We don't always get valid output, so we use a robust generation procedure to allow us to make a few
# Attempts at getting a valid output
max_attempts = 5
for idx in mesh_indices:

    output = ind_generator.robust_generate(
        ind_generator.simple_generate,
        max_attempts=max_attempts,
        prefix="",
        indications_tensor = dataset.get_indications_tensor(dataset.indications_names[idx]).to(model_trainer.device),
        num_chars=500,
        model=model,
        char_to_idx_mapping=dataset.char_to_idx,
        idx_to_char_mapping=dataset.idx_to_char,
        temperature=0.7,
        device=model_trainer.device
    )

    # Throw a warning if this exact output is in the training set
    if output in dataset.all_data["canonical_smiles"].tolist():
        print("\n WARNING: Exact output found in training, overfitting? \n")

    if output:
        outputs.append(output)
        n_valid += 1
    if not output:
        n_invalid += 1
        print("Generated SMILES is not valid.")
        outputs.append(None)

In [None]:
mols = []
names = []
for i, smiles in enumerate(outputs):
    mols.append(Chem.MolFromSmiles(smiles))
    names.append(dataset.indications_names[mesh_indices[i]].replace("mesh_heading_", ""))

mols = [[mols[i + j] for i in range(rows)] for j in range(0, cols * rows, rows)]
names = [[names[i + j] for i in range(rows)] for j in range(0, cols * rows, rows)]

display(Draw.MolsMatrixToGridImage(mols, legendsMatrix=names))

In [None]:
rows, cols = 3, 3

n_chars = 100

n_valid = 0
n_invalid = 0
images = []
outputs = []

# Get the indices for the (rows * cols) "other" category
mesh_indices = [dataset.indications_names.index(most_frequent_indications_names[7])] * 9

# We don't always get valid output, so we use a robust generation procedure to allow us to make a few
# Attempts at getting a valid output
max_attempts = 5
for idx in mesh_indices:

    output = ind_generator.robust_generate(
        ind_generator.simple_generate,
        max_attempts=max_attempts,
        prefix="N[C@H]",
        indications_tensor = dataset.get_indications_tensor(dataset.indications_names[idx]).to(model_trainer.device),
        num_chars=500,
        model=model,
        char_to_idx_mapping=dataset.char_to_idx,
        idx_to_char_mapping=dataset.idx_to_char,
        temperature=0.7,
        device=model_trainer.device
    )

    # Throw a warning if this exact output is in the training set
    if output in dataset.all_data["canonical_smiles"].tolist():
        print("\n WARNING: Exact output found in training, overfitting? \n")

    if output:
        outputs.append(output)
        n_valid += 1
    if not output:
        n_invalid += 1
        print("Generated SMILES is not valid.")
        outputs.append(None)

In [None]:
mols = []
names = []
for i, smiles in enumerate(outputs):
    mols.append(Chem.MolFromSmiles(smiles))
    names.append(dataset.indications_names[mesh_indices[i]].replace("mesh_heading_", ""))

mols = [[mols[i + j] for i in range(rows)] for j in range(0, cols * rows, rows)]
names = [[names[i + j] for i in range(rows)] for j in range(0, cols * rows, rows)]

display(Draw.MolsMatrixToGridImage(mols, legendsMatrix=names))

In [None]:
# # Fun outputs:

#  'C(C)OCC(=O)OCC(=O)O[C@]1(OC(=O)CC)[C@@H](C)C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C@@]3(F)[C@@H](O)C[C@@]21C',
#  'C(=O)O[C@H]1C(=O)[C@@]2(C)[C@H]([C@H](OC(=O)c3ccccc3)[C@]3(O)C[C@H](OC(=O)[C@H](O)[C@@H](NC(=O)c4ccccc4)c4ccccc4)C(C)=C1C3(C)C)[C@]1(OC(C)=O)CO[C@@H]1C[C@@H]2O',
#  'N[C@H](CC(C)C)C(=O)N[C@H]1C(=O)N[C@@H](CC(N)=O)C(=O)N[C@H]2C(=O)N[C@H]3C(=O)N[C@H](C(=O)N[C@H](C(=O)O)c4cc(O)cc(O)c4-c4cc3ccc4O)[C@H](O)c3ccc(c(Cl)c3)Oc3cc2cc(c3O[C@@H]2O[C@H](CO)[C@@H](O)[C@H](O)[C@H]2O[C@H]2C[C@](C)(N)[C@H](O)[C@H](C)O2)Oc2ccc(cc2Cl)[C@H]1O',
#  'C(C)(C)NC(=O)[C@H]1CC[C@H]2[C@@H]3CCC4=CC(=O)CC[C@]4(C)[C@H]3C(=O)C[C@@]21C',
#  'N(C)[C@@H]1C(O)=C(C(N)=O)C(=O)[C@@]2(O)C(O)=C3C(=O)c4c(O)cccc4[C@@](C)(O)[C@H]3C[C@@H]12',
#  'C(C)C[C@H](NC(=O)[C@@H](Cc1c[nH]c2ccccc12)NC(=O)[C@H](Cc1c[nH]cn1)NC(=O)[C@@H]1CCC(=O)N1)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N1CCC[C@H]1C(=O)NNC(N)=O',
#  '[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C@@]3(F)[C@@H](O)C[C@]2(C)[C@@]1(O)C(=O)COP(=O)([O-])[O-].[Na+].[Na+]',
#  'CCC(CCC)C(=O)O',
#  'C(=O)O[C@@]12CO[C@@H]1C[C@H](O)[C@@]1(C)C(=O)[C@H](O)C3=C(C)[C@@H](OC(=O)[C@H](O)[C@@H](NC(=O)OC(C)(C)C)c4ccccc4)C[C@@](O)([C@@H](OC(=O)c4ccccc4)[C@H]21)C3(C)C.O.O.O'

In [None]:
mol = Chem.MolFromSmiles('N[C@H](CC(C)C)C(=O)N[C@H]1C(=O)N[C@@H](CC(N)=O)C(=O)N[C@H]2C(=O)N[C@H]3C(=O)N[C@H](C(=O)N[C@H](C(=O)O)c4cc(O)cc(O)c4-c4cc3ccc4O)[C@H](O)c3ccc(c(Cl)c3)Oc3cc2cc(c3O[C@@H]2O[C@H](CO)[C@@H](O)[C@H](O)[C@H]2O[C@H]2C[C@](C)(N)[C@H](O)[C@H](C)O2)Oc2ccc(cc2Cl)[C@H]1O')
mol = Chem.AddHs(mol)

In [None]:
from rdkit.Chem import AllChem
import py3Dmol

AllChem.Compute2DCoords(mol)
params = AllChem.ETKDGv3()
params.randomSeed = 0xf00d
AllChem.EmbedMolecule(mol, params)
AllChem.MMFFOptimizeMolecule(mol)

In [None]:
view = py3Dmol.view(
    data=Chem.MolToMolBlock(mol),  # Convert the RDKit molecule for py3Dmol
    style={"stick": {}, "sphere": {"scale": 0.3}}
)
view.zoomTo()
