In [None]:
#|default_exp optimizer.torchscript

# TorchScript Optimizers
> fastai optimizers compiled with TorchScript for improved performance

In [None]:
#|export
from __future__ import annotations
from typing import Optional, Dict

from torch.nn import Parameter

from fastcore.basics import range_of, merge

from fastai.optimizer import (Optimizer, _update, weight_decay, l2_reg, average_grad, sgd_step, 
                              momentum_step, average_sqr_grad, rms_prop_step, step_stat, adam_step, 
                              radam_step, qhadam_step, larc_layer_lr, larc_step, lamb_step, Lookahead)

from fastxtend.imports import *

In [None]:
#|hide
from fastxtend.test_utils import *

## Test Utils -

In [None]:
#|hide
def tst_param(val, grad=None):
    "Create a tensor with `val` and a gradient of `grad` for testing"
    res = tensor([val]).float()
    res.grad = tensor([val/10 if grad is None else grad]).float()
    return res

In [None]:
#|hide
def tst_params():
    r = L.range(4)
    return r.map(tst_param)

## JitOptimizer -

In [None]:
#|exporti
def _update(
    state:dict,
    new=None # New values to update `state` dict
):
    if isinstance(new, dict): state.update(new)

In [None]:
#|export
class JitOptimizer(Optimizer):
    "An `Optimizer` with a modified step for TorchScript optimizers"
    @torch.no_grad()
    def step(self, closure=None):
        if closure is not None: raise NotImplementedError("fastai optimizers currently do not support closure")
        for pg, hyper in zip(self.param_lists, self.hypers):
            for p in pg:
                if hasattr(p, 'grad') and p.grad is not None:
                    _update(self.state[p], self.cbs[0](p, p.grad, **{**self.state[p], **hyper}))

## Optimizers

:::{.callout-note}
Documentation for individual optimizers lightly adapted from the [fastai optimizer documentation](https://docs.fast.ai/optimizer.html).
:::

## SGD -

In [None]:
#|exporti
@torch.jit.script
def sgd_jit_step(p:Tensor, grad:Tensor, lr:float, wd:float, mom:float, decouple_wd:bool, grad_avg:Optional[Tensor]=None, 
                 do_wd:bool=True, dampening:bool=False, force_train:Optional[bool]=None):
    dp = p
    if do_wd and wd != 0:
        if decouple_wd:
            # weight_decay
            dp = dp.mul(1 - lr*wd)
        else:
            # l2_reg
            grad = grad.add(dp, alpha=wd)

    if mom != 0:
        if grad_avg is None: 
            grad_avg = torch.zeros_like(dp, memory_format=torch.preserve_format)

        # average_grad
        damp = 1-mom if dampening else 1.
        grad_avg = grad_avg.mul(mom)
        grad_avg = grad_avg.add(grad, alpha=damp)

        # momentum_step
        dp = dp.add(grad_avg, alpha=-lr)
        p.set_(dp)
        return {'grad_avg': grad_avg}
    else:
        # sgd_step
        dp = dp.add(grad, alpha=-lr)
        p.set_(dp)
        return None

In [None]:
#|export
def SGD(params, lr, mom=0., wd=0., decouple_wd=True, jit=False):
    "A `Optimizer` or `JitOptimizer` for SGD with `lr` and `mom` and `params`"
    if jit:
        cb = partial(sgd_jit_step, decouple_wd=decouple_wd)
        return JitOptimizer(params, cb, lr=lr, mom=mom, wd=wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        if mom != 0: cbs.append(average_grad)
        cbs.append(sgd_step if mom==0 else momentum_step)
        return Optimizer(params, cbs, lr=lr, mom=mom, wd=wd)

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients).

In [None]:
#|hide

# Test Vanilla SGD
params_org = tst_params()
opt_org = SGD(params_org, lr=0.1, jit=False)
opt_org.step()

params_jit = tst_params()
opt_jit = SGD(params_jit, lr=0.1, jit=True)
opt_jit.step()

test_close([p.item() for p in params_org], [i*0.99 for i in range(4)])
test_close([p.item() for p in params_org], [p.item() for p in params_jit])

opt_org.step()
opt_jit.step()
test_close([p.item() for p in params_org], [i*0.98 for i in range(4)])
test_close([p.item() for p in params_org], [p.item() for p in params_jit])

## RMSProp -

In [None]:
#|exporti
@torch.jit.script
def rmsprop_jit_step(p:Tensor, grad:Tensor, lr:float, wd:float, mom:float, sqr_mom:float, eps:float, decouple_wd:bool, 
                     grad_avg:Optional[Tensor]=None, sqr_avg:Optional[Tensor]=None, do_wd:bool=True, force_train:Optional[bool]=None):
    dp = p
    if do_wd and wd != 0:
        if decouple_wd:
            # weight_decay
            dp = dp.mul(1 - lr*wd)
        else:
            # l2_reg
            grad = grad.add(dp, alpha=wd)

    if sqr_avg is None: 
        sqr_avg = torch.zeros_like(dp, memory_format=torch.preserve_format)

    if mom != 0:
        if grad_avg is None: 
            grad_avg = torch.zeros_like(dp, memory_format=torch.preserve_format)

        # average_grad, dampening=False
        grad_avg = torch.mul(grad_avg, mom)
        grad_avg = torch.add(grad_avg, grad)

        # average_sqr_grad
        sqr_avg = torch.mul(sqr_avg, sqr_mom)
        sqr_avg = torch.addcmul(sqr_avg, grad, grad, value=1-sqr_mom)

        # rms_prop_step
        denom = torch.sqrt(sqr_avg)
        denom = torch.add(denom, eps)
        dp = torch.addcdiv(dp, grad_avg, denom, value=-lr)
        p.set_(dp)
        return {'grad_avg': grad_avg, 'sqr_avg': sqr_avg}
    else:
        # average_sqr_grad
        sqr_avg = torch.mul(sqr_avg, sqr_mom)
        sqr_avg = torch.addcmul(sqr_avg, grad, grad, value=1-sqr_mom)
        
        # rms_prop_step
        denom = torch.sqrt(sqr_avg)
        denom = torch.add(denom, eps)
        dp = dp.addcdiv(grad, denom, value=-lr)
        p.set_(dp)
        return {'sqr_avg': sqr_avg}

In [None]:
#|export
def RMSProp(params, lr, sqr_mom=0.99, mom=0., eps=1e-8, wd=0., decouple_wd=True, jit=False):
    "A `Optimizer` or `JitOptimizer` for RMSProp with `lr`, `sqr_mom`, `mom` and `params`"
    if jit:
        cb = partial(rmsprop_jit_step, decouple_wd=decouple_wd)
        return JitOptimizer(params, cb, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += ([average_sqr_grad] if mom==0. else [average_grad, average_sqr_grad])
        cbs.append(rms_prop_step)
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, wd=wd, eps=eps)

RMSProp was introduced by Geoffrey Hinton in his [course](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf). What is named `sqr_mom` here is the `alpha` in the course. 

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients).

In [None]:
#|hide
#Without momentum
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = RMSProp(params_org, lr=0.1, jit=False)
opt_org.step()

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = RMSProp(params_jit, lr=0.1, jit=True)
opt_jit.step()

test_close(params_org[0], tensor([0.,1.,2.]))
test_close(params_org[0], params_jit[0])

opt_org.step()
opt_jit.step()
step = - 0.1 * 0.1 / (math.sqrt((0.01*0.99+0.01) * 0.1**2) + 1e-8)
test_close(params_org[0], tensor([step, 1+step, 2+step]))
test_close(params_org[0], params_jit[0])

In [None]:
#|hide
#With momentum
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = RMSProp(params_org, lr=0.1, mom=0.9, jit=False)
opt_org.step()

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = RMSProp(params_jit, lr=0.1, mom=0.9, jit=True)
opt_jit.step()

test_close(params_org[0], tensor([0.,1.,2.]))
test_close(params_org[0], params_jit[0])

opt_org.step()
opt_jit.step()
step = - 0.1 * (0.1 + 0.9*0.1) / (math.sqrt((0.01*0.99+0.01) * 0.1**2) + 1e-8)
test_close(params_org[0], tensor([step, 1+step, 2+step]))
test_close(params_org[0], params_jit[0])

## Adam -

In [None]:
#|exporti
def debias(beta:float, step:int):
    "Simple debias calculation"
    return 1-beta**step

In [None]:
#|exporti
@torch.jit.script
def adam_jit_step(p:Tensor, grad:Tensor, lr:float, wd:float, mom:float, sqr_mom:float, eps:float, 
                  decouple_wd:bool, grad_avg:Optional[Tensor]=None, sqr_avg:Optional[Tensor]=None, 
                  do_wd:bool=True, step:int=0, force_train:Optional[bool]=None):
    dp = p
    step += 1
    if do_wd and wd != 0:
        if decouple_wd:
            # weight_decay
            dp = dp.mul(1 - lr*wd)
        else:
            # l2_reg
            grad = grad.add(dp, alpha=wd)

    if grad_avg is None: 
        grad_avg = torch.zeros_like(dp, memory_format=torch.preserve_format)
    if sqr_avg is None: 
        sqr_avg  = torch.zeros_like(dp, memory_format=torch.preserve_format)

    # average_grad, dampening=True
    grad_avg = torch.mul(grad_avg, mom)
    grad_avg = torch.add(grad_avg, grad, alpha=1-mom)

    # average_sqr_grad
    sqr_avg = torch.mul(sqr_avg, sqr_mom)
    sqr_avg = torch.addcmul(sqr_avg, grad, grad, value=1-sqr_mom)

    # adam_step
    debias1 = debias(mom, step)
    debias2 = debias(sqr_mom, step)
    dp = torch.addcdiv(dp, grad_avg, torch.sqrt(sqr_avg/debias2) + eps, value = -lr / debias1)
    p.set_(dp)

    return torch.jit.annotate(Dict[str, Union[Tensor, int]], {'grad_avg': grad_avg, 'sqr_avg': sqr_avg, 'step': step})

In [None]:
#|export
def Adam(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0.01, decouple_wd=True, jit=False):
    "A `Optimizer` or `JitOptimizer` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`"
    if jit:
        cb = partial(adam_jit_step, decouple_wd=decouple_wd)
        return JitOptimizer(params, cb, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, adam_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

Adam was introduced by Diederik P. Kingma and Jimmy Ba in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980). For consistency across optimizers, fastai renamed `beta1` and `beta2` in the paper to `mom` and  `sqr_mom`. Note that the defaults also differ from the paper (0.99 for `sqr_mom` or `beta2`, 1e-5 for `eps`). Those values seem to be better from experiments in a wide range of situations.

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients).

:::{.callout-note}
Don't forget that `eps` is an hyper-parameter you can change. Some models won't train without a very high `eps` like 0.1 (intuitively, the higher `eps` is, the closer we are to normal SGD). The usual default of 1e-8 is often too extreme in the sense we don't manage to get as good results as with SGD.
:::

In [None]:
#|hide
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Adam(params_org, lr=0.1, wd=0, jit=False)
opt_org.step()

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = Adam(params_jit, lr=0.1, wd=0, jit=True)
opt_jit.step()

step = -0.1 * 0.1 / (math.sqrt(0.1**2) + 1e-8)
test_close(params_org[0], tensor([1+step, 2+step, 3+step]))
test_close([p[0].item() for p in params_org], [p[0].item() for p in params_jit])

opt_org.step()
opt_jit.step()
test_close(params_org[0], tensor([1+2*step, 2+2*step, 3+2*step]), eps=1e-3)
test_close([p[0].item() for p in params_org], [p[0].item() for p in params_jit])

## RAdam -

In [None]:
#|exporti
@torch.jit.script
def radam_jit_step(p:Tensor, grad:Tensor, lr:float, wd:float, mom:float, sqr_mom:float, eps:float, beta:float,
                   decouple_wd:bool, grad_avg:Optional[Tensor]=None, sqr_avg:Optional[Tensor]=None,
                   do_wd:bool=True, step:int=0, force_train:Optional[bool]=None):
    dp = p
    step += 1
    if do_wd and wd != 0:
        if decouple_wd:
            # weight_decay
            dp = dp.mul(1 - lr*wd)
        else:
            # l2_reg
            grad = grad.add(dp, alpha=wd)

    if grad_avg is None: 
        grad_avg = torch.zeros_like(dp, memory_format=torch.preserve_format)
    if sqr_avg is None: 
        sqr_avg  = torch.zeros_like(dp, memory_format=torch.preserve_format)

    # average_grad, dampening=True
    grad_avg = torch.mul(grad_avg, mom)
    grad_avg = torch.add(grad_avg, grad, alpha=1-mom)

    # average_sqr_grad
    sqr_avg = torch.mul(sqr_avg, sqr_mom)
    sqr_avg = torch.addcmul(sqr_avg, grad, grad, value=1-sqr_mom)

    # radam_step
    debias1 = debias(mom, step)
    debias2 = debias(sqr_mom, step)
    r_inf = 2/(1-sqr_mom) - 1
    r = r_inf - 2*step*sqr_mom**step/(1-sqr_mom**step)
    
    if r > 5:
        v = math.sqrt(((r-4) * (r-2) * r_inf)/((r_inf-4)*(r_inf-2)*r))
        denom = torch.sqrt(sqr_avg/debias2)
        if eps != 0: 
            denom = denom + eps
        if beta != 0: 
            denom = F.softplus(denom, beta)
        dp = torch.addcdiv(dp, grad_avg, denom, value = -lr*v / debias1)
    else:
        dp = torch.add(dp, grad_avg, alpha=-lr / debias1)
    p.set_(dp)

    return torch.jit.annotate(Dict[str, Union[Tensor, int]], {'grad_avg': grad_avg, 'sqr_avg': sqr_avg, 'step': step})

In [None]:
#|export
def RAdam(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., beta=0., decouple_wd=True, jit=False):
    "A `Optimizer` or `JitOptimizer` for RAdam with `lr`, `mom`, `sqr_mom`, `eps` and `params`"
    if jit:
        cb = partial(radam_jit_step, decouple_wd=decouple_wd)
        return JitOptimizer(params, cb, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, beta=beta)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, radam_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, beta=beta)

RAdam (for rectified Adam) was introduced by Zhang et al. in [On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1907.08610) to slightly modify the Adam optimizer to be more stable at the beginning of training (and thus not require a long warmup). They use an estimate of the variance of the moving average of the squared gradients (the term in the denominator of traditional Adam) and rescale this moving average by this term before performing the update.

This version also incorporates [SAdam](https://arxiv.org/abs/1908.00700); set `beta` to enable this (definition same as in the paper).

In [None]:
#|hide
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = RAdam(params_org, lr=0.1, jit=False)

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = RAdam(params_jit, lr=0.1, jit=True)

#The r factor is lower than 5 during the first 5 steps so updates use the average of gradients (all the same)
r_inf = 2/(1-0.99) - 1
for i in range(5): 
    r = r_inf - 2*(i+1)*0.99**(i+1)/(1-0.99**(i+1))
    assert r <= 5
    opt_org.step()
    opt_jit.step()
p = tensor([0.95, 1.9, 2.85])
test_close(params_org[0], p)
test_close(params_org[0], params_jit[0])

#The r factor is greater than 5 for the sixth step so we update with RAdam
r = r_inf - 2*6*0.99**6/(1-0.99**6)
assert r > 5
opt_org.step()
opt_jit.step()
v = math.sqrt(((r-4) * (r-2) * r_inf)/((r_inf-4)*(r_inf-2)*r))
step = -0.1*0.1*v/(math.sqrt(0.1**2) + 1e-8)
test_close(params_org[0], p+step)
test_close(params_org[0], params_jit[0])

## QHAdam -

In [None]:
#|exporti
@torch.jit.script
def qhadam_jit_step(p:Tensor, grad:Tensor, lr:float, wd:float, mom:float, sqr_mom:float, eps:float,
                    nu_1:float, nu_2:float, decouple_wd:bool, grad_avg:Optional[Tensor]=None, 
                    sqr_avg:Optional[Tensor]=None, do_wd:bool=True, step:int=0, force_train:Optional[bool]=None):
    dp = p
    step += 1
    if do_wd and wd != 0:
        if decouple_wd:
            # weight_decay
            dp = dp.mul(1 - lr*wd)
        else:
            # l2_reg
            grad = grad.add(dp, alpha=wd)

    if grad_avg is None: 
        grad_avg = torch.zeros_like(dp, memory_format=torch.preserve_format)
    if sqr_avg is None: 
        sqr_avg  = torch.zeros_like(dp, memory_format=torch.preserve_format)

    # average_grad, dampening=True
    grad_avg = torch.mul(grad_avg, mom)
    grad_avg = torch.add(grad_avg, grad, alpha=1-mom)

    # average_sqr_grad
    sqr_avg = torch.mul(sqr_avg, sqr_mom)
    sqr_avg = torch.addcmul(sqr_avg, grad, grad, value=1-sqr_mom)

    # qhadam_step
    debias1 = debias(mom, step)
    debias2 = debias(sqr_mom, step)
    dp = torch.addcdiv(dp, ((1-nu_1) * grad) + (nu_1 * (grad_avg / debias1)),
                       torch.sqrt(((1 - nu_2) * (grad)**2) + (nu_2 * (sqr_avg / debias2))) + eps,
                       value = -lr)
    p.set_(dp)

    return torch.jit.annotate(Dict[str, Union[Tensor, int]], {'grad_avg': grad_avg, 'sqr_avg': sqr_avg, 'step': step})

In [None]:
#|export
def QHAdam(params, lr, mom=0.999, sqr_mom=0.999, nu_1=0.7, nu_2=1.0, eps=1e-8, wd=0., decouple_wd=True, jit=True):
    "An `Optimizer` for Adam with `lr`, `mom`, `sqr_mom`, `nus`, eps` and `params`"
    if jit:
        cb = partial(qhadam_jit_step, decouple_wd=decouple_wd)
        return JitOptimizer(params, cb, lr=lr, nu_1=nu_1, nu_2=nu_2, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, qhadam_step]
        return Optimizer(params, cbs, lr=lr, nu_1=nu_1, nu_2=nu_2, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

QHAdam (for Quasi-Hyperbolic Adam) was introduced by Ma & Yarats in [Quasi-Hyperbolic Momentum and Adam for Deep Learning](https://arxiv.org/pdf/1810.06801.pdf) as a *"computationally cheap, intuitive to interpret, and simple to implement"* optimizer. Additional code can be found in their [qhoptim repo](https://github.com/facebookresearch/qhoptim). QHAdam is based on QH-Momentum, which introduces the immediate discount factor `nu`, encapsulating plain SGD (`nu = 0`) and momentum (`nu = 1`). QH-Momentum is defined below, where g_t+1 is the update of the moment. An interpretation of QHM is as a nu-weighted average of the momentum update step and the plain SGD update step.

> θ_t+1 ← θ_t − lr * [(1 − nu) · ∇L_t(θ_t) + nu · g_t+1]

QHAdam takes the concept behind QHM above and applies it to Adam, replacing both of Adam’s moment estimators with quasi-hyperbolic terms. 

The paper's suggested default parameters are `mom = 0.999`, `sqr_mom = 0.999`, `nu_1 = 0.7` and `and nu_2 = 1.0`. When training is not stable, it is possible that setting `nu_2 < 1` can improve stability by imposing a tighter step size bound. Note that QHAdam recovers Adam when `nu_1 = nu_2 = 1.0`. QHAdam recovers RMSProp (Hinton et al., 2012) when `nu_1 = 0` and `nu_2 = 1`, and NAdam (Dozat, 2016) when `nu_1 = mom` and `nu_2 = 1`.

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients).

In [None]:
#hide
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = QHAdam(params_org, lr=0.1, jit=False)
opt_org.step()

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = QHAdam(params_jit, lr=0.1, jit=True)
opt_jit.step()

step = -0.1 * (((1-0.7) * 0.1) + (0.7 * 0.1)) / (
     math.sqrt(((1-1.0) * 0.1**2) + (1.0 * 0.1**2)) + 1e-8) 

test_close(params_org[0], tensor([1+step, 2+step, 3+step]))
test_close(params_org[0], params_jit[0])

opt_org.step()
opt_jit.step()
test_close(params_org[0], tensor([1+2*step, 2+2*step, 3+2*step]), eps=1e-3)
test_close(params_org[0], params_jit[0])

## LARS/LARC -

In [None]:
#|exporti
@torch.jit.script
def larc_jit_step(p:Tensor, grad:Tensor, lr:float, wd:float, mom:float, eps:float, trust_coeff:float, decouple_wd:bool,
                  clip:bool, grad_avg:Optional[Tensor]=None, do_wd:bool=True, dampening:bool=False, force_train:Optional[bool]=None):
    dp = p
    if do_wd and wd != 0:
        if decouple_wd:
            # weight_decay
            dp = dp.mul(1 - lr*wd)
        else:
            # l2_reg
            grad = grad.add(dp, alpha=wd)

    # larc_layer_lr
    p_norm = torch.norm(dp)
    g_norm = torch.norm(grad)
    local_lr = lr*trust_coeff * (p_norm) / (g_norm + p_norm * wd + eps)
    if clip:
        lr = min(local_lr, lr)
    else:
        lr = local_lr

    if mom != 0:
        if grad_avg is None: 
            grad_avg = torch.zeros_like(dp, memory_format=torch.preserve_format)

        # average_grad, dampening=True
        grad_avg = torch.mul(grad_avg, mom)
        grad_avg = torch.add(grad_avg, grad)

        # larc_step
        dp = torch.add(dp, grad_avg, alpha=-lr)
    else:
        # larc_step
        dp = torch.add(dp, grad, alpha=-lr)

    p.set_(dp)
    return {'grad_avg': grad_avg}

In [None]:
#|export
def Larc(params, lr, mom=0.9, clip=True, trust_coeff=0.02, eps=1e-8, wd=0., decouple_wd=True, jit=False):
    "A `Optimizer` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`"
    if jit:
        cb = partial(larc_jit_step, decouple_wd=decouple_wd, clip=clip)
        return JitOptimizer(params, cb, lr=lr, mom=mom, trust_coeff=trust_coeff, eps=eps, wd=wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        if mom!=0.: cbs.append(average_grad)
        cbs += [partial(larc_layer_lr, clip=clip), larc_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, trust_coeff=trust_coeff, eps=eps, wd=wd)

The LARS optimizer was first introduced in [Large Batch Training of Convolutional Networks](https://arxiv.org/abs/1708.03888) then refined in its LARC variant (original LARS is with `clip=False`). A learning rate is computed for each individual layer with a certain `trust_coefficient`, then clipped to be always less than `lr`.

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients).

In [None]:
#hide
params_org = [tst_param([1,2,3], [0.1,0.2,0.3]), tst_param([1,2,3], [0.01,0.02,0.03])]
opt_org = Larc(params_org, lr=0.1, jit=False)
opt_org.step()

params_jit = [tst_param([1,2,3], [0.1,0.2,0.3]), tst_param([1,2,3], [0.01,0.02,0.03])]
opt_jit = Larc(params_jit, lr=0.1, jit=True)
opt_jit.step()

#First param local lr is 0.02 < lr so it's not clipped
test_close(opt_org.state[params_org[0]]['local_lr'], 0.02)
#Second param local lr is 0.2 > lr so it's clipped
test_eq(opt_org.state[params_org[1]]['local_lr'], 0.1)

test_close(params_org[0], tensor([0.998,1.996,2.994]))
test_close(params_org[0], params_jit[0])
test_close(params_org[1], tensor([0.999,1.998,2.997]))
test_close(params_org[1], params_jit[1])

In [None]:
#hide
params_org = [tst_param([1,2,3], [0.1,0.2,0.3]), tst_param([1,2,3], [0.01,0.02,0.03])]
opt_org = Larc(params_org, lr=0.1, clip=False, jit=False)
opt_org.step()

params_jit = [tst_param([1,2,3], [0.1,0.2,0.3]), tst_param([1,2,3], [0.01,0.02,0.03])]
opt_jit = Larc(params_jit, lr=0.1, clip=False, jit=True)
opt_jit.step()

#No clipping
test_close(opt_org.state[params_org[0]]['local_lr'], 0.02)
test_close(opt_org.state[params_org[1]]['local_lr'], 0.2)
test_close(params_org[0], tensor([0.998,1.996,2.994]))
test_close(params_org[0], params_jit[0])
test_close(params_org[1], tensor([0.998,1.996,2.994]))
test_close(params_org[1], params_jit[1])

## LAMB -

In [None]:
#|exporti
@torch.jit.script
def lamb_jit_step(p:Tensor, grad:Tensor, lr:float, wd:float, mom:float, sqr_mom:float, eps:float, 
                  decouple_wd:bool, grad_avg:Optional[Tensor]=None, sqr_avg:Optional[Tensor]=None, 
                  do_wd:bool=True, step:int=0, force_train:Optional[bool]=None):
    dp = p
    step += 1
    if do_wd and wd != 0:
        if decouple_wd:
            # weight_decay
            dp = dp.mul(1 - lr*wd)
        else:
            # l2_reg
            grad = grad.add(dp, alpha=wd)

    if grad_avg is None: 
        grad_avg = torch.zeros_like(dp, memory_format=torch.preserve_format)
    if sqr_avg is None: 
        sqr_avg  = torch.zeros_like(dp, memory_format=torch.preserve_format)

    # average_grad, dampening=True
    grad_avg = torch.mul(grad_avg, mom)
    grad_avg = torch.add(grad_avg, grad, alpha=1-mom)

    # average_sqr_grad
    sqr_avg = torch.mul(sqr_avg, sqr_mom)
    sqr_avg = torch.addcmul(sqr_avg, grad, grad, value=1-sqr_mom)

    # lamb_step
    debias1 = debias(mom, step)
    debias2 = debias(sqr_mom, step)
    r1 = dp.pow(2).mean().sqrt()
    lstep = (grad_avg/debias1) / ((sqr_avg/debias2).sqrt()+eps)
    r2 = lstep.pow(2).mean().sqrt()
    if r1 == 0 or r2 == 0:
        q = 1.
    else:
        q = min(r1/r2, 10.)
    dp = torch.add(dp,lstep, alpha = -lr * q)

    p.set_(dp)

    return torch.jit.annotate(Dict[str, Union[Tensor, int]], {'grad_avg': grad_avg, 'sqr_avg': sqr_avg, 'step': step})

In [None]:
#|export
def Lamb(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., decouple_wd=True, jit=False):
    "A `Optimizer` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`"
    if jit:
        cb = partial(lamb_jit_step, decouple_wd=decouple_wd)
        return JitOptimizer(params, cb, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, lamb_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

LAMB was introduced in [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962). Intuitively, it's LARC applied to Adam. As in `Adam`, fastai renamed `beta1` and `beta2` in the paper to `mom` and  `sqr_mom`. Note that the defaults also differ from the paper (0.99 for `sqr_mom` or `beta2`, 1e-5 for `eps`). Those values seem to be better from experiments in a wide range of situations.

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients).

In [None]:
#|hide
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Lamb(params_org, lr=0.1, jit=False)
opt_org.step()

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = Lamb(params_jit, lr=0.1, jit=True)
opt_jit.step()

test_close(params_org[0], tensor([0.7840,1.7840,2.7840]), eps=1e-3)
test_close(params_org[0], params_jit[0])

## Lookahead -

In [None]:
#|export
class JitLookahead(Optimizer):
    "An `Optimizer` with a modified step for Lookahead TorchScript optimizers"
    def __init__(self, params:Tensor, cbs:list, train_bn:bool=True, **defaults):
        super().__init__(params, cbs, train_bn, **defaults)
        self._init_state()

    @torch.no_grad()
    def step(self, closure=None):
        self.count += 1
        if closure is not None: raise NotImplementedError("fastai optimizers currently do not support closure")
        for pg, hyper in zip(self.param_lists, self.hypers):
            for p in pg:
                if hasattr(p, 'grad') and p.grad is not None:
                    _update(self.state[p], self.cbs[0](p, p.grad, **{**self.state[p], **hyper}, count=self.count))

    def clear_state(self):
        super().clear_state()
        self._init_state()

    def state_dict(self):
        state = super().state_dict()
        state.update({'count': self.count})
        return state

    def load_state_dict(self, sd):
        self.count = sd.pop('count')
        super().load_state_dict(sd)

    def _init_state(self): 
        self.count = 0

Lookahead was introduced by Zhang et al. in [Lookahead Optimizer: k steps forward, 1 step back](https://arxiv.org/abs/1907.08610). With Lookahead, the final weights (*slow weights*)are a moving average of the normal weights (*fast weights*). Every `k` steps, Lookahead modifieds the current weights by a moving average of the *fast weights* (normal weights) with the *slow weights* (the copy of old weights k steps ago). Those *slow weights* act like a stability mechanism.

:::{.callout-important}
While fastai's `Lookahead` can be applied to any optimizer, fastxtend's `JitLookahead` must have a custom written TorchScript optimizer callback. Currently `ranger` with `RAdam` is the optimizer with `JitLookahead` support.
:::

## ranger -

In [None]:
#|exporti
@torch.jit.script
def ranger_jit_step(p:Tensor, grad:Tensor, lr:float, wd:float, mom:float, sqr_mom:float, eps:float, beta:float,
                    decouple_wd:bool, count:int, k:int, alpha:float, grad_avg:Optional[Tensor]=None, sqr_avg:Optional[Tensor]=None,
                    slow_p:Optional[Tensor]=None, do_wd:bool=True, step:int=0, force_train:Optional[bool]=None):
    dp = p
    step += 1
    if slow_p is None: 
        slow_p = dp.clone().detach()

    if do_wd and wd != 0:
        if decouple_wd:
            # weight_decay
            dp = dp.mul(1 - lr*wd)
        else:
            # l2_reg
            grad = grad.add(dp, alpha=wd)

    if grad_avg is None: 
        grad_avg = torch.zeros_like(dp, memory_format=torch.preserve_format)
    if sqr_avg is None: 
        sqr_avg  = torch.zeros_like(dp, memory_format=torch.preserve_format)

    # average_grad, dampening=True
    grad_avg = torch.mul(grad_avg, mom)
    grad_avg = torch.add(grad_avg, grad, alpha=1-mom)

    # average_sqr_grad
    sqr_avg = torch.mul(sqr_avg, sqr_mom)
    sqr_avg = torch.addcmul(sqr_avg, grad, grad, value=1-sqr_mom)

    # radam_step
    debias1 = debias(mom, step)
    debias2 = debias(sqr_mom, step)
    r_inf = 2/(1-sqr_mom) - 1
    r = r_inf - 2*step*sqr_mom**step/(1-sqr_mom**step)
    
    if r > 5:
        v = math.sqrt(((r-4) * (r-2) * r_inf)/((r_inf-4)*(r_inf-2)*r))
        denom = torch.sqrt(sqr_avg/debias2)
        if eps != 0: 
            denom = denom + eps
        if beta != 0: 
            denom = F.softplus(denom, beta)
        dp = torch.addcdiv(dp, grad_avg, denom, value = -lr*v / debias1)
    else:
        dp = torch.add(dp, grad_avg, alpha=-lr / debias1)

    # lookahead step
    if count % k != 0:
        p.set_(dp)
    else:
        slow_p = torch.add(slow_p, torch.sub(dp, slow_p), alpha=alpha)
        p.set_(slow_p)
    
    return torch.jit.annotate(Dict[str, Union[Tensor, int]], {'grad_avg': grad_avg, 'sqr_avg': sqr_avg, 'step': step, 'slow_p': slow_p})

In [None]:
#|export
def ranger(params, lr, mom=0.95, sqr_mom=0.99, eps=1e-6, wd=0.01, beta=0., k=6, alpha=0.5, decouple_wd=True, jit=False):
    "Convenience method for `Lookahead` with `RAdam`"
    if jit:
        cb = partial(ranger_jit_step, decouple_wd=decouple_wd)
        return JitLookahead(params, cb, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, beta=beta, k=k, alpha=alpha)
    else:
        return Lookahead(RAdam(params, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, beta=beta), k=k, alpha=alpha)

Ranger was introduced by Less Wright [New Deep Learning Optimizer, Ranger: Synergistic combination of RAdam + Lookahead for the best of both.](https://lessw.medium.com/new-deep-learning-optimizer-ranger-synergistic-combination-of-radam-lookahead-for-the-best-of-2dc83f79a48d) It combines RAdam and Lookahead together in one optimizer and reduces the need for hyperparameter tuning due to a combination of RAdam's warmup heuristic and Lookahead's interpolation of parameter weights.

Ranger performs best on vision tasks when paired with the `fit_flat_cos` scheduler.

In [None]:
#|hide
po = tensor([1,2,3])

params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = ranger(params_org, lr=0.1, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., beta=0., jit=False)

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = ranger(params_jit, lr=0.1, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., beta=0., jit=True)

#The first 5 steps are normal RAdam steps
#The r factor is lower than 5 during the first 5 steps so updates use the average of gradients (all the same)
r_inf = 2/(1-0.99) - 1
for i in range(5): 
    r = r_inf - 2*(i+1)*0.99**(i+1)/(1-0.99**(i+1))
    assert r <= 5
    opt_org.step()
    opt_jit.step()
p = tensor([0.95, 1.9, 2.85])
test_close(params_org[0], p)
test_close(params_org[0], params_jit[0])

#The r factor is greater than 5 for the sixth step so we update with RAdam
r = r_inf - 2*6*0.99**6/(1-0.99**6)
assert r > 5
opt_org.step()
opt_jit.step()
v = math.sqrt(((r-4) * (r-2) * r_inf)/((r_inf-4)*(r_inf-2)*r))
step = -0.1*0.1*v/(math.sqrt(0.1**2) + 1e-8)

#Since k=6, sixth step is a moving average of the 6 RAdam steps with the initial weight
test_close(params_org[0], po+((p+step)-po)*0.5)
test_close(params_org[0], params_jit[0])