In [None]:
#|default_exp optimizer.torchscript

In [None]:
#|exporti
# Contains code from:
# fastai - Apache License 2.0 - Copyright (c) 2023 fast.ai

# TorchScript Optimizers
> Fused fastai optimizers compiled with TorchScript for improved performance

fastxtend TorchScript optimizers are adapted from [fastai optimizers](https://docs.fast.ai/optimizer.html) and are modified to be compiled with TorchScript. They are 10 to 137 percent faster relative to fastai native optimizers depending on the model and optimizer, with complex optimizers like `QHAdam` recieving the largest performance increase.

Unlike fastai optimizers, which are made of [multiple stepper callbacks](https://docs.fast.ai/optimizer.html#basic-steppers), TorchScript optimizers require a per-optimizer step so TorchScript can fuse the operation into as few CUDA calls as possible. All fastai optimizers have TorchScript implementations.

:::{.callout-important}
TorchScript optimizers have only been tested on PyTorch 1.12+ with NVFuser and are not guaranteed to work on older versions.
:::

TorchScript optimizers are faster due to vertical fusion across multiple Cuda calls. Using `xresnet50` and `SGD` with momentum as an example, a TorchScript fused `SGD` step would (hopefully) fuse all three Cuda calls (`mul`, `add`, and `add`) into one or two Cuda kernels resulting in 167 or 334 Cuda calls.

```python
@torch.jit.script
def sgd_momentum_jit(param:Tensor, grad:Tensor, grad_avg:Tensor, lr:float):
    grad_avg = grad_avg.mul(mom).add(grad)
    param = param.add(grad_avg, alpha=-lr)
```

In contrast, a standard PyTorch optimizer would call the `SGD` with momentum step 167 times for a total of 501 inplace Cuda kernel calls:

```python
def simple_momentum_standard(param:Tensor, grad_avg:Tensor, lr:float):
    grad_avg.mul_(mom).add_(grad)
    param.add_(param.grad, alpha=-lr)
```

TorchScript optimizers are tested to be equal to fastai optimizers for 25 steps using [nbdev's GitHub CI](https://nbdev.fast.ai/tutorials/tutorial.html#check-out-your-workflows).

In [None]:
#|export
from __future__ import annotations
from typing import Optional, Dict

from packaging.version import parse

import fastai
from fastai.optimizer import Optimizer

from fastxtend.imports import *

## Test Utils -

In [None]:
#|hide
from nbdev.showdoc import *

from fastai.optimizer import (weight_decay, l2_reg, average_grad, sgd_step, momentum_step,
                              average_sqr_grad, rms_prop_step, step_stat, adam_step, radam_step,
                              qhadam_step, larc_layer_lr, larc_step, lamb_step, Lookahead)

from fastxtend.test_utils import *

In [None]:
#|hide
# tests are copied with light modifications from fastai
test_steps = 25

def tst_param(val, grad=None):
    "Create a tensor with `val` and a gradient of `grad` for testing"
    res = tensor([val]).float()
    res.grad = tensor([val/10 if grad is None else grad]).float()
    return res

def tst_params():
    r = L.range(4)
    return r.map(tst_param)

## JitOptimizer -

In [None]:
#|exporti
def _update(
    state:dict,
    new=None # New values to update `state` dict
):
    if isinstance(new, dict): state.update(new)

In [None]:
#|exporti
class JitOptimizer(Optimizer):
    "An `Optimizer` with a modified step for TorchScript optimizers"
    def __init__(self,
        params:Listified[Tensor], # Model parameters
        opt_step:Callable, # `JitOptimizer` optimizer step
        decouple_wd:bool=False, # Use decoupled weight decay or L2 regularization, if applicable
        **defaults
    ):
        if notmax_torch('1.12'):
            warn(f'TorchScript optimizers are untested on PyTorch {torch.__verson__}, recommended to use 1.12 or newer')
        if parse(fastai.__version__) < parse('2.7.11'):
            super().__init__(params, [None], True, **defaults)
        else:
            super().__init__(params, [None], **defaults)
        self.opt_step = opt_step
        self.decouple_wd = decouple_wd

    @torch.no_grad()
    def step(self, closure=None):
        if closure is not None:
            raise NotImplementedError("fastai optimizers currently do not support closure")
        for pg, hyper in zip(self.param_lists, self.hypers):
            for p in pg:
                if hasattr(p, 'grad') and p.grad is not None:
                    state = self.state[p]
                    _update(state, self.opt_step(p=p, g=p.grad, decouple_wd=self.decouple_wd, **{**state, **hyper}))

In [None]:
show_doc(JitOptimizer)

## SGD -

In [None]:
#|exporti
@torch.jit.script
def sgd_jit_step(p:Tensor, g:Tensor, decouple_wd:bool, lr:float, wd:float, mom:float,
                 grad_avg:Optional[Tensor]=None, do_wd:bool=True, force_train:Optional[bool]=None):
    "SGD TorchScript compiled `JitOptimizer` step"
    param = p
    grad = g
    if do_wd and wd != 0:
        if decouple_wd:
            # weight_decay
            param = param.mul(1-lr*wd)
        else:
            # l2_reg
            grad = grad.add(param, alpha=wd)

    if mom != 0:
        if grad_avg is None:
            grad_avg = torch.zeros_like(param, memory_format=torch.preserve_format)

        # average_grad, dampening=False
        grad_avg = grad_avg.mul(mom).add(grad)

        # momentum_step
        param = param.add(grad_avg, alpha=-lr)
        p.set_(param)
        g.set_(grad)
        return {'grad_avg': grad_avg}
    else:
        # sgd_step
        param = param.add(grad, alpha=-lr)
        p.set_(param)
        g.set_(grad)
        return None

In [None]:
show_doc(sgd_jit_step.name)

In [None]:
#|hide
def SGD(params, lr, mom=0., wd=0., decouple_wd=True, jit=False):
    "A `Optimizer` or `JitOptimizer` for SGD with `lr` and `mom` and `params`"
    if jit:
        return JitOptimizer(params, sgd_jit_step, lr=lr, mom=mom, wd=wd, decouple_wd=decouple_wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        if mom != 0: cbs.append(average_grad)
        cbs.append(sgd_step if mom==0 else momentum_step)
        return Optimizer(params, cbs, lr=lr, mom=mom, wd=wd)

In [None]:
#|hide
# Vanilla SGD
params_org = tst_params()
opt_org = SGD(params_org, lr=0.1, jit=False)
opt_org.step()

params_jit = tst_params()
opt_jit = SGD(params_jit, lr=0.1, jit=True)
opt_jit.step()

test_close([p.item() for p in params_org], [i*0.99 for i in range(4)])
test_close([p.item() for p in params_org], [p.item() for p in params_jit])

opt_org.step()
opt_jit.step()
test_close([p.item() for p in params_org], [i*0.98 for i in range(4)])
test_close([p.item() for p in params_org], [p.item() for p in params_jit])

for i in range(test_steps):
    opt_org.step()
    opt_jit.step()
test_close([p.item() for p in params_org], [p.item() for p in params_jit])

In [None]:
#|hide
# SGD with momentum
params_org = tst_params()
opt_org = SGD(params_org, lr=0.1, mom=0.9, jit=False)
opt_org.step()

params_jit = tst_params()
opt_jit = SGD(params_jit, lr=0.1, mom=0.9, jit=True)
opt_jit.step()

test_close([p.item() for p in params_org], [i*0.99 for i in range(4)])
test_close([p.item() for p in params_org], [p.item() for p in params_jit])

opt_org.step()
opt_jit.step()

test_close([p.item() for p in params_org], [i*(1 - 0.1 * (0.1 + 0.1*1.9)) for i in range(4)])
test_close([p.item() for p in params_org], [p.item() for p in params_jit])

for i,p in enumerate(params_org):
    test_close(opt_org.state[p]['grad_avg'].item(), i*0.19)
for i,p in enumerate(params_jit):
    test_close(opt_jit.state[p]['grad_avg'].item(), i*0.19)

for i in range(test_steps):
    opt_org.step()
    opt_jit.step()
test_close([p.item() for p in params_org], [p.item() for p in params_jit])

In [None]:
#|hide
# Weight decay
params_org = tst_params()
opt_org = SGD(params_org, lr=0.1, mom=0.9, wd=0.1, jit=False)
opt_org.step()
opt_org.step()

params_jit = tst_params()
opt_jit = SGD(params_jit, lr=0.1, mom=0.9, wd=0.1, jit=True)
opt_jit.step()
opt_jit.step()

test_close([p.item() for p in params_org], [i*0.9512 for i in range(4)])
test_close([p.item() for p in params_org], [p.item() for p in params_jit])

opt_org.step()
opt_org.step()
opt_jit.step()
opt_jit.step()
test_close([p.item() for p in params_org], [p.item() for p in params_jit])

for i in range(test_steps):
    opt_org.step()
    opt_jit.step()
test_close([p.item() for p in params_org], [p.item() for p in params_jit])

In [None]:
#|hide
# L2 reg
params_org = tst_params()
opt_org = SGD(params_org, lr=0.1, mom=0.9, wd=0.1, decouple_wd=False, jit=False)
opt_org.step()
opt_org.step()

params_jit = tst_params()
opt_jit = SGD(params_jit, lr=0.1, mom=0.9, wd=0.1, decouple_wd=False, jit=True)
opt_jit.step()
opt_jit.step()

test_close([p.item() for p in params_org], [i*0.9322 for i in range(4)])
test_close([p.item() for p in params_org], [p.item() for p in params_jit])

for i in range(test_steps):
    opt_org.step()
    opt_jit.step()
test_close([p.item() for p in params_org], [p.item() for p in params_jit])

## RMSProp -

In [None]:
#|exporti
@torch.jit.script
def rmsprop_jit_step(p:Tensor, g:Tensor, lr:float, wd:float, mom:float, sqr_mom:float, eps:float, decouple_wd:bool,
                     grad_avg:Optional[Tensor]=None, sqr_avg:Optional[Tensor]=None, do_wd:bool=True, force_train:Optional[bool]=None):
    "SGD TorchScript compiled `JitOptimizer` step"
    param = p
    grad = g
    if do_wd and wd != 0:
        if decouple_wd:
            # weight_decay
            param = param.mul(1 - lr*wd)
        else:
            # l2_reg
            grad = grad.add(param, alpha=wd)

    if sqr_avg is None:
        sqr_avg = torch.zeros_like(param, memory_format=torch.preserve_format)

    if mom != 0:
        if grad_avg is None:
            grad_avg = torch.zeros_like(param, memory_format=torch.preserve_format)

        # average_grad, dampening=False
        grad_avg = grad_avg.mul(mom).add(grad)

        # average_sqr_grad
        sqr_avg = sqr_avg.mul(sqr_mom).addcmul(grad, grad, value=1-sqr_mom)

        # rms_prop_step
        param = param.addcdiv(grad_avg, sqr_avg.sqrt().add(eps), value=-lr)
        p.set_(param)
        g.set_(grad)
        return {'grad_avg': grad_avg, 'sqr_avg': sqr_avg}
    else:
        # average_sqr_grad
        sqr_avg = sqr_avg.mul(sqr_mom).addcmul(grad, grad, value=1-sqr_mom)

        # rms_prop_step
        param = param.addcdiv(grad, sqr_avg.sqrt().add(eps), value=-lr)
        p.set_(param)
        g.set_(grad)
        return {'sqr_avg': sqr_avg}

In [None]:
show_doc(rmsprop_jit_step.name)

In [None]:
#|hide
def RMSProp(params, lr, sqr_mom=0.99, mom=0., eps=1e-8, wd=0., decouple_wd=True, jit=False):
    if jit:
        return JitOptimizer(params, rmsprop_jit_step, lr=lr, mom=mom, sqr_mom=sqr_mom,
                            eps=eps, wd=wd, decouple_wd=decouple_wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += ([average_sqr_grad] if mom==0. else [average_grad, average_sqr_grad])
        cbs.append(rms_prop_step)
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, wd=wd, eps=eps)

In [None]:
#|hide
# Without momentum
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = RMSProp(params_org, lr=0.1, jit=False)
opt_org.step()

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = RMSProp(params_jit, lr=0.1, jit=True)
opt_jit.step()

test_close(params_org[0], tensor([0.,1.,2.]))
test_close(params_org[0], params_jit[0])

opt_org.step()
opt_jit.step()
step = - 0.1 * 0.1 / (math.sqrt((0.01*0.99+0.01) * 0.1**2) + 1e-8)
test_close(params_org[0], tensor([step, 1+step, 2+step]))
test_close(params_org[0], params_jit[0])

for i in range(test_steps):
    opt_org.step()
    opt_jit.step()
test_close(params_org[0], params_jit[0])

In [None]:
#|hide
# With momentum
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = RMSProp(params_org, lr=0.1, mom=0.9, jit=False)
opt_org.step()

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = RMSProp(params_jit, lr=0.1, mom=0.9, jit=True)
opt_jit.step()

test_close(params_org[0], tensor([0.,1.,2.]))
test_close(params_org[0], params_jit[0])

opt_org.step()
opt_jit.step()
step = - 0.1 * (0.1 + 0.9*0.1) / (math.sqrt((0.01*0.99+0.01) * 0.1**2) + 1e-8)
test_close(params_org[0], tensor([step, 1+step, 2+step]))
test_close(params_org[0], params_jit[0])

for i in range(test_steps):
    opt_org.step()
    opt_jit.step()
test_close(params_org[0], params_jit[0])

## Adam -

In [None]:
#|exporti
@torch.jit.script
def adam_jit_step(p:Tensor, g:Tensor, lr:float, wd:float, mom:float, sqr_mom:float, eps:float,
                  decouple_wd:bool, grad_avg:Optional[Tensor]=None, sqr_avg:Optional[Tensor]=None,
                  do_wd:bool=True, step:int=0, force_train:Optional[bool]=None):
    "Adam TorchScript compiled `JitOptimizer` step"
    param = p
    grad = g
    step += 1
    if do_wd and wd != 0:
        if decouple_wd:
            # weight_decay
            param = param.mul(1 - lr*wd)
        else:
            # l2_reg
            grad = grad.add(param, alpha=wd)

    if grad_avg is None:
        grad_avg = torch.zeros_like(param, memory_format=torch.preserve_format)
    if sqr_avg is None:
        sqr_avg  = torch.zeros_like(param, memory_format=torch.preserve_format)

    # average_grad, dampening=True
    grad_avg = grad_avg.mul(mom).add(grad, alpha=1-mom)

    # average_sqr_grad
    sqr_avg = sqr_avg.mul(sqr_mom).addcmul(grad, grad, value=1-sqr_mom)

    # adam_step
    debias1 = 1-mom**step
    debias2 = 1-sqr_mom**step
    param = torch.addcdiv(param, grad_avg, torch.sqrt(sqr_avg/debias2) + eps, value = -lr / debias1)
    p.set_(param)
    g.set_(grad)

    return torch.jit.annotate(Dict[str, Union[Tensor, int]], {'grad_avg': grad_avg, 'sqr_avg': sqr_avg, 'step': step})

In [None]:
show_doc(adam_jit_step.name)

In [None]:
#|hide
def Adam(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0.01, decouple_wd=True, jit=False):
    if jit:
        return JitOptimizer(params, adam_jit_step, lr=lr, mom=mom, sqr_mom=sqr_mom,
                            eps=eps, wd=wd, decouple_wd=decouple_wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, adam_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

In [None]:
#|hide
# fastai test
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Adam(params_org, lr=0.1, wd=0, jit=False)
opt_org.step()

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = Adam(params_jit, lr=0.1, wd=0, jit=True)
opt_jit.step()

step = -0.1 * 0.1 / (math.sqrt(0.1**2) + 1e-8)
test_close(params_org[0], tensor([1+step, 2+step, 3+step]))
test_close([p[0].item() for p in params_org], [p[0].item() for p in params_jit])

opt_org.step()
opt_jit.step()
test_close(params_org[0], tensor([1+2*step, 2+2*step, 3+2*step]), eps=1e-3)
test_close([p[0].item() for p in params_org], [p[0].item() for p in params_jit])

for i in range(test_steps):
    opt_org.step()
    opt_jit.step()
test_close([p[0].item() for p in params_org], [p[0].item() for p in params_jit])

In [None]:
#|hide
# test with weight decay
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Adam(params_org, lr=0.1, wd=0.1, jit=False)

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = Adam(params_jit, lr=0.1, wd=0.1, jit=True)

for i in range(test_steps):
    opt_org.step()
    opt_jit.step()

test_close([p[0].item() for p in params_org], [p[0].item() for p in params_jit])

In [None]:
#|hide
# test with l2 reg
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Adam(params_org, lr=0.1, wd=0.1, decouple_wd=False, jit=False)

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = Adam(params_jit, lr=0.1, wd=0.1, decouple_wd=False, jit=True)

for i in range(test_steps):
    opt_org.step()
    opt_jit.step()

test_close([p[0].item() for p in params_org], [p[0].item() for p in params_jit])

## RAdam -

In [None]:
#|exporti
@torch.jit.script
def radam_jit_step(p:Tensor, g:Tensor, lr:float, wd:float, mom:float, sqr_mom:float, eps:float,
                   decouple_wd:bool, grad_avg:Optional[Tensor]=None, sqr_avg:Optional[Tensor]=None,
                   do_wd:bool=True, step:int=0, force_train:Optional[bool]=None):
    "RAdam TorchScript compiled `JitOptimizer` step"
    param = p
    grad = g
    step += 1
    if do_wd and wd != 0:
        if decouple_wd:
            # weight_decay
            param = param.mul(1 - lr*wd)
        else:
            # l2_reg
            grad = grad.add(param, alpha=wd)

    if grad_avg is None:
        grad_avg = torch.zeros_like(param, memory_format=torch.preserve_format)
    if sqr_avg is None:
        sqr_avg  = torch.zeros_like(param, memory_format=torch.preserve_format)

    # average_grad, dampening=True
    grad_avg = grad_avg.mul(mom).add(grad, alpha=1-mom)

    # average_sqr_grad
    sqr_avg = sqr_avg.mul(sqr_mom).addcmul(grad, grad, value=1-sqr_mom)

    # radam_step
    debias1 = 1-mom**step
    debias2 = 1-sqr_mom**step
    r_inf = 2/(1-sqr_mom) - 1
    r = r_inf - 2*step*sqr_mom**step/(1-sqr_mom**step)

    if r > 5:
        v = math.sqrt(((r-4)*(r-2)*r_inf)/((r_inf-4)*(r_inf-2)*r))
        denom = torch.sqrt(sqr_avg/debias2).add(eps)
        param = param.addcdiv(grad_avg, denom, value=-lr*v/debias1)
    else:
        param = param.add(grad_avg, alpha=-lr/debias1)
    p.set_(param)
    g.set_(grad)

    return torch.jit.annotate(Dict[str, Union[Tensor, int]], {'grad_avg': grad_avg, 'sqr_avg': sqr_avg, 'step': step})

In [None]:
show_doc(radam_jit_step.name)

In [None]:
#|hide
def RAdam(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., beta=0., decouple_wd=True, jit=False):
    if jit:
        if beta != 0: warn('TorchScript RAdam does not use beta, set jit=False if beta!=0')
        return JitOptimizer(params, radam_jit_step, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps,
                            wd=wd, decouple_wd=decouple_wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, radam_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, beta=beta)

In [None]:
#|hide
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = RAdam(params_org, lr=0.1, jit=False)

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = RAdam(params_jit, lr=0.1, jit=True)

#The r factor is lower than 5 during the first 5 steps so updates use the average of gradients (all the same)
r_inf = 2/(1-0.99) - 1
for i in range(5):
    r = r_inf - 2*(i+1)*0.99**(i+1)/(1-0.99**(i+1))
    assert r <= 5
    opt_org.step()
    opt_jit.step()
p = tensor([0.95, 1.9, 2.85])
test_close(params_org[0], p)
test_close(params_org[0], params_jit[0])

#The r factor is greater than 5 for the sixth step so we update with RAdam
r = r_inf - 2*6*0.99**6/(1-0.99**6)
assert r > 5
opt_org.step()
opt_jit.step()
v = math.sqrt(((r-4) * (r-2) * r_inf)/((r_inf-4)*(r_inf-2)*r))
step = -0.1*0.1*v/(math.sqrt(0.1**2) + 1e-8)
test_close(params_org[0], p+step)
test_close(params_org[0], params_jit[0])

for i in range(test_steps):
    opt_org.step()
    opt_jit.step()

test_close(params_org[0], params_jit[0])

In [None]:
#|hide
# test with weight decay
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = RAdam(params_org, lr=0.1, wd=0.1, jit=False)

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = RAdam(params_jit, lr=0.1, wd=0.1, jit=True)

for i in range(test_steps):
    opt_org.step()
    opt_jit.step()

test_close(params_org[0], params_jit[0])

## QHAdam -

In [None]:
#|exporti
@torch.jit.script
def qhadam_jit_step(p:Tensor, g:Tensor, lr:float, wd:float, mom:float, sqr_mom:float, eps:float,
                    nu_1:float, nu_2:float, decouple_wd:bool, grad_avg:Optional[Tensor]=None,
                    sqr_avg:Optional[Tensor]=None, do_wd:bool=True, step:int=0, force_train:Optional[bool]=None):
    "QHAdam TorchScript compiled `JitOptimizer` step"
    param = p
    grad = g
    step += 1
    if do_wd and wd != 0:
        if decouple_wd:
            # weight_decay
            param = param.mul(1 - lr*wd)
        else:
            # l2_reg
            grad = grad.add(param, alpha=wd)

    if grad_avg is None:
        grad_avg = torch.zeros_like(param, memory_format=torch.preserve_format)
    if sqr_avg is None:
        sqr_avg  = torch.zeros_like(param, memory_format=torch.preserve_format)

    # average_grad, dampening=True
    grad_avg = grad_avg.mul(mom).add(grad, alpha=1-mom)

    # average_sqr_grad
    sqr_avg = sqr_avg.mul(sqr_mom).addcmul(grad, grad, value=1-sqr_mom)

    # qhadam_step
    debias1 = 1-mom**step
    debias2 = 1-sqr_mom**step
    param = param.addcdiv(((1-nu_1) * grad) + (nu_1 * (grad_avg / debias1)),
                          torch.sqrt(((1 - nu_2) * (grad)**2) + (nu_2 * (sqr_avg / debias2))) + eps,
                          value = -lr)
    p.set_(param)
    g.set_(grad)

    return torch.jit.annotate(Dict[str, Union[Tensor, int]], {'grad_avg': grad_avg, 'sqr_avg': sqr_avg, 'step': step})

In [None]:
show_doc(qhadam_jit_step.name)

In [None]:
#|hide
def QHAdam(params, lr, mom=0.999, sqr_mom=0.999, nu_1=0.7, nu_2=1.0, eps=1e-8, wd=0., decouple_wd=True, jit=True):
    if jit:
        return JitOptimizer(params, qhadam_jit_step, lr=lr, nu_1=nu_1, nu_2=nu_2, mom=mom,
                            sqr_mom=sqr_mom, eps=eps, wd=wd, decouple_wd=decouple_wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, qhadam_step]
        return Optimizer(params, cbs, lr=lr, nu_1=nu_1, nu_2=nu_2, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

In [None]:
#|hide
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = QHAdam(params_org, lr=0.1, jit=False)
opt_org.step()

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = QHAdam(params_jit, lr=0.1, jit=True)
opt_jit.step()

step = -0.1 * (((1-0.7) * 0.1) + (0.7 * 0.1)) / (
     math.sqrt(((1-1.0) * 0.1**2) + (1.0 * 0.1**2)) + 1e-8)

test_close(params_org[0], tensor([1+step, 2+step, 3+step]))
test_close(params_org[0], params_jit[0])

opt_org.step()
opt_jit.step()
test_close(params_org[0], tensor([1+2*step, 2+2*step, 3+2*step]), eps=1e-3)
test_close(params_org[0], params_jit[0])

for i in range(test_steps):
    opt_org.step()
    opt_jit.step()

test_close(params_org[0], params_jit[0])

In [None]:
#|hide
# test with weight decay
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = QHAdam(params_org, lr=0.1, wd=0.1, jit=False)

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = QHAdam(params_jit, lr=0.1, wd=0.1, jit=True)

for i in range(test_steps):
    opt_org.step()
    opt_jit.step()

test_close(params_org[0], params_jit[0])

## LARS/LARC -

In [None]:
#|exporti
@torch.jit.script
def larc_jit_step(p:Tensor, g:Tensor, lr:float, wd:float, mom:float, eps:float, trust_coeff:float, decouple_wd:bool,
                  clip:bool, grad_avg:Optional[Tensor]=None, do_wd:bool=True, dampening:bool=False, force_train:Optional[bool]=None):
    "LARC TorchScript compiled `JitOptimizer` step"
    param = p
    grad = g
    if do_wd and wd != 0:
        if decouple_wd:
            # weight_decay
            param = param.mul(1 - lr*wd)
        else:
            # l2_reg
            grad = grad.add(param, alpha=wd)

    # larc_layer_lr
    p_norm = torch.norm(param)
    g_norm = torch.norm(grad)
    local_lr = lr*trust_coeff * (p_norm) / (g_norm + p_norm * wd + eps)
    if clip:
        lr = min(local_lr, lr)
    else:
        lr = local_lr

    if mom != 0:
        if grad_avg is None:
            grad_avg = torch.zeros_like(param, memory_format=torch.preserve_format)

        # average_grad, dampening=False
        grad_avg = grad_avg.mul(mom).add(grad)

        # larc_step
        param = torch.add(param, grad_avg, alpha=-lr)
    else:
        # larc_step
        param = torch.add(param, grad, alpha=-lr)

    p.set_(param)
    g.set_(grad)
    return {'grad_avg': grad_avg}

In [None]:
show_doc(larc_jit_step.name)

In [None]:
#|hide
def Larc(params, lr, mom=0.9, clip=True, trust_coeff=0.02, eps=1e-8, wd=0., decouple_wd=True, jit=False):
    if jit:
        cb = partial(larc_jit_step, clip=clip)
        return JitOptimizer(params, cb, lr=lr, mom=mom, trust_coeff=trust_coeff,
                            eps=eps, wd=wd, decouple_wd=decouple_wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        if mom!=0.: cbs.append(average_grad)
        cbs += [partial(larc_layer_lr, clip=clip), larc_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, trust_coeff=trust_coeff, eps=eps, wd=wd)

In [None]:
#|hide
params_org = [tst_param([1,2,3], [0.1,0.2,0.3]), tst_param([1,2,3], [0.01,0.02,0.03])]
opt_org = Larc(params_org, lr=0.1, jit=False)
opt_org.step()

params_jit = [tst_param([1,2,3], [0.1,0.2,0.3]), tst_param([1,2,3], [0.01,0.02,0.03])]
opt_jit = Larc(params_jit, lr=0.1, jit=True)
opt_jit.step()

#First param local lr is 0.02 < lr so it's not clipped
test_close(opt_org.state[params_org[0]]['local_lr'], 0.02)
#Second param local lr is 0.2 > lr so it's clipped
test_eq(opt_org.state[params_org[1]]['local_lr'], 0.1)

test_close(params_org[0], tensor([0.998,1.996,2.994]))
test_close(params_org[0], params_jit[0])
test_close(params_org[1], tensor([0.999,1.998,2.997]))
test_close(params_org[1], params_jit[1])

for i in range(test_steps):
    opt_org.step()
    opt_jit.step()

test_close(params_org[0], params_jit[0])
test_close(params_org[1], params_jit[1])

In [None]:
#|hide
params_org = [tst_param([1,2,3], [0.1,0.2,0.3]), tst_param([1,2,3], [0.01,0.02,0.03])]
opt_org = Larc(params_org, lr=0.1, clip=False, jit=False)
opt_org.step()

params_jit = [tst_param([1,2,3], [0.1,0.2,0.3]), tst_param([1,2,3], [0.01,0.02,0.03])]
opt_jit = Larc(params_jit, lr=0.1, clip=False, jit=True)
opt_jit.step()

#No clipping
test_close(opt_org.state[params_org[0]]['local_lr'], 0.02)
test_close(opt_org.state[params_org[1]]['local_lr'], 0.2)
test_close(params_org[0], tensor([0.998,1.996,2.994]))
test_close(params_org[0], params_jit[0])
test_close(params_org[1], tensor([0.998,1.996,2.994]))
test_close(params_org[1], params_jit[1])

for i in range(test_steps):
    opt_org.step()
    opt_jit.step()

test_close(params_org[0], params_jit[0])
test_close(params_org[1], params_jit[1])

In [None]:
#|hide
# test with weight decay
params_org = [tst_param([1,2,3], [0.1,0.2,0.3]), tst_param([1,2,3], [0.01,0.02,0.03])]
opt_org = Larc(params_org, lr=0.1, wd=0.1, jit=False)

params_jit = [tst_param([1,2,3], [0.1,0.2,0.3]), tst_param([1,2,3], [0.01,0.02,0.03])]
opt_jit = Larc(params_jit, lr=0.1, wd=0.1, jit=True)

for i in range(test_steps):
    opt_org.step()
    opt_jit.step()

test_close(params_org[0], params_jit[0])
test_close(params_org[1], params_jit[1])

## LAMB -

In [None]:
#|exporti
@torch.jit.script
def lamb_jit_step(p:Tensor, g:Tensor, lr:float, wd:float, mom:float, sqr_mom:float, eps:float,
                  decouple_wd:bool, grad_avg:Optional[Tensor]=None, sqr_avg:Optional[Tensor]=None,
                  do_wd:bool=True, step:int=0, force_train:Optional[bool]=None):
    "LAMB TorchScript compiled `JitOptimizer` step"
    param = p
    grad = g
    step += 1
    if do_wd and wd != 0:
        if decouple_wd:
            # weight_decay
            param = param.mul(1 - lr*wd)
        else:
            # l2_reg
            grad = grad.add(param, alpha=wd)

    if grad_avg is None:
        grad_avg = torch.zeros_like(param, memory_format=torch.preserve_format)
    if sqr_avg is None:
        sqr_avg  = torch.zeros_like(param, memory_format=torch.preserve_format)

    # average_grad, dampening=True
    grad_avg = grad_avg.mul(mom).add(grad, alpha=1-mom)

    # average_sqr_grad
    sqr_avg = sqr_avg.mul(sqr_mom).addcmul(grad, grad, value=1-sqr_mom)

    # lamb_step
    debias1 = 1-mom**step
    debias2 = 1-sqr_mom**step

    r1 = param.norm(2)
    lstep = (grad_avg/debias1) / ((sqr_avg/debias2).sqrt()+eps)
    r2 = lstep.norm(2)

    if r1 == 0 or r2 == 0:
        param = param.add(lstep, alpha=-lr)
    else:
        q = min(r1/r2, 10.)
        param = param.add(lstep, alpha=-lr*q)

    p.set_(param)
    g.set_(grad)

    return torch.jit.annotate(Dict[str, Union[Tensor, int]], {'grad_avg': grad_avg, 'sqr_avg': sqr_avg, 'step': step})

In [None]:
show_doc(lamb_jit_step.name)

In [None]:
#|hide
def Lamb(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., decouple_wd=True, jit=False):
    if jit:
        return JitOptimizer(params, lamb_jit_step, lr=lr, mom=mom, sqr_mom=sqr_mom,
                            eps=eps, wd=wd, decouple_wd=decouple_wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, lamb_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

In [None]:
#|hide
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Lamb(params_org, lr=0.1, jit=False)
opt_org.step()

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = Lamb(params_jit, lr=0.1, jit=True)
opt_jit.step()

test_close(params_org[0], tensor([0.7840,1.7840,2.7840]), eps=1e-3)
test_close(params_org[0], params_jit[0])

for i in range(test_steps):
    opt_org.step()
    opt_jit.step()

test_close(params_org[0], params_jit[0])

In [None]:
#|hide
# test with weight decay
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Lamb(params_org, lr=0.1, wd=0.1, jit=False)

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = Lamb(params_jit, lr=0.1, wd=0.1, jit=True)

for i in range(test_steps):
    opt_org.step()
    opt_jit.step()

test_close(params_org[0], params_jit[0])

## Lookahead -

In [None]:
#|exporti
class JitLookahead(JitOptimizer):
    "An `JitOptimizer` with a modified step for Lookahead TorchScript optimizers"
    def __init__(self,
        params:Listified[Tensor], # Model parameters
        opt_step:Callable, # `JitLookahead` optimizer step
        decouple_wd:bool=False, # Use decoupled weight decay or L2 regularization, if applicable
        **defaults
    ):
        super().__init__(params, opt_step, decouple_wd, **defaults)
        self._init_state()

    @torch.no_grad()
    def step(self, closure=None):
        self.count += 1
        if closure is not None: raise NotImplementedError("fastai optimizers currently do not support closure")
        for pg, hyper in zip(self.param_lists, self.hypers):
            for p in pg:
                if hasattr(p, 'grad') and p.grad is not None:
                    _update(self.state[p], self.opt_step(p, p.grad, decouple_wd=self.decouple_wd, **{**self.state[p], **hyper}, count=self.count))

    def clear_state(self):
        super().clear_state()
        self._init_state()

    def state_dict(self):
        state = super().state_dict()
        state.update({'count': self.count})
        return state

    def load_state_dict(self, sd):
        self.count = sd.pop('count')
        super().load_state_dict(sd)

    def _init_state(self):
        self.count = 0

In [None]:
show_doc(JitLookahead)

## ranger -

In [None]:
#|exporti
@torch.jit.script
def ranger_jit_step(p:Tensor, g:Tensor, lr:float, wd:float, mom:float, sqr_mom:float, eps:float, decouple_wd:bool,
                    count:int, k:int, alpha:float, grad_avg:Optional[Tensor]=None, sqr_avg:Optional[Tensor]=None,
                    slow_p:Optional[Tensor]=None, do_wd:bool=True, step:int=0, force_train:Optional[bool]=None):
    "ranger TorchScript compiled `JitOptimizer` step"
    param = p
    grad = g
    step += 1
    if slow_p is None:
        slow_p = param.detach().clone()

    if do_wd and wd != 0:
        if decouple_wd:
            # weight_decay
            param = param.mul(1 - lr*wd)
        else:
            # l2_reg
            grad = grad.add(param, alpha=wd)

    if grad_avg is None:
        grad_avg = torch.zeros_like(param, memory_format=torch.preserve_format)
    if sqr_avg is None:
        sqr_avg  = torch.zeros_like(param, memory_format=torch.preserve_format)

    # average_grad, dampening=True
    grad_avg = grad_avg.mul(mom).add(grad, alpha=1-mom)

    # average_sqr_grad
    sqr_avg = sqr_avg.mul(sqr_mom).addcmul(grad, grad, value=1-sqr_mom)

    # radam_step
    debias1 = 1-mom**step
    debias2 = 1-sqr_mom**step
    r_inf = 2/(1-sqr_mom)-1
    r = r_inf - 2*step*sqr_mom**step/(1-sqr_mom**step)

    if r > 5:
        v = math.sqrt(((r-4)*(r-2)*r_inf)/((r_inf-4)*(r_inf-2)*r))
        denom = torch.sqrt(sqr_avg/debias2).add(eps)
        param = param.addcdiv(grad_avg, denom, value=-lr*v/debias1)
    else:
        param = param.add(grad_avg, alpha=-lr/debias1)

    # lookahead step
    if count % k != 0:
        p.set_(param)
        g.set_(grad)
    else:
        slow_p = slow_p.add(param.sub(slow_p), alpha=alpha)
        p.set_(slow_p)
        g.set_(grad)
    return torch.jit.annotate(Dict[str, Union[Tensor, int]], {'grad_avg': grad_avg, 'sqr_avg': sqr_avg, 'step': step, 'slow_p': slow_p})

In [None]:
show_doc(ranger_jit_step.name)

In [None]:
#|hide
def Ranger(params, lr, mom=0.95, sqr_mom=0.99, eps=1e-6, wd=0.01, k=6, alpha=0.5, decouple_wd=True, jit=False):
    if jit:
        return JitLookahead(params, ranger_jit_step, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps,
                            wd=wd, k=k, alpha=alpha, decouple_wd=decouple_wd)
    else:
        return Lookahead(RAdam(params, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd,
                               decouple_wd=decouple_wd, jit=False),
                         k=k, alpha=alpha)

In [None]:
#|hide
po = tensor([1,2,3])

params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Ranger(params_org, lr=0.1, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., jit=False)

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = Ranger(params_jit, lr=0.1, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., jit=True)

#The first 5 steps are normal RAdam steps
#The r factor is lower than 5 during the first 5 steps so updates use the average of gradients (all the same)
r_inf = 2/(1-0.99) - 1
for i in range(5):
    r = r_inf - 2*(i+1)*0.99**(i+1)/(1-0.99**(i+1))
    assert r <= 5
    opt_org.step()
    opt_jit.step()
p = tensor([0.95, 1.9, 2.85])
test_close(params_org[0], p)
test_close(params_org[0], params_jit[0])

#The r factor is greater than 5 for the sixth step so we update with RAdam
r = r_inf - 2*6*0.99**6/(1-0.99**6)
assert r > 5
opt_org.step()
opt_jit.step()
v = math.sqrt(((r-4) * (r-2) * r_inf)/((r_inf-4)*(r_inf-2)*r))
step = -0.1*0.1*v/(math.sqrt(0.1**2) + 1e-8)

#Since k=6, sixth step is a moving average of the 6 RAdam steps with the initial weight
test_close(params_org[0], po+((p+step)-po)*0.5)
test_close(params_org[0], params_jit[0])

for i in range(test_steps):
    opt_org.step()
    opt_jit.step()

test_close(params_org[0], params_jit[0])

In [None]:
#|hide
# test with weight decay
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Ranger(params_org, lr=0.1, wd=0.1, jit=False)

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = Ranger(params_jit, lr=0.1, wd=0.1, jit=True)

for i in range(test_steps):
    opt_org.step()
    opt_jit.step()

test_close(params_org[0], params_jit[0])