In [None]:
# notebook/00_setup.pynb
'''
load tokenizer & models, load small subset of WikiText-2 and a classification dataset English subset.
Run baseline eval of pretrained gpt2 on WikiText-2 (record perplexity) and simple prompt eval on SST-2 (zero-shot scoring).
'''
import math
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# --------- Load model & tokenizer ----------
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model.eval()

![AutoModel + AutoTokenizer diagram](../media/automodalcausallm_tokenizer.png)

In [None]:
# GPT-2 usually has no pad token; set pad_token to eos_token for batching
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model.resize_token_embeddings(len(tokenizer))  # safe no-op if unchanged

# ---------- Helper functions ----------
def batchify_tokenized(batch):
    # batch is a list of dictionaries, each with 'input_ids'
    input_ids_list = [torch.tensor(x["input_ids"], dtype=torch.long) for x in batch]
    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids_list,
        batch_first=True,
        padding_value=tokenizer.pad_token_id,
    )
    attention_mask = (input_ids != tokenizer.pad_token_id).long()
    return input_ids, attention_mask


# Tokenize examples (no truncation; we will feed variable-length sequences)
def tokenize_wiki(example):
    enc = tokenizer(example["text"], add_special_tokens=False)
    return {"input_ids": enc["input_ids"]}

In [None]:
# ---------- WikiText-2 Loading ----------
# Load small subset (validation). wikitext-2-raw-v1 keeps original whitespace.
wikitext = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")

# Subsample: choose a small number for quick runs on Colab; increase if needed.
NUM_WIKI_EXAMPLES = 512
wikitext = wikitext.select(range(min(len(wikitext), NUM_WIKI_EXAMPLES)))
wikitext_tok = wikitext.map(tokenize_wiki, remove_columns=wikitext.column_names)

# DataLoader
BATCH_SIZE_WIKI = 8
wiki_loader = DataLoader(wikitext_tok, batch_size=BATCH_SIZE_WIKI, collate_fn=batchify_tokenized)

In [None]:
# ---------- WikiText-2 Perplexity ----------
total_nll = 0.0
total_tokens = 0

with torch.no_grad():
    for input_ids, attention_mask in wiki_loader:
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        # For causal LM we can pass labels=input_ids; model returns mean loss over non-ignored tokens
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
        # outputs.loss is averaged per token (excluding -100 labels), so multiply by number of tokens to get sum NLL
        # Count tokens for this batch:
        n_tokens = int(attention_mask.sum().item())
        batch_loss = outputs.loss.item() * n_tokens

        total_nll += batch_loss
        total_tokens += n_tokens

per_token_loss = total_nll / total_tokens
perplexity = math.exp(per_token_loss)
print(f"WikiText-2 subset tokens: {total_tokens}")
print(f"Per-token cross-entropy (nats): {per_token_loss:.4f}")
print(f"Perplexity: {perplexity:.4f}")

In [None]:
# ---------- SST-2 Zero-shot prompt scoring ----------
# Load SST-2
sst = load_dataset("glue", "sst2", split="validation")
NUM_SST_EXAMPLES = 512   # keep manageable; change as desired
sst = sst.select(range(min(len(sst), NUM_SST_EXAMPLES)))

# Simple prompt template and label texts
def make_prompt(sentence):
    # Use a compact template; you can change wording to test sensitivity.
    return f"Sentence: {sentence}\nQuestion: Is the sentiment of the sentence positive or negative?\nAnswer:"

label_texts = [" positive", " negative"]  # leading space usually helps BPE for GPT-2

# Tokenize prompt-only (we will append labels and compute log-probs)
def tokenize_prompt(example):
    prompt = make_prompt(example["sentence"])
    enc = tokenizer(prompt, add_special_tokens=False)
    return {"prompt_ids": enc["input_ids"], "label": example["label"]}

sst_tok = sst.map(tokenize_prompt, remove_columns=sst.column_names)

# Collate: for each example in a batch we will create two sequences: prompt+label_positive, prompt+label_negative
def collate_prompt_label(batch):
    # Build a list of sequences (two per example)
    seqs = []
    seq_meta = []  # (example_idx, label_idx, prompt_len, label_len)
    for i, row in enumerate(batch):
        prompt_ids = row["prompt_ids"]
        for label_idx, label_text in enumerate(label_texts):
            label_ids = tokenizer(label_text, add_special_tokens=False)["input_ids"]
            seq = prompt_ids + label_ids
            seqs.append(torch.tensor(seq, dtype=torch.long))
            seq_meta.append((i, label_idx, len(prompt_ids), len(label_ids)))
    # pad sequences
    input_ids = torch.nn.utils.rnn.pad_sequence(seqs, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_mask = (input_ids != tokenizer.pad_token_id).long()
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "meta": seq_meta,
        "orig_labels": [row["label"] for row in batch],
    }

BATCH_SIZE_SST = 16
sst_loader = DataLoader(sst_tok, batch_size=BATCH_SIZE_SST, collate_fn=collate_prompt_label)

correct = 0
total = 0

with torch.no_grad():
    for batch in sst_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        meta = batch["meta"]
        orig_labels = batch["orig_labels"]  # length = batch_size

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # shape (N_seq, seq_len, vocab)
        log_probs = F.log_softmax(logits, dim=-1)

        # For each sequence (prompt+label), sum log-probs of label tokens (they are at positions prompt_len .. prompt_len+label_len-1)
        seq_scores = []  # one score per sequence (i.e., per prompt+label)
        for seq_idx, (example_idx, label_idx, prompt_len, label_len) in enumerate(meta):
            # label token positions: absolute indices = prompt_len ... prompt_len+label_len-1
            # For each label token at absolute position j, predicted by logits at position j-1
            # So we index log_probs[seq_idx, j-1, token_id]
            lp_sum = 0.0
            for tok_pos in range(label_len):
                abs_pos = prompt_len + tok_pos
                pred_logits_pos = abs_pos - 1
                # If the label starts at position 0 (no prompt), then pred_logits_pos == -1; but our prompts always non-empty.
                token_id = input_ids[seq_idx, abs_pos].item()
                lp = log_probs[seq_idx, pred_logits_pos, token_id].item()
                lp_sum += lp
            seq_scores.append(lp_sum)

        # seq_scores is length batch_size*2; group every 2 to get label scores per example
        for i in range(len(orig_labels)):
            score_pos = seq_scores[2 * i + 0]   # since label_texts[0] is " positive"
            score_neg = seq_scores[2 * i + 1]
            pred_label = 1 if score_pos > score_neg else 0  # SST-2 labels: 1 positive, 0 negative
            if pred_label == orig_labels[i]:
                correct += 1
            total += 1

accuracy = correct / total if total > 0 else 0.0
print(f"SST-2 zero-shot accuracy on subset ({total} examples): {accuracy:.4%}")