In [None]:
#|default_exp optimizer.eightbit
#|default_cls_lvl 2

In [None]:
#|exporti
# Contains code from:
# bitsandbytes - MIT License - Copyright (c) Facebook, Inc. and its affiliates.
# fastai - Apache License 2.0 - Copyright (c) 2023 fast.ai

# 8-Bit Optimizers
> bitsandbytes 8-bit optimizers with full fastai compatibility

[bitsandbytes](https://github.com/TimDettmers/bitsandbytes) 8-bit optimizers can reduce optimizer memory usage up to 75% compared to 32-bit optimizers.

While it is possible to use bitsandbytes optimizers[^PyTorch] with fastai via `fastai.optimizer.OptimWrapper`, this doesn't provide compatibility with all fastai optimizer features. fastxtend adds full fastai compatibility to bitsandbytes 8-bit optimizers, including per-parameter weight decay, automatic weight decay exclusion for normalization and bias terms, and discriminative learning rate support.

To use 8-bit optimizers, install bitsandbytes on a machine with a Cuda device

```bash
pip install bitandbytes
```

then import fastxtend optimizers after importing fastai

```python
from fastxtend.vision.all import *
# or just import fastxtend optimizers
from fastxtend.optimizer.all import *

opt_func = adam(eightbit=True)
Learner(..., opt_func=opt_func)
```

If training NLP models, you may need to replace the PyTorch embedding layer with a bitsandbytes layer : `torch.nn.Embedding(..) -> bnb.nn.Embedding(..)`.

Check out the bitsandbytes [readme](https://github.com/TimDettmers/bitsandbytes#using-the-8-bit-optimizers) for more details on using 8-bit optimizers.

:::{.callout-note collapse="true"}
#### Note: Modification of Synchronize
bitsandbytes calls [`torch.cuda.synchronize`](https://pytorch.org/docs/stable/generated/torch.cuda.synchronize.html) after each optimizer step. This prevents starting the next optimizer step until the current step finishes, which may increase optimizer wallclock time.

fastxtend adds `sync_each_step=False` as an argument to both all 8-bit optimizers, disabling the per-step `torch.cuda.synchronize`. Set to `sync_each_step=True` to match bitsandbytes behavior.
:::

:::{.callout-note collapse="true"}
#### Note: Supress Import Warnings
fastxtend supresses the bitsandbytes import warning message. To view message import bitsandbytes seperately.
:::

[^PyTorch]: Or any PyTorch-compatible optimizer.

In [None]:
#|export
from __future__ import annotations

from bitsandbytes.optim.optimizer import Optimizer1State, Optimizer2State, MockArgs
import bitsandbytes.functional as BF

from fastcore.basics import even_mults

from fastxtend.imports import *

In [None]:
#|hide
from nbdev.showdoc import *

## fastai and bitsandbytes Compatibility

In [None]:
#|exporti
def _convert_params(o:list, **defaults) -> list:
    "Convert fastai param_lists to PyTorch param_groups, adding defaults if group doesn't have it"
    splitter = []
    for group in o:
        if isinstance(group, dict):
            splitter.append({**defaults, **group})
        else:
            splitter.append({'params':group, **defaults})
    return splitter

In [None]:
#|exporti
class EightBitFastaiAdapter:
    "Base for adding fastai optimizer functionality to EightBit Optimizers"
    _keep_on_clear = ['force_train', 'do_wd']
    def get_config(self, gindex, pindex, group):
        config = {}
        config["mom"] = group["mom"]
        config["sqr_mom"] = group["sqr_mom"]
        config["wd"] = group["wd"]
        config["eps"] = group["eps"]
        config["lr"] = group["lr"]
        config["optim_bits"] = self.args.optim_bits
        config["min_8bit_size"] = self.args.min_8bit_size
        config["percentile_clipping"] = self.args.percentile_clipping
        config["block_wise"] = self.args.block_wise
        config["max_unorm"] = self.args.max_unorm
        config["skip_zeros"] = self.args.skip_zeros

        if (gindex, pindex) in self.mng.index2config:
            config.update(self.mng.index2config[(gindex, pindex)])
        return config

    def get_params(self,
        n:slice|int=slice(None), # Extended slicing over the optimizer `param_lists`
        with_grad:bool=False # Get all param tuples. If `True` select only those with a gradient
    ):
        "Slice of parameters and parameter states"
        return L((p, self.state[p]) for pg in self.param_groups[n] for p in pg['params']
                    if (hasattr(p, 'grad') and p.grad is not None) or with_grad==False)

    def clear_state(self):
        "Reset the state of the optimizer"
        for p,state in self.get_params():
            self.state[p] = {k: state[k] for k in self._keep_on_clear if k in state}

    def _set_require_grad(self,
        rg:bool, # Requires grad: if `True` sets gradient for parameters, else uses state `state["force_train"]`
        p:Tensor, # Parameter to set gradient
        state:dict, # Parameter's state dict
    ):
        p.requires_grad_(rg or state.get('force_train', False))

    def freeze_to(self, n:int):
        "Freeze parameter groups up to `n`"
        self.frozen_idx = n if n >= 0 else len(self.param_groups) + n
        if self.frozen_idx >= len(self.param_groups):
            warn(f"Freezing {self.frozen_idx} groups; model has {len(self.param_groups)}; whole model is frozen.")
        for o in self.get_params(slice(n, None)):
            self._set_require_grad(True, *o)
        for o in self.get_params(slice(None, n)):
            self._set_require_grad(False, *o)

    def freeze(self):
        "Freeze up to last parameter group"
        assert(len(self.param_groups) > 1)
        self.freeze_to(-1)

    def unfreeze(self):
        "Unfreeze the entire model"
        self.freeze_to(0)

    @property
    def hypers(self):
        return [{k:v for k,v in pg.items() if k != 'params'} for pg in self.param_groups]

    def set_hypers(self, **kwargs):
        "`set_hyper` for all `kwargs`"
        L(kwargs.items()).starmap(self.set_hyper)

    def _set_hyper(self, k, v):
        "Set the value(s) in `v` for hyper-parameter `k`"
        for v_,h in zip(v, self.param_groups):
            h[k] = v_

    def set_hyper(self, k, v):
        "Set the value(s) in `v` for hyper-parameter `k`"
        if isinstance(v, slice):
            if v.start:
                v = even_mults(v.start, v.stop, len(self.param_groups))
            else:
                v = [v.stop/10]*(len(self.param_groups)-1) + [v.stop]
        v = L(v, use_list=None)
        if len(v)==1:
            v = v*len(self.param_groups)
        assert len(v) == len(self.param_groups), f"Trying to set {len(v)} values for {k} but there are {len(self.param_groups)} parameter groups."
        self._set_hyper(k, v)

    @property
    def param_lists(self):
        return [pg['params'] for pg in self.param_groups]

    @param_lists.setter
    def param_lists(self, v):
        for pg,v_ in zip(self.param_groups,v):
            pg['params'] = v_

In [None]:
show_doc(EightBitFastaiAdapter)

In [None]:
#|exporti
class EightBitCommon:
    "Common changes to EightBit Optimizers"
    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        if not self.initialized:
            self.check_overrides()
            self.to_gpu()  # needed for fairseq pure fp16 training
            self.initialized = True

        #if self.is_paged: self.page_mng.prefetch_all()
        for gindex, group in enumerate(self.param_groups):
            for pindex, p in enumerate(group["params"]):
                if p.grad is None:
                    continue
                state = self.state[p]
                if 'step' not in state:
                    self.init_state(group, p, gindex, pindex)

                self.prefetch_state(p)
                self.update_step(group, p, gindex, pindex)
                if self.sync_each_step:
                    torch.cuda.synchronize()
        if self.is_paged or not self.sync_each_step:
            # all paged operation are asynchronous, we need
            # to sync to make sure all tensors are in the right state
            torch.cuda.synchronize()

        return loss

In [None]:
show_doc(EightBitCommon)

In [None]:
#|exporti
class EightBit1StateOptimizer(EightBitCommon, EightBitFastaiAdapter, Optimizer1State):
    "Adds fastai optimizer functionality & compatability to `Optimizer1State`"
    def __init__(
        self,
        optimizer_name,
        params,
        lr=1e-3,
        mom=0.9,
        sqr_mom=0.0,
        eps=1e-8,
        wd=0.0,
        optim_bits=8,
        args=None,
        min_8bit_size=4096,
        percentile_clipping=100,
        block_wise=True,
        max_unorm=0.0,
        skip_zeros=False,
        is_paged=False,
        sync_each_step=False
    ):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= mom < 1.0:
            raise ValueError(f"Invalid mom value: {mom}")
        if not 0.0 <= sqr_mom < 1.0:
            raise ValueError(f"Invalid sqr_mom value: {sqr_mom}")
        if not 0.0 <= wd:
            raise ValueError(f"Invalid weight_decay value: {wd}")
        defaults = dict(lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)
        params = L(params)
        params = _convert_params(params, **defaults) if isinstance(params[0], (L,list)) else params
        super(Optimizer1State, self).__init__(params, defaults, optim_bits, is_paged)

        if args is None:
            args = {}
            args["optim_bits"] = optim_bits
            args["percentile_clipping"] = 100
            args["min_8bit_size"] = min_8bit_size
            args["percentile_clipping"] = percentile_clipping
            args["block_wise"] = block_wise
            args["max_unorm"] = max_unorm
            args["skip_zeros"] = skip_zeros

            self.args = MockArgs(args)
        else:
            self.args = args

        self.optimizer_name = optimizer_name
        self.sync_each_step = sync_each_step

    @torch.no_grad()
    def update_step(self, group, p, gindex, pindex):
        state = self.state[p]
        grad = p.grad

        config = self.get_config(gindex, pindex, group)

        state["step"] += 1
        step = state["step"]

        if config["percentile_clipping"] < 100:
            current_gnorm, clip_value, gnorm_scale = BF.percentile_clipping(
                grad, state["gnorm_vec"], step, config["percentile_clipping"]
            )
        else:
            gnorm_scale = 1.0

        if state["state1"].dtype == torch.float:
            BF.optimizer_update_32bit(
                self.optimizer_name,
                grad,
                p,
                state["state1"],
                config["mom"],
                config["eps"],
                step,
                config["lr"],
                None,
                config['sqr_mom'],
                state.get('wd', config['wd']) if state.get('do_wd', True) else 0.0,
                gnorm_scale,
                state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
                max_unorm=config["max_unorm"],
                skip_zeros=config["skip_zeros"],
            )

        elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
            BF.optimizer_update_8bit(
                self.optimizer_name,
                grad,
                p,
                state["state1"],
                None,
                config["mom"],
                config['sqr_mom'],
                config["eps"],
                step,
                config["lr"],
                state["qmap1"],
                None,
                state["max1"],
                None,
                state["new_max1"],
                None,
                state.get('wd', config['wd']) if state.get('do_wd', True) else 0.0,
                gnorm_scale,
                state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
                max_unorm=config["max_unorm"],
            )

            state["max1"], state["new_max1"] = state["new_max1"], state["max1"]
        elif state["state1"].dtype == torch.uint8 and config["block_wise"]:
            BF.optimizer_update_8bit_blockwise(
                self.optimizer_name,
                grad,
                p,
                state["state1"],
                None,
                config["mom"],
                config['sqr_mom'],
                config["eps"],
                step,
                config["lr"],
                state["qmap1"],
                None,
                state["absmax1"],
                None,
                state.get('wd', config['wd']) if state.get('do_wd', True) else 0.0,
                gnorm_scale=gnorm_scale,
                skip_zeros=config["skip_zeros"],
            )

In [None]:
show_doc(EightBit1StateOptimizer)

In [None]:
#|exporti
class EightBit2StateOptimizer(EightBitCommon, EightBitFastaiAdapter, Optimizer2State):
    "Adds fastai optimizer functionality & compatability to `Optimizer2State`"
    def __init__(
        self,
        optimizer_name,
        params,
        lr=1e-3,
        mom=0.9,
        sqr_mom=0.999,
        eps=1e-8,
        wd=0.0,
        optim_bits=8,
        args=None,
        min_8bit_size=4096,
        percentile_clipping=100,
        block_wise=True,
        max_unorm=0.0,
        skip_zeros=False,
        is_paged=False,
        sync_each_step=False
    ):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= mom < 1.0:
            raise ValueError(f"Invalid mom value: {mom}")
        if not 0.0 <= sqr_mom < 1.0:
            raise ValueError(f"Invalid sqr_mom value: {sqr_mom}")
        if not 0.0 <= wd:
            raise ValueError(f"Invalid weight_decay value: {wd}")
        defaults = dict(lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)
        params = L(params)
        params = _convert_params(params, **defaults) if isinstance(params[0], (L,list)) else params
        super(Optimizer2State, self).__init__(params, defaults, optim_bits, is_paged)

        if args is None:
            args = {}
            args["optim_bits"] = optim_bits
            args["percentile_clipping"] = 100
            args["min_8bit_size"] = min_8bit_size
            args["percentile_clipping"] = percentile_clipping
            args["block_wise"] = block_wise
            args["max_unorm"] = max_unorm
            args["skip_zeros"] = skip_zeros

            self.args = MockArgs(args)
        else:
            self.args = args

        self.optimizer_name = optimizer_name
        self.sync_each_step = sync_each_step

    @torch.no_grad()
    def update_step(self, group, p, gindex, pindex):
        state = self.state[p]
        grad = p.grad

        config = self.get_config(gindex, pindex, group)

        state["step"] += 1
        step = state["step"]

        if config["percentile_clipping"] < 100:
            current_gnorm, clip_value, gnorm_scale = BF.percentile_clipping(
                grad, state["gnorm_vec"], step, config["percentile_clipping"]
            )
        else:
            gnorm_scale = 1.0

        if state["state1"].dtype == torch.float:
            BF.optimizer_update_32bit(
                self.optimizer_name,
                grad,
                p,
                state["state1"],
                config["mom"],
                config["eps"],
                step,
                config["lr"],
                state["state2"],
                config["sqr_mom"],
                state.get('wd', config['wd']) if state.get('do_wd', True) else 0.0,
                gnorm_scale,
                state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
                max_unorm=config["max_unorm"],
                skip_zeros=config["skip_zeros"],
            )

        elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
            BF.optimizer_update_8bit(
                self.optimizer_name,
                grad,
                p,
                state["state1"],
                state["state2"],
                config["mom"],
                config['sqr_mom'],
                config["eps"],
                step,
                config["lr"],
                state["qmap1"],
                state["qmap2"],
                state["max1"],
                state["max2"],
                state["new_max1"],
                state["new_max2"],
                state.get('wd', config['wd']) if state.get('do_wd', True) else 0.0,
                gnorm_scale=gnorm_scale,
                unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
                max_unorm=config["max_unorm"],
            )

            # swap maxes
            state["max1"], state["new_max1"] = state["new_max1"], state["max1"]
            state["max2"], state["new_max2"] = state["new_max2"], state["max2"]
        elif state["state1"].dtype == torch.uint8 and config["block_wise"]:
            BF.optimizer_update_8bit_blockwise(
                self.optimizer_name,
                grad,
                p,
                state["state1"],
                state["state2"],
                config["mom"],
                config['sqr_mom'],
                config["eps"],
                step,
                config["lr"],
                state["qmap1"],
                state["qmap2"],
                state["absmax1"],
                state["absmax2"],
                state.get('wd', config['wd']) if state.get('do_wd', True) else 0.0,
                gnorm_scale=gnorm_scale,
                skip_zeros=config["skip_zeros"],
            )

In [None]:
show_doc(EightBit2StateOptimizer)

## 8-bit Optimizers

In [None]:
#|export
class SGD8bitOptimizer(EightBit1StateOptimizer):
    "A fastai-compatible bitsandbytes 8-bit SGD optimizer"
    def __init__(
        self,
        params,
        lr,
        mom,
        wd=0,
        args=None,
        min_8bit_size=4096,
        percentile_clipping=100,
        block_wise=True,
        sync_each_step=False
    ):
        if mom == 0:
            raise NotImplementedError(f"8-bit SGD without momentum {mom=} is not supported")
        super().__init__("momentum", params, lr, mom, 0.0, 0.0, wd, 8, args,
                         min_8bit_size, percentile_clipping, block_wise,
                         sync_each_step=sync_each_step)

In [None]:
#|export
class RMSProp8bitOptimizer(EightBit1StateOptimizer):
    "A fastai-compatible bitsandbytes 8-bit RMSProb optimizer"
    def __init__(
        self,
        params,
        lr=1e-2,
        sqr_mom=0.99,
        eps=1e-8,
        wd=0,
        args=None,
        min_8bit_size=4096,
        percentile_clipping=100,
        block_wise=True,
        sync_each_step=False
    ):
        if sqr_mom == 0:
            raise NotImplementedError(f"8-bit RMSProp with {sqr_mom=} is not supported")
        super().__init__("rmsprop", params, lr, sqr_mom, 0, eps, wd, 8, args,
                         min_8bit_size, percentile_clipping, block_wise,
                         sync_each_step=sync_each_step)

In [None]:
#|export
class AdamW8bitOptimizer(EightBit2StateOptimizer):
    "A fastai-compatible bitsandbytes 8-bit AdamW optimizer"
    def __init__(self,
        params,
        lr=1e-3,
        mom=0.9,
        sqr_mom=0.99,
        eps=1e-8,
        wd=1e-2,
        args=None,
        min_8bit_size=4096,
        percentile_clipping=100,
        block_wise=True,
        is_paged=False,
        sync_each_step=False
    ):
        super().__init__("adam", params, lr, mom, sqr_mom, eps, wd, 8, args,
                         min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged,
                         sync_each_step=sync_each_step)

In [None]:
#|export
class LARS8bitOptimizer(EightBit1StateOptimizer):
    "A fastai-compatible bitsandbytes 8-bit LARS optimizer"
    def __init__(
        self,
        params,
        lr,
        mom=0,
        wd=0,
        args=None,
        min_8bit_size=4096,
        percentile_clipping=100,
        trust_coeff=0.02,
        sync_each_step=False
    ):
        if mom == 0:
            raise NotImplementedError(f"8-bit LARS without momentum {mom=} is not supported")
        super().__init__("lars", params, lr, mom, 0.0, 0.0, wd, 8, args,
                         min_8bit_size, percentile_clipping, max_unorm=trust_coeff, block_wise=False,
                         sync_each_step=sync_each_step)

In [None]:
#|export
class LAMB8bitOptimizer(EightBit2StateOptimizer):
    "A fastai-compatible bitsandbytes 8-bit LAMB optimizer"
    def __init__(
        self,
        params,
        lr=1e-3,
        mom=0.9,
        sqr_mom=0.999,
        eps=1e-8,
        wd=0,
        args=None,
        min_8bit_size=4096,
        percentile_clipping=100,
        block_wise=False,
        sync_each_step=False
    ):
        super().__init__("lamb", params, lr, mom, sqr_mom, eps, wd, 8, args,
                         min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0,
                         sync_each_step=sync_each_step)

In [None]:
#|export
class Lion8bitOptimizer(EightBit1StateOptimizer):
    "A fastai-compatible bitsandbytes 8-bit Lion optimizer"
    def __init__(self,
        params,
        lr=1e-4,
        beta1=0.9,
        beta2=0.99,
        wd=0,
        args=None,
        min_8bit_size=4096,
        percentile_clipping=100,
        block_wise=True,
        is_paged=False,
        sync_each_step=False
    ):
        super().__init__("lion", params, lr, beta1, beta2, 0., wd, 8, args,
                         min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged,
                         sync_each_step=sync_each_step)

## fastai Compatiblity Tests -

In [None]:
#|hide
from fastxtend.test_utils import *

In [None]:
#|hide
params = [tst_params(), tst_params(), tst_params()]
for opt in [EightBit1StateOptimizer('momentum', params, lr=0.1), EightBit2StateOptimizer('adam', params, lr=0.1)]:
    #Freezing the first layer
    opt.freeze_to(1)
    req_grad = Self.requires_grad()
    test_eq(L(params[0]).map(req_grad), [False]*4)
    for i in {1,2}:
        test_eq(L(params[i]).map(req_grad), [True]*4)

    #Unfreezing
    opt.unfreeze()
    for i in range(2):
        test_eq(L(params[i]).map(req_grad), [True]*4)

In [None]:
#|hide
params = [tst_params(), tst_params(), tst_params()]
for opt in [EightBit1StateOptimizer('momentum', params, lr=0.1), EightBit2StateOptimizer('adam', params, lr=0.1)]:
    for p in L(params[1])[[1,3]]:
        opt.state[p] = {'force_train': True}
    opt.freeze()
    test_eq(L(params[0]).map(req_grad), [False]*4)
    test_eq(L(params[1]).map(req_grad), [False, True, False, True])
    test_eq(L(params[2]).map(req_grad), [True]*4)

In [None]:
#|hide
params = [tst_params(), tst_params(), tst_params()]
for opt in [EightBit1StateOptimizer('momentum', params), EightBit2StateOptimizer('adam', params)]:
    opt.set_hypers(lr=3e-3, mom=0.98, sqr_mom=0.9999, eps=1e-6, wd=0.01)
    test_eq(opt.hypers, [dict(lr=3e-3, mom=0.98, sqr_mom=0.9999, eps=1e-6, wd=0.01)]*3)

## 8-bit Optimizer Tests -

In [None]:
#|hide
#|cuda
from fastai.optimizer import (weight_decay, l2_reg, average_grad, momentum_step,
                              average_sqr_grad, rms_prop_step, step_stat, adam_step,
                              larc_layer_lr, larc_step, lamb_step, Optimizer)

In [None]:
#|hide
#|cuda
def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=5):
    idx = torch.isclose(a, b, rtol=rtol, atol=atol)
    error_count = (idx == 0).sum().item()
    if error_count > max_error_count:
        print(f"Too many values not close: assert {error_count} < {max_error_count}")
        torch.testing.assert_close(a, b, rtol=rtol, atol=atol)

def tst_param(shape):
    "Create a tensor with `val` and a gradient of `grad` for testing"
    param = torch.rand(shape, device='cuda', dtype=torch.float32) * 0.1
    param.grad = torch.rand(shape, device='cuda', dtype=param.dtype) * 0.01
    return param

def tst_params():
    params1, params2 = [], []
    for i in range(4):
        param1 = tst_param(4096)
        param2 = param1.clone()
        param2.grad = param1.grad.clone()
        params1.append(param1)
        params2.append(param2)
    return params1, params2

In [None]:
#|hide
#|cuda
def SGD(
    params:Listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    mom:float=0.9, # Gradient moving average (β1) coefficient
    wd:float=0., # Optional L2 weight decay
    eightbit:bool=False, # Use fused 8-bit implementation
    **eightbitargs
) -> Optimizer|SGD8bitOptimizer:
    if eightbit:
        return SGD8bitOptimizer(params, lr=lr, mom=mom, wd=wd, **eightbitargs)
    else:
        cbs = [l2_reg, average_grad, momentum_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, wd=wd)

In [None]:
#|hide
#|cuda
params_org, params_8bit = tst_params()
opt_org = SGD(params_org, lr=0.01)
opt_org.step()

opt_8bit = SGD(params_8bit, lr=0.01, eightbit=True)
opt_8bit.step()
for p, e in zip(params_org, params_8bit):
    assert_most_approx_close(p, e)

opt_org.step()
opt_8bit.step()
for p, e in zip(params_org, params_8bit):
    assert_most_approx_close(p, e)

In [None]:
#|hide
#|cuda
def RMSProp(
    params:Listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    sqr_mom:float=0.99, # Gradient squared moving average (β2) coefficient
    eps:float=1e-8, # Added for numerical stability
    wd:float=0., # Optional weight decay (true or L2)
    eightbit:bool=False, # Use fused 8-bit implementation
    **eightbitargs
) -> Optimizer|RMSProp8bitOptimizer:
    "A fastai RMSProp/RMSPropW optimizer with fused TorchScript and 8-bit implementations"
    if eightbit:
        return RMSProp8bitOptimizer(params, lr=lr, sqr_mom=sqr_mom, eps=eps, wd=wd, **eightbitargs)
    else:
        cbs = [l2_reg, average_sqr_grad, rms_prop_step]
        return Optimizer(params, cbs, lr=lr, sqr_mom=sqr_mom, wd=wd, eps=eps)

In [None]:
#|hide
#|cuda
params_org, params_8bit = tst_params()
opt_org = RMSProp(params_org, lr=1e-3)
opt_org.step()

opt_8bit = RMSProp(params_8bit, lr=1e-3, eightbit=True)
opt_8bit.step()
for p, e in zip(params_org, params_8bit):
    assert_most_approx_close(p, e)

opt_org.step()
opt_8bit.step()
for p, e in zip(params_org, params_8bit):
    assert_most_approx_close(p, e)

In [None]:
#|hide
#|cuda
def Adam(
    params:Listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    mom:float=0.9, # Gradient moving average (β1) coefficient
    sqr_mom:float=0.99, # Gradient squared moving average (β2) coefficient
    eps:float=1e-5, # Added for numerical stability
    wd:float=0.01, # Optional weight decay (true or L2)
    eightbit:bool=False, # Use fused 8-bit implementation
    **eightbitargs
) -> Optimizer|AdamW8bitOptimizer:
    if eightbit:
        return AdamW8bitOptimizer(params, lr=lr, mom=mom, sqr_mom=sqr_mom,
                                eps=eps, wd=wd, **eightbitargs)
    else:
        cbs = [weight_decay, partial(average_grad, dampening=True), average_sqr_grad, step_stat, adam_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

In [None]:
#|hide
#|cuda
params_org, params_8bit = tst_params()
opt_org = Adam(params_org, lr=1e-3)
opt_org.step()

opt_8bit = Adam(params_8bit, lr=1e-3, eightbit=True)
opt_8bit.step()
for p, e in zip(params_org, params_8bit):
    assert_most_approx_close(p, e)

opt_org.step()
opt_8bit.step()
for p, e in zip(params_org, params_8bit):
    assert_most_approx_close(p, e)

In [None]:
#|hide
#|cuda
def Lars(
    params:Listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    mom:float=0.9, # Gradient moving average (β1) coefficient
    trust_coeff:float=0.02, # Trust coeffiecnet for calculating layerwise LR
    eps:float=1e-8, # Added for numerical stability
    wd:float=0., # Optional weight decay (true or L2)
    eightbit:bool=False, # Use fused 8-bit implementation. Only supports LARS: `clip=False`
    **eightbitargs
) -> Optimizer|LARS8bitOptimizer:
    if eightbit:
        return LARS8bitOptimizer(params, lr=lr, mom=mom, wd=wd, trust_coeff=trust_coeff, **eightbitargs)
    else:
        cbs = [l2_reg, average_grad,partial(larc_layer_lr, clip=False), larc_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, trust_coeff=trust_coeff, eps=eps, wd=wd)

In [None]:
#|hide
#|cuda
params_org, params_8bit = tst_params()
opt_org = Lars(params_org, lr=0.01)
opt_org.step()

opt_8bit = Lars(params_8bit, lr=0.01, eightbit=True)
opt_8bit.step()
for p, e in zip(params_org, params_8bit):
    assert_most_approx_close(p, e)

opt_org.step()
opt_8bit.step()
for p, e in zip(params_org, params_8bit):
    assert_most_approx_close(p, e)

In [None]:
#|hide
#|cuda
def Lamb(
    params:Listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    mom:float=0.9, # Gradient moving average (β1) coefficient
    sqr_mom:float=0.99, # Gradient squared moving average (β2) coefficient
    eps:float=1e-5, # Added for numerical stability
    wd:float=0., # Optional weight decay (true or L2)
    eightbit:bool=False, # Use fused 8-bit implementation. Only supports true weight decay
    **eightbitargs
) -> Optimizer|LAMB8bitOptimizer:
    if eightbit:
        return LAMB8bitOptimizer(params, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, **eightbitargs)
    else:
        cbs = [weight_decay, partial(average_grad, dampening=True), average_sqr_grad, step_stat, lamb_step]
        return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

In [None]:
#|hide
#|cuda
params_org, params_8bit = tst_params()
opt_org = Lamb(params_org, lr=0.01)
opt_org.step()

opt_8bit = Lamb(params_8bit, lr=0.01, eightbit=True)
opt_8bit.step()
for p, e in zip(params_org, params_8bit):
    assert_most_approx_close(p, e)

opt_org.step()
opt_8bit.step()
for p, e in zip(params_org, params_8bit):
    assert_most_approx_close(p, e)

## 8-bit Training Test -

In [None]:
#|hide
#|cuda
from packaging.version import parse
import fastai

from fastcore.basics import num_cpus

if parse(fastai.__version__) < parse('2.7.11'):
    from fastxtend.callback.channelslast import *
else:
    from fastai.callback.channelslast import *
from fastai.data.external import URLs, untar_data
from fastai.data.block import DataBlock, CategoryBlock
from fastai.data.transforms import GrandparentSplitter, get_image_files, parent_label, Normalize
from fastai.learner import Learner
from fastai.vision.augment import Resize
from fastai.vision.core import imagenet_stats
from fastai.vision.data import ImageBlock
from fastxtend.metrics import *
from fastxtend.vision.models.xresnet import xresnext34

In [None]:
#|hide
#|cuda
imagenette = untar_data(URLs.IMAGENETTE_160)

with less_random():
    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                       splitter=GrandparentSplitter(valid_name='val'),
                       get_items=get_image_files, get_y=parent_label,
                       item_tfms=Resize(128),
                       batch_tfms=[Normalize.from_stats(*imagenet_stats)])

    dls = dblock.dataloaders(imagenette, bs=64,
                             num_workers=num_cpus())

    learn = Learner(dls, xresnext34(n_out=dls.c), opt_func=Adam,
                    loss_func=nn.CrossEntropyLoss(label_smoothing=0.1),
                    metrics=Accuracy()).to_channelslast()

    learn.fit_one_cycle(5, 3e-3)

epoch,train_loss,valid_loss,accuracy,time
0,1.959734,2.83424,0.320764,00:12
1,1.543875,1.48775,0.585223,00:08
2,1.225693,1.226562,0.696306,00:08
3,1.037938,1.019766,0.786242,00:08
4,0.917183,0.970842,0.808408,00:08
