## 导入必要的库

In [1]:
import torch

## 设置模型和训练参数

模拟一个简单的训练循环，展示梯度累积的概念

In [2]:
model = torch.nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = torch.nn.MSELoss()

# 假设我们有一个 DataLoader 提供训练数据
dataloader = [
    (torch.randn(4, 10), torch.randn(4, 1)) for _ in range(8)
]  # 8 batches of size 4

accumulation_steps = 2  # 每 2 个 batch 更新一次参数

## 梯度累积示例

演示如何在多个批次上累积梯度，然后一次性更新参数

In [3]:
optimizer.zero_grad()  # 1. 开始前清零

for i, (inputs, labels) in enumerate(dataloader):
    outputs = model(inputs)
    loss = criterion(outputs, labels)

    # Loss 除以累积步数，以保持梯度量级一致
    loss = loss / accumulation_steps
    loss.backward()  # 2. 梯度累加到 .grad 中
    print(
        f"After Batch {i + 1}, weight.grad norm: {model.weight.grad.norm().item():.4f}"
    )
    print(model.weight.grad)  # 打印当前梯度值

    if (i + 1) % accumulation_steps == 0:
        optimizer.step()  # 3. 累积够了，更新一次参数
        optimizer.zero_grad()  # 4. 更新完，立刻清零，为下一组做准备

After Batch 1, weight.grad norm: 1.5881
tensor([[-0.2413,  1.2419, -0.2858, -0.0980, -0.5272, -0.5415, -0.3746,  0.2173,
          0.0060, -0.2678]])
After Batch 2, weight.grad norm: 2.5128
tensor([[ 0.0264,  1.6652, -0.7211, -0.3304, -0.2354, -0.3890, -1.2311,  0.6651,
         -0.8374, -0.2127]])
After Batch 3, weight.grad norm: 1.3218
tensor([[-0.1533, -0.0759, -0.3419, -0.1852, -0.3860, -0.3199, -0.4549, -0.4195,
         -0.8713, -0.4162]])
After Batch 4, weight.grad norm: 1.9170
tensor([[-0.8682,  1.0178, -0.9146, -0.5453, -0.1068, -0.6311, -0.1655, -0.3507,
         -0.4149, -0.1385]])
After Batch 5, weight.grad norm: 2.5140
tensor([[ 0.9387,  1.4293, -0.1148, -0.7026, -0.1062,  1.2205,  0.2383, -0.9716,
         -0.2133, -0.5845]])
After Batch 6, weight.grad norm: 3.2805
tensor([[ 0.8122,  1.9298,  0.5303, -0.8754, -0.7006,  0.6512, -0.7304, -1.0867,
         -1.1870, -1.1368]])
After Batch 7, weight.grad norm: 2.2507
tensor([[-0.1511,  0.7296, -0.2645, -0.5020, -0.4527, -0.758

## 梯度累积的三个角色

**loss.backward()：是 "生产者" (Producer)**

它的工作是生产梯度。你可以生产一次，也可以生产多次（累积）。它是相对灵活的。

**optimizer.step()：是 "消费者" (Consumer)**

它的工作是消费梯度，修改参数。关键点：一旦它消费了当前的梯度并更新了参数，这些梯度瞬间就变成了"垃圾"（过期的历史数据）。

**optimizer.zero_grad()：是 "清洁工" (Cleaner)**

它的工作是清空梯度缓存。