Weight decay, or L2 regularization, is a regularization method where the sum of all weights squared with a weight decay factor is add to the loss function. The intuition behind weight decay is the following. We add a term with weight params to the loss function. As a result, the gradient of the loss function will be larger, so the parameters will descend more quickly proportional to the size of the weight. This will move the parameters to smaller values.

That is we add an extra term to the loss function, where $w_i$ are weights and $\lambda$ is the weight decay factor:
$$
L_{total} = L_{original} + \lambda \sum_i w_i^2 
$$
When weights are large, the $\lambda \sum_i w_i^2$ term grows and leads to larger gradients and therefore quicker descend. 

The original SGD becomes:
$$
\frac{\partial L_{total}}{\partial w} = \frac{\partial L_{original}}{\partial w} + 2\lambda w
$$
In practice, we ignore the coefficient 2 in front of the weight decay term.

TODO: add highlight note

So the update rule becomes
$$
w_{new} \leftarrow w_{original} - \eta \left(\frac{\partial L_{original}}{\partial w_{original}} + 2\lambda w_{original} \right)
$$
where $\eta$ is the learning rate.


TODO: add SGD link

In the following example, we illustrate how weight decay works by showing how weight decay works for function $f = x^2$, we show side by side how SGD steps with and without a weight decay factor with two optimizers. As we can see, the resulting weights with weight decay factor ended up smaller than the weights without weight decay.

In [17]:
import torch

w = torch.tensor([1.0], requires_grad=True)

opt_no_wd = torch.optim.SGD([w], lr=0.1, weight_decay=0.0)
loss = w * w
print("Initial weight:", w.item())

opt_no_wd.zero_grad()
loss.backward()
opt_no_wd.step()
print("After step without weight decay:", w.item())

# Reset weight
w = torch.tensor([1.0], requires_grad=True)
opt_wd = torch.optim.SGD([w], lr=0.1, weight_decay=0.1)

# Step with weight decay
opt_wd.zero_grad()
loss = w * w
loss.backward()
opt_wd.step()
print("After step with weight decay:", w.item())

Initial weight: 1.0
After step without weight decay: 0.800000011920929
After step with weight decay: 0.7900000214576721
