# Week 1 — Optimizers as State Machines: Memory, Weight Decay, Gradient Accumulation, Checkpointing

**Promise:** By the end, you can  
1) write each update rule from memory,  
2) compute (and *measure*) optimizer memory overhead from first principles,  
3) reproduce/verify PyTorch behavior with minimal code, and  
4) explain the “why” verbally without hand-waving.

> Scope note: This notebook intentionally **does not** re-teach “momentum” and “second-order momentum” as concepts. It treats them as known and focuses on **comparisons, memory/state, and PyTorch semantics**.


## 0 — Agenda, deliverables, and “whiteboard readiness”

### 0.1 Mentor checklist → notebook map

| Mentor item | Where answered | Proof artifact |
|---|---|---|
| SGD vs. SGD with momentum (memory footprint, #params) | §2 | `optimizer.state` inspection + byte counting |
| SGD vs. Adam (memory footprint, #params) | §3 | state inspection + byte counting |
| Weight decay in Adam vs AdamW | §4 | scalar “unit test” + PyTorch flag semantics |
| Gradient accumulation (what + how in PyTorch) | §5 | correctness test vs “true big batch” |
| Activation / gradient checkpointing | §6 | peak memory + runtime comparison |

### 0.2 Deliverables

- **One-page cheat sheet** at the end: update rules + memory multipliers + “when-to-use”.
- **“Prove it” appendix:** tiny scalar examples showing Adam vs AdamW decay differs.
- **PyTorch verification suite:** code cells that inspect `optimizer.state_dict()` and (if CUDA) measure peak memory.

### 0.3 Reading links (primary sources)

- Lightly blog (decision heuristics, resource framing):  
  https://www.lightly.ai/blog/which-optimizer-should-i-use-for-my-machine-learning-project
- PyTorch docs (SGD / Adam / AdamW / checkpoint):  
  - https://docs.pytorch.org/docs/stable/generated/torch.optim.SGD.html  
  - https://docs.pytorch.org/docs/stable/generated/torch.optim.Adam.html  
  - https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html  
  - https://docs.pytorch.org/docs/stable/checkpoint.html
- Adam paper (Kingma & Ba, 2014): https://arxiv.org/abs/1412.6980  
- Decoupled Weight Decay Regularization (Loshchilov & Hutter, 2017): https://arxiv.org/abs/1711.05101  

> Ground rule for this notebook: if “PyTorch semantics” matters, we trust **the official docs + a runnable micro-test** over folklore.


In [None]:
# Setup
import os, time, math, inspect
from dataclasses import dataclass
from typing import Dict, Any, Iterable, Tuple, List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

print("torch version:", torch.__version__)
print("cuda available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("cuda device:", torch.cuda.get_device_name(0))


## 1 — First principles: what costs memory during training?

Think of training as a small ecosystem of tensors:

- **Parameters**: \(\theta\) (persist across steps)
- **Gradients**: \(g = \nabla_\theta L\) (persist unless cleared)
- **Optimizer state**: extra persistent tensors per parameter (e.g. momentum buffers, Adam moments)
- **Activations**: intermediate tensors saved for backward (often the *dominant* memory term)

### 1.1 A clean accounting model

Let

- \(P\) = total number of scalar parameters (sum of `.numel()`)
- \(b\) = bytes per scalar (fp32 → 4, bf16/fp16 → 2)

Then (very roughly):

- params memory \(\approx P \cdot b\)
- grads memory \(\approx P \cdot b\) (if grads are materialized)
- optimizer state depends on optimizer: **SGD: 0–\(P\)**, **Adam: \(2P\)**, etc.
- activations depend on batch/sequence/depth; checkpointing targets *this* term.

### 1.2 PyTorch nuance: `zero_grad(set_to_none=True)`

PyTorch’s `Optimizer.zero_grad()` supports setting gradients to `None` (not zeros). This:
- typically **reduces memory** and can be faster,
- but changes behavior: optimizers may **skip** parameters whose grad is `None`.

We’ll demo that below.


In [None]:
# Helper utilities: byte counting and state inspection

def pretty_bytes(n: int) -> str:
    # Human-readable byte formatting
    suffixes = ["B", "KB", "MB", "GB", "TB"]
    x = float(n)
    for s in suffixes:
        if x < 1024 or s == suffixes[-1]:
            return f"{x:.2f} {s}"
        x /= 1024.0

def tensor_nbytes(t: torch.Tensor) -> int:
    return t.numel() * t.element_size()

def count_params(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())

def params_nbytes(model: nn.Module) -> int:
    return sum(tensor_nbytes(p.data) for p in model.parameters())

def grads_nbytes(model: nn.Module) -> int:
    total = 0
    for p in model.parameters():
        if p.grad is not None:
            total += tensor_nbytes(p.grad)
    return total

def optimizer_state_nbytes(optim: torch.optim.Optimizer) -> int:
    # internal state uses Parameter objects as keys
    total = 0
    for p, st in optim.state.items():
        for k, v in st.items():
            if torch.is_tensor(v):
                total += tensor_nbytes(v)
            elif isinstance(v, (list, tuple)):
                for item in v:
                    if torch.is_tensor(item):
                        total += tensor_nbytes(item)
    return total

def optimizer_state_summary(optim: torch.optim.Optimizer, max_items: int = 12) -> List[Tuple[str, Tuple[int, ...], str]]:
    rows = []
    for p, st in optim.state.items():
        for k, v in st.items():
            if torch.is_tensor(v):
                rows.append((k, tuple(v.shape), str(v.dtype)))
    # Sort for stable display
    rows.sort(key=lambda x: (x[0], x[1]))
    return rows[:max_items]

def report_memory(model: nn.Module, optim: torch.optim.Optimizer, label: str = "") -> None:
    P = count_params(model)
    print(f"--- {label} ---")
    print("P (numel):", P)
    print("params:", pretty_bytes(params_nbytes(model)))
    print("grads :", pretty_bytes(grads_nbytes(model)))
    print("state :", pretty_bytes(optimizer_state_nbytes(optim)))
    print("state keys sample:", optimizer_state_summary(optim))
    print()

def cuda_peak_bytes() -> Optional[int]:
    if not torch.cuda.is_available():
        return None
    return torch.cuda.max_memory_allocated()

def measure_peak_cuda(fn, *args, **kwargs) -> Optional[int]:
    if not torch.cuda.is_available():
        print("CUDA not available; skipping peak-memory measurement.")
        return None
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    fn(*args, **kwargs)
    torch.cuda.synchronize()
    return torch.cuda.max_memory_allocated()

print("Optimizer.zero_grad signature:", inspect.signature(torch.optim.Optimizer.zero_grad))


In [None]:
# Tiny demo model
class TinyMLP(nn.Module):
    def __init__(self, d_in=128, d_h=256, d_out=10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, d_h),
            nn.ReLU(),
            nn.Linear(d_h, d_out),
        )
    def forward(self, x):
        return self.net(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TinyMLP().to(device)

x = torch.randn(32, 128, device=device)
y = torch.randint(0, 10, (32,), device=device)

crit = nn.CrossEntropyLoss()

# We'll use a placeholder optimizer for now
optim = torch.optim.SGD(model.parameters(), lr=1e-2)

report_memory(model, optim, "initial (before any backward/step)")


In [None]:
# 1.3 Demo: zero_grad(set_to_none=True) vs zeros

# First backward pass -> grads materialize
optim.zero_grad(set_to_none=True)
loss = crit(model(x), y)
loss.backward()
print("After backward: grads_nbytes =", pretty_bytes(grads_nbytes(model)))

# Set grads to None
optim.zero_grad(set_to_none=True)
none_count = sum(1 for p in model.parameters() if p.grad is None)
print("After zero_grad(set_to_none=True): #None grads =", none_count, "out of", len(list(model.parameters())))
print("grads_nbytes now =", pretty_bytes(grads_nbytes(model)))

# Backward again
loss = crit(model(x), y)
loss.backward()

# Set grads to zeros (materializes/keeps grad tensors)
optim.zero_grad(set_to_none=False)
zero_count = sum(1 for p in model.parameters() if (p.grad is not None and torch.all(p.grad == 0)))
print("After zero_grad(set_to_none=False): grads exist for all params?",
      all(p.grad is not None for p in model.parameters()))
print("Example: #grads that are all-zero right now (often all):", zero_count)
print("grads_nbytes now =", pretty_bytes(grads_nbytes(model)))


## 2 — SGD vs SGD with momentum (comparisons, not re-teaching momentum)

### 2.1 Update rule (PyTorch doc version)

PyTorch’s SGD algorithm (with optional momentum) conceptually does:

- optional L2 penalty: \(g_t \leftarrow g_t + \lambda \theta_{t-1}\)
- optional momentum buffer \(b_t\)
- parameter update: \(\theta_t \leftarrow \theta_{t-1} - \gamma \, g_t\) (after any momentum/Nesterov transform)

(Exact ordering/notation is in the official docs; PyTorch also notes its momentum differs subtly from some textbook variants.)

### 2.2 “#params” answered precisely

- **Trainable parameters**: the model’s \(\theta\). This number is **identical** for SGD, SGD+momentum, Adam, AdamW.
- **Optimizer state**: extra persistent tensors (not trainable parameters) that increase:
  - GPU memory
  - checkpoint size
  - optimizer step compute

### 2.3 Memory footprint (theory)

Let \(P\) be parameter scalars.

Persistent scalars (very rough; ignoring activations):

- SGD (no momentum): params \(P\) + grads \(P\) + state \(0\) → \(\approx 2P\)
- SGD + momentum: params \(P\) + grads \(P\) + momentum buffer \(P\) → \(\approx 3P\)

We’ll *verify* by inspecting `optimizer.state` after one step (state is often allocated lazily).


In [None]:
# 2.4 Memory footprint comparison: measured via optimizer.state

torch.manual_seed(0)

model = TinyMLP().to(device)
x = torch.randn(32, 128, device=device)
y = torch.randint(0, 10, (32,), device=device)
crit = nn.CrossEntropyLoss()

def one_step(optim: torch.optim.Optimizer, set_to_none: bool = True):
    optim.zero_grad(set_to_none=set_to_none)
    loss = crit(model(x), y)
    loss.backward()
    optim.step()
    return loss.item()

# SGD without momentum
optim_sgd = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.0)
report_memory(model, optim_sgd, "SGD before step")
one_step(optim_sgd)
report_memory(model, optim_sgd, "SGD after 1 step")

# SGD with momentum
model2 = TinyMLP().to(device)
optim_mom = torch.optim.SGD(model2.parameters(), lr=1e-2, momentum=0.9)
report_memory(model2, optim_mom, "SGD+momentum before step")
# do one step
optim_mom.zero_grad(set_to_none=True)
loss = crit(model2(x), y)
loss.backward()
optim_mom.step()
report_memory(model2, optim_mom, "SGD+momentum after 1 step")

print("Note: momentum creates per-parameter state keys like 'momentum_buffer'.")


## 3 — SGD vs Adam (memory footprint + practical framing)

### 3.1 Adam as a state machine (minimal)

Adam keeps two per-parameter exponential moving averages (EMAs):
- \(m_t\): EMA of gradients (first moment estimate)
- \(v_t\): EMA of squared gradients (second moment estimate)

Update uses a normalized step roughly like:
\[
\theta_t \leftarrow \theta_{t-1} - \gamma \frac{\hat m_t}{\sqrt{\hat v_t} + \epsilon}.
\]

### 3.2 Memory footprint (theory)

- Adam state: \(m\) and \(v\) → **+2P** scalars  
- Total persistent ≈ params \(P\) + grads \(P\) + state \(2P\) → \(\approx 4P\)

If AMSGrad is enabled, there’s an extra \(v_{\max}\) buffer: **+P** more (≈ \(5P\)).

### 3.3 Important PyTorch nuance (peak memory)

PyTorch notes that `foreach` optimizer implementations can use **~sizeof(params)** extra *peak* memory due to intermediate tensorlists. This affects peak CUDA memory measurements even if persistent state sizes are unchanged.


In [None]:
# 3.3 Memory footprint comparison: SGD vs Adam (measured)

torch.manual_seed(0)

model_sgd = TinyMLP().to(device)
model_adam = TinyMLP().to(device)

optim_sgd = torch.optim.SGD(model_sgd.parameters(), lr=1e-2, momentum=0.9)
optim_adam = torch.optim.Adam(model_adam.parameters(), lr=1e-3)

def train_step(model, optim):
    optim.zero_grad(set_to_none=True)
    loss = crit(model(x), y)
    loss.backward()
    optim.step()
    return loss.item()

report_memory(model_sgd, optim_sgd, "SGD+mom before")
train_step(model_sgd, optim_sgd)
report_memory(model_sgd, optim_sgd, "SGD+mom after 1 step")

report_memory(model_adam, optim_adam, "Adam before")
train_step(model_adam, optim_adam)
report_memory(model_adam, optim_adam, "Adam after 1 step")

print("Adam state keys are typically: exp_avg, exp_avg_sq (and step).")


## 4 — Weight decay in Adam vs AdamW (the “don’t mess this up” section)

This is where a lot of people confidently say wrong things.

### 4.1 Two concepts that people conflate

**L2 regularization (penalty in the objective)**  
Optimize:
\[
L(\theta) + \frac{\lambda}{2}\|\theta\|^2
\]
Then:
\[
\nabla_\theta \left( L(\theta) + \frac{\lambda}{2}\|\theta\|^2 \right)
= \nabla_\theta L(\theta) + \lambda \theta.
\]
So it **adds** \(\lambda \theta\) into the gradient pipeline.

**Decoupled weight decay (shrink weights directly)**  
Apply:
\[
\theta \leftarrow (1 - \gamma\lambda)\theta
\]
*separately* from the gradient step.

### 4.2 Equivalence for vanilla SGD

SGD with L2:
\[
\theta \leftarrow \theta - \gamma (g + \lambda\theta)
= (1-\gamma\lambda)\theta - \gamma g.
\]
So for *plain SGD*, “L2 penalty” and “weight decay” are effectively the same transformation.

### 4.3 Non-equivalence for Adam

If you inject \(\lambda\theta\) into Adam’s gradient path, it gets adapted/normalized by \(1/\sqrt{v}\), so the effective shrink is **parameter-wise and history-dependent**.

AdamW decouples: decay is applied directly to \(\theta\), and **does not accumulate** into the moment estimates.

### 4.4 PyTorch reality check

In current PyTorch docs:

- `torch.optim.Adam(..., weight_decay=λ, decoupled_weight_decay=False)` treats `weight_decay` as an **L2 penalty** (added into the gradient).
- `torch.optim.Adam(..., decoupled_weight_decay=True)` is documented as **equivalent to AdamW**.
- `torch.optim.AdamW` applies decay as a separate parameter shrink step and “does not accumulate in the momentum nor variance.”

We’ll verify with a scalar micro-test.


In [None]:
# 4.5 Scalar “unit test”: show Adam(L2-style) vs AdamW(decoupled) differs

def scalar_step(optim_ctor, *, lr=0.1, wd=0.1, betas=(0.0, 0.0), eps=1e-8, grad_value=0.0, decoupled=None):
    p = torch.nn.Parameter(torch.tensor([1.0], device=device))
    # Build optimizer with flexible signature
    kwargs = dict(lr=lr, betas=betas, eps=eps, weight_decay=wd)
    if decoupled is not None:
        kwargs["decoupled_weight_decay"] = decoupled
    optim = optim_ctor([p], **kwargs)
    # Manually assign a grad
    p.grad = torch.tensor([grad_value], device=device)
    optim.step()
    return float(p.detach().cpu().item())

# Adam: L2-style weight decay (decoupled_weight_decay=False)
theta_adam_l2 = scalar_step(torch.optim.Adam, decoupled=False, grad_value=0.0)

# Adam: decoupled weight decay (if supported) — should match AdamW behavior
supports_decoupled = "decoupled_weight_decay" in inspect.signature(torch.optim.Adam).parameters
theta_adam_decoupled = None
if supports_decoupled:
    theta_adam_decoupled = scalar_step(torch.optim.Adam, decoupled=True, grad_value=0.0)

# AdamW: decoupled
theta_adamw = scalar_step(torch.optim.AdamW, grad_value=0.0)

print("Initial θ = 1.0, grad = 0.0, lr=0.1, wd=0.1, betas=(0,0)")
print("Adam (L2-style wd):        θ ->", theta_adam_l2)
if supports_decoupled:
    print("Adam (decoupled wd=True): θ ->", theta_adam_decoupled)
print("AdamW (decoupled):        θ ->", theta_adamw)

print()
print("Interpretation:")
print("- Adam(L2-style): decay goes through Adam normalization; with grad=0, the first step can be ~lr in magnitude.")
print("- AdamW/decoupled: θ shrinks multiplicatively by (1 - lr*wd) when grad=0.")


In [None]:
# 4.5b Same test but with a nonzero gradient (to see combined effects)

theta_adam_l2_g = scalar_step(torch.optim.Adam, decoupled=False, grad_value=0.1)
theta_adamw_g = scalar_step(torch.optim.AdamW, grad_value=0.1)
print("Initial θ = 1.0, grad = 0.1, lr=0.1, wd=0.1, betas=(0,0)")
print("Adam (L2-style wd): θ ->", theta_adam_l2_g)
print("AdamW (decoupled):  θ ->", theta_adamw_g)


### 4.6 Practical pattern: exclude weight decay on biases and normalization layers

Common training practice (especially for Transformers) is to apply weight decay to “weight matrices” but **not** to:
- bias terms
- LayerNorm / BatchNorm parameters (scale/shift)

Reason: those parameters often act like *calibration knobs*, and decaying them can hurt.

Below is a robust PyTorch parameter-group builder that you can reuse.


In [None]:
# Parameter-group utility: decay vs no-decay (bias + norm excluded)

def build_param_groups_for_weight_decay(model: nn.Module, weight_decay: float) -> List[Dict[str, Any]]:
    decay_params = []
    no_decay_params = []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        is_bias = name.endswith(".bias")
        is_norm = any(nd in name.lower() for nd in ["bn", "batchnorm", "layernorm", "ln", "norm"])
        if is_bias or is_norm:
            no_decay_params.append(p)
        else:
            decay_params.append(p)
    return [
        {"params": decay_params, "weight_decay": weight_decay},
        {"params": no_decay_params, "weight_decay": 0.0},
    ]

m = TinyMLP().to(device)
groups = build_param_groups_for_weight_decay(m, weight_decay=0.01)
print("# decay params:", sum(p.numel() for p in groups[0]["params"]))
print("# no-decay params:", sum(p.numel() for p in groups[1]["params"]))

opt = torch.optim.AdamW(groups, lr=1e-3)
print("Param group weight_decays:", [g["weight_decay"] for g in opt.param_groups])


## 5 — Gradient accumulation (what problem + how in PyTorch)

### 5.1 The problem it solves

You want a larger **effective batch size** (for stability, variance reduction, or to match a paper), but your GPU can’t fit that batch’s activations.

**Gradient accumulation** simulates a batch of size \(B = K \cdot b\) by splitting it into \(K\) microbatches of size \(b\), accumulating gradients, then stepping once.

### 5.2 The math you must get right

If your loss is a *mean over the batch* (typical), then to match the gradient of the big batch you want:

- either sum per-example losses over all microbatches and divide once,
- or equivalently: **divide each microbatch loss by \(K\)** before `backward()`.

### 5.3 Canonical PyTorch loop

```
optimizer.zero_grad()
for micro_step in range(K):
    loss = loss_fn(...) / K
    loss.backward()
optimizer.step()
```

### 5.4 Verification test (must-have)

Below we compare gradients from:
1) one big batch, vs  
2) K microbatches with accumulation

on a deterministic model (no dropout/BatchNorm).


In [None]:
# 5.4 Correctness proof: big batch vs gradient accumulation

torch.manual_seed(123)

# Simple model with no stochastic layers
model_big = nn.Sequential(nn.Linear(16, 32), nn.ReLU(), nn.Linear(32, 4)).to(device)
model_acc = nn.Sequential(nn.Linear(16, 32), nn.ReLU(), nn.Linear(32, 4)).to(device)
# Force identical init
model_acc.load_state_dict(model_big.state_dict())

loss_fn = nn.MSELoss(reduction="mean")

B = 64
K = 4
b = B // K

X = torch.randn(B, 16, device=device)
T = torch.randn(B, 4, device=device)

# --- Big batch gradients ---
opt_big = torch.optim.SGD(model_big.parameters(), lr=0.1)
opt_big.zero_grad(set_to_none=True)
loss_big = loss_fn(model_big(X), T)
loss_big.backward()
grads_big = [p.grad.detach().clone() for p in model_big.parameters()]

# --- Accumulation gradients ---
opt_acc = torch.optim.SGD(model_acc.parameters(), lr=0.1)
opt_acc.zero_grad(set_to_none=True)
for i in range(K):
    xb = X[i*b:(i+1)*b]
    tb = T[i*b:(i+1)*b]
    loss = loss_fn(model_acc(xb), tb) / K
    loss.backward()

grads_acc = [p.grad.detach().clone() for p in model_acc.parameters()]

# Compare
max_abs = 0.0
for g1, g2 in zip(grads_big, grads_acc):
    max_abs = max(max_abs, (g1 - g2).abs().max().item())

print("loss_big:", float(loss_big.detach().cpu().item()))
print("max |grad_big - grad_acc|:", max_abs)
print("Expected: ~0 (numerical noise only).")


### 5.5 Gotchas (short but sharp)

- **BatchNorm**: accumulation does *not* make BN see the full big batch; BN stats are computed per microbatch.
- **Dropout / randomness**: big-batch vs microbatch equivalence assumes deterministic forward; with dropout, use a fixed RNG strategy if you need strict equivalence.
- **AMP / GradScaler**: typically you scale each micro-loss (already divided by \(K\)), call `scaler.scale(loss).backward()`, and only `scaler.step(optimizer)` once per accumulation cycle.
- **Gradient clipping**: clip *after* accumulation (i.e., right before stepping).


In [None]:
# 5.5 (Optional) AMP + accumulation sketch (runs only if CUDA)

from contextlib import nullcontext

def amp_accumulation_sketch(model, optimizer, X, T, K: int):
    scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
    autocast = torch.cuda.amp.autocast if torch.cuda.is_available() else nullcontext

    optimizer.zero_grad(set_to_none=True)
    B = X.shape[0]
    b = B // K

    for i in range(K):
        xb = X[i*b:(i+1)*b]
        tb = T[i*b:(i+1)*b]
        with autocast():
            loss = loss_fn(model(xb), tb) / K
        scaler.scale(loss).backward()

    # unscale before clipping (if you clip)
    scaler.unscale_(optimizer)
    # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)

print("This cell defines a sketch; it doesn't run a full training loop.")


## 6 — Activation / gradient checkpointing (the activation-memory lever)

### 6.1 Why activations dominate

Autograd needs intermediate tensors from the forward pass to compute gradients in the backward pass. Those saved activations can dominate memory, especially for:
- deep nets
- large batch sizes
- long sequences (Transformers)
- large feature maps (vision)

### 6.2 What checkpointing does

Activation checkpointing trades **compute for memory**:
- forward pass in checkpointed region does **not** save intermediates
- backward pass **recomputes** the forward for that region to recover needed activations

### 6.3 PyTorch API reality (important)

`torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=...)` has **two implementations**:
- `use_reentrant=True` (older “reentrant autograd” variant)
- `use_reentrant=False` (newer non-reentrant variant)

PyTorch docs recommend **`use_reentrant=False`** and warn you should pass it explicitly; newer versions will raise if you don’t.

We’ll run a peak-memory + runtime comparison (if CUDA is available).


In [None]:
# 6.4 Measurement experiment: checkpointing vs no checkpointing

from torch.utils.checkpoint import checkpoint

class DeepMLP(nn.Module):
    def __init__(self, d=2048, depth=12):
        super().__init__()
        layers = []
        for _ in range(depth):
            layers.append(nn.Linear(d, d))
            layers.append(nn.GELU())
        self.seq = nn.Sequential(*layers)

    def forward(self, x):
        return self.seq(x)

def forward_backward_step(model: nn.Module, x: torch.Tensor, use_ckpt: bool):
    model.train()
    # Simple scalar loss
    if use_ckpt:
        # Wrap the sequential as a function. Note: checkpoint expects a function, not a Module.
        def fn(inp):
            return model.seq(inp)
        out = checkpoint(fn, x, use_reentrant=False)
    else:
        out = model(x)
    loss = out.pow(2).mean()
    loss.backward()
    return float(loss.detach().cpu().item())

# Try to keep this runnable on both CPU and GPU.
d = 1024 if not torch.cuda.is_available() else 2048
depth = 10 if not torch.cuda.is_available() else 14

model_plain = DeepMLP(d=d, depth=depth).to(device)
model_ckpt = DeepMLP(d=d, depth=depth).to(device)
model_ckpt.load_state_dict(model_plain.state_dict())

x_big = torch.randn(8 if torch.cuda.is_available() else 2, d, device=device, requires_grad=True)

# Warmup
for _ in range(2):
    model_plain.zero_grad(set_to_none=True)
    loss = forward_backward_step(model_plain, x_big, use_ckpt=False)
    model_plain.zero_grad(set_to_none=True)

def run_plain():
    model_plain.zero_grad(set_to_none=True)
    forward_backward_step(model_plain, x_big, use_ckpt=False)

def run_ckpt():
    model_ckpt.zero_grad(set_to_none=True)
    forward_backward_step(model_ckpt, x_big, use_ckpt=True)

# Measure time + peak memory
def timed(fn, n=5):
    t0 = time.time()
    for _ in range(n):
        fn()
        if torch.cuda.is_available():
            torch.cuda.synchronize()
    return (time.time() - t0) / n

t_plain = timed(run_plain, n=5)
t_ckpt = timed(run_ckpt, n=5)

peak_plain = measure_peak_cuda(run_plain)
peak_ckpt = measure_peak_cuda(run_ckpt)

print("avg step time (plain):", t_plain, "s")
print("avg step time (ckpt): ", t_ckpt, "s")

if peak_plain is not None:
    print("peak CUDA (plain):", pretty_bytes(peak_plain))
    print("peak CUDA (ckpt): ", pretty_bytes(peak_ckpt))
    print("Δpeak (plain-ckpt):", pretty_bytes(peak_plain - peak_ckpt))
else:
    print("CUDA not available; memory comparison skipped. On GPU you should see lower peak with checkpointing, but slower time.")


### 6.5 Determinism and RNG gotcha

Checkpointing may re-run forward segments during backward, which can advance RNG state differently than a non-checkpointed run.

PyTorch’s checkpoint implementation stashes/restores RNG state by default (`preserve_rng_state=True`) to make dropout outputs match non-checkpointed behavior, but that can add overhead. If you don’t need that equivalence, you can set `preserve_rng_state=False` (with careful thought).

Also note: if you move tensors across devices inside the checkpointed function, RNG-state juggling may not do what you expect.


## 7 — Synthesis: one-page cheat sheet + meeting talk tracks

### 7.1 One-page cheat sheet (memorize-ready)

Below is the “state machine summary.” Print this. Tape it to your forehead (optional).


### Update rules (conceptual)

Let \(g = \nabla_\theta L(\theta)\), learning rate \(\gamma\), weight decay \(\lambda\).

**SGD (PyTorch semantics)**  
Optional L2 penalty:
\[
g \leftarrow g + \lambda \theta
\]
Then:
\[
\theta \leftarrow \theta - \gamma g
\]
Momentum adds a buffer \(b\) (details in PyTorch docs).

**Adam (L2-style `weight_decay` by default)**  
Maintains \(m\) and \(v\), does:
- optional L2: \(g \leftarrow g + \lambda \theta\)
- update \(m, v\)
- \(\theta \leftarrow \theta - \gamma \hat m / (\sqrt{\hat v} + \epsilon)\)

**AdamW (decoupled)**  
First:
\[
\theta \leftarrow \theta - \gamma \lambda \theta
\]
Then the Adam normalized gradient step (without mixing decay into \(m,v\)).

### Memory multipliers (persistent state, in units of P scalars)

Ignoring activations:

| Optimizer | Extra optimizer state | Total persistent (params+grads+state) |
|---|---:|---:|
| SGD (no momentum) | ~0P | ~2P |
| SGD + momentum | ~1P | ~3P |
| Adam / AdamW | ~2P | ~4P |
| Adam/AdamW + AMSGrad | ~3P | ~5P |

> Peak memory can differ from this due to activations and temporary optimizer intermediates (e.g., `foreach`).


### 7.2 “Explain in 90 seconds” scripts

**SGD vs SGD+momentum**  
SGD stores just parameters and their gradients. With momentum, it also stores a per-parameter velocity/momentum buffer. That buffer costs ~1× parameter memory. It often improves optimization by smoothing noisy gradients and accelerating consistent directions, but the model parameter count is unchanged — only optimizer state grows.

**Adam vs SGD**  
Adam adds two EMAs per parameter (first and second moment). That’s +2× parameter memory in optimizer state. It adapts step sizes per parameter based on historical gradient magnitudes, often converging faster with less hyperparameter tuning, but it’s heavier in memory and sometimes needs careful regularization for good generalization.

**Adam vs AdamW**  
In Adam (default), `weight_decay` behaves like an L2 penalty injected into the gradient pipeline; in adaptive methods, this makes the effective decay depend on the adaptive normalization. AdamW decouples weight decay, shrinking weights directly without contaminating moment estimates. This difference is easy to see in a scalar test.

**Gradient accumulation**  
It simulates a larger batch by summing/averaging gradients across microbatches and stepping once. If you divide each micro-loss by K, you match the large-batch gradient (for deterministic models). It’s not identical when layers depend on batch statistics (BatchNorm) or randomness (dropout) unless you manage those effects.

**Checkpointing**  
It reduces activation memory by not saving intermediates in some forward regions. Backward recomputes those regions, so it costs extra compute but can unlock larger models/batches.


### 7.3 “Mentor might ask” readiness drills (do these without a laptop)

- If a model has \(P\) parameters in fp32, what’s the optimizer-state memory for:
  - SGD vs SGD+momentum vs AdamW?
- Show in 1D why AdamW decay differs from Adam’s L2-style decay.
- Does gradient accumulation perfectly simulate a large batch? If not, what breaks equivalence?
- Why does checkpointing help activations but not optimizer state?

> If you can answer those quickly and precisely, you’re in a very good place for the meeting.


## Appendix A — Quick memory calculator

The helper below gives you a *back-of-the-envelope* memory estimate from \(P\), dtype bytes, and optimizer choice.  
(It ignores activation memory, which is often dominant.)


In [None]:
# Appendix A: simple memory estimator

@dataclass
class MemoryEstimate:
    P: int
    bytes_per_scalar: int
    params: int
    grads: int
    state: int
    total: int

def estimate_persistent_memory(P: int, bytes_per_scalar: int, state_multiplier: int) -> MemoryEstimate:
    params = P * bytes_per_scalar
    grads = P * bytes_per_scalar
    state = state_multiplier * P * bytes_per_scalar
    total = params + grads + state
    return MemoryEstimate(P, bytes_per_scalar, params, grads, state, total)

def print_est(name: str, est: MemoryEstimate):
    print(f"{name}:")
    print("  params:", pretty_bytes(est.params))
    print("  grads :", pretty_bytes(est.grads))
    print("  state :", pretty_bytes(est.state))
    print("  total :", pretty_bytes(est.total))
    print()

# Example: 1 billion params in bf16 (2 bytes) with AdamW state stored in bf16 (best-case)
P = 1_000_000_000
b = 2

print("Assuming optimizer state uses same dtype bytes as params (best-case; many fused/AMP setups use fp32 state).")
print_est("SGD", estimate_persistent_memory(P, b, state_multiplier=0))
print_est("SGD+momentum", estimate_persistent_memory(P, b, state_multiplier=1))
print_est("AdamW", estimate_persistent_memory(P, b, state_multiplier=2))

print("If state is fp32 while params are bf16, state bytes double relative to params.")


## Appendix B — Checklist: what to rehearse before the meeting

- You can explain: “optimizer state vs trainable parameters” cleanly.
- You can derive: SGD L2 ⇔ weight decay equivalence.
- You can explain: why Adam breaks that equivalence and why AdamW decouples.
- You can code (from memory): correct gradient accumulation loop (loss/K + step every K).
- You can explain: checkpointing saves activations, not optimizer state — and why it costs extra compute.

That’s the whole game. Define the state, write the update, compute the memory, then verify with a tiny experiment.
