# Homework: Activation Checkpointing in PyTorch

In this notebook, you will learn how to use **activation checkpointing** in PyTorch to reduce GPU memory usage when training deep models such as language models.

We will go step by step, comparing normal training vs checkpointed training, and measuring both **memory usage** and **speed trade-offs**.

---


## Setup

In [2]:
# Environment Check

import torch
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

PyTorch Version: 2.9.0+cu126
CUDA Available: True
CUDA Version: 12.6
GPU: Tesla T4
GPU Memory: 15.83 GB


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.half
print(f"Using device: {device}")
print(f"Using dtype: {dtype}")

torch.backends.cuda.enable_flash_sdp(True) # ðŸ’¡ this alone won't enable flash attention :) check the Bonus section!
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(False)

Using device: cuda
Using dtype: torch.bfloat16


## Step 1: Define a Toy Transformer Block

In [4]:
class SimpleTransformerBlock(nn.Module):
    def __init__(self, d_model=512, n_heads=4):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.linear1 = nn.Linear(d_model, 4*d_model)
        self.linear2 = nn.Linear(4*d_model, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def attn_path(self, x):
        h = self.norm1(x)
        return x + self.attn(h,h,h, need_weights=True)[0]

    def mlp_path(self, x):
        h = self.norm2(x)
        return x + self.linear2(F.relu(self.linear1(h)))

    def forward(self, x):
        x = self.attn_path(x)
        x = self.mlp_path(x)
        return x

## Step 2: Stack Many Blocks (Without Checkpointing)

In [5]:
class DeepModel(nn.Module):
    def __init__(self, depth=12):
        super().__init__()
        self.blocks = nn.ModuleList([SimpleTransformerBlock().to(dtype=dtype, device='cuda') for _ in range(depth)])

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x

## Step 3: Add Activation Checkpointing

Complete the forward pass so to enable activation checkpointing:

- Attention sublayer is checkpointed when ckpt_attn=True.

- MLP sublayer is checkpointed when ckpt_mlp=True.

ðŸ’¡ The first block runs normally (no checkpoint) to avoid cold-start edge cases.

In [8]:

class CheckpointedModel(nn.Module):
    def __init__(self, depth=12, ckpt_attn=True, ckpt_mlp=True):
        super().__init__()
        self.ckpt_attn = ckpt_attn
        self.ckpt_mlp = ckpt_mlp
        self.blocks = nn.ModuleList([SimpleTransformerBlock().to(dtype=dtype, device='cuda')  for _ in range(depth)])

    def forward(self, x):
        for i, block in enumerate(self.blocks):
            if i == 0:
                # First block runs normally (no checkpoint), as I read it may have cold-start edge cases
                # Probably an ugly if condition, should rethink
                x = block(x)
            else:
                # Apply checkpointing based on flags
                if self.ckpt_attn:
                    x = checkpoint(block.attn_path, x, use_reentrant=False)
                else:
                    x = block.attn_path(x)
                
                if self.ckpt_mlp:
                    x = checkpoint(block.mlp_path, x, use_reentrant=False)
                else:
                    x = block.mlp_path(x)
        return x


## Step 4: Compare Memory Usage and Speed

For the 4 configs below, check the elapsed time and peak memory usage.

Why does checkpointing the attention sublayer lead to large memory savings?

Why does checkpointing the MLP sublayer usually give much smaller savings compared to attention?

In [9]:
def measure_run(model, x, steps=10):
    model.to(device)
    x = x.to(device)
    opt = torch.optim.AdamW(model.parameters())
    torch.cuda.reset_peak_memory_stats()

    start = time.time()
    for _ in range(steps):
        opt.zero_grad()
        out = model(x)
        loss = out.mean()
        loss.backward()
        opt.step()
    end = time.time()

    max_mem = torch.cuda.max_memory_allocated() / 1e6
    return (end-start)/steps, max_mem


x = torch.randn(1, 4096, 512, dtype=dtype, device="cuda")  # batch, seq, hidden - adjust based on your GPU memory

m1 = CheckpointedModel(depth=12, ckpt_attn=False, ckpt_mlp=False)
m2 = CheckpointedModel(depth=12, ckpt_attn=True, ckpt_mlp=True)
m3 = CheckpointedModel(depth=12, ckpt_attn=False, ckpt_mlp=True)
m4 = CheckpointedModel(depth=12, ckpt_attn=True, ckpt_mlp=False)
t1, m1_mem = measure_run(m1, x)
t2, m2_mem = measure_run(m2, x)
t3, m3_mem = measure_run(m3, x)
t4, m4_mem = measure_run(m4, x)


print(f"Without checkpointing: {t1:.2f}s, {m1_mem:.0f} MB")
print(f"With checkpointing:    {t2:.2f}s, {m2_mem:.0f} MB")
print(f"CKPT mlp:              {t3:.2f}s, {m3_mem:.0f} MB")
print(f"CKPT attn:             {t4:.2f}s, {m4_mem:.0f} MB")


Without checkpointing: 0.75s, 3219 MB
With checkpointing:    1.01s, 1497 MB
CKPT mlp:              0.85s, 3176 MB
CKPT attn:             0.99s, 1878 MB


# Bonus
Modify the code to only checkpoint **every other block** instead of all blocks.  
   - What trade-off do you observe?

# Bonus
The code above currently does not use FlashAttention, which means attention has quadratic memory complexity with respect to sequence length (O(nÂ²)).

ðŸ‘‰ However, PyTorch lets you enable FlashAttention by simply toggling a single flag in your code.

ðŸ’¡ Challenge: Can you find and change that flag so that your model runs with FlashAttention instead of the standard (quadratic) path?