In [None]:
#|default_exp optimizers.adan

# Adan: ADAptive Nesterov Momentum Optimizer
> Adds the Adan optimizer to fastai from [Adan: Adaptive Nesterov momentum Algorithm for Faster Optimizing Deep Models](https://arxiv.org/abs/2208.06677).

In [None]:
#|export
from __future__ import annotations

from fastai.optimizer import Optimizer, step_stat

from fastxtend.imports import *

In [None]:
#|exporti
def debias(beta, step):
    "Simple debias calculation"
    return 1-beta**step

In [None]:
#|exporti
def adan_setup(p, step=0, grad_avg=None, diff_avg=None, nesterov_est=None, prior_grad=None, paper_init=False, **kwargs):
    "Handles Adan setup and keeps track of steps"
    if step == 0: 
        grad_avg     = torch.zeros_like(p, memory_format=torch.preserve_format)
        diff_avg     = torch.zeros_like(p, memory_format=torch.preserve_format)
        nesterov_est = torch.zeros_like(p, memory_format=torch.preserve_format)
        if paper_init:
            prior_grad = p.grad.clone()
        else:
            prior_grad = torch.zeros_like(p, memory_format=torch.preserve_format)
        step += 1
        return {'grad_avg':grad_avg, 'diff_avg':diff_avg, 'nesterov_est':nesterov_est, 'prior_grad':prior_grad, 'step':step}
    else:
        step += 1
        return {'step':step}

In [None]:
#|exporti
def adan_avgs(p, beta1, beta2, beta3, grad_avg, diff_avg, nesterov_est, prior_grad, **kwargs):
    "Updates Adan moving averages"
    # update m_k
    grad_avg.mul_(beta1).add_(p.grad.data, alpha=1-beta1)

    # update v_k
    diff_avg.mul_(beta2).add_(p.grad.data-prior_grad, alpha=1-beta2)
    
    # update n_k
    nesterov_est.mul_(beta3).add_(torch.square(torch.add(p.grad.data, torch.sub(p.grad.data, prior_grad), alpha=beta2)), alpha=1-beta3)

    # set current grad as next step's prior_grad
    prior_grad = p.grad.data.clone()
    return {'grad_avg':grad_avg, 'diff_avg':diff_avg, 'nesterov_est':nesterov_est, 'prior_grad':prior_grad}

adan_avgs.defaults = dict(beta1=0.98, beta2=0.92, beta3=0.99)

In [None]:
#|exporti
def adan_step(p, lr, eps, wd, beta1, beta2, beta3, step, grad_avg, diff_avg, nesterov_est, **kwargs):
    "Performs the Adan step with `lr` on `p`"
    # calculate debias terms
    db1, db2, db3 = debias(beta1, step), debias(beta2, step), debias(beta3, step)

    # calculate applied λ 
    wd = (1+lr*wd) if wd!=0 else 1 

    # calculate η_k
    lr = lr/torch.sqrt(nesterov_est/db3+eps) 

    # perform Adan step and apply to parameter `p`
    p.data.sub_(torch.add(grad_avg/db1, diff_avg/db2, alpha=beta2).mul_(lr)).div_(wd)
    return p

## Adan -

In [None]:
#|export
def Adan(params, lr, beta1=0.98, beta2=0.92, beta3=0.99, eps=1e-8, wd=0.02, paper_init=False):
    "A `Optimizer` for Adan with `lr`, `beta`s, `eps` and `params`"
    cbs = [partial(adan_setup, paper_init=paper_init), adan_avgs, adan_step]
    return Optimizer(params, cbs, lr=lr, beta1=beta1, beta2=beta2, beta3=beta3, eps=eps, wd=wd)

In [None]:
#|export
def AdanLargeBatchLR(bs):
    "Square root rule for scaling `Adan` learning rate for large-batch training"
    return math.sqrt(bs/256)*6.25e-3