In [1]:
from fastai.vision.all import *
from fastai.callback.core import _events, _inner_loop
from torch.cuda.amp import GradScaler,autocast

# Callbacks

Recall the full minimal training loop from the lesson.

fastai callbacks can be used to modify this training loop without overcomplicating the training loop code. This also allows us to easily mix and match different callbacks with each other and keep the compatibility code between training methods outside of the main training loop.

In [None]:
def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    for epoch in range(epochs):
        model.train()
        for xb,yb in train_dl:
            loss = loss_func(model(xb), yb)
            loss.backward()
            opt.step()
            opt.zero_grad()

        model.eval()
        with torch.no_grad():
            tot_loss,tot_acc,count = 0.,0.,0
            for xb,yb in valid_dl:
                pred = model(xb)
                n = len(xb)
                count += n
                tot_loss += loss_func(pred,yb).item()*n
                tot_acc  += accuracy (pred,yb).item()*n
        print(epoch, tot_loss/count, tot_acc/count)
    return tot_loss/count, tot_acc/count

A callback can implement actions on the following events:

- `after_create`: called after the `Learner` is created
- `before_fit`: called before starting training or inference, ideal for initial setup.
- `before_epoch`: called at the beginning of each epoch, useful for any behavior you need to reset at each epoch.
- `before_train`: called at the beginning of the training part of an epoch.
- `before_batch`: called at the beginning of each batch, just after drawing said batch. It can be used to do any setup necessary for the batch (like hyper-parameter scheduling) or to change the input/target before it goes in the model (change of the input with techniques like mixup for instance).
- `after_pred`: called after computing the output of the model on the batch. It can be used to change that output before it's fed to the loss.
- `after_loss`: called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss (AR or TAR in RNN training for instance).
- `before_backward`: called after the loss has been computed, but only in training mode (i.e. when the backward pass will be used)
- `after_backward`: called after the backward pass, but before the update of the parameters. Generally `before_step` should be used instead.
- `before_step`: called after the backward pass, but before the update of the parameters. It can be used to do any change to the gradients before said update (gradient clipping for instance).
- `after_step`: called after the step and before the gradients are zeroed.
- `after_batch`: called at the end of a batch, for any clean-up before the next one.
- `after_train`: called at the end of the training phase of an epoch.
- `before_validate`: called at the beginning of the validation phase of an epoch, useful for any setup needed specifically for validation.
- `after_validate`: called at the end of the validation part of an epoch.
- `after_epoch`: called at the end of an epoch, for any clean-up before the next one.
- `after_fit`: called at the end of training, for final clean-up.

(the above is excepted from the fastai docs on Callbacks)

Every callback has:
 - `order`: defines when the callback is ran when called
 - `learn`: easy access to `Learner` via `self.learn`
 - `run`: controls if the callback runs
 - `run_train`: controls if the callback runs on training loop events
 - `run_train`: controls if the callback runs on validation loop events

In [None]:
@funcs_kwargs(as_method=True)
class Callback(Stateful,GetAttr):
    "Basic class handling tweaks of the training loop by changing a `Learner` in various events"
    order = 0
    learn = None
    run = True
    run_train = True
    run_valid = True
    _methods = _events

    def __init__(self, **kwargs): 
        assert not kwargs, f'Passed unknown events: {kwargs}' # I think this should state unknown arguments, not events

The `__call__` method determines if an individual callback is called using `run`, `run_train`, and `run_valid`.

Note the following events:

`after_create`, `before_fit`, `before_epoch`, `before_train`, `after_train`, `before_validate`, `after_validate`, `after_epoch`, `after_fit`

are always called by `Learner`.

There are also `after_cancel_*` methods for all callback events that have a fastai cancel exception

In [None]:
def __call__(self, event_name):
    "Call `self.{event_name}` if it's defined"
    _run = (event_name not in _inner_loop or (self.run_train and getattr(self, 'training', True)) or
            (self.run_valid and not getattr(self, 'training', False)))
    res = None
    if self.run and _run: 
        try: 
            res = getcallable(self, event_name)()
        except (CancelBatchException, 
                CancelBackwardException, 
                CancelEpochException, 
                CancelFitException, 
                CancelStepException, 
                CancelTrainException, 
                CancelValidException
            ): 
                raise
        except Exception as e: 
            raise modify_exception(e, f'Exception occured in `{self.__class__.__name__}` when calling event `{event_name}`:\n\t{e.args[0]}', replace=True)
    if event_name=='after_fit': 
        self.run=True #Reset self.run to True at each end of fit
    return res

Thanks to inheriting from `GetAttr`:

```ptyhon
class Callback(Stateful,GetAttr):
```

we can read any `Learner` attribute from `self.var` instead of `self.learn.var`. Although to write to those attributes, we should use `self.learn.var` and not `self.var`.

This is enforced via the `__setattr__` in `Callback`.

In [None]:
def __setattr__(self, name, value):
    "Set an attribute for a `Callback`"
    if hasattr(self.learn,name):
        warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
    super().__setattr__(name, value)

# Callback Examples

fastai uses callbacks to simplify and extend the training loop. For example, setting the model to train or eval is handled by the `TrainEvalCallback`. Along with basic training statistics.

In [None]:
class TrainEvalCallback(Callback):
    "`Callback` that tracks the number of iterations done and properly sets training/eval mode"
    order,run_valid = -10,False

    def after_create(self): 
        self.learn.n_epoch = 1

    def before_fit(self):
        "Set the iter and epoch counters to 0, put the model and the right device"
        self.learn.epoch,self.learn.loss = 0,tensor(0.)
        self.learn.train_iter,self.learn.pct_train = 0,0.
        device = getattr(self.dls, 'device', default_device())
        self.model.to(device)
        if isinstance(self.loss_func, (nn.Module, BaseLoss)): 
            self.loss_func.to(device)
        if hasattr(self.model, 'reset'): 
            self.model.reset()

    def after_batch(self):
        "Update the iter counter (in training mode)"
        self.learn.pct_train += 1./(self.n_iter*self.n_epoch)
        self.learn.train_iter += 1

    def before_train(self):
        "Set the model to training mode"
        self.learn.pct_train=self.epoch/self.n_epoch
        self.model.train()
        self.learn.training=True

    def before_validate(self):
        "Set the model to validation mode"
        self.model.eval()
        self.learn.training=False

`MixedPrecision` is an example of modifying the training loop via callback. Here we apply PyTorch's `autocast` and gradient scaling to train models using Automatic Mixed Precision.

We can also see a use of `CancelStepException` to make fastai optimziers compatible with `GradScaler`, which expects a PyTorch optimizer.

In [None]:
class MixedPrecision(Callback):
    "Mixed precision training using Pytorch's `autocast` and `GradScaler`"
    order = 10
    def __init__(self, **kwargs): 
        self.kwargs = kwargs

    def before_fit(self): 
        self.autocast,self.learn.scaler,self.scales = autocast(),GradScaler(**self.kwargs),L()

    def before_batch(self):
        self.autocast.__enter__()

    def after_pred(self):
        if next(flatten(self.pred)).dtype==torch.float16: 
            self.learn.pred = to_float(self.pred)

    def after_loss(self): 
        self.autocast.__exit__(None, None, None)

    def before_backward(self): 
        self.learn.loss_grad = self.scaler.scale(self.loss_grad)

    def before_step(self):
        "Use `self` as a fake optimizer. `self.skipped` will be set to True `after_step` if gradients overflow. "
        self.skipped=True
        self.scaler.step(self)
        if self.skipped:
            raise CancelStepException()
        self.scales.append(self.scaler.get_scale())

    def after_step(self): 
        self.learn.scaler.update()

    @property 
    def param_groups(self): 
        "Pretend to be an optimizer for `GradScaler`"
        return self.opt.param_groups

    def step(self, *args, **kwargs): 
        "Fake optimizer step to detect whether this batch was skipped from `GradScaler`"
        self.skipped=False

    def after_fit(self): 
        self.autocast,self.learn.scaler,self.scales = None,None,None

We can also modify training data via a Callback. Here is how fastai applies MixUp to images and labels (in combination with `MixHandler` to correctly apply the loss):

In [None]:
class MixUp(MixHandler):
    "Implementation of https://arxiv.org/abs/1710.09412"
    def __init__(self, 
        alpha:float=.4 # Determine `Beta` distribution in range (0.,inf]
    ): 
        super().__init__(alpha)
        
    def before_batch(self):
        "Blend xb and yb with another random item in a second batch (xb1,yb1) with `lam` weights"
        lam = self.distrib.sample((self.y.size(0),)).squeeze().to(self.x.device)
        lam = torch.stack([lam, 1-lam], 1)
        self.lam = lam.max(1)[0]
        shuffle = torch.randperm(self.y.size(0)).to(self.x.device)
        xb1,self.yb1 = tuple(L(self.xb).itemgot(shuffle)),tuple(L(self.yb).itemgot(shuffle))
        nx_dims = len(self.x.size())
        self.learn.xb = tuple(L(xb1,self.xb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=nx_dims-1)))

        if not self.stack_y:
            ny_dims = len(self.y.size())
            self.learn.yb = tuple(L(self.yb1,self.yb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=ny_dims-1)))