In [1]:
import os
import torch


from main import get_model
from dataset.dataset import LoaderConstructor, create_alicewonderland_dataset, create_rocstories_dataset

from datasets import load_dataset


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cfg = {
    'dataset': 'alicewonderland',
    'batch_size': 3,
    'max_length': 20,
    'embed_dim': 512,
    'min_text_length': 100, 
}

# Load the dataset
if "wikitext" in cfg['dataset']:
    dataset = load_dataset("wikitext", f"{cfg['dataset']}-raw-v1")
    for split in dataset.keys():
        dataset[split] = dataset[split].filter(
            lambda x: len(x["text"]) > cfg['min_text_length']
        )

elif cfg['dataset'] == "rocstories":
    dataset = create_rocstories_dataset(os.getcwd())
elif cfg['dataset'] == "alicewonderland":
    dataset = create_alicewonderland_dataset(os.getcwd())

# Construct the dataloaders
lc = LoaderConstructor(
    dataset=dataset,
    batch_size=cfg['batch_size'],
    max_length=cfg['max_length'],
    labels_sequence=False,
    min_freq=1 if cfg['dataset'] == "alicewonderland" else 3,
)
loaders = {}
for loader in ["train", "validation", "test"]:
    loaders[loader] = lc.construct_loader(split=loader)

input_size = loaders["train"].dataset.input_size
vocab_size = lc.vocab_size
output_size = lc.output_size
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
def initialize_model(model_name):
    model = get_model(
        model=model_name,
        vocab_size=vocab_size,
        embed_dim=cfg['embed_dim'],
        seq_len=input_size,
        output_dim=output_size,
        device=device,
    )

    model_weights = {
        'lstm': f"trained_models/lstm_{cfg['dataset']}_lr=0_0001_lastepoch.pt",
        'transformer': f"trained_models/transformer_{cfg['dataset']}_lr=0_0001_lastepoch.pt",
        'xlstm': f"trained_models/xlstm_{cfg['dataset']}_lr=0_0001_lastepoch.pt",
    }

    model.load_state_dict(torch.load(model_weights[model_name]))
    model.to(device).eval()
    return model

In [4]:
def evaluate(model, batch, test=False, n_next_words=5):    
    # Forward pass
    inputs, labels = batch["input_ids"].to(device), batch["labels"].to(device)
    labels = labels.contiguous().view(-1)
    output = model(inputs).view(-1, output_size)

    for i in range(batch["input_ids"].shape[0]):
        input_text = lc.tokenizer.decode(inputs[i].tolist(), target=False)
        target_text = lc.tokenizer.decode(labels[i].unsqueeze(0).tolist(), target=True)[0]
        predicted_text = lc.tokenizer.decode(torch.argmax(output[i], dim=-1).unsqueeze(0).tolist(), target=True)[0]
        print(f"Input sentence: {' '.join(input_text)}")
        print(
            f"Target sentence: {target_text}"
        )
        print(
            f"Predicted sentence: {predicted_text}"
        )
        print()
        
    # Predict next words
    if test:
        mini_dataset = [
            {"text": "She was an extremely fast runner, she went to the forest every day to practice and be", "predictions": []},
            {"text": "The sun went down, making long shadows as the kids played by the big old tree when her mum", "predictions": []},
            {"text": "I would have gotten the new job, but my work wasn't good", "predictions": []},
            {"text": "She read a bedtime story with her parents and then went to sleep to rest for the next", "predictions": []},
            {"text": "Jason didn't understand why his parents wouldn't let him buy his favourite", "predictions": []},
            {"text": "Verena is the most beautiful girl I have ever seen. I love her blue eyes", "predictions": []},
        ]

        # Tokenize the texts
        tokenised_samples = lc.tokenizer.create_tokens(mini_dataset)
        tokenised_samples = lc.tokenizer.pad_sequences(tokenised_samples)
        encodings = lc.tokenizer.encode(tokenised_samples)

        # No labels in this case
        inputs = encodings[:, 1:].to(device)
        decoded_inputs = [lc.tokenizer.decode(sample.tolist(), target=False) for sample in inputs]

        with torch.no_grad():
            for _ in range(n_next_words):
                output = model(inputs)
                next_token = torch.argmax(output, dim=-1)

                # Decode the tokens to get the words
                decoded_tokens = lc.tokenizer.decode(next_token.tolist(), target=True)

                # Encode the tokens to get the input for the next iteration
                encoded_tokens = lc.tokenizer.encode([decoded_tokens]).to(device)

                # Concatenate the input with the new tokens
                inputs = torch.cat([inputs[:, 1:], encoded_tokens.reshape(-1, 1)], dim=1)
                
                for i, token in enumerate(decoded_tokens):
                    mini_dataset[i]["predictions"].append(token)

        for sample, input_sentence in zip(mini_dataset, decoded_inputs):
            print(f"Input sentence: {' '.join(input_sentence)}")
            print(f"Predicted next words: {' '.join(sample['predictions'])}")
            print()


In [5]:
batch = next(iter(loaders["train"]))
for model_name in ["xlstm", "lstm", "transformer"]:
    print('*' * 50)
    print(f"Evaluating model: {model_name}")
    print('*' * 50)
    model = initialize_model(model_name)
    evaluate(model, batch)
    print('*' * 50)

**************************************************
Evaluating model: xlstm
**************************************************
Input sentence: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> like after the candle is blown out for she could not
Target sentence: remember
Predicted sentence: remember

Input sentence: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> before seen a rabbit with either a waistcoat pocket or a watch
Target sentence: to
Predicted sentence: to

Input sentence: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> pocket and looked at it and then hurried on alice started
Target sentence: to
Predicted sentence: to

**************************************************
**************************************************
Evaluating model: lstm
**************************************************
Input sentence: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> like after the candle is blown out for she could not
Target sentence: remember
Predicted sentence: of

Inp

In [8]:
batch = next(iter(loaders["test"]))
for model_name in ["xlstm", "lstm", "transformer"]:
    print('*' * 50)
    print(f"Evaluating model: {model_name}")
    print('*' * 50)
    model = initialize_model(model_name)
    evaluate(model, batch, test=True, n_next_words=5)
    print('*' * 50)

**************************************************
Evaluating model: xlstm
**************************************************
Input sentence: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> and so she went on taking first one side and then the
Target sentence: other
Predicted sentence: cake

Input sentence: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> remarking as it went one side will make you grow <oov>
Target sentence: and
Predicted sentence: i

Input sentence: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> but it goes on they all <oov> from him to you
Target sentence: said
Predicted sentence: re

Input sentence: <pad> <pad> <pad> she was an extremely fast <oov> she went to the <oov> every day to practice and be
Predicted next words: going are old woman boon

Input sentence: <pad> the sun went down making long <oov> as the <oov> <oov> by the <oov> old tree when her <oov>
Predicted next words: i eat i eat i

Input sentence: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <p