In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("../../")

In [3]:
from tokenizers import Tokenizer
import sys

import matplotlib.pyplot as plt
import numpy as np
import collections
import torch

from ChEmbed.data import chembldb, smiles_dataset, chembed_tokenize
from ChEmbed.training import trainer
from ChEmbed.modules import simple_rnn
import attr

from ChEmbed import plots, utilities

from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

Disabling PyTorch because PyTorch >= 2.1 is required but found 2.0.0
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [4]:
chembl_raw = chembldb.ChemblDB()
chembl_smiles = chembl_raw._load_or_download()["canonical_smiles"]

In [5]:
tokenizer = chembed_tokenize.load_chembed_tokenizer(filepath = "../data/tokenizers/tokenizer-chembldb-16-06-2025.json")

In [6]:
# chembl_dataset = smiles_dataset.SMILESDataset(
#     smiles_list = chembl_smiles,
#     tokenizer = tokenizer
# )

# chembl_mini = smiles_dataset.SMILESDataset(
#     smiles_list = chembl_smiles[:5000],
#     tokenizer = tokenizer
# )

In [7]:
chembl_mini = smiles_dataset.SMILESDatasetContinuous(
    smiles_list = chembl_smiles[:25000],
    tokenizer = tokenizer,
    length = 1024,
)

In [8]:
chembl_mini[0]

(tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 tensor([0., 0., 0.,  ..., 0., 0., 0.]))

In [9]:
model = simple_rnn.simpleRNN(
    # Mandatory
    num_hiddens = 128,
    vocab_size = len(tokenizer),
    # tuning
    learning_rate = 0.2,
    weight_decay = 0.08
)



In [10]:
def predict(input, num_preds, model, dataset, device=None):
    state, outputs = None, [input[0]]

    # predict input + num_preds tokens
    for i in range(len(input) + num_preds - 1):
        
        X = torch.tensor(dataset.encode_smiles_to_one_hot(outputs[i]), device=device)
        
        rnn_outputs = model(X, state)
        
        if i < len(input) - 1:
            outputs.append(input[i + 1])
        else:
            tokens = int(rnn_outputs[:, -1].argmax())
            Y = dataset.tokenizer.decode(tokens)
            outputs.append(Y)
    return ''.join(outputs)
            

In [11]:
print(predict("CCC", 30, model, chembl_mini, device="cpu"))

  X = torch.tensor(dataset.encode_smiles_to_one_hot(outputs[i]), device=device)


CCCC1OC1OC1OC1OC1OC1OC1OC1OC1OC1OC1OC1OC1OC1OC1OC1OC1OC1OC1OC1OC1OC1OC1OC1OC1OC1OC1OC1OC1OC1O


In [16]:
model_trainer = trainer.Trainer(max_epochs=16, init_random=None, clip_grads_norm=3)
model_trainer.fit(model, chembl_mini)

Training batch 5/5... (Epoch 1/16)

TypeError: unsupported operand type(s) for +=: 'float' and 'NoneType'

In [120]:
print(predict("C", 30, model, chembl_mini, device="cuda"))

['C', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2', 'OCc2ccccc2']
COCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2OCc2ccccc2


  X = torch.tensor(dataset.encode_smiles_to_one_hot(outputs[i]), device=device)


In [None]:
losses = utilities.extract_training_losses(model_trainer.metadata)
fig, ax = plots.plot_training_validation_loss(losses['avg_train_losses'], losses['avg_val_losses'])
ax.set_yscale('log')