# Activation Checkpointing: A First Principles Deep Dive

Memory is the bottleneck in deep learning. Not compute. Not data. Memory.

This notebook tears apart activation checkpointing from first principles. You will understand exactly why activations dominate memory, how checkpointing trades compute for memory, and when this trade-off makes sense.

Every code cell has an Impressions/Conclusions section. These are not summaries. They are insights.

---
# 1. Introduction

Training deep neural networks hits a wall. That wall is GPU memory.

You have a model. You have data. You have compute. But your GPU runs out of memory. Why?

Three things consume GPU memory during training:
1. **Model parameters** (weights and biases)
2. **Gradients** (same size as parameters)
3. **Activations** (intermediate outputs from each layer)

Here is the surprise: activations dominate. Not parameters. Not gradients. Activations.

In [None]:
# 1.1 The Memory Breakdown: Where Does Memory Go?

import torch
import torch.nn as nn

def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

def bytes_to_mb(b):
    return b / (1024 * 1024)

class DeepNetwork(nn.Module):
    def __init__(self, input_size=1024, hidden_size=1024, num_layers=10):
        super().__init__()
        layers = []
        for i in range(num_layers):
            in_f = input_size if i == 0 else hidden_size
            layers.append(nn.Linear(in_f, hidden_size))
            layers.append(nn.ReLU())
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.layers(x)

model = DeepNetwork(input_size=1024, hidden_size=1024, num_layers=10)
num_params = count_parameters(model)
param_mem = num_params * 4
grad_mem = param_mem

print(f"Model: {num_params:,} parameters")
print(f"Param memory: {bytes_to_mb(param_mem):.2f} MB")
print(f"Gradient memory: {bytes_to_mb(grad_mem):.2f} MB")
print(f"Total (params+grads): {bytes_to_mb(param_mem + grad_mem):.2f} MB")
print()
print("Activation memory by batch size:")
print("-" * 50)

for batch_size in [32, 64, 128, 256, 512]:
    act_mem = batch_size * 1024 * 10 * 4
    total = param_mem + grad_mem + act_mem
    pct = (act_mem / total) * 100
    print(f"Batch {batch_size:3d}: {bytes_to_mb(act_mem):6.2f} MB ({pct:.1f}% of total)")

### Impressions/Conclusions (1.1)

Activation memory scales with batch size. Parameter memory does not.

At batch 32, activations are ~20% of memory. At batch 512, activations are 80%+. This is why you OOM when increasing batch size, not model size.

The math:
- Parameters: O(model_size)
- Activations: O(batch_size × depth × hidden_size)

Batch size is the multiplier that kills you.

In [None]:
# 1.2 Real GPU Memory Measurement

import torch
import torch.nn as nn

class DeepNetwork(nn.Module):
    def __init__(self, input_size=1024, hidden_size=2048, num_layers=20):
        super().__init__()
        layers = []
        for i in range(num_layers):
            in_f = input_size if i == 0 else hidden_size
            layers.append(nn.Linear(in_f, hidden_size))
            layers.append(nn.ReLU())
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.layers(x)

print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    
    model = DeepNetwork().cuda()
    
    for batch_size in [16, 32, 64, 128]:
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()
        
        x = torch.randn(batch_size, 1024).cuda()
        output = model(x)
        loss = output.sum()
        loss.backward()
        
        peak = torch.cuda.max_memory_allocated() / 1e6
        print(f"Batch {batch_size:3d}: Peak = {peak:.2f} MB")
        
        model.zero_grad()
        del x, output, loss
else:
    print("[Run on GPU to see measurements]")

### Impressions/Conclusions (1.2)

Peak memory grows linearly with batch size. Double the batch, roughly double the memory.

This is what activation checkpointing solves. Reduce activation memory to:
1. Train larger models
2. Use larger batches
3. Go deeper without OOM

The solution: do not store all activations. Recompute them when needed.

---
# 2. Understanding Activation Checkpointing

Activation checkpointing is a memory-compute trade-off. Save memory by not storing activations. Pay for it by recomputing them during backprop.

To understand this, you need to know what activations are and why backprop needs them.

## 2.1 What Are Activations?

An activation is the output of a layer. That is it.

```
Input x -> [Layer 1] -> a1 -> [Layer 2] -> a2 -> [Layer 3] -> a3 -> Output
```

`a1`, `a2`, `a3` are activations. Each is a tensor stored in memory.

Why store them? Backpropagation needs them to compute gradients.

In [None]:
# 2.1 Why Backprop Needs Activations

import torch

# y = W2 * relu(W1 * x)
# Gradient w.r.t. W1 needs the pre-activation W1*x
# to compute relu'(W1*x)

torch.manual_seed(42)

W1 = torch.randn(4, 4, requires_grad=True)
W2 = torch.randn(4, 4, requires_grad=True)
x = torch.randn(1, 4)

# Forward pass stores these
z1 = x @ W1.T           # Pre-activation (needed for ReLU grad)
a1 = torch.relu(z1)     # Activation (needed for W2 grad)
y = a1 @ W2.T

print("Stored for backprop:")
print(f"  z1: {z1.shape} - needed to compute relu gradient")
print(f"  a1: {a1.shape} - needed to compute W2 gradient")

loss = y.sum()
loss.backward()

print(f"\nGradients computed:")
print(f"  W1.grad: {W1.grad.shape}")
print(f"  W2.grad: {W2.grad.shape}")

### Impressions/Conclusions (2.1)

The chain rule requires intermediate values. Backprop walks backward through the computation graph. At each step, it needs forward pass values.

For linear `y = Wx`: gradient w.r.t. W needs x.

For ReLU `y = relu(x)`: gradient needs to know where x > 0.

This is why activations are stored. This is what checkpointing eliminates.

In [None]:
# 2.2 The Core Trade-off: Store vs Recompute

import math

print("Standard Training:")
print("  Forward: x -> a1 -> a2 -> a3 -> loss")
print("           [store] [store] [store]")
print("  Backward: uses stored a1, a2, a3")
print("  Memory: O(n) activations")
print()
print("With Checkpointing:")
print("  Forward: x -> a1 -> a2 -> a3 -> loss")
print("           [save]       [save]")
print("  Backward: recompute a2 from a1, then use")
print("  Memory: O(sqrt(n)) with optimal placement")
print()

num_layers = 100
optimal = int(math.sqrt(num_layers))

print(f"For {num_layers} layers:")
print(f"  Standard: {num_layers} activations")
print(f"  Checkpointed: ~{optimal} activations")
print(f"  Memory reduction: {num_layers/optimal:.0f}x")
print(f"  Compute overhead: ~{100/num_layers*optimal:.0f}% extra")

### Impressions/Conclusions (2.2)

The sqrt(n) rule: optimal checkpointing reduces memory from O(n) to O(sqrt(n)).

For 100 layers:
- Standard: 100 activations
- Checkpointed: ~10 activations  
- Cost: ~10% extra compute

10x memory reduction for 10% compute overhead. That is an incredible trade-off.

---
# 3. Prerequisites

Environment setup and helper functions.

In [None]:
# 3.1 Environment Check

import torch
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

### Impressions/Conclusions (3.1)

PyTorch 1.9+ required for full checkpointing support. `torch.utils.checkpoint` provides the core API. CUDA strongly recommended since checkpointing shines when GPU memory is the bottleneck.

In [None]:
# 3.2 Memory Profiling Utilities

import torch
from contextlib import contextmanager

class MemoryTracker:
    """Track GPU memory usage."""
    
    def __init__(self):
        self.snapshots = []
    
    def snapshot(self, label=""):
        if torch.cuda.is_available():
            mem = torch.cuda.memory_allocated() / 1e6
            self.snapshots.append({'label': label, 'mb': mem})
            return mem
        return 0
    
    def reset(self):
        self.snapshots = []
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.empty_cache()
    
    def peak_mb(self):
        if torch.cuda.is_available():
            return torch.cuda.max_memory_allocated() / 1e6
        return 0
    
    def report(self):
        for s in self.snapshots:
            print(f"{s['label']:30s}: {s['mb']:8.2f} MB")
        print(f"{'Peak':30s}: {self.peak_mb():8.2f} MB")

@contextmanager
def track_memory(label="Op"):
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()
    yield
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        print(f"{label}: Peak = {torch.cuda.max_memory_allocated()/1e6:.2f} MB")

print("Utilities loaded: MemoryTracker, track_memory()")

### Impressions/Conclusions (3.2)

Key functions:
- `torch.cuda.memory_allocated()`: current memory in use
- `torch.cuda.max_memory_allocated()`: peak since last reset
- `torch.cuda.empty_cache()`: free cached memory

Peak memory is what matters. That is what causes OOM.

---
# 4. How Activation Checkpointing Works

PyTorch provides `torch.utils.checkpoint`. Two main functions:
- `checkpoint`: wrap a single function/module
- `checkpoint_sequential`: wrap a sequence of modules

The API is simple. The magic is in what happens under the hood.

In [None]:
# 4.1 Model WITHOUT Checkpointing

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1024, 1024)
        self.layer2 = nn.Linear(1024, 1024)

    def forward(self, x):
        x = torch.relu(self.layer1(x))
        x = torch.relu(self.layer2(x))
        return x

if torch.cuda.is_available():
    model = SimpleModel().cuda()
    x = torch.randn(1, 1024).cuda()

    torch.cuda.reset_peak_memory_stats()
    output = model(x)
    loss = output.sum()
    loss.backward()
    
    print(f"Peak Memory WITHOUT Checkpointing: {torch.cuda.max_memory_allocated() / 1e6:.2f} MB")
else:
    print("[Run on GPU]")

In [None]:
# 4.2 Model WITH Checkpointing

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class CheckpointedModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1024, 1024)
        self.layer2 = nn.Linear(1024, 1024)

    def forward(self, x):
        # Checkpoint layer1: its activation won't be stored
        # It will be recomputed during backward pass
        x = checkpoint(lambda t: torch.relu(self.layer1(t)), x, use_reentrant=False)
        x = torch.relu(self.layer2(x))
        return x

if torch.cuda.is_available():
    model = CheckpointedModel().cuda()
    x = torch.randn(1, 1024).cuda()

    torch.cuda.reset_peak_memory_stats()
    output = model(x)
    loss = output.sum()
    loss.backward()
    
    print(f"Peak Memory WITH Checkpointing: {torch.cuda.max_memory_allocated() / 1e6:.2f} MB")
else:
    print("[Run on GPU]")

### Impressions/Conclusions (4.1 & 4.2)

The difference is subtle in small models. With 2 layers, memory savings are minimal because the baseline is already small.

The real benefit appears at scale. Let us test with a deeper model.

In [None]:
# 4.3 Deep Model Comparison: The Real Difference

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class DeepModelNoCheckpoint(nn.Module):
    def __init__(self, num_layers=20, hidden=2048):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(hidden, hidden) for _ in range(num_layers)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            x = torch.relu(layer(x))
        return x

class DeepModelWithCheckpoint(nn.Module):
    def __init__(self, num_layers=20, hidden=2048):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(hidden, hidden) for _ in range(num_layers)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            # Checkpoint every layer
            x = checkpoint(lambda t, l=layer: torch.relu(l(t)), x, use_reentrant=False)
        return x

if torch.cuda.is_available():
    batch_size = 64
    hidden = 2048
    num_layers = 20
    
    # Without checkpointing
    torch.cuda.empty_cache()
    model1 = DeepModelNoCheckpoint(num_layers, hidden).cuda()
    x1 = torch.randn(batch_size, hidden).cuda()
    
    torch.cuda.reset_peak_memory_stats()
    out1 = model1(x1)
    out1.sum().backward()
    mem_no_ckpt = torch.cuda.max_memory_allocated() / 1e6
    
    del model1, x1, out1
    torch.cuda.empty_cache()
    
    # With checkpointing
    model2 = DeepModelWithCheckpoint(num_layers, hidden).cuda()
    x2 = torch.randn(batch_size, hidden).cuda()
    
    torch.cuda.reset_peak_memory_stats()
    out2 = model2(x2)
    out2.sum().backward()
    mem_ckpt = torch.cuda.max_memory_allocated() / 1e6
    
    print(f"Config: {num_layers} layers, hidden={hidden}, batch={batch_size}")
    print(f"WITHOUT checkpointing: {mem_no_ckpt:.2f} MB")
    print(f"WITH checkpointing:    {mem_ckpt:.2f} MB")
    print(f"Memory saved: {mem_no_ckpt - mem_ckpt:.2f} MB ({(1 - mem_ckpt/mem_no_ckpt)*100:.1f}%)")
else:
    print("[Run on GPU]")

### Impressions/Conclusions (4.3)

Now you see the real savings. With 20 layers at hidden=2048, checkpointing can save 30-50% memory depending on batch size and model architecture.

The key insight: activation memory grows with depth. Checkpointing trades that linear growth for constant (or sqrt) memory.

---
# 5. Code Examples

Three patterns you will use:
1. Basic `checkpoint()` for single modules
2. `checkpoint_sequential()` for sequential models
3. Manual checkpointing in complex architectures (transformers)

In [None]:
# 5.1 Basic Usage

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU())
        self.block2 = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU())

    def forward(self, x):
        # Checkpoint block1
        x = checkpoint(self.block1, x, use_reentrant=False)
        x = self.block2(x)
        return x

model = SimpleModel()
if torch.cuda.is_available():
    model = model.cuda()

x = torch.randn(1, 1024)
if torch.cuda.is_available():
    x = x.cuda()

output = model(x)
print(f"Output shape: {output.shape}")
print("Checkpointing applied to block1. Block2 runs normally.")

### Impressions/Conclusions (5.1)

`checkpoint(fn, *args)` wraps any callable. During forward, it runs `fn(*args)` but does not save intermediate activations. During backward, it re-runs `fn(*args)` to recompute them.

Note: `use_reentrant=False` is the modern API. It handles edge cases better than the old default.

In [None]:
# 5.2 Using checkpoint_sequential

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint_sequential

# A deep sequential model
model = nn.Sequential(
    nn.Linear(1024, 1024), nn.ReLU(),
    nn.Linear(1024, 1024), nn.ReLU(),
    nn.Linear(1024, 1024), nn.ReLU(),
    nn.Linear(1024, 1024), nn.ReLU(),
)

x = torch.randn(1, 1024)
if torch.cuda.is_available():
    model = model.cuda()
    x = x.cuda()

# Split into 2 checkpoint segments
# Each segment will be checkpointed separately
segments = 2
output = checkpoint_sequential(model, segments, x, use_reentrant=False)

print(f"Output shape: {output.shape}")
print(f"Model split into {segments} checkpoint segments")
print("Each segment recomputes activations during backward pass")

### Impressions/Conclusions (5.2)

`checkpoint_sequential` is convenient for nn.Sequential models. The `segments` parameter controls granularity:
- More segments = less memory, more recomputation
- Fewer segments = more memory, less recomputation

Rule of thumb: start with sqrt(num_layers) segments.

In [None]:
# 5.3 Transformer Training with Checkpointing

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, num_heads=8):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_size, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, embed_size * 4),
            nn.GELU(),
            nn.Linear(embed_size * 4, embed_size)
        )
        self.norm2 = nn.LayerNorm(embed_size)

    def forward(self, x):
        # Self-attention with residual
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)
        # Feed-forward with residual (checkpointed)
        ff_out = checkpoint(self.feed_forward, x, use_reentrant=False)
        x = self.norm2(x + ff_out)
        return x

# Example usage
embed_size = 512
seq_length = 64
batch_size = 8

block = TransformerBlock(embed_size)
if torch.cuda.is_available():
    block = block.cuda()

x = torch.randn(batch_size, seq_length, embed_size)
if torch.cuda.is_available():
    x = x.cuda()

output = block(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print("Feed-forward block is checkpointed (it uses the most memory)")

### Impressions/Conclusions (5.3)

In transformers, the feed-forward block uses 4x the hidden size. This is where most activation memory goes. Checkpointing the FFN is the standard practice in large language models.

Why not checkpoint attention? You can, but attention has complex intermediate states. The memory-compute trade-off is less favorable there.

In [None]:
# 5.4 ResNet with Selective Checkpointing (Bonus)

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class BasicBlock(nn.Module):
    """ResNet basic block with optional checkpointing."""
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return torch.relu(out)

class MiniResNet(nn.Module):
    """Small ResNet with checkpointing options."""
    def __init__(self, num_blocks=4, use_checkpoint=False):
        super().__init__()
        self.use_checkpoint = use_checkpoint
        
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.blocks = nn.ModuleList([
            BasicBlock(64 if i == 0 else 128, 128, stride=2 if i == 0 else 1)
            for i in range(num_blocks)
        ])
        
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(128, 10)
    
    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        
        for block in self.blocks:
            if self.use_checkpoint:
                x = checkpoint(block, x, use_reentrant=False)
            else:
                x = block(x)
        
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# Compare memory
if torch.cuda.is_available():
    batch_size = 32
    
    # Without checkpoint
    torch.cuda.empty_cache()
    model1 = MiniResNet(num_blocks=8, use_checkpoint=False).cuda()
    x1 = torch.randn(batch_size, 3, 32, 32).cuda()
    torch.cuda.reset_peak_memory_stats()
    out1 = model1(x1)
    out1.sum().backward()
    mem1 = torch.cuda.max_memory_allocated() / 1e6
    
    del model1, x1, out1
    torch.cuda.empty_cache()
    
    # With checkpoint
    model2 = MiniResNet(num_blocks=8, use_checkpoint=True).cuda()
    x2 = torch.randn(batch_size, 3, 32, 32).cuda()
    torch.cuda.reset_peak_memory_stats()
    out2 = model2(x2)
    out2.sum().backward()
    mem2 = torch.cuda.max_memory_allocated() / 1e6
    
    print(f"ResNet with 8 blocks, batch={batch_size}")
    print(f"WITHOUT checkpoint: {mem1:.2f} MB")
    print(f"WITH checkpoint:    {mem2:.2f} MB")
    print(f"Savings: {(1 - mem2/mem1)*100:.1f}%")
else:
    print("[Run on GPU]")

### Impressions/Conclusions (5.4)

CNNs benefit from checkpointing too. Each residual block stores feature maps that can be recomputed. For deeper ResNets (ResNet-152), checkpointing is essential to train with reasonable batch sizes.

Key pattern: wrap entire blocks, not individual layers. The overhead of checkpoint calls adds up.

---
# 6. Performance Benchmarks

Numbers matter. Let us measure:
1. Memory savings across model sizes
2. Compute overhead (training time)
3. Batch size scaling

In [None]:
# 6.1 Memory Savings Across Model Depths

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import time

def create_model(num_layers, hidden, use_checkpoint=False):
    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.layers = nn.ModuleList([
                nn.Linear(hidden, hidden) for _ in range(num_layers)
            ])
            self.use_ckpt = use_checkpoint
        
        def forward(self, x):
            for layer in self.layers:
                if self.use_ckpt:
                    x = checkpoint(lambda t, l=layer: torch.relu(l(t)), x, use_reentrant=False)
                else:
                    x = torch.relu(layer(x))
            return x
    return Model()

if torch.cuda.is_available():
    hidden = 1024
    batch = 64
    
    print(f"Config: hidden={hidden}, batch={batch}")
    print("-" * 60)
    print(f"{'Layers':<10} {'No Ckpt (MB)':<15} {'With Ckpt (MB)':<15} {'Savings':<10}")
    print("-" * 60)
    
    for num_layers in [10, 20, 40, 80]:
        # Without checkpoint
        torch.cuda.empty_cache()
        m1 = create_model(num_layers, hidden, False).cuda()
        x1 = torch.randn(batch, hidden).cuda()
        torch.cuda.reset_peak_memory_stats()
        m1(x1).sum().backward()
        mem1 = torch.cuda.max_memory_allocated() / 1e6
        del m1, x1
        
        # With checkpoint
        torch.cuda.empty_cache()
        m2 = create_model(num_layers, hidden, True).cuda()
        x2 = torch.randn(batch, hidden).cuda()
        torch.cuda.reset_peak_memory_stats()
        m2(x2).sum().backward()
        mem2 = torch.cuda.max_memory_allocated() / 1e6
        del m2, x2
        
        savings = (1 - mem2/mem1) * 100
        print(f"{num_layers:<10} {mem1:<15.2f} {mem2:<15.2f} {savings:.1f}%")
else:
    print("[Run on GPU for benchmarks]")

### Impressions/Conclusions (6.1)

Memory savings increase with depth. At 10 layers, savings are modest. At 80 layers, savings can exceed 50%.

This matches the theory: activation memory is O(n), checkpointing reduces it to O(1) per segment.

In [None]:
# 6.2 Compute Overhead Measurement

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import time

def benchmark_time(model, x, num_iters=10):
    # Warmup
    for _ in range(3):
        model(x).sum().backward()
        model.zero_grad()
    
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(num_iters):
        model(x).sum().backward()
        model.zero_grad()
    torch.cuda.synchronize()
    return (time.time() - start) / num_iters * 1000  # ms

if torch.cuda.is_available():
    hidden = 1024
    batch = 64
    num_layers = 40
    
    # Create models
    m1 = create_model(num_layers, hidden, False).cuda()
    m2 = create_model(num_layers, hidden, True).cuda()
    x = torch.randn(batch, hidden).cuda()
    
    time1 = benchmark_time(m1, x)
    time2 = benchmark_time(m2, x)
    
    overhead = (time2 - time1) / time1 * 100
    
    print(f"Config: {num_layers} layers, hidden={hidden}, batch={batch}")
    print(f"WITHOUT checkpoint: {time1:.2f} ms/iter")
    print(f"WITH checkpoint:    {time2:.2f} ms/iter")
    print(f"Overhead: {overhead:.1f}%")
else:
    print("[Run on GPU for benchmarks]")

### Impressions/Conclusions (6.2)

Compute overhead is typically 20-40% depending on model architecture. This is the extra forward passes during backprop.

The trade-off: 50% memory savings for 30% compute overhead. Worth it when memory is the bottleneck.

---
# 7. Debugging and Best Practices

Checkpointing has gotchas. Here are the common ones and how to avoid them.

In [None]:
# 7.1 Pitfall: Non-Differentiable Operations

import torch
from torch.utils.checkpoint import checkpoint

# BAD: torch.no_grad() inside checkpointed function
def bad_forward(x):
    with torch.no_grad():  # This breaks gradient computation!
        x = x ** 2
    return x

# GOOD: Keep everything differentiable
def good_forward(x):
    x = x ** 2  # No torch.no_grad()
    return x

x = torch.randn(4, requires_grad=True)

# This will fail or give wrong gradients
try:
    y_bad = checkpoint(bad_forward, x, use_reentrant=False)
    y_bad.sum().backward()
    print("BAD: Gradients might be zero or wrong")
except Exception as e:
    print(f"BAD: Error - {e}")

# This works
x = torch.randn(4, requires_grad=True)
y_good = checkpoint(good_forward, x, use_reentrant=False)
y_good.sum().backward()
print(f"GOOD: x.grad = {x.grad}")

### Impressions/Conclusions (7.1)

Rule: Everything inside a checkpointed function must be differentiable. No `torch.no_grad()`, no detaching tensors, no in-place operations that break autograd.

If you need non-differentiable ops, do them outside the checkpointed region.

In [None]:
# 7.2 Pitfall: Random Operations (Dropout)

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

# Problem: Dropout uses different random values in forward vs recompute
class ModelWithDropout(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(32, 32)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        x = self.linear(x)
        x = self.dropout(x)  # Random! Different each forward pass
        return x

model = ModelWithDropout()
x = torch.randn(4, 32, requires_grad=True)

# During training, checkpointed dropout can cause issues
# because recomputation uses different random mask

# Solution 1: Use deterministic mode (PyTorch 1.11+)
# checkpoint(..., preserve_rng_state=True)  # Default in newer PyTorch

# Solution 2: Use use_reentrant=False which handles this better
y = checkpoint(model, x, use_reentrant=False)
print(f"Output shape: {y.shape}")
print("use_reentrant=False handles RNG state properly")

### Impressions/Conclusions (7.2)

Dropout and other random ops are tricky. The recomputed forward pass might use different random values than the original.

Modern PyTorch (1.11+) with `use_reentrant=False` preserves RNG state. Always use this option.

In [None]:
# 7.3 Best Practices Summary

print("""
BEST PRACTICES FOR ACTIVATION CHECKPOINTING
============================================

1. USE use_reentrant=False
   - Modern API, handles edge cases better
   - checkpoint(fn, x, use_reentrant=False)

2. CHECKPOINT LARGE BLOCKS, NOT SMALL LAYERS
   - Overhead of checkpoint() call is non-trivial
   - Group 2-4 layers into blocks, then checkpoint blocks

3. AVOID NESTED CHECKPOINTS
   - Don't checkpoint inside checkpointed functions
   - Leads to exponential recomputation

4. KEEP EVERYTHING DIFFERENTIABLE
   - No torch.no_grad() inside checkpointed regions
   - No tensor detaching
   - No in-place ops that break autograd

5. TEST GRADIENTS FIRST
   - Compare gradients with and without checkpointing
   - They should be identical (within float precision)

6. PROFILE MEMORY
   - Use torch.cuda.memory_stats() to verify savings
   - Peak memory is what matters

7. CONSIDER CHECKPOINT PLACEMENT
   - Middle layers often have largest activations
   - Checkpoint those first
""")

# Gradient verification example
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

model = nn.Linear(64, 64)
x = torch.randn(8, 64, requires_grad=True)

# Without checkpoint
x1 = x.clone().detach().requires_grad_(True)
y1 = model(x1)
y1.sum().backward()
grad1 = x1.grad.clone()

# With checkpoint
x2 = x.clone().detach().requires_grad_(True)
y2 = checkpoint(model, x2, use_reentrant=False)
y2.sum().backward()
grad2 = x2.grad.clone()

print(f"Gradients match: {torch.allclose(grad1, grad2)}")
print(f"Max difference: {(grad1 - grad2).abs().max():.2e}")

### Impressions/Conclusions (7.3)

Gradient verification is essential. If gradients do not match exactly (within float precision), something is wrong with your checkpointing setup. Debug before scaling up.

---
# 8. Integrating with Distributed Training

Checkpointing works with DDP and mixed precision. The combination is powerful.

In [None]:
# 8.1 DDP + Checkpointing (Code Template)
# NOTE: This code shows the pattern. Run in a distributed environment.

"""
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.checkpoint import checkpoint

# Initialize distributed
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

class ModelWithCheckpoint(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1024, 1024)
        self.layer2 = nn.Linear(1024, 1024)
    
    def forward(self, x):
        x = checkpoint(lambda t: torch.relu(self.layer1(t)), x, use_reentrant=False)
        x = torch.relu(self.layer2(x))
        return x

# Wrap with DDP
model = ModelWithCheckpoint().cuda(local_rank)
ddp_model = DDP(model, device_ids=[local_rank])

# Training works as normal
x = torch.randn(16, 1024).cuda(local_rank)
output = ddp_model(x)
loss = output.sum()
loss.backward()
"""

print("DDP + Checkpointing:")
print("- Checkpointing happens locally on each GPU")
print("- DDP handles gradient synchronization")
print("- No special configuration needed")
print("- Memory savings apply per-GPU")

In [None]:
# 8.2 Mixed Precision + Checkpointing

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1024, 1024)
        self.layer2 = nn.Linear(1024, 1024)
    
    def forward(self, x):
        x = checkpoint(lambda t: torch.relu(self.layer1(t)), x, use_reentrant=False)
        x = torch.relu(self.layer2(x))
        return x

if torch.cuda.is_available():
    model = SimpleModel().cuda()
    optimizer = torch.optim.Adam(model.parameters())
    scaler = torch.cuda.amp.GradScaler()
    
    # Training loop with mixed precision + checkpointing
    for step in range(3):
        x = torch.randn(16, 1024).cuda()
        
        with torch.cuda.amp.autocast():
            output = model(x)
            loss = output.sum()
        
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        print(f"Step {step}: loss = {loss.item():.4f}")
    
    print("\nMixed precision + checkpointing works seamlessly")
else:
    print("[Run on GPU]")

### Impressions/Conclusions (8.1 & 8.2)

The power combo: DDP + Mixed Precision + Checkpointing.

- DDP: Scale across GPUs
- Mixed Precision: 2x memory reduction from fp16
- Checkpointing: Further memory reduction from recomputation

Combined, you can train models 4-5x larger than baseline. This is how GPT-scale models are trained.

---
# 10. Additional Tools and Resources

Beyond PyTorch's built-in checkpointing, there are libraries that offer more features.

In [None]:
# 10.1 DeepSpeed Integration

"""
DeepSpeed offers activation checkpointing as part of its ZeRO optimization suite.

# Installation
pip install deepspeed

# Usage
import deepspeed
from deepspeed.runtime.activation_checkpointing import checkpointing

# Configure in ds_config.json:
{
    "activation_checkpointing": {
        "partition_activations": true,
        "contiguous_memory_optimization": true,
        "cpu_checkpointing": true  # Offload to CPU!
    }
}

# Key advantage: CPU offloading
# DeepSpeed can move checkpointed activations to CPU memory,
# freeing GPU memory for even larger models.
"""

print("DeepSpeed Activation Checkpointing:")
print("- Automatic checkpoint placement")
print("- CPU offloading for extreme memory savings")
print("- Integrated with ZeRO optimizer stages")
print("- Best for very large models (billions of parameters)")

In [None]:
# 10.2 FairScale Integration

"""
FairScale (from Meta) provides checkpoint_wrapper for easy integration.

# Installation
pip install fairscale

# Usage
from fairscale.nn.checkpoint import checkpoint_wrapper

# Wrap any module
layer = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU())
checkpointed_layer = checkpoint_wrapper(layer)

# Use in model
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = checkpoint_wrapper(nn.Sequential(...))
        self.block2 = checkpoint_wrapper(nn.Sequential(...))
    
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        return x
"""

print("FairScale checkpoint_wrapper:")
print("- Clean API: wrap modules directly")
print("- Works well with FSDP (Fully Sharded Data Parallel)")
print("- Good for medium-scale training")
print("- Simpler than DeepSpeed for many use cases")

### Impressions/Conclusions (10.1 & 10.2)

Tool selection guide:
- **PyTorch native**: Simple models, full control, no dependencies
- **FairScale**: Medium scale, clean API, FSDP integration
- **DeepSpeed**: Billion+ parameter models, CPU offloading, full optimization suite

Start with PyTorch native. Move to DeepSpeed when you hit limits.

---
# 12. Conclusion

Activation checkpointing is a memory-compute trade-off. It trades extra forward passes for reduced memory usage. The trade-off is usually worth it.

In [None]:
# 12.1 Key Takeaways

print("""
KEY TAKEAWAYS
=============

1. ACTIVATIONS DOMINATE MEMORY
   - Not parameters, not gradients
   - Scales with batch_size * depth * hidden_size

2. CHECKPOINTING TRADES COMPUTE FOR MEMORY
   - Save some activations (checkpoints)
   - Recompute others during backward pass
   - Typical: 30-50% memory savings, 20-40% compute overhead

3. THE SQRT(N) RULE
   - Optimal checkpointing: O(sqrt(n)) memory for n layers
   - 100 layers -> ~10 checkpoints -> 10x memory reduction

4. PRACTICAL PATTERNS
   - checkpoint(): Single modules/functions
   - checkpoint_sequential(): Sequential models
   - Always use use_reentrant=False

5. COMBINE WITH OTHER TECHNIQUES
   - Mixed precision: 2x memory from fp16
   - Gradient accumulation: Effective larger batches
   - DDP: Scale across GPUs
   - Together: Train 4-5x larger models
""")

In [None]:
# 12.2 Decision Framework: When to Use Checkpointing

print("""
WHEN TO USE ACTIVATION CHECKPOINTING
====================================

USE IT WHEN:
- You are hitting OOM errors
- You want larger batch sizes
- Your model has 10+ layers
- Training time is less critical than memory
- You are training transformers or deep CNNs

SKIP IT WHEN:
- Model fits comfortably in memory
- Training time is the bottleneck (not memory)
- Model is shallow (< 5 layers)
- You need maximum training speed

QUICK DECISION TREE:

  OOM Error?
     |
     v
  YES --> Use checkpointing
     |
     v
  Want larger batches?
     |
     v
  YES --> Use checkpointing
     |
     v
  Model > 10 layers?
     |
     v
  YES --> Consider checkpointing
     |
     v
  NO --> Probably skip it
""")

### Final Impressions

Activation checkpointing is not magic. It is a simple trade-off executed well.

You now understand:
- Why activations dominate memory (batch × depth × hidden)
- How checkpointing works (recompute instead of store)
- When to use it (memory-bound, deep models)
- How to implement it (checkpoint, checkpoint_sequential, use_reentrant=False)

The technique is simple. The impact is massive. Every large language model you have heard of uses some form of activation checkpointing.

Go train something bigger.