In [None]:
# default_exp basics.training_loop

# Visualizing the fastai Training Loop Further
> Extending fastai's `show_training_loop` to be more verbose about event triggers

In [None]:
#hide
from nbdev.showdoc import *
from wwf.utils import *

In [None]:
#hide_input
state_versions(['fastai', 'fastcore', 'wwf'])


---
This article is also a Jupyter Notebook available to be run from the top down. There
will be code snippets that you can then run in any environment.

Below are the versions of `fastai`, `fastcore`, and `wwf` currently running at the time of writing this:
* `fastai`: 2.2.5 
* `fastcore`: 1.3.19 
* `wwf`: 0.0.10 
---

## Understanding fastai's Training Loop

`fastai`'s training loop is certainly unique in its approach, where everything is handled through `Callbacks`. What this means is there should *never* be an instance where if you need to modify `fastai`'s training loop you are modifying `Learner`'s source code. 

Instead we can use various trigger points through `Callbacks` to get there. Currently fastai has a methodology of showing just what `Callbacks` are called during the training loop through a function called `Learner.show_training_loop`

## `show_training_loop`

The goal of `show_training_loop` is to show the user just what `Callbacks` are triggered during fastai's entire training cycle. An example is provided below:

In [None]:
from fastai.callback.all import *
from fastai.test_utils import synth_learner

learn = synth_learner()
learn.show_training_loop()

Start Fit
   - before_fit     : [TrainEvalCallback, Recorder, ProgressCallback]
  Start Epoch Loop
     - before_epoch   : [Recorder, ProgressCallback]
    Start Train
       - before_train   : [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - before_batch   : []
         - after_pred     : []
         - after_loss     : []
         - before_backward: []
         - before_step    : []
         - after_step     : []
         - after_cancel_batch: []
         - after_batch    : [TrainEvalCallback, Recorder, ProgressCallback]
      End Batch Loop
    End Train
     - after_cancel_train: [Recorder]
     - after_train    : [Recorder, ProgressCallback]
    Start Valid
       - before_validate: [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - **CBs same as train batch**: []
      End Batch Loop
    End Valid
     - after_cancel_validate: [Recorder]
     - after_validate : [Recorder, ProgressCallback]
  End Epoch Loop
   - after_

As we can see, every major event is detailed with a `Start` and `Finish`, and the intermediate steps at each level are described. Paired with this are the `Callbacks` that get triggered at that particular event.

However, I think we can take this a step further to enable you to understand *just* what happens during each step. As a result, I've written a revised version of `Learner.show_training_loop`:

In [None]:
#exporti
from fastcore.xtras import dict2obj

_event2doc = {
    'after_create': "Called after the `Learner` is created",
    'before_fit': "Called before starting training or inference, ideal for initial setup",
    'before_epoch': "Called at the beginning of each epoch, useful for any behavior you need to reset at each epoch",
    'before_train': "Called at the beginning of the training part of an epoch",
    'before_batch': "Called at the beginning of each batch, just after drawing said batch.\nIt can be used to do any setup necessary for the batch or to change the input/target before it goes in the model",
    'after_pred': "Called after computing the output of the model on the batch. It can be used to change that output before it's fed to the loss",
    'after_loss': "Called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss",
    'before_backward': "Called after the loss has been computed, but only in training mode (i.e. when the backward pass will be used)",
    'before_step': "Called after the backward pass, but before the update of the parameters. It can be used to do any change to the gradients before said update",
    'after_step': "Called after the step and before the gradients are zeroed",
    'after_batch': "Called at the end of a batch, for any clean-up before the next one",
    'after_train': "Called at the end of the training phase of an epoch",
    'before_validate': "Called at the beginning of the validation phase of an epoch, useful for any setup needed specifically for validation",
    'after_validate': "Called at the end of the validation part of an epoch",
    'after_epoch': "Called at the end of an epoch, for any clean-up before the next one",
    'after_fit': "Called at the end of training, for final clean-up",
    'after_cancel_batch': "Reached immediately after a CancelBatchException before proceeding to after_batch",
    'after_cancel_train': "Reached immediately after a CancelTrainException before proceeding to after_epoch",
    'after_cancel_validate': "Reached immediately after a CancelValidException before proceeding to after_epoch",
    'after_cancel_epoch': "Reached immediately after a CancelEpochException before proceeding to after_epoch",
    'after_cancel_fit': "Reached immediately after a CancelFitException before proceeding to after_fit"
}

event2doc = dict2obj(_event2doc)

In [None]:
#export
from typing import Union, List

from fastai.callback.core import Callback
from fastcore.dispatch import patch
from fastcore.xtras import is_listy, listify
from fastai.learner import _loop, Learner #list of all fastai events


In [None]:
#export
def _print_cb(cb:Callback, event:str, indent:int=0):
    "Prints what `cb` does during `event` with potential `indent`"
    if getattr(cb, event).__doc__ is not None:
        print(f'{" "*(indent+4)} - {cb}: \n{" "*(indent+8)} - {getattr(cb, event).__doc__}')
    else:
        print(f'{" "*(indent+4)} - {cb}')

In [None]:
#export
@patch
def show_training_loop(self:Learner, verbose:bool=False, cbs:Union[None,list,Callback]=None):
    "Show each step in the training loop, potentially with Callback event descriptions"
    if cbs is not None: self.add_cbs(cbs) if is_listy(cbs) else self.add_cbs(listify(cbs))
    indent = 0
    for s in _loop:
        if s.startswith('Start'): print(f'{" "*indent}{s}'); indent += 3
        elif s.startswith('End'): indent -= 3; print(f'{" "*indent}{s}')
        else:
            if not verbose:
                print(f'{" "*indent} - {s}:', self.ordered_cbs(s))
            else:
                print(f'{" "*indent} - {s}:')
                for cb in self.ordered_cbs(s): 
                    _print_cb(cb, s, indent)
    if cbs is not None: self.remove_cbs(cbs) if is_listy(cbs) else self.remove_cbs(listify(cbs))

With this new version we can pass in a `verbose` tag and for every `Callback` and its events we will pull its documentation string, so we can see what happens at each step as shown below:

In [None]:
learn.show_training_loop(verbose=True)

Start Fit
    - before_fit:
        - TrainEvalCallback: 
            - Set the iter and epoch counters to 0, put the model and the right device
        - Recorder: 
            - Prepare state for training
        - ProgressCallback: 
            - Setup the master bar over the epochs
   Start Epoch Loop
       - before_epoch:
           - Recorder: 
               - Set timer if `self.add_time=True`
           - ProgressCallback: 
               - Update the master bar
      Start Train
          - before_train:
              - TrainEvalCallback: 
                  - Set the model in training mode
              - Recorder: 
                  - Reset loss and metrics state
              - ProgressCallback: 
                  - Launch a progress bar over the training dataloader
         Start Batch Loop
             - before_batch:
             - after_pred:
             - after_loss:
             - before_backward:
             - before_step:
             - after_step:
             - 

## Usage Example:

To use this functionality, simply do:

In [None]:
from wwf.basics.training_loop import *

And then call `learn.show_training_loop(verbose=True)`