Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensorboard logging #512

Merged
merged 17 commits into from Sep 12, 2019
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- More careful check for wrong parameter names being passed to NeuralNet (#500)
- More helpful error messages when trying to predict using an uninitialized model
- Add TensorBoard callback for automatic logging to tensorboard

### Changed

Expand Down
Binary file added assets/tensorboard_digits.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/tensorboard_scalars.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions environment.yml
Expand Up @@ -27,6 +27,7 @@ dependencies:
- six=1.12.0
- sqlite=3.29.0
- tabulate=0.8.3
- tensorboard=1.14.0
- tk=8.6.8
- tqdm=4.32.1
- wheel=0.33.4
Expand Down
260 changes: 167 additions & 93 deletions notebooks/MNIST-torchvision.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions requirements-dev.txt
@@ -1,12 +1,15 @@
fire
flaky
future>=0.17.1
jupyter
matplotlib>=2.0.2
numpydoc
openpyxl
pandas
pillow
pylint
pytest>=3.4
pytest-cov
sphinx
sphinx_rtd_theme
tensorboard>=1.14.0
218 changes: 205 additions & 13 deletions skorch/callbacks/logging.py
Expand Up @@ -14,7 +14,34 @@
from skorch.callbacks import Callback


__all__ = ['EpochTimer', 'PrintLog', 'ProgressBar']
__all__ = ['EpochTimer', 'PrintLog', 'ProgressBar', 'TensorBoard']


def filter_log_keys(keys, keys_ignored=None):
"""Filter out keys that are generally to be ignored.

This is used by several callbacks to filter out keys from history
that should not be logged.

Parameters
----------
keys : iterable of str
All keys.

keys_ignored : iterable of str or None (default=None)
If not None, collection of extra keys to be ignored.

"""
keys_ignored = keys_ignored or ()
for key in keys:
if not (
key == 'epoch' or
(key in keys_ignored) or
key.endswith('_best') or
key.endswith('_batch_count') or
key.startswith('event_')
):
yield key


class EpochTimer(Callback):
Expand Down Expand Up @@ -62,7 +89,9 @@ class PrintLog(Callback):
----------
keys_ignored : str or list of str (default=None)
Key or list of keys that should not be part of the printed
table. Note that keys ending on '_best' are also ignored.
table. Note that in addition to the keys provided by the user,
keys such as those starting with 'event_' or ending on '_best'
are ignored by default.

sink : callable (default=print)
The target that the output string is sent to. By default, the
Expand Down Expand Up @@ -92,8 +121,6 @@ def __init__(
floatfmt='.4f',
stralign='right',
):
if isinstance(keys_ignored, str):
keys_ignored = [keys_ignored]
self.keys_ignored = keys_ignored
self.sink = sink
self.tablefmt = tablefmt
Expand All @@ -102,7 +129,11 @@ def __init__(

def initialize(self):
self.first_iteration_ = True
self.keys_ignored_ = set(self.keys_ignored or [])

keys_ignored = self.keys_ignored
if isinstance(keys_ignored, str):
keys_ignored = [keys_ignored]
self.keys_ignored_ = set(keys_ignored or [])
self.keys_ignored_.add('batches')
return self

Expand Down Expand Up @@ -140,24 +171,25 @@ def _sorted_keys(self, keys):
* all remaining keys are sorted alphabetically.
"""
sorted_keys = []

# make sure 'epoch' comes first
if ('epoch' in keys) and ('epoch' not in self.keys_ignored_):
sorted_keys.append('epoch')

for key in sorted(keys):
if not (
(key in ('epoch', 'dur')) or
(key in self.keys_ignored_) or
key.endswith('_best') or
key.endswith('_batch_count') or
key.startswith('event_')
):
# ignore keys like *_best or event_*
for key in filter_log_keys(sorted(keys), keys_ignored=self.keys_ignored_):
if key != 'dur':
sorted_keys.append(key)

# add event_* keys
for key in sorted(keys):
if key.startswith('event_') and (key not in self.keys_ignored_):
sorted_keys.append(key)

# make sure 'dur' comes last
if ('dur' in keys) and ('dur' not in self.keys_ignored_):
sorted_keys.append('dur')

return sorted_keys

def _yield_keys_formatted(self, row):
Expand Down Expand Up @@ -321,3 +353,163 @@ def on_epoch_begin(self, net, dataset_train=None, dataset_valid=None, **kwargs):

def on_epoch_end(self, net, **kwargs):
self.pbar.close()


def rename_tensorboard_key(key):
"""Rename keys from history to keys in TensorBoard

Specifically, prefixes all names with "Loss/" if they seem to be
losses.

"""
if key.startswith('train') or key.startswith('valid'):
key = 'Loss/' + key
return key


class TensorBoard(Callback):
"""Logs results from history to TensorBoard

"TensorBoard provides the visualization and tooling needed for
machine learning experimentation" (tensorboard_)

Use this callback to automatically log all interesting values from
your net's history to tensorboard after each epoch. Additionally
logs the graph of your module.

The best way to log additional information is to subclass this
callback and add your code to one of the ``on_*`` methods.

Examples
--------
>>> # Example to log the bias parameter as a histogram
>>> def extract_bias(module):
... return module.hidden.bias

>>> class MyTensorBoard(TensorBoard):
... def on_epoch_end(self, net, **kwargs):
... bias = extract_bias(net.module_)
... epoch = net.history[-1, 'epoch']
... self.writer.add_histogram('bias', bias, global_step=epoch)
... super().on_epoch_end(net, **kwargs) # call super last
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not have to be last?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, it doesn't have to. The reason why I say this is that sometimes, calling super first can lead to subtle errors. E.g. during on_batch_end, first_batch_ is set to false. If a user were to override this and call super first, then implement something that depends on first_batch_ being true, it would never trigger.


Parameters
----------
writer : torch.utils.tensorboard.writer.SummaryWriter
Instantiated ``SummaryWriter`` class.

include_graph : bool (default=True)
Whether to include a graph of the module. Turn this off if there
are problems while generating the graph.

close_after_train : bool (default=True)
Whether to close the ``SummaryWriter`` object once training
finishes. Set this parameter to False if you want to continue
logging with the same writer or if you use it as a context
manager.

keys_ignored : str or list of str (default=None)
Key or list of keys that should not be logged to
tensorboard. Note that in addition to the keys provided by the
user, keys such as those starting with 'event_' or ending on
'_best' are ignored by default.

key_mapper : callable or function (default=rename_tensorboard_key)
This function maps a key name from the history to a tag in
tensorboard. This is useful because tensorboard can
automatically group similar tags if their names start with the
same prefix, followed by a forward slash. By default, this
callback will prefix all keys that start with "train" or "valid"
with the "Loss/" prefix.

.. _tensorboard: https://www.tensorflow.org/tensorboard/

"""
def __init__(
self,
writer,
include_graph=True,
close_after_train=True,
keys_ignored=None,
key_mapper=rename_tensorboard_key,
):
self.writer = writer
self.include_graph = include_graph
self.close_after_train = close_after_train
self.keys_ignored = keys_ignored
self.key_mapper = key_mapper

def initialize(self):
self.first_batch_ = True

keys_ignored = self.keys_ignored
if isinstance(keys_ignored, str):
keys_ignored = [keys_ignored]
self.keys_ignored_ = set(keys_ignored or [])
self.keys_ignored_.add('batches')
return self

def add_graph(self, module, X):
""""Add a graph to tensorboard

This requires to run the module with a sample from the
dataset.

"""
self.writer.add_graph(module, X)

def on_batch_begin(self, net, X, **kwargs):
if self.first_batch_ and self.include_graph:
self.add_graph(net.module_, X)

def on_batch_end(self, net, **kwargs):
self.first_batch_ = False

def add_scalar_maybe(self, history, key, tag, global_step=None):
"""Add a scalar value from the history to TensorBoard

Will catch errors like missing keys or wrong value types.

Parameters
----------
history : skorch.History
History object saved as attribute on the neural net.

key : str
Key of the desired value in the history.

tag : str
Name of the tag used in TensorBoard.

global_step : int or None
Global step value to record.

"""
hist = history[-1]
val = hist.get(key)
if val is None:
return

global_step = global_step if global_step is not None else hist['epoch']
try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from contextlib import suppress

with suppress(NotImplementedError):
    self.writer...

self.writer.add_scalar(
tag=tag,
scalar_value=val,
global_step=global_step,
)
except NotImplementedError: # pytorch raises this on wrong types
pass

def on_epoch_end(self, net, **kwargs):
"""Automatically log values from the last history step."""
history = net.history
hist = history[-1]
epoch = hist['epoch']

for key in filter_log_keys(hist, keys_ignored=self.keys_ignored_):
tag = self.key_mapper(key)
self.add_scalar_maybe(history, key=key, tag=tag, global_step=epoch)

def on_train_end(self, net, **kwargs):
if self.close_after_train:
self.writer.close()