---
aliases:
    - callback.simpleprofiler.html
---

In [None]:
#|default_exp callback.profiler

In [None]:
#|exporti
# Contains code from:
# fastai - Apache License 2.0 - Copyright (c) 2023 fast.ai

# Profiler
> Callbacks which add a throughput and simple profiler to fastai. Inspired by PyTorch Lightning's SimpleProfiler.

In [None]:
#|export
import locale
import time
import pandas as pd
import numpy as np
from pathlib import Path
from packaging.version import parse

from fastcore.foundation import docs
from fastcore.basics import mk_class, noop, in_notebook

import fastai
from fastai.learner import Learner, Recorder
from fastai.callback.core import *

from fastxtend.imports import *
from fastxtend.utils import scale_time

if in_notebook():
    from IPython.display import display

In [None]:
#|hide
from fastai.test_utils import synth_learner

Since fastxtend profilers change the fastai data loading loop, they are not imported by any of the fastxtend all imports and need to be imported seperately:

```python
from fastxtend.callback import profiler
```

::: {.callout-warning}
Throughput and Simple Profiler are untested on distributed training.
:::

Jump to usage [examples](#examples).

## Events
fastai callbacks do not have an event which is called directly before drawing a batch. fastxtend profilers add a new callback event called `before_draw`.

With a fastxtend profiler imported, a callback can implement actions on the following events:

- `after_create`: called after the `Learner` is created
- `before_fit`: called before starting training or inference, ideal for initial setup.
- `before_epoch`: called at the beginning of each epoch, useful for any behavior you need to reset at each epoch.
- `before_train`: called at the beginning of the training part of an epoch.
- **`before_draw`**: called at the beginning of each batch, just before drawing said batch.
- `before_batch`: called at the beginning of each batch, just after drawing said batch. It can be used to do any setup necessary for the batch (like hyper-parameter scheduling) or to change the input/target before it goes in the model (change of the input with techniques like mixup for instance).
- `after_pred`: called after computing the output of the model on the batch. It can be used to change that output before it's fed to the loss.
- `after_loss`: called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss (AR or TAR in RNN training for instance).
- `before_backward`: called after the loss has been computed, but only in training mode (i.e. when the backward pass will be used)
- `before_step`: called after the backward pass, but before the update of the parameters. It can be used to do any change to the gradients before said update (gradient clipping for instance).
- `after_step`: called after the step and before the gradients are zeroed.
- `after_batch`: called at the end of a batch, for any clean-up before the next one.
- `after_train`: called at the end of the training phase of an epoch.
- `before_validate`: called at the beginning of the validation phase of an epoch, useful for any setup needed specifically for validation.
- `after_validate`: called at the end of the validation part of an epoch.
- `after_epoch`: called at the end of an epoch, for any clean-up before the next one.
- `after_fit`: called at the end of training, for final clean-up.

## Implement before_draw -

#|hide

To add `before_draw` as a callable event, first it needs to be added to both the `_inner_loop` and `_events` lists of fastai events (fastai 2.7.0 adds new backward events).

In [None]:
#|exporti
if parse(fastai.__version__) >= parse('2.7.0'):
    _inner_loop = "before_draw before_batch after_pred after_loss before_backward after_cancel_backward after_backward before_step after_step after_cancel_batch after_batch".split()
else:
    _inner_loop = "before_draw before_batch after_pred after_loss before_backward before_step after_step after_cancel_batch after_batch".split()

In [None]:
#|exporti
if parse(fastai.__version__) >= parse('2.7.0'):
    _events = L.split('after_create before_fit before_epoch before_train before_draw before_batch after_pred after_loss \
        before_backward after_cancel_backward after_backward before_step after_cancel_step after_step \
        after_cancel_batch after_batch after_cancel_train after_train before_validate after_cancel_validate \
        after_validate after_cancel_epoch after_epoch after_cancel_fit after_fit')
else:
    _events = L.split('after_create before_fit before_epoch before_train before_draw before_batch after_pred after_loss \
        before_backward before_step after_cancel_step after_step after_cancel_batch after_batch after_cancel_train \
        after_train before_validate after_cancel_validate after_validate after_cancel_epoch \
        after_epoch after_cancel_fit after_fit')

mk_class('event', **_events.map_dict(),
         doc="All possible events as attributes to get tab-completion and typo-proofing")

#|hide

Next, `Callback` needs to be modified to be aware of the new event.

In [None]:
#|exporti
@patch
def __call__(self:Callback, event_name):
    "Call `self.{event_name}` if it's defined"
    _run = (event_name not in _inner_loop or (self.run_train and getattr(self, 'training', True)) or
            (self.run_valid and not getattr(self, 'training', False)))
    res = None
    if self.run and _run:
        try: res = getattr(self, event_name, noop)()
        except (CancelBatchException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise
        except Exception as e:
            e.args = [f'Exception occured in `{self.__class__.__name__}` when calling event `{event_name}`:\n\t{e.args[0]}']
            raise
    if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
    return res

#|hide

Then `Learner._call_one` needs to be aware of the `before_draw`.

In [None]:
#|exporti
@patch
def _call_one(self:Learner, event_name):
    if not hasattr(event, event_name): raise Exception(f'missing {event_name}')
    for cb in self.cbs.sorted('order'): cb(event_name)

#|hide

Finally, `Learner.all_batches` can be modified to call `before_draw` when iterating through a dataloader.

In [None]:
#|exporti
@patch
def all_batches(self:Learner):
    self.n_iter = len(self.dl)
    self.it = iter(self.dl)
    for i in range(self.n_iter):
        self("before_draw")
        self.one_batch(i, next(self.it))
    del(self.it)

In [None]:
#|export
_loop = ['Start Fit', 'before_fit', 'Start Epoch Loop', 'before_epoch', 'Start Train', 'before_train',
         'Start Batch Loop', 'before_draw', 'before_batch', 'after_pred', 'after_loss', 'before_backward',
         'before_step', 'after_step', 'after_cancel_batch', 'after_batch','End Batch Loop', 'End Train',
         'after_cancel_train', 'after_train', 'Start Valid', 'before_validate', 'Start Batch Loop',
         '**CBs same as train batch**', 'End Batch Loop', 'End Valid', 'after_cancel_validate',
         'after_validate', 'End Epoch Loop', 'after_cancel_epoch', 'after_epoch', 'End Fit',
         'after_cancel_fit', 'after_fit']

In [None]:
#|exporti
@patch
def show_training_loop(self:Learner):
    indent = 0
    for s in _loop:
        if s.startswith('Start'): print(f'{" "*indent}{s}'); indent += 2
        elif s.startswith('End'): indent -= 2; print(f'{" "*indent}{s}')
        else: print(f'{" "*indent} - {s:15}:', self.ordered_cbs(s))

## Throughput
The Throughput profiler only measures the `step`, `draw`, and `batch`. To use, both `ThroughputCallback` and `ThroughputPostCallback` must be added to the `Learner`. The recommended way to use is via `Learner.profile`.

In [None]:
#|exporti
_phase = ['fit', 'epoch', 'train', 'valid']
_epoch = ['train', 'valid']
_train_full  = ['step', 'draw', 'batch', 'forward', 'loss', 'backward', 'opt_step', 'zero_grad']
_valid_full  = ['step', 'draw', 'batch', 'predict', 'loss']
_train_short = ['step', 'draw', 'batch']
_valid_short = _train_short

In [None]:
#|export
class ThroughputCallback(Callback):
    """
    Adds a throughput profiler to the fastai `Learner`. Optionally showing formatted report or saving unformatted results as csv.

    Pair with ThroughputPostCallback to profile training performance.

    Post fit, access report & results via `Learner.profile_report` & `Learner.profile_results`.
    """
    order,remove_on_fetch = TrainEvalCallback.order+1,True
    def __init__(self,
        show_report:bool=True, # Display formatted report post profile
        plain:bool=False, # For Jupyter Notebooks, display plain report
        markdown:bool=False, # Display markdown formatted report
        save_csv:bool=False,  # Save raw results to csv
        csv_name:str='throughput.csv', # CSV save location
        rolling_average:int=10, # Number of batches to average throughput over
        drop_first_batch:bool=True, # Drop the first batch from profiling
        logger_callback='wandb' # Log report and samples/second to `logger_callback` using `Callback.name`
    ):
        store_attr(but='csv_name,average,drop_first_batch')
        self.csv_name = Path(csv_name)
        self._drop = int(drop_first_batch)
        self._rolling_average = rolling_average
        self._log_after_batch = getattr(self, f'_{logger_callback}_log_after_batch', noop)
        self._log_after_fit   = getattr(self, f'_{logger_callback}_log_after_fit', noop)
        self._phase, self._train, self._valid = _phase, _train_short, _valid_short

    def before_fit(self):
        self.has_logger = hasattr(self.learn, self.logger_callback) and not hasattr(self.learn, 'lr_finder') and not hasattr(self, "gather_preds")
        self._raw_values, self._processed_samples = {}, {}
        for p in _phase:
            self._raw_values[p] = []
        for p in _epoch:
            for a in getattr(self, f'_{p}'):
                if a!='samples':
                    self._raw_values[f'{p}_{a}'] = []
            self._raw_values[f'{p}_bs'] = []
        self._fit_start = time.perf_counter()

    def before_epoch(self):
        self._epoch_start = time.perf_counter()

    def before_train(self):
        self._train_start = time.perf_counter()

    def before_validate(self):
        self._validate_start = time.perf_counter()

    def before_draw(self):
        if self.training:
            self._train_draw_start = time.perf_counter()
        else:
            self._valid_draw_start = time.perf_counter()

    def before_batch(self):
        if self.training:
            self._raw_values['train_draw'].append(time.perf_counter() - self._train_draw_start)
            self._train_batch_start = time.perf_counter()
        else:
            self._raw_values['valid_draw'].append(time.perf_counter() - self._valid_draw_start)
            self._valid_batch_start = time.perf_counter()

    def _samples_per_second(self, bs, action, epoch='train'):
        if action in ['step', 'draw']:
            batch = np.mean(self._raw_values[f'{epoch}_batch'][-self._rolling_average:])
            draw = np.mean(self._raw_values[f'{epoch}_draw'][-self._rolling_average:])
            return -((bs/batch if action=='draw' else 0) - bs/(draw+batch))
        else:
            return bs/np.mean(self._raw_values[f'{epoch}_{action}'][-self._rolling_average:])

    def _generate_report(self):
        total_time = self._raw_values['fit'][0]
        self.report = pd.DataFrame(columns=['Phase', 'Action', 'Mean Duration', 'Duration Std Dev',
                                            'Number of Calls', 'Samples/Second', 'Total Time', 'Percent of Total'])
        for p in _phase:
            if p == 'fit':
                self._append_to_df(['fit', p, 0, 0, 1, '-', total_time, f'{self._calc_percent(total_time):.0%}'])
            elif p == 'epoch':
                self._append_to_df(self._create_overview_row('fit', p, self._raw_values[p], None))
            else:
                self._append_to_df(self._create_overview_row('fit', p, self._raw_values[p], np.array(self._raw_values[f'{p}_bs'])))

        for p in _epoch:
            bs = np.array(self._raw_values[f'{p}_bs'])
            for a in getattr(self, f'_{p}'):
                if a == 'step':
                    values = np.array(self._raw_values[f'{p}_draw']) + np.array(self._raw_values[f'{p}_batch'])
                else:
                    values = np.array(self._raw_values[f'{p}_{a}'])
                self._append_to_df(self._create_detail_row(p, a, values, bs))

        self.learn.profile_results = self.report.copy()
        for c in ['Mean Duration', 'Duration Std Dev', 'Total Time']:
            self.report[c] = self.report[c].apply(scale_time)
        self.report[['Phase', 'Action']] = self.report[['Phase', 'Action']].where(~self.report[['Phase', 'Action']].duplicated(), '')
        self.report['Phase']  = self.report['Phase'].where(~self.report['Phase'].duplicated(), '')
        self.report['Action'] = self.report['Action'].where(self.report['Phase'] != self.report['Action']).fillna('')
        self.learn.profile_report = self.report

    def _display_report(self):
        if self.show_report:
            if self.markdown:
                print(self.report.to_markdown(index=False))
            else:
                if in_notebook() and not self.plain:
                    with pd.option_context('display.max_rows', len(self.report.index)):
                        s = self.report.style.set_caption("Profiling Results").hide(axis='index')
                        display(s)
                else:
                    print('Profiling Results')
                    print(self.report.to_string(index=False))
            if self._drop > 0:
                print(f'Batch dropped. train and valid phases show {self._drop} less batch than fit.')
        if self.save_csv:
            self.path.parent.mkdir(parents=True, exist_ok=True)
            self.learn.profile_results.to_csv(self.path/self.csv_name, index=False)

    def _append_to_df(self, row):
        self.report.loc[len(self.report.index)] = row

    def _calc_percent(self, time):
        return time / self._raw_values['fit'][0]

    def _create_overview_row(self, phase, action, input, bs=None):
        if bs is not None:
            draw = np.array(self._raw_values[f'{action}_draw'])
            batch = np.array(self._raw_values[f'{action}_batch'])
            self._processed_samples[f'{phase}_{action}'] = bs/(draw+batch)
            sam_per_sec = f'{int(np.around(self._processed_samples[f"{phase}_{action}"].mean())):,d}'
        else:
            sam_per_sec = '-'
        return [phase, action, np.mean(input), np.std(input), len(input), sam_per_sec,
                np.sum(input), f'{self._calc_percent(np.sum(input)):.0%}']

    def _create_detail_row(self, phase, action, input, bs=None):
        input = input[self._drop:]
        if bs is None or action=='zero_grad':
            sam_per_sec = '-'
        elif action == 'draw':
            bs = np.array(bs[self._drop:])
            batch = np.array(self._raw_values[f'{phase}_batch'][self._drop:])
            self._processed_samples[f'{phase}_{action}'] = -(bs/batch - bs/(input+batch))
            sam_per_sec = f'{int(np.around(self._processed_samples[f"{phase}_{action}"].mean())):,d}'
        else:
            bs = np.array(bs[self._drop:])
            self._processed_samples[f'{phase}_{action}'] = bs/input
            sam_per_sec = f'{int(np.around(self._processed_samples[f"{phase}_{action}"].mean())):,d}'
        return [phase, action, np.mean(input), np.std(input), len(input), sam_per_sec,
                np.sum(input), f'{self._calc_percent(np.sum(input)):.0%}']

In [None]:
#|export
class ThroughputPostCallback(Callback):
    "Required pair with `ThroughputCallback` to profile training performance. Removes itself after training is over."
    order,remove_on_fetch = Recorder.order-1,True
    def __init__(self):
        self._log_full = False
        self._phase, self._train, self._valid = _phase, _train_short, _train_short

    def before_fit(self):
        self.profiler = self.learn.throughput
        self.has_logger = self.profiler.has_logger
        self._start_train_logging, self._start_valid_logging = False, False
        self.n_train_batches = len(self.dls.train)
        self.n_valid_batches = len(self.dls.valid)
        self._rolling_average = self.profiler._rolling_average
        self._iter = -self.profiler._drop

    def after_train(self):
        self.profiler._raw_values['train'].append(time.perf_counter() - self.profiler._train_start)

    def after_validate(self):
        self.profiler._raw_values['valid'].append(time.perf_counter() - self.profiler._validate_start)

    def after_batch(self):
        if self.training:
            self.profiler._raw_values['train_batch'].append(time.perf_counter() - self.profiler._train_batch_start)
            self.profiler._raw_values['train_bs'].append(find_bs(self.learn.yb))
            if self.has_logger and self._iter >= self._rolling_average and self._iter % self._rolling_average == 0:
                self.profiler._log_after_batch(self._train)
            self._iter += 1
        else:
            self.profiler._raw_values['valid_batch'].append(time.perf_counter() - self.profiler._valid_batch_start)
            self.profiler._raw_values['valid_bs'].append(find_bs(self.learn.yb))

    def after_epoch(self):
        self.profiler._raw_values['epoch'].append(time.perf_counter() - self.profiler._epoch_start)

    def _after_fit(self, callbacks):
        self.profiler._raw_values['fit'].append(time.perf_counter() - self.profiler._fit_start)
        self.profiler._generate_report()
        if self.has_logger: self.profiler._log_after_fit()
        if not hasattr(self.learn, 'lr_finder'):
            self.profiler._display_report()
            self.learn.remove_cbs(callbacks)

    def after_fit(self):
        self._after_fit([ThroughputCallback, ThroughputPostCallback])

## Simple Profiler
To use, both `SimpleProfilerCallback` and `SimpleProfilerPostCallback` must be added to the `Learner`. The recommended way to use is via `Learner.profile`.

In [None]:
#|export
class SimpleProfilerCallback(ThroughputCallback):
    """
    Adds a simple profiler to the fastai `Learner`. Optionally showing formatted report or saving unformatted results as csv.

    Pair with SimpleProfilerPostCallback to profile training performance.

    Post fit, access report & results via `Learner.profile_report` & `Learner.profile_results`.
    """
    order,remove_on_fetch = TrainEvalCallback.order+1,True
    def __init__(self,
        show_report:bool=True, # Display formatted report post profile
        plain:bool=False, # For Jupyter Notebooks, display plain report
        markdown:bool=False, # Display markdown formatted report
        save_csv:bool=False,  # Save raw results to csv
        csv_name:str='simpleprofiler.csv', # CSV save location
        rolling_average:int=10, # Number of batches to average throughput over
        drop_first_batch:bool=True, # Drop the first batch from profiling
        logger_callback='wandb' # Log report and samples/second to `logger_callback` using `Callback.name`
    ):
        super().__init__(show_report=show_report, plain=plain, markdown=markdown, save_csv=save_csv,
                         csv_name=csv_name, rolling_average=rolling_average, drop_first_batch=drop_first_batch,
                         logger_callback=logger_callback)
        self._phase, self._train, self._valid = _phase, _train_full, _valid_full

    def before_backward(self):
        self._backward_start = time.perf_counter()

    def before_step(self):
        self._raw_values['train_backward'].append(time.perf_counter() - self._backward_start)
        self._step_start = time.perf_counter()

    def after_batch(self):
        if self.training:
            self._raw_values['train_zero_grad'].append(time.perf_counter() - self._zero_start)

In [None]:
#|export
class SimpleProfilerPostCallback(ThroughputPostCallback):
    "Required pair with `SimpleProfilerCallback` to profile training performance. Removes itself after training is over."
    order,remove_on_fetch = Recorder.order-1,True
    def __init__(self):
        self._log_full = True
        self._phase, self._train, self._valid = _phase, _train_full, _valid_full

    def before_fit(self):
        self.profiler = self.learn.simple_profiler
        self._start_logging = self.profiler._rolling_average + self.profiler._drop
        self.has_logger = self.profiler.has_logger
        self._start_train_logging, self._start_valid_logging = False, False
        self.n_train_batches = len(self.dls.train)
        self.n_valid_batches = len(self.dls.valid)

    def after_pred(self):
        if self.training:
            self.profiler._raw_values['train_forward'].append(time.perf_counter() - self.profiler._train_batch_start)
            self.profiler._train_loss_start = time.perf_counter()
        else:
            self.profiler._raw_values['valid_predict'].append(time.perf_counter() - self.profiler._valid_batch_start)
            self.profiler._valid_loss_start = time.perf_counter()

    def after_loss(self):
        if self.training:
            self.profiler._raw_values['train_loss'].append(time.perf_counter() - self.profiler._train_loss_start)
        else:
            self.profiler._raw_values['valid_loss'].append(time.perf_counter() - self.profiler._valid_loss_start)

    def after_step(self):
        self.profiler._raw_values['train_opt_step'].append(time.perf_counter() - self.profiler._step_start)
        self.profiler._zero_start = time.perf_counter()

    def after_fit(self):
        self._after_fit([SimpleProfilerCallback, SimpleProfilerPostCallback])

## Convenience Method
`Learner.profile` is the easy and recommended way to use a fastxtend profiler.

In [None]:
#|export
class ProfileMode(Enum):
    "Profile enum for `Learner.profile`"
    Throughput = 'throughput'
    Simple     = 'simple'

In [None]:
#|export
@patch
def profile(self:Learner,
        mode:ProfileMode=ProfileMode.Throughput, # Which profiler to use. Throughput or Simple.
        show_report:bool=True, # Display formatted report post profile
        plain:bool=False, # For Jupyter Notebooks, display plain report
        markdown:bool=False, # Display markdown formatted report
        save_csv:bool=False,  # Save raw results to csv
        csv_name:str='simpleprofiler.csv', # CSV save location
        rolling_average:int=10, # Number of batches to average throughput over
        drop_first_batch:bool=True, # Drop the first batch from profiling
        logger_callback='wandb' # Log report and samples/second to `logger_callback` using `Callback.name`
    ):
    "Run a fastxtend profiler which removes itself when finished training."
    if mode == ProfileMode.Throughput:
        self.add_cbs([ThroughputCallback(show_report=show_report, plain=plain, markdown=markdown, 
                                         save_csv=save_csv, csv_name=csv_name, rolling_average=rolling_average, 
                                         drop_first_batch=drop_first_batch, logger_callback=logger_callback),
                      ThroughputPostCallback()
                ])
    if mode == ProfileMode.Simple:
        self.add_cbs([SimpleProfilerCallback(show_report=show_report, plain=plain, markdown=markdown, 
                                             save_csv=save_csv, csv_name=csv_name, rolling_average=rolling_average, 
                                             drop_first_batch=drop_first_batch, logger_callback=logger_callback),
                      SimpleProfilerPostCallback()
                ])
    return self

## Output

The Simple Profiler report contains the following items divided in three Phases (Fit, Train, & Valid)

Fit:

- `fit`:  total time fitting the model takes.
- `epoch`: duration of both training and validation epochs. Often epoch total time is the same amount of elapsed time as fit.
- `train`: duration of each training epoch.
- `valid`: duration of each validation epoch.

Train:

- `step`: total duration of all batch steps including drawing the batch. Measured from `before_draw` to `after_batch`.
- `draw`: time spent waiting for a batch to be drawn. Measured from `before_draw` to `before_batch`. Ideally this value should be as close to zero as possible. 
- `batch`: total duration of all batch steps except drawing the batch. Measured from `before_batch` to `after_batch`.
- `forward`: duration of the forward pass and any additional batch modifications. Measured from `before_batch` to `after_pred`.
- `loss`: duration of calculating loss. Measured from `after_pred` to `after_loss`.
- `backward`: duration of the backward pass. Measured from `before_backward` to `before_step`.
- `opt_step`: duration of the optimizer step. Measured from `before_step` to `after_step`.
- `zero_grad`: duration of the zero_grad step. Measured from `after_step` to `after_batch`.

Valid:

- `step`: total duration of all batch steps including drawing the batch. Measured from `before_draw` to `after_batch`.
- `draw`: time spent waiting for a batch to be drawn. Measured from `before_draw` to `before_batch`. Ideally this value should be as close to zero as possible. 
- `batch`: total duration of all batch steps except drawing the batch. Measured from `before_batch` to `after_batch`.
- `predict`: duration of the prediction pass and any additional batch modifications. Measured from `before_batch` to `after_pred`.
- `loss`: duration of calculating loss. Measured from `after_pred` to `after_loss`.

The Throughput profiler only contains `step`, `draw`, and `batch`.

## Examples
These examples are trained on Imagenette with an image size of 224 and batch size of 64 on a 3080 Ti.

In [None]:
#|hide
#|slow
from fastcore.basics import num_cpus

from fastai.data.external import URLs, untar_data
from fastai.data.block import DataBlock, CategoryBlock
from fastai.data.transforms import GrandparentSplitter, get_image_files, parent_label, Normalize
from fastai.vision.augment import Resize
from fastai.vision.core import imagenet_stats
from fastai.vision.data import ImageBlock
from fastai.vision.models.xresnet import xresnext50
from fastxtend.optimizer.fused import adam
from fastxtend.metrics import *
from fastxtend.utils import *

In [None]:
#|hide
#|slow
#|cuda
imagenette = untar_data(URLs.IMAGENETTE_320)

dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                    splitter=GrandparentSplitter(valid_name='val'),
                    get_items=get_image_files, get_y=parent_label,
                    item_tfms=Resize(224),
                    batch_tfms=[Normalize.from_stats(*imagenet_stats)])
dls = dblock.dataloaders(imagenette, bs=64, num_workers=num_cpus(), pin_memory=True)

In [None]:
#|slow
#|cuda
learn = Learner(dls, xresnext50(n_out=dls.c), opt_func=adam(foreach=True),
                metrics=Accuracy()).to_channelslast().profile()
learn.fit_one_cycle(2, 3e-3)

epoch,train_loss,valid_loss,accuracy,time
0,1.501953,1.734705,0.472357,00:18
1,1.040516,0.913281,0.712866,00:16


Phase,Action,Mean Duration,Duration Std Dev,Number of Calls,Samples/Second,Total Time,Percent of Total
fit,,-,-,1,-,35.63 s,100%
,epoch,17.81 s,838.2ms,2,-,35.63 s,100%
,train,14.24 s,797.1ms,2,678,28.49 s,80%
,valid,3.565 s,39.48ms,2,1311,7.130 s,20%
train,step,86.62ms,41.67ms,293,739,25.38 s,71%
,draw,4.269ms,37.39ms,293,-38,1.251 s,4%
,batch,82.35ms,4.472ms,293,777,24.13 s,68%
valid,step,43.05ms,63.38ms,123,1470,5.295 s,15%
,draw,14.46ms,60.89ms,123,-744,1.779 s,5%
,batch,28.59ms,11.42ms,123,2214,3.516 s,10%


Batch dropped. train and valid phases show 1 less batch than fit.


In [None]:
#|hide
#|slow
#|cuda
free_gpu_memory(learn)

In [None]:
#|slow
#|cuda
learn = Learner(dls, xresnext50(n_out=dls.c), opt_func=adam(foreach=True),
                metrics=Accuracy()).to_channelslast().profile(ProfileMode.Simple)
learn.fit_one_cycle(2, 3e-3)

epoch,train_loss,valid_loss,accuracy,time
0,1.49755,2.453694,0.428535,00:17
1,0.997146,0.888791,0.723057,00:17


Phase,Action,Mean Duration,Duration Std Dev,Number of Calls,Samples/Second,Total Time,Percent of Total
fit,,-,-,1,-,34.55 s,100%
,epoch,17.27 s,44.73ms,2,-,34.54 s,100%
,train,13.64 s,4.756ms,2,709,27.28 s,79%
,valid,3.629 s,48.68ms,2,1291,7.259 s,21%
train,step,87.64ms,44.58ms,293,730,25.68 s,74%
,draw,4.428ms,39.70ms,293,-39,1.297 s,4%
,batch,83.22ms,6.353ms,293,769,24.38 s,71%
,forward,16.65ms,5.732ms,293,3843,4.880 s,14%
,loss,771.3µs,196.1µs,293,82977,226.0ms,1%
,backward,19.10ms,5.501ms,293,3351,5.597 s,16%


Batch dropped. train and valid phases show 1 less batch than fit.


## New Training Loop
The `show_training_loop` output below shows where the new `before_draw` event fits into the training loop.

In [None]:
learn = synth_learner()
learn.show_training_loop()

Start Fit
   - before_fit     : [TrainEvalCallback, Recorder, ProgressCallback]
  Start Epoch Loop
     - before_epoch   : [Recorder, ProgressCallback]
    Start Train
       - before_train   : [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - before_draw    : []
         - before_batch   : [CastToTensor]
         - after_pred     : []
         - after_loss     : []
         - before_backward: []
         - before_step    : []
         - after_step     : []
         - after_cancel_batch: []
         - after_batch    : [TrainEvalCallback, Recorder, ProgressCallback]
      End Batch Loop
    End Train
     - after_cancel_train: [Recorder]
     - after_train    : [Recorder, ProgressCallback]
    Start Valid
       - before_validate: [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - **CBs same as train batch**: []
      End Batch Loop
    End Valid
     - after_cancel_validate: [Recorder]
     - after_validate : [Recorder, Pro

## Weights & Biases Logging

If Weights & Biases is installed and the [`WandbCallback`](https://docs.fast.ai/callback.wandb.html) is added to `Learner`, the Simple Profiler callback will automatically logs samples/second for draw, batch, forward, loss, backward, and opt_step steps as wandb charts.

Also logs two tables to active wandb run:

* `profile_report`: formatted report from Simple Profiler
* `profile_results`: raw results from Simple Profiler

## Extend to other Loggers

To extend to new loggers, follow the Weights & Biases code below and create patches for `ThroughputCallback` to add a `_{Callback.name}_log_after_batch` and `_{Callback.name}_log_after_fit`, where `Callback.name` is the [name of the logger callback](https://docs.fast.ai/callback.core.html#Callback.name).

`SimpleProfilerCallback` inherits from `ThroughputCallback` so only one patch is needed.

In [None]:
#|exporti
def convert_to_int(s):
    try:
        return int(s.replace(",", ""))
    except ValueError:
        return s

In [None]:
#|exports
try:
    import wandb

    @patch
    def _wandb_log_after_batch(self:ThroughputCallback, actions:list[str]):
        bs = np.mean(self._raw_values[f'train_bs'][-self._rolling_average:])
        logs = {f'throughput/{action}': self._samples_per_second(bs, action) for action in actions}
        wandb.log(logs, self.learn.wandb._wandb_step+1)

    @patch
    def _wandb_log_after_fit(self:ThroughputCallback):
        for t in self.learn.profile_results.itertuples():
            if isinstance(convert_to_int(t._6), int):
                wandb.summary[f'{t.Phase}/{t.Action}_throughput'] = self._processed_samples[f'{t.Phase}_{t.Action}']

            values = self._raw_values[f'{t.Phase}_{t.Action}']
            if t.Phase in ['train', 'valid']:
                # Optionaly drop first batch if train/valid phase
                values = values[self._drop:]
            wandb.summary[f'{t.Phase}/{t.Action}_duration'] = values

        report = wandb.Table(dataframe=self.learn.profile_report)
        results = wandb.Table(dataframe=self.learn.profile_results)

        wandb.log({"profile_report": report})
        wandb.log({"profile_results": results})
except:
    pass

Then to use, pass `logger_callback='{Callback.name}'` to `Learner.profile()`.

`ThroughputCallback` sets its `_log_after_batch` method to `f'_{self.logger_callback}_log_after_batch'`, which should match the patched method.

```python
self._log_after_batch = getattr(self, f'_{self.logger_callback}_log_after_batch', noop)
```

`ThroughputCallback.log_after_fit` behaves the same way.