In [None]:
%cd ..
%env CUDA_VISIBLE_DEVICES=3

In [None]:
import torch
import copy
from accelerate import Accelerator
from accelerate.utils import set_seed
from torch.utils.data import TensorDataset, DataLoader

In [None]:
# seed
set_seed(0)

# define toy inputs and labels
x = torch.tensor([1., 2., 3., 4., 5., 6., 7., 8.])
y = torch.tensor([2., 4., 6., 8., 10., 12., 14., 16.])
gradient_accumulation_steps = 4
batch_size = len(x) // gradient_accumulation_steps

# define dataset and dataloader
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=batch_size)

# define model, optimizer and loss function
class SimpleLinearModel(torch.nn.Module):
    def __init__(self):
        super(SimpleLinearModel, self).__init__()
        self.weight = torch.nn.Parameter(torch.zeros((1, 1)))

    def forward(self, inputs):
        return inputs @ self.weight

model = SimpleLinearModel()
model_clone = copy.deepcopy(model)
criterion = torch.nn.MSELoss()
model_optimizer = torch.optim.SGD(model.parameters(), lr=0.02)
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
model, model_optimizer, dataloader = accelerator.prepare(model, model_optimizer, dataloader)
model_clone_optimizer = torch.optim.SGD(model_clone.parameters(), lr=0.02)
print(f"initial model weight is {model.weight.mean().item():.5f}")
print(f"initial model weight is {model_clone.weight.mean().item():.5f}")

In [None]:
# w/ accumulation update, the real batch size is 2x4 = 8, we use batch size 2 and the gradient accumulation of 4
for i, (inputs, labels) in enumerate(dataloader):
    with accelerator.accumulate(model):
        inputs = inputs.view(-1, 1)
        print(i, inputs.flatten())
        labels = labels.view(-1, 1)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        accelerator.backward(loss)
        model_optimizer.step()
        model_optimizer.zero_grad()

# w/o accumulation update, the real batch size is 8, because loss is computed on the whole dataset and compute the gradient
loss = criterion(x.view(-1, 1) @ model_clone.weight, y.view(-1, 1))
model_clone_optimizer.zero_grad()
loss.backward()
model_clone_optimizer.step()

# Comparison of with and without accumulation
print(f"w/ accumulation, the final model weight is {model.weight.mean().item():.5f}")
print(f"w/o accumulation, the final model weight is {model_clone.weight.mean().item():.5f}")