In [None]:
#|default_exp callback.channelslast

In [None]:
#|hide
from nbdev.showdoc import *

# Channels Last
> A Callback which converts a fastai `Learner` and input to channels_last format.

Using Mixed Precision, image models trained in channels last format on Nvidia Tensor Cores can achieve 8%-35% increased performance over contiguous format. 

Channels last memory format is only implemented for 4D NCHW Tensors. Not all PyTorch operators have been converted to support channels last. See [(Beta) Channels Last Memory Format in PyTorch](https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html) for more details.

Channels Last format can error out if `torch.backends.cudnn.benchmark = False`, e.g. via fast.ai's [no_random](https://docs.fast.ai/torch_core.html#no_random) context manager. If this occurs the `less_random` context manager instead. This will allow reproducable training on the same GPU, PyTorch, and CUDA setup at the expense of less reproducablity should any of those change.

In [None]:
#|export
from __future__ import annotations
import torch
from fastai.torch_core import TensorImageBase, TensorMask
from fastai.learner import Learner
from fastai.vision.augment import DisplayedTransform
from fastai.callback.core import Callback
from fastai.callback.fp16 import MixedPrecision
from fastai.callback.mixup import MixHandler
from fastai.basics import Pipeline
from torch.cuda.amp import GradScaler

from fastxtend.imports import *

## Channels Last Transform -

In [None]:
#|export
class ChannelsLastTfm(DisplayedTransform):
    "Sets image inputs to `channels_last` format. For use in ChannelsLastCallback"
    def encodes(self, x:TensorImageBase|TensorMask):
        return x.to(memory_format=torch.channels_last)

## Channels Last Callback -

In [None]:
#|export
class ChannelsLastCallback(Callback):
    "Channels last training using PyTorch's Channels Last Memory Format (beta)"
    order = MixHandler.order+1
    def __init__(self):
        self._channels_last = Pipeline([ChannelsLastTfm()])

    def before_fit(self):
        self.learn.model.to(memory_format=torch.channels_last)

    def before_batch(self):
        self.learn.xb = self._channels_last(self.xb)

## Convenience Methods

In [None]:
#|export
@patch
@delegates(GradScaler)
def to_channelslast(self:Learner, to_fp16=True, **kwargs):
    "Set `Learner` and inputs to `channels_last` format and Mixed Precision by default"
    if to_fp16 and not hasattr(self, 'mixed_precision'): 
        return self.add_cbs([ChannelsLastCallback(), MixedPrecision(**kwargs)])
    else:
        return self.add_cb(ChannelsLastCallback())

In [None]:
#|export
@patch
def to_contiguous(self:Learner, to_fp32=False):
    "Set `Learner` and inputs to `contiguous_format` (default format), optionally to single precision"
    self.model.to(memory_format=torch.contiguous_format)
    if to_fp32: return self.remove_cbs([ChannelsLastCallback, MixedPrecision])
    else:       return self.remove_cb(ChannelsLastCallback())