In [None]:
#|default_exp optimizer.lion
#|default_cls_lvl 2

In [None]:
#|exporti
# Lion implementation based on the paper's code release
# https://github.com/google/automl/tree/master/lion - Apache License 2.0 - Copyright 2020 Google Research

# Lion: Evo**L**ved S**i**gn M**o**me**n**tum Optimizer
> With fastai native and fused ForEach implementations

In [None]:
#|export
from __future__ import annotations

import numpy as np

from fastai.optimizer import Optimizer

from fastxtend.optimizer.foreach import ForEachOptimizer
from fastxtend.imports import *

In [None]:
#|hide
from fastxtend.test_utils import *

## Lion Fastai Callback -

In [None]:
#|exporti
def lion_step(p:Tensor, lr:float, wd:float, beta1:float, beta2:float,
              grad_avg:Tensor|None=None, do_wd:bool=True, **kwargs):

    if grad_avg is None:
        grad_avg = torch.zeros_like(p, memory_format=torch.preserve_format)

    # weight decay
    if do_wd and wd != 0:
        p.data.mul_(1-lr*wd)

    # lion step
    update = grad_avg.mul(beta1) + p.grad.data.mul(1 - beta1)
    p.data.add_(torch.sign(update), alpha=-lr)

    # Update m_k
    grad_avg.mul_(beta2).add_(p.grad.data, alpha=1-beta2)

    return {'grad_avg': grad_avg}

lion_step.defaults = dict(beta1=0.9, beta2=0.99)

## Lion ForEach -

In [None]:
#|exporti
def lion_foreach_step(p:list[Tensor], g:list[Tensor], grad_avg:list[Tensor],
                      do_wd:np.ndarray[Any, bool], lr:float, wd:float, beta1:float,
                      beta2:float, **kwargs):

    # weight_decay
    if wd != 0:
        wd = np.where(do_wd, 1-lr*wd, 1.)
        torch._foreach_mul_(p, scalars=wd.tolist())

    # lion update step
    update = torch._foreach_mul(grad_avg, scalar=beta1)
    torch._foreach_add_(update, g, alpha=1-beta1)
    for u in update: u.sign_()
    torch._foreach_add_(p, update, alpha=-lr)

    # update m_k
    torch._foreach_mul_(grad_avg, scalar=beta2)
    torch._foreach_add_(grad_avg, g, alpha=1-beta2)

In [None]:
#|exporti
class LionForEachOptimizer(ForEachOptimizer):
    "An `Optimizer` with a modified step for Lion ForEach"
    def __init__(self,
        params:Listified[Tensor], # Model parameters
        opt_step:Callable, # `ForEachOptimizer` optimizer step
        **defaults # Optimizer specific hyper parameters
    ):
        super().__init__(params, opt_step, **defaults)

    @torch.no_grad()
    def step(self, closure=None):
        if closure is not None: raise NotImplementedError("fastai optimizers currently do not support closure")
        for pg, hyper in zip(self.param_lists, self.hypers):
            pl, gl, grad_avg, do_wd = [], [], [], []

            for p in pg:
                if hasattr(p, 'grad') and p.grad is not None:
                    state = self.state[p]

                    if 'grad_avg' not in state:
                        state['grad_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                    pl.append(p)
                    gl.append(p.grad)
                    grad_avg.append(state['grad_avg'])
                    do_wd.append(state.get('do_wd', True))

            self.opt_step(p=pl, g=gl, grad_avg=grad_avg, do_wd=np.array(do_wd, dtype=bool), **hyper)

Lion was introduced by Chen et al in *[Symbolic Discovery of Optimization Algorithms](https://arxiv.org/abs/2302.06675)*.

In addition to a fastai native implementation, `Lion` has a fused ForEach implementation. See the [Fused Optimizer](optimizer.fused.html) documentation for more details.

## Lion -

In [None]:
#|export
def Lion(
    params:Listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    beta1:float=0.9, # Update gradient moving average (β1) coefficient
    beta2:float=0.99, # Gradient moving average (β2) coefficient
    wd:float=0.1, # True weight decay
    foreach:bool=False, # Use fused ForEach implementation
) -> Optimizer|LionForEachOptimizer:
    "A fastai Lion optimizer with a fused ForEach implementation"
    if foreach:
        return LionForEachOptimizer(params, lion_foreach_step, lr=lr,
                                    beta1=beta1, beta2=beta2, wd=wd)
    else:
        return Optimizer(params, [lion_step], lr=lr, beta1=beta1, beta2=beta2, wd=wd)

In [None]:
#|export
def lion(
    beta1:float=0.9, # Update gradient moving average (β1) coefficient
    beta2:float=0.99, # Gradient moving average (β2) coefficient
    wd:float=0.1, # True weight decay
    foreach:bool=False, # Use fused ForEach implementation
) -> Optimizer|LionForEachOptimizer:
    "Partial function for the Lion optimizer with a fused ForEach implementation"
    return partialler(Lion, beta1=beta1, beta2=beta2, wd=wd, foreach=foreach)

In [None]:
#|hide
# Test contain code from:
# fastai - Apache License 2.0 - Copyright (c) 2023 fast.ai

test_steps = 25

def tst_param(val, grad=None):
    "Create a tensor with `val` and a gradient of `grad` for testing"
    res = tensor([val]).float()
    res.grad = tensor([val/10 if grad is None else grad]).float()
    return res

def tst_params():
    r = L.range(4)
    return r.map(tst_param)

params_org = tst_params()
opt_org = Lion(params_org, lr=0.01)
opt_org.step()

params_for = tst_params()
opt_for = Lion(params_for, lr=0.01, foreach=True)
opt_for.step()

# Test values from paper implementation: https://github.com/google/automl/tree/master/lion
test_close([p.item() for p in params_org], [0., 0.9890, 1.9880, 2.9869999])
test_close([p.item() for p in params_org], [p.item() for p in params_for])

for i in range(test_steps):
    opt_org.step()
    opt_for.step()

test_close([p.item() for p in params_org], [p.item() for p in params_for])