## 导入必要的库

In [1]:
import gc
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint

## 辅助函数

定义用于监控内存使用和重置内存状态的函数

In [2]:
def get_memory_usage():
    """Returns current GPU memory usage in MB."""
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1024 / 1024
    return 0

def reset_memory():
    """Clears cache and garbage collects to ensure fair measurement."""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
    gc.collect()

## 定义模型组件

创建一个占用大量激活内存的层

In [3]:
class HeavyLayer(nn.Module):
    """
    A layer designed to consume significant activation memory.
    It performs a large matrix multiplication and an activation.
    """

    def __init__(self, size=2048):
        super().__init__()
        self.linear = nn.Linear(size, size)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.linear(x))

## 支持检查点的模型

定义一个可选择性使用梯度检查点的顺序模型

In [None]:
class CheckpointedModel(nn.Module):
    """
    A sequential model that can optionally use gradient checkpointing.
    """

    def __init__(self, num_layers=20, size=2048, use_checkpointing=False):
        super().__init__()
        self.layers = nn.ModuleList([HeavyLayer(size) for _ in range(num_layers)])
        self.use_checkpointing = use_checkpointing
        self.size = size

    def forward(self, x):
        if self.use_checkpointing:
            # Apply checkpointing to chunks of layers (e.g., every 4 layers)
            # This ensures that activations for entire segments are not stored
            for i in range(0, len(self.layers), 4):
                # Create a sequential block of 4 layers
                def run_segment(x, start_idx=i):
                    for j in range(start_idx, min(start_idx + 4, len(self.layers))):
                        x = self.layers[j](x)
                    return x

                # Checkpoint the entire segment
                x = checkpoint.checkpoint(run_segment, x, use_reentrant=False)
        else:
            # Standard forward pass - all activations are saved
            for layer in self.layers:
                x = layer(x)
        return x

## 实验函数

运行实验来测量有无梯度检查点时的内存使用

In [None]:
def run_experiment(use_checkpointing: bool):
    print(
        f"\n--- Running Experiment: Checkpointing={'ON' if use_checkpointing else 'OFF'} ---"
    )
    reset_memory()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup model and data
    model = CheckpointedModel(
        num_layers=20, size=4096, use_checkpointing=use_checkpointing
    ).to(device)
    inputs = torch.randn(32, 4096, device=device, requires_grad=True)

    # Measure memory before forward
    mem_start = get_memory_usage()
    print(f"Memory Usage Before Forward: {mem_start:.2f} MB")

    # Forward pass
    output = model(inputs)
    loss = output.sum()

    # Measure peak memory after forward (before backward)
    assert device.type == "cuda"
    mem_after_forward = torch.cuda.memory_allocated() / 1024 / 1024
    mem_peak_forward = torch.cuda.max_memory_allocated() / 1024 / 1024

    print("Forward pass complete.")
    print(f"Memory Usage After Forward (before backward): {mem_after_forward:.2f} MB")
    print(f"Peak Memory Usage During Forward: {mem_peak_forward:.2f} MB")

    # Reset peak stats to measure backward separately
    torch.cuda.reset_peak_memory_stats()

    # Backward pass
    loss.backward()

    mem_peak_backward = torch.cuda.max_memory_allocated() / 1024 / 1024
    mem_after_backward = torch.cuda.memory_allocated() / 1024 / 1024

    print(f"Peak Memory Usage During Backward: {mem_peak_backward:.2f} MB")
    print(f"Memory Usage After Backward: {mem_after_backward:.2f} MB")
    print("Backward pass complete.")

    # Calculate total peak memory (forward + backward)
    torch.cuda.reset_peak_memory_stats()
    return mem_after_forward, mem_peak_backward

## 运行实验

### 说明

这个脚本演示了梯度检查点如何以计算换内存。它构建了一个深度网络，并测量了有无检查点时的峰值内存使用。

**注意：** 这个示例在 GPU 上观看效果最佳，可以看到精确的内存分配统计。在 CPU 上运行仅用于演示逻辑。

In [12]:
if not torch.cuda.is_available():
    raise RuntimeError("This script is best run on a machine with a CUDA-capable GPU.")

print("PyTorch Gradient Checkpointing Demonstration")
print("============================================")
print(
    "This script demonstrates how gradient checkpointing trades compute for memory."
)
print(
    "It constructs a deep network and measures peak memory usage with and without checkpointing.\n"
)


PyTorch Gradient Checkpointing Demonstration
This script demonstrates how gradient checkpointing trades compute for memory.
It constructs a deep network and measures peak memory usage with and without checkpointing.



### 1. 不使用检查点运行（标准行为）

这将为所有 20 层存储激活值以计算梯度

In [17]:
run_experiment(use_checkpointing=False)


--- Running Experiment: Checkpointing=OFF ---
Memory Usage Before Forward: 1299.06 MB
Forward pass complete.
Memory Usage After Forward (before backward): 1309.06 MB
Peak Memory Usage During Forward: 1309.56 MB
Peak Memory Usage During Backward: 2580.88 MB
Memory Usage After Backward: 2580.38 MB
Backward pass complete.
Memory Usage Before Forward: 1299.06 MB
Forward pass complete.
Memory Usage After Forward (before backward): 1309.06 MB
Peak Memory Usage During Forward: 1309.56 MB
Peak Memory Usage During Backward: 2580.88 MB
Memory Usage After Backward: 2580.38 MB
Backward pass complete.


(1309.06298828125, 2580.8759765625)

### 2. 使用检查点运行

这将只存储每个检查点段的输入。段内的激活值在反向传播期间重新计算

In [18]:
run_experiment(use_checkpointing=True)


--- Running Experiment: Checkpointing=ON ---
Memory Usage Before Forward: 1299.06 MB
Forward pass complete.
Memory Usage After Forward (before backward): 1301.56 MB
Peak Memory Usage During Forward: 1302.56 MB
Peak Memory Usage During Backward: 2580.88 MB
Memory Usage After Backward: 2580.38 MB
Backward pass complete.
Memory Usage Before Forward: 1299.06 MB
Forward pass complete.
Memory Usage After Forward (before backward): 1301.56 MB
Peak Memory Usage During Forward: 1302.56 MB
Peak Memory Usage During Backward: 2580.88 MB
Memory Usage After Backward: 2580.38 MB
Backward pass complete.


(1301.56298828125, 2580.8759765625)

## 结果分析

从上面的结果可以看到：

### Forward Pass 内存差异（关键指标）
- **不使用检查点**: Forward 后内存 ~1309 MB
- **使用检查点**: Forward 后内存 ~1302 MB

虽然差异不是很大（约 7-8 MB），这是因为：

1. **模型参数和梯度始终需要存储**（约 1.3 GB）
2. **激活值的节省主要体现在前向传播后**
3. **Backward 峰值相同是正常的**，因为：
   - 梯度检查点在 backward 时需要重新计算前向传播
   - 重新计算时仍然会临时产生激活值
   - 参数梯度必须保存

### 如何看到更明显的差异？

要看到更显著的内存节省，需要：
1. 使用更深的网络（50+ 层）
2. 使用更大的批次大小
3. 使用更大的隐藏层维度
4. 关注 forward 后的内存使用，而不是 backward 峰值

## 更极端的例子

让我们用更大的模型来展示更明显的内存差异

In [None]:
def run_extreme_experiment(use_checkpointing: bool):
    """
    使用更大的模型参数来展示更明显的内存差异
    """
    print(
        f"\n--- EXTREME Experiment: Checkpointing={'ON' if use_checkpointing else 'OFF'} ---"
    )
    reset_memory()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 更大的模型：更多层，更大的尺寸，更大的批次
    model = CheckpointedModel(
        num_layers=40,  # 40 层
        size=4096,      # 4096 维度
        use_checkpointing=use_checkpointing
    ).to(device)

    # 更大的批次
    inputs = torch.randn(64, 4096, device=device, requires_grad=True)

    mem_start = get_memory_usage()
    print(f"Memory Before Forward: {mem_start:.2f} MB")

    # Forward
    output = model(inputs)
    loss = output.sum()

    mem_after_forward = torch.cuda.memory_allocated() / 1024 / 1024
    print(f"Memory After Forward: {mem_after_forward:.2f} MB")
    print(f"  → Activation Memory: {mem_after_forward - mem_start:.2f} MB")

    # Backward
    torch.cuda.reset_peak_memory_stats()
    loss.backward()

    mem_peak_backward = torch.cuda.max_memory_allocated() / 1024 / 1024
    print(f"Peak Memory During Backward: {mem_peak_backward:.2f} MB")

    return mem_after_forward - mem_start, mem_peak_backward

In [20]:
print("\n" + "="*60)
print("极端测试：40 层，批次大小 64")
print("="*60)

result_no_ckpt = run_extreme_experiment(use_checkpointing=False)
result_with_ckpt = run_extreme_experiment(use_checkpointing=True)

print("\n" + "="*60)
print("对比总结")
print("="*60)
print(f"不使用检查点 - Forward 激活内存: {result_no_ckpt[0]:.2f} MB")
print(f"使用检查点   - Forward 激活内存: {result_with_ckpt[0]:.2f} MB")
print(f"内存节省: {result_no_ckpt[0] - result_with_ckpt[0]:.2f} MB ({(1 - result_with_ckpt[0]/result_no_ckpt[0])*100:.1f}%)")
print("="*60)


极端测试：40 层，批次大小 64

--- EXTREME Experiment: Checkpointing=OFF ---
Memory Before Forward: 2579.88 MB
Memory After Forward: 2619.88 MB
  → Activation Memory: 40.00 MB
Peak Memory During Backward: 5143.50 MB

--- EXTREME Experiment: Checkpointing=ON ---
Memory Before Forward: 2579.88 MB
Memory After Forward: 2589.88 MB
  → Activation Memory: 10.00 MB
Peak Memory During Backward: 5143.50 MB

对比总结
不使用检查点 - Forward 激活内存: 40.00 MB
使用检查点   - Forward 激活内存: 10.00 MB
内存节省: 30.00 MB (75.0%)
