In [1]:
import os
import torch

from main import get_model
from dataset.dataset import (
    LoaderConstructor,
    create_poetryfoundation_dataset,
    create_poems_txt_dataset,
)

In [None]:
cfg = {
    "dataset": "poetryfoundation", 
    "batch_size": 3,
    "max_length": 50,
    "embed_dim": 512,
    "min_text_length": 50, 
}

# Load the dataset
if cfg["dataset"] == "poetryfoundation":
    dataset = create_poetryfoundation_dataset(os.getcwd())
elif cfg["dataset"] == "poems_txt":
    dataset = create_poems_txt_dataset(os.getcwd())

# Construct the dataloaders
lc = LoaderConstructor(
    dataset=dataset,
    batch_size=cfg["batch_size"],
    max_length=cfg["max_length"],
    labels_sequence=False,
    min_freq=3,
)
loaders = {}
for loader in ["train", "validation", "test"]:
    loaders[loader] = lc.construct_loader(split=loader)

input_size = loaders["train"].dataset.input_size
vocab_size = lc.vocab_size
output_size = lc.output_size
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
def initialize_model(model_name):
    """Loads the model with trained weights"""
    model = get_model(
        model=model_name,
        vocab_size=vocab_size,
        embed_dim=cfg["embed_dim"],
        seq_len=input_size,
        output_dim=output_size,
        device=device,
    )

    model_weights = {
        "lstm": f"trained_models/lstm_{cfg['dataset']}_lr=0_001_lastepoch.pt", # MAYBE WE NEED TO TRY WITH lstm_poetryfoundation_lr=0_001_best.pt AND THINK ABOUT LERNNG RATE
        "transformer": f"trained_models/transformer_{cfg['dataset']}_lr=0_001_lastepoch.pt",
    }

    model.load_state_dict(torch.load(model_weights[model_name]))
    model.to(device).eval()
    return model


In [4]:
def evaluate(model, batch, test=False, n_next_words=5):
    """Evaluates the model and predicts next words for poetry"""
    inputs, labels = batch["input_ids"].to(device), batch["labels"].to(device)
    labels = labels.contiguous().view(-1)
    output = model(inputs).view(-1, output_size)

    for i in range(batch["input_ids"].shape[0]):
        input_text = lc.tokenizer.decode(inputs[i].tolist(), target=False)
        target_text = lc.tokenizer.decode(labels[i].unsqueeze(0).tolist(), target=True)[0]
        predicted_text = lc.tokenizer.decode(torch.argmax(output[i], dim=-1).unsqueeze(0).tolist(), target=True)[0]

        print(f"Input: {' '.join(input_text)}")
        print(f"Target: {target_text}")
        print(f"Predicted: {predicted_text}")
        print()

    # Predict next words for new poetic lines
    if test:
        mini_dataset = [
            {"text": "The moonlight danced upon the silent lake", "predictions": []},
            {"text": "A whisper of wind through the autumn leaves", "predictions": []},
            {"text": "The poet's heart was filled with endless wonder", "predictions": []},
            {"text": "Beneath the stars, she walked alone", "predictions": []},
            {"text": "His voice was music, soft and deep", "predictions": []},
        ]

        # Tokenize the texts
        tokenised_samples = lc.tokenizer.create_tokens(mini_dataset)
        tokenised_samples = lc.tokenizer.pad_sequences(tokenised_samples)
        encodings = lc.tokenizer.encode(tokenised_samples)

        # No labels in this case
        inputs = encodings[:, 1:].to(device)
        decoded_inputs = [lc.tokenizer.decode(sample.tolist(), target=False) for sample in inputs]

        with torch.no_grad():
            for _ in range(n_next_words):
                output = model(inputs)
                next_token = torch.argmax(output, dim=-1)

                # Decode the tokens to get the words
                decoded_tokens = lc.tokenizer.decode(next_token.tolist(), target=True)

                # Encode the tokens to get the input for the next iteration
                encoded_tokens = lc.tokenizer.encode([decoded_tokens]).to(device)

                # Concatenate the input with the new tokens
                inputs = torch.cat([inputs[:, 1:], encoded_tokens.reshape(-1, 1)], dim=1)

                for i, token in enumerate(decoded_tokens):
                    mini_dataset[i]["predictions"].append(token)

        for sample, input_sentence in zip(mini_dataset, decoded_inputs):
            print(f"Input: {' '.join(input_sentence)}")
            print(f"Predicted next words: {' '.join(sample['predictions'])}")
            print()

In [None]:
# **Run Evaluation**
batch = next(iter(loaders["train"]))
for model_name in ["lstm", "transformer"]:
    print("*" * 50)
    print(f"Evaluating model: {model_name}")
    print("*" * 50)
    model = initialize_model(model_name)
    evaluate(model, batch)
    print("*" * 50)


**************************************************
Evaluating model: lstm
**************************************************


Input: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> i ll open the
Target: window
Predicted: window

Input: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> witch
Target: wife
Predicted: doctor

Input: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> between the house and the
Target: hill
Predicted: hill

**************************************************
**************************************************
Evaluating model: transformer
**************************************************
Input: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> i ll open the
Target: window
Predicted: window

Input: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> witch
Target: wife
Predicted: doctor

Input: <pad> <pad> <pad> <pad> <pad> <pad> <

In [None]:
batch = next(iter(loaders["test"]))
for model_name in ["lstm", "transformer"]:
    print("*" * 50)
    print(f"Evaluating model: {model_name}")
    print("*" * 50)
    model = initialize_model(model_name)
    evaluate(model, batch, test=True, n_next_words=5)
    print("*" * 50)