In [None]:
#|default_exp optimizer.sophia
#|default_cls_lvl 2

In [None]:
#|exporti
# Sophia implementation based on the paper's code release
# https://github.com/Liuhong99/Sophia - MIT License - Copyright 2023 Hong Liu

# Sophia: **S**econd-**o**rder Cli**p**ped Stoc**h**astic Optimiz**a**tion
> With fastai native and fused ForEach implementations

Sophia was introduced by Liu et al in *[Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training](https://arxiv.org/abs/2305.14342)*. Sophia is a second-order optimizer that leverages a light-weight Hessian estimate as a pre-conditioner, which is supposed to handle the Large Language Model (LLM) loss landscape better than [AdamW](https://openreview.net/forum?id=Bkg6RiCqY7). The Hessian pre-conditioner is more aggressive than [`AdamW`](optimizer.fused.html#adam-optimizer), with stronger update penalties sharp dimensions, which can lead to a more uniform loss decrease across parameters and faster convergence. Additionally, Sophia applies element-wise clipping to updates which allows infrequent and stochastic updates to the Hessian estimate, reducing optimizer wall-clock time.

:::{.callout-important}
`Sophia` will not update the Hessian estimate unless the `SophiaCallback` is added to `fastai.learner.Learner`.
:::

In addition to a fastai native implementation, `Sophia` has a fused ForEach implementation. See the [Fused Optimizer](optimizer.fused.html) documentation for more details.

In [None]:
#|export
from __future__ import annotations
from typing import Optional, Dict

import numpy as np

from torch.distributions import Categorical
from torch.nn import CrossEntropyLoss

from fastai.callback.core import Callback
from fastai.callback.fp16 import MixedPrecision
from fastai.losses import CrossEntropyLossFlat, LabelSmoothingCrossEntropy, LabelSmoothingCrossEntropyFlat
from fastai.optimizer import Optimizer, _update

from fastxtend.optimizer.foreach import ForEachOptimizer

from fastxtend.imports import *

In [None]:
#|exporti
def sophia_step(p:Tensor, lr:float, eps:float, wd:float, mom:float, hess_mom:float,
                rho:float, bs:int, hessian_step:bool, grad_avg:Tensor|None=None,
                hessian:Tensor|None=None, do_wd:bool=True, **kwargs):
    "Updates Stable Adam moving averages and performs the Stable Adam step with `lr` on `p`"
    if grad_avg is None:
        grad_avg = torch.zeros_like(p, memory_format=torch.preserve_format)
        hessian  = torch.zeros_like(p, memory_format=torch.preserve_format)

    if hessian_step:
        hessian.mul_(hess_mom).addcmul_(p.grad.data, p.grad.data, value=1-hess_mom)
    else:
        if wd!=0 and do_wd:
            p.data.mul_(1-lr*wd)

        # update moving average
        grad_avg.mul_(mom).add_(p.grad.data, alpha=1-mom)

        # compute sophia update ratio
        ratio = grad_avg.abs().div(hessian.mul(rho * bs).add(eps)).clamp(None, 1)

        # sophia update step
        p.data.addcmul_(grad_avg.sign(), ratio, value=-lr)

    return {'grad_avg': grad_avg, 'hessian': hessian}

sophia_step.defaults = dict(mom=0.9, hess_mom=0.99)

In [None]:
#|exporti
class SophiaOptimizer(Optimizer):
    def __init__(self,
        params:Tensor|Iterable, # Model parameters
        cbs:callable|MutableSequence, # `Optimizer` step callbacks
        **defaults # Hyper parameters default values
    ):
        super().__init__(params, cbs, **defaults)
        self.update_sophia_hypers(0, False)

    def update_sophia_hypers(self, bs, hessian_step):
        self._bs = bs
        self._hessian_step = hessian_step

    def step(self, closure=None):
        if closure is not None: raise NotImplementedError("fastai optimizers currently do not support closure")
        for p,pg,state,hyper in self.all_params(with_grad=True):
            for cb in self.cbs:
                state = _update(state, cb(p, **{**state, **hyper}, bs=self._bs, hessian_step=self._hessian_step))
            self.state[p] = state

    def clear_state(self):
        super().clear_state()
        self.update_sophia_hypers(0, False)

In [None]:
#|exporti
def sophia_foreach_step(p:list[Tensor], g:list[Tensor], grad_avg:list[Tensor], hessian:list[Tensor],
                        do_wd:np.ndarray[Any, bool], lr:float, wd:float, mom:float, hess_mom:float,
                        eps:float, rho:float, bs:int, hessian_step:bool, **kwargs):
    if hessian_step:
        torch._foreach_mul_(hessian, scalar=hess_mom)
        torch._foreach_addcmul_(hessian, g, g, value=1-hess_mom)
    else:
        # weight_decay
        if wd != 0:
            wd = np.where(do_wd, 1-lr*wd, 1.)
            torch._foreach_mul_(p, scalars=wd.tolist())

        # update moving average
        torch._foreach_mul_(grad_avg, scalar=mom)
        torch._foreach_add_(grad_avg, g, alpha=1-mom)

        # compute sophia update ratio
        ratio = torch._foreach_abs(grad_avg)
        temp = torch._foreach_mul(hessian, scalar=rho*bs)
        torch._foreach_add_(temp, scalar=eps)
        torch._foreach_div_(ratio, temp)
        torch._foreach_clamp_max_(ratio, scalar=1)

        # sophia update step
        temp = [ga.sign() for ga in grad_avg]
        torch._foreach_addcmul_(p, temp, ratio, value=-lr)

In [None]:
#|exporti
class SophiaForEachOptimizer(ForEachOptimizer):
    "An `ForEachOptimizer` with a modified step for `sophia_foreach_step`"
    def __init__(self,
        params:Listified[Tensor], # Model parameters
        opt_step:Callable, # `ForEachOptimizer` optimizer step
        **defaults # Optimizer specific hyper parameters default values
    ):
        super().__init__(params, opt_step, True, **defaults)
        self.update_sophia_hypers(0, False)

    def update_sophia_hypers(self, bs, hessian_step):
        self._bs = bs
        self._hessian_step = hessian_step

    def clear_state(self):
        super().clear_state()
        self.update_sophia_hypers(0, False)

    @torch.no_grad()
    def step(self, closure=None):
        if closure is not None:
            raise NotImplementedError("fastai optimizers currently do not support closure")
        for pg, hyper in zip(self.param_lists, self.hypers):
            pl, gl, grad_avg, hessian, do_wd = [], [], [], [], []

            for p in pg:
                if hasattr(p, 'grad') and p.grad is not None:
                    state = self.state[p]

                    if 'grad_avg' not in state.keys():
                        state['grad_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        state['hessian']  = torch.zeros_like(p, memory_format=torch.preserve_format)

                    pl.append(p)
                    gl.append(p.grad)
                    grad_avg.append(state['grad_avg'])
                    hessian.append(state['hessian'])
                    do_wd.append(state.get('do_wd', True))

            self.opt_step(p=pl, g=gl, grad_avg=grad_avg, hessian=hessian,
                          do_wd=np.array(do_wd, dtype=bool), bs=self._bs,
                          hessian_step=self._hessian_step, **hyper)

In [None]:
#|export
def Sophia(
    params:Listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    mom:float=0.965, # Gradient moving average (β1) coefficient
    hess_mom:float=0.99, # Hessian moving average (β2) coefficient
    rho:float=0.4, # Maximum update size, set higher for more agressive updates
    eps:float=1e-15, # Added for numerical stability
    wd:float=0.01, # Optional weight decay
    foreach:bool=False, # Use fused ForEach implementation
) -> SophiaOptimizer|SophiaForEachOptimizer:
    "A fastai Sophia optimizer with a fused ForEach implementation"
    if foreach:
        return SophiaForEachOptimizer(params, sophia_foreach_step, lr=lr, mom=mom,
                                      hess_mom=hess_mom, rho=rho, eps=eps, wd=wd)
    else:
        return SophiaOptimizer(params, [sophia_step], lr=lr, mom=mom,
                               hess_mom=hess_mom, rho=rho, eps=eps, wd=wd)

In [None]:
#|export
def sophia(
    mom:float=0.965, # Gradient moving average (β1) coefficient
    hess_mom:float=0.99, # Hessian moving average (β2) coefficient
    rho:float=0.4, # Maximum update size, set higher for more agressive updates
    eps:float=1e-15, # Added for numerical stability
    wd:float=0.01, # Optional weight decay
    foreach:bool=False, # Use fused ForEach implementation
) -> SophiaOptimizer|SophiaForEachOptimizer:
    "Partial function for the Sophia optimizer with a fused ForEach implementation"
    return partialler(Sophia, mom=mom, hess_mom=hess_mom, eps=eps,
                      rho=rho, wd=wd, foreach=foreach)

In [None]:
#|exporti
class SophiaHessian(str, Enum):
    "Hessian estimator for the Sophia optimizer for autocomplete"
    sophiag = 'sophiag'

In [None]:
#|export
class SophiaCallback(Callback):
    "Modifies the training loop for the Sophia Optimizer. Required for Sophia to run."
    order,run_valid = MixedPrecision.order+1,False
    def __init__(self,
        hessian_update:int=10, # Update Sophia's Hessian estimate every `hessian_update` Optimizer steps
        # hessian_est:str|SophiaHessian=SophiaHessian.sophiag # Sophia's Hessian estimator. Defaults to SophiaG's Gauss-Newton-Bartlett
    ):
        store_attr()

    def before_fit(self):
        if not isinstance(self.learn.opt, (SophiaOptimizer, SophiaForEachOptimizer)):
            raise ValueError("`SophiaCallback` only supports the `Sophia` optimizer")
        if not isinstance(self.learn.loss_fn, (CrossEntropyLoss, CrossEntropyLossFlat,
                                               LabelSmoothingCrossEntropy,
                                               LabelSmoothingCrossEntropyFlat)):
            warn('Non-CrossEntropy loss detected, SophiaG assumes data is in a categorical distrobution.')
        self._step_iter = 0
        self._hessian_step = False

    @torch.no_grad()
    def before_loss(self):
        if self._step_iter % self.hessian_update == self.hessian_update:
            dist = Categorical(logits=self.pred)
            self.learn.yb = dist.sample()
            self._hessian_step = True

    def before_step(self):
        self.learn.opt.update_sophia_hypers(find_bs(self.learn.yb), self._hessian_step)

    def after_step(self):
        self._step_iter += 1
        self._hessian_step = False

`SophiaCallback` expects the loss function to be a CrossEntropy loss, and only supports single target and single loss function training.

## Hyperparameters

Hyperparameter notes from Liu et al:

1) Sophia hyperparameters should be similar to AdamW
2) $\rho$ (`rho`) should be in $[0.01, 0.1]$. A larger $\rho$ means more aggressive updates
3) Sophia may benefit from slightly higher weight decay and learning rate compared to AdamW

# Tests -

In [None]:
#|hide
from fastxtend.test_utils import *

In [None]:
#|hide
# Test contain code from:
# fastai - Apache License 2.0 - Copyright (c) 2023 fast.ai

test_steps = 25

params_org = tst_params()
opt_org = Sophia(params_org, lr=0.01)
opt_org.update_sophia_hypers(64, False)
opt_org.step()

params_for = tst_params()
opt_for = Sophia(params_for, lr=0.01, foreach=True)
opt_for.update_sophia_hypers(64, False)
opt_for.step()

test_close([p.item() for p in params_org], [0.0, 0.9899, 1.9898, 2.9897])
test_close([p.item() for p in params_org], [p.item() for p in params_for])

for i in range(test_steps):
    if i % 10 == 9:
        opt_org.update_sophia_hypers(64, True)
        opt_for.update_sophia_hypers(64, True)
    else:
        opt_org.update_sophia_hypers(64, False)
        opt_for.update_sophia_hypers(64, False)
    opt_org.step()
    opt_for.step()

# Sophia numerical values from SophiaG https://github.com/Liuhong99/Sophia
test_close([p.item() for p in params_org], [0.0, 0.757878, 1.755481, 2.753083])
test_close([p.item() for p in params_org], [p.item() for p in params_for])

[0.0, 0.757878, 1.755481, 2.753083]
