In [28]:
import os
import torch
import torch.nn.functional as F
import torchmetrics
import math
import json
import pandas as pd

from main import get_model
from dataset.dataset import (
    LoaderConstructor,
    create_poetryfoundation_dataset,
    create_story_txt_dataset
)

In [29]:
cfg = [
   {
        "dataset": "poetryfoundation",
        "batch_size": 3,
        "max_length": 50,
        "embed_dim": 512,
        "min_text_length": 50,
    },
    {
        "dataset": "story_txt",
        "batch_size": 3,
        "max_length": 50,
        "embed_dim": 512,
        "min_text_length": 50,
    }
]

def load_and_construct(cfg):
    print(f"\nEvaluating dataset: {cfg['dataset']}")
    
    # Load the dataset
    if cfg["dataset"] == "poetryfoundation":
        dataset = create_poetryfoundation_dataset(os.getcwd())
    elif cfg["dataset"] == "story_txt":
        dataset = create_story_txt_dataset(os.getcwd())
    else:
        raise ValueError(f"Unknown dataset: {cfg['dataset']}")

    # Construct the dataloaders
    lc = LoaderConstructor(
        dataset=dataset,
        batch_size=cfg["batch_size"],
        max_length=cfg["max_length"],
        labels_sequence=False,
        min_freq=3,
    )

    loaders = {loader: lc.construct_loader(split=loader) for loader in ["train", "validation", "test"]}

    input_size = getattr(loaders["train"].dataset, "input_size", None)
    if input_size is None:
        raise AttributeError("Dataset object does not have 'input_size' attribute. Check implementation.")

    vocab_size = lc.vocab_size
    output_size = lc.output_size
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    
    return dataset, lc, loaders, input_size, vocab_size, output_size, device


In [30]:
def get_best_model_weights(dataset):
    """Finds the best model weight files for both LSTM and Transformer based on the highest accuracy."""
    model_types = ["lstm", "transformer"]
    best_models = {}
    
    for model_name in model_types:
        model_files = [
            f"trained_models/{model_name}_{dataset}_lr=0_001_best.pt",
            f"trained_models/{model_name}_{dataset}_lr=0_0001_best.pt",
            f"trained_models/{model_name}_{dataset}_lr=0_0005_best.pt"
        ]
        
        best_model = None
        best_accuracy = -1
        
        for model_file in model_files:
            stats_file = model_file.replace("_best.pt", "_stats.json")
            if os.path.exists(stats_file):
                with open(stats_file, "r") as f:
                    stats = json.load(f)
                    
                    for lr, details in stats.items():
                        if "val_accuracy" in details:
                            accuracy = details["val_accuracy"]
                            if accuracy > best_accuracy:
                                best_accuracy = accuracy
                                best_model = details["best_model"]
        
        if best_model:
            best_models[model_name] = best_model

    return best_models

In [31]:
model_weights = {
    "poetryfoundation": get_best_model_weights("poetryfoundation"),
    "story_txt": get_best_model_weights("story_txt")
}

def initialize_model(model_name, vocab_size, input_size, output_size, device, cfg):
    """Loads the model with trained weights"""
    model = get_model(
        model=model_name,
        vocab_size=vocab_size,
        embed_dim=cfg["embed_dim"],
        seq_len=input_size,
        output_dim=output_size,
        device=device,
    )
    best_model_weight = model_weights[cfg['dataset']][model_name]
    
    model.load_state_dict(torch.load(best_model_weight))
    model.to(device).eval()
    return model


In [32]:
def calculate_perplexity(loss):
    return math.exp(loss)

def evaluate(model, batch, output_size, device, lc, test=False, n_next_words=5):
    criterion = torch.nn.CrossEntropyLoss()
    accuracy_metric = torchmetrics.Accuracy(task="multiclass", num_classes=output_size).to(device)
    top3_accuracy_metric = torchmetrics.Accuracy(task="multiclass", num_classes=output_size, top_k=5).to(device)
    
    """Evaluates the model and predicts next words for poetry"""
    inputs, labels = batch["input_ids"].to(device), batch["labels"].to(device)
    labels = labels.contiguous().view(-1)
    output = model(inputs).view(-1, output_size)

    loss = criterion(output, labels).item()
    accuracy = accuracy_metric(torch.argmax(output, dim=-1), labels).item()
    perplexity = calculate_perplexity(loss)
    top3_accuracy = top3_accuracy_metric(output, labels).item()

    print(f"Loss: {loss:.4f}")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Perplexity: {perplexity:.4f}")
    print(f"Top-3 Accuracy: {top3_accuracy:.4f}")

    # Predict next words for new poetic lines
    if test:
        mini_dataset = [
            {"text": "The moonlight danced upon the silent lake", "predictions": []},
            {"text": "A whisper of wind through the autumn leaves", "predictions": []},
            {"text": "The poet's heart was filled with endless wonder", "predictions": []},
            {"text": "Beneath the stars, she walked alone", "predictions": []},
            {"text": "His voice was music, soft and deep", "predictions": []},
        ]

        # Tokenize the texts
        tokenised_samples = lc.tokenizer.create_tokens(mini_dataset)
        tokenised_samples = lc.tokenizer.pad_sequences(tokenised_samples)
        encodings = lc.tokenizer.encode(tokenised_samples)

        # No labels in this case
        inputs = encodings[:, 1:].to(device)
        decoded_inputs = [lc.tokenizer.decode(sample.tolist(), target=False) for sample in inputs]

        with torch.no_grad():
            for _ in range(n_next_words):
                output = model(inputs)
                next_token = torch.argmax(output, dim=-1)

                # Decode the tokens to get the words
                decoded_tokens = lc.tokenizer.decode(next_token.tolist(), target=True)

                # Encode the tokens to get the input for the next iteration
                encoded_tokens = lc.tokenizer.encode([decoded_tokens]).to(device)

                # Concatenate the input with the new tokens
                inputs = torch.cat([inputs[:, 1:], encoded_tokens.reshape(-1, 1)], dim=1)

                for i, token in enumerate(decoded_tokens):
                    mini_dataset[i]["predictions"].append(token)

        for sample, input_sentence in zip(mini_dataset, decoded_inputs):
            print(f"Input: {' '.join(input_sentence)}")
            print(f"Predicted next words: {' '.join(sample['predictions'])}")
            print()

In [33]:
# **Run Evaluation**
dataset, lc, loaders, input_size, vocab_size, output_size, device = load_and_construct(cfg[0])
batch = next(iter(loaders["train"]))
for model_name in ["lstm", "transformer"]:
    print("*" * 50)
    print(f"Evaluating model: {model_name}")
    print("*" * 50)
    model = initialize_model(model_name, vocab_size, input_size, output_size, device, cfg[0])
    evaluate(model, batch, output_size, device, lc)
    print("*" * 50)



Evaluating dataset: poetryfoundation
**************************************************
Evaluating model: lstm
**************************************************


RuntimeError: Error(s) in loading state_dict for LSTM:
	size mismatch for embedding.weight: copying a param with shape torch.Size([2060, 512]) from checkpoint, the shape in current model is torch.Size([2138, 512]).
	size mismatch for fc.weight: copying a param with shape torch.Size([2058, 512]) from checkpoint, the shape in current model is torch.Size([2136, 512]).
	size mismatch for fc.bias: copying a param with shape torch.Size([2058]) from checkpoint, the shape in current model is torch.Size([2136]).

In [None]:
batch = next(iter(loaders["test"]))
for model_name in ["lstm", "transformer"]:
    print("*" * 50)
    print(f"Evaluating model: {model_name}")
    print("*" * 50)
    model = initialize_model(model_name, vocab_size, input_size, output_size, device, cfg[0])
    evaluate(model, batch, output_size, device, lc, test=True, n_next_words=3)
    print("*" * 50)

**************************************************
Evaluating model: lstm
**************************************************
Loss: 7.8521
Accuracy: 0.0000
Perplexity: 2571.0934
Top-3 Accuracy: 0.0000
Input: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> the moonlight <oov> upon the silent lake
Predicted next words: song life life

Input: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> a <oov> of wind through the autumn leaves
Predicted next words: world year day

Input: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <p

In [None]:
# **Run Evaluation** for story dataset
dataset, lc, loaders, input_size, vocab_size, output_size, device = load_and_construct(cfg[1])
batch = next(iter(loaders["train"]))
for model_name in ["lstm", "transformer"]:
    print("*" * 50)
    print(f"Evaluating model: {model_name}")
    print("*" * 50)
    model = initialize_model(model_name, vocab_size, input_size, output_size, device, cfg[1])
    evaluate(model, batch, output_size, device, lc)
    print("*" * 50)


Evaluating dataset: story_txt
**************************************************
Evaluating model: transformer
**************************************************
Loss: 0.1580
Accuracy: 1.0000
Perplexity: 1.1712
Top-3 Accuracy: 1.0000
**************************************************


In [None]:
batch = next(iter(loaders["test"]))
for model_name in ["lstm", "transformer"]:
    print("*" * 50)
    print(f"Evaluating model: {model_name}")
    print("*" * 50)
    model = initialize_model(model_name, vocab_size, input_size, output_size, device, cfg[1])
    evaluate(model, batch, output_size, device, lc, test=True, n_next_words=3)
    print("*" * 50)

**************************************************
Evaluating model: transformer
**************************************************
Loss: 0.4494
Accuracy: 0.6667
Perplexity: 1.5673
Top-3 Accuracy: 1.0000
Input: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <oov> <oov> upon <oov> <oov>
Predicted next words: . ] ?

Input: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> whisper wind <oov> leaves
Predicted next words: ? ? ?

Input: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <p