In [None]:
#|default_exp optimizer.foreach

# ForEach Optimizers
> fastai optimizers using PyTorch ForEach methods for improved performance

In [None]:
#|export
from __future__ import annotations
from typing import Optional, Dict

import numpy as np

from torch.nn import Parameter

from fastcore.basics import range_of, merge

from fastai.optimizer import Optimizer

from fastxtend.imports import *

## Test Utils -

In [None]:
#|hide
from fastai.optimizer import (_update, weight_decay, l2_reg, average_grad, sgd_step, momentum_step, 
                              average_sqr_grad, rms_prop_step, step_stat, adam_step, radam_step, 
                              larc_layer_lr, larc_step, lamb_step, Lookahead)

from fastxtend.test_utils import *

In [None]:
#|hide
def tst_param(val, grad=None):
    "Create a tensor with `val` and a gradient of `grad` for testing"
    res = tensor([val]).float()
    res.grad = tensor([val/10 if grad is None else grad]).float()
    return res

In [None]:
#|hide
def tst_params():
    r = L.range(4)
    return r.map(tst_param)

ForEach Optimizers are adapted from the PyTorch [_multi_tensor optim implementations](https://github.com/pytorch/pytorch/tree/master/torch/optim).

## SGD -

In [None]:
#|exporti
def sgd_foreach_step(p:list[Tensor], g:list[Tensor], no_wd_p:list[Tensor], no_wd_g:list[Tensor], grad_avg:list[Tensor], 
                     no_wd_grad_avg:list[Tensor], lr:float, wd:float, mom:float, decouple_wd:bool, dampening:bool=False, **kwargs):
    if len(p) > 0 and wd != 0:
        if decouple_wd:
            # weight_decay
            torch._foreach_mul_(p, 1 - lr * wd)
        else:
            # l2_reg
            torch._foreach_add_(g, p, alpha=wd)
        
    # combine wd and non-wd lists
    if len(no_wd_p) > 0:
        p += no_wd_p
        g += no_wd_g
        if mom != 0:
            grad_avg += no_wd_grad_avg

    if mom != 0:
        # average_grad
        damp = 1-mom if dampening else 1.
        torch._foreach_mul_(grad_avg, mom)
        torch._foreach_add_(grad_avg, g, alpha=damp)

        # momentum_step
        torch._foreach_add_(p, grad_avg, alpha=-lr)
    else:
        # sgd_step
        torch._foreach_add_(p, g, alpha=-lr)

In [None]:
#|exporti
class SGDForEachOptimizer(Optimizer):
    "An `Optimizer` with a modified step for SGD ForEach"
    @torch.no_grad()
    def step(self, closure=None):
        if closure is not None: raise NotImplementedError("fastai optimizers currently do not support closure")
        for pg, hyper in zip(self.param_lists, self.hypers):
            do_wd_p, do_wd_g, do_wd_grad_avg, no_wd_p, no_wd_g, no_wd_grad_avg = [], [], [], [], [], []

            for p in pg:
                if hasattr(p, 'grad') and p.grad is not None:
                    state = self.state[p]

                    if 'grad_avg' not in state:
                        state['grad_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) if hyper['mom'] != 0 else None

                    if hyper['wd'] != 0 and state.get('do_wd', True):
                        do_wd_p.append(p)
                        do_wd_g.append(p.grad)
                        do_wd_grad_avg.append(state['grad_avg'])
                    else:
                        no_wd_p.append(p)
                        no_wd_g.append(p.grad)
                        no_wd_grad_avg.append(state['grad_avg'])

            self.cbs[0](do_wd_p, do_wd_g, no_wd_p, no_wd_g, do_wd_grad_avg, no_wd_grad_avg, **hyper)

In [None]:
#|hide
def SGD(params, lr, mom=0., wd=0., decouple_wd=True, foreach=False):
    "A `Optimizer` or `SGDForEachOptimizer` for SGD with `lr` and `mom` and `params`"
    if foreach:
        cb = partial(sgd_foreach_step, decouple_wd=decouple_wd)
        return SGDForEachOptimizer(params, cb, lr=lr, mom=mom, wd=wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        if mom != 0: cbs.append(average_grad)
        cbs.append(sgd_step if mom==0 else momentum_step)
        return Optimizer(params, cbs, lr=lr, mom=mom, wd=wd)

In [None]:
#|hide

# Test Vanilla SGD
params_org = tst_params()
opt_org = SGD(params_org, lr=0.1, foreach=False)
opt_org.step()

params_for = tst_params()
opt_for = SGD(params_for, lr=0.1, foreach=True)
opt_for.step()

test_close([p.item() for p in params_org], [i*0.99 for i in range(4)])
test_close([p.item() for p in params_org], [p.item() for p in params_for])

opt_org.step()
opt_for.step()
test_close([p.item() for p in params_org], [i*0.98 for i in range(4)])
test_close([p.item() for p in params_org], [p.item() for p in params_for])

## Adam -

In [None]:
#|exporti
def adam_foreach_step(p:list[Tensor], g:list[Tensor], no_wd_p:list[Tensor], no_wd_g:list[Tensor], grad_avg:list[Tensor], 
                      no_wd_grad_avg:list[Tensor], sqr_avg:list[Tensor], no_wd_sqr_avg:list[Tensor], steps:np.ndarray[Any, float],
                      lr:float, wd:float, mom:float, sqr_mom:float, eps:float, decouple_wd:bool, **kwargs):

    if len(p) > 0 and wd != 0:
        if decouple_wd:
            # weight_decay
            torch._foreach_mul_(p, 1 - lr * wd)
        else:
            # l2_reg
            torch._foreach_add_(g, p, alpha=wd)
        
    # combine wd and non-wd lists
    if len(no_wd_p) > 0:
        p += no_wd_p
        g += no_wd_g
        grad_avg += no_wd_grad_avg
        sqr_avg += no_wd_sqr_avg

    # average_grad, dampening=True
    torch._foreach_mul_(grad_avg, mom)
    torch._foreach_add_(grad_avg, g, alpha=1-mom)

    # average_sqr_grad
    torch._foreach_mul_(sqr_avg, sqr_mom)
    torch._foreach_addcmul_(sqr_avg, g, g, value=1-sqr_mom)

    # adam_step
    debias1 = -lr / (1 - mom**steps)
    debias2 = np.sqrt(1 - sqr_mom**steps)

    sqr_avg_debias2 = torch._foreach_sqrt(sqr_avg)
    torch._foreach_div_(sqr_avg_debias2, debias2.tolist())
    torch._foreach_add_(sqr_avg_debias2, eps)

    torch._foreach_addcdiv_(p, grad_avg, sqr_avg_debias2, debias1.tolist())

In [None]:
#|exporti
class AdamForEachOptimizer(Optimizer):
    "An `Optimizer` with a modified step for Adam ForEach"
    @torch.no_grad()
    def step(self, closure=None):
        if closure is not None: raise NotImplementedError("fastai optimizers currently do not support closure")
        for pg, hyper in zip(self.param_lists, self.hypers):
            do_wd_p, do_wd_g, do_wd_grad_avg, do_wd_sqr_avg, do_wd_steps = [], [], [], [], []
            no_wd_p, no_wd_g, no_wd_grad_avg, no_wd_sqr_avg, no_wd_steps = [], [], [], [], []

            for p in pg:
                if hasattr(p, 'grad') and p.grad is not None:
                    state = self.state[p]

                    if 'grad_avg' not in state:
                        state['grad_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        state['sqr_avg']  = torch.zeros_like(p, memory_format=torch.preserve_format)
                        state['step'] = 0

                    state['step'] += 1
                    if hyper['wd'] != 0 and state.get('do_wd', True):
                        do_wd_p.append(p)
                        do_wd_g.append(p.grad)
                        do_wd_grad_avg.append(state['grad_avg'])
                        do_wd_sqr_avg.append(state['sqr_avg'])
                        do_wd_steps.append(state['step'])
                    else:
                        no_wd_p.append(p)
                        no_wd_g.append(p.grad)
                        no_wd_grad_avg.append(state['grad_avg'])
                        no_wd_sqr_avg.append(state['sqr_avg'])
                        no_wd_steps.append(state['step'])

            steps = np.array([*do_wd_steps, *no_wd_steps], dtype=np.float32)
            self.cbs[0](do_wd_p, do_wd_g, no_wd_p, no_wd_g, do_wd_grad_avg, no_wd_grad_avg, do_wd_sqr_avg, no_wd_sqr_avg, steps, **hyper)

In [None]:
#|hide
def Adam(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0.01, decouple_wd=True, foreach=False):
    "A `Optimizer` or `JitOptimizer` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`"
    if foreach:
        cb = partial(adam_foreach_step, decouple_wd=decouple_wd)
        return AdamForEachOptimizer(params, cb, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, adam_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

In [None]:
#|hide
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Adam(params_org, lr=0.1, wd=0, foreach=False)
opt_org.step()

params_for = tst_param([1,2,3], [0.1,0.2,0.3])
opt_for = Adam(params_for, lr=0.1, wd=0, foreach=True)
opt_for.step()

step = -0.1 * 0.1 / (math.sqrt(0.1**2) + 1e-8)
test_close(params_org[0], tensor([1+step, 2+step, 3+step]))
test_close([p[0].item() for p in params_org], [p[0].item() for p in params_for])

opt_org.step()
opt_for.step()
test_close(params_org[0], tensor([1+2*step, 2+2*step, 3+2*step]), eps=1e-3)
test_close([p[0].item() for p in params_org], [p[0].item() for p in params_for])

## RAdam -

In [None]:
#|exporti
def radam_foreach_step(p:list[Tensor], g:list[Tensor], no_wd_p:list[Tensor], no_wd_g:list[Tensor], grad_avg:list[Tensor], 
                       no_wd_grad_avg:list[Tensor], sqr_avg:list[Tensor], no_wd_sqr_avg:list[Tensor], ones:list[Tensor], 
                       steps:np.ndarray[float], lr:float, wd:float, mom:float, sqr_mom:float, eps:float, decouple_wd:bool, **kwargs):

    if len(p) > 0 and wd != 0:
        if decouple_wd:
            # weight_decay
            torch._foreach_mul_(p, 1 - lr * wd)
        else:
            # l2_reg
            torch._foreach_add_(g, p, alpha=wd)
        
    # combine wd and non-wd lists
    if len(no_wd_p) > 0:
        p += no_wd_p
        g += no_wd_g
        grad_avg += no_wd_grad_avg
        sqr_avg += no_wd_sqr_avg

    # average_grad, dampening=True
    torch._foreach_mul_(grad_avg, mom)
    torch._foreach_add_(grad_avg, g, alpha=1-mom)

    # average_sqr_grad
    torch._foreach_mul_(sqr_avg, sqr_mom)
    torch._foreach_addcmul_(sqr_avg, g, g, value=1-sqr_mom)

    # radam_step
    debias1 = -lr / (1 - mom**steps)
    debias2 = np.sqrt(1 - sqr_mom**steps).tolist()
    
    r_inf = 2/(1-sqr_mom) - 1
    r = r_inf - 2*steps*sqr_mom**steps/(1-sqr_mom**steps)

    rect   = np.where(r > 5, debias1*np.emath.sqrt(((r-4) * (r-2) * r_inf)/((r_inf-4)*(r_inf-2)*r)), 0).tolist()
    unrect = np.where(r <= 5, debias1, 0).tolist()

    # rectified step
    sqrt_avg_debias2 = torch._foreach_sqrt(sqr_avg)
    torch._foreach_div_(sqrt_avg_debias2, debias2)
    torch._foreach_add_(sqrt_avg_debias2, eps)
    torch._foreach_addcdiv_(p, grad_avg, sqrt_avg_debias2, scalars=rect)

    # unrectified step. cannot scale with foreach_add, so divide by one with foreach_addcdiv
    torch._foreach_addcdiv_(p, grad_avg, ones, scalars=unrect)

In [None]:
#|exporti
class RAdamForEachOptimizer(Optimizer):
    "An `Optimizer` with a modified step for RAdam ForEach"
    @torch.no_grad()
    def step(self, closure=None):
        if closure is not None: raise NotImplementedError("fastai optimizers currently do not support closure")
        for pg, hyper in zip(self.param_lists, self.hypers):
            do_wd_p, do_wd_g, do_wd_grad_avg, do_wd_sqr_avg, do_wd_steps, do_wd_ones = [], [], [], [], [], []
            no_wd_p, no_wd_g, no_wd_grad_avg, no_wd_sqr_avg, no_wd_steps, no_wd_ones = [], [], [], [], [], []

            for p in pg:
                if hasattr(p, 'grad') and p.grad is not None:
                    state = self.state[p]

                    if 'grad_avg' not in state:
                        state['grad_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        state['sqr_avg']  = torch.zeros_like(p, memory_format=torch.preserve_format)
                        state['ones']     = torch.ones_like(p, memory_format=torch.preserve_format)
                        state['step'] = 0

                    state['step'] += 1
                    if hyper['wd'] != 0 and state.get('do_wd', True):
                        do_wd_p.append(p)
                        do_wd_g.append(p.grad)
                        do_wd_grad_avg.append(state['grad_avg'])
                        do_wd_sqr_avg.append(state['sqr_avg'])
                        do_wd_ones.append(state['ones'])
                        do_wd_steps.append(state['step'])
                    else:
                        no_wd_p.append(p)
                        no_wd_g.append(p.grad)
                        no_wd_grad_avg.append(state['grad_avg'])
                        no_wd_sqr_avg.append(state['sqr_avg'])
                        no_wd_ones.append(state['ones'])
                        no_wd_steps.append(state['step'])

            steps = np.array([*do_wd_steps, *no_wd_steps], dtype=np.float32)
            ones = do_wd_ones + no_wd_ones
            self.cbs[0](do_wd_p, do_wd_g, no_wd_p, no_wd_g, do_wd_grad_avg, no_wd_grad_avg, do_wd_sqr_avg, no_wd_sqr_avg, ones, steps, **hyper)

In [None]:
#|hide
def RAdam(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., beta=0., decouple_wd=True, foreach=False):
    "A `Optimizer` or `ForEachOptimizer` for RAdam with `lr`, `mom`, `sqr_mom`, `eps` and `params`"
    if foreach:
        cb = partial(radam_foreach_step, decouple_wd=decouple_wd)
        if beta != 0: warn('RAdam foreach does not use beta, set foreach=False if beta!=0')
        return RAdamForEachOptimizer(params, cb, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, radam_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, beta=beta)

In [None]:
#|hide
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = RAdam(params_org, lr=0.1, foreach=False)

params_for = tst_param([1,2,3], [0.1,0.2,0.3])
opt_for = RAdam(params_for, lr=0.1, foreach=True)

#The r factor is lower than 5 during the first 5 steps so updates use the average of gradients (all the same)
r_inf = 2/(1-0.99) - 1
for i in range(5): 
    r = r_inf - 2*(i+1)*0.99**(i+1)/(1-0.99**(i+1))
    assert r <= 5
    opt_org.step()
    opt_for.step()
p = tensor([0.95, 1.9, 2.85])
test_close(params_org[0], p)
test_close(params_org[0], params_for[0])

#The r factor is greater than 5 for the sixth step so we update with RAdam
r = r_inf - 2*6*0.99**6/(1-0.99**6)
assert r > 5
opt_org.step()
opt_for.step()
v = math.sqrt(((r-4) * (r-2) * r_inf)/((r_inf-4)*(r_inf-2)*r))
step = -0.1*0.1*v/(math.sqrt(0.1**2) + 1e-8)
test_close(params_org[0], p+step)
test_close(params_org[0], params_for[0])

## LAMB -

In [None]:
#|exporti
def lamb_foreach_step(p:list[Tensor], g:list[Tensor], no_wd_p:list[Tensor], no_wd_g:list[Tensor], grad_avg:list[Tensor],
                      no_wd_grad_avg:list[Tensor], sqr_avg:list[Tensor], no_wd_sqr_avg:list[Tensor], ones:list[Tensor],
                      steps:np.ndarray[float], lr:float, wd:float, mom:float, sqr_mom:float, eps:float, decouple_wd:bool, **kwargs):

    if len(p) > 0 and wd != 0:
        if decouple_wd:
            # weight_decay
            torch._foreach_mul_(p, 1 - lr * wd)
        else:
            # l2_reg
            torch._foreach_add_(g, p, alpha=wd)
        
    # combine wd and non-wd lists
    if len(no_wd_p) > 0:
        p += no_wd_p
        g += no_wd_g
        grad_avg += no_wd_grad_avg
        sqr_avg += no_wd_sqr_avg

    # average_grad, dampening=True
    torch._foreach_mul_(grad_avg, mom)
    torch._foreach_add_(grad_avg, g, alpha=1-mom)

    # average_sqr_grad
    torch._foreach_mul_(sqr_avg, sqr_mom)
    torch._foreach_addcmul_(sqr_avg, g, g, value=1-sqr_mom)

    # lamb_step
    debias1 = 1 - mom**steps
    debias2 = np.sqrt(1 - sqr_mom**steps)

    debias2 = torch._foreach_div(torch._foreach_sqrt(sqr_avg), debias2.tolist())
    torch._foreach_add_(debias2, eps)
    lstep = torch._foreach_div(grad_avg, debias1.tolist())
    torch._foreach_div_(lstep, debias2)

    # there currently is no foreach_mean or foreach_where/if methods
    q = []
    for i in range(len(p)):
        r1 = p[i].pow(2).mean().sqrt().item()
        r2 = lstep[i].pow(2).mean().sqrt().item()
        if r1 == 0 or r2 == 0:
            q.append(-lr)
        else:
            q.append(-lr*min(r1/r2, 10.))

    # cannot scale with foreach_add, so divide by one with foreach_addcdiv
    torch._foreach_addcdiv_(p, lstep, ones, scalars=q)

In [None]:
#|exporti
class LambForEachOptimizer(RAdamForEachOptimizer):
    "An `Optimizer` with a modified step for Lamb ForEach"

In [None]:
#|hide
def Lamb(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., decouple_wd=True, foreach=False):
    "A `Optimizer` or `ForEachOptimizer` for Lamb with `lr`, `mom`, `sqr_mom`, `eps` and `params`"
    if foreach:
        cb = partial(lamb_foreach_step, decouple_wd=decouple_wd)
        return LambForEachOptimizer(params, cb, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)
    else:
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, lamb_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

In [None]:
#|hide
params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = Lamb(params_org, lr=0.1, foreach=False)
opt_org.step()

params_for = tst_param([1,2,3], [0.1,0.2,0.3])
opt_for = Lamb(params_for, lr=0.1, foreach=True)
opt_for.step()

test_close(params_org[0], tensor([0.7840,1.7840,2.7840]), eps=1e-3)
test_close(params_org[0], params_for[0])

## Ranger -

In [None]:
#|exporti
def ranger_foreach_step(p:list[Tensor], g:list[Tensor], no_wd_p:list[Tensor], no_wd_g:list[Tensor], grad_avg:list[Tensor],
                        no_wd_grad_avg:list[Tensor], sqr_avg:list[Tensor], no_wd_sqr_avg:list[Tensor], ones:list[Tensor], 
                        slow_p:list[Tensor], steps:np.ndarray[float], lr:float, wd:float, mom:float, sqr_mom:float, eps:float, 
                        decouple_wd:bool, count:int, k:int, alpha:float, **kwargs):

    if len(p) > 0 and wd != 0:
        if decouple_wd:
            # weight_decay
            torch._foreach_mul_(p, 1 - lr * wd)
        else:
            # l2_reg
            torch._foreach_add_(g, p, alpha=wd)
        
    # combine wd and non-wd lists
    if len(no_wd_p) > 0:
        p += no_wd_p
        g += no_wd_g
        grad_avg += no_wd_grad_avg
        sqr_avg += no_wd_sqr_avg

    # average_grad, dampening=True
    torch._foreach_mul_(grad_avg, mom)
    torch._foreach_add_(grad_avg, g, alpha=1-mom)

    # average_sqr_grad
    torch._foreach_mul_(sqr_avg, sqr_mom)
    torch._foreach_addcmul_(sqr_avg, g, g, value=1-sqr_mom)

    # radam_step
    debias1 = -lr / (1 - mom**steps)
    debias2 = np.sqrt(1 - sqr_mom**steps).tolist()
    
    r_inf = 2/(1-sqr_mom) - 1
    r = r_inf - 2*steps*sqr_mom**steps/(1-sqr_mom**steps)

    rect   = np.where(r > 5, debias1*np.emath.sqrt(((r-4) * (r-2) * r_inf)/((r_inf-4)*(r_inf-2)*r)), 0).tolist()
    unrect = np.where(r <= 5, debias1, 0).tolist()

    # rectified step
    sqrt_avg_debias2 = torch._foreach_sqrt(sqr_avg)
    torch._foreach_div_(sqrt_avg_debias2, debias2)
    torch._foreach_add_(sqrt_avg_debias2, eps)
    torch._foreach_addcdiv_(p, grad_avg, sqrt_avg_debias2, scalars=rect)

    # unrectified step. cannot scale with foreach_add, so divide by one with foreach_addcdiv
    torch._foreach_addcdiv_(p, grad_avg, ones, scalars=unrect)

    if count % k == 0:
        torch._foreach_add_(slow_p, torch._foreach_sub(p, slow_p), alpha=alpha)
        # there currently is no foreach_set method
        p = [pi.set_(slow_pi) for pi, slow_pi in zip(p, slow_p)]

In [None]:
#|exporti
class RangerOptimizer(Optimizer):
    "An `Optimizer` with a modified step for Lookahead TorchScript optimizers"
    def __init__(self, params:Tensor, cbs:list, train_bn:bool=True, **defaults):
        super().__init__(params, cbs, train_bn, **defaults)
        self._init_state()

    @torch.no_grad()
    def step(self, closure=None):
        self.count += 1
        if closure is not None: raise NotImplementedError("fastai optimizers currently do not support closure")
        for pg, hyper in zip(self.param_lists, self.hypers):
            do_wd_p, do_wd_g, do_wd_grad_avg, do_wd_sqr_avg, do_wd_steps, do_wd_ones, do_wd_slow = [], [], [], [], [], [], []
            no_wd_p, no_wd_g, no_wd_grad_avg, no_wd_sqr_avg, no_wd_steps, no_wd_ones, no_wd_slow = [], [], [], [], [], [], []

            for p in pg:
                if hasattr(p, 'grad') and p.grad is not None:
                    state = self.state[p]

                    if 'grad_avg' not in state:
                        state['grad_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        state['sqr_avg']  = torch.zeros_like(p, memory_format=torch.preserve_format)
                        state['ones']     = torch.ones_like(p, memory_format=torch.preserve_format)
                        state['slow_p']   = p.clone().detach()
                        state['step'] = 0

                    state['step'] += 1
                    if hyper['wd'] != 0 and state.get('do_wd', True):
                        do_wd_p.append(p)
                        do_wd_g.append(p.grad)
                        do_wd_grad_avg.append(state['grad_avg'])
                        do_wd_sqr_avg.append(state['sqr_avg'])
                        do_wd_ones.append(state['ones'])
                        do_wd_slow.append(state['slow_p'])
                        do_wd_steps.append(state['step'])
                    else:
                        no_wd_p.append(p)
                        no_wd_g.append(p.grad)
                        no_wd_grad_avg.append(state['grad_avg'])
                        no_wd_sqr_avg.append(state['sqr_avg'])
                        no_wd_ones.append(state['ones'])
                        no_wd_slow.append(state['slow_p'])
                        no_wd_steps.append(state['step'])

            steps = np.array([*do_wd_steps, *no_wd_steps], dtype=np.float32)
            ones = do_wd_ones + no_wd_ones
            slow_p = do_wd_slow + no_wd_slow
            self.cbs[0](do_wd_p, do_wd_g, no_wd_p, no_wd_g, do_wd_grad_avg, no_wd_grad_avg, 
                        do_wd_sqr_avg, no_wd_sqr_avg, ones, slow_p, steps, count=self.count, **hyper)

    def clear_state(self):
        super().clear_state()
        self._init_state()

    def state_dict(self):
        state = super().state_dict()
        state.update({'count': self.count})
        return state

    def load_state_dict(self, sd):
        self.count = sd.pop('count')
        super().load_state_dict(sd)

    def _init_state(self): 
        self.count = 0

In [None]:
#|hide
def ranger(params, lr, mom=0.95, sqr_mom=0.99, eps=1e-6, wd=0.01, beta=0., k=6, alpha=0.5, decouple_wd=True, foreach=False):
    "Convenience method for `Lookahead` with `RAdam`"
    if foreach:
        cb = partial(ranger_foreach_step, decouple_wd=decouple_wd)
        return RangerOptimizer(params, cb, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, beta=beta, k=k, alpha=alpha)
    else:
        return Lookahead(RAdam(params, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, beta=beta), k=k, alpha=alpha)

In [None]:
#|hide
po = tensor([1,2,3])

params_org = tst_param([1,2,3], [0.1,0.2,0.3])
opt_org = ranger(params_org, lr=0.1, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., beta=0., foreach=False)

params_for = tst_param([1,2,3], [0.1,0.2,0.3])
opt_for = ranger(params_for, lr=0.1, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., beta=0., foreach=True)

#The first 5 steps are normal RAdam steps
#The r factor is lower than 5 during the first 5 steps so updates use the average of gradients (all the same)
r_inf = 2/(1-0.99) - 1
for i in range(5): 
    r = r_inf - 2*(i+1)*0.99**(i+1)/(1-0.99**(i+1))
    assert r <= 5
    opt_org.step()
    opt_for.step()
p = tensor([0.95, 1.9, 2.85])
test_close(params_org[0], p)
test_close(params_org[0], params_for[0])

#The r factor is greater than 5 for the sixth step so we update with RAdam
r = r_inf - 2*6*0.99**6/(1-0.99**6)
assert r > 5
opt_org.step()
opt_for.step()
v = math.sqrt(((r-4) * (r-2) * r_inf)/((r_inf-4)*(r_inf-2)*r))
step = -0.1*0.1*v/(math.sqrt(0.1**2) + 1e-8)

#Since k=6, sixth step is a moving average of the 6 RAdam steps with the initial weight
test_close(params_org[0], po+((p+step)-po)*0.5)
test_close(params_org[0], params_for[0])