In [None]:
from __future__ import annotations

from fastai.vision.all import *

# Learner: Part 1

In the past couple lessons, we've been working with a simple training loop like the one below. While it works, we must rewrite portions of it every time we want to change something with our training. And in addition to this annoyance, each time we add or delete something there is a potential for introducing bugs or errors. Some of these bugs can be quite insidious as they don't prevent training, but rather hinder model performance.

In [None]:
def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    for epoch in range(epochs):
        model.train()
        for xb,yb in train_dl:
            loss = loss_func(model(xb), yb)
            loss.backward()
            opt.step()
            opt.zero_grad()

        model.eval()
        with torch.no_grad():
            tot_loss,tot_acc,count = 0.,0.,0
            for xb,yb in valid_dl:
                pred = model(xb)
                n = len(xb)
                count += n
                tot_loss += loss_func(pred,yb).item()*n
                tot_acc  += accuracy (pred,yb).item()*n
        print(epoch, tot_loss/count, tot_acc/count)
    return tot_loss/count, tot_acc/count

To solve this in the course we are building MiniAI, a (minimal) flexible framework with callbacks.

In practice, we'd use other frameworks with callback and/or extension support like Composer, Lightning, or fastai. Or use a framework to handle the minimal training loop modifications for us like Accelerate.

In this notebook we'll look at the fastai Learner, focusing on how it adds, removes, and calls callbacks. And how fastai defines the training loop.

## Reviewing Callbacks

The lesson14 folder has a notebook on how Fastai defines callbacks, which might be useful to look at before continuing so the material is fresh.

## Adding and Removing Callbacks

fastai's `Learner` defines two methods for adding callbacks. `add_cb` and `add_cbs`, the latter which is a convenience method adding multiple callbacks using `add_cb`.

In this simplified `__init__` method, we can see that `Learner` adds default callbacks and the Callbacks we pass to the model.

In [None]:
class Learner(GetAttr):
    _default='model'
    def __init__(self,
        cbs:Callback|list|None=None, # `Callback`s to add to `Learner`
        default_cbs:bool=True # Include default callbacks
    ):
        store_attr(but='dls,cbs')
        self.cbs = L()
        if default_cbs: 
            self.add_cbs(L(defaults.callbacks))
        self.add_cbs(cbs)
        self("after_create")

`add_cbs` loops through all callbacks and calls `add_cb`.

`add_cb` initializes the callback if needed, set's the `Callback.learn` attribute, sets the callback as an attribute of `Learner`, and then adds the callback to the list of all callbacks.

In [None]:
def add_cbs(self, cbs):
    L(cbs).map(self.add_cb)
    return self

def add_cb(self, cb):
    if isinstance(cb, type): 
        cb = cb()
    cb.learn = self
    setattr(self, cb.name, cb)
    self.cbs.append(cb)
    return self

`remove_cbs` and `remove_cb` does the inverse. Which some special handeling using `_grab_cbs` if needed.

In [None]:
def remove_cbs(self, cbs):
    L(cbs).map(self.remove_cb)
    return self

def _grab_cbs(self, cb_cls): 
    return L(cb for cb in self.cbs if isinstance(cb, cb_cls))

def remove_cb(self, cb):
    if isinstance(cb, type): 
        self.remove_cbs(self._grab_cbs(cb))
    else:
        cb.learn = None
        if hasattr(self, cb.name): 
            delattr(self, cb.name)
        if cb in self.cbs: 
            self.cbs.remove(cb)
    return self

## Calling Callbacks

At the end of the `Learner.__init__` method, there was this line:

```python
self("after_create")
```
which was `Learner` calling the `after_create` method of any added Callbacks. 

This is one of the two ways that `Learner` invokes callbacks. This works because `Learner` overrides the standard Python `__call__` method to take an `event_name`. This event name(s) is passed to `_call_one` which verifies the event is a valid callback event, sorts the callbacks, and then calls the callback it's method for that event.

Note: event is a fastcore `mk_class`, which effectively works as an Enum, but as a class.

In [None]:
def __call__(self, event_name): 
    L(event_name).map(self._call_one)

def _call_one(self, event_name):
    if not hasattr(event, event_name): 
        raise Exception(f'missing {event_name}')
    for cb in self.cbs.sorted('order'): 
        cb(event_name)

The other way `Learner` calls fastai callbacks is via the `_with_events` method. This method calls both the "before" and "after" methods for each callback event and handles Callback errors via the "after_cancel" Callback event. This method is used throughout the `Learner` training loop to define when callback events are called.

`f` and `final` are methods we can pass to `_with_events`, and as we'll see in the next section are used to jump through the training loop.

In [None]:
def _with_events(self, f, event_type, ex, final=noop):
    try: 
        self(f'before_{event_type}')
        f()
    except ex: 
        self(f'after_cancel_{event_type}')
    self(f'after_{event_type}')
    final()

## Training Loop Overview

Once again as a reminder, this is the simplified training loop that we are recreating in a framework to be modified with Callbacks.

In [None]:
def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    for epoch in range(epochs):
        model.train()
        for xb,yb in train_dl:
            loss = loss_func(model(xb), yb)
            loss.backward()
            opt.step()
            opt.zero_grad()
            
        model.eval()
        with torch.no_grad():
            tot_loss,tot_acc,count = 0.,0.,0
            for xb,yb in valid_dl:
                pred = model(xb)
                n = len(xb)
                count += n
                tot_loss += loss_func(pred,yb).item()*n
                tot_acc  += accuracy (pred,yb).item()*n
        print(epoch, tot_loss/count, tot_acc/count)
    return tot_loss/count, tot_acc/count

I will first show the training loop method headers, so the entire fastai training loop can fit on one screen. Then we will look at each training loop method code.

Remember that:
```python
self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
```
calls any Callback's "before_fit", then calls `_do_fit`, handles any `CancelFitException`s, then calls any Callback's "after_fit", before finally calling `_end_cleanup`.

For conciseness, I am leaving out the exception methods and final in the comments. You'll see them in the code.

In [1]:
# Performs optimizer, epoch, and hyperparameter setup. Then calls _with_events(_do_fit, 'fit')
def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False, start_epoch=0):
    pass

def _do_fit(self): pass # Loops through all the epochs and calls _with_events(do_epoch, 'epoch')

def _do_epoch(self): pass # Calls _do_epoch_train then _do_epoch_validate

def _do_epoch_train(self): pass # Sets the Learner DataLoader to train and calls _with_events(all_batches, 'train')

def all_batches(self): pass # Enumerates through all batches calling one_batch

def one_batch(self, i, b): pass # Calls _set_device on the batch, optionally splits batch into Xs & Ys, and calls _with_events(_do_one_batch, 'batch')

def _do_one_batch(self): pass # The tight training loop: prediction, loss, then calls _with_events(_backward, 'backward') & _with_events(self._step, 'step')

def _backward(self): pass # Calls backwards on the loss

def _step(self): pass # Calls the optimizer step. After this _do_one_batch calls opt.zero_grad

def _do_epoch_valid(self): pass # Repeat the process on the validation set, minus the model update

## Training Loop Code

`fit` performs initialized the optimizer via `create_opt`, and sets hyperparameters like epoch, weight decay, etc and passes the to the optimizer. It then calls any Callback's "before_fit", then calls `_do_fit`, handles any `CancelFitException`s, then calls any Callback's "after_fit", before finally calling `_end_cleanup`.

In [None]:
def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False, start_epoch=0):
    "Fit `self.model` for `n_epoch` using `cbs`. Optionally `reset_opt`."
    if start_epoch != 0:
        cbs = L(cbs) + SkipToEpoch(start_epoch)
    with self.added_cbs(cbs):
        if reset_opt or not self.opt: 
            self.create_opt()
        if wd is None: 
            wd = self.wd
        if wd is not None: 
            self.opt.set_hypers(wd=wd)
        self.opt.set_hypers(lr=self.lr if lr is None else lr)
        self.n_epoch = n_epoch
        self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)

`_do_fit` loops through all the epochs, calling any Callback's "before_epoch" method, then calls `_do_fit`, handles any `CancelEpochException`s, and finally calls any Callback's "after_epoch" method.

In [None]:
def _do_fit(self):
    for epoch in range(self.n_epoch):
        self.epoch=epoch
        self._with_events(self._do_epoch, 'epoch', CancelEpochException)

`_do_epoch` calls `_do_epoch_train` then `_do_epoch_validate`

In [None]:
def _do_epoch(self):
    self._do_epoch_train()
    self._do_epoch_validate()

`_do_epoch_train` sets the  to train and calls any Callback's "before_train" method, then calls `all_batches`, handles any `CancelTrainException`s, and finally calls any Callback's "after_train" method.

In [None]:
def _do_epoch_train(self):
    self.dl = self.dls.train
    self._with_events(self.all_batches, 'train', CancelTrainException)

`all_batches` enumerates through the DataLoader, hopefully with multiprocessing, and calls `one_batch` with the batch

In [None]:
def all_batches(self):
    self.n_iter = len(self.dl)
    for o in enumerate(self.dl): 
        self.one_batch(*o)

`one_batch` sets the batch device (if needed), splits the batch into samples and labels (if applicable) and calls any Callback's "before_batch" method, then calls `_do_one_batch`, handles any `CancelBatchException`s, and finally calls any Callback's "after_batch" method.

In [None]:
def one_batch(self, i, b):
    self.iter = i
    b = self._set_device(b)
    self._split(b)
    self._with_events(self._do_one_batch, 'batch', CancelBatchException)

`_set_device` makes sure the batch is on the same device as the model is. If a batch is already on the same device, then calling `to_device` doesn't do anything too it, other than wasting a few CPU cycles.

In [None]:
def _set_device(self, b):
    model_device = next(self.model.parameters()).device
    dls_device = getattr(self.dls, 'device', default_device())
    if model_device == dls_device: 
        return to_device(b, dls_device)
    else: 
        return to_device(b, model_device)

`_split` uses the `DataLoader`'s 'n_inp' attribute (if it exists), which we set via the `DataBlock` to split the batch into inputs `xb` and labels `yb`

In [None]:
def _split(self, b):
    i = getattr(self.dls, 'n_inp', 1 if len(b)==1 else len(b)-1)
    self.xb,self.yb = b[:i],b[i:]

`_do_one_batch` handles the direct model training loop.
- Creates the model predictions from inputs
- Calls any Callbacks's "after_pred" 
- Calculates the loss if labels exist
- Calls any Callbacks's "after_loss"
- Returns if the Learner isn't in training mode or there are no labels
- Calls any Callback's "before_backward" method, then calls `backward`, handles any `CancelBackwardException`s, and finally calls any Callback's "after_backward" method
- Calls any Callback's "before_step" method, then calls `step`, handles any `CancelStepException`s, and finally calls any Callback's "after_step" method
- Finally zeroes the gradients of the optimizer

In [None]:
def _do_one_batch(self):
    self.pred = self.model(*self.xb)
    self('after_pred')
    if len(self.yb):
        self.loss_grad = self.loss_func(self.pred, *self.yb)
        self.loss = self.loss_grad.clone()
    self('after_loss')
    if not self.training or not len(self.yb): 
        return
    self._with_events(self._backward, 'backward', CancelBackwardException)
    self._with_events(self._step, 'step', CancelStepException)
    self.opt.zero_grad()

`_backward` calls backward on the loss to perform the backward pass using the gradients. It is called like this with "before_backward" and "after_backward" Callback events for compatibility with Accelerate, which handles multi-GPU training for fastai.

In [None]:
def _backward(self):
    self.loss_grad.backward()

`_step` calls the Optimizer's step method. It is called like this with "before_step" and "after_step" Callback events so the Mixed Precision callback can emulate a PyTorch Optimizer uwhile sing the Automatic Mixed Precision `GradScaler`.

In [None]:
def _step(self):
    self.opt.step()

Finally, `_do_epoch_validate` runs the whole loop again on the validation set, except without training the model

In [None]:
def _do_epoch_validate(self, ds_idx=1, dl=None):
    if dl is None: 
        dl = self.dls[ds_idx]
    self.dl = dl
    with torch.no_grad(): 
        self._with_events(self.all_batches, 'validate', CancelValidException)

## To Be Continued

To be continued after the next course lesson