In [None]:
#|default_exp optimizer.foreach

# ForEach Optimizers
> Fused fastai optimizers using PyTorch ForEach methods for improved performance

fastxtend ForEach optimizers are adapted from the PyTorch ForEach [`_multi_tensor`](https://github.com/pytorch/pytorch/tree/master/torch/optim) implementations. They are 21 to 293 percent faster relative to fastai native optimizers depending on the model. 

The primary difference between PyTorch's ForEach implementations and fastxtend is fastxtend's ForEach optimizers apply per-parameter weight decay in one optimizer step instead of requiring a seperate weight decay parameter group and a non-weight decay parameter group. This also allows seamless support for fastai's [discriminative learning rates](https://docs.fast.ai/callback.schedule.html#learner.fine_tune).

Unlike fastai optimizers, which are made of [multiple stepper callbacks](https://docs.fast.ai/optimizer.html#basic-steppers) and share one `Optimizer`, ForEach optimizers require an optimizer specific `ForEachOptimizer` implementation. 

Currently `SGD`, `Adam`, `RAdam`, `Lamb`, and `Ranger` have ForEach implementations.

:::{.callout-important}
ForEach optimizers have only been tested on PyTorch 1.12 and are not guaranteed to work on older versions.
:::

ForEach optimizers are faster due to horizontal fusion across multiple parameters. Using `xresnet50` and the simplest form of `SGD` as an example, a ForEach optimizer would construct a list of all 167 `params` and their `grads` before performing one horizontally fused step.

```python
def simple_sgd_foreach(params:list[Tensor], grads:list[Tensor], lr:float)
    torch._foreach_add_(params, grads, alpha=-lr)
```

In contrast, a standard PyTorch optimizer would call the simple `SGD` step 167 times:

```python
def simple_sgd_standard(param:Tensor, lr:float)
    param.add_(param.grad, alpha=-lr)
```

ForEach optimizers are tested to be equal to fastai optimizers for 25 steps using [nbdev's GitHub CI](https://nbdev.fast.ai/tutorials/tutorial.html#check-out-your-workflows).

In [None]:
#|export
from __future__ import annotations

import numpy as np

from fastai.optimizer import Optimizer

from fastxtend.imports import *

## Test Utils -

In [None]:
#|hide
from nbdev.showdoc import *

from fastai.optimizer import (weight_decay, l2_reg, average_grad, sgd_step, momentum_step, 
                              average_sqr_grad, rms_prop_step, step_stat, adam_step, radam_step, 
                              larc_layer_lr, larc_step, lamb_step, Lookahead)

from fastxtend.test_utils import *

In [None]:
#|hide
# tests are copied with light modifications from fastai
test_steps = 25

def tst_param(val, grad=None):
    "Create a tensor with `val` and a gradient of `grad` for testing"
    res = tensor([val]).float()
    res.grad = tensor([val/10 if grad is None else grad]).float()
    return res

def tst_params():
    r = L.range(4)
    return r.map(tst_param)

## ForEachOptimizer -

In [None]:
#|exporti
class ForEachOptimizer(Optimizer):
    "Base foreach optimizer class, updating `params` with `opt_step` instead of `Optimizer.cbs`"
    def __init__(self,
        params:listified[Tensor], # Model parameters
        opt_step:Callable, # `ForEachOptimizer` optimizer step
        decouple_wd:bool=True, # Use true weight decay or L2 regularization, if applicable
        **defaults # Optimizer specific hyper parameters
    ):
        if notmax_torch('1.12'):
            warn(f'ForEach optimizers are untested on PyTorch {torch.__verson__}, recommended to use 1.12 or newer')
        super().__init__(params, [None], True, **defaults)
        self.opt_step = opt_step
        self.decouple_wd = decouple_wd

In [None]:
show_doc(ForEachOptimizer)

## SGD -

In [None]:
#|exporti
def sgd_foreach_step(p:list[Tensor], g:list[Tensor], grad_avg:list[Tensor|None], ones:list[Tensor|None], 
                     do_wd:np.ndarray[Any, bool], lr:float, wd:float, mom:float, decouple_wd:bool, **kwargs):

    if wd != 0:
        if decouple_wd:
            # weight_decay
            wd = np.where(do_wd, 1-lr*wd, 1.)
            torch._foreach_mul_(p, scalars=wd.tolist())
        else:
            # l2_reg
            wd = np.where(do_wd, wd, 1.)
            torch._foreach_addcdiv_(g, p, ones, scalars=wd.tolist())
            # cannot use scalers with foreach_add & multiple tensors, so divide by one with foreach_addcdiv

    if mom != 0:
        # average_grad
        torch._foreach_mul_(grad_avg, mom)
        torch._foreach_add_(grad_avg, g)

        # momentum_step
        torch._foreach_add_(p, grad_avg, alpha=-lr)
    else:
        # sgd_step
        torch._foreach_add_(p, g, alpha=-lr)

In [None]:
#|exporti
class SGDForEachOptimizer(ForEachOptimizer):
    "A `ForEachOptimizer` with a modified step for `sgd_foreach_step`"
    @torch.no_grad()
    def step(self, closure=None):
        if closure is not None: raise NotImplementedError("fastai optimizers currently do not support closure")
        for pg, hyper in zip(self.param_lists, self.hypers):
            pl, gl, grad_avg, ones, do_wd = [], [], [], [], []

            for p in pg:
                if hasattr(p, 'grad') and p.grad is not None:
                    state = self.state[p]

                    if 'setup' not in state:
                        if hyper['mom'] != 0:
                            state['grad_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        if not self.decouple_wd:
                            state['ones'] = torch.ones_like(p, memory_format=torch.preserve_format)
                        state['setup'] = True
                    
                    pl.append(p)
                    gl.append(p.grad)
                    grad_avg.append(state.get('grad_avg', None))
                    ones.append(state.get('ones', None))
                    do_wd.append(state.get('do_wd', True))

            self.opt_step(p=pl, g=gl, grad_avg=grad_avg, ones=ones, do_wd=np.array(do_wd, dtype=bool), 
                          decouple_wd=self.decouple_wd, **hyper)

In [None]:
show_doc(SGDForEachOptimizer)

In [None]:
#|hide
def SGD(params, lr, mom=0., wd=0., decouple_wd=True, foreach=False):
    if foreach:
        return SGDForEachOptimizer(params, sgd_foreach_step, lr=lr, mom=mom, wd=wd, decouple_wd=decouple_wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        if mom != 0: cbs.append(average_grad)
        cbs.append(sgd_step if mom==0 else momentum_step)
        return Optimizer(params, cbs, lr=lr, mom=mom, wd=wd)

In [None]:
#|hide
# Vanilla SGD
params_org = tst_params()
opt_org = SGD(params_org, lr=0.1, foreach=False)
opt_org.step()

params_for = tst_params()
opt_for = SGD(params_for, lr=0.1, foreach=True)
opt_for.step()

test_close([p.item() for p in params_org], [i*0.99 for i in range(4)])
test_close([p.item() for p in params_org], [p.item() for p in params_for])

opt_org.step()
opt_for.step()
test_close([p.item() for p in params_org], [i*0.98 for i in range(4)])
test_close([p.item() for p in params_org], [p.item() for p in params_for])

for i in range(test_steps):
    opt_org.step()
    opt_for.step()
test_close([p.item() for p in params_org], [p.item() for p in params_for])

In [None]:
#|hide
# SGD with momentum
params_org = tst_params()
opt_org = SGD(params_org, lr=0.1, mom=0.9, foreach=False)
opt_org.step()

params_for = tst_params()
opt_for = SGD(params_for, lr=0.1, mom=0.9, foreach=True)
opt_for.step()
test_close([p.item() for p in params_org], [i*0.99 for i in range(4)])
test_close([p.item() for p in params_org], [p.item() for p in params_for])

opt_org.step()
opt_for.step()
test_close([p.item() for p in params_org], [i*(1 - 0.1 * (0.1 + 0.1*1.9)) for i in range(4)])
test_close([p.item() for p in params_org], [p.item() for p in params_for])
for i,p in enumerate(params_org):
    test_close(opt_org.state[p]['grad_avg'].item(), i*0.19)
for i,p in enumerate(params_for):
    test_close(opt_for.state[p]['grad_avg'].item(), i*0.19)

for i in range(test_steps):
    opt_org.step()
    opt_for.step()
test_close([p.item() for p in params_org], [p.item() for p in params_for])

In [None]:
#|hide
#Weight decay
params_org = tst_params()
opt_org = SGD(params_org, lr=0.1, mom=0.9, wd=0.1, foreach=False)
opt_org.step()
opt_org.step()

params_for = tst_params()
opt_for = SGD(params_for, lr=0.1, mom=0.9, wd=0.1, foreach=True)
opt_for.step()
opt_for.step()

test_close([p.item() for p in params_org], [i*0.9512 for i in range(4)])
test_close([p.item() for p in params_org], [p.item() for p in params_for])

for i in range(test_steps):
    opt_org.step()
    opt_for.step()
test_close([p.item() for p in params_org], [p.item() for p in params_for])

In [None]:
#|hide
#L2 reg
params_org = tst_params()
opt_org = SGD(params_org, lr=0.1, mom=0.9, wd=0.1, decouple_wd=False, foreach=False)
opt_org.step()
opt_org.step()

params_for = tst_params()
opt_for = SGD(params_for, lr=0.1, mom=0.9, wd=0.1, decouple_wd=False, foreach=True)
opt_for.step()
opt_for.step()

test_close([p.item() for p in params_org], [i*0.9322 for i in range(4)])
test_close([p.item() for p in params_org], [p.item() for p in params_for])

for i in range(test_steps):
    opt_org.step()
    opt_for.step()
test_close([p.item() for p in params_org], [p.item() for p in params_for])

## Adam -

In [None]:
#|exporti
def adam_foreach_step(p:list[Tensor], g:list[Tensor], grad_avg:list[Tensor], sqr_avg:list[Tensor], ones:list[Tensor|None], 
                      steps:np.ndarray[Any, int], do_wd:np.ndarray[Any, bool], lr:float, wd:float, mom:float, sqr_mom:float, 
                      eps:float, decouple_wd:bool, **kwargs):

    if wd != 0:
        if decouple_wd:
            # weight_decay
            wd = np.where(do_wd, 1-lr*wd, 1.)
            torch._foreach_mul_(p, scalars=wd.tolist())
        else:
            # l2_reg
            wd = np.where(do_wd, wd, 1.)
            torch._foreach_addcdiv_(g, p, ones, scalars=wd.tolist())
            # cannot use scalers with foreach_add & multiple tensors, so divide by one with foreach_addcdiv

    # average_grad, dampening=True
    torch._foreach_mul_(grad_avg, mom)
    torch._foreach_add_(grad_avg, g, alpha=1-mom)

    # average_sqr_grad
    torch._foreach_mul_(sqr_avg, sqr_mom)
    torch._foreach_addcmul_(sqr_avg, g, g, value=1-sqr_mom)

    # adam_step
    debias1 = -lr / (1 - mom**steps)
    debias2 = np.sqrt(1 - sqr_mom**steps)

    sqr_avg_debias2 = torch._foreach_sqrt(sqr_avg)
    torch._foreach_div_(sqr_avg_debias2, debias2.tolist())
    torch._foreach_add_(sqr_avg_debias2, eps)

    torch._foreach_addcdiv_(p, grad_avg, sqr_avg_debias2, debias1.tolist())

In [None]:
#|exporti
class AdamForEachOptimizer(ForEachOptimizer):
    "An `ForEachOptimizer` with a modified step for `adam_foreach_step`"
    @torch.no_grad()
    def step(self, closure=None):
        if closure is not None: raise NotImplementedError("fastai optimizers currently do not support closure")
        for pg, hyper in zip(self.param_lists, self.hypers):
            pl, gl, grad_avg, sqr_avg, ones, steps, do_wd = [], [], [], [], [], [], []

            for p in pg:
                if hasattr(p, 'grad') and p.grad is not None:
                    state = self.state[p]

                    if 'step' not in state:
                        state['grad_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        state['sqr_avg']  = torch.zeros_like(p, memory_format=torch.preserve_format)
                        if not self.decouple_wd:
                            state['ones'] = torch.ones_like(p, memory_format=torch.preserve_format)
                        state['step'] = 0

                    state['step'] += 1
                    pl.append(p)
                    gl.append(p.grad)
                    grad_avg.append(state['grad_avg'])
                    sqr_avg.append(state['sqr_avg'])
                    ones.append(state.get('ones', None))
                    steps.append(state['step'])
                    do_wd.append(state.get('do_wd', True))

            self.opt_step(p=pl, g=gl, grad_avg=grad_avg, sqr_avg=sqr_avg, ones=ones, 
                          steps=np.array(steps, dtype=np.int32), do_wd=np.array(do_wd, dtype=bool), 
                          decouple_wd=self.decouple_wd, **hyper)

In [None]:
show_doc(AdamForEachOptimizer)

In [None]:
#|hide
def Adam(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0.01, decouple_wd=True, foreach=False):
    if foreach:
        return AdamForEachOptimizer(params, adam_foreach_step, lr=lr, mom=mom, 
                                    sqr_mom=sqr_mom, eps=eps, wd=wd, decouple_wd=decouple_wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, adam_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

In [None]:
#|hide
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Adam(params_org, lr=0.1, wd=0, foreach=False)
opt_org.step()

params_for = tst_param([1,2,3], [0.1,0.2,0.3])
opt_for = Adam(params_for, lr=0.1, wd=0, foreach=True)
opt_for.step()

step = -0.1 * 0.1 / (math.sqrt(0.1**2) + 1e-8)
test_close(params_org[0], tensor([1+step, 2+step, 3+step]))
test_close([p[0].item() for p in params_org], [p[0].item() for p in params_for])

opt_org.step()
opt_for.step()
test_close(params_org[0], tensor([1+2*step, 2+2*step, 3+2*step]), eps=1e-3)
test_close([p[0].item() for p in params_org], [p[0].item() for p in params_for])

for i in range(test_steps):
    opt_org.step()
    opt_for.step()
test_close([p[0].item() for p in params_org], [p[0].item() for p in params_for])

In [None]:
#|hide
# test with weight decay
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Adam(params_org, lr=0.1, wd=0.1, foreach=False)

params_for = tst_param([1,2,3], [0.1,0.2,0.3])
opt_for = Adam(params_for, lr=0.1, wd=0.1, foreach=True)

for i in range(test_steps):
    opt_org.step()
    opt_for.step()

test_close([p[0].item() for p in params_org], [p[0].item() for p in params_for])

In [None]:
#|hide
# test with l2 reg
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Adam(params_org, lr=0.1, wd=0.1, decouple_wd=False, foreach=False)

params_for = tst_param([1,2,3], [0.1,0.2,0.3])
opt_for = Adam(params_for, lr=0.1, wd=0.1, decouple_wd=False, foreach=True)

for i in range(test_steps):
    opt_org.step()
    opt_for.step()

test_close([p[0].item() for p in params_org], [p[0].item() for p in params_for])

## RAdam -

In [None]:
#|exporti
def radam_foreach_step(p:list[Tensor], g:list[Tensor], grad_avg:list[Tensor], sqr_avg:list[Tensor], ones:list[Tensor],
                       steps:np.ndarray[Any, int], do_wd:np.ndarray[Any, bool], lr:float, wd:float, mom:float, sqr_mom:float,
                       eps:float, decouple_wd:bool, **kwargs):

    if wd != 0:
        if decouple_wd:
            # weight_decay
            wd = np.where(do_wd, 1-lr*wd, 1.)
            torch._foreach_mul_(p, scalars=wd.tolist())
        else:
            # l2_reg
            wd = np.where(do_wd, wd, 1.)
            torch._foreach_addcdiv_(g, p, ones, scalars=wd.tolist())
            # cannot use scalers with foreach_add & multiple tensors, so divide by one with foreach_addcdiv

    # average_grad, dampening=True
    torch._foreach_mul_(grad_avg, mom)
    torch._foreach_add_(grad_avg, g, alpha=1-mom)

    # average_sqr_grad
    torch._foreach_mul_(sqr_avg, sqr_mom)
    torch._foreach_addcmul_(sqr_avg, g, g, value=1-sqr_mom)

    # radam_step
    debias1 = -lr / (1 - mom**steps)
    debias2 = np.sqrt(1 - sqr_mom**steps)
    
    r_inf = 2/(1-sqr_mom) - 1
    r = r_inf - 2*steps*sqr_mom**steps/(1-sqr_mom**steps)

    rect   = np.where(r>5, debias1*np.sqrt(((r-4) * (r-2) * r_inf)/((r_inf-4)*(r_inf-2)*r), where=r>5), 0)
    unrect = np.where(r<=5, debias1, 0)

    # rectified step
    sqrt_avg_debias2 = torch._foreach_sqrt(sqr_avg)
    torch._foreach_div_(sqrt_avg_debias2, debias2.tolist())
    torch._foreach_add_(sqrt_avg_debias2, eps)
    torch._foreach_addcdiv_(p, grad_avg, sqrt_avg_debias2, scalars=rect.tolist())

    # unrectified step. cannot use scalers with foreach_add & multiple tensors, so divide by one with foreach_addcdiv
    torch._foreach_addcdiv_(p, grad_avg, ones, scalars=unrect.tolist())

In [None]:
#|exporti
class RAdamForEachOptimizer(ForEachOptimizer):
    "An `ForEachOptimizer` with a modified step for `radam_foreach_step`"
    @torch.no_grad()
    def step(self, closure=None):
        if closure is not None: raise NotImplementedError("fastai optimizers currently do not support closure")
        for pg, hyper in zip(self.param_lists, self.hypers):
            pl, gl, grad_avg, sqr_avg, steps, ones, do_wd = [], [], [], [], [], [], []

            for p in pg:
                if hasattr(p, 'grad') and p.grad is not None:
                    state = self.state[p]

                    if 'step' not in state:
                        state['grad_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        state['sqr_avg']  = torch.zeros_like(p, memory_format=torch.preserve_format)
                        state['ones']     = torch.ones_like(p, memory_format=torch.preserve_format)
                        state['step']     = 0

                    state['step'] += 1
                    pl.append(p)
                    gl.append(p.grad)
                    grad_avg.append(state['grad_avg'])
                    sqr_avg.append(state['sqr_avg'])
                    ones.append(state['ones'])
                    steps.append(state['step'])
                    do_wd.append(state.get('do_wd', True))

            self.opt_step(p=pl, g=gl, grad_avg=grad_avg, sqr_avg=sqr_avg, ones=ones, 
                          steps=np.array(steps, dtype=np.int32), do_wd=np.array(do_wd, dtype=bool), 
                          decouple_wd=self.decouple_wd, **hyper)

In [None]:
show_doc(RAdamForEachOptimizer)

In [None]:
#|hide
def RAdam(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., beta=0., decouple_wd=True, foreach=False):
    if foreach:
        if beta != 0: warn('ForEach RAdam does not use beta, set foreach=False if beta!=0')
        return RAdamForEachOptimizer(params, radam_foreach_step, lr=lr, mom=mom, sqr_mom=sqr_mom, 
                                     eps=eps, wd=wd, decouple_wd=decouple_wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, radam_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, beta=beta)

In [None]:
#|hide
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = RAdam(params_org, lr=0.1, foreach=False)

params_for = tst_param([1,2,3], [0.1,0.2,0.3])
opt_for = RAdam(params_for, lr=0.1, foreach=True)

#The r factor is lower than 5 during the first 5 steps so updates use the average of gradients (all the same)
r_inf = 2/(1-0.99) - 1
for i in range(5): 
    r = r_inf - 2*(i+1)*0.99**(i+1)/(1-0.99**(i+1))
    assert r <= 5
    opt_org.step()
    opt_for.step()
p = tensor([0.95, 1.9, 2.85])
test_close(params_org[0], p)
test_close(params_org[0], params_for[0])

#The r factor is greater than 5 for the sixth step so we update with RAdam
r = r_inf - 2*6*0.99**6/(1-0.99**6)
assert r > 5
opt_org.step()
opt_for.step()
v = math.sqrt(((r-4) * (r-2) * r_inf)/((r_inf-4)*(r_inf-2)*r))
step = -0.1*0.1*v/(math.sqrt(0.1**2) + 1e-8)
test_close(params_org[0], p+step)
test_close(params_org[0], params_for[0])

for i in range(test_steps):
    opt_org.step()
    opt_for.step()

test_close(params_org[0], params_for[0])

In [None]:
#|hide
# test with weight decay
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = RAdam(params_org, lr=0.1, wd=0.1, foreach=False)

params_for = tst_param([1,2,3], [0.1,0.2,0.3])
opt_for = RAdam(params_for, lr=0.1, wd=0.1, foreach=True)

opt_org.step()
opt_for.step()

test_close(params_org[0], params_for[0])

for i in range(test_steps):
    opt_org.step()
    opt_for.step()

test_close(params_org[0], params_for[0])

## LAMB -

In [None]:
#|exporti
@torch.jit.script
def lamb_jit_substep(p:Tensor, lstep:Tensor, lr:float):
    r1 = p.pow(2).mean().sqrt()
    r2 = lstep.pow(2).mean().sqrt()
    if r1 == 0 or r2 == 0:
        return -lr
    else:
        return -lr*min(r1/r2, 10.)

In [None]:
#|exporti
def lamb_foreach_step(p:list[Tensor], g:list[Tensor], grad_avg:list[Tensor], sqr_avg:list[Tensor], ones:list[Tensor], 
                      steps:np.ndarray[Any, int], do_wd:np.ndarray[Any, bool], lr:float, wd:float, mom:float, sqr_mom:float, 
                      eps:float, decouple_wd:bool, **kwargs):

    if wd != 0:
        if decouple_wd:
            # weight_decay
            wd = np.where(do_wd, 1-lr*wd, 1.)
            torch._foreach_mul_(p, scalars=wd.tolist())
        else:
            # l2_reg
            wd = np.where(do_wd, wd, 1.)
            torch._foreach_addcdiv_(g, p, ones, scalars=wd.tolist())
            # cannot use scalers with foreach_add & multiple tensors, so divide by one with foreach_addcdiv

    # average_grad, dampening=True
    torch._foreach_mul_(grad_avg, mom)
    torch._foreach_add_(grad_avg, g, alpha=1-mom)

    # average_sqr_grad
    torch._foreach_mul_(sqr_avg, sqr_mom)
    torch._foreach_addcmul_(sqr_avg, g, g, value=1-sqr_mom)

    # lamb_step
    debias1 = 1 - mom**steps
    debias2 = np.sqrt(1 - sqr_mom**steps)

    debias2 = torch._foreach_div(torch._foreach_sqrt(sqr_avg), debias2.tolist())
    torch._foreach_add_(debias2, eps)
    lstep = torch._foreach_div(grad_avg, debias1.tolist())
    torch._foreach_div_(lstep, debias2)

    # there are no implementations for foreach_pow, foreach_mean, or foreach_where/if methods
    q = [lamb_jit_substep(pi, ls, lr) for pi, ls in zip(p, lstep)]

    # cannot use scalers with foreach_add & multiple tensors, so divide by one with foreach_addcdiv
    torch._foreach_addcdiv_(p, lstep, ones, scalars=q)

In [None]:
#|exporti
class LambForEachOptimizer(RAdamForEachOptimizer):
    "An `ForEachOptimizer` with a modified step for `lamb_foreach_step`"

In [None]:
show_doc(LambForEachOptimizer)

In [None]:
#|hide
def Lamb(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., decouple_wd=True, foreach=False):
    if foreach:
        return LambForEachOptimizer(params, lamb_foreach_step, lr=lr, mom=mom, sqr_mom=sqr_mom, 
                                    eps=eps, wd=wd, decouple_wd=decouple_wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, lamb_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

In [None]:
#|hide
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Lamb(params_org, lr=0.1, foreach=False)
opt_org.step()

params_for = tst_param([1,2,3], [0.1,0.2,0.3])
opt_for = Lamb(params_for, lr=0.1, foreach=True)
opt_for.step()

test_close(params_org[0], tensor([0.7840,1.7840,2.7840]), eps=1e-3)
test_close(params_org[0], params_for[0])

for i in range(test_steps):
    opt_org.step()
    opt_for.step()

test_close(params_org[0], params_for[0])

In [None]:
#|hide
# test with weight decay
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Lamb(params_org, lr=0.1, wd=0.1, foreach=False)

params_for = tst_param([1,2,3], [0.1,0.2,0.3])
opt_for = Lamb(params_for, lr=0.1, wd=0.1, foreach=True)

for i in range(test_steps):
    opt_org.step()
    opt_for.step()

test_close(params_org[0], params_for[0])

## Ranger -

In [None]:
#|exporti
def ranger_foreach_step(p:list[Tensor], g:list[Tensor], grad_avg:list[Tensor], sqr_avg:list[Tensor], slow_p:list[Tensor], 
                        ones:list[Tensor], steps:np.ndarray[Any, int], do_wd:np.ndarray[Any, bool], lr:float, wd:float, 
                        mom:float, sqr_mom:float, eps:float, decouple_wd:bool, count:int, k:int, alpha:float, **kwargs):

    radam_foreach_step(p=p, g=g, grad_avg=grad_avg, sqr_avg=sqr_avg, ones=ones, steps=steps, do_wd=do_wd, 
                       lr=lr, wd=wd, mom=mom, sqr_mom=sqr_mom, eps=eps, decouple_wd=decouple_wd)

    if count % k == 0:
        torch._foreach_add_(slow_p, torch._foreach_sub(p, slow_p), alpha=alpha)
        # there currently is no foreach_set method
        [pi.set_(slow_pi.clone()) for pi, slow_pi in zip(p, slow_p)]

In [None]:
#|exporti
class RangerForEachOptimizer(ForEachOptimizer):
    "An `ForEachOptimizer` with a modified `LookAhead` step for `ranger_foreach_step`"
    def __init__(self, 
        params:listified[Tensor], # Model parameters
        opt_step:Callable, # `ForEachOptimizer` optimizer step
        decouple_wd:bool=True, # Use true weight decay or L2 regularization, if applicable
        **defaults # Optimizer specific hyper parameters default values
    ):
        super().__init__(params, opt_step, decouple_wd, **defaults)
        self._init_state()

    @torch.no_grad()
    def step(self, closure=None):
        self.count += 1
        if closure is not None: raise NotImplementedError("fastai optimizers currently do not support closure")
        for pg, hyper in zip(self.param_lists, self.hypers):
            pl, gl, grad_avg, sqr_avg, slow_p, steps, ones, do_wd = [], [], [], [], [], [], [], []
            for p in pg:
                if hasattr(p, 'grad') and p.grad is not None:
                    state = self.state[p]

                    if 'step' not in state:
                        state['grad_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        state['sqr_avg']  = torch.zeros_like(p, memory_format=torch.preserve_format)
                        state['ones']     = torch.ones_like(p, memory_format=torch.preserve_format)
                        state['slow_p']   = p.data.clone()
                        state['step']     = 0

                    state['step'] += 1
                    pl.append(p)
                    gl.append(p.grad)
                    grad_avg.append(state['grad_avg'])
                    sqr_avg.append(state['sqr_avg'])
                    slow_p.append(state['slow_p'])
                    ones.append(state['ones'])
                    steps.append(state['step'])
                    do_wd.append(state.get('do_wd', True))

            self.opt_step(p=pl, g=gl, grad_avg=grad_avg, sqr_avg=sqr_avg, slow_p=slow_p, ones=ones, 
                          steps=np.array(steps, dtype=np.int32), do_wd=np.array(do_wd, dtype=bool), 
                          decouple_wd=self.decouple_wd, count=self.count, **hyper)

    def clear_state(self):
        super().clear_state()
        self._init_state()

    def state_dict(self):
        state = super().state_dict()
        state.update({'count': self.count})
        return state

    def load_state_dict(self, sd):
        self.count = sd.pop('count')
        super().load_state_dict(sd)

    def _init_state(self): 
        self.count = 0

In [None]:
show_doc(RangerForEachOptimizer)

In [None]:
#|hide
def Ranger(params, lr, mom=0.95, sqr_mom=0.99, eps=1e-6, wd=0.01, beta=0., k=6, alpha=0.5, decouple_wd=True, foreach=False):
    if foreach:
        if beta != 0: warn('RAdam foreach does not use beta, set foreach=False if beta!=0')
        return RangerForEachOptimizer(params, ranger_foreach_step, lr=lr, mom=mom, sqr_mom=sqr_mom,
                                      eps=eps, wd=wd, decouple_wd=decouple_wd, k=k, alpha=alpha)
    else:
        return Lookahead(RAdam(params, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, 
                               beta=beta, decouple_wd=decouple_wd, foreach=False),
                         k=k, alpha=alpha)

In [None]:
#|hide
po = tensor([1,2,3])

params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Ranger(params_org, lr=0.1, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., beta=0., foreach=False)

params_for = tst_param([1,2,3], [0.1,0.2,0.3])
opt_for = Ranger(params_for, lr=0.1, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., foreach=True)

#The first 5 steps are normal RAdam steps
#The r factor is lower than 5 during the first 5 steps so updates use the average of gradients (all the same)
r_inf = 2/(1-0.99) - 1
for i in range(5): 
    r = r_inf - 2*(i+1)*0.99**(i+1)/(1-0.99**(i+1))
    assert r <= 5
    opt_org.step()
    opt_for.step()
p = tensor([0.95, 1.9, 2.85])
test_close(params_org[0], p)
test_close(params_org[0], params_for[0])

#The r factor is greater than 5 for the sixth step so we update with RAdam
r = r_inf - 2*6*0.99**6/(1-0.99**6)
assert r > 5
opt_org.step()
opt_for.step()
v = math.sqrt(((r-4) * (r-2) * r_inf)/((r_inf-4)*(r_inf-2)*r))
step = -0.1*0.1*v/(math.sqrt(0.1**2) + 1e-8)

#Since k=6, sixth step is a moving average of the 6 RAdam steps with the initial weight
test_close(params_org[0], po+((p+step)-po)*0.5)
test_close(params_org[0], params_for[0])

for i in range(test_steps):
    opt_org.step()
    opt_for.step()

test_close(params_org[0], params_for[0])

In [None]:
#|hide
# test with weight decay
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Ranger(params_org, lr=0.1, wd=0.1, foreach=False)

params_for = tst_param([1,2,3], [0.1,0.2,0.3])
opt_for = Ranger(params_for, lr=0.1, wd=0.1, foreach=True)

for i in range(test_steps):
    opt_org.step()
    opt_for.step()

test_close(params_org[0], params_for[0])