# CS336 Assignment 1 — Training Functions

本 notebook 定义训练循环需要的所有基础组件：

* `cross_entropy_loss`
* `AdamW` (自己写的 Optimizer，不能用 torch.optim.AdamW)
* 学习率调度器：余弦退火 + warmup
* 梯度裁剪
* batch 抽样（`get_batch`）
* checkpoint 的保存 / 加载
* `TrainingConfig` (方便后面主训练 loop 用)

这些实现遵守作业要求：

* 不调用 `torch.nn.functional`、`torch.optim` 里的现成实现（除了基类 Optimizer）
* 注意数值稳定性（logits 减最大值）
* 注意跨设备（我们会传 `device="cuda"`）

In [1]:
from __future__ import annotations
import math
import os
import io
import time
from dataclasses import dataclass
from typing import Optional, Iterable, Tuple, List

import torch
import numpy as np

## 1. cross_entropy_loss

交叉熵（单步）定义：
ℓ = -log softmax(logits)[target]

数值稳定写法：

* 减去最大值 `m = logits.max(dim=-1)`
* `logsumexp = m + log(sum(exp(logits - m)))`
* 目标 token 的 log-prob = logits[gather] - logsumexp
* 取负号并对 batch 取平均

注意：

* logits shape: `(..., vocab_size)`
* targets shape: `(...)` (同 batch 维，但少了 vocab_size 那一维)
* 我们要返回标量 loss（平均）

In [2]:
def cross_entropy_loss(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """
    Compute average cross-entropy over all batch positions.

    logits: (..., vocab_size) float32/float16/bfloat16
    targets: (...) int64, same leading shape as logits without vocab dim

    returns: scalar tensor (mean loss)
    """
    # keep original dtype for later? no need, loss is float32 typically
    # We'll upcast to float32 for numerical stability.
    logits_f32 = logits.to(torch.float32)

    # max over vocab dimension (last dim)
    max_logits, _ = torch.max(logits_f32, dim=-1, keepdim=True)  # (..., 1)

    # shift for stability
    shifted = logits_f32 - max_logits  # (..., vocab)

    # logsumexp = log(sum(exp(shifted)))
    sum_exp = torch.sum(torch.exp(shifted), dim=-1, keepdim=False)  # (...)
    logsumexp = torch.log(sum_exp) + max_logits.squeeze(-1)        # (...)

    # gather the logit of the true target
    # targets shape (...) -> need to align
    true_logits = torch.take_along_dim(
        logits_f32, targets.unsqueeze(-1), dim=-1
    ).squeeze(-1)  # (...)

    # nll = -(true_logit - logsumexp)
    nll = -(true_logits - logsumexp)  # (...,)

    # average over all positions in batch
    loss = nll.mean()
    return loss


## 2. 自定义 AdamW

按照作业算法：

* 维护一阶动量 m、二阶动量 v（存在 `self.state[p]`）
* 偏置修正：
  α_t = α * sqrt(1-β2^t)/(1-β1^t)
* 参数更新：
  θ ← θ - α_t * m / (sqrt(v)+eps)
  然后再做权重衰减：θ ← θ - α * weight_decay * θ

注意：

* 我们用 `torch.optim.Optimizer` 作为基类（这是允许的）
* 不允许用现成的 Adam/AdamW
* 我们也实现 `zero_grad()` 的惯用法：就用父类里的 `optimizer.zero_grad()`

In [3]:
class AdamW(torch.optim.Optimizer):
    """
    Minimal AdamW optimizer from scratch (decoupled weight decay).
    Matches the algorithm described in the handout.
    """

    def __init__(
        self,
        params,
        lr: float = 1e-3,
        betas: Tuple[float, float] = (0.9, 0.95),
        eps: float = 1e-8,
        weight_decay: float = 0.0,
    ):
        if lr < 0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if eps < 0:
            raise ValueError(f"Invalid eps: {eps}")
        if weight_decay < 0:
            raise ValueError(f"Invalid weight_decay: {weight_decay}")

        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        """
        Performs a single optimization step.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        # Iterate over param groups
        for group in self.param_groups:
            lr = group["lr"]
            beta1, beta2 = group["betas"]
            eps = group["eps"]
            wd = group["weight_decay"]

            for p in group["params"]:
                if p.grad is None:
                    continue

                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError("AdamW does not support sparse gradients")

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p.data)
                    state["exp_avg_sq"] = torch.zeros_like(p.data)

                exp_avg = state["exp_avg"]
                exp_avg_sq = state["exp_avg_sq"]

                state["step"] += 1
                t = state["step"]

                # Update biased first moment estimate
                exp_avg.mul_(beta1).add_(grad, alpha=(1 - beta1))
                # Update biased second raw moment estimate
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))

                # Bias correction
                bias_correction1 = 1 - beta1 ** t
                bias_correction2 = 1 - beta2 ** t
                # Compute step size
                step_size = lr * math.sqrt(bias_correction2) / bias_correction1

                # denom = sqrt(v) + eps
                denom = exp_avg_sq.sqrt().add_(eps)

                # Adam update
                p.data.addcdiv_(exp_avg, denom, value=-step_size)

                # Decoupled weight decay
                if wd != 0:
                    p.data.add_(p.data, alpha=-lr * wd)

        return loss


## 3. 学习率调度（余弦退火 + warmup）

按照 handout:

* t < T_w: 线性 warmup 到 α_max
* T_w ≤ t ≤ T_c: 余弦下降到 α_min
* t > T_c: 固定 α_min

我们实现成一个函数 `cosine_lr_schedule(step)`，并在训练循环里每个 step 调用它更新 optimizer.param_groups[i]["lr"]


In [4]:
def cosine_lr_schedule(
    t: int,
    lr_max: float,
    lr_min: float,
    warmup_iters: int,
    cosine_iters: int,
) -> float:
    """
    Return lr_t for iteration t.
    warmup_iters: Tw
    cosine_iters: Tc
    """
    if t < warmup_iters:
        # linear warmup 0 -> lr_max
        return lr_max * (t / max(1, warmup_iters))

    if t <= cosine_iters:
        # cosine anneal from lr_max -> lr_min
        progress = (t - warmup_iters) / max(1, (cosine_iters - warmup_iters))
        # cosine from 0..pi
        cosine_term = 0.5 * (1.0 + math.cos(math.pi * progress))
        return lr_min + (lr_max - lr_min) * cosine_term

    # post-anneal
    return lr_min


## 4. 梯度裁剪 (Gradient Clipping)

思路：

1. 统计所有参数梯度的整体 L2 范数
2. 如果超过 max_norm，则按比例缩放每个 p.grad
3. 加上一个很小的 eps 防止除0

注意：这个函数 **就地修改** grad，是作业要求。

In [5]:
@torch.no_grad()
def clip_gradients(params: Iterable[torch.nn.Parameter], max_norm: float, eps: float = 1e-6):
    """
    In-place gradient clipping to max global L2 norm.
    """
    # gather all grads into a single norm
    total_norm_sq = 0.0
    for p in params:
        if p.grad is not None:
            total_norm_sq += float(torch.sum(p.grad.data.to(torch.float32) ** 2))

    total_norm = math.sqrt(total_norm_sq)

    if total_norm > max_norm:
        scale = max_norm / (total_norm + eps)
        for p in params:
            if p.grad is not None:
                p.grad.data.mul_(scale)
    # return the unclipped norm for logging
    return total_norm


## 5. get_batch: 从整段 token 序列中随机采样 batch

输入：

* `data`: numpy array / memmap of uint16 (token IDs)
* `batch_size`
* `context_len` (m)
* `device` (e.g. "cuda")

输出：

* inputs: shape (B, m)
* targets: shape (B, m)
  两个 tensor 都放到 device 上。

采样方法：

* 随机起点 i
* x[i : i+m] → input
* x[i+1 : i+m+1] → target

In [6]:
def get_batch(
    data: np.ndarray,
    batch_size: int,
    context_len: int,
    device: str = "cuda",
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Randomly slice subsequences of length context_len from data,
    and return (inputs, targets), each (B, context_len) on device.
    data is a 1-D array of token IDs, e.g. uint16.
    """
    n = len(data)
    # choose random starting indices in [0, n - context_len - 1]
    # so that i+context_len is valid and i+context_len < n
    starts = np.random.randint(0, n - context_len - 1, size=(batch_size,))
    # build batches
    x_batch = np.stack([data[i : i + context_len] for i in starts])           # (B, m)
    y_batch = np.stack([data[i + 1 : i + 1 + context_len] for i in starts])   # (B, m)

    # convert to torch
    x_t = torch.tensor(x_batch, dtype=torch.long, device=device)
    y_t = torch.tensor(y_batch, dtype=torch.long, device=device)
    return x_t, y_t


## 6. checkpoint 保存 / 读取

我们要能：

* 保存模型权重、优化器状态、当前 step/iter
* 以后可以从 checkpoint 恢复训练

用到：

* `model.state_dict()`
* `optimizer.state_dict()`
* `torch.save / torch.load`

In [7]:
def save_checkpoint(model, optimizer, iteration: int, out_path: str):
    """
    Save model/optimizer/iteration to a single file.
    out_path: str or Path-like
    """
    payload = {
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "iteration": iteration,
    }
    torch.save(payload, out_path)


def load_checkpoint(model, optimizer, src_path: str) -> int:
    """
    Load from checkpoint file.
    Returns the iteration we loaded.
    """
    payload = torch.load(src_path, map_location="cpu")
    model.load_state_dict(payload["model_state"])
    optimizer.load_state_dict(payload["optimizer_state"])
    iteration = payload["iteration"]
    return iteration


## 7. TrainingConfig

这是一个小的 dataclass，用来把训练 loop 里会用到的超参数集中管理，方便后面的主循环 notebook 引用。

包含：

* 模型结构
* 优化器超参
* scheduler 超参
* 日志/保存相关配置


In [8]:
@dataclass
class TrainingConfig:
    # data / batching
    batch_size: int = 32
    context_len: int = 256

    # training steps
    total_steps: int = 10_000
    log_every: int = 50
    eval_every: int = 500
    ckpt_every: int = 1000

    # optimizer (AdamW)
    lr_max: float = 3e-4       # peak lr
    lr_min: float = 3e-5       # final lr after cosine
    betas: Tuple[float, float] = (0.9, 0.95)
    eps: float = 1e-8
    weight_decay: float = 0.1
    grad_clip_norm: float = 1.0

    # scheduler
    warmup_iters: int = 200
    cosine_iters: int = 10_000  # after this, lr = lr_min

    # misc
    device: str = "cuda"
    mixed_precision: bool = True
    ckpt_dir: str = "./checkpoints"
    run_name: str = "cs336_lm_run"
