In [None]:
#|default_exp callback.channelslast

In [None]:
#|hide
from nbdev.showdoc import *

# Channels Last
> A Callback which converts a fastai `Learner` and input to channels_last format.

Using Mixed Precision, image models trained in channels last format on Nvidia Tensor Cores can achieve 8%-35% increased performance over contiguous format. 

Channels last memory format is only implemented for 4D NCHW Tensors. Not all PyTorch operators have been converted to support channels last. See [(Beta) Channels Last Memory Format in PyTorch](https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html) for more details.

Channels Last format can error out if `torch.backends.cudnn.benchmark = False`, e.g. via fast.ai's [no_random](https://docs.fast.ai/torch_core.html#no_random) context manager. If this occurs the `less_random` context manager instead. This will allow reproducable training on the same GPU, PyTorch, and CUDA setup at the expense of less reproducablity should any of those change.

In [None]:
#|export
from __future__ import annotations

from torch.cuda.amp import GradScaler

from fastai.torch_core import TensorImageBase, TensorMask
from fastai.learner import Learner
from fastai.vision.augment import DisplayedTransform
from fastai.callback.core import Callback
from fastai.callback.fp16 import MixedPrecision
from fastai.basics import Pipeline

from fastxtend.imports import *

## Channels Last Transform -

In [None]:
#|export
class ChannelsLastTfm(DisplayedTransform):
    "Sets image-like inputs to `channels_last` format. For use in ChannelsLastCallback"
    order = 110 # run after all other transforms if added to batch_tfms
    def encodes(self, x:TensorImageBase|TensorMask):
        return x.to(memory_format=torch.channels_last)

Channels last format requires inputs to be 4D NCHW Tensors, so `ChannelsLastTfm` only encodes `TensorImageBase` and `TensorMask` inputs to channels last using fastcore's type dispatch.

To set another input type as channels last format, patch `ChannelsLastTfm.encodes` to dispatch for that type.

## Channels Last -

In [None]:
#|export
class ChannelsLastCallback(Callback):
    "Channels last training using PyTorch's Channels Last Memory Format (beta)"
    order = MixedPrecision.order+1
    def __init__(self):
        self._channels_last = Pipeline([ChannelsLastTfm()])

    def before_fit(self):
        self.learn.model.to(memory_format=torch.channels_last)

    def before_batch(self):
        self.learn.xb = self._channels_last(self.xb)

## Convenience Methods

In [None]:
#|export
@patch
@delegates(GradScaler)
def to_channelslast(self:Learner, to_fp16=True, **kwargs):
    "Set `Learner` and inputs to `channels_last` format and Mixed Precision by default"
    if to_fp16 and not hasattr(self, 'mixed_precision') and not hasattr(self, 'channels_last'): 
        return self.add_cbs([ChannelsLastCallback(), MixedPrecision(**kwargs)])
    elif not hasattr(self, 'channels_last'):
        return self.add_cb(ChannelsLastCallback())

In [None]:
#|export
@patch
def to_contiguous(self:Learner, to_fp32=False):
    "Set `Learner` and inputs to `contiguous_format` (default format), optionally to single precision"
    self.model.to(memory_format=torch.contiguous_format)
    if to_fp32: return self.remove_cbs([ChannelsLastCallback, MixedPrecision])
    else:       return self.remove_cb(ChannelsLastCallback)

# Test -

In [None]:
#|hide
from fastxtend.test_utils import *
from fastai.data.core import TfmdDL, DataLoaders
from fastai.optimizer import SGD
from torch.utils.data import TensorDataset

In [None]:
#|hide
class ChannelsLastInputTest(Callback):
    "Asserts that inputs are in channels last format"
    order = ChannelsLastCallback.order+1
    def before_batch(self):
        assert self.x.is_contiguous(memory_format=torch.channels_last), "Input isn't channels last"

class ChannelsLastPredTest(Callback):
    "Asserts that predictions are in channels last format"
    order = MixedPrecision.order-1
    def after_pred(self): 
        assert self.pred.is_contiguous(memory_format=torch.channels_last), "Model and/or output isn't channels last"

In [None]:
#|hide
#|cuda
def synth_dbunch(bs=16, n_train=10, n_valid=2, cuda=True):
    def get_data(n):
        return TensorDataset(TensorImage(torch.randn(bs*n, 3, 32, 32)))
    train_ds = get_data(n_train)
    valid_ds = get_data(n_valid)
    device = default_device() if cuda else None
    train_dl = TfmdDL(train_ds, bs=bs, shuffle=True, num_workers=0)
    valid_dl = TfmdDL(valid_ds, bs=bs, num_workers=0)
    return DataLoaders(train_dl, valid_dl, device=device)

In [None]:
#|hide
#|cuda
with no_random():
    learn = synth_learner(cbs=[MixedPrecision,ChannelsLastCallback,ChannelsLastInputTest,ChannelsLastPredTest], cuda=True, data=synth_dbunch())
    class ConvModel(Module):
        def __init__(self): self.conv = nn.Conv2d(3, 32, 1)
        def forward(self,x): return self.conv(x)
    def fakeloss(): pass
    learn.model = ConvModel()
    learn.opt_func = partial(SGD, mom=0.)
    learn.loss_func=fakeloss
    learn.fit(3)

epoch,train_loss,valid_loss,time
0,,,00:02
1,,,00:00
2,,,00:00
