# Homework: Activation Checkpointing in PyTorch

In this notebook, you will learn how to use **activation checkpointing** in PyTorch to reduce GPU memory usage when training deep models such as language models.

We will go step by step, comparing normal training vs checkpointed training, and measuring both **memory usage** and **speed trade-offs**.

---


## Setup

In [1]:
# Environment Check

import torch
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

PyTorch Version: 2.9.0+cu126
CUDA Available: True
CUDA Version: 12.6
GPU: NVIDIA A100-SXM4-40GB
GPU Memory: 42.47 GB


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.half
print(f"Using device: {device}")
print(f"Using dtype: {dtype}")

torch.backends.cuda.enable_flash_sdp(True) # ðŸ’¡ this alone won't enable flash attention :) check the Bonus section!
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(False)

Using device: cuda
Using dtype: torch.bfloat16


## Step 1: Define a Toy Transformer Block

In [3]:
class SimpleTransformerBlock(nn.Module):
    def __init__(self, d_model=512, n_heads=4):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.linear1 = nn.Linear(d_model, 4*d_model)
        self.linear2 = nn.Linear(4*d_model, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def attn_path(self, x):
        h = self.norm1(x)
        return x + self.attn(h,h,h, need_weights=True)[0]

    def mlp_path(self, x):
        h = self.norm2(x)
        return x + self.linear2(F.relu(self.linear1(h)))

    def forward(self, x):
        x = self.attn_path(x)
        x = self.mlp_path(x)
        return x

## Step 2: Stack Many Blocks (Without Checkpointing)

In [4]:
class DeepModel(nn.Module):
    def __init__(self, depth=12):
        super().__init__()
        self.blocks = nn.ModuleList([SimpleTransformerBlock().to(dtype=dtype, device='cuda') for _ in range(depth)])

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x

## Step 3: Add Activation Checkpointing

Complete the forward pass so to enable activation checkpointing:

- Attention sublayer is checkpointed when ckpt_attn=True.

- MLP sublayer is checkpointed when ckpt_mlp=True.

ðŸ’¡ The first block runs normally (no checkpoint) to avoid cold-start edge cases.

In [5]:

class CheckpointedModel(nn.Module):
    def __init__(self, depth=12, ckpt_attn=True, ckpt_mlp=True):
        super().__init__()
        self.ckpt_attn = ckpt_attn
        self.ckpt_mlp = ckpt_mlp
        self.blocks = nn.ModuleList([SimpleTransformerBlock().to(dtype=dtype, device='cuda')  for _ in range(depth)])

    def forward(self, x):
        for i, block in enumerate(self.blocks):
            if i == 0:
                # First block runs normally (no checkpoint), as I read it may have cold-start edge cases
                # Probably an ugly if condition, should rethink
                x = block(x)
            else:
                # Apply checkpointing based on flags
                if self.ckpt_attn:
                    x = checkpoint(block.attn_path, x, use_reentrant=False)
                else:
                    x = block.attn_path(x)
                
                if self.ckpt_mlp:
                    x = checkpoint(block.mlp_path, x, use_reentrant=False)
                else:
                    x = block.mlp_path(x)
        return x


## Step 4: Compare Memory Usage and Speed

For the 4 configs below, check the elapsed time and peak memory usage.

Why does checkpointing the attention sublayer lead to large memory savings?

Why does checkpointing the MLP sublayer usually give much smaller savings compared to attention?

In [6]:
def measure_run(model, x, steps=10):
    model.to(device)
    x = x.to(device)
    opt = torch.optim.AdamW(model.parameters())
    torch.cuda.reset_peak_memory_stats()

    start = time.time()
    for _ in range(steps):
        opt.zero_grad()
        out = model(x)
        loss = out.mean()
        loss.backward()
        opt.step()
    end = time.time()

    max_mem = torch.cuda.max_memory_allocated() / 1e6
    return (end-start)/steps, max_mem


x = torch.randn(1, 4096, 512, dtype=dtype, device="cuda")  # batch, seq, hidden - adjust based on your GPU memory

m1 = CheckpointedModel(depth=12, ckpt_attn=False, ckpt_mlp=False)
m2 = CheckpointedModel(depth=12, ckpt_attn=True, ckpt_mlp=True)
m3 = CheckpointedModel(depth=12, ckpt_attn=False, ckpt_mlp=True)
m4 = CheckpointedModel(depth=12, ckpt_attn=True, ckpt_mlp=False)
t1, m1_mem = measure_run(m1, x)
t2, m2_mem = measure_run(m2, x)
t3, m3_mem = measure_run(m3, x)
t4, m4_mem = measure_run(m4, x)


print(f"Without checkpointing: {t1:.2f}s, {m1_mem:.0f} MB")
print(f"With checkpointing:    {t2:.2f}s, {m2_mem:.0f} MB")
print(f"CKPT mlp:              {t3:.2f}s, {m3_mem:.0f} MB")
print(f"CKPT attn:             {t4:.2f}s, {m4_mem:.0f} MB")


Without checkpointing: 0.12s, 3140 MB
With checkpointing:    0.06s, 1414 MB
CKPT mlp:              0.03s, 3093 MB
CKPT attn:             0.04s, 1787 MB


## Answer to our above questions

Let's look at our forward pass, our model has d_model=512, n_heads=4, and we're feeding it a sequence of length 4096.

In attention, we compute Q, K, V from the input, then do Q @ K^T to get attention scores. That matrix multiply means that with 4096 tokens, we have 4096 Ã— 4096 = ~16.7 million entries per head for the attention score calculation of Q@K_T. We have 4 heads, so that's ~65 million numbers per layer just for the attention weights. In bfloat16 (2 bytes each), that's roughly 134 MB per layer. Multiply by 11 checkpointed layers and we have ~1.5 GB of attention matrices that need to stick around for the backward pass.

Now let's look at the MLP. It takes in (batch, seq, 512) and expands it to (batch, seq, 2048) via linear1, then squashes back down. The intermediate activation we need to save is (1, 4096, 2048) = ~8 million numbers. That's about 16 MB per layer, or ~176 MB across all layers.

So, attention scales with sequence length *squared* because every token talks to every other token. MLP scales *linearly* because it just processes each token position independently because it doesn't care how many other tokens exist. At seq_len=4096, the squared term completely dominates. That's why checkpointing attention alone drops memory from 3140 MB to 1787 MB (saving ~1350 MB), while checkpointing MLP alone barely moves the needle from 3140 MB to 3093 MB (saving ~47 MB).

This is also why FlashAttention is awesome as it's specifically attacking that O(n^2) memory problem in attention by being clever about what or how it computes/fuses.

# Bonus
Modify the code to only checkpoint **every other block** instead of all blocks.  
   - What trade-off do you observe?

#### 'checkpoint_sequential':

PyTorch provides 'torch.utils.checkpoint.checkpoint_sequential' for checkpointing sequential models. It splits a model into N contiguous segments and checkpoints each segment. However, it's less suitable here because: a) it operates at the block level, not sublayer level (can't separately control attention vs MLP checkpointing), and b) it creates contiguous chunks rather than alternating patterns. It may be a good choice when we want simple, coarse-grained checkpointing without fine control.

In [7]:
class EveryOtherCheckpointModel(nn.Module):
    def __init__(self, depth=12, ckpt_attn=True, ckpt_mlp=True):
        super().__init__()
        self.ckpt_attn = ckpt_attn
        self.ckpt_mlp = ckpt_mlp
        self.blocks = nn.ModuleList([SimpleTransformerBlock().to(dtype=dtype, device='cuda') for _ in range(depth)])

    def forward(self, x):
        for i, block in enumerate(self.blocks):
            use_ckpt = (i % 2 == 1)  # checkpoint odd-indexed blocks only
            if use_ckpt and self.ckpt_attn:
                x = checkpoint(block.attn_path, x, use_reentrant=False)
            else:
                x = block.attn_path(x)
            if use_ckpt and self.ckpt_mlp:
                x = checkpoint(block.mlp_path, x, use_reentrant=False)
            else:
                x = block.mlp_path(x)
        return x

# Compare: full checkpointing vs every-other checkpointing
x = torch.randn(1, 4096, 512, dtype=dtype, device="cuda")

m_full = CheckpointedModel(depth=12, ckpt_attn=True, ckpt_mlp=True)
m_half = EveryOtherCheckpointModel(depth=12, ckpt_attn=True, ckpt_mlp=True)

t_full, mem_full = measure_run(m_full, x)
t_half, mem_half = measure_run(m_half, x)

print(f"Full checkpointing:        {t_full:.2f}s, {mem_full:.0f} MB")
print(f"Every-other checkpointing: {t_half:.2f}s, {mem_half:.0f} MB")

Full checkpointing:        0.05s, 1819 MB
Every-other checkpointing: 0.04s, 2803 MB


## Trade-off



No checkpointing (from above) 0.12s, 3140 MB  
Full checkpointing:        0.05s, 1819 MB  
Every-other checkpointing: 0.04s, 2803 MB  

When it comes to mem, every-other uses ~50% more memory than full checkpointing (2803 vs 2829 MB). This makes sense since we only discard half the activations. Still saves ~10% vs no checkpointing.
Regarding speed though, this is surprisingly similar to full checkpointing (0.05s).... 
The recomputation cost is dominated by attention O(n^2), and we're still recomputing half the attention blocks.

Overall though, if we compare every-other with full checkpointing, every-other seems to provide a bad trade-off if we just look at these numbers for the A100 GPU, but in general, we would expect every-other to have more compute involvement. We get most of the speed penalty of full checkpointing but just half of the memory savings. This is likely because attention recomputation is expensive regardless of how many blocks we skip...

With this limited toy example, we could make a case that perhaps full checkpointing or no checkpointing are usually better choices than partial strategies.

# Bonus
The code above currently does not use FlashAttention, which means attention has quadratic memory complexity with respect to sequence length (O(nÂ²)).

ðŸ‘‰ However, PyTorch lets you enable FlashAttention by simply toggling a single flag in your code.

ðŸ’¡ Challenge: Can you find and change that flag so that your model runs with FlashAttention instead of the standard (quadratic) path?

In [None]:
class SimpleTransformerBlockFlash(nn.Module):
    def __init__(self, d_model=512, n_heads=4):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.linear1 = nn.Linear(d_model, 4*d_model)
        self.linear2 = nn.Linear(4*d_model, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def attn_path(self, x):
        h = self.norm1(x)
        return x + self.attn(h, h, h, need_weights=False)[0]  # 'need_weights=False' enables FlashAttention!

    def mlp_path(self, x):
        h = self.norm2(x)
        return x + self.linear2(F.relu(self.linear1(h)))

    def forward(self, x):
        x = self.attn_path(x)
        x = self.mlp_path(x)
        return x


class CheckpointedModelWithFlashAttn(nn.Module):
    def __init__(self, depth=12, ckpt_attn=True, ckpt_mlp=True):
        super().__init__()
        self.ckpt_attn = ckpt_attn
        self.ckpt_mlp = ckpt_mlp
        self.blocks = nn.ModuleList([SimpleTransformerBlockFlash().to(dtype=dtype, device='cuda') for _ in range(depth)])

    def forward(self, x):
        for i, block in enumerate(self.blocks):
            if i == 0:
                x = block(x)
            else:
                if self.ckpt_attn:
                    x = checkpoint(block.attn_path, x, use_reentrant=False)
                else:
                    x = block.attn_path(x)
                if self.ckpt_mlp:
                    x = checkpoint(block.mlp_path, x, use_reentrant=False)
                else:
                    x = block.mlp_path(x)
        return x


class EveryOtherCheckpointModelWithFlashAttn(nn.Module):
    def __init__(self, depth=12, ckpt_attn=True, ckpt_mlp=True):
        super().__init__()
        self.ckpt_attn = ckpt_attn
        self.ckpt_mlp = ckpt_mlp
        self.blocks = nn.ModuleList([SimpleTransformerBlockFlash().to(dtype=dtype, device='cuda') for _ in range(depth)])

    def forward(self, x):
        for i, block in enumerate(self.blocks):
            use_ckpt = (i % 2 == 1)
            if use_ckpt and self.ckpt_attn:
                x = checkpoint(block.attn_path, x, use_reentrant=False)
            else:
                x = block.attn_path(x)
            if use_ckpt and self.ckpt_mlp:
                x = checkpoint(block.mlp_path, x, use_reentrant=False)
            else:
                x = block.mlp_path(x)
        return x


# Compare with and without FlashAttention
x = torch.randn(1, 4096, 512, dtype=dtype, device="cuda")

m_no_flash = CheckpointedModel(depth=12, ckpt_attn=True, ckpt_mlp=True)
m_flash = CheckpointedModelWithFlashAttn(depth=12, ckpt_attn=True, ckpt_mlp=True)
m_flash_every_other = EveryOtherCheckpointModelWithFlashAttn(depth=12, ckpt_attn=True, ckpt_mlp=True)

t1, mem1 = measure_run(m_no_flash, x)
t2, mem2 = measure_run(m_flash, x)
t3, mem3 = measure_run(m_flash_every_other, x)

print(f"Full ckpt (no flash):     {t1:.2f}s, {mem1:.0f} MB")
print(f"Full ckpt (flash):        {t2:.2f}s, {mem2:.0f} MB")
print(f"Every-other ckpt (flash): {t3:.2f}s, {mem3:.0f} MB")

Full ckpt (no flash):     0.05s, 2216 MB
Full ckpt (flash):        0.05s, 1691 MB
Every-other ckpt (flash): 0.04s, 1987 MB
