In [None]:
from transformers import GPT2LMHeadModel, AutoTokenizer

model = GPT2LMHeadModel.from_pretrained("./gpt2-finetuned-nwp-final")
tokenizer = AutoTokenizer.from_pretrained("./gpt2-finetuned-nwp-final")
model.eval() # Set model to evaluation mode

In [None]:
from datasets import load_dataset

ds = load_dataset("wikitext", "wikitext-2-v1")

def tokenize(examples):
    return tokenizer(examples["text"], max_length=256, padding="max_length", truncation=True)

tokenized_dataset = ds.map(tokenize, batched=True)

In [None]:
test_ds = tokenized_dataset["test"].shuffle(seed=42).range(100)  # Use a smaller subset for testing

In [None]:
from transformers import Trainer, TrainingArguments
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()

training_args = TrainingArguments(output_dir="./dummy", per_device_eval_batch_size=1, fp16=True, eval_accumulation_steps=8, remove_unused_columns=False)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
)

# Predict on test set
outputs = trainer.predict(test_ds)

In [None]:
from sklearn.metrics import top_k_accuracy_score
import torch
import numpy as np
import math

def compute_metrics(eval_pred):
    logits, labels = eval_pred

    # Shift so model predicts token t+1
    shift_logits = torch.tensor(logits)[..., :-1, :].contiguous()
    shift_labels = torch.tensor(labels)[..., 1:].contiguous()

    # Flatten the tensors
    shift_logits = shift_logits.view(-1, shift_logits.size(-1))
    shift_labels = shift_labels.view(-1)

    # Mask out padding
    valid = shift_labels != -100
    y_true = shift_labels[valid].numpy()
    y_pred = shift_logits[valid].numpy()

    # Top-k accuracy
    topk_acc = top_k_accuracy_score(y_true, y_pred, k=5, labels=list(range(50257))) # GPT-2 vocab size is 50257

    return {
        "top5_accuracy": topk_acc
    }

In [None]:
from evaluate import load

perplexity = load("perplexity")

raw_test_texts = test_ds["text"]

raw_test_texts = [t for t in test_ds["text"] if t.strip() != ""]

results = perplexity.compute(
    predictions=raw_test_texts,
    model_id="./gpt2-finetuned-nwp-final",
)

print("Perplexity:", results["perplexity"])