In [1]:
import torch 

# super simple little MLP
net = torch.nn.Sequential(
    torch.nn.Linear(16, 32),
    torch.nn.GELU(),
    torch.nn.Linear(32, 1)
)
torch.random.manual_seed(42)
x = torch.randn(4, 16)
y = torch.randn(4, 1)
net.zero_grad()
yhat = net(x)
loss = torch.nn.functional.mse_loss(yhat, y)
loss.backward()
print(net[0].weight.grad.view(-1)[:10])

# the loss objective here is (due to reduction='mean' in mse_loss)
# L = 1/4 * sum_i (yhat_i - y_i)^2
# = 1/4 [
#        (y[0] -yhat[0])**2 +
#        (y[1] -yhat[1])**2 +
#        (y[2] -yhat[2])**2 +
#        (y[3] -yhat[3])**2
# ]
# NOTE: 1/4 ! 

tensor([-0.2537, -0.0825,  0.0240,  0.1384,  0.0733, -0.0872, -0.0494, -0.2356,
        -0.3026,  0.0435])


In [5]:
# now let's do it with grad_accum_steps of 4, and B=1
# the loss objective here is different because
# accumulation in gradient <---> SUM in loss
# i.e. we instead get:
# L0 = (y[0] -yhat[0])**2
# L1 = (y[1] -yhat[1])**2
# L2 = (y[2] -yhat[2])**2
# L3 = (y[3] -yhat[3])**2
# L = L0 + L1 + L2 + L3
# NOTE: the "normalizer" of 1/4 is lost
net.zero_grad()
for i in range(4):
    yhat = net(x[i])
    loss = torch.nn.functional.mse_loss(yhat, y[i])
    loss.backward()
    
print(net[0].weight.grad.view(-1)[:10])

tensor([-1.0148, -0.3299,  0.0961,  0.5536,  0.2931, -0.3488, -0.1977, -0.9425,
        -1.2104,  0.1740])


In [7]:
# now let's do it with grad_accum_steps of 4, and B=1
# the loss objective here is different because
# accumulation in gradient <---> SUM in loss
# i.e. we instead get:
# L0 = 1/4 (y[0] -yhat[0])**2
# L1 = 1/4 (y[1] -yhat[1])**2
# L2 = 1/4 (y[2] -yhat[2])**2
# L3 = 1/4 (y[3] -yhat[3])**2
# L = L0 + L1 + L2 + L3
# NOTE: the "normalizer" of 1/4 is inside every of the components
net.zero_grad()
for i in range(4):
    yhat = net(x[i])
    loss = torch.nn.functional.mse_loss(yhat, y[i])
    loss = loss / 4 # <-- this is the "normalizer"
    loss.backward()
    
print(net[0].weight.grad.view(-1)[:10])

tensor([-0.2537, -0.0825,  0.0240,  0.1384,  0.0733, -0.0872, -0.0494, -0.2356,
        -0.3026,  0.0435])
