Skip to content

Commit

Permalink
Tensorboard logging (#512)
Browse files Browse the repository at this point in the history
* Factor out common functionality from PrintLog

.. that it will share with TensorBoard.

In addition, improve docstring and move the handling of keys_ignored
being a string to initialize method (so that init params are not
modified).

* (WIP) TensorBoard callback for automatic logging to tensorboard

There will be an update to one of the existing notebooks to use this
callback and show how to extend it.

* Add tensorboard as optional dependency

Skip TensorBoard tests if tensorboard is installed.

* Fix typo in requirements-dev.txt

* Tensorboard needs past which needs future

* SummaryWriter also needs PIL

* Fix typo in docstring

Co-Authored-By: Sergey Alexandrov <alexandrov88@gmail.com>

* Add TensorBoard to MNIST-torchvision notebook

Also demonstrates how to subclass TensorBoard callback.

* Improve docstrings for clarity

As suggested by taketwo

* Add (failing) test for add_graph with dict input

* Remove add_graph for now from TensorBoard

It works most of the time but not all of the time. Therefore, we
remove this functionality for now, so that we have at least the scalar
logging (which is more useful anyway).

Still show in the notebook how to add graphs.

* Add instruction how to start tensorboard

Addressing reviewer suggestion.

* Satisfy pylint on tensorboard install

Addresses reviewer comment

* Use contextlib.suppress instead of empty catch

Addresses reviewer comment
  • Loading branch information
BenjaminBossan authored and ottonemo committed Sep 12, 2019
1 parent 066eb39 commit b8cb53f
Show file tree
Hide file tree
Showing 10 changed files with 695 additions and 138 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- More careful check for wrong parameter names being passed to NeuralNet (#500)
- More helpful error messages when trying to predict using an uninitialized model
- Add TensorBoard callback for automatic logging to tensorboard

### Changed

Expand Down
Binary file added assets/tensorboard_digits.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/tensorboard_graph.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/tensorboard_scalars.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies:
- six=1.12.0
- sqlite=3.29.0
- tabulate=0.8.3
- tensorboard=1.14.0
- tk=8.6.8
- tqdm=4.32.1
- wheel=0.33.4
Expand Down
372 changes: 249 additions & 123 deletions notebooks/MNIST-torchvision.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
fire
flaky
future>=0.17.1
jupyter
matplotlib>=2.0.2
numpydoc
openpyxl
pandas
pillow
pylint
pytest>=3.4
pytest-cov
sphinx
sphinx_rtd_theme
tensorboard>=1.14.0
198 changes: 185 additions & 13 deletions skorch/callbacks/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import sys
import time
from contextlib import suppress
from numbers import Number
from itertools import cycle

Expand All @@ -14,7 +15,34 @@
from skorch.callbacks import Callback


__all__ = ['EpochTimer', 'PrintLog', 'ProgressBar']
__all__ = ['EpochTimer', 'PrintLog', 'ProgressBar', 'TensorBoard']


def filter_log_keys(keys, keys_ignored=None):
"""Filter out keys that are generally to be ignored.
This is used by several callbacks to filter out keys from history
that should not be logged.
Parameters
----------
keys : iterable of str
All keys.
keys_ignored : iterable of str or None (default=None)
If not None, collection of extra keys to be ignored.
"""
keys_ignored = keys_ignored or ()
for key in keys:
if not (
key == 'epoch' or
(key in keys_ignored) or
key.endswith('_best') or
key.endswith('_batch_count') or
key.startswith('event_')
):
yield key


class EpochTimer(Callback):
Expand Down Expand Up @@ -62,7 +90,9 @@ class PrintLog(Callback):
----------
keys_ignored : str or list of str (default=None)
Key or list of keys that should not be part of the printed
table. Note that keys ending on '_best' are also ignored.
table. Note that in addition to the keys provided by the user,
keys such as those starting with 'event_' or ending on '_best'
are ignored by default.
sink : callable (default=print)
The target that the output string is sent to. By default, the
Expand Down Expand Up @@ -92,8 +122,6 @@ def __init__(
floatfmt='.4f',
stralign='right',
):
if isinstance(keys_ignored, str):
keys_ignored = [keys_ignored]
self.keys_ignored = keys_ignored
self.sink = sink
self.tablefmt = tablefmt
Expand All @@ -102,7 +130,11 @@ def __init__(

def initialize(self):
self.first_iteration_ = True
self.keys_ignored_ = set(self.keys_ignored or [])

keys_ignored = self.keys_ignored
if isinstance(keys_ignored, str):
keys_ignored = [keys_ignored]
self.keys_ignored_ = set(keys_ignored or [])
self.keys_ignored_.add('batches')
return self

Expand Down Expand Up @@ -140,24 +172,25 @@ def _sorted_keys(self, keys):
* all remaining keys are sorted alphabetically.
"""
sorted_keys = []

# make sure 'epoch' comes first
if ('epoch' in keys) and ('epoch' not in self.keys_ignored_):
sorted_keys.append('epoch')

for key in sorted(keys):
if not (
(key in ('epoch', 'dur')) or
(key in self.keys_ignored_) or
key.endswith('_best') or
key.endswith('_batch_count') or
key.startswith('event_')
):
# ignore keys like *_best or event_*
for key in filter_log_keys(sorted(keys), keys_ignored=self.keys_ignored_):
if key != 'dur':
sorted_keys.append(key)

# add event_* keys
for key in sorted(keys):
if key.startswith('event_') and (key not in self.keys_ignored_):
sorted_keys.append(key)

# make sure 'dur' comes last
if ('dur' in keys) and ('dur' not in self.keys_ignored_):
sorted_keys.append('dur')

return sorted_keys

def _yield_keys_formatted(self, row):
Expand Down Expand Up @@ -321,3 +354,142 @@ def on_epoch_begin(self, net, dataset_train=None, dataset_valid=None, **kwargs):

def on_epoch_end(self, net, **kwargs):
self.pbar.close()


def rename_tensorboard_key(key):
"""Rename keys from history to keys in TensorBoard
Specifically, prefixes all names with "Loss/" if they seem to be
losses.
"""
if key.startswith('train') or key.startswith('valid'):
key = 'Loss/' + key
return key


class TensorBoard(Callback):
"""Logs results from history to TensorBoard
"TensorBoard provides the visualization and tooling needed for
machine learning experimentation" (tensorboard_)
Use this callback to automatically log all interesting values from
your net's history to tensorboard after each epoch.
The best way to log additional information is to subclass this
callback and add your code to one of the ``on_*`` methods.
Examples
--------
>>> # Example to log the bias parameter as a histogram
>>> def extract_bias(module):
... return module.hidden.bias
>>> class MyTensorBoard(TensorBoard):
... def on_epoch_end(self, net, **kwargs):
... bias = extract_bias(net.module_)
... epoch = net.history[-1, 'epoch']
... self.writer.add_histogram('bias', bias, global_step=epoch)
... super().on_epoch_end(net, **kwargs) # call super last
Parameters
----------
writer : torch.utils.tensorboard.writer.SummaryWriter
Instantiated ``SummaryWriter`` class.
close_after_train : bool (default=True)
Whether to close the ``SummaryWriter`` object once training
finishes. Set this parameter to False if you want to continue
logging with the same writer or if you use it as a context
manager.
keys_ignored : str or list of str (default=None)
Key or list of keys that should not be logged to
tensorboard. Note that in addition to the keys provided by the
user, keys such as those starting with 'event_' or ending on
'_best' are ignored by default.
key_mapper : callable or function (default=rename_tensorboard_key)
This function maps a key name from the history to a tag in
tensorboard. This is useful because tensorboard can
automatically group similar tags if their names start with the
same prefix, followed by a forward slash. By default, this
callback will prefix all keys that start with "train" or "valid"
with the "Loss/" prefix.
.. _tensorboard: https://www.tensorflow.org/tensorboard/
"""
def __init__(
self,
writer,
close_after_train=True,
keys_ignored=None,
key_mapper=rename_tensorboard_key,
):
self.writer = writer
self.close_after_train = close_after_train
self.keys_ignored = keys_ignored
self.key_mapper = key_mapper

def initialize(self):
self.first_batch_ = True

keys_ignored = self.keys_ignored
if isinstance(keys_ignored, str):
keys_ignored = [keys_ignored]
self.keys_ignored_ = set(keys_ignored or [])
self.keys_ignored_.add('batches')
return self

def on_batch_end(self, net, **kwargs):
self.first_batch_ = False

def add_scalar_maybe(self, history, key, tag, global_step=None):
"""Add a scalar value from the history to TensorBoard
Will catch errors like missing keys or wrong value types.
Parameters
----------
history : skorch.History
History object saved as attribute on the neural net.
key : str
Key of the desired value in the history.
tag : str
Name of the tag used in TensorBoard.
global_step : int or None
Global step value to record.
"""
hist = history[-1]
val = hist.get(key)
if val is None:
return

global_step = global_step if global_step is not None else hist['epoch']
with suppress(NotImplementedError):
# pytorch raises NotImplementedError on wrong types
self.writer.add_scalar(
tag=tag,
scalar_value=val,
global_step=global_step,
)

def on_epoch_end(self, net, **kwargs):
"""Automatically log values from the last history step."""
history = net.history
hist = history[-1]
epoch = hist['epoch']

for key in filter_log_keys(hist, keys_ignored=self.keys_ignored_):
tag = self.key_mapper(key)
self.add_scalar_maybe(history, key=key, tag=tag, global_step=epoch)

def on_train_end(self, net, **kwargs):
if self.close_after_train:
self.writer.close()
Loading

0 comments on commit b8cb53f

Please sign in to comment.