# Autocast / AMP in PyTorch: a deep practical reference

This notebook is a **learning tool** for understanding and *actually feeling* the impact of `autocast` during training.

It is organized into **three sections**:

1. **Theory** (floating-point range/precision, underflow/overflow, what AMP is really doing)
2. **What the literature says** (paper- and doc-driven mental models, *written explanations*, no experiments)
3. **Practicalities** (hands-on experiments + graphs: loss curves, stability, gradient underflow, and dtype flow through a tiny LLM)

**Primary goal:** after completing this notebook, you should be able to answer (and debug) questions like:

- Why does FP16 often need **loss scaling**, but BF16 often does not?
- What does `autocast` *actually* do per operation (matmul vs softmax vs layernorm)?
- Why do people talk about **FP32 master weights** and **optimizer state precision**?
- How can you *see* autocast happening inside a transformer forward pass?
- What fails if you try to “just train in half precision everywhere”?

---

## How to use this notebook

- Read the markdown, then run the code cells.
- Most experiments are designed to run in a few minutes on a single GPU.
- CPU-only runs are supported for the *conceptual* demos, but some mixed-precision behaviors (and speedups) are fundamentally GPU-driven.

---

## Table of contents (high-level)

- **Section 1 — Theory**
  - Floating-point: range vs precision
  - FP16 vs BF16 vs FP32 (tables you can trust)
  - Underflow/overflow and why training cares
  - What AMP is (autocast + grad scaling)
  - Master weights, optimizer state, and accumulation

- **Section 2 — What the literature says**
  - Mixed Precision Training (Micikevicius et al.)
  - NVIDIA mixed precision guidance
  - BF16 design intent
  - PyTorch AMP operator policy
  - LLM training stacks (FSDP/ZeRO) and where AMP fits

- **Section 3 — Practicalities**
  - A minimal AMP training loop
  - Build an operator policy table *from your local PyTorch*
  - Visualize gradient underflow + the effect of loss scaling
  - Train a tiny causal transformer (LLM) under different precision regimes
  - Plot and interpret loss/time/scale curves

---

## Quick glossary

- **AMP**: automatic mixed precision (in PyTorch: `torch.amp`)
- **autocast**: chooses an op-specific dtype policy inside a context manager
- **GradScaler / loss scaling**: rescales loss to avoid FP16 gradient underflow
- **master weights**: keep weights in FP32 for updates, cast for compute
- **underflow**: magnitude too small → becomes 0 (or subnormal/denormal)
- **overflow**: magnitude too large → becomes `inf`

## Prerequisites

You need:

- Python 3.10+
- PyTorch 2.x
- `matplotlib`, `numpy`, `pandas`

### Install (CPU-only quick start)

```bash
pip install torch numpy pandas matplotlib
```

### Install (CUDA)

Install the correct PyTorch + CUDA build from the official PyTorch instructions for your platform.

---

This notebook is written to *degrade gracefully*:

- If BF16 is not supported on your GPU, BF16 experiments will be skipped.
- If FP16 training without scaling explodes (often does), we record that as a result rather than pretending it “worked”.


Note: on Apple Silicon, this notebook defaults to CPU (AMP/autocast behavior is best-defined on CUDA/CPU).

In [None]:
# Core imports + environment report
import os
import math
import time
import random
from dataclasses import dataclass
from contextlib import nullcontext
from datetime import date

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Plot defaults
plt.rcParams.update({
    "figure.figsize": (10, 4),
    "axes.grid": True,
    "axes.spines.top": False,
    "axes.spines.right": False,
})

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
except ModuleNotFoundError as e:
    raise ModuleNotFoundError(
        "PyTorch is required for this notebook. Install it, restart the kernel, then re-run.\n"
        "CPU-only (quick start): pip install torch\n"
        "CUDA: install the correct wheel from the official PyTorch site."
    ) from e

# Prefer torch.amp (newer API), but fall back for older versions.
if hasattr(torch, "amp"):
    autocast = torch.amp.autocast
    GradScaler = torch.amp.GradScaler
else:
    autocast = torch.cuda.amp.autocast
    GradScaler = torch.cuda.amp.GradScaler


def set_seed(seed: int = 0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


set_seed(0)

# Device selection: CUDA if available, else CPU.
# (MPS exists on Apple, but this notebook is intentionally CUDA/CPU-centric.)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Date:", date.today().isoformat())
print("Torch:", torch.__version__)
print("Device:", device)
if device.type == "cuda":
    print("CUDA:", torch.version.cuda)
    print("GPU:", torch.cuda.get_device_name(0))
    print("BF16 supported:", torch.cuda.is_bf16_supported())

# Section 1 — Theory

The core trick of autocast is simple to state:

> **Run the *right* operations in lower precision for speed/memory, while keeping *numerically sensitive* operations in FP32.**

But to understand *why* this works (and when it doesn’t), we need to understand what floating-point formats can and cannot represent.

## 1.0 Where floating-point lives during training (and why autocast exists)

A single training step can be decomposed into:

1. **Forward**: parameters + activations → logits
2. **Loss**: logits + targets → scalar loss
3. **Backward**: loss → gradients for parameters
4. **Optimizer update**: parameters + gradients (+ optimizer state) → new parameters

Different tensors have different numeric requirements:

| Object | Typical AMP dtype | Why |
|---|---|---|
| Activations / intermediate matmul results | FP16/BF16 (where safe) | saves memory + uses Tensor Cores |
| Softmax / LayerNorm stats / big reductions | FP32 (often) | protects against overflow + rounding accumulation |
| Parameter gradients | often FP32 *storage* (even if compute is mixed) | stable updates + compatibility with optimizers |
| Parameters (“master weights”) | FP32 | prevents update stagnation |
| Optimizer state (Adam moments) | FP32 | long-horizon accumulation is precision-sensitive |

Autocast’s job is mostly about **(1) and (2)**: choose per-op dtypes.
Grad scaling’s job is mostly about **(3)** when FP16 is involved: keep gradients from underflowing.

## 1.1 Floating-point is “range + precision”, not just “more bits = better”

A binary floating-point number is roughly:

\[
(-1)^{\text{sign}} \times (1.\text{mantissa}) \times 2^{\text{exponent}}
\]

The bit budget is split across:

- **Exponent bits** → *range* (how large/small magnitudes you can represent)
- **Mantissa (fraction) bits** → *precision* (how many significant bits you keep)

For deep learning training, the key question is not “can I store 3.14159?” but:

- Can I represent **tiny gradients** without them becoming 0 (underflow)?
- Can I represent **large activations** without them becoming `inf` (overflow)?
- Can I sum many numbers without destroying meaning via rounding?

These failure modes show up differently in FP16 and BF16.

### 1.1.1 IEEE 754 anatomy (sign, exponent bias, hidden bit)

A *normalized* binary floating-point value is encoded as:

- **sign bit** $s$
- **exponent field** $E$ (stored with a **bias**)
- **mantissa / fraction field** $m$

For normalized numbers:

$$
\text{value} = (-1)^s \times (1 + m) \times 2^{(E - \text{bias})}
$$

Key details:

- The leading `1.` is **implicit** (the “hidden bit”), which is why you get 1 extra bit of precision.
- Exponent all-zeros and all-ones are **reserved**:
  - `E=0`: subnormals / zero
  - `E=all ones`: `inf` / `nan`

The bias is what lets exponents represent both negative and positive powers.

We’ll decode π in FP32/FP16 (and BF16) to make this concrete.

In [None]:
# Bit-level decoding demo (pi) for float32/float16/(optional) bfloat16
import struct


def bits_f32(x: float) -> str:
    (u32,) = struct.unpack(">I", struct.pack(">f", float(x)))
    return f"{u32:032b}"


def bits_f16(x: float) -> str:
    u16 = np.frombuffer(np.float16(x).tobytes(), dtype=np.uint16)[0]
    return f"{int(u16):016b}"


def bits_bf16(x: float) -> str:
    t = torch.tensor(float(x), dtype=torch.bfloat16)
    i16 = int(t.view(torch.int16).item()) & 0xFFFF
    return f"{i16:016b}"


def decode(bits: str, exp_bits: int, mant_bits: int, bias: int):
    s = int(bits[0], 2)
    E = int(bits[1 : 1 + exp_bits], 2)
    M_bits = bits[1 + exp_bits :]
    assert len(M_bits) == mant_bits

    if E == 0:
        kind = "subnormal/zero"
        exp = 1 - bias
        mant = sum(int(b) * (2 ** (-(i + 1))) for i, b in enumerate(M_bits))
        val = ((-1) ** s) * (mant) * (2 ** exp)
        return kind, s, E, exp, mant, val

    if E == (2**exp_bits - 1):
        kind = "inf/nan"
        return kind, s, E, None, None, None

    kind = "normal"
    exp = E - bias
    mant = sum(int(b) * (2 ** (-(i + 1))) for i, b in enumerate(M_bits))
    val = ((-1) ** s) * (1.0 + mant) * (2 ** exp)
    return kind, s, E, exp, mant, val


x = math.pi

rows = []

b32 = bits_f32(x)
kind, s, E, exp, mant, val = decode(b32, exp_bits=8, mant_bits=23, bias=127)
rows.append({"dtype": "float32", "bits": b32, "kind": kind, "sign": s, "E(raw)": E, "exp": exp, "mant": mant, "decoded": val})

b16 = bits_f16(x)
kind, s, E, exp, mant, val = decode(b16, exp_bits=5, mant_bits=10, bias=15)
rows.append({"dtype": "float16", "bits": b16, "kind": kind, "sign": s, "E(raw)": E, "exp": exp, "mant": mant, "decoded": val})

bbf = bits_bf16(x)
kind, s, E, exp, mant, val = decode(bbf, exp_bits=8, mant_bits=7, bias=127)
rows.append({"dtype": "bfloat16", "bits": bbf, "kind": kind, "sign": s, "E(raw)": E, "exp": exp, "mant": mant, "decoded": val})

pd.DataFrame(rows)

### 1.1.2 Normal vs subnormal numbers (and “flush-to-zero”)

**Subnormals** extend representable values closer to 0 by giving up the implicit leading `1.`.

- They matter because gradients can be very small.
- But they can be slow on some hardware.

Many compute paths enable **FTZ/DAZ** (“flush-to-zero” / “denormals-are-zero”), which means extremely small values become exactly 0.

The practical lesson:

- It is not enough to know the *spec* of a dtype.
- You also need to know what your hardware/kernel path does with denormals.

Let’s probe whether the smallest subnormal survives on your device.

In [None]:
# Subnormal survival probe (device-dependent)

def subnormal_survives(dtype: torch.dtype, device: torch.device):
    z = torch.tensor(0.0, dtype=dtype, device=device)
    o = torch.tensor(1.0, dtype=dtype, device=device)
    sub = torch.nextafter(z, o)
    return {
        "dtype": str(dtype),
        "device": device.type,
        "nextafter(0,1)": float(sub) if sub.numel() == 1 else None,
        "is_zero": bool((sub == 0).item()),
    }

rows = []
for dt in [torch.float16, torch.bfloat16, torch.float32]:
    try:
        rows.append(subnormal_survives(dt, device))
    except Exception as e:
        rows.append({"dtype": str(dt), "device": device.type, "error": type(e).__name__})

pd.DataFrame(rows)

### 1.1.3 ULP: spacing grows with magnitude

A float format has *roughly constant relative precision* but *variable absolute precision*.

- Near 1.0, FP16 spacing is ~`1e-3`.
- Near 1024, spacing is ~`1`.

This is the concrete reason “tiny updates disappear” when weights are stored in low precision.

Let’s plot ULP as a function of magnitude.

In [None]:
# ULP vs magnitude

def ulp(x: torch.Tensor, dtype: torch.dtype):
    x = x.to(dtype)
    return (torch.nextafter(x, x * 2) - x).abs().to(torch.float32)

# Use exact powers of two to avoid extra rounding.
ks = torch.arange(-10, 21, device=device)
x = (2.0 ** ks).to(torch.float32)

plt.figure(figsize=(10, 4))
for dt in [torch.float16, torch.bfloat16, torch.float32]:
    if device.type == "cpu" and dt is torch.float16:
        continue
    u = ulp(x, dt).cpu().numpy()
    plt.plot(ks.cpu().numpy(), np.log2(u + 1e-30), marker="o", label=str(dt))

plt.title("log2(ULP) vs log2(|x|) for powers of two")
plt.xlabel("log2(|x|)")
plt.ylabel("log2(ULP(x))")
plt.legend();

## 1.2 FP16 vs BF16 vs FP32 (the table you should memorize)

### Bit layouts (IEEE-like)

| dtype | total bits | exponent bits | mantissa bits | “precision bits” (incl. hidden 1) | what it’s good for |
|---|---:|---:|---:|---:|---|
| FP16 (IEEE half) | 16 | 5 | 10 | 11 | fast Tensor Core compute, but narrow range → underflow/overflow risks |
| BF16 | 16 | 8 | 7 | 8 | FP32-like range, lower precision; often “drop-in” for training |
| FP32 (single) | 32 | 8 | 23 | 24 | stable baseline; slower/more memory |

### Two immediate consequences

1. **FP16 has more mantissa bits than BF16** → better precision.
2. **BF16 has the same exponent width as FP32** → dramatically better *range* than FP16.

So:

- FP16 fails *first* due to **range** (underflow/overflow).
- BF16 fails *first* due to **precision** (rounding / accumulation error).

Autocast exists to route computations so that you get the performance of 16-bit compute *without* the worst numeric failure modes.

In [None]:
# A numeric “format facts” table from your local PyTorch

def smallest_subnormal(dtype: torch.dtype) -> float:
    z = torch.tensor(0.0, dtype=dtype)
    o = torch.tensor(1.0, dtype=dtype)
    return float(torch.nextafter(z, o))


def ulp_at_one(dtype: torch.dtype) -> float:
    one = torch.tensor(1.0, dtype=dtype)
    nxt = torch.nextafter(one, one + one)
    return float(nxt - one)


def dtype_row(name: str, dtype: torch.dtype, exp_bits: int, mant_bits: int, exp_min: int, exp_max: int):
    fi = torch.finfo(dtype)
    return {
        "dtype": name,
        "bits": fi.bits,
        "exp_bits": exp_bits,
        "mant_bits": mant_bits,
        "precision_bits": mant_bits + 1,
        "approx_decimal_digits": round((mant_bits + 1) * math.log10(2), 2),
        "exp_min(normal)": exp_min,
        "exp_max(normal)": exp_max,
        "eps": float(fi.eps),
        "ulp(1.0)": ulp_at_one(dtype),
        "min_normal": float(fi.tiny),
        "min_subnormal": smallest_subnormal(dtype),
        "max_finite": float(fi.max),
    }


dtype_info = pd.DataFrame([
    dtype_row("float16", torch.float16, exp_bits=5, mant_bits=10, exp_min=-14, exp_max=15),
    dtype_row("bfloat16", torch.bfloat16, exp_bits=8, mant_bits=7, exp_min=-126, exp_max=127),
    dtype_row("float32", torch.float32, exp_bits=8, mant_bits=23, exp_min=-126, exp_max=127),
]).set_index("dtype")

dtype_info

### Interpret the table

- `eps` / `ulp(1.0)` are *precision* around 1.0.
  - BF16’s spacing near 1.0 is ~`0.0078125`.
  - FP16’s spacing near 1.0 is ~`0.0009765625`.
  - FP32’s spacing near 1.0 is ~`1.19e-07`.

- `min_normal` and `min_subnormal` are the smallest magnitudes you can represent.
  - FP16’s smallest normal is around `6e-5`, and smallest subnormal around `6e-8`.
  - BF16’s smallest normal is around `1e-38`, which is *enormously* smaller.

This is why FP16 can silently turn small gradients into exact 0, while BF16 usually won’t.

## 1.3 The three numeric disasters that show up during training

### (A) Underflow (values become 0)

- Common in **gradients**, especially late in training or in deep nets with tiny signals.
- Most harmful in FP16 due to narrow exponent range.

### (B) Overflow (values become `inf`)

- Common in **activations** (e.g., exponentials), attention logits, or badly-initialized models.
- FP16 can overflow much earlier than BF16/FP32.

### (C) Accumulation / cancellation error

Even when values are in range, precision limits can corrupt sums/products.

Classic example: adding many small numbers to a large accumulator.

- FP16/BF16 have very limited mantissas.
- For reductions (layernorm statistics, softmax normalization, large sums), frameworks often keep accumulation in FP32.

Autocast is partly about **preventing A/B** (range disasters), and partly about routing sensitive reductions so **C** doesn’t destroy training.

In [None]:
# A quick visualization: where exp() overflows by dtype
# This is not an AMP experiment; it's a numeric intuition builder.

x = torch.linspace(-20, 20, 400, device=device)

def safe_exp(x, dtype):
    y = torch.exp(x.to(dtype))
    return y.to(torch.float32).cpu().numpy()

ys = {
    "float32": safe_exp(x, torch.float32),
}

# float16 exp is meaningful on CUDA; on CPU it may be slower / less supported.
try:
    ys["float16"] = safe_exp(x, torch.float16)
except Exception as e:
    ys["float16"] = None
    print("float16 exp skipped:", type(e).__name__, str(e)[:120])

try:
    ys["bfloat16"] = safe_exp(x, torch.bfloat16)
except Exception as e:
    ys["bfloat16"] = None
    print("bfloat16 exp skipped:", type(e).__name__, str(e)[:120])

plt.figure()
for k, y in ys.items():
    if y is None:
        continue
    plt.plot(x.cpu().numpy(), np.log10(np.clip(y, 1e-30, 1e30)), label=k)
plt.title("log10(exp(x)) computed in different dtypes")
plt.xlabel("x")
plt.ylabel("log10(exp(x)) (clipped)")
plt.legend();

### 1.3.1 Loss functions are “log-sum-exp machines” (and dtype matters)

Many deep learning losses contain exponentials and logs.

A classic example is the logistic / softplus component:

$$
\log(1 + e^x)
$$

- The naïve formula overflows quickly in FP16.
- Stable implementations (e.g., `F.softplus`) avoid overflow by rewriting the expression.

We’ll compare the naïve and stable versions across dtypes.

In [None]:
# Naïve vs stable softplus across dtypes

def naive_softplus(x: torch.Tensor):
    return torch.log1p(torch.exp(x))

x = torch.linspace(-80, 80, 2000, device=device)
ref = F.softplus(x.double()).float()  # high-precision reference

plt.figure(figsize=(10, 4))
for dt in [torch.float16, torch.bfloat16, torch.float32]:
    if device.type == "cpu" and dt is torch.float16:
        continue
    y_naive = naive_softplus(x.to(dt)).float()
    y_stable = F.softplus(x.to(dt)).float()
    err_naive = (y_naive - ref).abs().cpu().numpy()
    err_stable = (y_stable - ref).abs().cpu().numpy()

    plt.plot(x.cpu().numpy(), np.log10(err_naive + 1e-12), label=f"naive {dt}")
    plt.plot(x.cpu().numpy(), np.log10(err_stable + 1e-12), linestyle="--", label=f"stable {dt}")

plt.title("log10(|error|) vs x for softplus implementations")
plt.xlabel("x")
plt.ylabel("log10 absolute error (vs FP64-stable ref)")
plt.legend(ncols=2);

### 1.3.2 Softmax: the overflow trap and the stability rewrite

Softmax is everywhere in transformers (attention).

Naïve softmax is:

$$
\text{softmax}(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}}
$$

This can overflow in low precision because `exp(x)` explodes quickly.

The stable rewrite subtracts the maximum value before exponentiating:

$$
\text{softmax}(x) = \text{softmax}(x - \max(x))
$$

PyTorch’s `F.softmax` uses a stable implementation.
We’ll demonstrate the difference.

In [None]:
# Naïve vs stable softmax across dtypes

def naive_softmax(x: torch.Tensor, dim: int = -1):
    ex = torch.exp(x)
    return ex / ex.sum(dim=dim, keepdim=True)

logits = torch.tensor([0.0, 20.0, 40.0, 80.0], device=device)
ref = F.softmax(logits.double(), dim=0).float()

rows = []
for dt in [torch.float16, torch.bfloat16, torch.float32]:
    if device.type == "cpu" and dt is torch.float16:
        continue

    x = logits.to(dt)

    try:
        naive = naive_softmax(x, dim=0).float()
        naive_ok = bool(torch.isfinite(naive).all().item())
    except Exception:
        naive = torch.full_like(ref, float("nan"))
        naive_ok = False

    stable = F.softmax(x, dim=0).float()
    stable_ok = bool(torch.isfinite(stable).all().item())

    rows.append({
        "dtype": str(dt),
        "naive_finite": naive_ok,
        "stable_finite": stable_ok,
        "max_abs_err(stable vs ref)": float((stable - ref).abs().max().item()),
    })

pd.DataFrame(rows)

## 1.4 What AMP actually is

In PyTorch, AMP is two tools:

1. **`autocast`** (forward + loss)
   - A context manager that applies a **per-operation dtype policy**.
   - It may cast inputs/weights *for the operation*.
   - It does **not** permanently change your model parameters unless *you* cast them.

2. **`GradScaler`** (backward + optimizer step)
   - Primarily for **FP16 training**.
   - Rescales the loss (and therefore gradients) to prevent gradient underflow.
   - Dynamically adjusts the scale to avoid overflow.

A clean mental model:

- `autocast` protects you from **bad forward dtypes**.
- `GradScaler` protects you from **bad backward magnitudes** (FP16 underflow).

## 1.5 The canonical AMP training loop (four conceptual changes)

Start with FP32 training:

```python
optimizer.zero_grad(set_to_none=True)
logits = model(x)
loss = loss_fn(logits, y)
loss.backward()
optimizer.step()
```

AMP adds:

1. Wrap forward + loss in autocast.
2. Scale the loss before backward.
3. Step via the scaler (it unscales + checks for inf/nan).
4. Update the scaler.

```python
scaler = GradScaler()

optimizer.zero_grad(set_to_none=True)
with autocast(device_type="cuda", dtype=torch.float16):
    logits = model(x)
    loss = loss_fn(logits, y)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
```

That’s the “small code change” people talk about.
But the *reason* it works is the theory above.

## 1.6 Master weights and optimizer state (why “just casting the model” is not the same)

There are **three** numeric objects in training that matter:

1. **Parameters** (weights)
2. **Gradients**
3. **Optimizer state** (e.g., Adam moments)

A classic mixed precision recipe is:

- Keep a **master copy of weights in FP32**.
- Do forward/backward compute in FP16/BF16 where safe.
- Accumulate/maintain optimizer state (Adam moments) in FP32.

Why keep FP32 master weights?

- Because 16-bit formats have coarse spacing.
- A small update \(\Delta w\) can be *below the spacing* at the magnitude of \(w\), so the weight does not change (stagnation).

We will demonstrate this in Section 3 with a tiny, concrete example.

## 1.7 Autocast is an operator policy (not a global cast)

Autocast does not “turn the whole model into FP16”.
Instead it says (conceptually):

- Matmul/linear/convolution → run in lower precision if possible (Tensor Cores)
- Softmax, layernorm, large reductions, many losses → run in FP32 for stability

The exact policy is PyTorch-version- and backend-dependent.

In Section 3 we will *build a policy table by executing real ops on your machine*.
That is the most reliable way to learn it.

## Section 1 summary (cheat sheet)

- FP16: better precision than BF16 but narrow range → needs **loss scaling** and careful op policies.
- BF16: FP32-like range → often trains without scaling, but is coarse → reductions need care.
- AMP = `autocast` (forward policy) + `GradScaler` (backward magnitude control).
- Good AMP training typically keeps **optimizer state in FP32**, often keeps **master weights in FP32**.

# Section 2 — What the literature says (no experiments)

This section is intentionally *written*: the point is to build a paper-and-doc-driven mental model that you can carry into real training code.

Think of this section as: “what each source contributes, and how it maps to PyTorch AMP”.

## 2.0 A reading map (what each source gives you)

| Source | What it gives you | How it maps to this notebook |
|---|---|---|
| Micikevicius et al. (2017) | the canonical mixed precision recipe (master weights + loss scaling) | Sections 1.4–1.6 and the scaling experiments |
| NVIDIA mixed precision guidance | engineering intuition + failure modes | the underflow/overflow + “sensitive ops in FP32” story |
| PyTorch `torch.amp` docs | the actual API + gotchas | `autocast` + `GradScaler` loops in Section 3 |
| PyTorch autocast op reference | *the* per-op policy | we probe the policy empirically in Section 3.2 |
| BF16 background (TPU/GPU) | why BF16 often works without scaling | Section 1.2 + BF16 training run |
| Distributed training docs (FSDP/ZeRO/DeepSpeed) | where dtypes live in large systems | Section 2.5 mental model |

If you only read one thing: read the mixed precision paper, then read the PyTorch autocast op reference.

## 2.1 Micikevicius et al. (2017): *Mixed Precision Training*

**Problem:** FP16 is fast, but naïvely training in FP16 breaks due to underflow/overflow and update precision.

**Key ideas:**

- Keep **FP32 master weights** for the update.
- Use **loss scaling** to shift gradient magnitudes into FP16’s representable range.
- Use mixed precision: compute some ops in FP16, keep others in FP32.

**What to remember:**

- The “magic” is not just `half()`; it is the *combination* of (a) op-level policies, (b) scaling, (c) FP32 updates.
- AMP frameworks are automations of this recipe.

Reference: Micikevicius et al., arXiv:1710.03740.

## 2.2 NVIDIA mixed precision guidance (engineering view)

A useful engineering decomposition:

1. **Tensor Core compute** wants FP16/BF16 inputs.
2. **Some ops are numerically sensitive** (exp/softmax, normalization stats, large reductions).
3. FP16’s range is narrow, so gradients can underflow.

This viewpoint is why modern stacks:

- Run matmuls in FP16/BF16.
- Accumulate many reductions in FP32.
- Use dynamic loss scaling for FP16.

The practical outcome is the “few lines” AMP recipe, but the reason it works is the numeric analysis.

## 2.3 BF16: “FP32 range, fewer bits of precision”

BF16 was designed so that you can get most of the training stability benefits of FP32 range while using 16-bit storage/compute.

The mental model:

- BF16 is usually not endangered by underflow in the same way as FP16.
- BF16 can still introduce training error via *coarse rounding*, so reductions and sensitive ops often remain FP32.

In practice, for many transformer trainings on modern GPUs:

- **BF16 autocast** can be “drop-in” without loss scaling.
- FP16 autocast often needs a `GradScaler`.

## 2.4 PyTorch AMP docs: the operator policy is the decoder ring

A common mistake is to treat autocast like a global switch.

But autocast is really an **operator policy**:

- Some ops are eligible for lower precision.
- Some ops are forced to FP32.
- Some ops “promote” to the widest input type.

The takeaway for practitioners:

- When debugging AMP, you need to know which ops ran in which dtype.
- The most robust way to learn is to *probe it empirically on your version*, which we’ll do in Section 3.

## 2.5 LLM training stacks (FSDP / ZeRO / DeepSpeed): where precision choices multiply

At LLM scale, training systems often shard:

- parameters
- gradients
- optimizer state

Mixed precision gets more complicated because:

- You may store “working weights” in BF16/FP16.
- You may keep FP32 master weights (sometimes sharded).
- Optimizer states (Adam moments) are often FP32 for stability.

Practical takeaway:

- “AMP” in a distributed stack is not only about autocast; it is about **where each tensor lives and in what dtype**.

## 2.6 How to think about the autocast policy (conceptual categories)

Autocast decisions usually fall into one of these buckets:

1. **Lower precision eligible** (often matmul-like): run in FP16/BF16 for throughput.
2. **Force FP32**: ops with high risk of overflow/underflow or large reduction error (softmax, normalization stats, some losses).
3. **Promote to widest**: if one input is FP32, the op may run in FP32.

Two practitioner rules:

- If an op creates very large/small magnitudes (exp/log/softmax), assume it needs FP32 unless proven otherwise.
- If an op reduces many values (sum/mean/var), assume accumulation needs FP32 unless proven otherwise.

In Section 3.2 we probe your exact PyTorch build to see what happens.

## 2.7 Dynamic loss scaling (what `GradScaler` is doing)

Loss scaling multiplies the loss by a factor $S$.

- Backprop gradients scale by $S$ as well.
- This shifts gradients into FP16’s representable range.

Then, before the optimizer step:

- GradScaler **unscales** gradients by dividing by $S$.
- If it detects `inf`/`nan` gradients (overflow), it **skips the step** and lowers $S$.
- If things are stable, it can slowly increase $S$ to gain more headroom.

The reason it’s safe: scaling/unscaling cancels out *if no overflow occurs*.

This is why you should not clip gradients until after unscale.

## Section 2 summary

- The original mixed precision recipe explains *why* AMP needs scaling + FP32 updates (especially FP16).
- BF16 works well largely because its exponent range matches FP32.
- PyTorch autocast is an operator policy; learning it by probing is more reliable than memorizing.
- In real LLM training stacks, “precision” applies to parameters, grads, and optimizer state separately.

# Section 3 — Practicalities (experiments + graphs)

This section is where we turn everything into measurements.

Principles:

- Prefer experiments that are **small, fast, and explain a single idea**.
- Log everything you might need for debugging (loss, grad norms, scaler scale, step time, NaN/inf).
- Make the results comparable across dtypes.

## 3.0 Controlling confounders (TF32, randomness, and fair comparisons)

When you compare FP32 vs AMP runs, you can accidentally compare the wrong thing.

Two common confounders:

1. **TF32 on NVIDIA Ampere+**
   - Many FP32 matmuls can use TF32 internally (10-bit mantissa) for speed.
   - That means your “FP32 baseline” may not be strict FP32 precision.

2. **Randomness**
   - Dropout, data sampling, and nondeterministic kernels can introduce run-to-run variance.

For a learning notebook, it’s often useful to **print** these settings and optionally disable TF32 for “clean” comparisons.

In [None]:
# (Optional) inspect / control TF32

if device.type == "cuda":
    print("TF32 matmul allow:", torch.backends.cuda.matmul.allow_tf32)
    print("TF32 cuDNN allow:", torch.backends.cudnn.allow_tf32)

    # Set this to True if you want strict FP32 matmul (slower, but cleaner comparisons).
    DISABLE_TF32 = False
    if DISABLE_TF32:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
        print("TF32 disabled")
else:
    print("CUDA not available; TF32 not applicable")

## 3.1 Minimal AMP training loop (reference snippet)

This is the “copy/paste” template, but the rest of the notebook explains *why* each line exists.

In [None]:
# Minimal AMP loop (template)

def amp_train_step(model, optimizer, x, y, *, device_type: str, amp_dtype: torch.dtype, scaler=None):
    model.train()
    optimizer.zero_grad(set_to_none=True)

    ac = autocast(device_type=device_type, dtype=amp_dtype, enabled=(scaler is not None or amp_dtype is not None))
    with ac:
        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))

    if scaler is None:
        loss.backward()
        optimizer.step()
        return float(loss)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    return float(loss)

## 3.2 Build an “autocast operator policy table” from your local PyTorch

Instead of trusting a static table from the internet, we can probe:

- With autocast disabled
- With autocast enabled (FP16 or BF16)

and record the output dtype.

This turns autocast from “mystery magic” into an observable policy.

In [None]:
# Operator policy probe

@dataclass
class OpProbe:
    name: str
    fn: callable


def probe_ops(device: torch.device, amp_dtype: torch.dtype):
    device_type = device.type if device.type != "mps" else "cpu"  # autocast is best-defined for cuda/cpu

    ops = [
        OpProbe("matmul", lambda a, b: a @ b),
        OpProbe("linear", lambda a, w, b: F.linear(a, w, b)),
        OpProbe("softmax", lambda x: F.softmax(x, dim=-1)),
        OpProbe("layer_norm", lambda x: F.layer_norm(x, normalized_shape=(x.size(-1),))),
        OpProbe("exp", lambda x: torch.exp(x)),
        OpProbe("sum", lambda x: x.sum()),
    ]

    rows = []

    # Inputs
    a = torch.randn(128, 128, device=device, dtype=torch.float32)
    b = torch.randn(128, 128, device=device, dtype=torch.float32)
    x = torch.randn(1024, device=device, dtype=torch.float32)

    w = torch.randn(128, 128, device=device, dtype=torch.float32)
    bias = torch.randn(128, device=device, dtype=torch.float32)

    def run(op: OpProbe, use_autocast: bool):
        ctx = autocast(device_type=device_type, dtype=amp_dtype, enabled=use_autocast)
        with ctx:
            if op.name == "matmul":
                y = op.fn(a, b)
            elif op.name == "linear":
                y = op.fn(a, w, bias)
            else:
                y = op.fn(x)
        if isinstance(y, torch.Tensor):
            return str(y.dtype)
        return type(y).__name__

    for op in ops:
        rows.append({
            "op": op.name,
            "autocast": False,
            "out_dtype": run(op, False),
        })
        rows.append({
            "op": op.name,
            "autocast": True,
            "out_dtype": run(op, True),
        })

    return pd.DataFrame(rows)


for dtype in [torch.float16, torch.bfloat16]:
    if device.type == "cuda" and dtype is torch.bfloat16 and not torch.cuda.is_bf16_supported():
        continue
    if device.type == "cpu" and dtype is torch.float16:
        # CPU autocast is typically bfloat16-focused.
        continue

    df = probe_ops(device, dtype)
    print()
    print("=== amp_dtype =", dtype, "===")
    display(df)

## 3.3 Watch dtype flow through a tiny transformer (hooks)

This is the “1000x better version” of printing dtypes: instead of only showing that autocast exists, we’ll build a transformer-ish module and observe which submodules:

- receive FP32 inputs
- output BF16/FP16 activations
- are forced to FP32 (e.g., normalization)

This makes the op policy visceral.

In [None]:
# A tiny transformer block (enough to exercise matmul/softmax/norm)

class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd: int, n_heads: int, dropout: float = 0.0):
        super().__init__()
        assert n_embd % n_heads == 0
        self.n_heads = n_heads
        self.head_dim = n_embd // n_heads

        self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
        self.proj = nn.Linear(n_embd, n_embd, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # (B, H, T, D)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        # Attention scores
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)  # (B, H, T, T)

        # Causal mask
        mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
        att = att.masked_fill(mask, float("-inf"))

        att = F.softmax(att, dim=-1)
        att = self.dropout(att)

        y = att @ v  # (B, H, T, D)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.proj(y)
        return y


class MLP(nn.Module):
    def __init__(self, n_embd: int, hidden_mult: int = 4, dropout: float = 0.0):
        super().__init__()
        self.fc1 = nn.Linear(n_embd, hidden_mult * n_embd)
        self.fc2 = nn.Linear(hidden_mult * n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class Block(nn.Module):
    def __init__(self, n_embd: int, n_heads: int, dropout: float = 0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.attn = CausalSelfAttention(n_embd, n_heads, dropout)
        self.ln2 = nn.LayerNorm(n_embd)
        self.mlp = MLP(n_embd, dropout=dropout)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class TinyGPT(nn.Module):
    def __init__(self, vocab_size: int, block_size: int, n_layer: int = 2, n_embd: int = 128, n_heads: int = 4, dropout: float = 0.0):
        super().__init__()
        self.block_size = block_size
        self.tok_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Embedding(block_size, n_embd)
        self.drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([Block(n_embd, n_heads, dropout) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size, bias=False)

    def forward(self, idx):
        B, T = idx.shape
        assert T <= self.block_size
        pos = torch.arange(0, T, device=idx.device)

        x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
        x = self.drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.ln_f(x)
        logits = self.head(x)
        return logits


def install_dtype_hooks(model: nn.Module, watch=(nn.Linear, nn.LayerNorm, nn.Embedding)):
    hooks = []
    records = []

    def make_hook(name):
        def hook(m, inp, out):
            def dt(x):
                if isinstance(x, torch.Tensor):
                    return str(x.dtype)
                return type(x).__name__
            indt = dt(inp[0]) if isinstance(inp, (tuple, list)) and inp else dt(inp)
            outt = dt(out) if not isinstance(out, (tuple, list)) else dt(out[0])
            records.append({"module": name, "type": type(m).__name__, "in": indt, "out": outt})
        return hook

    for name, m in model.named_modules():
        if isinstance(m, watch):
            hooks.append(m.register_forward_hook(make_hook(name)))

    return hooks, records


# Build a tiny model and run a single forward pass under several regimes
vocab_size = 128
block_size = 64
idx = torch.randint(0, vocab_size, (2, block_size), device=device)


def run_dtype_trace(amp_enabled: bool, amp_dtype: torch.dtype | None, param_dtype: torch.dtype):
    model = TinyGPT(vocab_size=vocab_size, block_size=block_size).to(device)
    model = model.to(param_dtype)

    hooks, rec = install_dtype_hooks(model)

    device_type = device.type if device.type != "mps" else "cpu"
    ctx = autocast(device_type=device_type, dtype=amp_dtype, enabled=amp_enabled)
    with torch.inference_mode(), ctx:
        _ = model(idx)

    for h in hooks:
        h.remove()

    df = pd.DataFrame(rec)
    df["count"] = 1
    summary = df.groupby(["type", "in", "out"], as_index=False)["count"].sum().sort_values(["type", "in", "out"])
    return summary


traces = []
traces.append(("no autocast, params fp32", run_dtype_trace(False, None, torch.float32)))

if device.type == "cuda":
    traces.append(("autocast fp16, params fp32", run_dtype_trace(True, torch.float16, torch.float32)))
    if torch.cuda.is_bf16_supported():
        traces.append(("autocast bf16, params fp32", run_dtype_trace(True, torch.bfloat16, torch.float32)))

for title, df in traces:
    print()
    print("---", title, "---")
    display(df)

### Optional: dtype tracing on a real pretrained LLM (Hugging Face)

If you have `transformers` installed and can load a small model, the same dtype-hook method works on a real LLM.

Notes:

- Loading from Hugging Face may require internet access *or* a cached model.
- The forward pass is enough to observe dtype flow; you don’t need to train the model.
- This cell is optional; the notebook remains fully self-contained without it.

In [None]:
# Optional: dtype tracing on a small pretrained causal LM

try:
    from transformers import AutoModelForCausalLM, AutoTokenizer
except Exception as e:
    AutoModelForCausalLM = None
    print("transformers not available; skipping pretrained LLM trace:", type(e).__name__)

if AutoModelForCausalLM is not None and device.type == "cuda":
    model_name = "sshleifer/tiny-gpt2"  # tiny + fast; swap for OPT-125M if you have it cached

    try:
        tok = AutoTokenizer.from_pretrained(model_name)
        mdl = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).to(device).eval()
    except Exception as e:
        print("Could not load model (maybe no internet / not cached):", type(e).__name__, str(e)[:200])
    else:
        text = "AMP dtype tracing on a pretrained LM."
        inputs = tok(text, return_tensors="pt").to(device)

        hooks, rec = install_dtype_hooks(mdl)

        dtype16 = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        with torch.inference_mode(), autocast(device_type="cuda", dtype=dtype16):
            _ = mdl(**inputs)

        for h in hooks:
            h.remove()

        df = pd.DataFrame(rec)
        df["count"] = 1
        summary = df.groupby(["type", "in", "out"], as_index=False)["count"].sum().sort_values(["type", "in", "out"])
        display(summary)

## 3.4 Gradient underflow and why loss scaling works

We’ll do a controlled experiment:

1. Create a synthetic gradient distribution.
2. Cast it to FP16 and count how many values become exactly 0.
3. Apply a scale factor \(S\), cast again, then unscale.

This shows the core mechanism of loss scaling *without* needing a full training run.

In [None]:
# Gradient underflow demo

def underflow_report(grads: torch.Tensor, dtype: torch.dtype):
    g = grads.to(dtype)
    zeros = (g == 0).float().mean().item()
    finite = torch.isfinite(g).float().mean().item()
    return {"dtype": str(dtype), "zero_frac": zeros, "finite_frac": finite}


# Log-uniform magnitudes: covers many orders of magnitude.
N = 200_000
log10_mag = torch.empty(N).uniform_(-12, 0)  # 1e-12 .. 1
sign = torch.randint(0, 2, (N,)) * 2 - 1
synthetic = (10 ** log10_mag) * sign
synthetic = synthetic.to(torch.float32)

rows = []
rows.append({"setting": "unscaled", **underflow_report(synthetic, torch.float16)})
rows.append({"setting": "unscaled", **underflow_report(synthetic, torch.bfloat16)})

S = 2**13  # roughly 8192, common starting point
scaled = synthetic * S
rows.append({"setting": f"scaled by {S}", **underflow_report(scaled, torch.float16)})
rows.append({"setting": f"scaled by {S}", **underflow_report(scaled, torch.bfloat16)})

pd.DataFrame(rows)

In [None]:
# Visualize the distribution and FP16 representable thresholds

fi16 = torch.finfo(torch.float16)
min_normal = float(fi16.tiny)
min_sub = float(torch.nextafter(torch.tensor(0.0, dtype=torch.float16), torch.tensor(1.0, dtype=torch.float16)))

vals = synthetic.abs().cpu().numpy()
plt.figure()
plt.hist(np.log10(vals + 1e-30), bins=200)
plt.axvline(np.log10(min_normal), color="r", linestyle="--", label="FP16 min normal")
plt.axvline(np.log10(min_sub), color="m", linestyle=":", label="FP16 min subnormal")
plt.title("Synthetic |grad| distribution (log10) with FP16 thresholds")
plt.xlabel("log10(|grad|)")
plt.ylabel("count")
plt.legend();

### 3.4.1 Underflow in a real backward pass (FP16) + rescue via scaling

The synthetic demo above shows how casting kills tiny values.

Now we’ll show the *actual training failure mode*:

- Build a tiny FP16 network.
- Choose inputs/targets that produce very small gradients.
- Compare gradients with and without loss scaling.

This is the smallest end-to-end demonstration of why scaling exists.

In [None]:
# Real backward underflow demo (tiny MLP)

def tiny_backward(use_scaling: bool, scale: float = 2**13):
    model = nn.Sequential(nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 1)).to(device)
    model = model.to(torch.float16)

    x = (torch.randn(256, 128, device=device) * 1e-3).to(torch.float16)
    y = (torch.randn(256, 1, device=device) * 1e-3).to(torch.float16)

    pred = model(x)
    loss = ((pred - y) ** 2).mean()  # MSE

    if use_scaling:
        (loss * scale).backward()
        # Unscale grads (mimics GradScaler's unscale step)
        for p in model.parameters():
            if p.grad is not None:
                p.grad.div_(scale)
    else:
        loss.backward()

    grads = torch.cat([p.grad.flatten().abs().to(torch.float32) for p in model.parameters() if p.grad is not None])
    return float(loss), float((grads == 0).float().mean()), float(torch.median(grads)), grads


if device.type == "cuda":
    loss0, z0, med0, g0 = tiny_backward(use_scaling=False)
    loss1, z1, med1, g1 = tiny_backward(use_scaling=True)

    pd.DataFrame([
        {"setting": "fp16 no scaling", "loss": loss0, "zero_grad_frac": z0, "median|grad|": med0},
        {"setting": "fp16 scaled+unscaled", "loss": loss1, "zero_grad_frac": z1, "median|grad|": med1},
    ])
else:
    print("This demo is most meaningful on CUDA (FP16 backward kernels).")

## 3.5 Weight update stagnation (why FP32 master weights matter)

Even if you avoid underflow, you can still lose learning signal if your **weight updates are below the spacing** of the weight’s dtype.

Example idea:

- If \(w \approx 1\) in FP16, the spacing (ULP) is ~\(2^{-10} \approx 9.8\times 10^{-4}\).
- Any update smaller than that can get rounded away.

Let’s demonstrate by applying a sequence of tiny updates and checking whether the weight changes.

In [None]:
# Weight stagnation demo

def apply_updates(dtype: torch.dtype, w0: float, delta: float, steps: int = 2000):
    w = torch.tensor(w0, dtype=dtype)
    changed = 0
    for _ in range(steps):
        w_new = (w - torch.tensor(delta, dtype=dtype))
        changed += int(w_new.item() != w.item())
        w = w_new
    return {
        "dtype": str(dtype),
        "w0": w0,
        "delta": delta,
        "steps": steps,
        "num_steps_where_w_changed": changed,
        "final_w": float(w),
        "ulp_at_one": ulp_at_one(dtype),
    }

rows = []
for dtype in [torch.float16, torch.bfloat16, torch.float32]:
    rows.append(apply_updates(dtype, w0=1.0, delta=1e-5, steps=2000))

pd.DataFrame(rows)

## 3.6 The main event: train a tiny causal transformer under different precision regimes

We’ll train a small GPT-like model on a tiny in-notebook text corpus (character-level next-token prediction).

Why character-level?

- No external downloads.
- Stable and deterministic.
- Still exercises the transformer mechanics that matter for autocast.

We will compare:

- FP32 baseline
- FP16 naïve (no autocast, no scaling) — expected to be unstable on many models
- AMP FP16 (autocast + GradScaler)
- AMP BF16 (autocast BF16, no scaler)

We will log and plot:

- training loss
- gradient norm
- step time
- (CUDA) memory
- (FP16 AMP) dynamic loss scale

Tip: start with the default `FAST_DEV_RUN` setting in the suite-definition cell. For longer curves, increase `BASE_STEPS`.

In [None]:
# Tiny corpus + char-level tokenizer

lines = [
    "Autocast is not a global cast.",
    "It is an operator policy.",
    "Some ops run in lower precision for speed.",
    "Some ops run in float32 for stability.",
    "",
    "Loss scaling rescues fp16 gradients from underflow.",
    "Bfloat16 usually has enough exponent range to avoid underflow.",
    "",
    "Transformers amplify numeric issues via softmax, layernorm, and large reductions.",
    "AMP exists to route the right operations to the right dtype.",
]

corpus = "\n".join(lines).strip() * 50

# Build vocab
chars = sorted(list(set(corpus)))
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}
vocab_size = len(chars)


def encode(s: str):
    return [stoi[c] for c in s]


def decode(ids):
    return "".join(itos[i] for i in ids)


data = torch.tensor(encode(corpus), dtype=torch.long)

# Train/val split
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

print("vocab_size:", vocab_size)
print("train tokens:", len(train_data), "val tokens:", len(val_data))
print("sample decode:\n")
print(decode(train_data[:200].tolist()))

In [None]:
# Batch sampling

def get_batch(split: str, batch_size: int, block_size: int):
    src = train_data if split == "train" else val_data
    ix = torch.randint(len(src) - block_size - 1, (batch_size,))
    x = torch.stack([src[i : i + block_size] for i in ix])
    y = torch.stack([src[i + 1 : i + block_size + 1] for i in ix])
    return x.to(device), y.to(device)


@torch.no_grad()
def estimate_loss(model, block_size: int, batch_size: int, iters: int = 20):
    model.eval()
    out = {}
    for split in ["train", "val"]:
        losses = []
        for _ in range(iters):
            x, y = get_batch(split, batch_size, block_size)
            logits = model(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
            losses.append(float(loss))
        out[split] = float(np.mean(losses))
    model.train()
    return out

In [None]:
# Training utilities (logging loss, grad norms, speed, and optional eval)

@dataclass
class TrainConfig:
    name: str
    steps: int = 200
    batch_size: int = 32
    block_size: int = 64
    lr: float = 3e-4
    weight_decay: float = 0.0

    # Precision knobs
    use_autocast: bool = False
    amp_dtype: torch.dtype | None = None
    use_grad_scaler: bool = False
    param_dtype: torch.dtype = torch.float32

    # Logging knobs
    eval_interval: int | None = 50
    eval_iters: int = 10


def global_grad_norm(model: nn.Module) -> float:
    total_sq = 0.0
    for p in model.parameters():
        if p.grad is None:
            continue
        g = p.grad.detach().float()
        total_sq += float(g.norm()) ** 2
    return math.sqrt(total_sq)


def train_one(cfg: TrainConfig):
    model = TinyGPT(
        vocab_size=vocab_size,
        block_size=cfg.block_size,
        n_layer=2,
        n_embd=128,
        n_heads=4,
        dropout=0.0,
    ).to(device)
    model = model.to(cfg.param_dtype)

    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    device_type = device.type if device.type != "mps" else "cpu"

    scaler = None
    if cfg.use_grad_scaler:
        scaler = GradScaler(enabled=(device.type == "cuda"))

    logs = {
        "step": [],
        "train_loss": [],
        "grad_norm": [],
        "step_time_ms": [],
        "tokens_per_s": [],
        "scale": [],
        "cuda_mem_mb": [],
        "val_step": [],
        "val_loss": [],
    }

    status = "ok"

    if device.type == "cuda":
        torch.cuda.reset_peak_memory_stats()

    tokens_per_step = cfg.batch_size * cfg.block_size

    for step in range(cfg.steps):
        t0 = time.perf_counter()

        x, y = get_batch("train", cfg.batch_size, cfg.block_size)
        optimizer.zero_grad(set_to_none=True)

        ctx = autocast(device_type=device_type, dtype=cfg.amp_dtype, enabled=cfg.use_autocast)
        with ctx:
            logits = model(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))

        if not torch.isfinite(loss):
            status = "non_finite_loss"
            break

        if scaler is None:
            loss.backward()
            grad_norm = global_grad_norm(model)
            optimizer.step()
            scale_val = float("nan")
        else:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            grad_norm = global_grad_norm(model)
            scaler.step(optimizer)
            scaler.update()
            scale_val = float(scaler.get_scale())

        t1 = time.perf_counter()
        dt = max(t1 - t0, 1e-12)

        logs["step"].append(step)
        logs["train_loss"].append(float(loss))
        logs["grad_norm"].append(float(grad_norm))
        logs["step_time_ms"].append(dt * 1000)
        logs["tokens_per_s"].append(tokens_per_step / dt)
        logs["scale"].append(scale_val)

        if device.type == "cuda":
            logs["cuda_mem_mb"].append(torch.cuda.max_memory_allocated() / 1024**2)
        else:
            logs["cuda_mem_mb"].append(float("nan"))

        if cfg.eval_interval is not None and (step % cfg.eval_interval == 0 or step == cfg.steps - 1):
            try:
                losses = estimate_loss(model, cfg.block_size, cfg.batch_size, iters=cfg.eval_iters)
                logs["val_step"].append(step)
                logs["val_loss"].append(losses["val"])
            except Exception:
                logs["val_step"].append(step)
                logs["val_loss"].append(float("nan"))

    return {"config": cfg, "status": status, "logs": logs}

In [None]:
# Define the experiment suite (conditionally, based on device support)

# A small “dev run” knob. CPU runs can be slow; CUDA runs are usually fast enough.
FAST_DEV_RUN = (device.type != "cuda")
BASE_STEPS = 60 if FAST_DEV_RUN else 200

suite = []

suite.append(TrainConfig(
    name="fp32",
    steps=BASE_STEPS,
    use_autocast=False,
    amp_dtype=None,
    use_grad_scaler=False,
    param_dtype=torch.float32,
))

if device.type == "cuda":
    # Naïve FP16: params in fp16, no autocast, no scaler.
    suite.append(TrainConfig(
        name="fp16_naive",
        steps=BASE_STEPS,
        use_autocast=False,
        amp_dtype=None,
        use_grad_scaler=False,
        param_dtype=torch.float16,
    ))

    # AMP FP16: params fp32, autocast fp16, scaler on.
    suite.append(TrainConfig(
        name="amp_fp16",
        steps=BASE_STEPS,
        use_autocast=True,
        amp_dtype=torch.float16,
        use_grad_scaler=True,
        param_dtype=torch.float32,
    ))

    # AMP BF16: params fp32, autocast bf16, no scaler.
    if torch.cuda.is_bf16_supported():
        suite.append(TrainConfig(
            name="amp_bf16",
            steps=BASE_STEPS,
            use_autocast=True,
            amp_dtype=torch.bfloat16,
            use_grad_scaler=False,
            param_dtype=torch.float32,
        ))

elif device.type == "cpu":
    # CPU autocast is typically bfloat16.
    suite.append(TrainConfig(
        name="amp_bf16_cpu",
        steps=BASE_STEPS,
        use_autocast=True,
        amp_dtype=torch.bfloat16,
        use_grad_scaler=False,
        param_dtype=torch.float32,
    ))

print("Planned runs:")
for cfg in suite:
    print("-", cfg.name, "steps", cfg.steps, "param", cfg.param_dtype, "autocast", cfg.use_autocast, cfg.amp_dtype, "scaler", cfg.use_grad_scaler)

In [None]:
# Run training experiments

results = []

for cfg in suite:
    print()
    print("=== Running:", cfg.name, "===")
    res = train_one(cfg)
    print("status:", res["status"], "steps:", len(res["logs"]["step"]))
    results.append(res)

# Quick summary table
summary_rows = []
for r in results:
    cfg = r["config"]
    logs = r["logs"]
    steps_ran = len(logs["step"])

    final_train_loss = logs["train_loss"][-1] if steps_ran else None
    mean_step_ms = float(np.mean(logs["step_time_ms"])) if steps_ran else None
    mean_tokens_per_s = float(np.mean(logs["tokens_per_s"])) if steps_ran else None

    # Best-effort: latest val loss if we logged any
    final_val_loss = logs["val_loss"][-1] if len(logs["val_loss"]) else None

    peak_cuda_mem_mb = None
    if device.type == "cuda" and steps_ran:
        peak_cuda_mem_mb = float(np.nanmax(np.array(logs["cuda_mem_mb"], dtype=np.float64)))

    summary_rows.append({
        "name": cfg.name,
        "status": r["status"],
        "steps_ran": int(steps_ran),
        "final_train_loss": final_train_loss,
        "final_val_loss": final_val_loss,
        "mean_step_ms": mean_step_ms,
        "mean_tokens_per_s": mean_tokens_per_s,
        "peak_cuda_mem_mb": peak_cuda_mem_mb,
    })

pd.DataFrame(summary_rows)

In [None]:
# Plot: training + validation loss curves

fig, ax = plt.subplots(1, 2, figsize=(12, 4))

for r in results:
    name = r["config"].name
    logs = r["logs"]

    step = np.array(logs["step"])
    train_loss = np.array(logs["train_loss"]) 

    if len(train_loss):
        ax[0].plot(step, train_loss, label=f"{name} ({r['status']})")

    if len(logs["val_loss"]):
        ax[1].plot(logs["val_step"], logs["val_loss"], marker="o", linestyle="-", label=f"{name} ({r['status']})")

ax[0].set_title("Train loss vs step")
ax[0].set_xlabel("step")
ax[0].set_ylabel("loss")

ax[1].set_title("Val loss vs step (periodic eval)")
ax[1].set_xlabel("step")
ax[1].set_ylabel("loss")

ax[0].legend();
ax[1].legend();

In [None]:
# Plot: step time and throughput

fig, ax = plt.subplots(1, 2, figsize=(12, 4))

for r in results:
    name = r["config"].name
    logs = r["logs"]

    step = np.array(logs["step"])
    ms = np.array(logs["step_time_ms"]) 
    tps = np.array(logs["tokens_per_s"]) 

    if len(ms) == 0:
        continue

    ax[0].plot(step, ms, label=name)
    ax[1].plot(step, tps, label=name)

ax[0].set_title("Step time (ms)")
ax[0].set_xlabel("step")
ax[0].set_ylabel("ms")

ax[1].set_title("Throughput (tokens/s)")
ax[1].set_xlabel("step")
ax[1].set_ylabel("tokens/s")

ax[0].legend();
ax[1].legend();

In [None]:
# Plot: GradScaler scale over time (only meaningful for amp_fp16)

plt.figure(figsize=(10, 3))
plotted = False

for r in results:
    if r["config"].name != "amp_fp16":
        continue

    logs = r["logs"]
    step = np.array(logs["step"])
    scale = np.array(logs["scale"], dtype=np.float64)

    if len(scale) == 0 or np.all(np.isnan(scale)):
        continue

    plt.plot(step, scale)
    plotted = True

plt.title("GradScaler scale (AMP FP16)")
plt.xlabel("step")
plt.ylabel("scale")

if not plotted:
    print("No GradScaler scale data to plot (did amp_fp16 run?)")

## 3.7 Interpreting results (what to look for)

### Loss curves

- If `fp16_naive` diverges or hits `non_finite_loss`, that’s a *feature*: it illustrates why mixed precision needs scaling + policies.
- `amp_fp16` should usually be stable if the model and learning rate are reasonable.
- `amp_bf16` is often stable without scaling (when supported).

### Step time

- On modern GPUs, `amp_fp16`/`amp_bf16` often reduce step time because matmuls hit Tensor Cores.
- If you don’t see speedup, common reasons:
  - model is too small (overhead dominates)
  - you’re CPU-bound
  - you’re not actually using CUDA

### GradScaler scale

- If scale drops repeatedly, you’re hitting overflow (inf/nan) events.
- If scale grows over time, training is stable and GradScaler is increasing “headroom”.

## 3.8 Practical checklist (AMP in real code)

### Defaults that usually work

- Prefer **BF16 autocast** if your GPU supports it.
- Otherwise use **FP16 autocast + GradScaler**.
- Keep optimizer state in FP32 (default for most PyTorch optimizers).

### Debugging steps

1. If loss becomes `nan`/`inf`: check for overflow sources (attention logits, exp/log, unstable loss).
2. If gradients are mostly 0 in FP16: use GradScaler / increase scale.
3. If training “does nothing”: check for update stagnation (are weights changing?) and whether you inadvertently cast master weights.
4. Use dtype hooks to confirm what ran in what dtype.

### Common gotchas

- Mixing manual `.half()` casts with autocast can lead to unexpected behavior.
- Gradient clipping should happen **after** unscale (if using GradScaler).
- `autocast` should cover forward + loss, not optimizer step.

In [None]:
# Plot: CUDA memory (if available)

if device.type == "cuda":
    plt.figure(figsize=(10, 3))
    for r in results:
        name = r["config"].name
        logs = r["logs"]
        step = np.array(logs["step"])
        mem = np.array(logs["cuda_mem_mb"], dtype=np.float64)
        if len(mem) == 0:
            continue
        plt.plot(step, mem, label=name)

    plt.title("Max CUDA memory allocated (MB)")
    plt.xlabel("step")
    plt.ylabel("MB")
    plt.legend();
else:
    print("CUDA not available; skipping memory plot")

# Appendix — References and further reading

- Micikevicius et al., *Mixed Precision Training* (arXiv:1710.03740)
- NVIDIA mixed precision training guides/blog posts (engineering implementation view)
- PyTorch docs: `torch.amp` (`autocast`, `GradScaler`) and the autocast op reference
- BF16 background materials (design intent: FP32 range with fewer mantissa bits)
- LLM training system docs (FSDP / ZeRO / DeepSpeed): mixed precision interacts with sharding and optimizer state

This notebook’s experiments are intentionally small; the *mechanisms* are the same at scale.

Generated: 2026-02-22