In [1]:
%load_ext autoreload
%autoreload 2

In [105]:
import sys

sys.path.append("../../")

from tokenizers import Tokenizer
import sys

import matplotlib.pyplot as plt
import numpy as np
import collections
import torch

from ChEmbed.data import chembldb, smiles_dataset, chembed_tokenize
from ChEmbed.training import trainer
from ChEmbed.modules import simple_rnn
import attr

from ChEmbed import plots, utilities

from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import torch
import attrs
from torch import nn, optim


In [23]:
def decode_indices_to_string(encoded_indices: list, idx_to_char_mapping: dict[int, str]):
    decoded = ''.join([idx_to_char_mapping[int(inx)] for inx in encoded_indices])
    return decoded

def encode_string_to_indices(smiles_string: str, char_to_idx_mapping: dict[str, int]):
    encoded = [char_to_idx_mapping[c] for c in smiles_string]
    return encoded

In [6]:
chembl_raw = chembldb.ChemblDB()
chembl_smiles = chembl_raw._load_or_download()["canonical_smiles"].to_list()

In [8]:
chembl_smiles[:10]

['Cc1cc(-c2csc(N=C(N)N)n2)cn1C',
 'CC[C@H](C)[C@H](NC(=O)[C@H](CC(C)C)NC(=O)[C@@H](NC(=O)[C@@H](N)CCSC)[C@@H](C)O)C(=O)NCC(=O)N[C@@H](C)C(=O)N[C@@H](C)C(=O)N[C@@H](Cc1c[nH]cn1)C(=O)N[C@@H](CC(N)=O)C(=O)NCC(=O)N[C@@H](C)C(=O)N[C@@H](C)C(=O)N[C@@H](CCC(N)=O)C(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](CCCN=C(N)N)C(=O)N[C@@H](CCC(N)=O)C(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](CCCN=C(N)N)C(=O)NCC(=O)N[C@@H](CCC(N)=O)C(=O)N[C@@H](CC(C)C)C(=O)NCC(=O)N1CCC[C@H]1C(=O)N1CCC[C@H]1C(=O)NCC(=O)N[C@@H](CO)C(=O)N[C@@H](CCCN=C(N)N)C(N)=O',
 'CCCC[C@@H]1NC(=O)[C@@H](NC(=O)[C@H](CC(C)C)NC(=O)[C@@H](NC(=O)[C@H](CCC(=O)O)NC(=O)[C@H](CCCN=C(N)N)NC(=O)[C@H](CC(C)C)NC(=O)[C@H](CC(C)C)NC(=O)[C@H](Cc2c[nH]cn2)NC(=O)[C@H](N)Cc2ccccc2)C(C)C)CCC(=O)NCCCC[C@@H](C(=O)N[C@@H](CCC(N)=O)C(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](C)C(=O)N[C@@H](CCC(N)=O)C(=O)N[C@@H](CCC(N)=O)C(=O)N[C@@H](C)C(=O)N[C@@H](Cc2c[nH]cn2)C(=O)N[C@@H](CO)C(=O)N[C@@H](CC(N)=O)C(=O)N[C@@H](CCCN=C(N)N)C(=O)N[C@@H](CCCCN)C(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](CCC

In [101]:
chembl_mini = smiles_dataset.CharacterLevelSMILES(
    smiles_list = chembl_smiles[:100000],
    length = 512,
    batch_size = 128
)

In [102]:
len(chembl_mini.all_smiles)

5511270

In [103]:
model = simple_rnn.simpleRNN(
    # Mandatory
    num_hiddens = 128,
    vocab_size = len(chembl_mini.characters),
    # tuning
    learning_rate = 0.05,
    weight_decay = 0.05
)

In [None]:
@attrs.define(eq=False)
class simpleRNN(nn.Module):

    num_hiddens: int
    vocab_size: int

    learning_rate: float = 0.1
    weight_decay: float = 0.01

    def __attrs_post_init__(self):
        super().__init__()
        # RNN layer that takes one-hot encoded input and produces hidden states
        self.rnn = nn.RNN(self.vocab_size, self.num_hiddens, batch_first=True)
        # Linear layer to project hidden states to vocabulary size for prediction
        self.linear = nn.Linear(self.num_hiddens, self.vocab_size)
        # Initialize parameters
        self._initialize_parameters()

    def _initialize_parameters(self):
        """Initialize model parameters using Xavier initialization"""
        for name, param in self.named_parameters():
            if 'weight' in name:
                nn.init.xavier_normal_(param)
            elif 'bias' in name:
                nn.init.zeros_(param)

    def forward(self, inputs, state=None):
        """
        Forward pass through the RNN
        Args:
            inputs: One-hot encoded input tensor of shape (batch_size, seq_len, vocab_size)
            state: Optional hidden state from previous time step
        Returns:
            output: Predictions of shape (batch_size, seq_len, vocab_size)
        """
        # Pass through RNN layer
        rnn_output, hidden_state = self.rnn(inputs, state)
        # Project to vocabulary size
        output = self.linear(rnn_output)
        return output

    def loss(self, y_hat, y):
        """
        Compute cross-entropy loss between predictions and targets
        Args:
            y_hat: Predictions of shape (batch_size * seq_len, vocab_size)
            y: Target indices of shape (batch_size * seq_len)
        Returns:
            loss: Cross-entropy loss
        """
        loss_fn = nn.CrossEntropyLoss()
        return loss_fn(y_hat, y)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)

    def training_step(self, batch):
        """
        Training step for one batch
        Args:
            batch: Tuple of (input_tensor, target_tensor)
        Returns:
            loss: Training loss for this batch
        """
        inputs, targets = batch
        # Forward pass
        predictions = self.forward(inputs)
        
        # Reshape for loss computation
        # predictions: (batch_size, seq_len, vocab_size) -> (batch_size * seq_len, vocab_size)
        # targets: (batch_size, seq_len) -> (batch_size * seq_len)
        predictions = predictions.reshape(-1, self.vocab_size)
        targets = targets.reshape(-1)
        
        return self.loss(predictions, targets)

    def validation_step(self, batch):
        """
        Validation step for one batch
        Args:
            batch: Tuple of (input_tensor, target_tensor)
        Returns:
            loss: Validation loss for this batch
        """
        inputs, targets = batch
        # Forward pass
        predictions = self.forward(inputs)
        
        # Reshape for loss computation
        # predictions: (batch_size, seq_len, vocab_size) -> (batch_size * seq_len, vocab_size)
        # targets: (batch_size, seq_len) -> (batch_size * seq_len)
        predictions = predictions.reshape(-1, self.vocab_size)
        targets = targets.reshape(-1)
        
        return self.loss(predictions, targets)

In [109]:
model_trainer = trainer.Trainer(max_epochs=32, init_random=None, clip_grads_norm=1.0)
model_trainer.fit(model, chembl_mini)

Epoch 1/32: Train Loss: 3.2996, Val Loss: 2.6527
Epoch 1/32: Train Loss: 3.2996, Val Loss: 2.6527
Training batch 85/85... (Epoch 2/32)

KeyboardInterrupt: 

In [82]:
def simple_generate(prefix, num_chars, model, char_to_idx_mapping, idx_to_char_mapping, device=None):
    """
    Simple character-by-character generation function.
    """

    def decode_indices_to_string(encoded_indices: list, idx_to_char_mapping: dict[int, str]):
        decoded = ''.join([idx_to_char_mapping[int(inx)] for inx in encoded_indices])
        return decoded

    def encode_string_to_indices(smiles_string: str, char_to_idx_mapping: dict[str, int]):
        encoded = [char_to_idx_mapping[c] for c in smiles_string]
        return encoded

    model.eval()
    generated = prefix
    
    with torch.no_grad():
        for i in range(num_chars):
            # Encode current text
            encoded = torch.nn.functional.one_hot(torch.tensor(encode_string_to_indices(generated, char_to_idx_mapping)), num_classes=len(char_to_idx_mapping))
            input_tensor = torch.tensor(encoded, device=device, dtype=torch.float32)
            
            # Get prediction
            output = model(input_tensor.unsqueeze(0))  # Add batch dim
            
            # Get most likely next token
            next_token = output[0, -1, :].argmax().item()
            
            # Decode and append
            next_char = decode_indices_to_string([next_token], idx_to_char_mapping)
            generated += next_char
            
            # print(f"Step {i+1}: Added '{next_char}' -> '{generated}'")
            
    return generated

In [46]:
print(simple_generate("CO", 30, model, chembl_mini.char_to_idx, chembl_mini.idx_to_char, device='cuda'))

CO)CCCCCCCCCCCCCCCCCCCCCCCCCCCCC


  input_tensor = torch.tensor(encoded, device=device, dtype=torch.float32)


In [107]:
# Test the completed simpleRNN class
test_model = simpleRNN(
    num_hiddens = 128,
    vocab_size = len(chembl_mini.characters),
    learning_rate = 0.05,
    weight_decay = 0.05
)

print(f"Model created successfully with {sum(p.numel() for p in test_model.parameters())} parameters")
print(f"Vocab size: {test_model.vocab_size}")
print(f"Hidden units: {test_model.num_hiddens}")

Model created successfully with 29490 parameters
Vocab size: 50
Hidden units: 128
