In [7]:
import os
import torch


from main import get_model
from dataset.dataset import LoaderConstructor, create_alicewonderland_dataset, create_rocstories_dataset

from datasets import load_dataset


In [8]:
cfg = {
    'dataset': 'rocstories',
    'batch_size': 3,
    'max_length': 20,
    'embed_dim': 512,
    'min_text_length': 100, 
    'tokenizer': 'torchtext'
}

# Load the dataset
if "wikitest" in cfg['dataset']:
    dataset = load_dataset("wikitext", f"{cfg['dataset']}-raw-v1")
    for split in dataset.keys():
        dataset[split] = dataset[split].filter(
            lambda x: len(x["text"]) > cfg.min_text_length
        )

elif cfg['dataset'] == "rocstories":
    dataset = create_rocstories_dataset(os.getcwd())
elif cfg['dataset'] == "alicewonderland":
    dataset = create_alicewonderland_dataset(os.getcwd())

# Construct the dataloaders
lc = LoaderConstructor(
    dataset=dataset,
    batch_size=cfg['batch_size'],
    max_length=cfg['max_length'],
    tokenizer_type=cfg['tokenizer'],
    labels_sequence=False,
)
loaders = {}
for loader in ["train", "validation", "test"]:
    loaders[loader] = lc.construct_loader(split=loader)

input_size = loaders["train"].dataset.input_size
vocab_size = lc.vocab_size
output_size = lc.output_size
device = "cuda" if torch.cuda.is_available() else "cpu"

In [9]:
def initialize_model(model_name):
    model = get_model(
        model=model_name,
        vocab_size=vocab_size,
        embed_dim=cfg['embed_dim'],
        seq_len=input_size,
        output_dim=output_size,
        device=device,
    )

    model_weights = {
        'lstm': f"trained_models/lstm_{cfg['dataset']}_lr=0_0001_lastepoch.pt",
        'transformer': f"trained_models/transformer_{cfg['dataset']}_lr=0_0001_lastepoch.pt",
        'xlstm': f"trained_models/xlstm_{cfg['dataset']}_lr=0_0001_lastepoch.pt",
    }

    model.load_state_dict(torch.load(model_weights[model_name]))
    model.to(device).eval()
    return model

In [16]:
def evaluate(model, batch, test=False, n_next_words=5):    
    # Forward pass
    inputs, labels = batch["input_ids"].to(device), batch["labels"].to(device)
    labels = labels.contiguous().view(-1)
    output = model(inputs).view(-1, output_size)

    for i in range(batch["input_ids"].shape[0]):
        print(f"Input sentence: {lc.tokenizer.decode(inputs[i].tolist(), target=False)}")
        print(
            f"Target sentence: {lc.tokenizer.decode(labels[i].unsqueeze(0).tolist(), target=True)}"
        )
        print(
            f"Predicted sentence: {lc.tokenizer.decode(torch.argmax(output[i], dim=-1).unsqueeze(0).tolist(), target=True)}"
        )
        print()
        
    # Predict next words
    if test:
        mini_dataset = [
            {"text": "The quick brown fox jumps over the lazy dog and then went", "predictions": []},
            {"text": "I can't wait to", "predictions": []},
            {"text": "The capital of France is Paris, a city full of", "predictions": []},
            {"text": "The best way to learn mathemathics", "predictions": []},
            {"text": "I am so mad, my boss is always", "predictions": []},
        ]

        # Tokenize the texts
        tokenised_samples = lc.tokenizer.create_tokens(mini_dataset)
        tokenised_samples = lc.tokenizer.pad_sequences(tokenised_samples)
        encodings = lc.tokenizer.encode(tokenised_samples)

        # No labels in this case
        inputs = encodings[:, 1:].to(device)
        with torch.no_grad():
            for _ in range(n_next_words):
                output = model(inputs)
                next_token = torch.argmax(output, dim=-1)
                for i in range(len(mini_dataset)):
                    mini_dataset[i]["predictions"].append(
                        lc.tokenizer.decode(next_token[i].unsqueeze(0).tolist(), target=True)
                    )
                inputs = torch.cat([inputs[:, 1:], next_token.reshape(-1, 1)], dim=1)

        for sample in mini_dataset:
            print(f"Input sentence: {sample['text']}")
            print(f"Predicted next words: {' '.join(sample['predictions'])}")
            print()


In [21]:
batch = next(iter(loaders["train"]))
for model_name in ["xlstm", "lstm", "transformer"]:
    print('*' * 50)
    print(f"Evaluating model: {model_name}")
    print('*' * 50)
    model = initialize_model(model_name)
    evaluate(model, batch)
    print('*' * 50)

**************************************************
Evaluating model: xlstm
**************************************************
Input sentence: jane was asleep when a thunderstorm started outside she didn realize that the electricity went out when she woke up
Target sentence: it
Predicted sentence: it

Input sentence: tim and jessica were in a long term relationship together one day jessica ended the relationship tim cried and wept
Target sentence: for
Predicted sentence: for

Input sentence: his bed shook he turned on the news and learned about the <oov> jake was calm and thought it was
Target sentence: interesting
Predicted sentence: interesting

**************************************************
**************************************************
Evaluating model: lstm
**************************************************
Input sentence: jane was asleep when a thunderstorm started outside she didn realize that the electricity went out when she woke up
Target sentence: it
Predicted sentenc

In [20]:
batch = next(iter(loaders["test"]))
for model_name in ["xlstm", "lstm", "transformer"]:
    print('*' * 50)
    print(f"Evaluating model: {model_name}")
    print('*' * 50)
    model = initialize_model(model_name)
    evaluate(model, batch, test=True, n_next_words=20)
    print('*' * 50)

**************************************************
Evaluating model: xlstm
**************************************************
Input sentence: stormy asked if i could be her new mama she asked because she was little and didn know who her
Target sentence: mother
Predicted sentence: phone

Input sentence: mary was an animal care specialist at a dog and cat rescue shelter she could usually work with even the
Target sentence: most
Predicted sentence: one

Input sentence: fred exercised a lot and was a very healthy person he enjoyed running every evening one evening he felt a
Target sentence: little
Predicted sentence: cold

Input sentence: The quick brown fox jumps over the lazy dog and then went
Predicted next words: to had the were to had the were to but as up everyone with got son was talk last the

Input sentence: I can't wait to
Predicted next words: never to now about one to she at died to could his to made a job a food ryan later

Input sentence: The capital of France is Paris, a cit