In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

sys.path.append("../")

from tokenizers import Tokenizer
import sys

import matplotlib.pyplot as plt
import numpy as np
import collections
import torch

from data import datasets
from training import trainer
from modules import ind_generator
import attr

import plots, utilities

from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [None]:
dataset = datasets.CharSMILESChEMBLIndications()

In [None]:
model = ind_generator.SmilesIndGeneratorRNN(
    vocab_size=dataset.vocab_size,
    num_indications=dataset.num_indications,
    num_hiddens=250,
    num_layers=2,
    learning_rate = 0.001,
    weight_decay = 0.01
)

In [None]:
model_trainer = trainer.Trainer(max_epochs=4, init_random=None, clip_grads_norm=1.0)
model_trainer.fit(model, dataset)

In [None]:
print(dataset.indications_names)

In [None]:
def generate_sequence(prefix, indications_tensor, num_chars, model, char_to_idx_mapping, idx_to_char_mapping, temperature = 0.0, device=None):
    """
    Simple character-by-character generation function.
    """

    def decode_indices_to_string(encoded_indices: list, idx_to_char_mapping: dict[int, str]):
        decoded = ''.join([idx_to_char_mapping[int(inx)] for inx in encoded_indices])
        return decoded

    def encode_string_to_indices(smiles_string: str, char_to_idx_mapping: dict[str, int]):
        encoded = [char_to_idx_mapping[c] for c in smiles_string]
        return encoded

    model.eval()
    generated = prefix
    
    # Ensure device is correct
    state = indications_tensor.unsqueeze(0).to(device)

    with torch.no_grad():
        for i in range(num_chars):
            # Encode current text to indices
            encoded_indices = encode_string_to_indices(generated, char_to_idx_mapping)
            
            # Convert to tensor and create one-hot encoding
            indices_tensor = torch.tensor(encoded_indices, device=device)
            input_tensor = torch.nn.functional.one_hot(indices_tensor, num_classes=len(char_to_idx_mapping)).float()
            
            # Get prediction
            output, state = model(input_tensor.unsqueeze(0), state) 
            
            # Get most likely next token
            if temperature > 0:
                # Apply temperature scaling
                output = output / temperature
                probabilities = torch.softmax(output, dim=-1)
                next_token = torch.multinomial(probabilities[0, -1, :], num_samples=1).item()
            else:
                # Default to argmax if temperature is 0
                next_token = output[0, -1, :].argmax().item()
            
            # Decode and append
            next_char = idx_to_char_mapping[next_token]

            if next_token == dataset.padding_index or next_char == "":
                break

            
            generated += next_char
            
    return generated

In [None]:
generate_sequence(
    prefix="C",
    indications_tensor = dataset.get_indications_tensor(dataset.indications_names[0]),
    num_chars=100,
    model=model,
    char_to_idx_mapping=dataset.char_to_idx,
    idx_to_char_mapping=dataset.idx_to_char,
    temperature=0.0,
    device=model_trainer.device
)