In [None]:
import argparse
import os
from typing import Optional, Tuple

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
from datasets import load_dataset, Dataset

In [None]:
device = torch.device("cuda")

model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2").to(device)
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")

## AG News

In [None]:
LABEL_TEXTS = ["world", "sports", "business", "sci/tech"]

raw = load_dataset("ag_news")

def add_label_text(batch):
    batch["label_text"] = LABEL_TEXTS[batch["label"]]
    return batch

raw = raw.map(add_label_text)

In [None]:
from torch.utils.data import DataLoader

MAX_INPUT = 256
MAX_TARGET = 256
PREFIX = "classify: "

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

def tokenize(batch):
    inputs = [PREFIX + x for x in batch["text"]]
    targets = batch["label_text"]

    model_inputs = tokenizer(
        inputs,
        truncation=True,
        padding="max_length",
        max_length=MAX_INPUT,
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            targets,
            truncation=True,
            padding="max_length",
            max_length=MAX_TARGET,
        )

    labels_ids = []
    for seq in labels["input_ids"]:
        labels_ids.append([tok if tok != tokenizer.pad_token_id else -100 for tok in seq])

    model_inputs["labels"] = labels_ids
    return model_inputs

tok = raw.map(tokenize, batched=True)
tok = tok.remove_columns(["text", "label_text"])
tok.set_format(type="torch")

train_dataset = tok["train"].select(range(5000))
val_dataset = tok["test"].select(range(1000))

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4)


In [None]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=2e-5)
num_epochs = 10

best_val_loss = float("inf")
patience = 1
patience_counter = 0

for epoch in range(num_epochs):

    # -----------------------
    # TRAINING
    # -----------------------
    model.train()
    total_train_loss = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch} training"):
        input_ids = batch["input_ids"].to(model.device)
        attn = batch["attention_mask"].to(model.device)
        labels = batch["labels"].to(model.device)

        out = model(input_ids, attention_mask=attn, labels=labels)
        loss = out.loss

        total_train_loss += loss.item()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    avg_train_loss = total_train_loss / len(train_loader)

    # -----------------------
    # VALIDATION
    # -----------------------
    model.eval()
    total_val_loss = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch} validation"):
            input_ids = batch["input_ids"].to(model.device)
            attn = batch["attention_mask"].to(model.device)
            labels = batch["labels"].to(model.device)

            out = model(input_ids, attention_mask=attn, labels=labels)
            total_val_loss += out.loss.item()

    avg_val_loss = total_val_loss / len(val_loader)

    print(f"\nEpoch {epoch}: Train={avg_train_loss:.4f}  Val={avg_val_loss:.4f}")

    # -----------------------
    # EARLY STOPPING
    # -----------------------
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(model.state_dict(), "agnews-best.pth")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

print("Training complete.")


In [None]:
def map_prediction_to_label(text):
    t = text.strip().lower()
    for lbl in LABEL_TEXTS:
        if lbl in t:
            return lbl
    first = t.split()[0] if t else ""
    for lbl in LABEL_TEXTS:
        if lbl.startswith(first) or first.startswith(lbl[:3]):
            return lbl
    return None

model.eval()
correct = 0
total = 0

eval_loader = DataLoader(val_dataset, batch_size=4)

with torch.no_grad():
    for batch in tqdm(eval_loader, desc="Evaluating generation"):
        input_ids = batch["input_ids"].to(model.device)
        attn = batch["attention_mask"].to(model.device)

        gen = model.generate(
            input_ids=input_ids,
            attention_mask=attn,
            max_new_tokens=16,
            num_beams=4,
        )

        outputs = tokenizer.batch_decode(gen, skip_special_tokens=True)
        preds = [map_prediction_to_label(x) for x in outputs]
        golds = [LABEL_TEXTS[l] for l in batch["label"]]

        for p, g in zip(preds, golds):
            if p == g:
                correct += 1
        total += len(golds)

print(f"Generation Accuracy: {100*correct/total:.2f}%  ({correct}/{total})")


## TinyStories

In [None]:
ds = load_dataset("roneneldan/TinyStories", split="train[:5000]")

def tokenize_fn(batch):
    return tokenizer(
        batch["text"],
        truncation=True,
        max_length=256,
    )

tokenized = ds.map(tokenize_fn, batched=True)

train_val = tokenized.train_test_split(test_size=0.1, seed=42)
train_ds = train_val["train"]
val_ds = train_val["test"]

In [None]:
from torch.utils.data import DataLoader

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def collate_fn(batch):
    tensors = [
        torch.tensor(example["input_ids"][:256], dtype=torch.long)
        for example in batch
    ]
    return {
        "input_ids": torch.nn.utils.rnn.pad_sequence(
            tensors,
            batch_first=True,
            padding_value=tokenizer.pad_token_id
        )
    }


train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=2, shuffle=False, collate_fn=collate_fn)

In [None]:
import torch
from torch.optim import AdamW
from tqdm import tqdm

optimizer = AdamW(model.parameters(), lr=2e-5)
num_epochs = 5

best_val_loss = float("inf")
patience = 1
patience_counter = 0

for epoch in range(num_epochs):
    # -----------------------
    # TRAINING
    # -----------------------
    model.train()
    total_train_loss = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch} training"):
        input_ids = batch["input_ids"].to(model.device)

        outputs = model(input_ids, labels=input_ids)
        loss = outputs.loss
        total_train_loss += loss.item()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    avg_train_loss = total_train_loss / len(train_loader)

    # -----------------------
    # VALIDATION
    # -----------------------
    model.eval()
    total_val_loss = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch} validation"):
            input_ids = batch["input_ids"].to(model.device)

            outputs = model(input_ids, labels=input_ids)
            total_val_loss += outputs.loss.item()

    avg_val_loss = total_val_loss / len(val_loader)

    print(f"\nEpoch {epoch}:")
    print(f"  Train loss: {avg_train_loss:.4f}")
    print(f"  Val loss: {avg_val_loss:.4f}")

    # -----------------------
    # EARLY STOPPING
    # -----------------------
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(model.state_dict(), "best_model.pth")
    else:
        patience_counter += 1

        if patience_counter >= patience:
            break

print("Training complete.")


In [None]:
import torch
import numpy as np
from tqdm import tqdm

def compute_ppl(model, loader, device="cuda"):
    loss_fn = torch.nn.CrossEntropyLoss(reduction="none")

    total_loss = 0
    total_tokens = 0

    with torch.no_grad():
        for batch in tqdm(loader):
            ids = batch["input_ids"].to(device)

            out = model(ids, use_cache=False)
            logits = out.logits

            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = ids[:, 1:].contiguous()

            loss = loss_fn(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )

            total_loss += loss.sum().item()
            total_tokens += loss.numel()

    mean_nll = total_loss / total_tokens
    return float(np.exp(mean_nll))

ppl = compute_ppl(model, val_loader)
print("PPL:", ppl)

## TweetEval

In [None]:
from datasets import load_dataset

# Load TweetEval sentiment
raw = load_dataset("tweet_eval", "sentiment")

# Map numeric labels to strings
label_map = {0: "negative", 1: "neutral", 2: "positive"}

def build_text(example):
    return f"Tweet: {example['text']}\nSentiment: {label_map[example['label']]}"

train_raw = raw["train"].select(range(5000))
val_raw   = raw["validation"]

train_texts = [build_text(ex) for ex in train_raw]
val_texts   = [build_text(ex) for ex in val_raw]

In [None]:
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def tokenize(batch):
    return tokenizer(
        batch,
        truncation=True,
        padding=True,
        max_length=128
    )

train_tokens = tokenize(train_texts)
val_tokens   = tokenize(val_texts)


In [None]:
import torch

class TD(torch.utils.data.Dataset):
    def __init__(self, toks):
        self.input_ids = toks["input_ids"]
        self.attn = toks["attention_mask"]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            "input_ids": torch.tensor(self.input_ids[idx], dtype=torch.long),
            "attention_mask": torch.tensor(self.attn[idx], dtype=torch.long),
        }

train_ds = TD(train_tokens)
val_ds   = TD(val_tokens)

In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=4, shuffle=False)

In [None]:
import torch
from torch.optim import AdamW
from tqdm import tqdm

optimizer = AdamW(model.parameters(), lr=2e-5)
num_epochs = 10

best_val_loss = float("inf")
patience = 1
patience_counter = 0

for epoch in range(num_epochs):
    # -----------------------
    # TRAINING
    # -----------------------
    model.train()
    total_train_loss = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch} training"):
        input_ids = batch["input_ids"].to(model.device)
        attn      = batch["attention_mask"].to(model.device)

        out = model(input_ids, attention_mask=attn, labels=input_ids)
        loss = out.loss

        total_train_loss += loss.item()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    avg_train_loss = total_train_loss / len(train_loader)

    # -----------------------
    # VALIDATION
    # -----------------------
    model.eval()
    total_val_loss = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch} validation"):
            input_ids = batch["input_ids"].to(model.device)
            attn      = batch["attention_mask"].to(model.device)

            out = model(input_ids, attention_mask=attn, labels=input_ids)
            total_val_loss += out.loss.item()

    avg_val_loss = total_val_loss / len(val_loader)

    print(f"\nEpoch {epoch}: Train={avg_train_loss:.4f}  Val={avg_val_loss:.4f}")

    # -----------------------
    # EARLY STOPPING
    # -----------------------
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(model.state_dict(), "tweeteval-best.pth")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            break

print("Training complete.")


In [None]:
from sklearn.metrics import accuracy_score

label_map = {"negative": 0, "neutral": 1, "positive": 2}

def build_infer_prompt(text):
    return f"Tweet: {text}\nSentiment:"

def extract_sentiment_from_text(txt):
    txt = txt.lower()
    if "positive" in txt:
        return 2
    if "negative" in txt:
        return 0
    if "neutral" in txt:
        return 1
    return -1   # unknown / bad generation


model.eval()

preds = []
labels = []

for ex in tqdm(raw["test"]):
    prompt = build_infer_prompt(ex["text"])
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        gen = model.generate(
            **inputs,
            max_new_tokens=5,
            pad_token_id=tokenizer.eos_token_id
        )

    # decode only the generated part
    gen_text = tokenizer.decode(gen[0][inputs["input_ids"].size(1):])

    pred_label = extract_sentiment_from_text(gen_text)
    true_label = ex["label"]

    if pred_label == -1:
        # optional: skip unparseable outputs
        continue

    preds.append(pred_label)
    labels.append(true_label)

acc = accuracy_score(labels, preds)
print("Generative accuracy:", acc)
