In [None]:
from __future__ import annotations

from typing import Iterable, List, Union

import torch
import torch.nn as nn
from torch.optim import SGD
from torch.optim.optimizer import required  # type: ignore

from timm.optim.optim_factory import create_optimizer_v2, param_groups_weight_decay

In [None]:
param = torch.Tensor()
learning_rate = 3e-3
model = nn.Module()
list_of_params_without_wd = [model.parameters()]

# Optimizers

We will start by looking at Composer's `DecoupledSGDW`, as it fastai's SGD with momentum and decoupled weight decay and is less complicated then PyTorch's `SGD`.

Remember that a basic SGD step is:

In [None]:
param = param + param.grad * -learning_rate

# Or written using PyTorch ops

param = param.add(param.grad, alpha=-learning_rate)

We'll look at each part step by step, with Nesterov momentum, `initial_lr`, clouture, and extra comments removed to simplify the code a bit.

An optimizer is initialized with the parameters we want to train, from any number of models with gradients, and the optimizer specific hyperparameters.

In [None]:
class DecoupledSGDW(SGD):
    """SGD optimizer with the weight decay term decoupled from the learning rate.

    Args:
        params (iterable): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        momentum (int, optional): Momentum factor. Default: ``0``.
        dampening (int, optional): Dampening factor applied to the momentum. Default: ``0``.
        weight_decay (int, optional): Decoupled weight decay factor. Default: ``0``.
    """
    def __init__(self,
                 params: Union[Iterable[torch.Tensor], Iterable[dict]],
                 lr: float = required,
                 momentum: float = 0,
                 dampening: float = 0,
                 weight_decay: float = 0):

        super().__init__(params=params,
                         lr=lr,
                         momentum=momentum,
                         dampening=dampening,
                         weight_decay=weight_decay)

`params` can also be parameter groups, which are dictionaries with default values for that parameter group. So, if we wanted to not apply weight decay to normalization layers or bias terms (which is usually a good idea), we'd create two parameter groups, one with normal model parameters and one with our normalization layers or bias term parameters.

In [None]:
# for conciseness, I omit the model parameters, which are required to be part of the 
parameter_groups = [
    dict(weight_decay=0.),
    dict(weight_decay=1e-2)
]

# or more likely, we'd use a method to do it for us. Like this param_groups_weight_decay from timm
parameter_groups = param_groups_weight_decay(model, 1e-2, list_of_params_without_wd)

Recall in the training loop the optimization step occurs after we calculate the gradients from the loss.

In [None]:
def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    for epoch in range(epochs):
        model.train()
        for xb,yb in train_dl:
            loss = loss_func(model(xb), yb)
            loss.backward()
            opt.step()
            opt.zero_grad()

The optimizer step function will loop through all parameter groups and collect parameter groups default values, then loop through all parameters and apply the optimization step to any parameter with a gradient. (Frozen model layers will not have gradients).

In [None]:
@torch.no_grad()
def step(self):
    "Performs a single optimization step."
    for group in self.param_groups:
        params_with_grad = []
        grad_list = []
        momentum_buffer_list = []
        weight_decay = group['weight_decay']
        momentum = group['momentum']
        dampening = group['dampening']
        lr = group['lr']

        for p in group['params']:
            if p.grad is not None:
                params_with_grad.append(p)
                grad_list.append(p.grad)

                state = self.state[p]
                if 'momentum_buffer' not in state:
                    momentum_buffer_list.append(None)
                else:
                    momentum_buffer_list.append(state['momentum_buffer'])

        self.sgdw(params_with_grad,
                    grad_list,
                    momentum_buffer_list,
                    weight_decay=weight_decay,
                    momentum=momentum,
                    lr=lr,
                    dampening=dampening)

        # update momentum_buffers in state
        for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
            state = self.state[p]
            state['momentum_buffer'] = momentum_buffer

And then we apply the optimizer step to the parameters using the gradients and our hyperparameters

In [None]:
@staticmethod
def sgdw(params: List[torch.Tensor], grad_list: List[torch.Tensor], momentum_buffer_list: List[torch.Tensor], *,
            weight_decay: float, momentum: float, lr: float, initial_lr: float, dampening: float, nesterov: bool):
    "Functional API that performs SGDW algorithm computation."
    for i, param in enumerate(params):
        grad = grad_list[i]
        
        if momentum != 0:
            buf = momentum_buffer_list[i]

            if buf is None:
                buf = torch.clone(grad).detach()
                momentum_buffer_list[i] = buf
            else:
                buf.mul_(momentum).add_(grad, alpha=1 - dampening)
            grad = buf

        if weight_decay != 0:
            param.mul_(1 - lr * weight_decay)

        param.add_(grad, alpha=-lr)

The full Composer `DecoupledSGDW` optimizer (with extra comments and warnings removed), which is equivalent to fastai's SGD with momentum and decoupled weight decay.

In [None]:
class DecoupledSGDW(SGD):
    def __init__(self,
                 params: Union[Iterable[torch.Tensor], Iterable[dict]],
                 lr: float = required,
                 momentum: float = 0,
                 dampening: float = 0,
                 weight_decay: float = 0,
                 nesterov: bool = False):
        super().__init__(params=params,
                         lr=lr,
                         momentum=momentum,
                         dampening=dampening,
                         weight_decay=weight_decay,
                         nesterov=nesterov)
        for group in self.param_groups:
            group['initial_lr'] = group['lr']

    @staticmethod
    def sgdw(params: List[torch.Tensor], grad_list: List[torch.Tensor], momentum_buffer_list: List[torch.Tensor], *,
             weight_decay: float, momentum: float, lr: float, initial_lr: float, dampening: float, nesterov: bool):
        "Functional API that performs SGDW algorithm computation."
        for i, param in enumerate(params):

            grad = grad_list[i]

            if momentum != 0:
                buf = momentum_buffer_list[i]

                if buf is None:
                    buf = torch.clone(grad).detach()
                    momentum_buffer_list[i] = buf
                else:
                    buf.mul_(momentum).add_(grad, alpha=1 - dampening)

                if nesterov:
                    grad = grad.add(buf, alpha=momentum)
                else:
                    grad = buf

            if weight_decay != 0:
                decay_factor = (lr / initial_lr) if initial_lr else 1.0
                param.mul_(1 - decay_factor * weight_decay)

            param.add_(grad, alpha=-lr)

    @torch.no_grad()
    def step(self, closure=None):
        "Performs a single optimization step."
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            grad_list = []
            momentum_buffer_list = []
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            lr = group['lr']
            initial_lr = group['initial_lr']

            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)
                    grad_list.append(p.grad)

                    state = self.state[p]
                    if 'momentum_buffer' not in state:
                        momentum_buffer_list.append(None)
                    else:
                        momentum_buffer_list.append(state['momentum_buffer'])

            self.sgdw(params_with_grad,
                      grad_list,
                      momentum_buffer_list,
                      weight_decay=weight_decay,
                      momentum=momentum,
                      lr=lr,
                      initial_lr=initial_lr,
                      dampening=dampening,
                      nesterov=nesterov)

            # update momentum_buffers in state
            for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
                state = self.state[p]
                state['momentum_buffer'] = momentum_buffer

        return loss

Look at PyTorch's SGD, to see additional features like for each methods, nestorov momentum, etc.

## Fastai Optimizers

In [None]:
from fastai.vision.all import *
from fastai.optimizer import _BaseOptimizer, _update

fastai optimizers a bit different.

First, every hyperparameter can be set per parameter or parameter group.

Second, optimizers are defined as optim step callbacks, which are called by `Optimizer` one after another.

Third, they support discriminative learning rates out of the box. i.e. different learning rates per parameter group.

Like PyTorch optimizers, `Optimizer` is initialized with the parameters we want to train and the optimizer specific hyperparameters.

In [None]:
class Optimizer(_BaseOptimizer):
    "Base optimizer class for the fastai library, updating `params` with `cbs`"
    _keep_on_clear = ['force_train', 'do_wd']
    def __init__(self,
        params:Tensor, # Model parameters or parameter groups
        cbs:list, # `Optimizer` step callbacks
        **defaults # Hyper parameters default values
    ):
        params = L(params)
        self.cbs,self.state = L(cbs),defaultdict(dict)
        defaults = merge(*self.cbs.attrgot('defaults'), defaults)
        self.param_lists = L(L(p) for p in params) if isinstance(params[0], (L,list)) else L([params])
        self.hypers = L({} for _ in range_of(self.param_lists))
        self.set_hypers(**defaults)
        self.frozen_idx = 0

And optimizer step function loops through all parameter groups and collect parameter groups default values, and then loop through all parameters and apply the optimization step to any parameter with a gradient.

Except there is an additional loop for optimizer callbacks, which apply the optimizer steps and update the state dictionary.

In [None]:
def all_params(self,
    with_grad:bool=False # Get all parameters. If `True` select only those with a gradient
):
    res = L((p,pg,self.state[p],hyper) for pg,hyper in zip(self.param_lists,self.hypers) for p in pg)
    return L(o for o in res if hasattr(o[0], 'grad') and o[0].grad is not None) if with_grad else res


def step(self):
    for p,pg,state,hyper in self.all_params(with_grad=True):
        for cb in self.cbs: 
            state = _update(state, cb(p, **{**state, **hyper}))
        self.state[p] = state

A simplified SGD (showing only SGDW decoupled true weight decay). 

Note all the separate optimizer steps defined as callbacks added to `cbs`.

In [None]:
def SGD(
    params:Tensor, # Model parameters or parameter groups
    lr:float, # Default learning rate
    mom:float=0., # Gradient moving average (β1) coefficient
    wd:float=0., # Optional weight decay (true or L2)
) -> Optimizer:
    "A `Optimizer` for SGD with `lr` and `mom` and `params`"
    cbs = [weight_decay]
    if mom != 0: cbs.append(average_grad)
    cbs.append(sgd_step if mom==0 else momentum_step)
    return Optimizer(params, cbs, lr=lr, mom=mom, wd=wd)

And all the optimizer callback steps (simplified a bit), which are almost the same as `DecoupledSGDW` above.

In [None]:
def weight_decay(p, lr, wd, do_wd=True, **kwargs):
    "Weight decay as decaying `p` with `lr*wd`"
    if do_wd and wd!=0: 
        p.data.mul_(1 - lr*wd)


def average_grad(p, mom, grad_avg=None, **kwargs):
    "Keeps track of the avg grads of `p` in `state` with `mom`."
    if grad_avg is None: 
        grad_avg = torch.zeros_like(p.grad.data)

    grad_avg.mul_(mom).add_(p.grad.data)
    return {'grad_avg': grad_avg}


def momentum_step(p, lr, grad_avg, **kwargs):
    "Step for SGD with momentum with `lr`"
    p.data.add_(grad_avg, alpha=-lr)


def sgd_step(p, lr, **kwargs):
    "Step for SGD with `lr`"
    p.data.add_(p.grad.data, alpha=-lr)

The full `Optimizer` class with `zero_grad`, loading, and clearing, the `state_dict`.

In [None]:
class Optimizer(_BaseOptimizer):
    "Base optimizer class for the fastai library, updating `params` with `cbs`"
    _keep_on_clear = ['force_train', 'do_wd']
    def __init__(self,
        params:Tensor, # Model parameters or parameter groups
        cbs:list, # `Optimizer` step callbacks
        **defaults # Hyper parameters default values
    ):
        params = L(params)
        self.cbs,self.state = L(cbs),defaultdict(dict)
        defaults = merge(*self.cbs.attrgot('defaults'), defaults)
        self.param_lists = L(L(p) for p in params) if isinstance(params[0], (L,list)) else L([params])
        self.hypers = L({} for _ in range_of(self.param_lists))
        self.set_hypers(**defaults)
        self.frozen_idx = 0

    def zero_grad(self):
        for p,*_ in self.all_params(with_grad=True):
            p.grad.detach_()
            p.grad.zero_()

    def step(self, closure=None):
        if closure is not None: raise NotImplementedError("fastai optimizers currently do not support closure")
        for p,pg,state,hyper in self.all_params(with_grad=True):
            for cb in self.cbs: state = _update(state, cb(p, **{**state, **hyper}))
            self.state[p] = state

    def clear_state(self):
        for p,pg,state,hyper in self.all_params():
            self.state[p] = {k: state[k] for k in self._keep_on_clear if k in state}

    def state_dict(self):
        state = [self.state[p] for p,*_ in self.all_params()]
        return {'state': state, 'hypers': self.hypers}

    def load_state_dict(self,
        sd:dict # State dict with `hypers` and `state` to load on the optimizer
    ):
        assert len(sd["hypers"]) == len(self.param_lists)
        assert len(sd["state"])  == sum([len(pg) for pg in self.param_lists])
        self.hypers = sd['hypers']
        self.state = {p: s for p,s in zip(self.all_params().itemgot(0), sd['state'])}

Setting hyper parameters, freezing and unfreezing, and compatibility with PyTorch optimizers are all part of `_BaseOptimizer`