In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

sys.path.append("../")

import sys
import torch
from data import datasets
from training import trainer
from modules import ind_generator

import datetime
import plots, utilities
from IPython.display import Image, display, Markdown

In [None]:
dataset = datasets.CharSMILESChEMBLIndications(
    batch_size=128
)

In [None]:
dataset.all_data

In [None]:
len(dataset)

In [None]:
dataset[0][1].shape

In [None]:
model = ind_generator.SmilesIndGeneratorRNN(
    vocab_size = dataset.vocab_size,
    num_indications = dataset.num_indications,
    num_hiddens = 256,
    num_layers = 5,
    learning_rate = 1e-3,
    weight_decay = 1e-4,
    output_dropout = 0.3,
    rnn_dropout = 0.3,
)

In [None]:

# train_new = True  # Set false to load a pre-trained model
# save_model = True  # If trainign a new model, do we want to save it?

# if train_new:
#     model_trainer = trainer.Trainer(max_epochs=16, init_random=None, clip_grads_norm=2.0)
#     model_trainer.fit(model, dataset)

#     if save_model:
#         utilities.save_model_weights("Chembl-Mini-", model, dataset)
# else: 
#     model.load_weights(
#         path = load_model_path,
#     )
#     device="cuda" if torch.cuda.is_available() else "cpu"
#     model.to(device)

In [None]:
model_trainer = trainer.Trainer(max_epochs=64, init_random=None, clip_grads_norm=10.0)
model_trainer.fit(model, dataset)

In [None]:
utilities.save_model_weights("Chembl-Ind-", model, dataset)

In [None]:
losses = utilities.extract_training_losses(model_trainer.metadata)
fig, ax = plots.plot_training_validation_loss(
    training_losses = losses['avg_train_losses'], 
    validation_losses = losses['avg_val_losses']
)
ax.set_yscale('log')

In [None]:
test_batch = dataset[5]

In [None]:
test_batch[0].shape, test_batch[1].shape, test_batch[2].shape

In [None]:
output, _ = model(test_batch[0].unsqueeze(0).to(device=model_trainer.device), test_batch[1].unsqueeze(0).to(device=model_trainer.device))

In [None]:
print(f"Input SMILES: {''.join([dataset.idx_to_char[c] for c in test_batch[2].cpu().numpy()])}")
print(f"Prediction: {''.join([dataset.idx_to_char[c] for c in output.argmax(dim=-1).squeeze().cpu().numpy()])}")

In [None]:
# find where index == 1
print(torch.where(test_batch[1] == 1)[0])

In [None]:
print(dataset.indications_names)

In [None]:
def simple_generate(prefix, num_chars, model, indications_tensor, char_to_idx_mapping, idx_to_char_mapping, temperature = 0.0, device=None):
    """
    Simple character-by-character generation function.
    """

    def decode_indices_to_string(encoded_indices: list, idx_to_char_mapping: dict[int, str]):
        decoded = ''.join([idx_to_char_mapping[int(inx)] for inx in encoded_indices])
        return decoded

    def encode_string_to_indices(smiles_string: str, char_to_idx_mapping: dict[str, int]):
        encoded = [char_to_idx_mapping[c] for c in smiles_string]
        return encoded

    model.eval()
    generated = prefix
    
    with torch.no_grad():
        # Initialize state with indications
        state = model.init_state(indications_tensor.unsqueeze(0).to(device))  # Add batch dim

        # First, process the prefix to get the proper state
        if len(prefix) > 0:
            prefix_encoded = encode_string_to_indices(prefix, char_to_idx_mapping)
            prefix_tensor = torch.nn.functional.one_hot(
                torch.tensor(prefix_encoded), 
                num_classes=len(char_to_idx_mapping)
            ).float().to(device)
            
            # Process prefix through model to get proper state
            _, state = model(prefix_tensor.unsqueeze(0), state=state)
        
        # Now generate new characters one by one
        for i in range(num_chars - len(prefix)):
            # For generation, we need to feed the last character (or a dummy if this is the first step)
            if len(generated) > 0:
                last_char = generated[-1]
                last_char_idx = char_to_idx_mapping[last_char]
            else:
                # If no prefix, start with some default (this shouldn't happen with your use case)
                last_char_idx = 0
            
            # Create one-hot encoding for single character
            char_tensor = torch.nn.functional.one_hot(
                torch.tensor([last_char_idx]), 
                num_classes=len(char_to_idx_mapping)
            ).float().to(device)
            
            # Get prediction for next character
            output, state = model(char_tensor.unsqueeze(0), state=state)  # Add batch dim
            
            # Get most likely next token
            if temperature > 0:
                # Apply temperature scaling
                output = output / temperature
                probabilities = torch.softmax(output, dim=-1)
                next_token = torch.multinomial(probabilities[0, -1, :], num_samples=1).item()
            else:
                # Default to argmax if temperature is 0
                next_token = output[0, -1, :].argmax().item()
            
            # Decode and append
            next_char = decode_indices_to_string([next_token], idx_to_char_mapping)

            if next_char == '£' or next_char == '': # EOS token
            # if next_char == ' ' or next_char == '': # EOS token
                break

            generated += next_char
            
            # print(f"Step {i+1}: Added '{next_char}' -> '{generated}'")
            
    return generated

def robust_generate(generate_function, max_attempts: int, **kwargs):
    n_chars = 100

    attempts = 0
    valid = False
    output = None

    while attempts < max_attempts and valid == False:
        output = generate_function(**kwargs)

        valid = utilities.validate_smiles_string(output)

        if valid:
            return output
        else:
            attempts += 1
        
    print(f"Could not generate valid molecular sample in {max_attempts} attemtps. Aborting.")
    return output



In [None]:
for idx, name in enumerate(dataset.indications_names[::10]):
    print(name)

# Lets generate some  Medications

In [None]:
n_chars = 100

n_valid = 0
n_invalid = 0
images = []
outputs = []
mesh_indices = [idx for idx, _ in enumerate(dataset.indications_names[::10])] + [-1]
max_attempts = 5
for i in mesh_indices:

    output = robust_generate(
        simple_generate,
        max_attempts=max_attempts,
        prefix="",
        indications_tensor = dataset.get_indications_tensor(dataset.indications_names[i]).to(model_trainer.device),
        num_chars=500,
        model=model,
        char_to_idx_mapping=dataset.char_to_idx,
        idx_to_char_mapping=dataset.idx_to_char,
        temperature=0.7,
        device=model_trainer.device

    )

    if output:
        images.append(utilities.draw_molecule(output))
        outputs.append(output)
        n_valid += 1
    if not output:
        n_invalid += 1
        print("Generated SMILES is not valid.")
        images.append(None)
        outputs.append(None)

In [None]:
display(Markdown(f"# Generated {n_valid} valid molecules and {n_invalid} invalid SMILES strings out of {len(mesh_indices)} requested molecules."))
display(Markdown("## Generated Molecules"))
for i, mesh_idx in enumerate(mesh_indices):
    display(Markdown(f"### Indication - {dataset.indications_names[mesh_idx]}"))
    display(Markdown(f"**SMILES:** {outputs[i]}"))

    display(images[i])

## Sanity check... "mesh heading other" should generate essentially random molecules

In [None]:


n_chars = 100

n_valid = 0
n_invalid = 0
images = []
outputs = []
mesh_indices = [-1] * 10
max_attempts = 5

for i in mesh_indices:

    output = robust_generate(
        simple_generate,
        max_attempts=max_attempts,
        prefix="",
        indications_tensor = dataset.get_indications_tensor(dataset.indications_names[i]).to(model_trainer.device),
        num_chars=500,
        model=model,
        char_to_idx_mapping=dataset.char_to_idx,
        idx_to_char_mapping=dataset.idx_to_char,
        temperature=0.7,
        device=model_trainer.device

    )

    if output:
        images.append(utilities.draw_molecule(output))
        outputs.append(output)
        n_valid += 1
    if not output:
        n_invalid += 1
        print("Generated SMILES is not valid.")
        images.append(None)
        outputs.append(None)

In [None]:
display(Markdown(f"# Generated {n_valid} valid molecules and {n_invalid} invalid SMILES strings out of {len(mesh_indices)} requested molecules."))
display(Markdown("## Generated Molecules"))
for i, mesh_idx in enumerate(mesh_indices):
    display(Markdown(f"### Indication - {dataset.indications_names[mesh_idx]}"))
    display(Markdown(f"**SMILES:** {outputs[i]}"))

    display(images[i])