New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tensorboard logging #512
Merged
Merged
Tensorboard logging #512
Changes from 11 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
85618e2
Factor out common functionality from PrintLog
BenjaminBossan 27d9a6e
(WIP) TensorBoard callback for automatic logging to tensorboard
BenjaminBossan 1ea6b4f
Add tensorboard as optional dependency
BenjaminBossan 4a8c00d
Fix typo in requirements-dev.txt
BenjaminBossan a494931
Tensorboard needs past which needs future
BenjaminBossan c2a137c
SummaryWriter also needs PIL
BenjaminBossan 289bc7b
Fix typo in docstring
BenjaminBossan 4ff9260
Add TensorBoard to MNIST-torchvision notebook
BenjaminBossan 4c24ef7
Improve docstrings for clarity
BenjaminBossan e9a1fd5
Merge branch 'tensorboard-logging' of https://github.com/skorch-dev/s…
BenjaminBossan 0ae86d9
Merge branch 'master' into tensorboard-logging
BenjaminBossan 1389062
Add (failing) test for add_graph with dict input
BenjaminBossan 3a35b7e
Merge branch 'tensorboard-logging' of https://github.com/skorch-dev/s…
BenjaminBossan 2dcdd42
Remove add_graph for now from TensorBoard
BenjaminBossan ec56dec
Add instruction how to start tensorboard
BenjaminBossan 2dbee3d
Satisfy pylint on tensorboard install
BenjaminBossan 28d41bd
Use contextlib.suppress instead of empty catch
BenjaminBossan File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,15 @@ | ||
fire | ||
flaky | ||
future>=0.17.1 | ||
jupyter | ||
matplotlib>=2.0.2 | ||
numpydoc | ||
openpyxl | ||
pandas | ||
pillow | ||
pylint | ||
pytest>=3.4 | ||
pytest-cov | ||
sphinx | ||
sphinx_rtd_theme | ||
tensorboard>=1.14.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,34 @@ | |
from skorch.callbacks import Callback | ||
|
||
|
||
__all__ = ['EpochTimer', 'PrintLog', 'ProgressBar'] | ||
__all__ = ['EpochTimer', 'PrintLog', 'ProgressBar', 'TensorBoard'] | ||
|
||
|
||
def filter_log_keys(keys, keys_ignored=None): | ||
"""Filter out keys that are generally to be ignored. | ||
|
||
This is used by several callbacks to filter out keys from history | ||
that should not be logged. | ||
|
||
Parameters | ||
---------- | ||
keys : iterable of str | ||
All keys. | ||
|
||
keys_ignored : iterable of str or None (default=None) | ||
If not None, collection of extra keys to be ignored. | ||
|
||
""" | ||
keys_ignored = keys_ignored or () | ||
for key in keys: | ||
if not ( | ||
key == 'epoch' or | ||
(key in keys_ignored) or | ||
key.endswith('_best') or | ||
key.endswith('_batch_count') or | ||
key.startswith('event_') | ||
): | ||
yield key | ||
|
||
|
||
class EpochTimer(Callback): | ||
|
@@ -62,7 +89,9 @@ class PrintLog(Callback): | |
---------- | ||
keys_ignored : str or list of str (default=None) | ||
Key or list of keys that should not be part of the printed | ||
table. Note that keys ending on '_best' are also ignored. | ||
table. Note that in addition to the keys provided by the user, | ||
keys such as those starting with 'event_' or ending on '_best' | ||
are ignored by default. | ||
|
||
sink : callable (default=print) | ||
The target that the output string is sent to. By default, the | ||
|
@@ -92,8 +121,6 @@ def __init__( | |
floatfmt='.4f', | ||
stralign='right', | ||
): | ||
if isinstance(keys_ignored, str): | ||
keys_ignored = [keys_ignored] | ||
self.keys_ignored = keys_ignored | ||
self.sink = sink | ||
self.tablefmt = tablefmt | ||
|
@@ -102,7 +129,11 @@ def __init__( | |
|
||
def initialize(self): | ||
self.first_iteration_ = True | ||
self.keys_ignored_ = set(self.keys_ignored or []) | ||
|
||
keys_ignored = self.keys_ignored | ||
if isinstance(keys_ignored, str): | ||
keys_ignored = [keys_ignored] | ||
self.keys_ignored_ = set(keys_ignored or []) | ||
self.keys_ignored_.add('batches') | ||
return self | ||
|
||
|
@@ -140,24 +171,25 @@ def _sorted_keys(self, keys): | |
* all remaining keys are sorted alphabetically. | ||
""" | ||
sorted_keys = [] | ||
|
||
# make sure 'epoch' comes first | ||
if ('epoch' in keys) and ('epoch' not in self.keys_ignored_): | ||
sorted_keys.append('epoch') | ||
|
||
for key in sorted(keys): | ||
if not ( | ||
(key in ('epoch', 'dur')) or | ||
(key in self.keys_ignored_) or | ||
key.endswith('_best') or | ||
key.endswith('_batch_count') or | ||
key.startswith('event_') | ||
): | ||
# ignore keys like *_best or event_* | ||
for key in filter_log_keys(sorted(keys), keys_ignored=self.keys_ignored_): | ||
if key != 'dur': | ||
sorted_keys.append(key) | ||
|
||
# add event_* keys | ||
for key in sorted(keys): | ||
if key.startswith('event_') and (key not in self.keys_ignored_): | ||
sorted_keys.append(key) | ||
|
||
# make sure 'dur' comes last | ||
if ('dur' in keys) and ('dur' not in self.keys_ignored_): | ||
sorted_keys.append('dur') | ||
|
||
return sorted_keys | ||
|
||
def _yield_keys_formatted(self, row): | ||
|
@@ -321,3 +353,163 @@ def on_epoch_begin(self, net, dataset_train=None, dataset_valid=None, **kwargs): | |
|
||
def on_epoch_end(self, net, **kwargs): | ||
self.pbar.close() | ||
|
||
|
||
def rename_tensorboard_key(key): | ||
"""Rename keys from history to keys in TensorBoard | ||
|
||
Specifically, prefixes all names with "Loss/" if they seem to be | ||
losses. | ||
|
||
""" | ||
if key.startswith('train') or key.startswith('valid'): | ||
key = 'Loss/' + key | ||
return key | ||
|
||
|
||
class TensorBoard(Callback): | ||
"""Logs results from history to TensorBoard | ||
|
||
"TensorBoard provides the visualization and tooling needed for | ||
machine learning experimentation" (tensorboard_) | ||
|
||
Use this callback to automatically log all interesting values from | ||
your net's history to tensorboard after each epoch. Additionally | ||
logs the graph of your module. | ||
|
||
The best way to log additional information is to subclass this | ||
callback and add your code to one of the ``on_*`` methods. | ||
|
||
Examples | ||
-------- | ||
>>> # Example to log the bias parameter as a histogram | ||
>>> def extract_bias(module): | ||
... return module.hidden.bias | ||
|
||
>>> class MyTensorBoard(TensorBoard): | ||
... def on_epoch_end(self, net, **kwargs): | ||
... bias = extract_bias(net.module_) | ||
... epoch = net.history[-1, 'epoch'] | ||
... self.writer.add_histogram('bias', bias, global_step=epoch) | ||
... super().on_epoch_end(net, **kwargs) # call super last | ||
|
||
Parameters | ||
---------- | ||
writer : torch.utils.tensorboard.writer.SummaryWriter | ||
Instantiated ``SummaryWriter`` class. | ||
|
||
include_graph : bool (default=True) | ||
Whether to include a graph of the module. Turn this off if there | ||
are problems while generating the graph. | ||
|
||
close_after_train : bool (default=True) | ||
Whether to close the ``SummaryWriter`` object once training | ||
finishes. Set this parameter to False if you want to continue | ||
logging with the same writer or if you use it as a context | ||
manager. | ||
|
||
keys_ignored : str or list of str (default=None) | ||
Key or list of keys that should not be logged to | ||
tensorboard. Note that in addition to the keys provided by the | ||
user, keys such as those starting with 'event_' or ending on | ||
'_best' are ignored by default. | ||
|
||
key_mapper : callable or function (default=rename_tensorboard_key) | ||
This function maps a key name from the history to a tag in | ||
tensorboard. This is useful because tensorboard can | ||
automatically group similar tags if their names start with the | ||
same prefix, followed by a forward slash. By default, this | ||
callback will prefix all keys that start with "train" or "valid" | ||
with the "Loss/" prefix. | ||
|
||
.. _tensorboard: https://www.tensorflow.org/tensorboard/ | ||
|
||
""" | ||
def __init__( | ||
self, | ||
writer, | ||
include_graph=True, | ||
close_after_train=True, | ||
keys_ignored=None, | ||
key_mapper=rename_tensorboard_key, | ||
): | ||
self.writer = writer | ||
self.include_graph = include_graph | ||
self.close_after_train = close_after_train | ||
self.keys_ignored = keys_ignored | ||
self.key_mapper = key_mapper | ||
|
||
def initialize(self): | ||
self.first_batch_ = True | ||
|
||
keys_ignored = self.keys_ignored | ||
if isinstance(keys_ignored, str): | ||
keys_ignored = [keys_ignored] | ||
self.keys_ignored_ = set(keys_ignored or []) | ||
self.keys_ignored_.add('batches') | ||
return self | ||
|
||
def add_graph(self, module, X): | ||
""""Add a graph to tensorboard | ||
|
||
This requires to run the module with a sample from the | ||
dataset. | ||
|
||
""" | ||
self.writer.add_graph(module, X) | ||
|
||
def on_batch_begin(self, net, X, **kwargs): | ||
if self.first_batch_ and self.include_graph: | ||
self.add_graph(net.module_, X) | ||
|
||
def on_batch_end(self, net, **kwargs): | ||
self.first_batch_ = False | ||
|
||
def add_scalar_maybe(self, history, key, tag, global_step=None): | ||
"""Add a scalar value from the history to TensorBoard | ||
|
||
Will catch errors like missing keys or wrong value types. | ||
|
||
Parameters | ||
---------- | ||
history : skorch.History | ||
History object saved as attribute on the neural net. | ||
|
||
key : str | ||
Key of the desired value in the history. | ||
|
||
tag : str | ||
Name of the tag used in TensorBoard. | ||
|
||
global_step : int or None | ||
Global step value to record. | ||
|
||
""" | ||
hist = history[-1] | ||
val = hist.get(key) | ||
if val is None: | ||
return | ||
|
||
global_step = global_step if global_step is not None else hist['epoch'] | ||
try: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. from contextlib import suppress
with suppress(NotImplementedError):
self.writer... |
||
self.writer.add_scalar( | ||
tag=tag, | ||
scalar_value=val, | ||
global_step=global_step, | ||
) | ||
except NotImplementedError: # pytorch raises this on wrong types | ||
pass | ||
|
||
def on_epoch_end(self, net, **kwargs): | ||
"""Automatically log values from the last history step.""" | ||
history = net.history | ||
hist = history[-1] | ||
epoch = hist['epoch'] | ||
|
||
for key in filter_log_keys(hist, keys_ignored=self.keys_ignored_): | ||
tag = self.key_mapper(key) | ||
self.add_scalar_maybe(history, key=key, tag=tag, global_step=epoch) | ||
|
||
def on_train_end(self, net, **kwargs): | ||
if self.close_after_train: | ||
self.writer.close() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not have to be last?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, it doesn't have to. The reason why I say this is that sometimes, calling
super
first can lead to subtle errors. E.g. duringon_batch_end
,first_batch_
is set to false. If a user were to override this and callsuper
first, then implement something that depends onfirst_batch_
being true, it would never trigger.