In [1]:
from __future__ import annotations

from fastai.vision.all import *
from fastai.metrics import ActivationType
from fastprogress.fastprogress import format_time

# Learner: Part 2

In the past two lessons, we've defined a few iterations of a MiniAI Learner: a Basics Learner, Basic Callbacks Learner, and two different implementations of a Flexible Learner.

The Lesson 16 Learner looks a bit more similar to the fastai Learner which we covered in the `Learner_part_1.ipynb` notebook then the Flexible Learner from Lesson 15.

## MiniAI Flexible Callbacks

This new version of the Flexible Learner calls callbacks via a Callback context manager `callback_ctx`:

In [2]:
def run_cbs(cbs, method_nm):
    for cb in sorted(cbs, key=attrgetter('order')):
        method = getattr(cb, method_nm, None)
        if method is not None: method()

In [3]:
@contextmanager
def callback_ctx(self, nm):
    try:
        self.callback(f'before_{nm}')
        yield
    except globals()[f'Cancel{nm.title()}Exception']: 
        pass
    finally: 
        self.callback(f'after_{nm}')

def callback(self, method_nm): 
    run_cbs(self.cbs, method_nm)

Which we then use in code like this:

In [4]:
def fit(self, n_epochs):
    with self.callback_ctx('fit'):
        pass

The Flexible Learner also defines a `__getattr__` so we can define `self.predict()` in a Callback.

In [5]:
def __getattr__(self, name):
    if name in ('predict','get_loss','backward','step','zero_grad'): 
        return partial(self.callback, name)
    raise AttributeError(name)

This works because a python class will use `__getattr__` to attempt to find any missing references when we call `self.predict()` without defining it in code. So the result is our MiniAI Learner will call all the Callback's `predict` methods, if they exist.

## fastai Learner Callbacks

This should all look quite familiar to what we covered in the Learner: Part 1 notebook, because it's an iteration on the same idea.

In fastai instead of calling `self.predict` we'd call:

```python
self("after_create")
```

this works because `Learner` overrides the standard Python `__call__` method to take an `event_name`. The event name(s) are mapped (passed) to `_call_one` which verifies the event is a valid callback event, sorts the callbacks, and then calls the callback's method for that event.

Note: event is a fastcore `mk_class`, which effectively works as an Enum, but as a class.

In [6]:
def __call__(self, event_name): 
    L(event_name).map(self._call_one)

def _call_one(self, event_name):
    if not hasattr(event, event_name): 
        raise Exception(f'missing {event_name}')
    for cb in self.cbs.sorted('order'): 
        cb(event_name)

And fastai's Learner has an equivalent of `callback_ctx` called `_with_events` which calls both the "before" and "after" methods for each callback event and handles Callback errors via the "after_cancel" Callback event.

In [7]:
def _with_events(self, f, event_type, ex, final=noop):
    try: 
        self(f'before_{event_type}')
        f()
    except ex: 
        self(f'after_cancel_{event_type}')
    self(f'after_{event_type}')
    final()

Unlike MiniAI, `_with_events` allows us to handle an exception via the `after_cancel` Callback event. And the equivalent to `callback_ctx`'s Callback exception handling:
```python
except globals()[f'Cancel{nm.title()}Exception']: 
    pass
```
is in the `__call__` method of Callback

In [8]:
def __call__(self, event_name):
    if self.run: 
        try: 
            res = getcallable(self, event_name)()
        except (CancelBatchException, 
                CancelBackwardException, 
                CancelEpochException, 
                CancelFitException, 
                CancelStepException, 
                CancelTrainException, 
                CancelValidException
            ): 
                raise
        except Exception as e: 
            raise modify_exception(e, f'Exception occured in `{self.__class__.__name__}` when calling event `{event_name}`:\n\t{e.args[0]}', replace=True)
    return res

## Training Loops

The MiniAI training loop is as follows, with 
```python
with self.callback_ctx('fit')
```
and 
```python
self.predict()
```
calling Callbacks as we previously discussed.

In [9]:
def fit(self, n_epochs):
    self.n_epochs = n_epochs
    self.epochs = range(n_epochs)
    self.opt = self.opt_func(self.model.parameters(), self.lr)
    with self.callback_ctx('fit'):
        for self.epoch in self.epochs:
            self.one_epoch(True)
            with torch.no_grad(): 
                self.one_epoch(False)

    def one_epoch(self, train):
        self.model.train(train)
        self.dl = self.dls.train if train else self.dls.valid
        with self.callback_ctx('epoch'):
            for self.iter,self.batch in enumerate(self.dl):
                with self.callback_ctx('batch'):
                    self.predict()
                    self.get_loss()
                    if self.model.training:
                        self.backward()
                        self.step()
                        self.zero_grad()

Meanwhile the fastai training loop looks like this (for details, see the Learner Part 1 notebook in the lesson 15 folder):

In [10]:
# Performs optimizer, epoch, and hyperparameter setup. Then calls _with_events(_do_fit, 'fit')
def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False, start_epoch=0):
    pass

def _do_fit(self): pass # Loops through all the epochs and calls _with_events(do_epoch, 'epoch')

def _do_epoch(self): pass # Calls _do_epoch_train then _do_epoch_validate

def _do_epoch_train(self): pass # Sets the Learner DataLoader to train and calls _with_events(all_batches, 'train')

def all_batches(self): pass # Enumerates through all batches calling one_batch

def one_batch(self, i, b): pass # Calls _set_device on the batch, optionally splits batch into Xs & Ys, and calls _with_events(_do_one_batch, 'batch')

def _do_one_batch(self): pass # The tight training loop: prediction, loss, then calls _with_events(_backward, 'backward') & _with_events(self._step, 'step')

def _backward(self): pass # Calls backwards on the loss

def _step(self): pass # Calls the optimizer step. After this _do_one_batch calls opt.zero_grad

def _do_epoch_valid(self): pass # Repeat the process on the validation set, minus the model update

Where instead of creating callback methods for `predict` and `backward` we have an `after_pred`, `after_loss`, and before and after backwards and step Callback `_with_events`.

In [11]:
def _do_one_batch(self):
    self.pred = self.model(*self.xb)
    self('after_pred')
    if len(self.yb):
        self.loss_grad = self.loss_func(self.pred, *self.yb)
        self.loss = self.loss_grad.clone()
    self('after_loss')
    if not self.training or not len(self.yb): 
        return
    self._with_events(self._backward, 'backward', CancelBackwardException)
    self._with_events(self._step, 'step', CancelStepException)
    self.opt.zero_grad()

## Default Training Callbacks
Like the MiniAI Learner, fastai's Learner defines training loop behavior in a set of default training Callbacks. Which we can see in:

In [12]:
defaults.callbacks

[fastai.callback.core.TrainEvalCallback,
 fastai.learner.Recorder,
 fastai.learner.CastToTensor,
 fastai.callback.progress.ProgressCallback]

In this section we will go over all four of these default training Callbacks. I will save Recorder for last, as I will discuss fastai Metrics and Recorder at the same time.

### TrainEvalCallback

First is the `TrainEvalCallback` Callback, which does a lot of setup, switches from the model from training to evaluation mode, and keeps track of and sets basic statistics, primarily for other Callbacks to use.

In [None]:
class TrainEvalCallback(Callback):
    "`Callback` that tracks the number of iterations done and properly sets training/eval mode"
    order,run_valid = -10,False

    def after_create(self): 
        self.learn.n_epoch = 1

    def before_fit(self):
        "Set the iter and epoch counters to 0, put the model and the right device"
        self.learn.epoch,self.learn.loss = 0,tensor(0.)
        self.learn.train_iter,self.learn.pct_train = 0,0.
        device = getattr(self.dls, 'device', default_device())
        self.model.to(device)
        if isinstance(self.loss_func, (nn.Module, BaseLoss)): 
            self.loss_func.to(device)
        if hasattr(self.model, 'reset'): 
            self.model.reset()

    def after_batch(self):
        "Update the iter counter (in training mode)"
        self.learn.pct_train += 1./(self.n_iter*self.n_epoch)
        self.learn.train_iter += 1

    def before_train(self):
        "Set the model to training mode"
        self.learn.pct_train=self.epoch/self.n_epoch
        self.model.train()
        self.learn.training=True

    def before_validate(self):
        "Set the model to validation mode"
        self.model.eval()
        self.learn.training=False

### ProgressCallback
`ProgressCallback` in fastai behaves quite similarly to MiniAI's progress bar callback. Except it has two Progress Bars, one for Epochs and one for Batches, and prints Metrics in a table as training progresses. Like MiniAI, it sets the Learner's `logger` attribute for printing so other Callbacks like `Recorder` can use it.

In [None]:
class ProgressCallback(Callback):
    "A `Callback` to handle the display of progress bars"
    order,_stateattrs = 60,('mbar','pbar')

    def before_fit(self):
        assert hasattr(self.learn, 'recorder')
        if self.create_mbar:
            self.mbar = master_bar(list(range(self.n_epoch)))
        if self.learn.logger != noop:
            self.old_logger,self.learn.logger = self.logger,self._write_stats
            self._write_stats(self.recorder.metric_names)
        else: self.old_logger = noop

    def before_epoch(self):
        if getattr(self, 'mbar', False): 
            self.mbar.update(self.epoch)

    def before_train(self):    
        self._launch_pbar()

    def before_validate(self): 
        self._launch_pbar()

    def after_train(self):     
        self.pbar.on_iter_end()

    def after_validate(self):  
        self.pbar.on_iter_end()

    def after_batch(self):
        self.pbar.update(self.iter+1)
        if hasattr(self, 'smooth_loss'): 
            self.pbar.comment = f'{self.smooth_loss:.4f}'

    def _launch_pbar(self):
        self.pbar = progress_bar(self.dl, parent=getattr(self, 'mbar', None), leave=False)
        self.pbar.update(0)

    def after_fit(self):
        if getattr(self, 'mbar', False):
            self.mbar.on_iter_end()
            delattr(self, 'mbar')
        if hasattr(self, 'old_logger'): 
            self.learn.logger = self.old_logger

    def _write_stats(self, log):
        if getattr(self, 'mbar', False): 
            self.mbar.write([f'{l:.6f}' if isinstance(l, float) else str(l) for l in log], table=True)

### CastToTensor
`CastToTensor` is a very simple callback which casts any fastai subclassed Tensors (such as `TensorImage`) to a normal Tensor before passing them to the model.

Training on a subclassed Tensor can result in up to a forty percent decrease in GPU throughput when using Automatic Mixed Precision and Channels Last memory format on modern hardware. Thomas Capelle and I discovered this bug, which you can [read more about here](https://benjaminwarner.dev/2022/06/14/debugging-pytorch-performance-decrease.html).

This one is important to note as any Callback which needs to make use of the Tensor subclass type must occur before order of 9.

In [None]:
def _cast_tensor(x): 
    if isinstance(x, tuple): 
        return tuple(_cast_tensor(x_) for x_ in x)
    else: 
        return cast(x, Tensor) if isinstance(x,torch.Tensor) else x

class CastToTensor(Callback):
    "Cast Subclassed Tensors to `Tensor`"
    order=9 # Right before MixedPrecision

    def before_batch(self):
        self.learn.xb = _cast_tensor(self.learn.xb)
        self.learn.yb = _cast_tensor(self.learn.yb)

## Metrics and Recorder
fastai has its own metrics system. At the time of development, `TorchMetrics` from Lightning.ai, `TorchEval` from PyTorch, and `Evaluate` from Hugging Face all didn't exist, so fastai built its own metrics system with a combination of custom code and scikit-learn metrics.

All fastai metrics inherit from the base `Metric` class.

In [None]:
class Metric():
    "Blueprint for defining a metric"
    def reset(self):
        "Reset inner state to prepare for new computation"
        pass

    def accumulate(self, learn):
        "Use `learn` to update the state with new results"
        pass

    @property
    def value(self):
        "The value of the metric"
        raise NotImplementedError

    @property
    def name(self):
        "Name of the `Metric`, camel-cased and with Metric removed" 
        return class2attr(self, 'Metric')

`AvgMetric` is the default fastai metric. It accepts a function `func` which it then calculates and stores the metric when `accumulate` and `value` are called.

If you pass a functional metric to Learner

```python
Learner(..., metrics=accuracy)
```
it will automatically be converted to an `AvgMetric` behind the scenes. This behavior can cause issues if your metric cannot be averaged across batches. For example, the mean of multiple batches of Root Means Square Error isn't equal to the RMSE of the whole dataset.

In [None]:
class AvgMetric(Metric):
    "Average the values of `func` taking into account potential different batch sizes"
    def __init__(self, func): 
        self.func = func

    def reset(self): 
        self.total,self.count = 0.,0

    def accumulate(self, learn):
        bs = find_bs(learn.yb)
        self.total += learn.to_detach(self.func(learn.pred, *learn.yb))*bs
        self.count += bs

    @property
    def value(self): 
        return self.total/self.count if self.count != 0 else None

    @property
    def name(self):
        return self.func.func.__name__ if hasattr(self.func, 'func') else  self.func.__name__

fastai also can keep track of the Loss as a metric.

In [None]:
class AvgLoss(Metric):
    "Average the losses taking into account potential different batch sizes"
    def reset(self):           
        self.total,self.count = 0.,0

    def accumulate(self, learn):
        bs = find_bs(learn.yb)
        self.total += learn.to_detach(learn.loss.mean())*bs
        self.count += bs

    @property
    def value(self): 
        return self.total/self.count if self.count != 0 else None

    @property
    def name(self):  
        return "loss"

And the exponentially smoothed loss, which is what is displayed in `Recorder` by default.

In [None]:
class AvgSmoothLoss(Metric):
    "Smooth average of the losses (exponentially weighted with `beta`)"
    def __init__(self, beta=0.98): 
        self.beta = beta

    def reset(self):               
        self.count,self.val = 0,tensor(0.)

    def accumulate(self, learn):
        self.count += 1
        self.val = torch.lerp(to_detach(learn.loss.mean()), self.val, self.beta)

    @property
    def value(self): 
        return self.val/(1-self.beta**self.count)

`AccumMetric` is the answer for the RMSE problem and how to use scikit-learn metrics since they are implemented using NumPy as the backend and our results are in Tensors.

Instead of calculating the metric and averaging it as we go, `AccumMetric` accumulates all the values, applying any `activation` and thresholding `thresh` along the way (removed for brevity, see the source code for full details). When `value` is called, `AccumMetric` calculates the metric on all the batches at once.

In [None]:
#|export
class AccumMetric(Metric):
    "Stores predictions and targets on CPU in accumulate to perform final calculations with `func`."
    def __init__(self, func, dim_argmax=None, activation=ActivationType.No, thresh=None, to_np=False,
                 invert_arg=False, flatten=True, name=None, **kwargs):
        store_attr('func,dim_argmax,activation,thresh,flatten')
        self._name = ifnone(name, self.func.func.__name__ if hasattr(self.func, 'func') else  self.func.__name__)
        self.to_np,self.invert_args,self.kwargs = to_np,invert_arg,kwargs

    def reset(self):
        "Clear all targs and preds"
        self.targs,self.preds = [],[]

    def accumulate(self, learn):
        "Store targs and preds from `learn`, using activation function and argmax as appropriate"
        pred = learn.pred
        # handle activations here
        self.accum_values(pred,learn.y,learn)

    def accum_values(self, preds, targs,learn=None):
        "Store targs and preds"
        to_d = learn.to_detach if learn is not None else to_detach
        preds,targs = to_d(preds),to_d(targs)
        if self.flatten: 
            preds,targs = flatten_check(preds,targs)
        self.preds.append(preds)
        self.targs.append(targs)

    @property
    def value(self):
        "Value of the metric using accumulated preds and targs"
        if len(self.preds) == 0: return
        preds,targs = torch.cat(self.preds),torch.cat(self.targs)
        if self.to_np: 
            preds,targs = preds.numpy(),targs.numpy()
        return self.func(targs, preds, **self.kwargs) if self.invert_args else self.func(preds, targs, **self.kwargs)

I reproduced the documetation on `AccumMetric` here for some more details on `AccumMetric`
> `func` is only applied to the accumulated predictions/targets when the `value` attribute is asked for (so at the end of a validation/training phase, in use with `Learner` and its `Recorder`).The signature of `func` should be `inp,targ` (where `inp` are the predictions of the model and `targ` the corresponding labels).
> 
> For classification problems with single label, predictions need to be transformed with a softmax then an argmax before being compared to the targets. Since a softmax doesn't change the order of the numbers, we can just apply the argmax. Pass along `dim_argmax` to have this done by `AccumMetric` (usually -1 will work pretty well). If you need to pass to your metrics the probabilities and not the predictions, use `softmax=True`.
> 
> For classification problems with multiple labels, or if your targets are one-hot encoded, predictions may need to pass through a sigmoid (if it wasn't included in your model) then be compared to a given threshold (to decide between 0 and 1), this is done by `AccumMetric` if you pass `sigmoid=True` and/or a value for `thresh`.
> 
> If you want to use a metric function sklearn.metrics, you will need to convert predictions and labels to numpy arrays with `to_np=True`. Also, scikit-learn metrics adopt the convention `y_true`, `y_preds` which is the opposite from us, so you will need to pass `invert_arg=True` to make `AccumMetric` do the inversion for you.

Finally, there is a convience method for creating a fastai metric from scikit-learn using `AccumMetric`

In [None]:
def skm_to_fastai(func, is_class=True, thresh=None, axis=-1, activation=None, **kwargs):
    "Convert `func` from sklearn.metrics to a fastai metric"
    dim_argmax = axis if is_class and thresh is None else None
    if activation is None:
        activation = ActivationType.Sigmoid if (is_class and thresh is not None) else ActivationType.No
    return AccumMetric(func, dim_argmax=dim_argmax, activation=activation, thresh=thresh,
                       to_np=True, invert_arg=True, **kwargs)

### Recorder

`Recorder` is the Callback which records all the metrics allowing us and loggers to capture and view them. Recorder can either record all metrics on train, valid, or both, without any granularity. This basic metric setup, like grabbing metric names, all happens in `before_fit`.

In [None]:
class Recorder(Callback):
    "Callback that registers statistics (lr, loss and metrics) during training"
    _stateattrs=('lrs','iters','losses','values')
    remove_on_fetch,order = True,50

    def __init__(self, add_time=True, train_metrics=False, valid_metrics=True, beta=0.98):
        store_attr('add_time,train_metrics,valid_metrics')
        self.loss,self.smooth_loss = AvgLoss(),AvgSmoothLoss(beta=beta)

    def before_fit(self):
        "Prepare state for training"
        self.lrs,self.iters,self.losses,self.values = [],[],[],[]
        names = self.metrics.attrgot('name')
        if self.train_metrics and self.valid_metrics:
            names = L('loss') + names
            names = names.map('train_{}') + names.map('valid_{}')
        elif self.valid_metrics: 
            names = L('train_loss', 'valid_loss') + names
        else: 
            names = L('train_loss') + names
        if self.add_time: 
            names.append('time')
        self.metric_names = 'epoch'+names
        self.smooth_loss.reset()

All `Metrics` need to be reset before recording metrics for each epoch, which is handled in `before_train` and `before_validate`.

In [None]:
def before_train(self): 
    self._train_mets[1:].map(Self.reset())

def before_validate(self): 
    self._valid_mets.map(Self.reset())

`after_batch` calls the `accumulate` method of all the metrics 
```python
for met in mets: 
    met.accumulate(self.learn)
```
but only if there are labels.

In [None]:
def after_batch(self):
    "Update all metrics and records lr and smooth loss in training"
    if len(self.yb) == 0: return
    mets = self._train_mets if self.training else self._valid_mets
    for met in mets: 
        met.accumulate(self.learn)
    if not self.training: 
        return
    self.lrs.append(self.opt.hypers[-1]['lr'])
    self.losses.append(self.smooth_loss.value)
    self.learn.smooth_loss = self.smooth_loss.value

When we're done with a train/valid epoch, `Recorder` appends the metrics into a list `log`

In [None]:
def _maybe_item(t):
    t = t.value
    try: return t.item()
    except: return t

def after_train(self): 
    self.log += self._train_mets.map(_maybe_item)

def after_validate(self): 
    self.log += self._valid_mets.map(_maybe_item)

And then at the end of the epoch, does the same except for epoch statistics like `self.logger(self.log)`

In [None]:
def after_epoch(self):
    "Store and log the loss/metric values"
    self.learn.final_record = self.log[1:].copy()
    self.values.append(self.learn.final_record)
    if self.add_time: 
        self.log.append(format_time(time.time() - self.start_epoch))
    self.logger(self.log)
    self.iters.append(self.smooth_loss.count)

## fastxtend Metrics

I created an improved version of fastai metrics as a part of [fastxtend](https://fastxtend.benjaminwarner.dev) which are backwards compatible with fastai metrics.

1. fastxtend metrics can independently log on train, valid, or both train and valid
2. All fastxtend metrics can use the activation support of fastai's `AccumMetric`, inherited from `MetricX`
3. fastxtend metrics add `AvgSmoothMetric`, a metric version of `AvgSmoothLoss`

I use it mostly for [logging multiple losses](https://fastxtend.benjaminwarner.dev/multiloss.html) individually as metrics. You can check out [fastxtend metrics here](https://fastxtend.benjaminwarner.dev/metrics.html).