In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

sys.path.append("../")

import sys
import torch
from data import datasets
from training import trainer
from modules import ind_generator

import plots, utilities

In [None]:
dataset = datasets.CharSMILESChEMBLIndications()

In [None]:
dataset.all_data

In [None]:
len(dataset)

In [None]:
dataset[0][1].shape

In [None]:
model = ind_generator.SmilesIndGeneratorRNN(
    vocab_size=dataset.vocab_size,
    num_indications=dataset.num_indications,
    num_hiddens=256,
    num_layers=2,
    learning_rate = 1e-4,
    weight_decay = 1e-3,
    output_dropout=0.2,
    rnn_dropout=0.2,
)

In [None]:
model_trainer = trainer.Trainer(max_epochs=16, init_random=None, clip_grads_norm=1.0)
model_trainer.fit(model, dataset)

In [None]:
test_batch = dataset[1]

In [None]:
output, _ = model(test_batch[0].unsqueeze(0).to(device=model_trainer.device), test_batch[1].unsqueeze(0).to(device=model_trainer.device))

In [None]:
print(f"Input SMILES: {''.join([dataset.idx_to_char[c] for c in test_batch[2].cpu().numpy()])}")
print(f"Prediction: {''.join([dataset.idx_to_char[c] for c in output.argmax(dim=-1).squeeze().cpu().numpy()])}")

In [None]:
# find where index == 1
print(torch.where(test_batch[1] == 1)[0])

In [None]:
print(dataset.indications_names)

In [None]:
def simple_generate(prefix, num_chars, model, indications_tensor, char_to_idx_mapping, idx_to_char_mapping, temperature = 0.0, device=None):
    """
    Simple character-by-character generation function.
    """

    def decode_indices_to_string(encoded_indices: list, idx_to_char_mapping: dict[int, str]):
        decoded = ''.join([idx_to_char_mapping[int(inx)] for inx in encoded_indices])
        return decoded

    def encode_string_to_indices(smiles_string: str, char_to_idx_mapping: dict[str, int]):
        encoded = [char_to_idx_mapping[c] for c in smiles_string]
        return encoded

    model.eval()
    generated = prefix

    # Initialize state with indications
    state = model.init_state(indications_tensor.unsqueeze(0).to(device))  # Add batch dim
    
    with torch.no_grad():
        # First, process the prefix to get the proper state
        if len(prefix) > 0:
            prefix_encoded = encode_string_to_indices(prefix, char_to_idx_mapping)
            prefix_tensor = torch.nn.functional.one_hot(
                torch.tensor(prefix_encoded), 
                num_classes=len(char_to_idx_mapping)
            ).float().to(device)
            
            # Process prefix through model to get proper state
            _, state = model(prefix_tensor.unsqueeze(0), state=state)
        
        # Now generate new characters one by one
        for i in range(num_chars):
            # For generation, we need to feed the last character (or a dummy if this is the first step)
            if len(generated) > 0:
                last_char = generated[-1]
                last_char_idx = char_to_idx_mapping[last_char]
            else:
                # If no prefix, start with some default (this shouldn't happen with your use case)
                last_char_idx = 0
            
            # Create one-hot encoding for single character
            char_tensor = torch.nn.functional.one_hot(
                torch.tensor([last_char_idx]), 
                num_classes=len(char_to_idx_mapping)
            ).float().to(device)
            
            # Get prediction for next character
            output, state = model(char_tensor.unsqueeze(0), state=state)  # Add batch dim
            
            # Get most likely next token
            if temperature > 0:
                # Apply temperature scaling
                output = output / temperature
                probabilities = torch.softmax(output, dim=-1)
                next_token = torch.multinomial(probabilities[0, -1, :], num_samples=1).item()
            else:
                # Default to argmax if temperature is 0
                next_token = output[0, -1, :].argmax().item()
            
            # Decode and append
            next_char = decode_indices_to_string([next_token], idx_to_char_mapping)

            if next_char == ' ' or next_char == '': # EOS token
                break

            generated += next_char
            
            # print(f"Step {i+1}: Added '{next_char}' -> '{generated}'")
            
    return generated

In [None]:
simple_generate(
    prefix="C1cccc2c(c1)",
    indications_tensor = dataset.get_indications_tensor(dataset.indications_names[4]).to(model_trainer.device),
    num_chars=50,
    model=model,
    char_to_idx_mapping=dataset.char_to_idx,
    idx_to_char_mapping=dataset.idx_to_char,
    temperature=0.0,
    device=model_trainer.device
)