In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys, os

sys.path.append("../")

import torch

from data import chembldb, datasets
from training import trainer
from modules import simple_rnn
import datetime
import plots, utilities
from IPython.display import Image, display, Markdown

import torch


## First we load the SMILES strings representing the molecules in the CHEMBL database.

In [None]:
# the whole dataset is huge, so we only take every nth SMILES string
every_nth = 10
chembl_raw = chembldb.ChemblDBChemreps()
chembl_smiles = chembl_raw._load_or_download()["canonical_smiles"].to_list()[::every_nth]

In [None]:
# Chembl smiles contains a list of all the SMILES strings in the Chembl database.
chembl_smiles[:10]

In [None]:
chembl = datasets.CharacterLevelSMILES(
    smiles_list = chembl_smiles,
    length = 256, 
    batch_size = 128
)

## Next, lets define a model. We'll be using a character-level LSTM model to predict the next character in a SMILES string based on the previous characters.

In [None]:
model = simple_rnn.simpleLSTM(
    # Mandatory
    num_hiddens = 512,
    vocab_size = len(chembl.characters),
    # tuning
    learning_rate = 0.001,
    weight_decay = 1e-4,
    num_layers = 5,
    output_dropout=0.2,
    rnn_dropout=0.2,
)

In [None]:
load_model_path = "../models/character_level_rnn_generator.pth"

train_new = True  # Set false to load a pre-trained model
save_model = True  # If trainign a new model, do we want to save it?

if train_new:
    model_trainer = trainer.Trainer(max_epochs=16, init_random=None, clip_grads_norm=2.0)
    model_trainer.fit(model, chembl)

    if save_model:
        utilities.save_model_weights("Chembl-Mini-", model, chembl)
else: 
    model.load_weights(
        path = load_model_path,
    )
    device="cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

In [None]:
losses = utilities.extract_training_losses(
    metadata = model_trainer.metadata,
)
fig, ax = plots.plot_training_validation_loss(
    training_losses = losses["train_losses"],
    validation_losses = losses["val_losses"],
)

## Below, we ask the model to generate 20 totally random molecules, by providing it with an empty seed string.

In [None]:
n_chars = 100

n_valid = 0
n_invalid = 0
images = []
for i in range(20):
    output = simple_rnn.simple_generate(" ", n_chars, model, chembl.char_to_idx, chembl.idx_to_char, temperature=0.5, device='cuda')
    print(f"Requested up to {n_chars} characters, got: {len(output)}")
    print(f"Molecule Canonical SMILES: {output}")

    valid = utilities.validate_smiles_string(output)

    if not valid:
        n_invalid += 1
        print("Generated SMILES is not valid.")
    else:
        n_valid += 1
        print("Generated SMILES is valid.")
        images.append(utilities.draw_molecule(output))

## Visualising the attemps at generating molecules

For such a simple model, the results of this test are rather impressive. The model has learned to generate valid SMILES strings, which correspond to real molecules. In order to do this, the model will have implicitly "learned" the correct valences for atoms, common functional groups, and the rules of SMILES syntax, which is not a trivial task.

We can see that the model has learned to generate a variety of different molecules, some of which are quite complex. The model is able to generate molecules with rings, branches, and various functional groups, all while adhering to the rules of SMILES syntax. Occasionally included are rarer functional groups which are common in some pharmaceuticals, such as trifluoromethyl groups (CF3).

Of course, because SMILES strings are syntactially rigid, often with long-range dependencies, a simple model like this smaller LSTM will sometimes generate molecules which are _almost_ but not quite valid. Models become better at avoiding these syntactic errors with longer training and larger models. 

Above, we can see that the model has learned to generate valid SMILES strings, but we can also see information on the types of errors the model makes. These fall into two categories, syntactic errors, where the model generates a string which is not valid SMILES, and semantic errors, where the model generates a valid SMILES string but one which does not correspond to a real molecule. 

Syntactic: e.g. The model has a tendancy to open parentheses but not close them, or opening rings but not specifying where they close.
`SMILES Parse Error: extra open parentheses while parsing: CC(=O)N[C@@H](CCCCN)C(=O)N[C@@H](CCCCN)C(=O)N[C@@H](CCCN=C(N)N)C(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](CCCC
SMILES Parse Error: check for mistakes around position 96:
(CC(C)C)C(=O)N[C@@H](CCCC
`

Semantic: E.g. Occasionally generating molecules where atoms have incorrect valences (e.g. F with two bonds):
`SMILES Parse Error: Failed parsing SMILES ' HC(=O)N[C@H](C(=O)N[C@@H](CCCCN)C(=O)N[C@@H](CCCCN)C(=O)N[C@@H](CCCCN)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O''
Explicit valence for atom # 0 F, 2, is greater than permitted` 

The model also appears to have a real desire to generate long, saturated hydrocarbon chains. The training data does include some of these, so it's not surprising that the model has learned to generate them. If in doubt, the next character is probably just a saturated carbon atom.


In [None]:
display(Markdown(f"# Generated {n_valid} valid and {n_invalid} invalid SMILES strings out of 5 attempts."))
display(Markdown("## Generated Molecules"))
for img in images:
    display(img)

## Generating random molecules is fun, but a real example of where this sort of model might be useful would be the generation of molecules with specific properties...

In other models, we'll work to condition the initial state of the model with a vector of desired properties extracted from CHEMBL, but for now, we could say that we hope to generate _variants_ of an existing molecule, but with random, physically plausible changes to the structure.

Lets take caffiene, every scientists favourit, without it no work would get done. It is described by the following SMILES string `CN1C=NC2=C1C(=O)N(C(=O)N2C)C`.

We can generate a series of caffiene-like molecules by providing the model with a portion of the caffiene SMILES and asking it to predict the next character in the SMILES string.

In [None]:
n_chars = 100

n_valid = 0
n_invalid = 0
images = []
for i in range(20):
    output = simple_rnn.simple_generate("CN1C=NC2=C1C(=O)N", n_chars, model, chembl.char_to_idx, chembl.idx_to_char, temperature=0.8, device='cuda')
    print(f"Requested up to {n_chars} characters, got: {len(output)}")
    print(f"Molecule Canonical SMILES: {output}")

    valid = utilities.validate_smiles_string(output)

    if not valid:
        n_invalid += 1
        print("Generated SMILES is not valid.")
    else:
        n_valid += 1
        print("Generated SMILES is valid.")
        images.append(utilities.draw_molecule(output))

In [None]:
display(Markdown(f"# Generated {n_valid} valid caffiene-like molecules and {n_invalid} invalid SMILES strings out of 5 attempts."))
display(Markdown("## Generated Molecules"))
for img in images:
    display(img)