In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

sys.path.append("../../")

from tokenizers import Tokenizer
import sys

import matplotlib.pyplot as plt
import numpy as np
import collections
import torch

from ChEmbed.data import chembldb, smiles_dataset, chembed_tokenize, shakespeare_dataset
from ChEmbed.training import trainer
from ChEmbed.modules import simple_rnn
import attr

from ChEmbed import plots, utilities

from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import torch
import attrs
from torch import nn, optim


In [23]:
def decode_indices_to_string(encoded_indices: list, idx_to_char_mapping: dict[int, str]):
    decoded = ''.join([idx_to_char_mapping[int(inx)] for inx in encoded_indices])
    return decoded

def encode_string_to_indices(smiles_string: str, char_to_idx_mapping: dict[str, int]):
    encoded = [char_to_idx_mapping[c] for c in smiles_string]
    return encoded

In [6]:
chembl_raw = chembldb.ChemblDB()
chembl_smiles = chembl_raw._load_or_download()["canonical_smiles"].to_list()

In [8]:
chembl_smiles[:10]

['Cc1cc(-c2csc(N=C(N)N)n2)cn1C',
 'CC[C@H](C)[C@H](NC(=O)[C@H](CC(C)C)NC(=O)[C@@H](NC(=O)[C@@H](N)CCSC)[C@@H](C)O)C(=O)NCC(=O)N[C@@H](C)C(=O)N[C@@H](C)C(=O)N[C@@H](Cc1c[nH]cn1)C(=O)N[C@@H](CC(N)=O)C(=O)NCC(=O)N[C@@H](C)C(=O)N[C@@H](C)C(=O)N[C@@H](CCC(N)=O)C(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](CCCN=C(N)N)C(=O)N[C@@H](CCC(N)=O)C(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](CCCN=C(N)N)C(=O)NCC(=O)N[C@@H](CCC(N)=O)C(=O)N[C@@H](CC(C)C)C(=O)NCC(=O)N1CCC[C@H]1C(=O)N1CCC[C@H]1C(=O)NCC(=O)N[C@@H](CO)C(=O)N[C@@H](CCCN=C(N)N)C(N)=O',
 'CCCC[C@@H]1NC(=O)[C@@H](NC(=O)[C@H](CC(C)C)NC(=O)[C@@H](NC(=O)[C@H](CCC(=O)O)NC(=O)[C@H](CCCN=C(N)N)NC(=O)[C@H](CC(C)C)NC(=O)[C@H](CC(C)C)NC(=O)[C@H](Cc2c[nH]cn2)NC(=O)[C@H](N)Cc2ccccc2)C(C)C)CCC(=O)NCCCC[C@@H](C(=O)N[C@@H](CCC(N)=O)C(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](C)C(=O)N[C@@H](CCC(N)=O)C(=O)N[C@@H](CCC(N)=O)C(=O)N[C@@H](C)C(=O)N[C@@H](Cc2c[nH]cn2)C(=O)N[C@@H](CO)C(=O)N[C@@H](CC(N)=O)C(=O)N[C@@H](CCCN=C(N)N)C(=O)N[C@@H](CCCCN)C(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](CCC

In [101]:
chembl_mini = smiles_dataset.CharacterLevelSMILES(
    smiles_list = chembl_smiles[:100000],
    length = 512,
    batch_size = 128
)

In [116]:
len(chembl_mini.all_smiles)

5511270

In [117]:
model = simple_rnn.simpleRNN(
    # Mandatory
    num_hiddens = 128,
    vocab_size = len(chembl_mini.characters),
    # tuning
    learning_rate = 0.05,
    weight_decay = 0.05
)

In [118]:
model_trainer = trainer.Trainer(max_epochs=32, init_random=None, clip_grads_norm=1.0)
model_trainer.fit(model, chembl_mini)

Training batch 8/85... (Epoch 1/32)

Epoch 1/32: Train Loss: 3.0980, Val Loss: 2.7657
Epoch 2/32: Train Loss: 2.7189, Val Loss: 2.7990
Epoch 3/32: Train Loss: 2.7110, Val Loss: 2.6420
Training batch 73/85... (Epoch 4/32)

KeyboardInterrupt: 

In [82]:
def simple_generate(prefix, num_chars, model, char_to_idx_mapping, idx_to_char_mapping, device=None):
    """
    Simple character-by-character generation function.
    """

    def decode_indices_to_string(encoded_indices: list, idx_to_char_mapping: dict[int, str]):
        decoded = ''.join([idx_to_char_mapping[int(inx)] for inx in encoded_indices])
        return decoded

    def encode_string_to_indices(smiles_string: str, char_to_idx_mapping: dict[str, int]):
        encoded = [char_to_idx_mapping[c] for c in smiles_string]
        return encoded

    model.eval()
    generated = prefix
    
    with torch.no_grad():
        for i in range(num_chars):
            # Encode current text
            encoded = torch.nn.functional.one_hot(torch.tensor(encode_string_to_indices(generated, char_to_idx_mapping)), num_classes=len(char_to_idx_mapping))
            input_tensor = torch.tensor(encoded, device=device, dtype=torch.float32)
            
            # Get prediction
            output = model(input_tensor.unsqueeze(0))  # Add batch dim
            
            # Get most likely next token
            next_token = output[0, -1, :].argmax().item()
            
            # Decode and append
            next_char = decode_indices_to_string([next_token], idx_to_char_mapping)
            generated += next_char
            
            # print(f"Step {i+1}: Added '{next_char}' -> '{generated}'")
            
    return generated

In [46]:
print(simple_generate("CO", 30, model, chembl_mini.char_to_idx, chembl_mini.idx_to_char, device='cuda'))

CO)CCCCCCCCCCCCCCCCCCCCCCCCCCCCC


  input_tensor = torch.tensor(encoded, device=device, dtype=torch.float32)
