---
aliases:
    - optimizer.fused.html
    - optimizer.foreach.html
    - optimizer.torchscript.html
    - optimizer.adan.html
    - optimizer.lion.html
    - optimizer.stableadam.html
---

In [None]:
#|default_exp optimizer.optimizers

In [None]:
#|exporti
# Contains code from:
# fastai - Apache License 2.0 - Copyright (c) 2023 fast.ai

# Optimizers
> Fast fastai optimizers with optimi low precision and bitsandbytes 8-bit implementations

documentation placeholder

In [None]:
#|export
from __future__ import annotations

from fastcore.basics import partialler

from fastai.optimizer import (Optimizer, weight_decay, l2_reg, average_grad, average_sqr_grad,
                               step_stat, qhadam_step, larc_layer_lr, larc_step, lamb_step, rms_prop_step)

try:
    from fastxtend.optimizer.optimi import (AdamOptimiOptimizer, AdanOptimiOptimizer, LionOptimiOptimizer,
                                            RAdamOptimiOptimizer, RangerOptimiOptimizer, SGDOptimiOptimizer,
                                            StableAdamWOptimiOptimizer)
    OPTIMI = True
except ImportError:
    OPTIMI = False

try:
    from packaging.version import parse
    import bitsandbytes
    from fastxtend.optimizer.eightbit import (SGD8bitOptimizer, RMSProp8bitOptimizer, AdamW8bitOptimizer,
                                              LARS8bitOptimizer, LAMB8bitOptimizer, Lion8bitOptimizer)
    EIGHTBIT = True
except ImportError:
    EIGHTBIT = False

from fastxtend.imports import *

## Optimi and Eight-bit Optimizers

These optimizers support both optimi's low precision and bitsandbytes' eight-bit implementations.

### Adam -

In [None]:
#|export
def Adam(
    params:Listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    mom:float=0.9, # Gradient moving average (β1) coefficient
    sqr_mom:float=0.99, # Gradient squared moving average (β2) coefficient
    eps:float=1e-5, # Added for numerical stability
    wd:float=0.01, # Optional weight decay
    decouple_wd:bool=True, # Apply decoupled weight decay (AdamW) instead of L2 penalty (Adam)
    decouple_lr:bool=False, # Apply fully decoupled weight decay (AdamW) instead of L2 penalty (Adam). Unsupported for `eightbit=True`.
    kahan_sum:bool|None=None, # More accurate parameter updates when training in low precision (float16 or bfloat16). If unspecified, automatically applies for low precision parameters. Unsupported for `eightbit=True`. (float16 or bfloat16). If unspecified, automatically applies for low precision parameters. Unsupported for `eightbit=True`.
    foreach:bool|None=None, # Use ForEach implementation. If unspecified, tries to use foreach over for-loop implementation.
    eightbit:bool=False, # Use bitsandbytes' eight-bit implementation instead of optimi's implementation.
    **eightbitargs # Additional eight-bit arguments. See `AdamW8bitOptimizer` for details.
) -> AdamOptimiOptimizer|AdamW8bitOptimizer:
    "A fastai Adam/AdamW optimizer with low precision, foreach, and eight-bit implementations"

    if not eightbit:
        if OPTIMI:
            return AdamOptimiOptimizer(params, lr=lr, mom=mom, sqr_mom=sqr_mom, wd=wd,
                                       eps=eps, decouple_wd=decouple_wd, decouple_lr=decouple_lr,
                                       kahan_sum=kahan_sum, foreach=foreach)
        else:
            raise ImportError('optimi package not found. Run `pip install torch-optimi`.')
    else:
        if EIGHTBIT:
            if (not decouple_wd and wd > 0) or (decouple_lr and wd > 0):
                raise NotImplementedError(f'Eight-bit Adam only supports decoupled weight decay: {decouple_wd=}')
            return AdamW8bitOptimizer(params, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, **eightbitargs)
        else:
            raise ImportError(f'{eightbit=}. bitsandbytes package not found. Run `pip install bitsandbytes`.')

In [None]:
#|export
def adam(
    mom:float=0.9, # Gradient moving average (β1) coefficient
    sqr_mom:float=0.99, # Gradient squared moving average (β2) coefficient
    eps:float=1e-5, # Added for numerical stability
    wd:float=0.01, # Optional weight decay
    decouple_wd:bool=True, # Apply decoupled weight decay (AdamW) instead of L2 penalty (Adam)
    decouple_lr:bool=False, # Apply fully decoupled weight decay (AdamW) instead of L2 penalty (Adam). Unsupported for `eightbit=True`.
    kahan_sum:bool|None=None, # More accurate parameter updates when training in low precision (float16 or bfloat16). If unspecified, automatically applies for low precision parameters. Unsupported for `eightbit=True`. (float16 or bfloat16). If unspecified, automatically applies for low precision parameters. Unsupported for `eightbit=True`.
    foreach:bool|None=None, # Use faster ForEach implementation. If unspecified, tries to use foreach over for-loop implementation.
    eightbit:bool=False, # Use bitsandbytes' eight-bit implementation instead of optimi's implementation.
    **eightbitargs # Additional eight-bit arguments. See `AdamW8bitOptimizer` for details.
) -> AdamOptimiOptimizer|AdamW8bitOptimizer:
    "A fastai-compatible Adam/AdamW optimizer with low precision, foreach, and eight-bit implementations"
    return partialler(Adam, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, decouple_wd=decouple_wd,
                      decouple_lr=decouple_lr, kahan_sum=kahan_sum, foreach=foreach,
                      eightbit=eightbit, **eightbitargs)

### Lion -

In [None]:
#|export
def Lion(
    params: Listified[Tensor],  # Model parameters or parameter groups
    lr: float,  # Default learning rate
    beta1: float = 0.9,  # Update gradient moving average (β1) coefficient
    beta2: float = 0.99,  # Gradient moving average (β2) coefficient
    wd: float = 0.1,  # Decoupled weight decay
    decouple_lr: bool = False,  # Apply fully decoupled weight decay
    kahan_sum: bool | None = None,  # More accurate parameter updates when training in low precision (float16 or bfloat16). If unspecified, automatically applies for low precision parameters. Unsupported for `eightbit=True`.
    foreach: bool | None = None,  # Use faster ForEach implementation. If unspecified, tries to use foreach over for-loop implementation.
    eightbit: bool = False,  # Use bitsandbytes' eight-bit implementation
    **eightbitargs  # Additional eight-bit arguments
) -> LionOptimiOptimizer | Lion8bitOptimizer:
    "A fastai-compatible Lion optimizer with low precision, foreach, and eight-bit implementations"

    if not eightbit:
        if OPTIMI:
            return LionOptimiOptimizer(params, lr=lr, beta1=beta1, beta2=beta2, wd=wd,
                                       decouple_lr=decouple_lr, kahan_sum=kahan_sum, foreach=foreach)
        else:
            raise ImportError('optimi package not found. Run `pip install torch-optimi`.')
    else:
        if EIGHTBIT:
            if decouple_lr and wd > 0:
                raise NotImplementedError('Eight-bit Lion only supports decoupled weight decay.')
            return Lion8bitOptimizer(params, lr=lr, beta1=beta1, beta2=beta2, wd=wd, **eightbitargs)
        else:
            raise ImportError(f'{eightbit=}. bitsandbytes package not found. Run `pip install bitsandbytes`.')


In [None]:
#|export
def lion(
    beta1: float = 0.9,  # Update gradient moving average (β1) coefficient
    beta2: float = 0.99,  # Gradient moving average (β2) coefficient
    wd: float = 0.1,  # Decoupled weight decay
    decouple_lr: bool = False,  # Apply fully decoupled weight decay
    kahan_sum: bool | None = None,  # More accurate parameter updates when training in low precision (float16 or bfloat16). If unspecified, automatically applies for low precision parameters. Unsupported for `eightbit=True`.
    foreach: bool | None = None,  # Use faster ForEach implementation. If unspecified, tries to use foreach over for-loop implementation.
    eightbit: bool = False,  # Use bitsandbytes' eight-bit implementation
    **eightbitargs  # Additional eight-bit arguments
) -> LionOptimiOptimizer | Lion8bitOptimizer:
    "A partial function for the Lion optimizer with low precision, foreach, and eight-bit implementations"
    return partialler(Lion, beta1=beta1, beta2=beta2, wd=wd, decouple_lr=decouple_lr,
                      kahan_sum=kahan_sum, foreach=foreach, eightbit=eightbit, **eightbitargs)


## SGD -

In [None]:
#|export
def SGD(
    params: Listified[Tensor],  # Model parameters or parameter groups
    lr: float,  # Default learning rate
    mom: float = 0.,  # Gradient moving average (β1) coefficient
    wd: float = 0.,  # Optional weight decay (decoupled or L2)
    decouple_wd: bool = True,  # Apply decoupled weight decay (SGDW) or L2 regularization (SGD)
    decouple_lr: bool = False,  # Apply fully decoupled weight decay
    kahan_sum: bool | None = None,  # More accurate parameter updates when training in low precision (float16 or bfloat16). If unspecified, automatically applies for low precision parameters. Unsupported for `eightbit=True`.
    foreach: bool | None = None,  # Use faster ForEach implementation. If unspecified, tries to use foreach over for-loop implementation.
    eightbit: bool = False,  # Use bitsandbytes' eight-bit implementation
    **eightbitargs  # Additional eight-bit arguments
) -> SGDOptimiOptimizer | SGD8bitOptimizer:
    "A fastai-compatible SGD optimizer with low precision, foreach, and eight-bit implementations"

    if not eightbit:
        if OPTIMI:
            return SGDOptimiOptimizer(params, lr=lr, mom=mom, wd=wd, decouple_wd=decouple_wd,
                                      decouple_lr=decouple_lr, kahan_sum=kahan_sum, foreach=foreach)
        else:
            raise ImportError('optimi package not found. Run `pip install torch-optimi`.')
    else:
        if EIGHTBIT:
            if decouple_wd and wd > 0:
                raise NotImplementedError('Eight-bit SGD only supports L2 weight decay.')
            return SGD8bitOptimizer(params, lr=lr, mom=mom, wd=wd, **eightbitargs)
        else:
            raise ImportError(f'{eightbit=}. bitsandbytes package not found. Run `pip install bitsandbytes`.')


In [None]:
#|export
def sgd(
    mom: float = 0.,  # Gradient moving average (β1) coefficient
    wd: float = 0.,  # Optional weight decay (decoupled or L2)
    decouple_wd: bool = True,  # Apply decoupled weight decay (SGDW) or L2 regularization (SGD)
    decouple_lr: bool = False,  # Apply fully decoupled weight decay
    kahan_sum: bool | None = None,  # More accurate parameter updates when training in low precision (float16 or bfloat16). If unspecified, automatically applies for low precision parameters. Unsupported for `eightbit=True`.
    foreach: bool | None = None,  # Use faster ForEach implementation. If unspecified, tries to use foreach over for-loop implementation.
    eightbit: bool = False,  # Use bitsandbytes' eight-bit implementation
    **eightbitargs  # Additional eight-bit arguments
) -> SGDOptimiOptimizer | SGD8bitOptimizer:
    "Partial function for the SGD optimizer with low precision, foreach, and eight-bit implementations"
    return partialler(SGD, mom=mom, wd=wd, decouple_wd=decouple_wd, decouple_lr=decouple_lr,
                      kahan_sum=kahan_sum, foreach=foreach, eightbit=eightbit, **eightbitargs)


## Optimi-only Optimizers

These optimizers only use the optimi implementation and do not have an eight-bit version.

### Adan -

In [None]:
#|export
def Adan(
    params: Listified[Tensor],  # Model parameters or parameter groups
    lr: float,  # Default learning rate
    beta1: float = 0.98,  # Gradient moving average (β1) coefficient
    beta2: float = 0.92,  # Gradient difference moving average (β2) coefficient
    beta3: float = 0.99,  # Gradient squared moving average (β3) coefficient
    eps: float = 1e-8,  # Added for numerical stability
    wd: float = 0.02,  # Decoupled weight decay
    decouple_lr: bool = False,  # Apply fully decoupled weight decay
    adam_wd: bool = False,  # Apply weight decay before parameter update (Adam-style), instead of after the update per Adan algorithm
    kahan_sum: bool | None = None,  # More accurate parameter updates when training in low precision (float16 or bfloat16). If unspecified, automatically applies for low precision parameters.
    foreach: bool | None = None,  # Use faster ForEach implementation. If unspecified, tries to use foreach over for-loop implementation.
) -> AdanOptimiOptimizer:
    "A fastai-compatible Adan optimizer with low precision and foreach implementations"
    if OPTIMI:
        return AdanOptimiOptimizer(params, lr=lr, beta1=beta1, beta2=beta2, beta3=beta3, wd=wd,
                                   eps=eps, decouple_lr=decouple_lr, adam_wd=adam_wd,
                                   kahan_sum=kahan_sum, foreach=foreach)
    else:
        raise ImportError('optimi package not found. Run `pip install torch-optimi`.')


In [None]:
#|export
def adan(
    beta1: float = 0.98,  # Gradient moving average (β1) coefficient
    beta2: float = 0.92,  # Gradient difference moving average (β2) coefficient
    beta3: float = 0.99,  # Gradient squared moving average (β3) coefficient
    eps: float = 1e-8,  # Added for numerical stability
    wd: float = 0.02,  # Decoupled weight decay
    decouple_lr: bool = False,  # Apply fully decoupled weight decay
    adam_wd: bool = False,  # Apply weight decay before parameter update (Adam-style), instead of after the update per Adan algorithm
    kahan_sum: bool | None = None,  # More accurate parameter updates when training in low precision (float16 or bfloat16). If unspecified, automatically applies for low precision parameters.
    foreach: bool | None = None,  # Use faster ForEach implementation. If unspecified, tries to use foreach over for-loop implementation.
) -> AdanOptimiOptimizer:
    "A partial function for the Adan optimizer with low precision and foreach implementations"
    return partialler(Adan, beta1=beta1, beta2=beta2, beta3=beta3, eps=eps, wd=wd,
                      decouple_lr=decouple_lr, adam_wd=adam_wd, kahan_sum=kahan_sum,
                      foreach=foreach)


### RAdam -

In [None]:
#|export
def RAdam(
    params: Listified[Tensor],  # Model parameters or parameter groups
    lr: float,  # Default learning rate
    mom: float = 0.9,  # Gradient moving average (β1) coefficient
    sqr_mom: float = 0.99,  # Gradient squared moving average (β2) coefficient
    eps: float = 1e-5,  # Added for numerical stability
    wd: float = 0.,  # Optional weight decay (decoupled or L2)
    decouple_wd: bool = True,  # Apply decoupled weight decay (RAdamW) or L2 regularization (RAdam)
    decouple_lr: bool = False,  # Apply fully decoupled weight decay
    kahan_sum: bool | None = None,  # More accurate parameter updates when training in low precision (float16 or bfloat16). If unspecified, automatically applies for low precision parameters.
    foreach: bool | None = None,  # Use faster ForEach implementation. If unspecified, tries to use foreach over for-loop implementation.
) -> RAdamOptimiOptimizer:
    "A fastai-compatible RAdam optimizer with low precision and foreach implementations"

    if OPTIMI:
        return RAdamOptimiOptimizer(params, lr=lr, mom=mom, sqr_mom=sqr_mom, wd=wd, eps=eps,
                                    decouple_wd=decouple_wd, decouple_lr=decouple_lr,
                                    kahan_sum=kahan_sum, foreach=foreach)
    else:
        raise ImportError('optimi package not found. Run `pip install torch-optimi`.')


In [None]:
#|export
def radam(
    mom: float = 0.9,  # Gradient moving average (β1) coefficient
    sqr_mom: float = 0.99,  # Gradient squared moving average (β2) coefficient
    eps: float = 1e-5,  # Added for numerical stability
    wd: float = 0.,  # Optional weight decay (decoupled or L2)
    decouple_wd: bool = True,  # Apply decoupled weight decay (RAdamW) or L2 regularization (RAdam)
    decouple_lr: bool = False,  # Apply fully decoupled weight decay
    kahan_sum: bool | None = None,  # More accurate parameter updates when training in low precision (float16 or bfloat16). If unspecified, automatically applies for low precision parameters.
    foreach: bool | None = None,  # Use faster ForEach implementation. If unspecified, tries to use foreach over for-loop implementation.
) -> RAdamOptimiOptimizer:
    "Partial function for the RAdam optimizer with low precision and foreach implementations"
    return partialler(RAdam, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, decouple_wd=decouple_wd,
                      decouple_lr=decouple_lr, kahan_sum=kahan_sum, foreach=foreach)


### Ranger -

In [None]:
#|export
def Ranger(
    params: Listified[Tensor],  # Model parameters or parameter groups
    lr: float,  # Default learning rate
    mom: float = 0.95,  # Gradient moving average (β1) coefficient
    sqr_mom: float = 0.99,  # Gradient squared moving average (β2) coefficient
    eps: float = 1e-6,  # Added for numerical stability
    wd: float = 0.01,  # Optional weight decay (decoupled or L2)
    k: int = 6,  # How often to conduct Lookahead step
    alpha: float = 0.5,  # Slow weight moving average coefficient
    decouple_wd: bool = True,  # Apply decoupled weight decay (RangerW) or L2 regularization (Ranger)
    decouple_lr: bool = False,  # Apply fully decoupled weight decay
    kahan_sum: bool | None = None,  # More accurate parameter updates when training in low precision (float16 or bfloat16). If unspecified, automatically applies for low precision parameters. Unsupported for `eightbit=True`.
    foreach: bool | None = None  # Use faster ForEach implementation. If unspecified, tries to use foreach over for-loop implementation.
) -> RangerOptimiOptimizer:
    "Convenience method for `Lookahead` with `RAdam` with low precision and foreach implementations"

    if OPTIMI:
        return RangerOptimiOptimizer(params, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd,
                                     k=k, alpha=alpha, decouple_wd=decouple_wd, decouple_lr=decouple_lr,
                                     kahan_sum=kahan_sum, foreach=foreach)
    else:
        raise ImportError('optimi package not found. Run `pip install torch-optimi`.')


In [None]:
#|export
def ranger(
    mom: float = 0.95,  # Gradient moving average (β1) coefficient
    sqr_mom: float = 0.99,  # Gradient squared moving average (β2) coefficient
    eps: float = 1e-6,  # Added for numerical stability
    wd: float = 0.01,  # Optional weight decay (decoupled or L2)
    k: int = 6,  # How often to conduct Lookahead step
    alpha: float = 0.5,  # Slow weight moving average coefficient
    decouple_wd: bool = True,  # Apply decoupled weight decay (RangerW) or L2 regularization (Ranger)
    decouple_lr: bool = False,  # Apply fully decoupled weight decay
    kahan_sum: bool | None = None,  # More accurate parameter updates when training in low precision (float16 or bfloat16). If unspecified, automatically applies for low precision parameters. Unsupported for `eightbit=True`.
    foreach: bool | None = None  # Use faster ForEach implementation. If unspecified, tries to use foreach over for-loop implementation.
) -> RangerOptimiOptimizer:
    "Partial function for the Ranger optimizer using RAdam with low precision and foreach implementations"
    return partialler(Ranger, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, k=k,
                      alpha=alpha, decouple_wd=decouple_wd, decouple_lr=decouple_lr,
                      kahan_sum=kahan_sum, foreach=foreach)


### StableAdamW -

In [None]:
#|export
def StableAdamW(
    params: Listified[Tensor],  # Model parameters or parameter groups
    lr: float,  # Default learning rate
    mom: float = 0.9,  # Gradient moving average (β1) coefficient
    sqr_mom: float = 0.99,  # Gradient squared moving average (β2) coefficient
    eps: float = 1e-5,  # Added for numerical stability
    wd: float = 0.01,  # Optional weight decay
    decouple_lr: bool = False,  # Apply fully decoupled weight decay
    kahan_sum: bool | None = None,  # More accurate parameter updates when training in low precision (float16 or bfloat16). If unspecified, automatically applies for low precision parameters. Unsupported for `eightbit=True`.
    foreach: bool | None = None,  # Use faster ForEach implementation. If unspecified, tries to use foreach over for-loop implementation.
) -> StableAdamWOptimiOptimizer:
    "A fastai-compatible StableAdamW optimizer with low precision and foreach implementations"

    if OPTIMI:
        return StableAdamWOptimiOptimizer(params, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd,
                                          decouple_lr=decouple_lr, kahan_sum=kahan_sum, foreach=foreach)
    else:
        raise ImportError('optimi package not found. Run `pip install torch-optimi`.')


In [None]:
#|export
def stableadamw(
    mom: float = 0.9,  # Gradient moving average (β1) coefficient
    sqr_mom: float = 0.99,  # Gradient squared moving average (β2) coefficient
    eps: float = 1e-5,  # Added for numerical stability
    wd: float = 0.01,  # Optional weight decay (decoupled or L2)
    decouple_lr: bool = False,  # Apply fully decoupled weight decay
    kahan_sum: bool | None = None,  # More accurate parameter updates when training in low precision (float16 or bfloat16). If unspecified, automatically applies for low precision parameters. Unsupported for `eightbit=True`.
    foreach: bool | None = None,  # Use faster ForEach implementation. If unspecified, tries to use foreach over for-loop implementation.
) -> StableAdamWOptimiOptimizer:
    "Partial function for the StableAdamW optimizer with low precision and foreach implementations"
    return partialler(StableAdamW, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd,
                      decouple_lr=decouple_lr, kahan_sum=kahan_sum, foreach=foreach)

## fastai with Eight-bit Optimizers

These optimizers either use the fastai or bitsandbytes eight-bit implemenations.

### Larc -

In [None]:
#|export
def Larc(
    params:Listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    mom:float=0.9, # Gradient moving average (β1) coefficient
    clip:bool=True, # LARC if clip=True, LARS if clip=False
    trust_coeff:float=0.02, # Trust coeffiecnet for calculating layerwise LR
    eps:float=1e-8, # Added for numerical stability
    wd:float=0., # Optional weight decay (decoupled or L2)
    decouple_wd:bool=True, # Apply decoupled weight decay or L2 regularization. Ignored if `eightbit=True`
    eightbit:bool=False, # Use fused 8-bit implementation. Only supports LARS: `clip=False`
    hide_warning:bool=False, # Hide warning
    **eightbitargs
) -> Optimizer|LARS8bitOptimizer:
    "A fastai LARC/LARS optimizer with eight-bit implementations"
    if eightbit:
        if EIGHTBIT:
            if clip:
                raise NotImplementedError(f'{eightbit=} only supports the LARS optimizer. Set `clip=False`.')
            if decouple_wd and wd > 0:
                raise NotImplementedError(f'8-bit LARS only supports L2 weight decay: {decouple_wd=}')
            return LARS8bitOptimizer(params, lr=lr, mom=mom, wd=wd, trust_coeff=trust_coeff, **eightbitargs)
        else:
            raise ImportError(f'{eightbit=}. bitsandbytes package not found. Run `pip install bitsandbytes`.')
    else:
        if not hide_warning:
            warn("fastxtend doesn't have a non-eight-bit Lamb implementation, using the"
                 " fastai implementation. Pass `hide_warning=True` to hide this message.")
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        if mom!=0.: cbs.append(average_grad)
        cbs += [partial(larc_layer_lr, clip=clip), larc_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, trust_coeff=trust_coeff, eps=eps, wd=wd)

In [None]:
#|export
def larc(
    mom:float=0.9, # Gradient moving average (β1) coefficient
    clip:bool=True, # LARC if clip=True, LARS if clip=False
    trust_coeff:float=0.02, # Trust coeffiecnet for calculating layerwise LR
    eps:float=1e-8, # Added for numerical stability
    wd:float=0., # Optional weight decay (decoupled or L2)
    decouple_wd:bool=True, # Apply decoupled weight decay or L2 regularization
    eightbit:bool=False, # Use fused 8-bit implementation. Only supports LARS
    hide_warning:bool=False, # Hide warning
    **eightbitargs
) -> Optimizer|LARS8bitOptimizer:
    "Partial function for the LARC/LARS optimizer with fused TorchScript & 8-bit implementations"
    return partialler(Larc, mom=mom, clip=clip, eps=eps, trust_coeff=trust_coeff,
                      wd=wd, decouple_wd=decouple_wd, eightbit=eightbit, hide_warning=hide_warning)

### Lamb -

In [None]:
#|export
def Lamb(
    params:Listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    mom:float=0.9, # Gradient moving average (β1) coefficient
    sqr_mom:float=0.99, # Gradient squared moving average (β2) coefficient
    eps:float=1e-5, # Added for numerical stability
    wd:float=0., # Optional weight decay (decoupled or L2)
    decouple_wd:bool=True, # Apply decoupled weight decay or L2 regularization. Ignored if `eightbit=True`
    eightbit:bool=False, # Use fused 8-bit implementation. Only supports Decoupled weight decay
    hide_warning:bool=False, # Hide warning
    **eightbitargs
) -> Optimizer|LAMB8bitOptimizer:
    "A fastai LAMB optimizer with fused ForEach, TorchScript, & 8-bit implementations"
    if eightbit:
        if EIGHTBIT:
            if parse(bitsandbytes.__version__) <= parse('0.43.1') and not hide_warning:
                raise ValueError("8-bit LAMB in bitsandbytes will error out weights too small to quantize. "
                                 "Pass `hide_warning=True` to ignore and use anyway.")
            if not decouple_wd and wd > 0:
                raise NotImplementedError(f'8-bit LAMB only supports Decoupled weight decay: {decouple_wd=}')
            return LAMB8bitOptimizer(params, lr=lr, mom=mom, sqr_mom=sqr_mom,
                                     eps=eps, wd=wd, **eightbitargs)
        else:
            raise ImportError(f'{eightbit=}. bitsandbytes package not found. Run `pip install bitsandbytes`.')
    else:
        if not hide_warning:
            warn("fastxtend doesn't have a non-eight-bit Lamb implementation, using the"
                 " fastai implementation. Pass `hide_warning=True` to hide this message.")
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, lamb_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

In [None]:
#|export
def lamb(
    mom:float=0.9, # Gradient moving average (β1) coefficient
    sqr_mom:float=0.99, # Gradient squared moving average (β2) coefficient
    eps:float=1e-5, # Added for numerical stability
    wd:float=0., # Optional weight decay (decoupled or L2)
    decouple_wd:bool=True, # Apply decoupled weight decay or L2 regularization
    eightbit:bool=False, # Use fused 8-bit implementation. Only supports Decoupled weight decay
    hide_warning:bool=False, # Hide warning
    **eightbitargs
) -> Optimizer|LAMB8bitOptimizer:
    "Partial function for the LAMB optimizer with fused ForEach, TorchScript, & 8-bit implementations"
    return partialler(Lamb, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, decouple_wd=decouple_wd,
                      eightbit=eightbit, hide_warning=hide_warning, **eightbitargs)

### RMSProp -

In [None]:
#|export
def RMSProp(
    params:Listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    mom:float=0., # Gradient moving average (β1) coefficient
    sqr_mom:float=0.99, # Gradient squared moving average (β2) coefficient
    eps:float=1e-8, # Added for numerical stability
    wd:float=0., # Optional weight decay (true or L2)
    decouple_wd:bool=True, # Apply true weight decay or L2 regularization. Ignored if `eightbit=True`
    eightbit:bool=False, # Use fused 8-bit implementation. Only supports Decoupled weight decay
    hide_warning:bool=False, # Hide warning
    **eightbitargs
) -> Optimizer|RMSProp8bitOptimizer:
    "A fastai RMSProp/RMSPropW optimizer with fused TorchScript and 8-bit implementations"
    if eightbit:
        if EIGHTBIT:
            if decouple_wd and wd > 0:
                raise NotImplementedError(f'8-bit RMSProp only supports L2 weight decay: {decouple_wd=}')
            if mom > 0:
                raise NotImplementedError(f'8-bit RMSProp does not use momentum: {mom=}')
            return RMSProp8bitOptimizer(params, lr=lr, sqr_mom=sqr_mom, eps=eps, wd=wd, **eightbitargs)
        else:
            raise ImportError(f'{eightbit=}. bitsandbytes package not found. Run `pip install bitsandbytes`.')
    else:
        if not hide_warning:
            warn("fastxtend doesn't have a non-eight-bit RMSProp implementation, using the"
                 " fastai implementation. Pass `hide_warning=True` to hide this message.")
        cbs = [weight_decay] if decouple_wd else [l2_reg]
        cbs += ([average_sqr_grad] if mom==0. else [average_grad, average_sqr_grad])
        cbs.append(rms_prop_step)
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, wd=wd, eps=eps)

In [None]:
#|export
def rmsprop(
    mom:float=0., # Gradient moving average (β1) coefficient
    sqr_mom:float=0.99, # Gradient squared moving average (β2) coefficient
    eps:float=1e-8, # Added for numerical stability
    wd:float=0., # Optional weight decay (decoupled or L2)
    decouple_wd:bool=True, # Apply decoupled weight decay or L2 regularization
    eightbit:bool=False, # Use fused 8-bit implementation
    hide_warning:bool=False, # Hide warning
    **eightbitargs
) -> Optimizer|RMSProp8bitOptimizer:
    "Partial function for the RMSProp/RMSPropW optimizer with fused TorchScript and 8-bit implementations"
    return partialler(RMSProp, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd,
                      decouple_wd=decouple_wd, eightbit=eightbit,
                      hide_warning=hide_warning, **eightbitargs)

## fastai-only Optimizers

This optimizer only has a fastai implementation

In [None]:
#|export
def QHAdam(
    params:Listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    mom:float=0.999, # Gradient moving average (β1) coefficient
    sqr_mom:float=0.999, # Gradient squared moving average (β2) coefficient
    nu_1:float=0.7, # QH immediate discount factor
    nu_2:float=1.0, # QH momentum discount factor
    eps:float=1e-8, # Added for numerical stability
    wd:float=0., # Optional weight decay (decoupled or L2)
    decouple_wd:bool=True, # Apply decoupled weight decay (QHAdamW) or L2 regularization (QHAdam)
    hide_warning:bool=False, # Hide warning
) -> Optimizer:
    "The fastai QHAdam/QHAdamW optimizer"
    if not hide_warning:
        warn("fastxtend doesn't have a QHAdam implementation, using the fastai"
             " implementation. Pass `hide_warning=True` to hide this message.")
    cbs = [weight_decay] if decouple_wd else [l2_reg]
    cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, qhadam_step]
    return Optimizer(params, cbs, lr=lr, nu_1=nu_1, nu_2=nu_2, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

In [None]:
#|export
def qhadam(
    mom:float=0.999, # Gradient moving average (β1) coefficient
    sqr_mom:float=0.999, # Gradient squared moving average (β2) coefficient
    nu_1:float=0.7, # QH immediate discount factor
    nu_2:float=1.0, # QH momentum discount factor
    eps:float=1e-8, # Added for numerical stability
    wd:float=0., # Optional weight decay (decoupled or L2)
    decouple_wd:bool=True, # Apply decoupled weight decay (QHAdamW) or L2 regularization (QHAdam)
    hide_warning:bool=False, # Hide warning
) -> Optimizer:
    "Partial function for the fastai QHAdam/QHAdamW optimizer"
    return partialler(QHAdam, mom=mom, sqr_mom=sqr_mom, nu_1=nu_1, nu_2=nu_2, eps=eps,
                      wd=wd, decouple_wd=decouple_wd, hide_warning=hide_warning)

## Tests -

In [None]:
#|hide
#|cuda
import inspect
from itertools import product

from torch.utils.data import TensorDataset

from fastai.basics import default_device
from fastai.data.core import TfmdDL, DataLoaders
from fastai.learner import Learner

In [None]:
#|hide
#|cuda
class MLP(torch.nn.Module):
    def __init__(self, input_size, hidden_size, device, dtype):
        super().__init__()
        self.fc1 = torch.nn.Linear(input_size, hidden_size, bias=True, device=device, dtype=dtype)
        self.act = torch.nn.Mish()
        self.norm = nn.LayerNorm(hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, 1, bias=False, device=device, dtype=dtype)

    def forward(self, x):
        return self.fc2(self.norm(self.act(self.fc1(x))))

In [None]:
#|hide
#|cuda
def synth_dbunch(a=2, b=3, bs=16, n_train=10, n_valid=2, cuda=False):
    def get_data(n):
        x = torch.randn(bs*n, 256)
        return TensorDataset(x, a*x + b + 0.1*torch.randn(bs*n, 1))
    train_ds = get_data(n_train)
    valid_ds = get_data(n_valid)
    device = default_device() if cuda else None
    train_dl = TfmdDL(train_ds, bs=bs, shuffle=True, num_workers=0)
    valid_dl = TfmdDL(valid_ds, bs=bs, num_workers=0)
    return DataLoaders(train_dl, valid_dl, device=device)


def synth_learner(n_trn=10, n_val=2, optimizer=sgd, cuda=False, lr=1e-3, **kwargs):
    data=synth_dbunch(n_train=n_trn, n_valid=n_val, cuda=cuda)
    model=MLP(256, 512, device=default_device() if cuda else None, dtype=torch.float32)
    return Learner(data, model, lr=lr, opt_func=optimizer, loss_func=nn.MSELoss(), **kwargs)

In [None]:
#|hide
#|cuda
def dict_product(params):
    keys = params.keys()
    for combination in product(*params.values()):
        yield dict(zip(keys, combination))

In [None]:
#|hide
#|cuda
def filter_arguments(optimizer, args):
    optimizer_args = inspect.signature(optimizer).parameters
    return {k: v for k, v in args.items() if k in optimizer_args}

In [None]:
#|hide
#|cuda

# Since all of these optimziers have consistency tests in their respective libaries
# and most in their notebooks, this test makes sure the integration works without errors.
def test_optimizers(optimizers, eightbit, optimi):
    params = {
        'wd': [0, 1e-2],
        'decouple_wd': [True, False],
        'decouple_lr': [True, False],
        'foreach': [True, False],
        'eightbit': [True, False]
    }
    if not eightbit:
        params.pop('eightbit')
    if not optimi:
        params.pop('foreach')
    for optimizer in optimizers:
        for args in dict_product(params):
            if args.get('eightbit', False) and (args['decouple_lr'] or args.get('foreach', False)):
                pass
            opt = optimizer(**filter_arguments(optimizer, args))
            learn = synth_learner(optimizer=opt, cuda=torch.cuda.is_available())
            try:
                with learn.no_logging():
                    learn.fit(5)
                    if args.get('eightbit', False):
                        assert learn.opt.state[next(learn.model.parameters())]['state1'].dtype == torch.uint8
            except NotImplementedError:
                pass

In [None]:
#|hide
#|cuda
test_optimizers((adam, lion, sgd), eightbit=True, optimi=True)

In [None]:
#|hide
#|cuda
test_optimizers((adan, radam, ranger, stableadamw), eightbit=False, optimi=True)

In [None]:
#|hide
#|cuda

# lamb doesn't work with bnb 0.43.1
if parse(bitsandbytes.__version__) > parse('0.43.1'):
    optimizers = (larc, lamb, rmsprop)
else:
    optimizers = (larc, rmsprop)
test_optimizers(optimizers, eightbit=True, optimi=False)

In [None]:
#|hide
#|cuda
test_optimizers((qhadam,), eightbit=False, optimi=False)