In [None]:
#|default_exp optimizer.fused

In [None]:
#|exporti
# Contains code from:
# fastai - Apache License 2.0 - Copyright (c) 2023 fast.ai

# Fused Optimizers
> Fused fastai optimizers using ForEach methods and TorchScript

fastxtend's fused optimizers are 21 to 293 percent faster, drop-in replacements for fastai native optimizers.

Like fastai optimizers, fastxtend fused optimizers support both discriminative learning rates across multiple parameter groups and per-parameter weight decay without any extra setup.

While all fastai optimizers have vertically fused TorchScript implementations, only a subset have horizontally fused ForEach implementations. These optimizers, [SGD](#sgd-optimizer), [Adam](#adam-optimizer), [RAdam](#radam-optimizer), [Lamb](#lamb-optimizer), and [Ranger](#ranger-optimizer), usually outperform their TorchScript counterparts in all but the tiniest models.

fastxtend ForEach optimizers are equivalent in performance to PyTorch ForEach optimizers with two parameter groups, one for applying weight decay and one for parameters without weight decay.

:::{.callout-important}
ForEach and TorchScript optimizers have only been tested on PyTorch 1.12+ and are not guaranteed to work on older versions.
:::

:::{.callout-note}
Documentation for individual optimizers are lightly adapted from the [fastai optimizer documentation](https://docs.fast.ai/optimizer.html). [Docments](https://nbdev.fast.ai/tutorials/best_practices.html#document-parameters-with-docments) and type hints have been  [upstreamed to fastai](https://github.com/fastai/fastai/pull/3847).

For implementation details, see the [ForEach](optimizer.foreach.html) or [TorchScript](optimizer.torchscript.html) documentation.

fastxtend ForEach optimizers are adapted from the PyTorch ForEach [`_multi_tensor`](https://github.com/pytorch/pytorch/tree/master/torch/optim) implementations, but seamlessly work with fastai features.
:::

## Fused Performance

As shown in @tbl-single, ForEach Optimizers are 21 to 293 percent faster[^faster] in [AdamW](#adam-optimizer) optimizer step performance relative to fastai implementations across benchmarked models. Complex optimizers without ForEach implementations, such as [QHAdam](#qhadam-optimizer), are up to 137 percent faster using TorchScript implementations.

| Model         | fastai Step | ForEach Step | ForEach Speedup | JIT Step | JIT Speedup |
| :-----------: | :---------: | :----------: | :-------------: | :------: | :---------: |
| XResNet18     | 26ms        | 12ms         | 109%            | 20ms     | 29%         |
| XResNet50     | 56ms        | 32ms         | 74%             | 46ms     | 20%         |
| XSE-ResNeXt50 | 72ms        | 43ms         | 68%             | 61ms     | 18%         |
| XResNet101    | 88ms        | 47ms         | 84%             | 68ms     | 30%         |
| DeBERTa Base  | 27ms        | 6.9ms        | 293%            | 19ms     | 46%         |

: Increase in [AdamW](#adam-optimizer) `opt_step` Speed vs fastai Native Optimizer {#tbl-single} {tbl-colwidths="[10,8,8,8,8,8]"}

This speedup persists with single or multiple parameter groups. Although more groups can lead to a small decrease in optimizer step speed, as shown by DeBERTa in @tbl-layers.

| Model             | Layers | fastai Step | ForEach Step | ForEach Speedup | JIT Step | JIT Speedup |
| :---------------: | :----: | :---------: | :----------: | :-------------: | :------: | :---------: |
| XResNet18         | 2      | 25ms        | 12ms         | 103%            | 19ms     | 30%         |
| XResNet50         | 2      | 56ms        | 32ms         | 76%             | 46ms     | 24%         |
| XSE-ResNeXt50     | 2      | 72ms        | 45ms         | 85%             | 61ms     | 29%         |
| XResNet101        | 2      | 87ms        | 47ms         | 60%             | 67ms     | 17%         |
| ConvNeXt Tiny     | 2      | 125ms       | 102ms        | 22%             | 115ms    | 9.4%        |
| ConvNeXt Small    | 2      | 200ms       | 165ms        | 21%             | 181ms    | 10%         |
| ViT Patch16 Small | 2      | 62ms        | 38ms         | 62%             | 52ms     | 20%         |
| DeBERTa Base      | 4      | 27ms        | 7.7ms        | 254%            | 19ms     | 47%         |

: Increase in [AdamW](#adam-optimizer) `opt_step` Speed With Multiple Param Groups vs fastai Native Optimizer {#tbl-layers} {tbl-colwidths="[10,2,8,8,8,8,8]"}

[^faster]: All optimizers benchmarked on a GeForce 3080 Ti using PyTorch 1.12.1, Cuda 11.6, Mixed Precision, [Channels Last](callback.channelslast.html) (except ViT and DeBERTa), and fastxtend's [Simple Profiler Callback](callback.simpleprofiler.html). Results may differ with other optimizers, models, hardware, and across benchmarking runs. Speedup is calculated from the total time spent on the optimization step.

## Examples

For backwards compatibility, all fastxtend optimizers return a fastai native optimizer by default. To use a fused version set `foreach=True` or `jit=True`.

```python
from fastai.vision.all import *
from fastxtend.vision.all import *

# Use ForEach AdamW
opt_func = adam(foreach=True)

# Or use TorchScript AdamW
opt_func = adam(jit=True)

Learner(..., opt_func=opt_func)
```

Or import fused optimizers independent of other fastxtend features.

```python
from fastai.vision.all import *
from fastxtend.optimizer.all import *

Learner(..., opt_func=partial(Adam, foreach=True))
```

:::{.callout-note}
`adam(...)` is a fastxtend convenience method equivalent to `partial(Adam, ...)`. fastextend adds lowercase convenience methods for all fastai optimizers.
:::

In [None]:
#|export
from __future__ import annotations
from typing import Optional, Dict

from fastcore.basics import partialler

from fastai.optimizer import (Optimizer, weight_decay, l2_reg, average_grad, sgd_step, momentum_step,
                              average_sqr_grad, rms_prop_step, step_stat, adam_step, radam_step,
                              qhadam_step, larc_layer_lr, larc_step, lamb_step, Lookahead)

from fastxtend.optimizer.torchscript import (JitOptimizer, radam_jit_step, sgd_jit_step, rmsprop_jit_step,
                                             adam_jit_step, radam_jit_step, qhadam_jit_step, larc_jit_step,
                                             lamb_jit_step, JitLookahead, ranger_jit_step)

from fastxtend.optimizer.foreach import (SGDForEachOptimizer, sgd_foreach_step, AdamForEachOptimizer,
                                         adam_foreach_step, RAdamForEachOptimizer, radam_foreach_step,
                                         LambForEachOptimizer, lamb_foreach_step, RangerForEachOptimizer,
                                         ranger_foreach_step)

from fastxtend.imports import *

## Test Utils -

In [None]:
#|hide
from fastxtend.test_utils import *

In [None]:
#|hide
# full tests in the optimizer.torchscript and optimizer.foreach notebooks
# tests are copied with light modifications from fastai
def tst_param(val, grad=None):
    "Create a tensor with `val` and a gradient of `grad` for testing"
    res = tensor([val]).float()
    res.grad = tensor([val/10 if grad is None else grad]).float()
    return res

def tst_params():
    r = L.range(4)
    return r.map(tst_param)

## SGD Optimizer

Stochastic gradient descent, optionally with momentum.

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients).

In [None]:
#|export
def SGD(
    params:Listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    mom:float=0., # Gradient moving average (β1) coefficient
    wd:float=0., # Optional weight decay (true or L2)
    decouple_wd:bool=True, # Apply true weight decay (SGDW) or L2 regularization (SGD)
    foreach:bool=False, # Use fused ForEach implementation
    jit:bool=False # Use fused TorchScript implementation
) -> Optimizer|SGDForEachOptimizer|JitOptimizer:
    "A fastai SGD/SGDW optimizer with fused ForEach and TorchScript implementations"
    if foreach:
        return SGDForEachOptimizer(params, sgd_foreach_step, lr=lr, mom=mom, wd=wd, decouple_wd=decouple_wd)
    elif jit:
        return JitOptimizer(params, sgd_jit_step, lr=lr, mom=mom, wd=wd, decouple_wd=decouple_wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        if mom != 0: cbs.append(average_grad)
        cbs.append(sgd_step if mom==0 else momentum_step)
        return Optimizer(params, cbs, lr=lr, mom=mom, wd=wd)

In [None]:
#|export
def sgd(
    mom:float=0., # Gradient moving average (β1) coefficient
    wd:float=0., # Optional weight decay (true or L2)
    decouple_wd:bool=True, # Apply true weight decay (SGDW) or L2 regularization (SGD)
    foreach:bool=False, # Use fused ForEach implementation
    jit:bool=False # Use fused TorchScript implementation
) -> Optimizer|SGDForEachOptimizer|JitOptimizer:
    "Partial function for the SGD/SGDW optimizer with fused ForEach and TorchScript implementations"
    return partialler(SGD, mom=mom, wd=wd, decouple_wd=decouple_wd, jit=jit, foreach=foreach)

In [None]:
#|hide
params_org = tst_params()
opt_org = SGD(params_org, lr=0.1)
opt_org.step()

params_jit = tst_params()
opt_jit = SGD(params_jit, lr=0.1, jit=True)
opt_jit.step()

params_for = tst_params()
opt_for = SGD(params_for, lr=0.1, foreach=True)
opt_for.step()

test_close([p.item() for p in params_org], [i*0.99 for i in range(4)])
test_close([p.item() for p in params_org], [p.item() for p in params_jit])
test_close([p.item() for p in params_org], [p.item() for p in params_for])

opt_org.step()
opt_jit.step()
opt_for.step()
test_close([p.item() for p in params_org], [i*0.98 for i in range(4)])
test_close([p.item() for p in params_org], [p.item() for p in params_jit])
test_close([p.item() for p in params_org], [p.item() for p in params_for])

## RMSProp Optimizer

RMSProp was introduced by Geoffrey Hinton in his [course](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf). What is named `sqr_mom` here is the `alpha` in the course. 

Optional weight decay of `wd` is applied as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients).

:::{.callout-note}
The order of the `mom` and `sqr_mom` hyperparameters has been swapped from fastai to follow the order of all the other fastai and fastxtend optimizers.
:::

In [None]:
#|export
def RMSProp(
    params:Listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    mom:float=0., # Gradient moving average (β1) coefficient
    sqr_mom:float=0.99, # Gradient squared moving average (β2) coefficient
    eps:float=1e-8, # Added for numerical stability
    wd:float=0., # Optional weight decay (true or L2)
    decouple_wd:bool=True, # Apply true weight decay (RMSPropW) or L2 regularization (RMSProp)
    jit:bool=False # Use fused TorchScript implementation
) -> Optimizer|JitOptimizer:
    "A fastai RMSProp/RMSPropW optimizer with a fused TorchScript implementation"
    if jit:
        return JitOptimizer(params, rmsprop_jit_step, lr=lr, mom=mom, sqr_mom=sqr_mom,
                            eps=eps, wd=wd, decouple_wd=decouple_wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += ([average_sqr_grad] if mom==0. else [average_grad, average_sqr_grad])
        cbs.append(rms_prop_step)
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, wd=wd, eps=eps)

In [None]:
#|export
def rmsprop(
    mom:float=0., # Gradient moving average (β1) coefficient
    sqr_mom:float=0.99, # Gradient squared moving average (β2) coefficient
    eps:float=1e-8, # Added for numerical stability
    wd:float=0., # Optional weight decay (true or L2)
    decouple_wd:bool=True, # Apply true weight decay (RMSPropW) or L2 regularization (RMSProp)
    jit:bool=False # Use fused TorchScript implementation
) -> Optimizer|JitOptimizer:
    "Partial function for the RMSProp/RMSPropW optimizer with a fused TorchScript implementation"
    return partialler(RMSProp, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, decouple_wd=decouple_wd, jit=jit)

In [None]:
#|hide
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = RMSProp(params_org, lr=0.1, mom=0.9)
opt_org.step()

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = RMSProp(params_jit, lr=0.1, mom=0.9, jit=True)
opt_jit.step()

test_close(params_org[0], tensor([0.,1.,2.]))
test_close(params_org[0], params_jit[0])

opt_org.step()
opt_jit.step()
step = - 0.1 * (0.1 + 0.9*0.1) / (math.sqrt((0.01*0.99+0.01) * 0.1**2) + 1e-8)
test_close(params_org[0], tensor([step, 1+step, 2+step]))
test_close(params_org[0], params_jit[0])

## Adam Optimizer

Adam was introduced by Diederik P. Kingma and Jimmy Ba in *[Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)*. For consistency across optimizers, fastai renamed `beta1` and `beta2` in the paper to `mom` and  `sqr_mom`. Note that the defaults also differ from the paper (0.99 for `sqr_mom` or `beta2`, 1e-5 for `eps`). Those values seem to be better from experimentation in a wide range of situations.

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients).

:::{.callout-note}
Don't forget that `eps` is an hyper-parameter you can change. Some models won't train without a very high `eps` like 0.1 (intuitively, the higher `eps` is, the closer Adam is to normal SGD). The usual default of 1e-8 is often too extreme in the sense Adam does't manage to get as good results as with SGD.
:::

In [None]:
#|export
def Adam(
    params:Listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    mom:float=0.9, # Gradient moving average (β1) coefficient
    sqr_mom:float=0.99, # Gradient squared moving average (β2) coefficient
    eps:float=1e-5, # Added for numerical stability
    wd:float=0.01, # Optional weight decay (true or L2)
    decouple_wd:bool=True, # Apply true weight decay (AdamW) or L2 regularization (Adam)
    foreach:bool=False, # Use fused ForEach implementation
    jit:bool=False # Use fused TorchScript implementation
) -> Optimizer|AdamForEachOptimizer|JitOptimizer:
    "A fastai Adam/AdamW optimizer with fused ForEach and TorchScript implementations"
    if foreach:
        return AdamForEachOptimizer(params, adam_foreach_step, lr=lr, mom=mom,
                                    sqr_mom=sqr_mom, eps=eps, wd=wd, decouple_wd=decouple_wd)
    elif jit:
        return JitOptimizer(params, adam_jit_step, lr=lr, mom=mom, sqr_mom=sqr_mom,
                            eps=eps, wd=wd, decouple_wd=decouple_wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, adam_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

In [None]:
#|export
def adam(
    mom:float=0.9, # Gradient moving average (β1) coefficient
    sqr_mom:float=0.99, # Gradient squared moving average (β2) coefficient
    eps:float=1e-5, # Added for numerical stability
    wd:float=0.01, # Optional weight decay (true or L2)
    decouple_wd:bool=True, # Apply true weight decay (RMSPropW) or L2 regularization (RMSProp)
    foreach:bool=False, # Use fused ForEach implementation
    jit:bool=False # Use fused TorchScript implementation
) -> Optimizer|AdamForEachOptimizer|JitOptimizer:
    "Partial function for the Adam/AdamW optimizer with fused ForEach and TorchScript implementations"
    return partialler(Adam, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd,
                      decouple_wd=decouple_wd, foreach=foreach, jit=jit)

In [None]:
#|hide
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Adam(params_org, lr=0.1, wd=0)
opt_org.step()

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = Adam(params_jit, lr=0.1, wd=0, jit=True)
opt_jit.step()

params_for = tst_param([1,2,3], [0.1,0.2,0.3])
opt_for = Adam(params_for, lr=0.1, wd=0, foreach=True)
opt_for.step()

step = -0.1 * 0.1 / (math.sqrt(0.1**2) + 1e-8)
test_close(params_org[0], tensor([1+step, 2+step, 3+step]))
test_close([p[0].item() for p in params_org], [p[0].item() for p in params_jit])
test_close([p[0].item() for p in params_org], [p[0].item() for p in params_for])

opt_org.step()
opt_jit.step()
opt_for.step()
test_close(params_org[0], tensor([1+2*step, 2+2*step, 3+2*step]), eps=1e-3)
test_close([p[0].item() for p in params_org], [p[0].item() for p in params_jit])
test_close([p[0].item() for p in params_org], [p[0].item() for p in params_for])

## RAdam Optimizer

RAdam (for rectified Adam) was introduced by Zhang et al. in *[On the Variance of the Adaptive Default learning rate and Beyond](https://arxiv.org/abs/1908.03265)* to slightly modify the Adam optimizer to be more stable at the beginning of training (and thus not require a long warmup). They use an estimate of the variance of the moving average of the squared gradients (the term in the denominator of traditional Adam) and rescale this moving average by this term before performing the update.

The native fastai implementation also incorporates [SAdam](https://arxiv.org/abs/1908.00700); set `beta` to enable this (definition same as in the paper).

:::{.callout-note}
fastxtend ForEach and TorchScript implementations do not support `beta` and SAdam.
:::

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients).

In [None]:
#|export
def RAdam(
    params:Listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    mom:float=0.9, # Gradient moving average (β1) coefficient
    sqr_mom:float=0.99, # Gradient squared moving average (β2) coefficient
    eps:float=1e-5, # Added for numerical stability
    wd:float=0., # Optional weight decay (true or L2)
    beta:float=0., # Set to enable SAdam with native fastai RAdam
    decouple_wd:bool=True, # Apply true weight decay (RAdamW) or L2 regularization (RAdam)
    foreach:bool=False, # Use fused ForEach implementation
    jit:bool=False # Use fused TorchScript implementation
) -> Optimizer|RAdamForEachOptimizer|JitOptimizer:
    "A fastai RAdam/RAdamW optimizer with fused ForEach and TorchScript implementations"
    if (foreach or jit) and beta != 0:
        raise ValueError(f'ForEach and TorchScript RAdam does not use {beta=}, set `jit` & `foreach` to False if beta!=0')
    if foreach:
        return RAdamForEachOptimizer(params, radam_foreach_step, lr=lr, mom=mom, sqr_mom=sqr_mom,
                                     eps=eps, wd=wd, decouple_wd=decouple_wd)
    if jit:
        return JitOptimizer(params, radam_jit_step, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps,
                            wd=wd, decouple_wd=decouple_wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, radam_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, beta=beta)

In [None]:
#|export
def radam(
    mom:float=0.9, # Gradient moving average (β1) coefficient
    sqr_mom:float=0.99, # Gradient squared moving average (β2) coefficient
    eps:float=1e-5, # Added for numerical stability
    wd:float=0., # Optional weight decay (true or L2)
    beta:float=0., # Set to enable SAdam with native fastai RAdam
    decouple_wd:bool=True, # Apply true weight decay (RMSPropW) or L2 regularization (RMSProp)
    foreach:bool=False, # Use fused ForEach implementation
    jit:bool=False # Use fused TorchScript implementation
) -> Optimizer|RAdamForEachOptimizer|JitOptimizer:
    "Partial function for the RAdam/RAdamW optimizer with fused ForEach and TorchScript implementations"
    return partialler(RAdam, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, beta=beta,
                      decouple_wd=decouple_wd, foreach=foreach, jit=jit)

In [None]:
#|hide
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = RAdam(params_org, lr=0.1)

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = RAdam(params_jit, lr=0.1, jit=True)

params_for = tst_param([1,2,3], [0.1,0.2,0.3])
opt_for = RAdam(params_for, lr=0.1, foreach=True)

#The r factor is lower than 5 during the first 5 steps so updates use the average of gradients (all the same)
r_inf = 2/(1-0.99) - 1
for i in range(5):
    r = r_inf - 2*(i+1)*0.99**(i+1)/(1-0.99**(i+1))
    assert r <= 5
    opt_org.step()
    opt_jit.step()
    opt_for.step()
p = tensor([0.95, 1.9, 2.85])
test_close(params_org[0], p)
test_close(params_org[0], params_jit[0])
test_close(params_org[0], params_for[0])

#The r factor is greater than 5 for the sixth step so we update with RAdam
r = r_inf - 2*6*0.99**6/(1-0.99**6)
assert r > 5
opt_org.step()
opt_jit.step()
opt_for.step()
v = math.sqrt(((r-4) * (r-2) * r_inf)/((r_inf-4)*(r_inf-2)*r))
step = -0.1*0.1*v/(math.sqrt(0.1**2) + 1e-8)
test_close(params_org[0], p+step)
test_close(params_org[0], params_jit[0])
test_close(params_org[0], params_for[0])

## QHAdam Optimizer

QHAdam (for Quasi-Hyperbolic Adam) was introduced by Ma & Yarats in *[Quasi-Hyperbolic Momentum and Adam for Deep Learning](https://arxiv.org/abs/1810.06801)* as a *"computationally cheap, intuitive to interpret, and simple to implement"* optimizer. Additional code can be found in their [qhoptim repo](https://github.com/facebookresearch/qhoptim). QHAdam is based on QH-Momentum, which introduces the immediate discount factor `nu`, encapsulating plain SGD (`nu = 0`) and momentum (`nu = 1`). QH-Momentum is defined below, where $g_t+1$ is the update of the moment. An interpretation of QHM is as a nu-weighted average of the momentum update step and the plain SGD update step.

$$θ_t+1 ← θ_t − lr * [(1 − nu) · ∇L_t(θ_t) + nu · g_t+1]$$

QHAdam takes the concept behind QHM above and applies it to Adam, replacing both of Adam’s moment estimators with quasi-hyperbolic terms. 

The paper's suggested default parameters are `mom = 0.999`, `sqr_mom = 0.999`, `nu_1 = 0.7` and `and nu_2 = 1.0`. When training is not stable, it is possible that setting `nu_2 < 1` can improve stability by imposing a tighter step size bound. Note that QHAdam recovers Adam when `nu_1 = nu_2 = 1.0`. QHAdam recovers RMSProp (Hinton et al., 2012) when `nu_1 = 0` and `nu_2 = 1`, and NAdam (Dozat, 2016) when `nu_1 = mom` and `nu_2 = 1`.

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients).

In [None]:
#|export
def QHAdam(
    params:Listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    mom:float=0.999, # Gradient moving average (β1) coefficient
    sqr_mom:float=0.999, # Gradient squared moving average (β2) coefficient
    nu_1:float=0.7, # QH immediate discount factor
    nu_2:float=1.0, # QH momentum discount factor
    eps:float=1e-8, # Added for numerical stability
    wd:float=0., # Optional weight decay (true or L2)
    decouple_wd:bool=True, # Apply true weight decay (QHAdamW) or L2 regularization (QHAdam)
    jit:bool=False # Use fused TorchScript implementation
) -> Optimizer|JitOptimizer:
    "A fastai QHAdam/QHAdamW optimizer with a fused TorchScript implementation"
    if jit:
        return JitOptimizer(params, qhadam_jit_step, lr=lr, nu_1=nu_1, nu_2=nu_2, mom=mom,
                            sqr_mom=sqr_mom, eps=eps, wd=wd, decouple_wd=decouple_wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, qhadam_step]
        return Optimizer(params, cbs, lr=lr, nu_1=nu_1, nu_2=nu_2, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

In [None]:
#|export
def qhadam(
    mom:float=0.999, # Gradient moving average (β1) coefficient
    sqr_mom:float=0.999, # Gradient squared moving average (β2) coefficient
    nu_1:float=0.7, # QH immediate discount factor
    nu_2:float=1.0, # QH momentum discount factor
    eps:float=1e-8, # Added for numerical stability
    wd:float=0., # Optional weight decay (true or L2)
    decouple_wd:bool=True, # Apply true weight decay (RMSPropW) or L2 regularization (RMSProp)
    jit:bool=False # Use fused TorchScript implementation
) -> Optimizer|JitOptimizer:
    "Partial function for the QHAdam/QHAdamW optimizer with a fused TorchScript implementation"
    return partialler(QHAdam, mom=mom, sqr_mom=sqr_mom, nu_1=nu_1, nu_2=nu_2, eps=eps,
                      wd=wd, decouple_wd=decouple_wd, jit=jit)

In [None]:
#|hide
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = QHAdam(params_org, lr=0.1)
opt_org.step()

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = QHAdam(params_jit, lr=0.1, jit=True)
opt_jit.step()

step = -0.1 * (((1-0.7) * 0.1) + (0.7 * 0.1)) / (
     math.sqrt(((1-1.0) * 0.1**2) + (1.0 * 0.1**2)) + 1e-8)

test_close(params_org[0], tensor([1+step, 2+step, 3+step]))
test_close(params_org[0], params_jit[0])

opt_org.step()
opt_jit.step()
test_close(params_org[0], tensor([1+2*step, 2+2*step, 3+2*step]), eps=1e-3)
test_close(params_org[0], params_jit[0])

## LARS/LARC Optimizer

The LARS optimizer was first introduced in *[Large Batch Training of Convolutional Networks](https://arxiv.org/abs/1708.03888)* then refined in its LARC variant (original LARS is with `clip=False`). A Default learning rate is computed for each individual layer with a certain `trust_coefficient`, then clipped to be always less than `lr`.

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients).

In [None]:
#|export
def Larc(
    params:Listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    mom:float=0.9, # Gradient moving average (β1) coefficient
    clip:bool=True, # LARC if clip=True, LARS if clip=False
    trust_coeff:float=0.02, # Trust coeffiecnet for calculating layerwise LR
    eps:float=1e-8, # Added for numerical stability
    wd:float=0., # Optional weight decay (true or L2)
    decouple_wd:bool=True, # Apply true weight decay or L2 regularization
    jit:bool=False # Use fused TorchScript implementation
) -> Optimizer|JitOptimizer:
    "A fastai LARC/LARS optimizer with a fused TorchScript implementation"
    if jit:
        cb = partial(larc_jit_step, clip=clip)
        return JitOptimizer(params, cb, lr=lr, mom=mom, trust_coeff=trust_coeff,
                            eps=eps, wd=wd, decouple_wd=decouple_wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        if mom!=0.: cbs.append(average_grad)
        cbs += [partial(larc_layer_lr, clip=clip), larc_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, trust_coeff=trust_coeff, eps=eps, wd=wd)

In [None]:
#|export
def larc(
    mom:float=0.9, # Gradient moving average (β1) coefficient
    clip:bool=True, # LARC if clip=True, LARS if clip=False
    trust_coeff:float=0.02, # Trust coeffiecnet for calculating layerwise LR
    eps:float=1e-8, # Added for numerical stability
    wd:float=0., # Optional weight decay (true or L2)
    decouple_wd:bool=True, # Apply true weight decay (RMSPropW) or L2 regularization (RMSProp)
    jit:bool=False # Use fused TorchScript implementation
) -> Optimizer|JitOptimizer:
    "Partial function for the LARC/LARS optimizer with a fused TorchScript implementation"
    return partialler(Larc, mom=mom, clip=clip, eps=eps, trust_coeff=trust_coeff,
                      wd=wd, decouple_wd=decouple_wd, jit=jit)

In [None]:
#|hide
params_org = [tst_param([1,2,3], [0.1,0.2,0.3]), tst_param([1,2,3], [0.01,0.02,0.03])]
opt_org = Larc(params_org, lr=0.1)
opt_org.step()

params_jit = [tst_param([1,2,3], [0.1,0.2,0.3]), tst_param([1,2,3], [0.01,0.02,0.03])]
opt_jit = Larc(params_jit, lr=0.1, jit=True)
opt_jit.step()

#First param local lr is 0.02 < lr so it's not clipped
test_close(opt_org.state[params_org[0]]['local_lr'], 0.02)
#Second param local lr is 0.2 > lr so it's clipped
test_eq(opt_org.state[params_org[1]]['local_lr'], 0.1)

test_close(params_org[0], tensor([0.998,1.996,2.994]))
test_close(params_org[0], params_jit[0])
test_close(params_org[1], tensor([0.999,1.998,2.997]))
test_close(params_org[1], params_jit[1])

## LAMB Optimizer

LAMB was introduced in *[Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962)*. Intuitively, it's LARC applied to Adam. As in `Adam`, `beta1` and `beta2` in the paper is renamed to `mom` and  `sqr_mom`. Note that the defaults also differ from the paper (0.99 for `sqr_mom` or `beta2`, 1e-5 for `eps`). Those values seem to be better from experimentation in a wide range of situations.

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients).

In [None]:
#|export
def Lamb(
    params:Listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    mom:float=0.9, # Gradient moving average (β1) coefficient
    sqr_mom:float=0.99, # Gradient squared moving average (β2) coefficient
    eps:float=1e-5, # Added for numerical stability
    wd:float=0., # Optional weight decay (true or L2)
    decouple_wd:bool=True, # Apply true weight decay or L2 regularization
    foreach:bool=False, # Use fused ForEach implementation
    jit:bool=False # Use fused TorchScript implementation
) -> Optimizer|LambForEachOptimizer|JitOptimizer:
    "A fastai LAMB optimizer with fused ForEach and TorchScript implementations"
    if foreach:
        return LambForEachOptimizer(params, lamb_foreach_step, lr=lr, mom=mom, sqr_mom=sqr_mom,
                                    eps=eps, wd=wd, decouple_wd=decouple_wd)
    if jit:
        return JitOptimizer(params, lamb_jit_step, lr=lr, mom=mom, sqr_mom=sqr_mom,
                            eps=eps, wd=wd, decouple_wd=decouple_wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, lamb_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

In [None]:
#|export
def lamb(
    mom:float=0.9, # Gradient moving average (β1) coefficient
    sqr_mom:float=0.99, # Gradient squared moving average (β2) coefficient
    eps:float=1e-5, # Added for numerical stability
    wd:float=0., # Optional weight decay (true or L2)
    decouple_wd:bool=True, # Apply true weight decay (RMSPropW) or L2 regularization (RMSProp)
    foreach:bool=False, # Use fused ForEach implementation
    jit:bool=False # Use fused TorchScript implementation
) -> Optimizer|LambForEachOptimizer|JitOptimizer:
    "Partial function for the LAMB optimizer with fused ForEach and TorchScript implementations"
    return partialler(Lamb, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd,
                      decouple_wd=decouple_wd, foreach=foreach, jit=jit)

In [None]:
#|hide
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Lamb(params_org, lr=0.1, wd=0.)
opt_org.step()

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = Lamb(params_jit, lr=0.1, wd=0., jit=True)
opt_jit.step()

params_for = tst_param([1,2,3], [0.1,0.2,0.3])
opt_for = Lamb(params_for, lr=0.1, wd=0., foreach=True)
opt_for.step()

test_close(params_org[0], tensor([0.7840,1.7840,2.7840]), eps=1e-3)
test_close(params_org[0], params_jit[0])
test_close(params_org[0], params_for[0])

## Ranger Optimizer

:::{.callout-warning}
Ranger is the only non-backward compatible fastxtend Optimizer. `Ranger` is equivalent fastai's [<code>ranger</code>](https://docs.fast.ai/optimizer.html#ranger) while fastxtend's `ranger` is a partial function which returns `Ranger`. Most fastai code should be uneffected by this change.
:::

Lookahead was introduced by Zhang et al. in *[Lookahead Optimizer: k steps forward, 1 step back](https://arxiv.org/abs/1907.08610)*. With Lookahead, the final weights (slow weights) are a moving average of the normal weights (fast weights). Every k steps, Lookahead modifieds the current weights by a moving average of the fast weights (normal weights) with the slow weights (the copy of old weights k steps ago). Those slow weights act like a stability mechanism.

Ranger was introduced by Less Wright in *[New Deep Learning Optimizer, Ranger: Synergistic combination of RAdam + Lookahead for the best of both](https://lessw.medium.com/new-deep-learning-optimizer-ranger-synergistic-combination-of-radam-lookahead-for-the-best-of-2dc83f79a48d)*. It combines RAdam and Lookahead together in one optimizer and reduces the need for hyperparameter tuning due to a combination of RAdam's warmup heuristic and Lookahead's interpolation of parameter weights.

Ranger performs best on vision tasks when paired with the `fit_flat_cos` or `fit_flat_varied` schedulers.

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients).

:::{.callout-important}
While fastai's `Lookahead` can be applied to any optimizer, fastxtend's `JitLookahead` must have a custom written TorchScript callback and `ForEachOptimizer` a custom Lookahead optimizer step. Currently ranger with RAdam is the only TorchScript and ForEach optimizer with Lookahead support.
:::

In [None]:
#|export
def Ranger(
    params:Listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    mom:float=0.95, # Gradient moving average (β1) coefficient
    sqr_mom:float=0.99, # Gradient squared moving average (β2) coefficient
    eps:float=1e-6, # Added for numerical stability
    wd:float=0.01, # Optional weight decay (true or L2)
    k:int=6, # How often to conduct Lookahead step
    alpha:float=0.5, # Slow weight moving average coefficient
    decouple_wd:bool=True, # Apply true weight decay (RAdamW) or L2 regularization (RAdam)
    foreach:bool=False, # Use fused ForEach implementation
    jit:bool=False # Use fused TorchScript implementation
) -> Lookahead|RangerForEachOptimizer|JitLookahead:
    "Convenience method for `Lookahead` with `RAdam` fused ForEach and TorchScript implementations"
    if foreach:
        return RangerForEachOptimizer(params, ranger_foreach_step, lr=lr, mom=mom, sqr_mom=sqr_mom,
                                      eps=eps, wd=wd, decouple_wd=decouple_wd, k=k, alpha=alpha)
    elif jit:
        return JitLookahead(params, ranger_jit_step, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps,
                            wd=wd, k=k, alpha=alpha, decouple_wd=decouple_wd)
    else:
        return Lookahead(RAdam(params, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd,
                               decouple_wd=decouple_wd),
                         k=k, alpha=alpha)

In [None]:
#|export
def ranger(
    mom:float=0.95, # Gradient moving average (β1) coefficient
    sqr_mom:float=0.99, # Gradient squared moving average (β2) coefficient
    eps:float=1e-6, # Added for numerical stability
    wd:float=0.01, # Optional weight decay (true or L2)
    k:int=6, # How often to conduct Lookahead step
    alpha:float=0.5, # Slow weight moving average coefficient
    decouple_wd:bool=True, # Apply true weight decay (RAdamW) or L2 regularization (RAdam)
    foreach:bool=False, # Use fused ForEach implementation
    jit:bool=False # Use fused TorchScript implementation
) -> Lookahead|RangerForEachOptimizer|JitLookahead:
    "Partial function of the onvenience method for `Lookahead` with `RAdam` fused ForEach and TorchScript implementations"
    return partialler(Ranger, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, k=k,
                      alpha=alpha, decouple_wd=decouple_wd, foreach=foreach, jit=jit)

In [None]:
#|hide
po = tensor([1,2,3])

params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Ranger(params_org, lr=0.1, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0.)

params_jit = tst_param([1,2,3], [0.1,0.2,0.3])
opt_jit = Ranger(params_jit, lr=0.1, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., jit=True)

params_for = tst_param([1,2,3], [0.1,0.2,0.3])
opt_for = Ranger(params_for, lr=0.1, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., foreach=True)

#The first 5 steps are normal RAdam steps
#The r factor is lower than 5 during the first 5 steps so updates use the average of gradients (all the same)
r_inf = 2/(1-0.99) - 1
for i in range(5):
    r = r_inf - 2*(i+1)*0.99**(i+1)/(1-0.99**(i+1))
    assert r <= 5
    opt_org.step()
    opt_jit.step()
    opt_for.step()
p = tensor([0.95, 1.9, 2.85])
test_close(params_org[0], p)
test_close(params_org[0], params_jit[0])
test_close(params_org[0], params_for[0])

#The r factor is greater than 5 for the sixth step so we update with RAdam
r = r_inf - 2*6*0.99**6/(1-0.99**6)
assert r > 5
opt_org.step()
opt_jit.step()
opt_for.step()
v = math.sqrt(((r-4) * (r-2) * r_inf)/((r_inf-4)*(r_inf-2)*r))
step = -0.1*0.1*v/(math.sqrt(0.1**2) + 1e-8)

#Since k=6, sixth step is a moving average of the 6 RAdam steps with the initial weight
test_close(params_org[0], po+((p+step)-po)*0.5)
test_close(params_org[0], params_jit[0])
test_close(params_org[0], params_for[0])