Skip to content
Permalink
Browse files

Tensorboard logging (#512)

* Factor out common functionality from PrintLog

.. that it will share with TensorBoard.

In addition, improve docstring and move the handling of keys_ignored
being a string to initialize method (so that init params are not
modified).

* (WIP) TensorBoard callback for automatic logging to tensorboard

There will be an update to one of the existing notebooks to use this
callback and show how to extend it.

* Add tensorboard as optional dependency

Skip TensorBoard tests if tensorboard is installed.

* Fix typo in requirements-dev.txt

* Tensorboard needs past which needs future

* SummaryWriter also needs PIL

* Fix typo in docstring

Co-Authored-By: Sergey Alexandrov <alexandrov88@gmail.com>

* Add TensorBoard to MNIST-torchvision notebook

Also demonstrates how to subclass TensorBoard callback.

* Improve docstrings for clarity

As suggested by taketwo

* Add (failing) test for add_graph with dict input

* Remove add_graph for now from TensorBoard

It works most of the time but not all of the time. Therefore, we
remove this functionality for now, so that we have at least the scalar
logging (which is more useful anyway).

Still show in the notebook how to add graphs.

* Add instruction how to start tensorboard

Addressing reviewer suggestion.

* Satisfy pylint on tensorboard install

Addresses reviewer comment

* Use contextlib.suppress instead of empty catch

Addresses reviewer comment
  • Loading branch information...
BenjaminBossan authored and ottonemo committed Sep 12, 2019
1 parent 066eb39 commit b8cb53fae69ef2d3d8eb591af525ca54354a985a
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- More careful check for wrong parameter names being passed to NeuralNet (#500)
- More helpful error messages when trying to predict using an uninitialized model
- Add TensorBoard callback for automatic logging to tensorboard

### Changed

Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -27,6 +27,7 @@ dependencies:
- six=1.12.0
- sqlite=3.29.0
- tabulate=0.8.3
- tensorboard=1.14.0
- tk=8.6.8
- tqdm=4.32.1
- wheel=0.33.4

Large diffs are not rendered by default.

@@ -1,12 +1,15 @@
fire
flaky
future>=0.17.1
jupyter
matplotlib>=2.0.2
numpydoc
openpyxl
pandas
pillow
pylint
pytest>=3.4
pytest-cov
sphinx
sphinx_rtd_theme
tensorboard>=1.14.0
@@ -2,6 +2,7 @@

import sys
import time
from contextlib import suppress
from numbers import Number
from itertools import cycle

@@ -14,7 +15,34 @@
from skorch.callbacks import Callback


__all__ = ['EpochTimer', 'PrintLog', 'ProgressBar']
__all__ = ['EpochTimer', 'PrintLog', 'ProgressBar', 'TensorBoard']


def filter_log_keys(keys, keys_ignored=None):
"""Filter out keys that are generally to be ignored.
This is used by several callbacks to filter out keys from history
that should not be logged.
Parameters
----------
keys : iterable of str
All keys.
keys_ignored : iterable of str or None (default=None)
If not None, collection of extra keys to be ignored.
"""
keys_ignored = keys_ignored or ()
for key in keys:
if not (
key == 'epoch' or
(key in keys_ignored) or
key.endswith('_best') or
key.endswith('_batch_count') or
key.startswith('event_')
):
yield key


class EpochTimer(Callback):
@@ -62,7 +90,9 @@ class PrintLog(Callback):
----------
keys_ignored : str or list of str (default=None)
Key or list of keys that should not be part of the printed
table. Note that keys ending on '_best' are also ignored.
table. Note that in addition to the keys provided by the user,
keys such as those starting with 'event_' or ending on '_best'
are ignored by default.
sink : callable (default=print)
The target that the output string is sent to. By default, the
@@ -92,8 +122,6 @@ def __init__(
floatfmt='.4f',
stralign='right',
):
if isinstance(keys_ignored, str):
keys_ignored = [keys_ignored]
self.keys_ignored = keys_ignored
self.sink = sink
self.tablefmt = tablefmt
@@ -102,7 +130,11 @@ def __init__(

def initialize(self):
self.first_iteration_ = True
self.keys_ignored_ = set(self.keys_ignored or [])

keys_ignored = self.keys_ignored
if isinstance(keys_ignored, str):
keys_ignored = [keys_ignored]
self.keys_ignored_ = set(keys_ignored or [])
self.keys_ignored_.add('batches')
return self

@@ -140,24 +172,25 @@ def _sorted_keys(self, keys):
* all remaining keys are sorted alphabetically.
"""
sorted_keys = []

# make sure 'epoch' comes first
if ('epoch' in keys) and ('epoch' not in self.keys_ignored_):
sorted_keys.append('epoch')

for key in sorted(keys):
if not (
(key in ('epoch', 'dur')) or
(key in self.keys_ignored_) or
key.endswith('_best') or
key.endswith('_batch_count') or
key.startswith('event_')
):
# ignore keys like *_best or event_*
for key in filter_log_keys(sorted(keys), keys_ignored=self.keys_ignored_):
if key != 'dur':
sorted_keys.append(key)

# add event_* keys
for key in sorted(keys):
if key.startswith('event_') and (key not in self.keys_ignored_):
sorted_keys.append(key)

# make sure 'dur' comes last
if ('dur' in keys) and ('dur' not in self.keys_ignored_):
sorted_keys.append('dur')

return sorted_keys

def _yield_keys_formatted(self, row):
@@ -321,3 +354,142 @@ def on_epoch_begin(self, net, dataset_train=None, dataset_valid=None, **kwargs):

def on_epoch_end(self, net, **kwargs):
self.pbar.close()


def rename_tensorboard_key(key):
"""Rename keys from history to keys in TensorBoard
Specifically, prefixes all names with "Loss/" if they seem to be
losses.
"""
if key.startswith('train') or key.startswith('valid'):
key = 'Loss/' + key
return key


class TensorBoard(Callback):
"""Logs results from history to TensorBoard
"TensorBoard provides the visualization and tooling needed for
machine learning experimentation" (tensorboard_)
Use this callback to automatically log all interesting values from
your net's history to tensorboard after each epoch.
The best way to log additional information is to subclass this
callback and add your code to one of the ``on_*`` methods.
Examples
--------
>>> # Example to log the bias parameter as a histogram
>>> def extract_bias(module):
... return module.hidden.bias
>>> class MyTensorBoard(TensorBoard):
... def on_epoch_end(self, net, **kwargs):
... bias = extract_bias(net.module_)
... epoch = net.history[-1, 'epoch']
... self.writer.add_histogram('bias', bias, global_step=epoch)
... super().on_epoch_end(net, **kwargs) # call super last
Parameters
----------
writer : torch.utils.tensorboard.writer.SummaryWriter
Instantiated ``SummaryWriter`` class.
close_after_train : bool (default=True)
Whether to close the ``SummaryWriter`` object once training
finishes. Set this parameter to False if you want to continue
logging with the same writer or if you use it as a context
manager.
keys_ignored : str or list of str (default=None)
Key or list of keys that should not be logged to
tensorboard. Note that in addition to the keys provided by the
user, keys such as those starting with 'event_' or ending on
'_best' are ignored by default.
key_mapper : callable or function (default=rename_tensorboard_key)
This function maps a key name from the history to a tag in
tensorboard. This is useful because tensorboard can
automatically group similar tags if their names start with the
same prefix, followed by a forward slash. By default, this
callback will prefix all keys that start with "train" or "valid"
with the "Loss/" prefix.
.. _tensorboard: https://www.tensorflow.org/tensorboard/
"""
def __init__(
self,
writer,
close_after_train=True,
keys_ignored=None,
key_mapper=rename_tensorboard_key,
):
self.writer = writer
self.close_after_train = close_after_train
self.keys_ignored = keys_ignored
self.key_mapper = key_mapper

def initialize(self):
self.first_batch_ = True

keys_ignored = self.keys_ignored
if isinstance(keys_ignored, str):
keys_ignored = [keys_ignored]
self.keys_ignored_ = set(keys_ignored or [])
self.keys_ignored_.add('batches')
return self

def on_batch_end(self, net, **kwargs):
self.first_batch_ = False

def add_scalar_maybe(self, history, key, tag, global_step=None):
"""Add a scalar value from the history to TensorBoard
Will catch errors like missing keys or wrong value types.
Parameters
----------
history : skorch.History
History object saved as attribute on the neural net.
key : str
Key of the desired value in the history.
tag : str
Name of the tag used in TensorBoard.
global_step : int or None
Global step value to record.
"""
hist = history[-1]
val = hist.get(key)
if val is None:
return

global_step = global_step if global_step is not None else hist['epoch']
with suppress(NotImplementedError):
# pytorch raises NotImplementedError on wrong types
self.writer.add_scalar(
tag=tag,
scalar_value=val,
global_step=global_step,
)

def on_epoch_end(self, net, **kwargs):
"""Automatically log values from the last history step."""
history = net.history
hist = history[-1]
epoch = hist['epoch']

for key in filter_log_keys(hist, keys_ignored=self.keys_ignored_):
tag = self.key_mapper(key)
self.add_scalar_maybe(history, key=key, tag=tag, global_step=epoch)

def on_train_end(self, net, **kwargs):
if self.close_after_train:
self.writer.close()

0 comments on commit b8cb53f

Please sign in to comment.
You can’t perform that action at this time.