## Task 4:
Task remarks: GPT – Hand in, until 31.08.

* If something is underspecified, just make decision yourself
* Well-documented code
* Submission format
    * Notebook (incl. pdf) or GitHub readme (submit pdf with link to repo) as technical
report of what we did
        * Nice narrative and way to navigate code, not scientific paper
        * Include plots (loss, perplexity scores, hyperparameters, etc.)
        * Optional include pseudocode
        * Qualitative analysis nice to have, e.g, add and evaluate generated text in
report
        * Can add appendix for additional plots
* Hand in every mile stone, starting from UNIX comments
* Removed in-between milestone of causal-self attention
* Everything together in one file
* Compare the models from each milestone, report perplexity for all
    * Old-school n-gram
    * Best neural n-gram
    * GPT


**GPT itself**
* Hyperparameter tuning: do not need all of them, choose what is most interesting and
explain why
    * Number of merges in BPE (not complete gridsearch, isolate top three number of
merges in perplexity in n-gram, test those for GPT)
    * Regularisation
    * How small can we make neural embedding
    * Do not change optimiser
* General remarks
    * Transformer blocks from scratch would be beyond 1.0, not required
    * Implement causal self-attention yourself, do not use ready-made PyTorch version
    * For computing perplexity: Implementing teacher forcing annealing is necessary
for good generation performance, but we don’t have to do it for our assignment
* Reminders
    * Skip weight initialisation and optimiser configuration
        * Can use standard PyTorch initialisation → just get transformer
parameters and add them when initialising the optimiser
    * Remember to change device selection, currently “cuda”, you might want “mps” or
“cpu”
    * Configs: make n_embd smaller, don’t change betas and weight decay (unless
you want to), can change batch size, chunk size, n_head, n_layer
    * Specify temperature and top-k parameters for generate function
    * Activation function used in MLP: not ReLU as in slides but GELU (might not be in
PyTorch yet)

In [None]:
"""
GPT (Transformer) training pipeline for Shakespeare with BPE tokenizer
====================================================================

Goals
-----
- PyTorch implementation of a small GPT (nanoGPT-style) with clean structure.
- Re-use your existing BPE merges and token conventions (</w>, <bos>, <eos>).
- Detailed logging, file outputs, checkpoints, CSV logs, and PNG loss plots.
- Validation + test perplexity.
- Sample text generation at checkpoints.

Directory layout (inputs & outputs)
----------------------------------
Inputs (must exist):
- Corpus/
    Shakespeare_clean_train.txt
    Shakespeare_clean_valid.txt
    Shakespeare_clean_test.txt
- Generated_tokens/
    (one of) bpe_merges with k = {k}.txt, standard_bpe_merges_k{k}.txt, ...

Outputs (this script will create):
- runs/gpt_{timestamp}_k{k}/
    config.json
    tokenizer.json
    train_encoded.pt, valid_encoded.pt, test_encoded.pt
    logs.csv
    loss_plot.png
    ckpt_step{...}.pt (model + optimizer + scaler + config + samples)
    
    
    
    
    
    /
        step{...}_sample.txt

Usage
-----
python gpt_shakespeare_trainer.py --k 1600 --batch_size 64 --block_size 128 --n_layer 4 --n_head 4 --n_embd 256 --max_steps 2000

Notes
-----
- Designed for small models/datasets. Mixed precision is optional (amp).
- If no CUDA, training runs on CPU (slower but fine for tiny configs).
"""

import os
import re
import json
import math
import time
import random
import argparse
import sys
from dataclasses import dataclass, asdict
from typing import List, Tuple, Dict, Iterable, Optional

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import matplotlib
matplotlib.use("Agg")  # headless
import matplotlib.pyplot as plt
import pandas as pd

plt.rcParams["font.family"] = "DejaVu Sans"  # Matplotlib’s default bundled font

# ================================ Constants ================================
CORPUS_DIR = "Corpus"
GENERATED_DIR = "Generated_tokens"
WORD_END = "</w>"
EOS = "<eos>"
BOS = "<bos>"
_wsre = re.compile(r"\s+")
#random.seed(42)

@dataclass
class TrainConfig:
    seed: int = 42
    k: int = 1000
    batch_size: int = 32
    block_size: int = 128
    n_layer: int = 4
    n_head: int = 4
    n_embd: int = 128
    dropout: float = 0.1
    lr: float = 3e-4
    weight_decay: float = 0.0
    max_steps: int = 1000
    eval_interval: int = 200
    eval_batches: int = 20
    ckpt_interval: int = 500
    warmup_steps: int = 100
    grad_clip: float = 1.0
    amp: bool = True
    no_amp: bool = False 


def parse_args() -> TrainConfig:
    parser = argparse.ArgumentParser()
    parser.add_argument("--k", type=int, default=1000)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--block_size", type=int, default=64)
    parser.add_argument("--n_layer", type=int, default=4)
    parser.add_argument("--n_head", type=int, default=4)
    parser.add_argument("--n_embd", type=int, default=128)
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--weight_decay", type=float, default=1e-2)
    parser.add_argument("--max_steps", type=int, default=5000)
    parser.add_argument("--warmup_steps", type=int, default=100)
    parser.add_argument("--grad_clip", type=float, default=1.0)
    parser.add_argument("--eval_interval", type=int, default=500)
    parser.add_argument("--eval_batches", type=int, default=20)
    parser.add_argument("--ckpt_interval", type=int, default=1000)
    parser.add_argument("--no_amp", action="store_true")
    parser.add_argument("--seed", type=int, default=42)

    if "ipykernel_launcher" in sys.argv[0]:
        args, _ = parser.parse_known_args()
    else:
        args = parser.parse_args()

    return TrainConfig(**vars(args))

def get_args():
    parser = argparse.ArgumentParser()

    # Core training params
    parser.add_argument("--k", type=int, default=1000)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--block_size", type=int, default=64)
    parser.add_argument("--n_layer", type=int, default=4)
    parser.add_argument("--n_head", type=int, default=4)
    parser.add_argument("--n_embd", type=int, default=128)
    parser.add_argument("--dropout", type=float, default=0.1)

    # Optimization
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--weight_decay", type=float, default=1e-2)
    parser.add_argument("--max_steps", type=int, default=5000)
    parser.add_argument("--warmup_steps", type=int, default=100)
    parser.add_argument("--grad_clip", type=float, default=1.0)

    # Evaluation/checkpoints
    parser.add_argument("--eval_interval", type=int, default=500)
    parser.add_argument("--eval_batches", type=int, default=20)
    parser.add_argument("--ckpt_interval", type=int, default=1000)

    # Misc
    parser.add_argument("--no_amp", action="store_true")
    parser.add_argument("--seed", type=int, default=42)

    # Use parse_known_args to ignore --f=...json
    if "ipykernel_launcher" in sys.argv[0]:
        args, _ = parser.parse_known_args()
    else:
        args = parser.parse_args()

    return args

# ============================= BPE Tokenizer ===============================

def find_merges_file(k: int, verbose: bool = True) -> str:
    candidates = [
        os.path.join(GENERATED_DIR, f"bpe_merges with k = {k}.txt"),
        os.path.join(GENERATED_DIR, f"standard_bpe_merges_k{k}.txt"),
        os.path.join(GENERATED_DIR, f"aggressive_clean_bpe_merges_k{k}.txt"),
        os.path.join(GENERATED_DIR, f"bpe_merges_k{k}.txt"),
        os.path.join(GENERATED_DIR, f"bpe_merges_k{k}_webtext_clean.txt"),
    ]
    for path in candidates:
        if os.path.exists(path):
            if verbose:
                print(f"[Found] Using merges file: {path}")
            return path
    raise FileNotFoundError(f"No merges file found for k={k}. Tried: {candidates}")

def load_merges(merges_path: str) -> List[Tuple[str, str]]:
    merges = []
    with open(merges_path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 2:
                merges.append((parts[0], parts[1]))
    return merges

def words_from_text(text: str, lowercase: bool = True) -> List[str]:
    if lowercase:
        text = text.lower()
    return [w for w in _wsre.split(text.strip()) if w]

def apply_merges_to_word(word: str, merges: List[Tuple[str, str]]) -> List[str]:
    symbols = tuple(list(word) + [WORD_END])
    for a, b in merges:
        out = []
        i, L = 0, len(symbols)
        while i < L:
            if i < L-1 and symbols[i] == a and symbols[i+1] == b:
                out.append(a + b); i += 2
            else:
                out.append(symbols[i]); i += 1
        symbols = tuple(out)
    return list(symbols)

def tokenize_lines_with_merges(text: str, merges: List[Tuple[str, str]]) -> List[List[str]]:
    token_lines: List[List[str]] = []
    for line in text.strip().splitlines():
        words = words_from_text(line)
        if not words:
            continue
        toks: List[str] = []
        for w in words:
            toks.extend(apply_merges_to_word(w, merges))
        toks.append(EOS)
        token_lines.append(toks)
    return token_lines

# Convert tokens to ids, build vocab
class BPETokenizer:
    def __init__(self, merges: List[Tuple[str, str]], extra_tokens: Optional[List[str]] = None):
        self.merges = merges
        self.extra_tokens = extra_tokens or []
        self.token_to_id: Dict[str, int] = {}
        self.id_to_token: List[str] = []

    def build_vocab_from_texts(self, texts: Dict[str, str]):
        vocab = set()
        for name, txt in texts.items():
            for line in tokenize_lines_with_merges(txt, self.merges):
                vocab.update(line)
        vocab.update(self.extra_tokens)
        # Deterministic order
        self.id_to_token = sorted(vocab)
        self.token_to_id = {t: i for i, t in enumerate(self.id_to_token)}

    def encode_words(self, words: Iterable[str]) -> List[str]:
        toks: List[str] = []
        for w in (w.lower() for w in words):
            toks.extend(apply_merges_to_word(w, self.merges))
        return toks

    def encode_lines(self, text: str) -> List[List[int]]:
        lines_tok = tokenize_lines_with_merges(text, self.merges)
        ids_lines: List[List[int]] = []
        for line in lines_tok:
            ids_lines.append([self.token_to_id[t] for t in line if t in self.token_to_id])
        return ids_lines

    def decode_tokens(self, token_stream: List[str]) -> List[str]:
        words: List[str] = []
        buf: List[str] = []
        for t in token_stream:
            if t == EOS:
                break
            buf.append(t)
            if t.endswith(WORD_END):
                chars: List[str] = []
                for sub in buf:
                    if sub.endswith(WORD_END):
                        chars.extend(list(sub[:-len(WORD_END)]))
                    else:
                        chars.extend(list(sub))
                words.append("".join(chars))
                buf = []
        if buf:
            chars = []
            for sub in buf:
                if sub.endswith(WORD_END):
                    chars.extend(list(sub[:-len(WORD_END)]))
                else:
                    chars.extend(list(sub))
            if chars:
                words.append("".join(chars))
        return words

    def save(self, path: str):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, "w", encoding="utf-8") as f:
            json.dump({
                "id_to_token": self.id_to_token,
                "token_to_id": self.token_to_id,
                "extra_tokens": self.extra_tokens,
            }, f, ensure_ascii=False, indent=2)

    @staticmethod
    def load(path: str) -> "BPETokenizer":
        with open(path, "r", encoding="utf-8") as f:
            obj = json.load(f)
        tok = BPETokenizer(merges=[], extra_tokens=obj.get("extra_tokens", []))
        tok.id_to_token = obj["id_to_token"]
        tok.token_to_id = {k: int(v) for k, v in obj["token_to_id"].items()}
        return tok

# ============================ Data preparation ==============================

def read_text(path: str) -> str:
    with open(path, "r", encoding="utf-8") as f:
        return f.read()

@dataclass
class EncodedSplits:
    train: torch.Tensor
    valid: torch.Tensor
    test: torch.Tensor


def build_or_load_encoded(run_dir: str, k: int) -> Tuple[BPETokenizer, EncodedSplits]:
    enc_train_path = os.path.join(run_dir, "train_encoded.pt")
    enc_valid_path = os.path.join(run_dir, "valid_encoded.pt")
    enc_test_path  = os.path.join(run_dir, "test_encoded.pt")
    tok_path       = os.path.join(run_dir, "tokenizer.json")

    if all(os.path.exists(p) for p in [enc_train_path, enc_valid_path, enc_test_path, tok_path]):
        print("[Load] Using cached encoded splits + tokenizer")
        tokenizer = BPETokenizer.load(tok_path)
        train_ids = torch.load(enc_train_path)
        valid_ids = torch.load(enc_valid_path)
        test_ids  = torch.load(enc_test_path)
        return tokenizer, EncodedSplits(train_ids, valid_ids, test_ids)

    # Build from raw
    merges_path = find_merges_file(k, verbose=True)
    merges = load_merges(merges_path)

    train_txt = read_text(os.path.join(CORPUS_DIR, "Shakespeare_clean_train.txt"))
    valid_txt = read_text(os.path.join(CORPUS_DIR, "Shakespeare_clean_valid.txt"))
    test_txt  = read_text(os.path.join(CORPUS_DIR, "Shakespeare_clean_test.txt"))

    tokenizer = BPETokenizer(merges=merges, extra_tokens=[BOS, EOS])
    tokenizer.build_vocab_from_texts({"train": train_txt, "valid": valid_txt, "test": test_txt})

    def flatten(lines: List[List[int]]) -> List[int]:
        flat = []
        for ln in lines: flat.extend(ln)
        return flat

    train_ids = torch.tensor(flatten(tokenizer.encode_lines(train_txt)), dtype=torch.long)
    valid_ids = torch.tensor(flatten(tokenizer.encode_lines(valid_txt)), dtype=torch.long)
    test_ids  = torch.tensor(flatten(tokenizer.encode_lines(test_txt)),  dtype=torch.long)

    os.makedirs(run_dir, exist_ok=True)
    tokenizer.save(tok_path)
    torch.save(train_ids, enc_train_path)
    torch.save(valid_ids, enc_valid_path)
    torch.save(test_ids,  enc_test_path)
    print(f"[Save] Encoded splits to {run_dir}")
    print(f"[Info] Vocab size = {len(tokenizer.id_to_token)}")
    return tokenizer, EncodedSplits(train_ids, valid_ids, test_ids)

# ================================ Dataset ==================================

class GPTDataset(Dataset):
    def __init__(self, ids: torch.Tensor, block_size: int):
        self.ids = ids
        self.block_size = block_size
        # we will sample random start positions in __getitem__

    def __len__(self) -> int:
        # approximate number of sequences of length block_size we can draw
        return max(1, len(self.ids) - self.block_size)

    def __getitem__(self, idx):
        # ignore idx and sample randomly to add stochasticity
        i = random.randint(0, len(self.ids) - self.block_size - 1)
        x = self.ids[i : i + self.block_size]
        y = self.ids[i + 1 : i + self.block_size + 1]
        return x, y

# ============================== GPT Model ==================================

class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.key = nn.Linear(n_embd, n_embd)
        self.query = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        self.attn_drop = nn.Dropout(dropout)
        self.resid_drop = nn.Dropout(dropout)
        self.proj = nn.Linear(n_embd, n_embd)
        # causal mask
        self.register_buffer("bias", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))

    def forward(self, x):
        B, T, C = x.size()
        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        att = torch.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v  # (B, n_head, T, head_dim)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_drop(self.proj(y))
        return y

class Block(nn.Module):
    def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
        self.ln2 = nn.LayerNorm(n_embd)
        self.mlp = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class GPT(nn.Module):
    def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float):
        super().__init__()
        self.block_size = block_size
        self.transformer = nn.ModuleDict({
            'wte': nn.Embedding(vocab_size, n_embd),
            'wpe': nn.Embedding(block_size, n_embd),
            'drop': nn.Dropout(dropout),
            'h': nn.ModuleList([Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]),
            'ln_f': nn.LayerNorm(n_embd),
        })
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx):
        B, T = idx.size()
        assert T <= self.block_size
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
        tok_emb = self.transformer['wte'](idx)
        pos_emb = self.transformer['wpe'](pos)[None, :, :]
        x = self.transformer['drop'](tok_emb + pos_emb)
        for block in self.transformer['h']:
            x = block(x)
        x = self.transformer['ln_f'](x)
        logits = self.lm_head(x)
        return logits

# ============================ Training utilities ============================

def count_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())

# ============================== Evaluation ==================================

def evaluate(model: GPT, loader: DataLoader, device: torch.device, max_batches: int) -> float:
    model.eval()
    losses = []
    loss_fn = nn.CrossEntropyLoss()
    with torch.no_grad():
        for i, (x, y) in enumerate(loader):
            if i >= max_batches:
                break
            x = x.to(device)
            y = y.to(device)
            logits = model(x)
            loss = loss_fn(logits.view(-1, logits.size(-1)), y.view(-1))
            losses.append(loss.item())
    model.train()
    return float(sum(losses) / max(1, len(losses)))

# ============================== Generation ==================================

def generate(model: GPT, start_tokens: List[int], max_new_tokens: int, temperature: float = 1.0, top_k: Optional[int] = None) -> List[int]:
    model.eval()
    device = next(model.parameters()).device
    idx = torch.tensor(start_tokens, dtype=torch.long, device=device)[None, :]
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -model.block_size:]
        logits = model(idx_cond)
        logits = logits[:, -1, :] / max(1e-8, temperature)
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')
        probs = torch.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat((idx, next_id), dim=1)
    return idx[0].tolist()

# ============================== Plot / CSV ==================================

def save_plot_and_csv(run_dir, history):
    """Save training/validation loss plot and history CSV."""
    # make sure run_dir and subdirs exist
    os.makedirs(run_dir, exist_ok=True)
    os.makedirs(os.path.join(run_dir, "samples"), exist_ok=True)

    csv_path = os.path.join(run_dir, "history.csv")
    png_path = os.path.join(run_dir, "history.png")

    # Save CSV
    df = pd.DataFrame(history)
    df.to_csv(csv_path, index=False)

    # Try plotting
    try:
        # Force matplotlib to use a safe font
        matplotlib.rcParams["font.family"] = "DejaVu Sans"

        steps = [h["step"] for h in history]
        train_loss = [h["train_loss"] for h in history]
        val_loss = [h["val_loss"] for h in history]

        plt.figure(figsize=(8, 5))
        plt.plot(steps, train_loss, label="train_loss")
        plt.plot(steps, val_loss, label="val_loss")
        plt.xlabel("step")
        plt.ylabel("loss")
        plt.legend()
        plt.tight_layout()
        plt.savefig(png_path)
        plt.close()
        print(f"[Plot] Saved to {png_path}")

    except Exception as e:
        print(f"[Plot Warning] Could not generate plot: {e}")


# ============================== Main training ===============================

def train_and_eval_with_logging(cfg: TrainConfig):
    torch.manual_seed(cfg.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ts = time.strftime("%Y%m%d_%H%M%S")
    run_dir = os.path.join("runs", f"gpt_{ts}_k{cfg.k}")
    os.makedirs(run_dir, exist_ok=True)
    os.makedirs(os.path.join(run_dir, "samples"), exist_ok=True)

    # Save config
    with open(os.path.join(run_dir, "config.json"), "w", encoding="utf-8") as f:
        json.dump(asdict(cfg), f, indent=2)

    # Encode or load
    tokenizer, splits = build_or_load_encoded(run_dir, cfg.k)
    vocab_size = len(tokenizer.id_to_token)

    # Datasets/loaders
    train_ds = GPTDataset(splits.train, cfg.block_size)
    valid_ds = GPTDataset(splits.valid, cfg.block_size)
    test_ds  = GPTDataset(splits.test,  cfg.block_size)

    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, drop_last=True)
    valid_loader = DataLoader(valid_ds, batch_size=cfg.batch_size, shuffle=False, drop_last=True)
    test_loader  = DataLoader(test_ds,  batch_size=cfg.batch_size, shuffle=False, drop_last=True)

    # Model
    model = GPT(vocab_size=vocab_size, block_size=cfg.block_size, n_layer=cfg.n_layer, n_head=cfg.n_head, n_embd=cfg.n_embd, dropout=cfg.dropout)
    model.to(device)

    print(f"[Info] Device: {device} | Parameters: {count_parameters(model)/1e6:.2f}M | Vocab={vocab_size} | block={cfg.block_size}")

    # Optimizer / Scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    lr_sched = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: min(1.0, step / max(1, cfg.warmup_steps))
    )

    scaler = torch.amp.GradScaler("cuda", enabled=cfg.amp and device.type == "cuda")
    loss_fn = nn.CrossEntropyLoss()

    history = []
    running_loss = 0.0

    model.train()
    step = 0
    while step < cfg.max_steps:
        for xb, yb in train_loader:
            step += 1
            xb = xb.to(device)
            yb = yb.to(device)

            with torch.amp.autocast("cuda", enabled=scaler.is_enabled()):
                logits = model(xb)
                loss = loss_fn(logits.view(-1, logits.size(-1)), yb.view(-1))

            scaler.scale(loss).backward()
            if cfg.grad_clip:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            lr_sched.step()

            running_loss = 0.9 * running_loss + 0.1 * loss.item() if step > 1 else loss.item()

            if step % 50 == 0:
                print(f"[Step {step:5d}] train_loss={loss.item():.4f} (ema {running_loss:.4f})")


            # Eval
            if step % cfg.eval_interval == 0 or step == cfg.max_steps:
                val_loss = evaluate(model, valid_loader, device, cfg.eval_batches)
                val_ppl = math.exp(val_loss)
                print(f"[Eval  {step:5d}] val_loss={val_loss:.4f} | val_ppl={val_ppl:.2f}")
                history.append({"step": step, "train_loss": running_loss, "val_loss": val_loss, "val_ppl": val_ppl})
                save_plot_and_csv(run_dir, history)

                # save extrinsic evaluation with fixed prompts
                prompts = ["To be or not to be", "Once upon a midnight dreary"]
                eval_path = os.path.join(run_dir, f"samples/step{step}_eval.txt")
                with open(eval_path, "w", encoding="utf-8") as f:
                    for prompt in prompts:
                        start_ids = tokenizer.encode_words(prompt.split())
                        start_ids = [tokenizer.token_to_id.get(tok, 0) for tok in start_ids]
                        f.write(f"\nPrompt: {prompt}\n")

                        greedy_ids = generate(model, start_tokens=start_ids, max_new_tokens=40, temperature=1.0, top_k=None)
                        greedy_text = " ".join([tokenizer.id_to_token[i] for i in greedy_ids])
                        f.write("Greedy: " + greedy_text + "\n")

                        topk_ids = generate(model, start_tokens=start_ids, max_new_tokens=40, temperature=0.8, top_k=50)
                        topk_text = " ".join([tokenizer.id_to_token[i] for i in topk_ids])
                        f.write("Top-k: " + topk_text + "\n")
                print(f"[Sample Eval] saved → {eval_path}")

            # Checkpoint + sample
            if step % cfg.ckpt_interval == 0 or step == cfg.max_steps:
                ckpt_path = os.path.join(run_dir, f"ckpt_step{step}.pt")
                torch.save({
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scaler_state": scaler.state_dict(),
                    "config": asdict(cfg),
                    "vocab_size": vocab_size,
                    "step": step,
                }, ckpt_path)
                print(f"[Save] checkpoint → {ckpt_path}")

            if step >= cfg.max_steps:
                break

    # Final test evaluation
    val_loss = evaluate(model, valid_loader, device, cfg.eval_batches)
    val_ppl = math.exp(val_loss)
    test_loss = evaluate(model, test_loader, device, cfg.eval_batches)
    test_ppl = math.exp(test_loss)
    
    print(f"[Final Val] loss={val_loss:.4f} | ppl={val_ppl:.2f}")
    print(f"[Final Test] loss={test_loss:.4f} | ppl={test_ppl:.2f}")

    # Save final metrics
    results = {
        "k": cfg.k,
        "n_embd": cfg.n_embd,
        "dropout": cfg.dropout,
        "val_loss": val_loss,
        "val_ppl": val_ppl,
        "test_loss": test_loss,
        "test_ppl": test_ppl,
    }
    with open(os.path.join(run_dir, "final_results.json"), "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2)

    return results


if __name__ == "__main__":
    cfg = parse_args()   # returns a TrainConfig object
    results = train_and_eval_with_logging(cfg)
    save_plot_and_csv(os.path.join("runs", f"gpt_final_k{cfg.k}"), [results])



[Found] Using merges file: Generated_tokens\bpe_merges with k = 1000.txt
[Save] Encoded splits to runs\gpt_20250831_200113_k1000
[Info] Vocab size = 1028
[Info] Device: cpu | Parameters: 1.06M | Vocab=1028 | block=64
[Step    50] train_loss=6.2732 (ema 6.4469)
[Step   100] train_loss=5.5120 (ema 5.6544)
[Step   150] train_loss=5.0238 (ema 5.0975)
[Step   200] train_loss=4.6748 (ema 4.8054)
[Step   250] train_loss=4.6635 (ema 4.6314)
[Step   300] train_loss=4.5805 (ema 4.5115)
[Step   350] train_loss=4.3110 (ema 4.4233)
[Step   400] train_loss=4.2795 (ema 4.3166)
[Step   450] train_loss=4.2358 (ema 4.2783)
[Step   500] train_loss=4.2132 (ema 4.2182)
[Eval    500] val_loss=4.1863 | val_ppl=65.78
[Sample Eval] saved → runs\gpt_20250831_200113_k1000\samples/step500_eval.txt
[Step   550] train_loss=3.9659 (ema 4.1180)
[Step   600] train_loss=4.0263 (ema 4.0882)
[Step   650] train_loss=4.0694 (ema 4.0503)
[Step   700] train_loss=3.9535 (ema 4.0074)
[Step   750] train_loss=3.9800 (ema 3.9815)

In [None]:
# ---------------- Training & Evaluation ----------------
def train_and_eval_for_gpt(k, n_embd, dropout, run_dir, num_steps=2000, batch_size=32, lr=3e-4, eval_batches=50):
device = "cuda" if torch.cuda.is_available() else "cpu"


# --- Load data & tokenizer ---
tokenizer, splits = build_or_load_encoded(run_dir, k)
vocab_size = len(tokenizer.token_to_id)


train_ds = GPTDataset(splits.train, block_size=64)
val_ds = GPTDataset(splits.valid, block_size=64)
test_ds = GPTDataset(splits.test, block_size=64)


train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=True)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, drop_last=True)


# --- Build model ---
model = GPT(
vocab_size=vocab_size,
block_size=64,
n_layer=4,
n_head=4,
n_embd=n_embd,
dropout=dropout
).to(device)


optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()


# --- Training loop ---
print(f"Training GPT (k={k}, n_embd={n_embd}, dropout={dropout})...")
for step in range(1, num_steps+1):
loss = train_one_epoch(model, train_loader, optimizer, criterion, device)


if step % 200 == 0:
val_loss = evaluate(model, val_loader, device, max_batches=eval_batches)
val_ppl = math.exp(val_loss)
print(f"Step {step}: train_loss={loss:.4f}, val_loss={val_loss:.4f}, val_ppl={val_ppl:.2f}")


# --- Final evaluation ---
val_loss = evaluate(model, val_loader, device, max_batches=eval_batches)
test_loss = evaluate(model, test_loader, device, max_batches=eval_batches)
val_ppl = math.exp(val_loss)
test_ppl = math.exp(test_loss)


print(f"Final [Val] loss={val_loss:.4f}, ppl={val_ppl:.2f}")
print(f"Final [Test] loss={test_loss:.4f}, ppl={test_ppl:.2f}")


# --- Extrinsic: Generation ---
prompts = ["To be or not to be", "Once upon a midnight dreary"]
for prompt in prompts:
start_ids = tokenizer.encode_words(prompt.split())
start_ids = [tokenizer.token_to_id.get(tok, 0) for tok in start_ids]


print("\nPrompt:", prompt)
print("Greedy:")
greedy_ids = generate(model, start_tokens=start_ids, max_new_tokens=40, temperature=1.0, top_k=None)
print(" ".join(tokenizer.decode_tokens([tokenizer.id_to_token[i] for i in greedy_ids])))


print("Top-k:")
topk_ids = generate(model, start_tokens=start_ids, max_new_tokens=40, temperature=0.8, top_k=50)
print(" ".join(tokenizer.decode_tokens([tokenizer.id_to_token[i] for i in topk_ids])))


return {
"k": k,
"n_embd": n_embd,
"dropout": dropout,
"val_loss": val_loss,
"val_ppl": val_ppl,
"test_loss": test_loss,
"test_ppl": test_ppl,
}

In [None]:
results = []
for k in [200, 500, 1000]:   # your top 3 from n-grams
    for n_embd in [64, 128]:
        res = train_and_eval_for_gpt(k, n_embd, dropout=0.1, run_dir="runs/gpt_exp")
        results.append(res)


In [None]:
## probably not used

"""# sample generation from random position of valid set
start = random.randint(0, max(0, len(splits.valid) - cfg.block_size - 1))
prefix = splits.valid[start:start+min(32, cfg.block_size)].tolist()
gen_ids = generate(model, start_tokens=prefix, max_new_tokens=60, temperature=0.8, top_k=50)
sample_txt = " ".join([tokenizer.id_to_token[i] for i in gen_ids])
sample_path = os.path.join(run_dir, "samples", f"step{step}_sample.txt")
with open(sample_path, "w", encoding="utf-8") as f:
    f.write(sample_txt)
print(f"[Sample] saved → {sample_path}")"""