Skip to content

Commit

Permalink
Feature/tensorboard list and dict (#563)
Browse files Browse the repository at this point in the history
* Add handling of list and dicts to tensorboard

* Fix test

* Update changelog

* Add silent fail test

* Add warning messages when failing to log

* Add warning messages when failing to log
  • Loading branch information
MattPainter01 committed Jun 11, 2019
1 parent 50caf34 commit 924336d
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Removed
- Removed the variational sub-package, this will now be packaged separately
### Fixed
- Fixed a bug where list or dictionary metrics would cause the tensorboard callback to error

## [0.3.2] - 2019-05-28
### Added
Expand Down
166 changes: 165 additions & 1 deletion tests/callbacks/test_tensor_board.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,179 @@
import os
from unittest import TestCase
import warnings

import torch
import torch.nn as nn
from mock import patch, Mock, ANY
from mock import patch, Mock, ANY, MagicMock

import torchbearer
from torchbearer.callbacks import TensorBoard, TensorBoardImages, TensorBoardProjector, TensorBoardText


class TestTensorBoard(TestCase):

@patch('tensorboardX.SummaryWriter')
@patch('torchbearer.callbacks.tensor_board.os.path.isdir')
@patch('torchbearer.callbacks.tensor_board.os.makedirs')
def test_add_metric_single(self, _, __, writer):
mock_fn = MagicMock()

def fn_test(ex, types):
def fn_test_1(tag, metric, *args, **kwargs):
if type(metric) in types:
raise ex
else:
mock_fn(tag, metric)
return fn_test_1

tb = TensorBoard()
state = {torchbearer.METRICS: {'test': 1, 'test2': [1, 2, 3], 'test3': [[1], [2], [3, 4]]}}
tb.add_metric(fn_test(NotImplementedError, [list]), 'single', state[torchbearer.METRICS]['test'])

self.assertTrue(mock_fn.call_args_list[0][0] == ('single', 1))

@patch('tensorboardX.SummaryWriter')
@patch('torchbearer.callbacks.tensor_board.os.path.isdir')
@patch('torchbearer.callbacks.tensor_board.os.makedirs')
def test_add_metric_list(self, _, __, writer):
mock_fn = MagicMock()

def fn_test(ex, types):
def fn_test_1(tag, metric, *args, **kwargs):
if type(metric) in types:
raise ex
else:
mock_fn(tag, metric)
return fn_test_1

tb = TensorBoard()
state = {torchbearer.METRICS: {'test': 1, 'test2': [1, 2, 3], 'test3': [[1], [2], [3, 4]]}}
tb.add_metric(fn_test(NotImplementedError, [list]), 'single', state[torchbearer.METRICS]['test2'])

self.assertTrue(mock_fn.call_args_list[0][0] == ('single_0', 1))
self.assertTrue(mock_fn.call_args_list[1][0] == ('single_1', 2))
self.assertTrue(mock_fn.call_args_list[2][0] == ('single_2', 3))


@patch('tensorboardX.SummaryWriter')
@patch('torchbearer.callbacks.tensor_board.os.path.isdir')
@patch('torchbearer.callbacks.tensor_board.os.makedirs')
def test_add_metric_list_of_list(self, _, __, writer):
mock_fn = MagicMock()

def fn_test(ex, types):
def fn_test_1(tag, metric, *args, **kwargs):
if type(metric) in types:
raise ex
else:
mock_fn(tag, metric)
return fn_test_1

tb = TensorBoard()
state = {torchbearer.METRICS: {'test': 1, 'test2': [1, 2, 3], 'test3': [[1], 2, [3, 4]]}}
tb.add_metric(fn_test(NotImplementedError, [list]), 'single', state[torchbearer.METRICS]['test3'])

self.assertTrue(mock_fn.call_args_list[0][0] == ('single_0_0', 1))
self.assertTrue(mock_fn.call_args_list[1][0] == ('single_1', 2))
self.assertTrue(mock_fn.call_args_list[2][0] == ('single_2_0', 3))
self.assertTrue(mock_fn.call_args_list[3][0] == ('single_2_1', 4))

@patch('tensorboardX.SummaryWriter')
@patch('torchbearer.callbacks.tensor_board.os.path.isdir')
@patch('torchbearer.callbacks.tensor_board.os.makedirs')
def test_add_metric_dict(self, _, __, writer):
mock_fn = MagicMock()

def fn_test(ex, types):
def fn_test_1(tag, metric, *args, **kwargs):
if type(metric) in types:
raise ex
else:
mock_fn(tag, metric)
return fn_test_1

tb = TensorBoard()
state = {torchbearer.METRICS: {'test': {'key1': 2, 'key2': 3}}}
tb.add_metric(fn_test(NotImplementedError, [list, dict]), 'single', state[torchbearer.METRICS]['test'])

call_args = list(mock_fn.call_args_list)
call_args.sort()
self.assertTrue(call_args[0][0] == ('single_key1', 2))
self.assertTrue(call_args[1][0] == ('single_key2', 3))

@patch('tensorboardX.SummaryWriter')
@patch('torchbearer.callbacks.tensor_board.os.path.isdir')
@patch('torchbearer.callbacks.tensor_board.os.makedirs')
def test_add_metric_dict_and_list(self, _, __, writer):
mock_fn = MagicMock()

def fn_test(ex, types):
def fn_test_1(tag, metric, *args, **kwargs):
if type(metric) in types:
raise ex
else:
mock_fn(tag, metric)
return fn_test_1

tb = TensorBoard()
state = {torchbearer.METRICS: {'test': {'key1': 2, 'key2': [3, 4]}}}
tb.add_metric(fn_test(NotImplementedError, [list, dict]), 'single', state[torchbearer.METRICS]['test'])

call_args = list(mock_fn.call_args_list)
call_args.sort()
self.assertTrue(call_args[0][0] == ('single_key1', 2))
self.assertTrue(call_args[1][0] == ('single_key2_0', 3))
self.assertTrue(call_args[2][0] == ('single_key2_1', 4))

@patch('tensorboardX.SummaryWriter')
@patch('torchbearer.callbacks.tensor_board.os.path.isdir')
@patch('torchbearer.callbacks.tensor_board.os.makedirs')
def test_add_metric_fail_iterable(self, _, __, writer):
mock_fn = MagicMock()

def fn_test(ex, types):
def fn_test_1(tag, metric, *args, **kwargs):
if type(metric) in types:
raise ex
else:
mock_fn(tag, metric)
return fn_test_1

tb = TensorBoard()
state = {torchbearer.METRICS: {'test': 0.1}}
with warnings.catch_warnings(record=True) as w:
tb.add_metric(fn_test(NotImplementedError, [list, dict, float]), 'single', state[torchbearer.METRICS]['test'])
self.assertTrue(len(w) == 1)

call_args = list(mock_fn.call_args_list)
call_args.sort()
self.assertTrue(len(call_args) == 0)

@patch('tensorboardX.SummaryWriter')
@patch('torchbearer.callbacks.tensor_board.os.path.isdir')
@patch('torchbearer.callbacks.tensor_board.os.makedirs')
def test_add_metric_fail(self, _, __, writer):
mock_fn = MagicMock()

def fn_test(ex, types):
def fn_test_1(tag, metric, *args, **kwargs):
if type(metric) in types:
raise ex
else:
mock_fn(tag, metric)
return fn_test_1

tb = TensorBoard()
state = {torchbearer.METRICS: {'test': 0.1}}
with warnings.catch_warnings(record=True) as w:
tb.add_metric(fn_test(Exception, [float]), 'single', state[torchbearer.METRICS]['test'])
self.assertTrue(len(w) == 1)

call_args = list(mock_fn.call_args_list)
call_args.sort()
self.assertTrue(len(call_args) == 0)


@patch('tensorboardX.SummaryWriter')
@patch('visdom.Visdom')
@patch('torchbearer.callbacks.tensor_board.os.path.isdir')
Expand Down
29 changes: 23 additions & 6 deletions torchbearer/callbacks/tensor_board.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import os
import warnings

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -164,6 +165,22 @@ def on_start(self, state):
self.log_dir = os.path.join(self.log_dir, state[torchbearer.MODEL].__class__.__name__ + '_' + self.comment)
self.writer = self.get_writer(visdom=self.visdom, visdom_params=self.visdom_params)

@staticmethod
def add_metric(add_fn, tag, metric, *args, **kwargs):
try:
add_fn(tag, metric, *args, **kwargs)
except NotImplementedError:
try:
for key, met in enumerate(metric):
if isinstance(metric, dict):
key, met = met, metric[met]

AbstractTensorBoard.add_metric(add_fn, tag+'_{}'.format(key), met, *args, **kwargs)
except TypeError as e:
warnings.warn('Failed to log metric to tensorboard with error: {}'.format(e))
except Exception as e:
warnings.warn('Failed to log metric to tensorboard with error: {}'.format(e))

def on_end(self, state):
self.close_writer()

Expand Down Expand Up @@ -227,21 +244,21 @@ def on_step_training(self, state):
if self.write_batch_metrics and state[torchbearer.BATCH] % self.batch_step_size == 0:
for metric in state[torchbearer.METRICS]:
if self.visdom:
self.batch_writer.add_scalar(metric, state[torchbearer.METRICS][metric],
self.add_metric(self.batch_writer.add_scalar, metric, state[torchbearer.METRICS][metric],
state[torchbearer.EPOCH] * state[torchbearer.TRAIN_STEPS] + state[
torchbearer.BATCH], main_tag='batch')
else:
self.batch_writer.add_scalar('batch/' + metric, state[torchbearer.METRICS][metric], state[torchbearer.BATCH])
self.add_metric(self.batch_writer.add_scalar, 'batch/' + metric, state[torchbearer.METRICS][metric], state[torchbearer.BATCH])

def on_step_validation(self, state):
if self.write_batch_metrics and state[torchbearer.BATCH] % self.batch_step_size == 0:
for metric in state[torchbearer.METRICS]:
if self.visdom:
self.batch_writer.add_scalar(metric, state[torchbearer.METRICS][metric],
self.add_metric(self.batch_writer.add_scalar, metric, state[torchbearer.METRICS][metric],
state[torchbearer.EPOCH] * state[torchbearer.TRAIN_STEPS] + state[
torchbearer.BATCH], main_tag='batch')
else:
self.batch_writer.add_scalar('batch/' + metric, state[torchbearer.METRICS][metric], state[torchbearer.BATCH])
self.add_metric(self.batch_writer.add_scalar, 'batch/' + metric, state[torchbearer.METRICS][metric], state[torchbearer.BATCH])

def on_end_epoch(self, state):
if self.write_batch_metrics and not self.visdom:
Expand All @@ -250,10 +267,10 @@ def on_end_epoch(self, state):
if self.write_epoch_metrics:
for metric in state[torchbearer.METRICS]:
if self.visdom:
self.writer.add_scalar(metric, state[torchbearer.METRICS][metric], state[torchbearer.EPOCH],
self.add_metric(self.writer.add_scalar, metric, state[torchbearer.METRICS][metric], state[torchbearer.EPOCH],
main_tag='epoch')
else:
self.writer.add_scalar('epoch/' + metric, state[torchbearer.METRICS][metric], state[torchbearer.EPOCH])
self.add_metric(self.writer.add_scalar, 'epoch/' + metric, state[torchbearer.METRICS][metric], state[torchbearer.EPOCH])

def on_end(self, state):
super(TensorBoard, self).on_end(state)
Expand Down

0 comments on commit 924336d

Please sign in to comment.