In [None]:
from __future__ import annotations

from collections.abc import MutableSequence

from fastai.vision.all import *
from fastai.optimizer import _BaseOptimizer, _update

from fastxtend.optimizer.foreach import ForEachOptimizer

# Optimizers: Part 2

In the Lesson 14 Optimizers notebook, we went over PyTorch and fastai Optimizers and showed how SGD was implemented in both frameworks.

In this notebook, we will review how fastai Optimizers work and implement AdamW from the paper's algorithm.

## Fastai Optimizer Review

fastai optimizers are unique in two ways. First, fastai Optimizers can set weight decay (and any hyperparameter) uniquely per parameter. This allows fastai Optimizers to have half as many parameter groups as PyTorch Optimizers, which simplifies discriminative learning rates out of the box. i.e. different learning rates per parameter group. 

Second, fastai Optimizers are defined as a series of optim step callbacks, which are called by `Optimizer` one after another. Which allows fastai to reuse optimizer steps across multiple optimizers.

Like PyTorch optimizers, `Optimizer` is initialized with the parameters we want to train and any optimizer specific hyperparameters.

In [None]:
class Optimizer(_BaseOptimizer):
    "Base optimizer class for the fastai library, updating `params` with `cbs`"
    _keep_on_clear = ['force_train', 'do_wd']
    def __init__(self,
        params:Tensor|Iterable, # Model parameters
        cbs:callable|MutableSequence, # `Optimizer` step callbacks
        **defaults # Hyper parameters default values
    ):
        params = L(params)
        self.cbs,self.state = L(cbs),defaultdict(dict)
        defaults = merge(*self.cbs.attrgot('defaults'), defaults)
        self.param_lists = L(L(p) for p in params) if isinstance(params[0], (L,list)) else L([params])
        self.hypers = L({} for _ in range_of(self.param_lists))
        self.set_hypers(**defaults)
        self.frozen_idx = 0

The Optimizer step function loops through all parameter groups and collect parameter groups default values, and then loops through all parameters and applies the optimization step to any parameter with a gradient.

There is an additional loop for optimizer callbacks, which apply the optimizer steps sequentially and update the per parameter state dictionary.

In [None]:
def all_params(self,
    with_grad:bool=False # Get all parameters. If `True` select only those with a gradient
):
    res = L((p,pg,self.state[p],hyper) for pg,hyper in zip(self.param_lists,self.hypers) for p in pg)
    return L(o for o in res if hasattr(o[0], 'grad') and o[0].grad is not None) if with_grad else res


def step(self):
    for p,pg,state,hyper in self.all_params(with_grad=True):
        for cb in self.cbs: 
            state = _update(state, cb(p, **{**state, **hyper}))
        self.state[p] = state

A simplified SGDW (without momentum) shows how we would define a fastai optimizer out of separate Optimizer callbacks.

In [None]:
def weight_decay(p, lr, wd, do_wd=True, **kwargs):
    "Weight decay as decaying `p` with `lr*wd`"
    if do_wd and wd!=0: 
        p.data.mul_(1 - lr*wd)

def sgd_step(p, lr, **kwargs):
    "Step for SGD with `lr`"
    p.data.add_(p.grad.data, alpha=-lr)

def SGD(
    params:Tensor|Iterable, # Model parameters
    lr:float|slice, # Default learning rate
    wd:float=0., # Optional true weight decay
) -> Optimizer:
    "A SGD `Optimizer`"
    cbs = [weight_decay, sgd_step]
    return Optimizer(params, cbs, lr=lr, wd=wd)

## Implementing AdamW from Scratch

AdamW was introduced as an improvement of the Adam Optimizer in [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) by Ilya Loshchilov & Frank Hutter.

It introduces decoupled (or true) weight decay, which unlike L2 regularization, does not directly interact with the learning rate, allowing weight decay and the learning rate to be tuned independently. AdamW's weight decay also outperforms L2 regularization on most tasks. (And thus is the default Adam in fastai)

AdamW is defined by the following formula:

$$
\begin{aligned}
t \leftarrow t + 1 \\[0.5em]
\bm{m}_t \leftarrow \beta_1 \bm{m}_{t-1} + (1 - \beta_1) \bm{g}_t \\[0.5em]
\bm{v}_t \leftarrow \beta_2 \bm{v}_{t-1} + (1 - \beta_2) \bm{g}^2_t \\[0.5em]
\hat{\bm{m}}_t \leftarrow \bm{m}_t/(1 - \beta_1^t) \\[0.5em]
\hat{\bm{{v}}}_t \leftarrow \bm{v}_t/(1 - \beta_2^t) \\[0.5em]
\eta_t \leftarrow \text{SetScheduleMultiplier}(t) \\[0.5em]
\bm{\theta}_t \leftarrow \bm{\theta}_{t-1} - \eta_t \left( \hat{\bm{m}}_t / (\sqrt{\hat{\bm{v}}_t} + \epsilon) + \lambda\bm{\theta}_{t-1} \right) \\[0.5em]
\end{aligned}
$$
which we will go through and implement line by line.

First, we assume we have some parameter `param` and a gradient of that parameter `grad` passed to our AdamW optimizer.

In [None]:
param = Tensor(0)
grad = Tensor(0)

The first line simply keeps track of the number of steps we've taken, incrementing it by one every pass
$$t \leftarrow t + 1$$

In [None]:
step = 0

step += 1

Next, assuming we have a gradient from a PyTorch model, we calculate the moving average $\bm{m}_t$ of the gradient $\bm{g}_t$. 
$$\bm{m}_t \leftarrow \beta_1 \bm{m}_{t-1} + (1 - \beta_1) \bm{g}_t$$

In [None]:
grad_avg = None
mom = 0.9

if grad_avg is None: 
    grad_avg = torch.zeros_like(param, memory_format=torch.preserve_format)

# average_grad
grad_avg = grad_avg.mul(mom).add(grad, alpha=1-mom)

Next, we calculate the squared moving average $\bm{v}_t$ of the gradient $\bm{g}_t$. 
$$\bm{v}_t \leftarrow \beta_2 \bm{v}_{t-1} + (1 - \beta_2) \bm{g}^2_t$$
where `addcmul` is a single op to add the weighted multiplied terms to the first term.

In [None]:
sqr_avg = None
sqr_mom = 0.99

if sqr_avg is None: 
    sqr_avg  = torch.zeros_like(param, memory_format=torch.preserve_format)

# average_sqr_grad
sqr_avg = sqr_avg.mul(sqr_mom).addcmul(grad, grad, value=1-sqr_mom)

Next we calculate the debias terms and debias our gradient moving average and gradient squared moving average
$$
\hat{\bm{m}}_t \leftarrow \bm{m}_t/(1 - \beta_1^t) \\[0.5em]
\hat{\bm{{v}}}_t \leftarrow \bm{v}_t/(1 - \beta_2^t)
$$

In [None]:
debias1 = 1-mom**step
debias2 = 1-sqr_mom**step

grad_avg_d = grad_avg/debias1
sqr_avg_d  = sqr_avg /debias2

The next line in the equation tells us that the Learning Rate has been set by a scheduler
$$\eta_t \leftarrow \text{SetScheduleMultiplier}(t)$$
And finally we arrive to the AdamW step.
$$\bm{\theta}_t \leftarrow \bm{\theta}_{t-1} - \eta_t \left( \hat{\bm{m}}_t / (\sqrt{\hat{\bm{v}}_t} + \epsilon) + \lambda\bm{\theta}_{t-1} \right)$$
This is a lot in one equation, so we'll break it down into multiple parts.

First lets look at the decoupled (true) weight decay all the way on the right.
$$\lambda\bm{\theta}_{t-1}$$
The full equation applied to the parameter weights includes the negative learning rate, so we add that to our code
$$\bm{\theta}_t \leftarrow \bm{\theta}_{t-1} - \eta_t\lambda\bm{\theta}_{t-1}$$

In [None]:
lr = 3e-3
wd = 1e-2

param = param.mul(1 - lr*wd)

Next, we calculate the actual weight update from our debiased moving averages, taking the square root of the squared moving average, adding an epsilon term for mathematical stability, and dividing the moving average by the square root of the squared moving average:
$$\bm{u}_t = \hat{\bm{m}}_t / (\sqrt{\hat{\bm{v}}_t} + \epsilon)$$

In [None]:
eps = 1e-5

update = grad_avg_d/(torch.sqrt(sqr_avg_d) + eps)

And then we update the parameter by our Adam step
$$\bm{\theta}_t \leftarrow \bm{\theta}_{t-1} - \eta_t \bm{u}_t$$

In [None]:
param = torch.add(param, update, value =-lr)

However, in practice, this is slow, so we combine the entire update step into as few operations as possible using `addcdiv`:
$$\bm{\theta}_t \leftarrow \bm{\theta}_{t-1} - \eta_t \left( \hat{\bm{m}}_t / (\sqrt{\hat{\bm{v}}_t} + \epsilon) \right)$$

In [None]:
param = torch.addcdiv(param, grad_avg_d, torch.sqrt(sqr_avg_d) + eps, value = -lr)

And there we have it, the entire AdamW Optimizer update algorithm
$$
\begin{aligned}
t \leftarrow t + 1 \\[0.5em]
\bm{m}_t \leftarrow \beta_1 \bm{m}_{t-1} + (1 - \beta_1) \bm{g}_t \\[0.5em]
\bm{v}_t \leftarrow \beta_2 \bm{v}_{t-1} + (1 - \beta_2) \bm{g}^2_t \\[0.5em]
\hat{\bm{m}}_t \leftarrow \bm{m}_t/(1 - \beta_1^t) \\[0.5em]
\hat{\bm{{v}}}_t \leftarrow \bm{v}_t/(1 - \beta_2^t) \\[0.5em]
\bm{\theta}_t \leftarrow \bm{\theta}_{t-1} - \eta_t \left( \hat{\bm{m}}_t / (\sqrt{\hat{\bm{v}}_t} + \epsilon) + \lambda\bm{\theta}_{t-1} \right) \\[0.5em]
\end{aligned}
$$
in one fastai optimizer compatible method.

In [None]:
def adam_step(param:Tensor, grad:Tensor, lr:float, wd:float, mom:float, sqr_mom:float, eps:float, 
              grad_avg:Optional[Tensor]=None, sqr_avg:Optional[Tensor]=None, step:int=0):
    if grad_avg is None: 
        grad_avg = torch.zeros_like(param, memory_format=torch.preserve_format)
    if sqr_avg is None: 
        sqr_avg  = torch.zeros_like(param, memory_format=torch.preserve_format)

    step += 1
    if wd != 0:
        # true weight_decay
        param = param.mul(1 - lr*wd)

    # average_grad
    grad_avg = grad_avg.mul(mom).add(grad, alpha=1-mom)

    # average_sqr_grad
    sqr_avg = sqr_avg.mul(sqr_mom).addcmul(grad, grad, value=1-sqr_mom)

    # adam_step
    debias1 = 1-mom**step
    debias2 = 1-sqr_mom**step
    param = torch.addcdiv(param, grad_avg, torch.sqrt(sqr_avg/debias2) + eps, value = -lr / debias1)

    return {'grad_avg': grad_avg, 'sqr_avg': sqr_avg, 'step': step}

## Can We Go Faster?

The only problem with this formulation of AdamW, and the fastai callbacks version, is it's slow. Why is that? Because PyTorch must call all of these operations on every Tensor in our model sequentially. Which means we are relaunching the same `mul`, `add`, `addcmul`, `addcdiv`, etc Cuda kernels multiple times.

Is there a way to get around this and apply the optimizer step even faster? 

Yes. Two actually.

The first way is to use PyTorch's undocumented `foreach` methods. Unlike `torch.mul` which applys a multiply operation on one Tensor at a time, `torch._foreach_mul` takes a list of Tensors and applies the multiply operation on them in parallel with as few Cuda kernel launches as possible (I assume, I haven't looked at the Cuda code for these methods).

So let’s rewrite AdamW, except using the faster `foreach` methods.

First, we need an `Optimizer` that loops through our model's parameters, grabs all the Tensors with gradients, and adds them to lists.

In [None]:
#|exporti
class AdamForEachOptimizer(ForEachOptimizer):
    "An `ForEachOptimizer` with a modified step for `adam_foreach_step`"
    @torch.no_grad()
    def step(self, closure=None):
        if closure is not None: raise NotImplementedError("fastai optimizers currently do not support closure")
        for pg, hyper in zip(self.param_lists, self.hypers):
            pl, gl, grad_avg, sqr_avg, ones, steps, do_wd = [], [], [], [], [], [], []

            for p in pg:
                if hasattr(p, 'grad') and p.grad is not None:
                    state = self.state[p]

                    if 'step' not in state:
                        state['grad_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        state['sqr_avg']  = torch.zeros_like(p, memory_format=torch.preserve_format)
                        state['step'] = 0

                    state['step'] += 1
                    pl.append(p)
                    gl.append(p.grad)
                    grad_avg.append(state['grad_avg'])
                    sqr_avg.append(state['sqr_avg'])
                    steps.append(state['step'])
                    do_wd.append(state.get('do_wd', True))

            self.opt_step(param=pl, grad=gl, grad_avg=grad_avg, sqr_avg=sqr_avg,
                          steps=np.array(steps, dtype=np.int32), do_wd=np.array(do_wd, dtype=bool), 
                          decouple_wd=self.decouple_wd, **hyper)

We don't need to keep track of all the steps in the AdamW step implementation, since we are already doing that in `AdamForEachOptimizer.step`
```
if 'step' not in state:
    state['step'] = 0

state['step'] += 1
```
so we can jump straight to the gradient $\bm{g}_t$ moving average $\bm{m}_t$ .
$$\bm{m}_t \leftarrow \beta_1 \bm{m}_{t-1} + (1 - \beta_1) \bm{g}_t$$
The `torch._foreach_mul_` method takes an existing list of Tensors, `grad_avg`, and applies the operation `mul` on all of the Tensors, just like `torch.mul_(grad_avg, mom)` would.

In [None]:
torch._foreach_mul_(grad_avg, mom)
torch._foreach_add_(grad_avg, grad, alpha=1-mom)

Next, we calculate the squared moving average $\bm{v}_t$ of the gradient $\bm{g}_t$. 
$$\bm{v}_t \leftarrow \beta_2 \bm{v}_{t-1} + (1 - \beta_2) \bm{g}^2_t$$
where `addcmul` is a single op to add the weighted multiplied terms to the first term.

In [None]:
torch._foreach_mul_(sqr_avg, sqr_mom)
torch._foreach_addcmul_(sqr_avg, grad, grad, value=1-sqr_mom)

Next, we calculate the debias terms and debias our gradient moving average and gradient squared moving average
$$
\hat{\bm{m}}_t \leftarrow \bm{m}_t/(1 - \beta_1^t) \\[0.5em]
\hat{\bm{{v}}}_t \leftarrow \bm{v}_t/(1 - \beta_2^t)
$$
In this case, steps are a NumPy array, so all the debias terms (which might be different per Tensor) can be vectorized. But we won't apply them to `grad_avg` and `sqr_avg` until later. 

Note, we include the learning rate in the first debias term and take the square root of the squared moving average debias term here, so we don't have to repeat ops later.

In [None]:
steps = np.array([1,1,1,...])

debias1 = -lr / (1 - mom**steps)
debias2 = np.sqrt(1 - sqr_mom**steps)

Just like our slow implementation, we will break up the AdamW step into multiple substeps
$$\bm{\theta}_t \leftarrow \bm{\theta}_{t-1} - \eta_t \left( \hat{\bm{m}}_t / (\sqrt{\hat{\bm{v}}_t} + \epsilon) + \lambda\bm{\theta}_{t-1} \right)$$
First being the weight decay term
$$\bm{\theta}_t \leftarrow \bm{\theta}_{t-1} - \eta_t\lambda\bm{\theta}_{t-1}$$
Not all parameters necessarily apply weight decay, so I use the `np.where` to calculate weight decay for those parameters and multiply by 1 for the parameters without weight decay.

In [None]:
do_wd = np.array([True, False, True, ...])

if wd != 0:
    # weight_decay
    wd = np.where(do_wd, 1-lr*wd, 1.)
    torch._foreach_mul_(param, scalars=wd.tolist())

Next, we calculate the debiased root of the squared moving average
$$\bm{r}_{t} = \sqrt{\hat{\bm{v}}_t} + \epsilon$$
due to limitations in the `foreach` methods, this becomes three lines of code:

In [None]:
sqr_avg_debias2 = torch._foreach_sqrt(sqr_avg)
torch._foreach_div_(sqr_avg_debias2, debias2.tolist())
torch._foreach_add_(sqr_avg_debias2, eps)

And then we finally calculate the rest of the Adam step
$$\bm{\theta}_t \leftarrow \bm{\theta}_{t-1} - \eta_t \left( \hat{\bm{m}}_t / \bm{r}_{t} \right)$$
where `debias1` term has the negative learning rate included. 

In [None]:
torch._foreach_addcdiv_(param, grad_avg, sqr_avg_debias2, scalars=debias1.tolist())

Now we reimplemented AdamW in PyTorch again.
$$
\begin{aligned}
\bm{m}_t \leftarrow \beta_1 \bm{m}_{t-1} + (1 - \beta_1) \bm{g}_t \\[0.5em]
\bm{v}_t \leftarrow \beta_2 \bm{v}_{t-1} + (1 - \beta_2) \bm{g}^2_t \\[0.5em]
\hat{\bm{m}}_t \leftarrow \bm{m}_t/(1 - \beta_1^t) \\[0.5em]
\hat{\bm{{v}}}_t \leftarrow \bm{v}_t/(1 - \beta_2^t) \\[0.5em]
\bm{\theta}_t \leftarrow \bm{\theta}_{t-1} - \eta_t \left( \hat{\bm{m}}_t / (\sqrt{\hat{\bm{v}}_t} + \epsilon) + \lambda\bm{\theta}_{t-1} \right) \\[0.5em]
\end{aligned}
$$
Except it runs 21 to 293 percent faster than our first implementation (across models I tested). You can find this, and a handful of other fastai optimizers reimplemented using foreach methods in [fastxtend](https://fastxtend.benjaminwarner.dev/optimizer.fused.html).

In [None]:
def adam_foreach_step(p:list[Tensor], g:list[Tensor], grad_avg:list[Tensor], sqr_avg:list[Tensor], ones:list[Tensor|None], 
                      steps:np.ndarray[Any, int], do_wd:np.ndarray[Any, bool], lr:float, wd:float, mom:float, sqr_mom:float, 
                      eps:float, **kwargs):
    if wd != 0:
        # weight_decay
        wd = np.where(do_wd, 1-lr*wd, 1.)
        torch._foreach_mul_(p, scalars=wd.tolist())

    # average_grad, dampening=True
    torch._foreach_mul_(grad_avg, mom)
    torch._foreach_add_(grad_avg, g, alpha=1-mom)

    # average_sqr_grad
    torch._foreach_mul_(sqr_avg, sqr_mom)
    torch._foreach_addcmul_(sqr_avg, g, g, value=1-sqr_mom)

    # adam_step
    debias1 = -lr / (1 - mom**steps)
    debias2 = np.sqrt(1 - sqr_mom**steps)

    sqr_avg_debias2 = torch._foreach_sqrt(sqr_avg)
    torch._foreach_div_(sqr_avg_debias2, debias2.tolist())
    torch._foreach_add_(sqr_avg_debias2, eps)

    torch._foreach_addcdiv_(p, grad_avg, sqr_avg_debias2, debias1.tolist())

To use them, simply import `fastxtend.optimizer` pass to `opt_func` like normal. 

In [None]:
from fastai.vision.all import *
from fastxtend.optimizer.all import *

Learner(..., opt_func=adam(foreach=True))

## But Can We Still Go Faster?
Yes. The foreach methods above are vertically fused operations, where we sequentually apply multiple operations to multiple tensors at once.

But we can both horizontally and vertically fuse operations to apply multiple steps to multiple tensors at once, instead of sequentially.

[Nvidia's Apex](https://github.com/NVIDIA/apex) package does this. It is custom Cuda code, so we won't look at it today. But it's even faster. But not 100% fastai compatible due to not supporting per-parameter weight decay.