# Activation Checkpointing: A First Principles Deep Dive

Memory is the bottleneck in deep learning. Not compute. Not data. Memory.

This notebook tears apart activation checkpointing from first principles. You will understand exactly why activations dominate memory, how checkpointing trades compute for memory, and when this trade-off makes sense.

---

## The Core Problem

During training you do two things:
1. **Forward pass**: compute outputs layer by layer
2. **Backward pass**: compute gradients using the chain rule

Here's what kills you: backward needs intermediate values from forward. If forward does `y = sin(x)`, backward needs `x` to compute `dy/dx = cos(x)`. PyTorch keeps `x` alive until backward reaches that node. These are "saved tensors" in autograd-speak.

So activation memory accumulates through forward and peaks at the start of backward. You've got this pile of "will need later" tensors sitting in VRAM, waiting.

## What Checkpointing Actually Does

Think of training like hiking a trail:

**Normal training**: You drop breadcrumbs at every step (store all activations). Easy to retrace (backprop), but you're carrying a giant breadcrumb bag (VRAM).

**Checkpointing**: You only drop breadcrumbs at a few checkpoints. On the way back, when you need details between checkpoints, you re-walk that segment.

That's literally it. Save fewer tensors in forward, recompute them on demand during backward. Memory goes down, compute goes up.

## Concrete Example

Suppose your forward is:

```
x --> [f1] --> a --> [f2] --> b --> [f3] --> y
```

**No checkpointing**: autograd saves `a` and `b` so backward can compute gradients.

**Checkpoint after `a`**: you keep `a`, but you don't keep `b`. During backward, when you need `b` to differentiate `f3`, PyTorch reruns `f2(a)` to get it back.

Memory goes down (you didn't store `b`). Compute goes up (you recomputed it).

## PyTorch API

```python
out = checkpoint(fn, *args, use_reentrant=False)
```

What happens:
- **Forward**: runs `fn(*args)` but doesn't save intermediates for backward; only keeps `args`
- **Backward**: reruns `fn(*args)` to recreate intermediates, then backprops through

This is why checkpointing is easiest at "block" granularity - Transformer blocks, ResNet blocks, etc.

---

Every code cell below has an Impressions/Conclusions section. These are not summaries. They are insights.

---
# 1. Introduction

Training deep neural networks hits a wall. That wall is GPU memory.

You have a model. You have data. You have compute. But your GPU runs out of memory. Why?

Four things consume GPU memory during training:
1. **Model parameters** (weights and biases)
2. **Gradients** (same size as parameters)
3. **Optimizer states** (for Adam: 2x parameter size for m and v buffers)
4. **Activations** (intermediate outputs from each layer)

For Adam/AdamW, the first three combined are 4x the model size (params + grads + 2 momentum buffers). This is fixed cost.

Here is the surprise: activations often dominate anyway. Not parameters. Not gradients. Not optimizer states. Activations. Why? Because activations scale with batch size. Everything else does not.

In [1]:
# 1.1 The Memory Breakdown: Where Does Memory Go?

import torch
import torch.nn as nn

def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

def bytes_to_mb(b):
    return b / (1024 * 1024)

class DeepNetwork(nn.Module):
    def __init__(self, input_size=1024, hidden_size=1024, num_layers=10):
        super().__init__()
        layers = []
        for i in range(num_layers):
            in_f = input_size if i == 0 else hidden_size
            layers.append(nn.Linear(in_f, hidden_size))
            layers.append(nn.ReLU())
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.layers(x)

model = DeepNetwork(input_size=1024, hidden_size=1024, num_layers=10)
num_params = count_parameters(model)

# THEORETICAL memory breakdown for a full training setup with Adam
# These are projections, not actual allocations in this code
param_mem = num_params * 4  # fp32 parameters
grad_mem = param_mem        # gradients (allocated during backward)
optimizer_mem = param_mem * 2  # Adam's m and v buffers (allocated when optimizer.step() is called)

print(f"Model: {num_params:,} parameters")
print()
print("THEORETICAL memory breakdown for full training with Adam:")
print("-" * 60)
print(f"Parameter memory:  {bytes_to_mb(param_mem):.2f} MB (always allocated)")
print(f"Gradient memory:   {bytes_to_mb(grad_mem):.2f} MB (allocated during backward)")
print(f"Optimizer memory:  {bytes_to_mb(optimizer_mem):.2f} MB (Adam: m + v, allocated on first step)")
print(f"Fixed cost total:  {bytes_to_mb(param_mem + grad_mem + optimizer_mem):.2f} MB")
print()
print("Activation memory by batch size (THEORETICAL):")
print("-" * 60)

fixed_cost = param_mem + grad_mem + optimizer_mem
for batch_size in [32, 64, 128, 256, 512]:
    # Activation estimate: batch_size * hidden_size * num_layers * 4 bytes
    # This is simplified - actual activation memory depends on what PyTorch saves
    act_mem = batch_size * 1024 * 10 * 4
    total = fixed_cost + act_mem
    act_pct = (act_mem / total) * 100
    print(f"Batch {batch_size:3d}: activations = {bytes_to_mb(act_mem):6.2f} MB ({act_pct:.1f}% of total)")

print()
print("NOTE: These are theoretical estimates. See Section 1.2 for actual GPU measurements.")

Model: 10,496,000 parameters

THEORETICAL memory breakdown for full training with Adam:
------------------------------------------------------------
Parameter memory:  40.04 MB (always allocated)
Gradient memory:   40.04 MB (allocated during backward)
Optimizer memory:  80.08 MB (Adam: m + v, allocated on first step)
Fixed cost total:  160.16 MB

Activation memory by batch size (THEORETICAL):
------------------------------------------------------------
Batch  32: activations =   1.25 MB (0.8% of total)
Batch  64: activations =   2.50 MB (1.5% of total)
Batch 128: activations =   5.00 MB (3.0% of total)
Batch 256: activations =  10.00 MB (5.9% of total)
Batch 512: activations =  20.00 MB (11.1% of total)

NOTE: These are theoretical estimates. See Section 1.2 for actual GPU measurements.


### Impressions/Conclusions (1.1)

This cell shows **theoretical** memory projections for a full training setup. No optimizer or gradients are actually allocated yet - we're just computing what the memory breakdown *would be*.

The theoretical fixed costs for Adam/AdamW:
- Parameters: 1x model size (always allocated)
- Gradients: 1x model size (allocated during backward pass)
- Optimizer states (m, v): 2x model size (allocated on first optimizer.step())
- Total fixed: 4x model size

At small batches, fixed costs dominate. At large batches, activations dominate. This crossover point depends on model depth and hidden size.

The math:
- Fixed cost: O(model_size)
- Activations: O(batch_size × depth × hidden_size)

Batch size is the multiplier that kills you. Checkpointing targets activations because that is where the scaling problem lives.

**Key distinction**: Section 1.2 shows *actual* GPU memory measurements during real forward/backward passes.

In [None]:
# 1.2 Real GPU Memory Measurement

import torch
import torch.nn as nn

class DeepNetwork(nn.Module):
    """Deliberately shallow but wide to make activations dominate over parameters."""
    def __init__(self, input_size=1024, hidden_size=4096, num_layers=6):
        super().__init__()
        layers = []
        for i in range(num_layers):
            in_f = input_size if i == 0 else hidden_size
            layers.append(nn.Linear(in_f, hidden_size))
            layers.append(nn.ReLU())
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.layers(x)

print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    
    model = DeepNetwork().cuda()
    num_params = sum(p.numel() for p in model.parameters())
    param_mem = num_params * 4 / 1e6
    print(f"Model parameters: {num_params:,} ({param_mem:.1f} MB)")
    print()
    
    # Use larger batch sizes to show activation scaling
    for batch_size in [64, 128, 256, 512]:
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()
        
        x = torch.randn(batch_size, 1024).cuda()
        output = model(x)
        loss = output.sum()
        loss.backward()
        
        peak = torch.cuda.max_memory_allocated() / 1e6
        # Estimate activation memory: batch * hidden * num_layers * 4 bytes
        act_estimate = batch_size * 4096 * 6 * 4 / 1e6
        print(f"Batch {batch_size:3d}: Peak = {peak:.1f} MB  (est. activations: {act_estimate:.1f} MB)")
        
        model.zero_grad()
        del x, output, loss
else:
    print("[Run on GPU to see measurements]")

### Impressions/Conclusions (1.2)

Activation memory scales linearly with batch size. But whether you *see* this depends on the ratio of activations to fixed costs (parameters + gradients).

**The math:**
- Fixed cost: `num_params × 4 bytes × 2` (weights + gradients)
- Activation cost: `batch × hidden × depth × 4 bytes`

With a 100M parameter model and batch size 32, fixed costs dwarf activations. You won't see much scaling. But as batch size grows, activations eventually dominate.

This is exactly what activation checkpointing targets. At the batch sizes where you're actually memory-constrained, activations are the problem.

In [5]:
# 1.3 When Does Memory Peak? The Critical Insight

import torch
import torch.nn as nn

class InstrumentedNetwork(nn.Module):
    """Network that tracks memory at each layer."""
    def __init__(self, num_layers=5, hidden=4096):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(hidden, hidden) for _ in range(num_layers)
        ])
    
    def forward(self, x, track=False):
        memory_log = []
        for i, layer in enumerate(self.layers):
            x = torch.relu(layer(x))
            if track and torch.cuda.is_available():
                torch.cuda.synchronize()
                memory_log.append(torch.cuda.memory_allocated() / 1e6)
        return x, memory_log

if torch.cuda.is_available():
    # Use larger batch and hidden size to make activations visible
    model = InstrumentedNetwork(num_layers=8, hidden=4096).cuda()
    x = torch.randn(256, 4096).cuda()
    
    num_params = sum(p.numel() for p in model.parameters())
    print(f"Model: {num_params/1e6:.1f}M params")
    print()
    
    # Baseline: just model loaded
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    mem_model_only = torch.cuda.memory_allocated() / 1e6
    
    # Track memory during forward
    output, fwd_mem = model(x, track=True)
    mem_after_forward = torch.cuda.memory_allocated() / 1e6
    
    # Now trigger backward
    loss = output.sum()
    loss.backward()
    peak_memory = torch.cuda.max_memory_allocated() / 1e6
    mem_after_backward = torch.cuda.memory_allocated() / 1e6
    
    # After clearing gradients
    model.zero_grad(set_to_none=True)
    mem_after_zero_grad = torch.cuda.memory_allocated() / 1e6
    
    print("Memory Timeline:")
    print("-" * 55)
    print(f"{'Model loaded:':<35} {mem_model_only:>8.1f} MB")
    print()
    print("During forward (activations accumulating):")
    for i, mem in enumerate(fwd_mem):
        delta = mem - (fwd_mem[i-1] if i > 0 else mem_model_only)
        print(f"  After layer {i+1}:{'':<23} {mem:>8.1f} MB  (+{delta:.1f})")
    print()
    print(f"{'After forward (all activations live):':<35} {mem_after_forward:>8.1f} MB")
    print(f"{'PEAK (activations + gradients):':<35} {peak_memory:>8.1f} MB  <-- this is what kills you")
    print(f"{'After backward (gradients allocated):':<35} {mem_after_backward:>8.1f} MB")
    print(f"{'After zero_grad:':<35} {mem_after_zero_grad:>8.1f} MB")
    print()
    
    activation_mem = mem_after_forward - mem_model_only
    grad_mem = mem_after_backward - mem_after_zero_grad
    print(f"Activation memory: ~{activation_mem:.1f} MB")
    print(f"Gradient memory:   ~{grad_mem:.1f} MB")
else:
    print("[Run on GPU to see memory timeline]")

Model: 134.3M params

Memory Timeline:
-------------------------------------------------------
Model loaded:                          895.5 MB

During forward (activations accumulating):
  After layer 1:                           899.7 MB  (+4.2)
  After layer 2:                           903.9 MB  (+4.2)
  After layer 3:                           908.1 MB  (+4.2)
  After layer 4:                           912.3 MB  (+4.2)
  After layer 5:                           916.5 MB  (+4.2)
  After layer 6:                           920.7 MB  (+4.2)
  After layer 7:                           924.9 MB  (+4.2)
  After layer 8:                           929.1 MB  (+4.2)

After forward (all activations live):    928.5 MB
PEAK (activations + gradients):       1104.7 MB  <-- this is what kills you
After backward (gradients allocated):   1100.5 MB
After zero_grad:                       563.5 MB

Activation memory: ~33.0 MB
Gradient memory:   ~537.0 MB


### Impressions/Conclusions (1.3)

The memory timeline tells the story:

1. **Model loaded**: Just weights in memory
2. **Forward pass**: Memory grows as activations accumulate (each layer adds `batch × hidden × 4 bytes`)
3. **Peak**: Occurs during backward when you have BOTH activations (not yet freed) AND gradients (being allocated)
4. **After backward**: Activations freed, but gradients remain (attached to parameters)
5. **After zero_grad**: Back to just weights

The peak is what causes OOM. At that moment, you're holding activations AND gradients simultaneously. Checkpointing attacks this by not storing activations during forward—so when backward starts, you only have checkpoints + gradients, not the full activation pile.

---
# 2. Understanding Activation Checkpointing

Activation checkpointing is a memory-compute trade-off. Save memory by not storing activations. Pay for it by recomputing them during backprop.

To understand this, you need to know what activations are and why backprop needs them.

## 2.1 What Are Activations?

An activation is the output of a layer. That is it.

```
Input x -> [Layer 1] -> a1 -> [Layer 2] -> a2 -> [Layer 3] -> a3 -> Output
```

`a1`, `a2`, `a3` are activations. Each is a tensor stored in memory.

Why store them? Backpropagation needs them to compute gradients.

In [6]:
# 2.1 Why Backprop Needs Activations

import torch

# y = W2 * relu(W1 * x)
# Gradient w.r.t. W1 needs the pre-activation W1*x
# to compute relu'(W1*x)

torch.manual_seed(42)

W1 = torch.randn(4, 4, requires_grad=True)
W2 = torch.randn(4, 4, requires_grad=True)
x = torch.randn(1, 4)

# Forward pass stores these
z1 = x @ W1.T           # Pre-activation (needed for ReLU grad)
a1 = torch.relu(z1)     # Activation (needed for W2 grad)
y = a1 @ W2.T

print("Stored for backprop:")
print(f"  z1: {z1.shape} - needed to compute relu gradient")
print(f"  a1: {a1.shape} - needed to compute W2 gradient")

loss = y.sum()
loss.backward()

print(f"\nGradients computed:")
print(f"  W1.grad: {W1.grad.shape}")
print(f"  W2.grad: {W2.grad.shape}")

Stored for backprop:
  z1: torch.Size([1, 4]) - needed to compute relu gradient
  a1: torch.Size([1, 4]) - needed to compute W2 gradient

Gradients computed:
  W1.grad: torch.Size([4, 4])
  W2.grad: torch.Size([4, 4])


### Impressions/Conclusions (2.1)

The chain rule requires intermediate values. Backprop walks backward through the computation graph. At each step, it needs forward pass values.

For linear `y = Wx`: gradient w.r.t. W needs x.

For ReLU `y = relu(x)`: gradient needs to know where x > 0.

This is why activations are stored. This is what checkpointing eliminates.

In [7]:
# 2.2 The Core Trade-off: Store vs Recompute

import math

print("Standard Training:")
print("  Forward: x -> a1 -> a2 -> a3 -> loss")
print("           [store] [store] [store]")
print("  Backward: uses stored a1, a2, a3")
print("  Memory: O(n) activations")
print()
print("With Checkpointing:")
print("  Forward: x -> a1 -> a2 -> a3 -> loss")
print("           [save]       [save]")
print("  Backward: recompute a2 from a1, then use")
print("  Memory: O(sqrt(n)) with optimal placement")
print()

num_layers = 100
optimal = int(math.sqrt(num_layers))

# With sqrt(n) checkpoints, we have sqrt(n) segments of sqrt(n) layers each
# During backward, each segment must be recomputed once
# Recompute cost = sqrt(n) segments × sqrt(n) layers = n layer computations
# This equals one additional forward pass worth of compute

print(f"For {num_layers} layers:")
print(f"  Standard: {num_layers} activations stored")
print(f"  Checkpointed: ~{optimal} activations stored (at checkpoint boundaries)")
print(f"  Memory reduction: {num_layers/optimal:.0f}x")
print()
print("Compute analysis:")
print(f"  Without checkpointing: n forward + n backward = 2n")
print(f"  With checkpointing:    n forward + n backward + n recompute = 3n")
print(f"  Compute overhead: ~50% of total training time")

Standard Training:
  Forward: x -> a1 -> a2 -> a3 -> loss
           [store] [store] [store]
  Backward: uses stored a1, a2, a3
  Memory: O(n) activations

With Checkpointing:
  Forward: x -> a1 -> a2 -> a3 -> loss
           [save]       [save]
  Backward: recompute a2 from a1, then use
  Memory: O(sqrt(n)) with optimal placement

For 100 layers:
  Standard: 100 activations stored
  Checkpointed: ~10 activations stored (at checkpoint boundaries)
  Memory reduction: 10x

Compute analysis:
  Without checkpointing: n forward + n backward = 2n
  With checkpointing:    n forward + n backward + n recompute = 3n
  Compute overhead: ~50% of total training time


### Impressions/Conclusions (2.2)

The sqrt(n) rule: optimal checkpointing reduces memory from O(n) to O(sqrt(n)).

For 100 layers:
- Standard: 100 activations stored
- Checkpointed: ~10 activations stored (at checkpoint boundaries)
- Memory reduction: 10x

The compute cost is real:
- Without checkpointing: Forward (n) + Backward (n) = 2n operations
- With checkpointing: Forward (n) + Backward (n) + Recompute (n) = 3n operations
- **Overhead: ~50% of total training time**

This is still an excellent trade-off: 10x memory reduction for 50% compute overhead. When memory is the bottleneck (and it usually is), this lets you train models that otherwise wouldn't fit.

---
# 3. Prerequisites

Environment setup and helper functions.

In [8]:
# 3.1 Environment Check

import torch
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

PyTorch Version: 2.9.0+cu126
CUDA Available: True
CUDA Version: 12.6
GPU: Tesla T4
GPU Memory: 15.83 GB


### Impressions/Conclusions (3.1)

PyTorch 1.9+ required for full checkpointing support. `torch.utils.checkpoint` provides the core API. CUDA strongly recommended since checkpointing shines when GPU memory is the bottleneck.

In [9]:
# 3.2 Memory Profiling Utilities

import torch
from contextlib import contextmanager

class MemoryTracker:
    """Track GPU memory usage."""
    
    def __init__(self):
        self.snapshots = []
    
    def snapshot(self, label=""):
        if torch.cuda.is_available():
            mem = torch.cuda.memory_allocated() / 1e6
            self.snapshots.append({'label': label, 'mb': mem})
            return mem
        return 0
    
    def reset(self):
        self.snapshots = []
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.empty_cache()
    
    def peak_mb(self):
        if torch.cuda.is_available():
            return torch.cuda.max_memory_allocated() / 1e6
        return 0
    
    def report(self):
        for s in self.snapshots:
            print(f"{s['label']:30s}: {s['mb']:8.2f} MB")
        print(f"{'Peak':30s}: {self.peak_mb():8.2f} MB")

@contextmanager
def track_memory(label="Op"):
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()
    yield
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        print(f"{label}: Peak = {torch.cuda.max_memory_allocated()/1e6:.2f} MB")

print("Utilities loaded: MemoryTracker, track_memory()")

Utilities loaded: MemoryTracker, track_memory()


### Impressions/Conclusions (3.2)

Key functions:
- `torch.cuda.memory_allocated()`: current memory in use
- `torch.cuda.max_memory_allocated()`: peak since last reset
- `torch.cuda.empty_cache()`: free cached memory

Peak memory is what matters. That is what causes OOM.

---
# 4. How Activation Checkpointing Works

PyTorch provides `torch.utils.checkpoint`. Two main functions:
- `checkpoint`: wrap a single function/module
- `checkpoint_sequential`: wrap a sequence of modules

The API is simple. The magic is in what happens under the hood.

In [10]:
# 4.1 Model WITHOUT Checkpointing

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1024, 1024)
        self.layer2 = nn.Linear(1024, 1024)

    def forward(self, x):
        x = torch.relu(self.layer1(x))
        x = torch.relu(self.layer2(x))
        return x

if torch.cuda.is_available():
    model = SimpleModel().cuda()
    x = torch.randn(1, 1024).cuda()

    torch.cuda.reset_peak_memory_stats()
    output = model(x)
    loss = output.sum()
    loss.backward()
    
    print(f"Peak Memory WITHOUT Checkpointing: {torch.cuda.max_memory_allocated() / 1e6:.2f} MB")
else:
    print("[Run on GPU]")

Peak Memory WITHOUT Checkpointing: 567.70 MB


In [11]:
# 4.2 Model WITH Checkpointing

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class CheckpointedModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1024, 1024)
        self.layer2 = nn.Linear(1024, 1024)

    def forward(self, x):
        # NOTE: checkpoint() is most useful on *intermediate* blocks where the input
        # activation already requires grad (i.e., it comes from earlier layers).
        x = torch.relu(self.layer1(x))

        # Checkpoint layer2: its internal activations won't be stored.
        # They will be recomputed during backward.
        x = checkpoint(lambda t: torch.relu(self.layer2(t)), x, use_reentrant=False)
        return x

if torch.cuda.is_available():
    model = CheckpointedModel().cuda()
    x = torch.randn(1, 1024).cuda()

    torch.cuda.reset_peak_memory_stats()
    output = model(x)
    loss = output.sum()
    loss.backward()
    
    print(f"Peak Memory WITH Checkpointing: {torch.cuda.max_memory_allocated() / 1e6:.2f} MB")
else:
    print("[Run on GPU]")

Peak Memory WITH Checkpointing: 43.30 MB


### Impressions/Conclusions (4.1 & 4.2)

The difference is subtle in small models. With 2 layers, memory savings are minimal because the baseline is already small.

The real benefit appears at scale. Let us test with a deeper model.

In [None]:
# 4.3 Deep Model Comparison: The Real Difference

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class DeepModelNoCheckpoint(nn.Module):
    def __init__(self, num_layers=20, hidden=2048):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(hidden, hidden) for _ in range(num_layers)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            x = torch.relu(layer(x))
        return x

class DeepModelWithCheckpoint(nn.Module):
    """Checkpoints SEGMENTS of layers, not individual layers.
    
    Per-layer checkpointing doesn't help: each checkpoint() saves its input,
    which is the output of the previous layer. You still store all activations.
    
    Segment checkpointing: group N layers, only save input to each segment.
    Memory: O(num_segments) instead of O(num_layers).
    """
    def __init__(self, num_layers=20, hidden=2048, segment_size=5):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(hidden, hidden) for _ in range(num_layers)
        ])
        self.segment_size = segment_size
    
    def _run_segment(self, x, start_idx, end_idx):
        """Run a segment of layers."""
        for i in range(start_idx, end_idx):
            x = torch.relu(self.layers[i](x))
        return x
    
    def forward(self, x):
        num_layers = len(self.layers)
        for start in range(0, num_layers, self.segment_size):
            end = min(start + self.segment_size, num_layers)
            # Checkpoint each segment: only the segment INPUT is saved
            x = checkpoint(
                self._run_segment, x, start, end,
                use_reentrant=False
            )
        return x

if torch.cuda.is_available():
    batch_size = 128
    hidden = 2048
    num_layers = 20
    
    # Without checkpointing
    torch.cuda.empty_cache()
    model1 = DeepModelNoCheckpoint(num_layers, hidden).cuda()
    x1 = torch.randn(batch_size, hidden).cuda()
    
    torch.cuda.reset_peak_memory_stats()
    out1 = model1(x1)
    out1.sum().backward()
    mem_no_ckpt = torch.cuda.max_memory_allocated() / 1e6
    
    del model1, x1, out1
    torch.cuda.empty_cache()
    
    # With segment checkpointing (4 segments of 5 layers each)
    model2 = DeepModelWithCheckpoint(num_layers, hidden, segment_size=5).cuda()
    x2 = torch.randn(batch_size, hidden).cuda()
    
    torch.cuda.reset_peak_memory_stats()
    out2 = model2(x2)
    out2.sum().backward()
    mem_ckpt = torch.cuda.max_memory_allocated() / 1e6
    
    print(f"Config: {num_layers} layers, hidden={hidden}, batch={batch_size}")
    print(f"Segment size: 5 layers (4 checkpoints total)")
    print()
    print(f"WITHOUT checkpointing: {mem_no_ckpt:.2f} MB")
    print(f"WITH checkpointing:    {mem_ckpt:.2f} MB")
    print(f"Memory saved: {mem_no_ckpt - mem_ckpt:.2f} MB ({(1 - mem_ckpt/mem_no_ckpt)*100:.1f}%)")
else:
    print("[Run on GPU]")

### Impressions/Conclusions (4.3)

**Critical insight: checkpoint granularity matters.**

Per-layer checkpointing (wrapping each layer individually) doesn't save memory. Why? Each `checkpoint()` call saves its input. If you checkpoint layer 2, you save the output of layer 1. Checkpoint layer 3, save output of layer 2. You're still storing all intermediate activations.

**Segment checkpointing** is the fix: group N layers together, checkpoint the group. Now you only save the input to each segment, not every layer's output.

With 20 layers split into 4 segments of 5:
- Without checkpointing: store 20 activations
- With checkpointing: store 4 activations (segment inputs) + recompute 5 layers during backward per segment

The trade-off: fewer stored activations, but each segment's layers run twice (once in forward, once in backward).

---
# 5. Code Examples

Three patterns you will use:
1. Basic `checkpoint()` for single modules
2. `checkpoint_sequential()` for sequential models
3. Manual checkpointing in complex architectures (transformers)

In [None]:
# 5.1 Basic Usage

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU())
        self.block2 = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU())

    def forward(self, x):
        # Run block1 normally, then checkpoint block2.
        # This ensures the checkpointed block receives an activation that requires grad.
        x = self.block1(x)
        x = checkpoint(self.block2, x, use_reentrant=False)
        return x

model = SimpleModel()
if torch.cuda.is_available():
    model = model.cuda()

x = torch.randn(1, 1024)
if torch.cuda.is_available():
    x = x.cuda()

output = model(x)
print(f"Output shape: {output.shape}")
print("Checkpointing applied to block2. Block1 runs normally.")

### Impressions/Conclusions (5.1)

`checkpoint(fn, *args)` wraps any callable. During forward, it runs `fn(*args)` but does not save intermediate activations. During backward, it re-runs `fn(*args)` to recompute them.

Note: `use_reentrant=False` is the modern API. It handles edge cases better than the old default.

In [None]:
# 5.2 Using checkpoint_sequential

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint_sequential

# A deep sequential model
model = nn.Sequential(
    nn.Linear(1024, 1024), nn.ReLU(),
    nn.Linear(1024, 1024), nn.ReLU(),
    nn.Linear(1024, 1024), nn.ReLU(),
    nn.Linear(1024, 1024), nn.ReLU(),
)

x = torch.randn(1, 1024)
if torch.cuda.is_available():
    model = model.cuda()
    x = x.cuda()

# Split into 2 checkpoint segments
# Each segment will be checkpointed separately
segments = 2
output = checkpoint_sequential(model, segments, x, use_reentrant=False)

print(f"Output shape: {output.shape}")
print(f"Model split into {segments} checkpoint segments")
print("Each segment recomputes activations during backward pass")

### Impressions/Conclusions (5.2)

`checkpoint_sequential` is convenient for nn.Sequential models. The `segments` parameter controls granularity:
- More segments = less memory, more recomputation
- Fewer segments = more memory, less recomputation

Rule of thumb: start with sqrt(num_layers) segments.

In [None]:
# 5.3 Transformer Training with Checkpointing

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, num_heads=8):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_size, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, embed_size * 4),
            nn.GELU(),
            nn.Linear(embed_size * 4, embed_size)
        )
        self.norm2 = nn.LayerNorm(embed_size)

    def forward(self, x):
        # Self-attention with residual
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)
        # Feed-forward with residual (checkpointed)
        ff_out = checkpoint(self.feed_forward, x, use_reentrant=False)
        x = self.norm2(x + ff_out)
        return x

# Example usage
embed_size = 512
seq_length = 64
batch_size = 8

block = TransformerBlock(embed_size)
if torch.cuda.is_available():
    block = block.cuda()

x = torch.randn(batch_size, seq_length, embed_size)
if torch.cuda.is_available():
    x = x.cuda()

output = block(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print("Feed-forward block is checkpointed (it uses the most memory)")

### Impressions/Conclusions (5.3)

In transformers, the feed-forward block uses 4x the hidden size. This is where most activation memory goes. Checkpointing the FFN is the standard practice in large language models.

Why not checkpoint attention? You can, but attention has complex intermediate states. The memory-compute trade-off is less favorable there.

In [None]:
# 5.4 ResNet with Selective Checkpointing (Bonus)

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class BasicBlock(nn.Module):
    """ResNet basic block with optional checkpointing."""
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return torch.relu(out)

class MiniResNet(nn.Module):
    """Small ResNet with checkpointing options."""
    def __init__(self, num_blocks=4, use_checkpoint=False):
        super().__init__()
        self.use_checkpoint = use_checkpoint
        
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.blocks = nn.ModuleList([
            BasicBlock(64 if i == 0 else 128, 128, stride=2 if i == 0 else 1)
            for i in range(num_blocks)
        ])
        
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(128, 10)
    
    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        
        for block in self.blocks:
            if self.use_checkpoint:
                x = checkpoint(block, x, use_reentrant=False)
            else:
                x = block(x)
        
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# Compare memory
if torch.cuda.is_available():
    batch_size = 32
    
    # Without checkpoint
    torch.cuda.empty_cache()
    model1 = MiniResNet(num_blocks=8, use_checkpoint=False).cuda()
    x1 = torch.randn(batch_size, 3, 32, 32).cuda()
    torch.cuda.reset_peak_memory_stats()
    out1 = model1(x1)
    out1.sum().backward()
    mem1 = torch.cuda.max_memory_allocated() / 1e6
    
    del model1, x1, out1
    torch.cuda.empty_cache()
    
    # With checkpoint
    model2 = MiniResNet(num_blocks=8, use_checkpoint=True).cuda()
    x2 = torch.randn(batch_size, 3, 32, 32).cuda()
    torch.cuda.reset_peak_memory_stats()
    out2 = model2(x2)
    out2.sum().backward()
    mem2 = torch.cuda.max_memory_allocated() / 1e6
    
    print(f"ResNet with 8 blocks, batch={batch_size}")
    print(f"WITHOUT checkpoint: {mem1:.2f} MB")
    print(f"WITH checkpoint:    {mem2:.2f} MB")
    print(f"Savings: {(1 - mem2/mem1)*100:.1f}%")
else:
    print("[Run on GPU]")

### Impressions/Conclusions (5.4)

CNNs benefit from checkpointing too. Each residual block stores feature maps that can be recomputed. For deeper ResNets (ResNet-152), checkpointing is essential to train with reasonable batch sizes.

Key pattern: wrap entire blocks, not individual layers. The overhead of checkpoint calls adds up.

---
# 6. Performance Benchmarks

Numbers matter. Let us measure:
1. Memory savings across model sizes
2. Compute overhead (training time)
3. Batch size scaling

In [None]:
# 6.1 Memory Savings Across Model Depths

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import time

def create_model(num_layers, hidden, use_checkpoint=False):
    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.layers = nn.ModuleList([
                nn.Linear(hidden, hidden) for _ in range(num_layers)
            ])
            self.use_ckpt = use_checkpoint
        
        def forward(self, x):
            # Avoid checkpointing the very first layer: its input is a user-provided leaf tensor,
            # so checkpointing it can be a no-op. Once the first layer runs, activations require
            # grad and checkpointing behaves as expected.
            for i, layer in enumerate(self.layers):
                if self.use_ckpt and i > 0:
                    x = checkpoint(lambda t, l=layer: torch.relu(l(t)), x, use_reentrant=False)
                else:
                    x = torch.relu(layer(x))
            return x
    return Model()

if torch.cuda.is_available():
    hidden = 1024
    batch = 64
    
    print(f"Config: hidden={hidden}, batch={batch}")
    print("-" * 60)
    print(f"{'Layers':<10} {'No Ckpt (MB)':<15} {'With Ckpt (MB)':<15} {'Savings':<10}")
    print("-" * 60)
    
    for num_layers in [10, 20, 40, 80]:
        # Without checkpoint
        torch.cuda.empty_cache()
        m1 = create_model(num_layers, hidden, False).cuda()
        x1 = torch.randn(batch, hidden).cuda()
        torch.cuda.reset_peak_memory_stats()
        m1(x1).sum().backward()
        mem1 = torch.cuda.max_memory_allocated() / 1e6
        del m1, x1
        
        # With checkpoint
        torch.cuda.empty_cache()
        m2 = create_model(num_layers, hidden, True).cuda()
        x2 = torch.randn(batch, hidden).cuda()
        torch.cuda.reset_peak_memory_stats()
        m2(x2).sum().backward()
        mem2 = torch.cuda.max_memory_allocated() / 1e6
        del m2, x2
        
        savings = (1 - mem2/mem1) * 100
        print(f"{num_layers:<10} {mem1:<15.2f} {mem2:<15.2f} {savings:.1f}%")
else:
    print("[Run on GPU for benchmarks]")

### Impressions/Conclusions (6.1)

Memory savings increase with depth. At 10 layers, savings are modest. At 80 layers, savings can exceed 50%.

This matches the theory: activation memory is O(n), checkpointing reduces it to O(1) per segment.

In [None]:
# 6.2 Batch Size Scaling: The Real Win

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

def create_model(num_layers, hidden, use_checkpoint=False):
    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.layers = nn.ModuleList([
                nn.Linear(hidden, hidden) for _ in range(num_layers)
            ])
            self.use_ckpt = use_checkpoint
        
        def forward(self, x):
            # Avoid checkpointing the first layer (see Section 6.1 note).
            for i, layer in enumerate(self.layers):
                if self.use_ckpt and i > 0:
                    x = checkpoint(lambda t, l=layer: torch.relu(l(t)), x, use_reentrant=False)
                else:
                    x = torch.relu(layer(x))
            return x
    return Model()

def find_max_batch(model_fn, hidden, start=32, max_batch=2048):
    """Find maximum batch size before OOM (CUDA only)."""
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is required for find_max_batch().")

    start = max(1, int(start))
    max_batch = max(start, int(max_batch))

    def can_run(batch_size):
        model = None
        x = None
        out = None
        try:
            torch.cuda.empty_cache()
            model = model_fn().cuda()
            model.train()
            x = torch.randn(batch_size, hidden, device="cuda")
            torch.cuda.reset_peak_memory_stats()
            out = model(x)
            out.sum().backward()
            return True
        except RuntimeError as e:
            if "out of memory" in str(e).lower():
                return False
            raise
        finally:
            # Best-effort cleanup after OOMs
            try:
                if model is not None:
                    model.zero_grad(set_to_none=True)
            except Exception:
                pass
            del model, x, out
            torch.cuda.empty_cache()

    # Exponential search to find an upper bound, then binary search.
    max_working = 0
    batch = start
    while batch <= max_batch and can_run(batch):
        max_working = batch
        batch *= 2

    low = max_working + 1
    high = min(batch, max_batch)

    while low <= high:
        mid = (low + high) // 2
        if can_run(mid):
            max_working = mid
            low = mid + 1
        else:
            high = mid - 1

    return max_working

if torch.cuda.is_available():
    hidden = 2048
    num_layers = 30
    
    print(f"Finding maximum batch size for {num_layers}-layer model, hidden={hidden}")
    print("This demonstrates the practical benefit of checkpointing.\n")
    
    # Max batch without checkpointing
    model_no_ckpt = lambda: create_model(num_layers, hidden, False)
    max_batch_no_ckpt = find_max_batch(model_no_ckpt, hidden)
    
    # Max batch with checkpointing
    model_ckpt = lambda: create_model(num_layers, hidden, True)
    max_batch_ckpt = find_max_batch(model_ckpt, hidden)

    print(f"WITHOUT checkpointing: max batch = {max_batch_no_ckpt}")
    print(f"WITH checkpointing:    max batch = {max_batch_ckpt}")

    if max_batch_no_ckpt == 0:
        print("Baseline OOM even at batch size 1. Reduce hidden size / layers.")
    else:
        improvement = max_batch_ckpt / max_batch_no_ckpt
        print(f"Improvement: {improvement:.1f}x larger batches possible")
else:
    print("[Run on GPU to find max batch sizes]")

In [None]:
# 6.3 Compute Overhead Measurement

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import time

def benchmark_time(model, x, num_iters=10):
    # Warmup
    for _ in range(3):
        model(x).sum().backward()
        model.zero_grad()
    
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(num_iters):
        model(x).sum().backward()
        model.zero_grad()
    torch.cuda.synchronize()
    return (time.time() - start) / num_iters * 1000  # ms

if torch.cuda.is_available():
    hidden = 1024
    batch = 64
    num_layers = 40
    
    # Create models
    m1 = create_model(num_layers, hidden, False).cuda()
    m2 = create_model(num_layers, hidden, True).cuda()
    x = torch.randn(batch, hidden).cuda()
    
    time1 = benchmark_time(m1, x)
    time2 = benchmark_time(m2, x)
    
    overhead = (time2 - time1) / time1 * 100
    
    print(f"Config: {num_layers} layers, hidden={hidden}, batch={batch}")
    print(f"WITHOUT checkpoint: {time1:.2f} ms/iter")
    print(f"WITH checkpoint:    {time2:.2f} ms/iter")
    print(f"Overhead: {overhead:.1f}%")
else:
    print("[Run on GPU for benchmarks]")

### Impressions/Conclusions (6.2)

This is the practical payoff. Checkpointing lets you use 2-3x larger batch sizes.

Larger batches mean:
- Better gradient estimates (less noise)
- Higher GPU utilization
- Faster convergence (sometimes)

The compute overhead of checkpointing is often offset by the efficiency gains from larger batches. You recompute more, but you also process more data per step.

### Impressions/Conclusions (6.3)

Compute overhead is typically 30-50% depending on model architecture. This is the extra forward passes during backprop.

The trade-off: 50% memory savings for ~50% compute overhead. Worth it when memory is the bottleneck.

---
# 7. Debugging and Best Practices

Checkpointing has gotchas. Here are the common ones and how to avoid them.

In [None]:
# 7.1 Pitfall: Non-Differentiable Operations

import torch
from torch.utils.checkpoint import checkpoint

# BAD: torch.no_grad() inside checkpointed function
def bad_forward(x):
    with torch.no_grad():  # This breaks gradient computation!
        x = x ** 2
    return x

# GOOD: Keep everything differentiable
def good_forward(x):
    x = x ** 2  # No torch.no_grad()
    return x

x = torch.randn(4, requires_grad=True)

# This will fail or give wrong gradients
try:
    y_bad = checkpoint(bad_forward, x, use_reentrant=False)
    y_bad.sum().backward()
    print("BAD: Gradients might be zero or wrong")
except Exception as e:
    print(f"BAD: Error - {e}")

# This works
x = torch.randn(4, requires_grad=True)
y_good = checkpoint(good_forward, x, use_reentrant=False)
y_good.sum().backward()
print(f"GOOD: x.grad = {x.grad}")

### Impressions/Conclusions (7.1)

Rule: Everything inside a checkpointed function must be differentiable. No `torch.no_grad()`, no detaching tensors, no in-place operations that break autograd.

If you need non-differentiable ops, do them outside the checkpointed region.

In [None]:
# 7.2 Pitfall: Random Operations (Dropout)

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

# Problem: Dropout uses random masks.
# If RNG state is NOT preserved, the recomputed forward in backward
# will sample a different mask than the original forward.
class ModelWithDropout(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(32, 32)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        x = self.linear(x)
        x = self.dropout(x)  # Random! Different each forward pass
        return x

model = ModelWithDropout()
x = torch.randn(4, 32, requires_grad=True)

# Fix: preserve RNG state (default) so recomputation matches the original forward.
# use_reentrant=False is recommended for other reasons, but RNG correctness comes from preserve_rng_state.

y = checkpoint(model, x, use_reentrant=False, preserve_rng_state=True)
print(f"Output shape: {y.shape}")
print("RNG state preserved: dropout mask matches between forward and recompute")

### Impressions/Conclusions (7.2)

Dropout and other random ops are tricky. The recomputed forward pass must use the same random mask as the original forward.

In PyTorch, RNG correctness comes from `preserve_rng_state=True` (the default for `checkpoint`). Keep it enabled when checkpointing stochastic ops. `use_reentrant=False` is recommended for other reasons, but RNG correctness is controlled by `preserve_rng_state`.

In [None]:
# 7.3 Best Practices Summary

print("""
BEST PRACTICES FOR ACTIVATION CHECKPOINTING
============================================

1. USE use_reentrant=False
   - Modern API, handles edge cases better
   - checkpoint(fn, x, use_reentrant=False)

2. CHECKPOINT LARGE BLOCKS, NOT SMALL LAYERS
   - Overhead of checkpoint() call is non-trivial
   - Group 2-4 layers into blocks, then checkpoint blocks

3. AVOID NESTED CHECKPOINTS
   - Don't checkpoint inside checkpointed functions
   - Leads to exponential recomputation

4. KEEP EVERYTHING DIFFERENTIABLE
   - No torch.no_grad() inside checkpointed regions
   - No tensor detaching
   - No in-place ops that break autograd

5. TEST GRADIENTS FIRST
   - Compare gradients with and without checkpointing
   - They should be identical (within float precision)

6. PROFILE MEMORY
   - Use torch.cuda.memory_stats() to verify savings
   - Peak memory is what matters

7. CONSIDER CHECKPOINT PLACEMENT
   - Middle layers often have largest activations
   - Checkpoint those first
""")

# Gradient verification example
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

model = nn.Linear(64, 64)
x = torch.randn(8, 64, requires_grad=True)

# Without checkpoint
x1 = x.clone().detach().requires_grad_(True)
y1 = model(x1)
y1.sum().backward()
grad1 = x1.grad.clone()

# With checkpoint
x2 = x.clone().detach().requires_grad_(True)
y2 = checkpoint(model, x2, use_reentrant=False)
y2.sum().backward()
grad2 = x2.grad.clone()

print(f"Gradients match: {torch.allclose(grad1, grad2)}")
print(f"Max difference: {(grad1 - grad2).abs().max():.2e}")

### Impressions/Conclusions (7.3)

Gradient verification is essential. If gradients do not match exactly (within float precision), something is wrong with your checkpointing setup. Debug before scaling up.

---
# 8. Selective Activation Checkpoint (SAC)

Standard checkpointing is all-or-nothing. Every op in the checkpointed region gets recomputed during backward.

Selective Activation Checkpoint (SAC) gives you granular control. You choose which operations to save and which to recompute.

Why does this matter? Not all operations are equal:
- Matmuls are expensive to recompute
- Pointwise ops (relu, sigmoid) are cheap
- Attention is very expensive

SAC lets you save the expensive ones and recompute the cheap ones. Best of both worlds.

---
# 9. Integrating with Distributed Training

Checkpointing works with DDP and mixed precision. The combination is powerful.

In [None]:
# 8.1 The CheckpointPolicy Enum

try:
    from torch.utils.checkpoint import CheckpointPolicy
except Exception as e:
    CheckpointPolicy = None
    print("CheckpointPolicy not available in this PyTorch build.")
    print(f"Details: {type(e).__name__}: {e}")

if CheckpointPolicy is not None:
    print("CheckpointPolicy has four options:")
    print("-" * 50)
    print()
    print("MUST_SAVE:")
    print("  - Always save this op's output")
    print("  - Never recompute it")
    print("  - Use for expensive ops (matmul, attention)")
    print()
    print("PREFER_SAVE:")
    print("  - Save if possible")
    print("  - torch.compile may override this")
    print()
    print("MUST_RECOMPUTE:")
    print("  - Always recompute this op")
    print("  - Never save its output")
    print()
    print("PREFER_RECOMPUTE:")
    print("  - Recompute if possible")
    print("  - torch.compile may override this")
    print("  - Use for cheap ops (relu, elementwise)")
    print()
    print("The MUST_ variants are strict. The PREFER_ variants are hints.")

### Impressions/Conclusions (8.1)

The four policies give you a 2x2 matrix:
- MUST vs PREFER: How strict?
- SAVE vs RECOMPUTE: What action?

Use MUST_ when you know for certain. Use PREFER_ when you want torch.compile to potentially optimize further.

Key insight: A policy that returns PREFER_RECOMPUTE for everything is equivalent to vanilla checkpointing. A policy that returns PREFER_SAVE for everything is NOT the same as no checkpointing (it may save extra tensors).

In [None]:
# 8.2 Policy Function: Save Matmuls, Recompute Everything Else

import torch

try:
    from torch.utils.checkpoint import CheckpointPolicy
except Exception as e:
    CheckpointPolicy = None
    print("CheckpointPolicy not available; skipping SAC policy definition.")
    print(f"Details: {type(e).__name__}: {e}")

aten = torch.ops.aten

def _maybe_default(op_name: str):
    """Return torch.ops.aten.<op_name>.default if it exists, else None."""
    try:
        return getattr(getattr(aten, op_name), "default")
    except AttributeError:
        return None

# Policy 1: conservative - only save the most expensive ops.
# IMPORTANT: Use the .default variants for correct op matching.
compute_intensive_ops_basic = [
    _maybe_default("mm"),
    _maybe_default("bmm"),
    _maybe_default("addmm"),
]
compute_intensive_ops_basic = [op for op in compute_intensive_ops_basic if op is not None]

def policy_save_matmuls(ctx, op, *args, **kwargs):
    """Save matmuls, recompute everything else."""
    if CheckpointPolicy is None:
        raise RuntimeError("CheckpointPolicy is not available in this PyTorch build.")
    if op in compute_intensive_ops_basic:
        return CheckpointPolicy.MUST_SAVE
    return CheckpointPolicy.PREFER_RECOMPUTE

print("Policy 1: Save Matmuls Only")
print("-" * 50)
print("Saves: aten.mm.default, aten.bmm.default, aten.addmm.default (when available)")
print("Recomputes: relu, gelu, sigmoid, layernorm, etc.")
print()
print("NOTE: Always match against .default (e.g., aten.mm.default)")
print("to match the actual ops passed to the policy function.")

In [None]:
# 8.3 Policy Function: Save All Compute-Intensive Ops

import torch

try:
    from torch.utils.checkpoint import CheckpointPolicy
except Exception as e:
    CheckpointPolicy = None
    print("CheckpointPolicy not available; skipping SAC policy definition.")
    print(f"Details: {type(e).__name__}: {e}")

aten = torch.ops.aten

def _maybe_default(op_name: str):
    try:
        return getattr(getattr(aten, op_name), "default")
    except AttributeError:
        return None

# Policy 2: aggressive - save everything expensive.
# We build this list dynamically so older/newer PyTorch builds don't crash.
op_names = [
    "mm",
    "bmm",
    "addmm",
    "convolution",
    "upsample_bilinear2d",
    "_scaled_mm",
    "linear",
    "_scaled_dot_product_flash_attention",
    "_scaled_dot_product_efficient_attention",
]

compute_intensive_ops_full = []
missing = []
for name in op_names:
    op = _maybe_default(name)
    if op is None:
        missing.append(name)
    else:
        compute_intensive_ops_full.append(op)

def policy_save_all_expensive(ctx, op, *args, **kwargs):
    """Save all compute-intensive ops, including attention (when present)."""
    if CheckpointPolicy is None:
        raise RuntimeError("CheckpointPolicy is not available in this PyTorch build.")
    if op in compute_intensive_ops_full:
        return CheckpointPolicy.MUST_SAVE
    return CheckpointPolicy.PREFER_RECOMPUTE

print("Policy 2: Save All Expensive Ops")
print("-" * 50)
print("Saves: matmuls + convolutions + attention + upsampling (when available)")
print("Recomputes: only cheap pointwise ops")
if missing:
    print(f"Missing ops in this build: {', '.join(missing)}")
print()
print("NOTE: Always match against .default (e.g., aten.convolution.default)")
print("and expect op availability to vary by build/hardware.")

### Impressions/Conclusions (8.2 & 8.3)

Two policies, two positions on the speed-memory curve:

Policy 1 (save matmuls only):
- Memory: Low (close to full checkpointing)
- Speed: Medium (recomputes everything except matmuls)

Policy 2 (save all expensive):
- Memory: Medium (saves attention and convolutions too)
- Speed: High (only recomputes cheap pointwise ops)

The key realization: pointwise ops are cheap to recompute but take significant memory. Recomputing just those gives you substantial memory savings with minimal compute overhead.

In [None]:
# 8.4 Using SAC in Practice: The Full API

import functools
import torch
from torch.utils.checkpoint import checkpoint

try:
    from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
except Exception as e:
    CheckpointPolicy = None
    create_selective_checkpoint_contexts = None
    print("Selective checkpointing APIs not available in this PyTorch build.")
    print(f"Details: {type(e).__name__}: {e}")

if CheckpointPolicy is not None and create_selective_checkpoint_contexts is not None:
    aten = torch.ops.aten

    def _maybe_default(op_name: str):
        try:
            return getattr(getattr(aten, op_name), "default")
        except AttributeError:
            return None

    # Define which ops to save - MUST use .default variants
    ops_to_save = [
        _maybe_default("mm"),
        _maybe_default("bmm"),
        _maybe_default("addmm"),
    ]
    ops_to_save = [op for op in ops_to_save if op is not None]

    def policy_fn(ctx, op, *args, **kwargs):
        """Policy function for selective checkpointing."""
        if op in ops_to_save:
            return CheckpointPolicy.MUST_SAVE
        return CheckpointPolicy.PREFER_RECOMPUTE

    # Create the context function using functools.partial
    context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)

    def _run_checkpoint(fn, *args, context_fn=None):
        """Call checkpoint() with best-effort compatibility across versions."""
        kwargs = {"use_reentrant": False}
        if context_fn is not None:
            kwargs["context_fn"] = context_fn
        try:
            return checkpoint(fn, *args, **kwargs)
        except TypeError:
            # Drop unsupported kwargs (older builds may not have context_fn/use_reentrant).
            kwargs.pop("context_fn", None)
            try:
                return checkpoint(fn, *args, **kwargs)
            except TypeError:
                kwargs.pop("use_reentrant", None)
                return checkpoint(fn, *args)

    def forward_fn(x, weight1, weight2):
        """Example forward: two matmuls with activations."""
        x = torch.mm(x, weight1)   # Expected to match aten.mm.default
        x = torch.relu(x)          # Cheap: often recomputed
        x = torch.mm(x, weight2)   # Expected to match aten.mm.default
        x = torch.sigmoid(x)       # Cheap: often recomputed
        return x

    x = torch.randn(32, 64, requires_grad=True)
    w1 = torch.randn(64, 64, requires_grad=True)
    w2 = torch.randn(64, 64, requires_grad=True)

    output = _run_checkpoint(forward_fn, x, w1, w2, context_fn=context_fn)

    print("SAC API Usage:")
    print("-" * 50)
    print("1. Define policy_fn(ctx, op, *args, **kwargs) -> CheckpointPolicy")
    print("2. Create context_fn with functools.partial")
    print("3. Pass context_fn to checkpoint(..., context_fn=context_fn)")
    print()
    print(f"Output shape: {output.shape}")
    print("Matmuls saved. Cheap activations will be recomputed during backward.")
    print()
    print("IMPORTANT: Use aten.<op>.default for op matching, not aten.<op>")
else:
    print("[Skipping SAC demo: missing CheckpointPolicy/create_selective_checkpoint_contexts]")

In [None]:
# 8.5 Shortcut: Allowlist of Ops (No Boilerplate Policy)

import functools
import torch
from torch.utils.checkpoint import checkpoint

try:
    from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
except Exception as e:
    CheckpointPolicy = None
    create_selective_checkpoint_contexts = None
    print("Selective checkpointing APIs not available in this PyTorch build.")
    print(f"Details: {type(e).__name__}: {e}")

aten = torch.ops.aten

def _maybe_default(op_name: str):
    try:
        return getattr(getattr(aten, op_name), "default")
    except AttributeError:
        return None

# Allowlist of ops to save (built dynamically)
ops_to_save = [
    _maybe_default("mm"),
    _maybe_default("bmm"),
]
ops_to_save = [op for op in ops_to_save if op is not None]

def allowlist_policy(ops):
    ops = set(ops)

    def _policy(ctx, op, *args, **kwargs):
        if CheckpointPolicy is None:
            raise RuntimeError("CheckpointPolicy is not available in this PyTorch build.")
        if op in ops:
            return CheckpointPolicy.MUST_SAVE
        return CheckpointPolicy.PREFER_RECOMPUTE

    return _policy

if CheckpointPolicy is not None and create_selective_checkpoint_contexts is not None:
    policy_fn = allowlist_policy(ops_to_save)
    context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)

    def _run_checkpoint(fn, *args, context_fn=None):
        kwargs = {"use_reentrant": False}
        if context_fn is not None:
            kwargs["context_fn"] = context_fn
        try:
            return checkpoint(fn, *args, **kwargs)
        except TypeError:
            kwargs.pop("context_fn", None)
            try:
                return checkpoint(fn, *args, **kwargs)
            except TypeError:
                return checkpoint(fn, *args)

    def simple_forward(x, w):
        x = torch.mm(x, w)
        x = torch.relu(x)
        return x

    x = torch.randn(32, 64, requires_grad=True)
    w = torch.randn(64, 64, requires_grad=True)

    output = _run_checkpoint(simple_forward, x, w, context_fn=context_fn)

    print("Shortcut API:")
    print("-" * 50)
    print("Allowlist: choose ops to save (e.g., mm/bmm)")
    print("Everything else defaults to PREFER_RECOMPUTE")
    print(f"Output shape: {output.shape}")
else:
    print("[Skipping SAC shortcut demo: missing APIs]")

In [None]:
# 8.6 SAC vs Standard AC: Memory and Time Comparison

import functools
import time
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

try:
    from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
except Exception as e:
    CheckpointPolicy = None
    create_selective_checkpoint_contexts = None
    print("Selective checkpointing APIs not available in this PyTorch build.")
    print(f"Details: {type(e).__name__}: {e}")

class TransformerFFN(nn.Module):
    """Feed-forward block from a transformer."""
    def __init__(self, dim=1024, expansion=4):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim * expansion)
        self.fc2 = nn.Linear(dim * expansion, dim)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.gelu(x)  # Cheap pointwise
        x = self.fc2(x)
        return x

def _run_checkpoint(fn, *args, context_fn=None):
    """Call checkpoint() with best-effort compatibility across versions."""
    kwargs = {"use_reentrant": False}
    if context_fn is not None:
        kwargs["context_fn"] = context_fn
    try:
        return checkpoint(fn, *args, **kwargs)
    except TypeError:
        kwargs.pop("context_fn", None)
        try:
            return checkpoint(fn, *args, **kwargs)
        except TypeError:
            return checkpoint(fn, *args)

if torch.cuda.is_available():
    dim = 2048
    batch = 64
    seq = 128

    # Add a small stem so the checkpointed region receives activations that require grad.
    stem = nn.Linear(dim, dim).cuda()
    ffn = TransformerFFN(dim).cuda()

    x = torch.randn(batch, seq, dim, device="cuda")

    sac_context = None
    if CheckpointPolicy is not None and create_selective_checkpoint_contexts is not None:
        aten = torch.ops.aten

        def _maybe_default(op_name: str):
            try:
                return getattr(getattr(aten, op_name), "default")
            except AttributeError:
                return None

        saved_ops = [op for op in [_maybe_default("mm"), _maybe_default("addmm")] if op is not None]

        def sac_policy(ctx, op, *args, **kwargs):
            # nn.Linear typically lowers to addmm/mm
            if op in saved_ops:
                return CheckpointPolicy.MUST_SAVE
            return CheckpointPolicy.PREFER_RECOMPUTE

        sac_context = functools.partial(create_selective_checkpoint_contexts, sac_policy)
    else:
        print("SAC APIs not available: will skip Selective AC benchmark")

    def bench(mode: str, iters: int = 5, warmup: int = 2):
        def step():
            h = stem(x)
            if mode == "none":
                out = ffn(h)
            elif mode == "standard":
                out = _run_checkpoint(ffn, h)
            elif mode == "selective":
                if sac_context is None:
                    raise RuntimeError("Selective checkpointing context not available")
                out = _run_checkpoint(ffn, h, context_fn=sac_context)
            else:
                raise ValueError(mode)

            out.sum().backward()
            stem.zero_grad(set_to_none=True)
            ffn.zero_grad(set_to_none=True)

        for _ in range(warmup):
            step()

        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()

        start = time.time()
        for _ in range(iters):
            step()

        torch.cuda.synchronize()
        time_ms = (time.time() - start) / iters * 1000
        mem_mb = torch.cuda.max_memory_allocated() / 1e6
        return mem_mb, time_ms

    results = {
        "No Checkpoint": bench("none"),
        "Standard AC": bench("standard"),
    }
    if sac_context is not None:
        results["Selective AC"] = bench("selective")

    print(f"Transformer FFN: dim={dim}, batch={batch}, seq={seq}")
    print("-" * 70)
    print(f"{'Mode':<15} {'Peak Memory (MB)':<18} {'Time (ms/iter)':<15}")
    print("-" * 70)
    for name, (mem, t) in results.items():
        print(f"{name:<15} {mem:<18.2f} {t:<15.2f}")
    print("-" * 70)
    print("Standard AC: recomputes everything inside the checkpointed region")
    print("Selective AC (when available): can save matmuls and recompute only GELU")
else:
    print("[Run on GPU for comparison]")

### Impressions/Conclusions (8.4, 8.5, 8.6)

SAC is the middle ground between "checkpoint everything" and "checkpoint nothing."

The API flow:
1. Define ops you care about (matmuls, attention, etc.)
2. Write a policy function (an allowlist wrapper is often enough)
3. Create contexts with `create_selective_checkpoint_contexts` (when available)
4. Pass to `checkpoint(..., context_fn=...)`

When to use SAC over standard AC:
- You need fine-grained control
- Recomputing certain ops is too expensive
- You want to tune the speed-memory trade-off precisely

Availability varies by PyTorch version/build. Treat SAC as evolving/experimental and guard imports in production code.

In [None]:
# 9.1 DDP + Checkpointing (Code Template)
# NOTE: This code shows the pattern. Run in a distributed environment.

"""
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.checkpoint import checkpoint

# Initialize distributed
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

class ModelWithCheckpoint(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1024, 1024)
        self.layer2 = nn.Linear(1024, 1024)
    
    def forward(self, x):
        x = checkpoint(lambda t: torch.relu(self.layer1(t)), x, use_reentrant=False)
        x = torch.relu(self.layer2(x))
        return x

# Wrap with DDP
model = ModelWithCheckpoint().cuda(local_rank)
ddp_model = DDP(model, device_ids=[local_rank])

# Training works as normal
x = torch.randn(16, 1024).cuda(local_rank)
output = ddp_model(x)
loss = output.sum()
loss.backward()
"""

print("DDP + Checkpointing:")
print("- Checkpointing happens locally on each GPU")
print("- DDP handles gradient synchronization")
print("- No special configuration needed")
print("- Memory savings apply per-GPU")

In [None]:
# 9.2 Mixed Precision + Checkpointing

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1024, 1024)
        self.layer2 = nn.Linear(1024, 1024)
    
    def forward(self, x):
        x = checkpoint(lambda t: torch.relu(self.layer1(t)), x, use_reentrant=False)
        x = torch.relu(self.layer2(x))
        return x

if torch.cuda.is_available():
    model = SimpleModel().cuda()
    optimizer = torch.optim.Adam(model.parameters())
    scaler = torch.cuda.amp.GradScaler()
    
    # Training loop with mixed precision + checkpointing
    for step in range(3):
        x = torch.randn(16, 1024).cuda()
        
        with torch.cuda.amp.autocast():
            output = model(x)
            loss = output.sum()
        
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        print(f"Step {step}: loss = {loss.item():.4f}")
    
    print("\nMixed precision + checkpointing works seamlessly")
else:
    print("[Run on GPU]")

### Impressions/Conclusions (9.1 & 9.2)

The power combo: DDP + Mixed Precision + Checkpointing.

- DDP: Scale across GPUs
- Mixed Precision: 2x memory reduction from fp16
- Checkpointing: Further memory reduction from recomputation

Combined, you can train models 4-5x larger than baseline. This is how GPT-scale models are trained.

---
# 10. Additional Tools and Resources

Beyond PyTorch's built-in checkpointing, there are libraries that offer more features.

In [None]:
# 10.1 DeepSpeed Integration

"""
DeepSpeed offers activation checkpointing as part of its ZeRO optimization suite.

# Installation
pip install deepspeed

# Usage
import deepspeed
from deepspeed.runtime.activation_checkpointing import checkpointing

# Configure in ds_config.json:
{
    "activation_checkpointing": {
        "partition_activations": true,
        "contiguous_memory_optimization": true,
        "cpu_checkpointing": true  # Offload to CPU!
    }
}

# Key advantage: CPU offloading
# DeepSpeed can move checkpointed activations to CPU memory,
# freeing GPU memory for even larger models.
"""

print("DeepSpeed Activation Checkpointing:")
print("- Automatic checkpoint placement")
print("- CPU offloading for extreme memory savings")
print("- Integrated with ZeRO optimizer stages")
print("- Best for very large models (billions of parameters)")

In [None]:
# 10.2 FairScale Integration

"""
FairScale (from Meta) provides checkpoint_wrapper for easy integration.

# Installation
pip install fairscale

# Usage
from fairscale.nn.checkpoint import checkpoint_wrapper

# Wrap any module
layer = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU())
checkpointed_layer = checkpoint_wrapper(layer)

# Use in model
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = checkpoint_wrapper(nn.Sequential(...))
        self.block2 = checkpoint_wrapper(nn.Sequential(...))
    
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        return x
"""

print("FairScale checkpoint_wrapper:")
print("- Clean API: wrap modules directly")
print("- Works well with FSDP (Fully Sharded Data Parallel)")
print("- Good for medium-scale training")
print("- Simpler than DeepSpeed for many use cases")

### Impressions/Conclusions (10.1 & 10.2)

Tool selection guide:
- **PyTorch native**: Simple models, full control, no dependencies
- **FairScale**: Medium scale, clean API, FSDP integration
- **DeepSpeed**: Billion+ parameter models, CPU offloading, full optimization suite

Start with PyTorch native. Move to DeepSpeed when you hit limits.

---
# 11. torch.compile and Memory Budget API

torch.compile (PyTorch 2.0+) does something clever: it traces your forward and backward passes into a single joint graph. Then it applies a "min-cut" partitioner.

This is different from activation checkpointing. The min-cut algorithm automatically decides which tensors to save and which to recompute based on minimizing total runtime. No manual configuration needed.

But here is the catch: by default, min-cut prioritizes speed, not memory. It only recomputes cheap, fusible ops (like pointwise operations).

The Memory Budget API gives you control over this trade-off.

In [None]:
# 11.1 torch.compile: The Min-Cut Partitioner

import torch
import torch.nn as nn

print("How torch.compile handles activations:")
print("-" * 50)
print()
print("1. TRACING")
print("   torch.compile traces forward AND backward into one graph.")
print("   This lets it see the whole picture.")
print()
print("2. MIN-CUT PARTITIONING")
print("   The graph is split at the optimal points.")
print("   Algorithm minimizes: tensors crossing the cut")
print("   (These are the tensors saved for backward)")
print()
print("3. AUTOMATIC RECOMPUTATION")
print("   Cheap ops (relu, add, mul) are recomputed automatically.")
print("   No user intervention needed.")
print()
print("4. FUSION")
print("   Pointwise ops get fused into kernels.")
print("   Fused ops are fast to recompute.")
print()
print("Result: torch.compile gives you SOME memory savings for FREE,")
print("        plus speed improvements from fusion.")

In [None]:
# 11.2 Comparing: Eager vs Compile vs Checkpointing

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class SimpleFFN(nn.Module):
    def __init__(self, dim=2048):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim * 4)
        self.fc2 = nn.Linear(dim * 4, dim)
    
    def forward(self, x):
        x = self.fc1(x)
        x = torch.gelu(x)
        x = self.fc2(x)
        return x

class Wrapper(nn.Module):
    """Adds a small stem so the checkpointed region receives activations that require grad."""
    def __init__(self, dim=2048, use_checkpoint=False):
        super().__init__()
        self.stem = nn.Linear(dim, dim)
        self.ffn = SimpleFFN(dim)
        self.use_checkpoint = use_checkpoint

    def forward(self, x):
        x = self.stem(x)
        if self.use_checkpoint:
            x = checkpoint(self.ffn, x, use_reentrant=False)
        else:
            x = self.ffn(x)
        return x

def warmup(fn, zero_grad_fn, x, iters=2):
    for _ in range(iters):
        fn(x).sum().backward()
        zero_grad_fn()

def peak_memory_mb(fn, zero_grad_fn, x):
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    out = fn(x)
    out.sum().backward()
    peak = torch.cuda.max_memory_allocated() / 1e6
    zero_grad_fn()
    return peak

if torch.cuda.is_available():
    dim = 2048
    batch = 64
    seq = 128

    x = torch.randn(batch, seq, dim, device="cuda")

    eager_model = Wrapper(dim, use_checkpoint=False).cuda()
    ckpt_model = Wrapper(dim, use_checkpoint=True).cuda()

    results = {}

    # 1) Eager mode
    warmup(eager_model, lambda: eager_model.zero_grad(set_to_none=True), x)
    results['Eager'] = peak_memory_mb(eager_model, lambda: eager_model.zero_grad(set_to_none=True), x)

    # 2) torch.compile (if available)
    if hasattr(torch, "compile"):
        compiled = torch.compile(eager_model)
        warmup(compiled, lambda: eager_model.zero_grad(set_to_none=True), x)
        results['torch.compile'] = peak_memory_mb(compiled, lambda: eager_model.zero_grad(set_to_none=True), x)
    else:
        print("torch.compile not available in this PyTorch build")

    # 3) Activation checkpointing
    warmup(ckpt_model, lambda: ckpt_model.zero_grad(set_to_none=True), x)
    results['Checkpointing'] = peak_memory_mb(ckpt_model, lambda: ckpt_model.zero_grad(set_to_none=True), x)

    print(f"FFN: dim={dim}, batch={batch}, seq={seq}")
    print("-" * 50)
    for name, mem in results.items():
        print(f"{name:20s}: {mem:8.2f} MB")
    print()
    print("torch.compile: better speed, moderate memory savings")
    print("Checkpointing: maximum memory savings, some speed cost")
else:
    print("[Run on GPU for comparison]")

### Impressions/Conclusions (11.1 & 11.2)

The speed-memory trade-off diagram:

```
Speed (high is good)
  ^
  |   * torch.compile (top-left: fast, some memory savings)
  |
  |        * Eager (top-right: fast, high memory)
  |
  |   * SAC policies (middle: tunable)
  |
  |   * Checkpointing (bottom-left: slower, low memory)
  |
  +---------------------------------> Memory (right is bad)
```

torch.compile sits between eager and full checkpointing. It automatically recomputes cheap ops. The Memory Budget API lets you push it further toward checkpointing.

In [None]:
# 11.3 Memory Budget API: Control the Trade-off

import torch

if not hasattr(torch, "compile"):
    print("torch.compile not available in this PyTorch build")
else:
    try:
        import torch._dynamo as dynamo
    except Exception as e:
        dynamo = None
        print("torch._dynamo not available; cannot use memory budget APIs")
        print(f"Details: {type(e).__name__}: {e}")

    print("Memory Budget API (torch.compile only)")
    print("-" * 50)
    print()

    if dynamo is None or not hasattr(dynamo.config, "activation_memory_budget"):
        print("activation_memory_budget not found in torch._dynamo.config in this build")
    else:
        print("torch._dynamo.config.activation_memory_budget = X")
        print()
        print("X = 0.0: Maximum recomputation (like full AC)")
        print("         Recompute everything, save almost nothing")
        print()
        print("X = 0.5: Balanced")
        print("         Recompute pointwise ops, save matmuls")
        print()
        print("X = 1.0: Default torch.compile behavior")
        print("         Minimal recomputation, maximize speed")
        print()
        print("The API automatically finds pareto-optimal policies.")
        print("You just specify how much memory you want to use.")

In [None]:
# 11.4 Using Memory Budget API in Practice

import torch
import torch.nn as nn

class TransformerBlock(nn.Module):
    def __init__(self, dim=1024, heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + attn_out)
        x = self.norm2(x + self.ffn(x))
        return x

if not torch.cuda.is_available():
    print("[Run on GPU for memory budget comparison]")
elif not hasattr(torch, "compile"):
    print("torch.compile not available in this PyTorch build")
else:
    try:
        import torch._dynamo as dynamo
    except Exception as e:
        dynamo = None
        print("torch._dynamo not available; cannot use memory budget APIs")
        print(f"Details: {type(e).__name__}: {e}")

    if dynamo is None or not hasattr(dynamo.config, "activation_memory_budget"):
        print("activation_memory_budget not available in torch._dynamo.config in this build")
    else:
        dim = 1024
        batch = 32
        seq = 256

        model = TransformerBlock(dim).cuda()
        x = torch.randn(batch, seq, dim, device="cuda")

        budgets = [1.0, 0.7, 0.5, 0.3, 0.0]

        print(f"Transformer: dim={dim}, batch={batch}, seq={seq}")
        print("-" * 50)
        print(f"{'Budget':<10} {'Memory (MB)':<15} {'Relative':<10}")
        print("-" * 50)

        baseline = None
        original_budget = dynamo.config.activation_memory_budget

        try:
            for budget in budgets:
                dynamo.config.activation_memory_budget = float(budget)

                # Recompile with new budget
                if hasattr(dynamo, "reset"):
                    dynamo.reset()
                compiled = torch.compile(model)

                # Warmup
                for _ in range(2):
                    compiled(x).sum().backward()
                    model.zero_grad(set_to_none=True)

                # Measure
                torch.cuda.empty_cache()
                torch.cuda.reset_peak_memory_stats()
                out = compiled(x)
                out.sum().backward()
                mem = torch.cuda.max_memory_allocated() / 1e6
                model.zero_grad(set_to_none=True)

                if baseline is None:
                    baseline = mem

                relative = mem / baseline
                print(f"{budget:<10.1f} {mem:<15.2f} {relative:.2f}x")
        finally:
            dynamo.config.activation_memory_budget = original_budget

In [None]:
# 11.5 What Gets Recomputed at Each Budget Level?

print("Recomputation order (from blog's real transformer results):")
print("-" * 60)
print()
print("Budget 1.0 (default compile):")
print("  Recomputes: Nothing extra")
print("  Saves: Everything")
print("  Memory: 100% (baseline)")
print()
print("Budget 0.7:")
print("  Recomputes: Pointwise ops (gelu, add, mul)")
print("  Saves: Matmuls, attention")
print("  Memory: ~85% of baseline")
print()
print("Budget 0.5:")
print("  Recomputes: Pointwise + some matmuls")
print("  Saves: Attention (most expensive)")
print("  Memory: ~50% of baseline")
print()
print("Budget 0.3:")
print("  Recomputes: Pointwise + most matmuls")
print("  Saves: Only attention")
print("  Memory: ~35% of baseline")
print()
print("Budget 0.0:")
print("  Recomputes: Everything (like full AC)")
print("  Saves: Almost nothing")
print("  Memory: Minimum possible")
print()
print("Key insight: 50% memory reduction by recomputing only pointwise ops.")
print("Attention is expensive. Recompute it last.")

### Impressions/Conclusions (11.3, 11.4, 11.5)

Memory Budget API is the easiest way to tune the speed-memory trade-off with torch.compile.

One line of code (when available):
```python
torch._dynamo.config.activation_memory_budget = 0.5
```

The system automatically:
1. Finds pareto-optimal recomputation strategies
2. Prioritizes recomputing cheap ops first
3. Saves attention and expensive matmuls for last

When to use Memory Budget API vs SAC:
- Memory Budget: Using torch.compile, want automatic optimization
- SAC: Need precise control over which ops to save/recompute

Availability varies by PyTorch build. Guard `torch._dynamo.config.activation_memory_budget` before relying on it.

---
# 12. Case Study: Activation Checkpointing in ResNet50

ResNet50 is deep enough to benefit from checkpointing. It has 4 stages with multiple bottleneck blocks. Each block stores feature maps for backward.

This section shows how to apply checkpointing to a real production model. You will see:
1. How to modify torchvision's ResNet50
2. Memory comparison across checkpointing strategies
3. A complete training loop with checkpointing

In [None]:
# 12.1 ResNet50 Architecture Overview

import torch
import torch.nn as nn

try:
    import torchvision.models as models
except Exception as e:
    models = None
    print("torchvision is required for the ResNet50 case study.")
    print(f"Details: {type(e).__name__}: {e}")

if models is not None:
    # Load ResNet50 without downloading weights
    try:
        resnet50 = models.resnet50(weights=None)
    except TypeError:
        # Older torchvision
        resnet50 = models.resnet50(pretrained=False)

    print("ResNet50 Structure:")
    print("-" * 60)
    print(f"conv1:  1 conv layer")
    print(f"layer1: {len(resnet50.layer1)} Bottleneck blocks (64 -> 256 channels)")
    print(f"layer2: {len(resnet50.layer2)} Bottleneck blocks (128 -> 512 channels)")
    print(f"layer3: {len(resnet50.layer3)} Bottleneck blocks (256 -> 1024 channels)")
    print(f"layer4: {len(resnet50.layer4)} Bottleneck blocks (512 -> 2048 channels)")
    print(f"fc:     1 linear layer")
    print()
    print(
        f"Total Bottleneck blocks: {len(resnet50.layer1) + len(resnet50.layer2) + len(resnet50.layer3) + len(resnet50.layer4)}"
    )
    print()
    print("Each Bottleneck block has 3 conv layers + skip connection.")
    print("Feature maps grow larger in early layers, then shrink spatially.")
    print("layer3 often has the largest activation memory (1024 channels, moderate spatial size).")

In [None]:
# 12.2 ResNet50 with Checkpointing: Three Strategies

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

try:
    import torchvision.models as models
except Exception as e:
    models = None
    ResNet50Checkpointed = None
    print("torchvision is required for ResNet50Checkpointed.")
    print(f"Details: {type(e).__name__}: {e}")

if models is not None:
    class ResNet50Checkpointed(nn.Module):
        """ResNet50 with configurable checkpointing strategies."""

        def __init__(self, num_classes=1000, checkpoint_strategy='none'):
            """
            Args:
                checkpoint_strategy: 'none', 'per_stage', 'per_block', or 'aggressive'
            """
            super().__init__()

            # Load base ResNet50 without downloading weights
            try:
                base = models.resnet50(weights=None)
            except TypeError:
                base = models.resnet50(pretrained=False)

            # Copy all layers
            self.conv1 = base.conv1
            self.bn1 = base.bn1
            self.relu = base.relu
            self.maxpool = base.maxpool
            self.layer1 = base.layer1
            self.layer2 = base.layer2
            self.layer3 = base.layer3
            self.layer4 = base.layer4
            self.avgpool = base.avgpool
            self.fc = nn.Linear(2048, num_classes)

            self.checkpoint_strategy = checkpoint_strategy

        def _forward_stage(self, stage, x):
            """Forward through a stage (layer1, layer2, etc.)."""
            for block in stage:
                x = block(x)
            return x

        def forward(self, x):
            # Stem
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
            x = self.maxpool(x)

            # NOTE: ResNet blocks include BatchNorm, which updates running stats in train mode.
            # Checkpointing recomputes forward in backward, which can update BN stats twice.
            # For strict parity, consider freezing BN stats (eval for BN) or using GroupNorm.

            if self.checkpoint_strategy == 'none':
                x = self.layer1(x)
                x = self.layer2(x)
                x = self.layer3(x)
                x = self.layer4(x)

            elif self.checkpoint_strategy == 'per_stage':
                x = checkpoint(lambda t: self._forward_stage(self.layer1, t), x, use_reentrant=False)
                x = checkpoint(lambda t: self._forward_stage(self.layer2, t), x, use_reentrant=False)
                x = checkpoint(lambda t: self._forward_stage(self.layer3, t), x, use_reentrant=False)
                x = checkpoint(lambda t: self._forward_stage(self.layer4, t), x, use_reentrant=False)

            elif self.checkpoint_strategy == 'per_block':
                for block in self.layer1:
                    x = checkpoint(block, x, use_reentrant=False)
                for block in self.layer2:
                    x = checkpoint(block, x, use_reentrant=False)
                for block in self.layer3:
                    x = checkpoint(block, x, use_reentrant=False)
                for block in self.layer4:
                    x = checkpoint(block, x, use_reentrant=False)

            elif self.checkpoint_strategy == 'aggressive':
                x = self.layer1(x)
                x = self.layer2(x)
                for block in self.layer3:
                    x = checkpoint(block, x, use_reentrant=False)
                for block in self.layer4:
                    x = checkpoint(block, x, use_reentrant=False)
            else:
                raise ValueError(f"Unknown checkpoint_strategy: {self.checkpoint_strategy}")

            # Head
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)
            return x

    print("ResNet50Checkpointed created with 4 strategies:")
    print("-" * 60)
    print("'none':       No checkpointing (baseline)")
    print("'per_stage':  Checkpoint each of the 4 stages")
    print("'per_block':  Checkpoint each of the 16 bottleneck blocks")
    print("'aggressive': Only checkpoint layer3 and layer4 (best ROI)")

In [None]:
# 12.3 Memory Comparison: ResNet50 Strategies

import torch
import torch.nn as nn
import time

def benchmark_resnet50(strategy, batch_size=32, image_size=224, num_iters=5):
    """Benchmark memory and time for a ResNet50 checkpointing strategy."""
    if ResNet50Checkpointed is None:
        raise RuntimeError("ResNet50Checkpointed is not available (torchvision import likely failed).")

    model = ResNet50Checkpointed(num_classes=1000, checkpoint_strategy=strategy).cuda()
    model.train()

    x = torch.randn(batch_size, 3, image_size, image_size, device="cuda")
    target = torch.randint(0, 1000, (batch_size,), device="cuda")
    criterion = nn.CrossEntropyLoss()

    # Warmup
    for _ in range(2):
        out = model(x)
        loss = criterion(out, target)
        loss.backward()
        model.zero_grad(set_to_none=True)

    # Measure memory + time
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()

    start = time.time()
    for _ in range(num_iters):
        out = model(x)
        loss = criterion(out, target)
        loss.backward()
        model.zero_grad(set_to_none=True)
    torch.cuda.synchronize()

    elapsed = (time.time() - start) / num_iters * 1000  # ms/iter
    peak_mem = torch.cuda.max_memory_allocated() / 1e6  # MB

    del model, x, target
    torch.cuda.empty_cache()

    return peak_mem, elapsed

if torch.cuda.is_available() and ResNet50Checkpointed is not None:
    batch_size = 32
    image_size = 224

    print("ResNet50 Checkpointing Comparison")
    print(f"Batch size: {batch_size}, Image size: {image_size}x{image_size}")
    print("-" * 70)
    print(f"{'Strategy':<15} {'Peak Memory (MB)':<20} {'Time (ms)':<15} {'Mem Savings':<15}")
    print("-" * 70)

    strategies = ['none', 'aggressive', 'per_stage', 'per_block']
    results = {}

    for strategy in strategies:
        mem, time_ms = benchmark_resnet50(strategy, batch_size, image_size)
        results[strategy] = (mem, time_ms)

    baseline_mem = results['none'][0]
    for strategy in strategies:
        mem, time_ms = results[strategy]
        savings = (1 - mem / baseline_mem) * 100
        print(f"{strategy:<15} {mem:<20.2f} {time_ms:<15.2f} {savings:>6.1f}%")

    print("-" * 70)
    print("\nObservations:")
    print("- 'aggressive' gives best memory/speed balance (checkpoints only deep layers)")
    print("- 'per_block' gives maximum memory savings but highest overhead")
    print("- 'per_stage' is a middle ground")
else:
    print("[Run on GPU and ensure torchvision is installed for ResNet50 benchmarks]")

In [None]:
# 12.4 Maximum Batch Size: ResNet50

import torch
import torch.nn as nn

def find_max_batch_resnet50(strategy, start=16, max_batch=256):
    """Find maximum batch size before OOM for ResNet50 (CUDA only)."""
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is required for find_max_batch_resnet50().")
    if ResNet50Checkpointed is None:
        raise RuntimeError("ResNet50Checkpointed is not available (torchvision import likely failed).")

    start = max(1, int(start))
    max_batch = max(start, int(max_batch))

    def can_run(batch_size):
        model = None
        x = None
        target = None
        out = None
        loss = None
        try:
            torch.cuda.empty_cache()
            model = ResNet50Checkpointed(checkpoint_strategy=strategy).cuda()
            model.train()
            x = torch.randn(batch_size, 3, 224, 224, device="cuda")
            target = torch.randint(0, 1000, (batch_size,), device="cuda")

            out = model(x)
            loss = nn.CrossEntropyLoss()(out, target)
            loss.backward()
            return True
        except RuntimeError as e:
            if "out of memory" in str(e).lower():
                return False
            raise
        finally:
            try:
                if model is not None:
                    model.zero_grad(set_to_none=True)
            except Exception:
                pass
            del model, x, target, out, loss
            torch.cuda.empty_cache()

    # Exponential search to find an upper bound, then binary search.
    max_working = 0
    batch = start
    while batch <= max_batch and can_run(batch):
        max_working = batch
        batch *= 2

    low = max_working + 1
    high = min(batch, max_batch)

    while low <= high:
        mid = (low + high) // 2
        if can_run(mid):
            max_working = mid
            low = mid + 1
        else:
            high = mid - 1

    return max_working

if torch.cuda.is_available() and ResNet50Checkpointed is not None:
    print("Finding maximum batch size for ResNet50 (224x224 images)")
    print("-" * 50)

    strategies = ['none', 'aggressive', 'per_block']
    baseline = None

    for strategy in strategies:
        max_batch = find_max_batch_resnet50(strategy)
        if baseline is None:
            baseline = max_batch
        improvement = (max_batch / baseline) if baseline else float('inf')
        print(f"{strategy:<15}: max batch = {max_batch:>3d} ({improvement:.2f}x)")

    print("-" * 50)
    print("\nWith checkpointing, you can often fit larger batches.")
else:
    print("[Run on GPU and ensure torchvision is installed to find max batch sizes]")

In [None]:
# 12.5 Complete Training Loop: ResNet50 with Checkpointing

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

def train_resnet50_with_checkpointing(
    checkpoint_strategy='aggressive',
    batch_size=32,
    num_epochs=3,
    num_samples=256,  # Small for demo
    use_amp=True,
):
    """Complete training loop with checkpointing and mixed precision."""

    if ResNet50Checkpointed is None:
        raise RuntimeError("ResNet50Checkpointed is not available (torchvision import likely failed).")

    # Create model
    model = ResNet50Checkpointed(
        num_classes=10,  # Simplified for demo
        checkpoint_strategy=checkpoint_strategy,
    ).cuda()

    # Create synthetic dataset
    X = torch.randn(num_samples, 3, 224, 224)
    y = torch.randint(0, 10, (num_samples,))
    dataset = TensorDataset(X, y)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Optimizer and loss
    optimizer = optim.AdamW(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    scaler = torch.cuda.amp.GradScaler() if (use_amp and torch.cuda.is_available()) else None

    print(f"Training ResNet50 with '{checkpoint_strategy}' checkpointing")
    print(f"Batch size: {batch_size}, Mixed precision: {use_amp and scaler is not None}")
    print("-" * 50)

    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        num_batches = 0

        torch.cuda.reset_peak_memory_stats()

        for batch_x, batch_y in loader:
            batch_x = batch_x.cuda()
            batch_y = batch_y.cuda()

            optimizer.zero_grad(set_to_none=True)

            if scaler is not None:
                with torch.cuda.amp.autocast():
                    outputs = model(batch_x)
                    loss = criterion(outputs, batch_y)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(batch_x)
                loss = criterion(outputs, batch_y)
                loss.backward()
                optimizer.step()

            epoch_loss += loss.item()
            num_batches += 1

        peak_mem = torch.cuda.max_memory_allocated() / 1e6
        avg_loss = epoch_loss / num_batches
        print(f"Epoch {epoch+1}/{num_epochs}: loss = {avg_loss:.4f}, peak memory = {peak_mem:.2f} MB")

    print("-" * 50)
    print("Training complete!")
    return model

if torch.cuda.is_available() and ResNet50Checkpointed is not None:
    model = train_resnet50_with_checkpointing(
        checkpoint_strategy='aggressive',
        batch_size=32,
        num_epochs=3,
        use_amp=True,
    )
else:
    print("[Run on GPU and ensure torchvision is installed for training demo]")

In [None]:
# 12.6 Alternative: Using torchvision's Built-in Support

import torch
from torch.utils.checkpoint import checkpoint_sequential

try:
    import torchvision.models as models
except Exception as e:
    models = None
    print("torchvision is required for this alternative ResNet50 example.")
    print(f"Details: {type(e).__name__}: {e}")


def _checkpoint_sequential(mod, segments, x):
    """checkpoint_sequential signature varies across PyTorch versions."""
    try:
        return checkpoint_sequential(mod, segments, x, use_reentrant=False)
    except TypeError:
        return checkpoint_sequential(mod, segments, x)


def resnet50_with_sequential_checkpoint(num_segments=4):
    """Monkey-patch torchvision ResNet50 to checkpoint its stage blocks."""
    if models is None:
        raise RuntimeError("torchvision is not available")

    try:
        model = models.resnet50(weights=None)
    except TypeError:
        model = models.resnet50(pretrained=False)

    def checkpointed_forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        # Checkpoint each stage with user-controlled granularity.
        x = _checkpoint_sequential(self.layer1, min(num_segments, len(self.layer1)), x)
        x = _checkpoint_sequential(self.layer2, min(num_segments, len(self.layer2)), x)
        x = _checkpoint_sequential(self.layer3, min(num_segments, len(self.layer3)), x)
        x = _checkpoint_sequential(self.layer4, min(num_segments, len(self.layer4)), x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

    import types

    model.forward = types.MethodType(checkpointed_forward, model)
    return model

# Test it
if models is not None:
    model = resnet50_with_sequential_checkpoint(num_segments=4)
    print("Alternative approach: Monkey-patch torchvision ResNet50")
    print("-" * 60)
    print("Pros:")
    print("  - Uses official torchvision model")
    print("  - Simple implementation")
    print("  - Adjustable granularity via num_segments")
    print()
    print("Cons:")
    print("  - Less control over which blocks to checkpoint")
    print("  - Harder to switch strategies dynamically")
    print()
    print("Use the custom class (ResNet50Checkpointed) for flexible experiments.")

### Impressions/Conclusions (12: ResNet50 Case Study)

ResNet50 is a perfect testbed for checkpointing because:
- 16 bottleneck blocks across 4 stages
- Deep enough to benefit, not so deep it is impractical
- Widely used in production

Key findings:

**Strategy Selection:**
- `aggressive` (checkpoint layer3+layer4): Best ROI. These layers have the most parameters and largest feature maps.
- `per_block`: Maximum memory savings (~40-50%), but highest compute overhead (~30-40%).
- `per_stage`: Middle ground. Good for quick wins.

**Practical Impact:**
- 1.5-2x larger batch sizes possible
- Combined with mixed precision: up to 3x improvement

**BatchNorm Caveat:**
ResNet blocks use BatchNorm. Checkpointing recomputes forward in backward, which can update BN running stats twice in `train()` mode. If you need strict parity, consider freezing BN stats (set BN layers to eval) or using GroupNorm.

**Production Recommendations:**
1. Start with `aggressive` strategy
2. Combine with mixed precision (AMP)
3. Profile before and after
4. If still OOM, move to `per_block`

The ResNet50Checkpointed class is a solid starting point. Copy it into your codebase and adapt it to your training setup and normalization strategy.

---
# 13. Conclusion

Activation checkpointing is a memory-compute trade-off. It trades extra forward passes for reduced memory usage. The trade-off is usually worth it.

In [None]:
# 13.1 Key Takeaways

print("""
KEY TAKEAWAYS
=============

1. ACTIVATIONS DOMINATE MEMORY
   - Not parameters, not gradients
   - Scales with batch_size * depth * hidden_size
   - Peak memory occurs at start of backward pass

2. CHECKPOINTING TRADES COMPUTE FOR MEMORY
   - Save some activations (checkpoints)
   - Recompute others during backward pass
   - Typical: 30-50% memory savings, ~50% compute overhead

3. THE SQRT(N) RULE
   - Optimal checkpointing: O(sqrt(n)) memory for n layers
   - 100 layers -> ~10 checkpoints -> 10x memory reduction

4. PRACTICAL PATTERNS
   - checkpoint(): Single modules/functions
   - checkpoint_sequential(): Sequential models
   - Always use use_reentrant=False

5. ADVANCED: SELECTIVE ACTIVATION CHECKPOINT (SAC)
   - Fine-grained control: choose what to save vs recompute
   - Policy functions: MUST_SAVE for expensive ops, PREFER_RECOMPUTE for cheap
   - Sweet spot: save matmuls, recompute pointwise ops
   - Use aten.<op>.default for correct op matching

6. ADVANCED: MEMORY BUDGET API (torch.compile)
   - One line: torch._dynamo.config.activation_memory_budget = 0.5
   - Automatic pareto-optimal recomputation strategies
   - Budget 0 = full AC, Budget 1 = default compile

7. COMBINE WITH OTHER TECHNIQUES
   - Mixed precision: 2x memory from fp16
   - Gradient accumulation: Effective larger batches
   - DDP: Scale across GPUs
   - Together: Train 4-5x larger models
""")

In [None]:
# 13.2 Decision Framework: When to Use Checkpointing

print("""
WHEN TO USE ACTIVATION CHECKPOINTING
====================================

USE IT WHEN:
- You are hitting OOM errors
- You want larger batch sizes
- Your model has 10+ layers
- Training time is less critical than memory
- You are training transformers or deep CNNs

SKIP IT WHEN:
- Model fits comfortably in memory
- Training time is the bottleneck (not memory)
- Model is shallow (< 5 layers)
- You need maximum training speed

QUICK DECISION TREE:

  OOM Error?
     |
     v
  YES --> Use checkpointing
     |
     v
  Want larger batches?
     |
     v
  YES --> Use checkpointing
     |
     v
  Model > 10 layers?
     |
     v
  YES --> Consider checkpointing
     |
     v
  NO --> Probably skip it
""")

### Final Impressions

Activation checkpointing is not magic. It is a simple trade-off executed well.

You now understand:
- Why activations dominate memory (batch x depth x hidden)
- How checkpointing works (recompute instead of store)
- When to use it (memory-bound, deep models)
- Basic implementation (checkpoint, checkpoint_sequential, use_reentrant=False)
- Advanced control with SAC (choose exactly what to save)
- Automatic optimization with Memory Budget API (one config line)

The landscape of techniques:
- **Eager**: Maximum speed, maximum memory
- **torch.compile**: Free speedups, some automatic memory savings
- **Memory Budget API**: Tunable compile-time optimization (0 to 1)
- **Selective AC**: Manual control over save/recompute decisions
- **Standard AC**: Maximum memory savings, ~50% compute overhead

Start simple. Add complexity only when needed. Measure everything.

Every large language model uses some form of activation checkpointing. Now you know exactly how it works and when to use each variant.

Go train something bigger.