# CS336 Assignment 1 — Training Loop

本 notebook 用 `training_functions.ipynb` 里实现的组件，加上你之前写好的 TransformerLM 模型，跑一个完整训练循环。

主要功能：

* memmap 加载 tokenized 数据
* 采样 batch
* 混合精度前向 + 反向
* AdamW 更新
* 余弦学习率调度
* 梯度裁剪
* 训练/验证 loss 日志
* Matplotlib 实时曲线可视化（loss 下降情况）
* checkpoint 保存
* （可选）wandb 日志

In [3]:
from __future__ import annotations
import os
import math
import time
from typing import Dict, Any, Optional, Tuple, List

import numpy as np
import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

# === bring in what we wrote in training_functions.ipynb ===
# You can "import" them if you made a .py file,
# or if you're running in one notebook environment, just run that notebook first and reuse.
# from training_functions import (
#     cross_entropy_loss,
#     AdamW,
#     cosine_lr_schedule,
#     clip_gradients,
#     get_batch,
#     save_checkpoint,
#     load_checkpoint,
#     TrainingConfig,
# )

# Your Transformer LM
# from transformerLM import TransformerLM  # <-- TODO: adjust to your actual module name

## 0. 准备数据 (memmap + sanity check)

作业里建议把整个数据集（比如 TinyStories tokenized）保存成一个 `uint16` 的 `.npy`
然后用 `np.memmap` 或 `np.load(..., mmap_mode='r')` 读取，避免一次性吃满内存。

我们会准备：

* `train_data`
* `val_data`

注意：`vocab_size` 应该和模型一致。
`context_len` 要 ≤ 你模型声明的 `context_length`。

In [4]:
def load_token_dataset(path: str) -> np.memmap:
    """
    path: .npy file of token IDs, dtype=uint16 (or int32 etc.)
    We'll open in read-only memmap mode.
    """
    arr = np.load(path, mmap_mode="r")
    # 可选 sanity check: vocab 范围
    # print("max token id:", int(arr.max()))
    return arr

## 1. 评估函数 eval_loss(model, data, config, num_batches)

我们不想每次 eval 都跑全量数据，所以：

* 随机抽 `num_batches`
* 计算平均 loss (无梯度、eval 模式)
* 返回这个标量，方便 logging / 画图

In [5]:
@torch.no_grad()
def evaluate_model_loss(
    model: nn.Module,
    data: np.ndarray,
    config: TrainingConfig,
    num_batches: int = 20,
) -> float:
    model.eval()
    device = config.device
    losses = []

    for _ in range(num_batches):
        x, y = get_batch(
            data,
            batch_size=config.batch_size,
            context_len=config.context_len,
            device=device,
        )
        # forward
        logits = model(x)  # (B, T, vocab)
        # shift: predict next token => targets already aligned in y
        # we just flatten both B,T into one big batch dimension for CE
        B, T, V = logits.shape
        loss = cross_entropy_loss(
            logits.view(B * T, V),
            y.view(B * T),
        )
        losses.append(loss.item())

    model.train()
    return float(sum(losses) / len(losses))

## 2. plotter: 实时更新训练/验证 loss 曲线

我们用一个小的 helper 来维护 lists，然后每隔几步刷新一张图。
注意：Jupyter 里用 `%matplotlib inline` 默认静态。如果你想动态刷新，可以 `clear_output(wait=True)`。
下面实现一个简单的在线曲线记录类。

In [6]:
from IPython.display import clear_output

class LivePlotter:
    def __init__(self):
        self.train_steps = []
        self.train_losses = []
        self.eval_steps = []
        self.eval_losses = []

    def log_train(self, step, loss):
        self.train_steps.append(step)
        self.train_losses.append(loss)

    def log_eval(self, step, loss):
        self.eval_steps.append(step)
        self.eval_losses.append(loss)

    def draw(self):
        clear_output(wait=True)
        plt.figure(figsize=(6,4))
        if self.train_steps:
            plt.plot(self.train_steps, self.train_losses, label="train_loss")
        if self.eval_steps:
            plt.plot(self.eval_steps, self.eval_losses, label="val_loss")
        plt.xlabel("step")
        plt.ylabel("loss")
        plt.legend()
        plt.title("Training Progress")
        plt.show()

## 3. create_model_and_optimizer(config, vocab_size, context_length, num_layers, d_model, num_heads)

* 构建 TransformerLM
* 构建 AdamW
* GradScaler (混合精度)
* 返回所有句柄

In [7]:
def create_model_and_optimizer(
    config: TrainingConfig,
    vocab_size: int,
    context_length: int,
    num_layers: int,
    d_model: int,
    num_heads: int,
    d_ff: Optional[int] = None,  # if your model needs explicit ff dim
    ckpt_path_to_resume: Optional[str] = None,
):
    device = config.device

    # 1. 模型
    model = TransformerLM(
        vocab_size=vocab_size,
        context_length=context_length,
        num_layers=num_layers,
        d_model=d_model,
        num_heads=num_heads,
        d_ff=d_ff,
        device=device,
    )
    model.to(device)

    # 2. 优化器
    optimizer = AdamW(
        model.parameters(),
        lr=config.lr_max,  # we'll override per step using scheduler anyway
        betas=config.betas,
        eps=config.eps,
        weight_decay=config.weight_decay,
    )

    # 3. GradScaler for AMP
    scaler = GradScaler(enabled=config.mixed_precision)

    # 4. maybe resume from checkpoint
    start_step = 0
    if ckpt_path_to_resume is not None and os.path.exists(ckpt_path_to_resume):
        print(f"[resume] Loading checkpoint from {ckpt_path_to_resume}")
        start_step = load_checkpoint(model, optimizer, ckpt_path_to_resume)
        # Note: GradScaler state isn't in our checkpoint. You could store it too if you want.

    return model, optimizer, scaler, start_step

## 4. train_loop(...)

核心训练循环。
每个 step：

1. 抽 batch
2. 计算当前 step 的学习率并塞回 optimizer
3. 前向 (AMP autocast)
4. loss 计算 (我们用自己写的 cross_entropy_loss)
5. backward (scaled if AMP)
6. gradient clipping
7. optimizer.step(), zero_grad()

另外：

* 每隔 `log_every` 记录 train loss
* 每隔 `eval_every` 在 val 上跑 evaluate_model_loss
* 每隔 `ckpt_every` 保存 checkpoint
* 用 LivePlotter 实时画图

In [8]:
def train_loop(
    train_data: np.ndarray,
    val_data: np.ndarray,
    config: TrainingConfig,
    vocab_size: int,
    context_length: int,
    num_layers: int,
    d_model: int,
    num_heads: int,
    d_ff: Optional[int] = None,
    resume_ckpt: Optional[str] = None,
    save_prefix: str = "checkpoint_step",
    use_wandb: bool = False,
):
    device = config.device
    os.makedirs(config.ckpt_dir, exist_ok=True)

    # 可选 wandb
    if use_wandb:
        import wandb
        wandb.init(project="cs336_assignment1", name=config.run_name, config=dict(
            batch_size=config.batch_size,
            context_len=config.context_len,
            lr_max=config.lr_max,
            lr_min=config.lr_min,
            warmup_iters=config.warmup_iters,
            cosine_iters=config.cosine_iters,
            weight_decay=config.weight_decay,
            grad_clip_norm=config.grad_clip_norm,
            mixed_precision=config.mixed_precision,
            model=dict(
                vocab_size=vocab_size,
                context_length=context_length,
                num_layers=num_layers,
                d_model=d_model,
                num_heads=num_heads,
                d_ff=d_ff,
            )
        ))

    # 初始化模型/优化器/缩放器
    model, optimizer, scaler, step = create_model_and_optimizer(
        config,
        vocab_size=vocab_size,
        context_length=context_length,
        num_layers=num_layers,
        d_model=d_model,
        num_heads=num_heads,
        d_ff=d_ff,
        ckpt_path_to_resume=resume_ckpt,
    )

    plotter = LivePlotter()

    pbar = tqdm(range(step, config.total_steps), initial=step, total=config.total_steps)
    for cur_step in pbar:
        # === 1. 取一批数据 ===
        x, y = get_batch(
            train_data,
            batch_size=config.batch_size,
            context_len=config.context_len,
            device=device,
        )

        # === 2. 学习率调度，更新 optimizer group lr ===
        lr_t = cosine_lr_schedule(
            t=cur_step,
            lr_max=config.lr_max,
            lr_min=config.lr_min,
            warmup_iters=config.warmup_iters,
            cosine_iters=config.cosine_iters,
        )
        for g in optimizer.param_groups:
            g["lr"] = lr_t

        # === 3. 前向 + loss ===
        with autocast(enabled=config.mixed_precision, device_type="cuda"):
            logits = model(x)  # (B, T, vocab)
            B, T, V = logits.shape
            loss = cross_entropy_loss(
                logits.view(B*T, V),
                y.view(B*T),
            )

        # === 4. 反向传播 ===
        optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()

        # === 5. 梯度裁剪 ===
        # 先 unscale 再裁剪 (amp best practice)
        scaler.unscale_(optimizer)
        grad_norm = clip_gradients(model.parameters(), config.grad_clip_norm)

        # === 6. Optimizer Step (AMP aware) ===
        scaler.step(optimizer)
        scaler.update()

        # === logging ===
        if (cur_step % config.log_every) == 0:
            train_loss_value = float(loss.item())
            plotter.log_train(cur_step, train_loss_value)

            pbar.set_description(
                f"step {cur_step} "
                f"loss {train_loss_value:.4f} "
                f"lr {lr_t:.2e} "
                f"gnorm {grad_norm:.2f}"
            )

            if use_wandb:
                import wandb
                wandb.log({
                    "train/loss": train_loss_value,
                    "train/lr": lr_t,
                    "train/grad_norm": grad_norm,
                    "step": cur_step,
                })

        # === eval ===
        if (cur_step % config.eval_every) == 0 and cur_step > 0:
            val_loss_value = evaluate_model_loss(model, val_data, config)
            plotter.log_eval(cur_step, val_loss_value)

            if use_wandb:
                import wandb
                wandb.log({
                    "val/loss": val_loss_value,
                    "step": cur_step,
                })

        # === 可视化刷新 ===
        if (cur_step % config.log_every) == 0:
            plotter.draw()

        # === checkpoint ===
        if (cur_step % config.ckpt_every) == 0 and cur_step > 0:
            ckpt_path = os.path.join(
                config.ckpt_dir, f"{save_prefix}_{cur_step}.pt"
            )
            save_checkpoint(model, optimizer, cur_step, ckpt_path)
            # 你也可以顺带保存 plotter 的数据等

    # 训练结束，最后再存一个最终 checkpoint
    final_ckpt = os.path.join(config.ckpt_dir, f"{save_prefix}_final.pt")
    save_checkpoint(model, optimizer, config.total_steps, final_ckpt)

    # 最后一张完整曲线
    plotter.draw()

    return model

## 5. 使用示例

下面是一个最小可跑示例（debug规模），假设：

* 你有两个 `.npy`：`tinystories_train_tokens.npy` 和 `tinystories_val_tokens.npy`
* 这些是 uint16 token ids
* 你的 `TransformerLM` 支持这些超参

注意：第一次调试可以把 `total_steps`、`context_len`、`batch_size` 都调小，确保 loop 能跑通、loss 会下降。


In [None]:
# 1. 载入 memmap 数据（TinyStories 的tokens）
train_data = load_token_dataset("../datasets/TinyStories/tokens_train.npy")
val_data   = load_token_dataset("../datasets/TinyStories/tokens_valid.npy)

    # 2. 配置
cfg = TrainingConfig(
    batch_size=32,
    context_len=128,
    total_steps=2000,
    log_every=20,
    eval_every=200,
    ckpt_every=500,
    lr_max=3e-4,
    lr_min=3e-5,
    warmup_iters=100,
    cosine_iters=2000,
    weight_decay=0.1,
    grad_clip_norm=1.0,
    device="cuda",
    mixed_precision=True,
    ckpt_dir="./checkpoints_tinystories",
    run_name="tinystories_debug_run",
)

    # 3. 模型超参（示例：小模型以便CPU/GPU都能跑）
vocab_size     = 10_000      # e.g. tokenizer vocab size
context_length = cfg.context_len
num_layers     = 4
d_model        = 256
num_heads      = 4
d_ff           = int((8/3) * d_model)  # or whatever your impl expects

trained_model = train_loop(
    train_data=train_data,
    val_data=val_data,
    config=cfg,
    vocab_size=vocab_size,
    context_length=context_length,
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    d_ff=d_ff,
    resume_ckpt=None,            # or "./checkpoints_tinystories/checkpoint_step_1000.pt"
    save_prefix="tinystories_lm",
    use_wandb=False,             # turn on if you want wandb logging
)

## 检查数据

In [13]:
import numpy as np
# train_data = np.load("../datasets/TinyStories/tokens_train.npy", mmap_mode="r")
valid_data = np.load("../datasets/TinyStories/tokens_valid.npy", mmap_mode="r")
# print(train_data.dtype, train_data.shape)
print("valid example:", valid_data[:100])

valid example: [ 83 112 111 116  46  32  83 112 111 116  32 115  97 119  32 116 104 101
  32 115 104 105 110 121  32  99  97 114  32  97 110 100  32 115  97 105
 100  44  32  34  87 111 119  44  32  75 105 116 116 121  44  32 121 111
 117 114  32  99  97 114  32 105 115  32 115 111  32  98 114 105 103 104
 116  32  97 110 100  32  99 108 101  97 110  32  75 105 116 116 121  32
 115 109 105 108 101 100  32  97 110 100]
