In [None]:
#|default_exp callback.amp

In [None]:
#|exporti
# Contains code from:
# fastai - Apache License 2.0 - Copyright (c) 2023 fast.ai

In [None]:
#|export
from __future__ import annotations

from torch.cuda.amp import GradScaler, autocast

from fastai.callback.core import Callback, CancelStepException
from fastai.learner import Learner
from fastai.torch_basics import ismin_torch

from fastxtend.imports import *

# Automatic Mixed Precision

> Mixed precision training using PyTorch's AMP

With supported hardware, fastxtend supports training in both float16 and bfloat16 automatic mixed precision.

For details on float16 mixed precision training, please see the fastai [mixed precision documentation](https://docs.fast.ai/callback.fp16.html).

:::{.callout-note}
BF16 mixed precision support has been [upstreamed](https://github.com/fastai/fastai/pull/3929) into fastai 2.7.13.
:::

## MixedPrecision Callback
fastxtend's <code>MixedPrecision</code> is a drop in replacement for `fastai.callback.fp16.MixedPrecision` for float16 mixed precision. Set `amp_mode` to train in bfloat16 mixed precision or use `Learner.to_bf16`.

In [None]:
#|export
class AMPMode(str, Enum):
    "Automatic mixed precision modes for ease of completion"
    FP16 = 'fp16'
    BF16 = 'bf16'

In [None]:
#|export
@delegates(GradScaler)
class MixedPrecision(Callback):
    "Mixed precision training using Pytorch's Automatic Mixed Precision (AMP)"
    order = 10
    def __init__(self,
        amp_mode:str|AMPMode=AMPMode.FP16, # Mixed Precision training mode. Supports fp16 and bf16.
        **kwargs
    ):
        amp_mode = AMPMode(amp_mode)
        store_attr(names='amp_mode')
        self.kwargs = kwargs

    def before_fit(self):
        if self.amp_mode == AMPMode.BF16:
            if not ismin_torch("1.10"):
                raise ValueError("PyTorch 1.10 or greater required for bfloat16 mixed precision training.")
            if torch.cuda.is_available() and not torch.cuda.is_bf16_supported():
                raise ValueError("Unsuported GPU for bfloat16 mixed precision training.")
            dtype = torch.bfloat16
        elif self.amp_mode == AMPMode.FP16:
            dtype = torch.float16
        else:
            raise ValueError(f"Unrecognized precision: {self.amp_mode=}")

        # `autocast` dtype should not be set before PyTorch 1.10.
        self.autocast = autocast(dtype=dtype) if ismin_torch("1.10") else autocast()

        # `GradScaler` is not needed for bfloat16 as fp32 and bf16 have the same range
        self.kwargs['enabled'] = dtype == torch.float16
        self.learn.scaler,self.scales = GradScaler(**self.kwargs),L()

    def before_batch(self):
        self.autocast.__enter__()

    def after_pred(self):
        self.learn.pred = to_float(self.pred)

    def after_loss(self):
        self.autocast.__exit__(None, None, None)

    def before_backward(self):
        self.learn.loss_grad = self.scaler.scale(self.loss_grad)

    def before_step(self):
        "Use `self` as a fake optimizer. `self.skipped` will be set to True `after_step` if gradients overflow."
        self.skipped=True
        self.scaler.step(self)
        if self.skipped:
            raise CancelStepException()
        self.scales.append(self.scaler.get_scale())

    def after_step(self):
        self.learn.scaler.update()

    def after_fit(self):
        self.autocast,self.learn.scaler,self.scales = None,None,None

    @property
    def param_groups(self):
        "Pretend to be an optimizer for `GradScaler`"
        return self.opt.param_groups

    def step(self, *args, **kwargs):
        "Fake optimizer step to detect whether this batch was skipped from `GradScaler`"
        self.skipped=False

`amp_mode` accepts both <code>AMPMode</code> enums and 'fp16' or 'bf16' strings.

Passing `GradScaler` arguments to <code>MixedPrecision</code> when training in bfloat16 has no effect, as bfloat16 mixed precision does not used a gradient scaler.

## Convenience Methods
In addition to the fastai convenience methods, fastxtend adds <code>Learner.to_bf16</code> for training in bfloat16 mixed precision.

In [None]:
#|export
@patch
@delegates(GradScaler)
def to_fp16(self:Learner, **kwargs):
    "Set `Learner` to float16 mixed precision using PyTorch AMP"
    return self.add_cb(MixedPrecision(**kwargs))

In [None]:
#|export
@patch
def to_bf16(self:Learner):
    "Set `Learner` to bfloat16 mixed precision using PyTorch AMP"
    return self.add_cb(MixedPrecision(amp_mode=AMPMode.BF16))

In [None]:
#|export
@patch
def to_fp32(self:Learner):
    "Set `Learner` to float32 precision"
    return self.remove_cb(MixedPrecision)

## Tests -

In [None]:
#|hide
#|cuda
from fastxtend.test_utils import *
from fastai.optimizer import SGD
from fastcore.basics import listify, flatten

In [None]:
#|hide
#|cuda
class FP16TestCallback(Callback):
    "Asserts that predictions are `float16` values"
    order = 9
    def after_pred(self):
        assert listify(flatten(self.pred))[0].dtype==torch.float16

In [None]:
#|hide
#|cuda
class BF16TestCallback(Callback):
    "Asserts that predictions are `bfloat16` values"
    order = 9
    def after_pred(self):
        assert listify(flatten(self.pred))[0].dtype==torch.bfloat16

In [None]:
#|hide
#|cuda
set_seed(99, True)
learn = synth_learner(cbs=[MixedPrecision,FP16TestCallback], cuda=True)
learn.model = nn.Sequential(nn.Linear(1,1), nn.Linear(1,1)).cuda()
learn.opt_func = partial(SGD, mom=0.)
learn.splitter = lambda m: [list(m[0].parameters()), list(m[1].parameters())]
learn.fit(3)
assert learn.recorder.values[-1][-1]<learn.recorder.values[0][-1]

In [None]:
#|hide
#|cuda
#Multioutput version
set_seed(99, True)
learn = synth_learner(cbs=[MixedPrecision('fp16'),FP16TestCallback], cuda=True)
class MultiOutputModel(Module):
    def __init__(self): self.linear1, self.linear2 = nn.Linear(1,1), nn.Linear(1,1)
    def forward(self,x): return self.linear1(x), self.linear2(x)
def multioutputloss(pred, val): return ((val-pred[0]).abs() + 0.5 * (val-pred[1]).abs()).sum()
learn.model = MultiOutputModel()
learn.opt_func = partial(SGD, mom=0.)
learn.splitter = lambda m: [list(m.linear1.parameters()), list(m.linear2.parameters())]
learn.loss_func=multioutputloss
learn.fit(3)
assert learn.recorder.values[-1][-1]<learn.recorder.values[0][-1]

In [None]:
#|hide
#|cuda
if torch.cuda.is_bf16_supported():
    set_seed(99, True)
    learn = synth_learner(cbs=[MixedPrecision(amp_mode=AMPMode.BF16),BF16TestCallback], cuda=True)
    learn.model = nn.Sequential(nn.Linear(1,1), nn.Linear(1,1)).cuda()
    learn.opt_func = partial(SGD, mom=0.)
    learn.splitter = lambda m: [list(m[0].parameters()), list(m[1].parameters())]
    learn.fit(3)
    assert learn.recorder.values[-1][-1]<learn.recorder.values[0][-1]