Skip to content

Commit

Permalink
Fix/replay validation (#520)
Browse files Browse the repository at this point in the history
* Fix none steps validation bug

* Update changelog
  • Loading branch information
MattPainter01 committed Mar 1, 2019
1 parent 97c58a0 commit 8c7ad16
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 2 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [Unreleased]
### Added
### Changed
### Deprecated
### Removed
### Fixed
- Fixed bug where replay errored when train or val steps were None

## [0.3.0] - 2019-02-28
### Added
- Added torchbearer.variational, a sub-package for implementations of state of the art variational auto-encoders
Expand Down
22 changes: 22 additions & 0 deletions tests/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -1780,6 +1780,28 @@ def test_replay_callback_calls(self):
self.assertTrue(callback.on_sample.call_count == 100)
self.assertTrue(callback.on_sample_validation.call_count == 50)

def test_replay_none_train_steps(self):
t = Trial(MagicMock())
callback = MagicMock()
history = [((None, 5), {'test': i, 'val_test2': i+1}) for i in range(10)]

t.state[tb.HISTORY] = history
t.replay(callbacks=[callback], verbose=0)
self.assertEqual(callback.on_start.call_count, 1)
self.assertTrue(callback.on_sample.call_count == 0)
self.assertTrue(callback.on_sample_validation.call_count == 50)

def test_replay_none_validation_steps(self):
t = Trial(MagicMock())
callback = MagicMock()
history = [((10, None), {'test': i}) for i in range(10)]

t.state[tb.HISTORY] = history
t.replay(callbacks=[callback], verbose=0)
self.assertEqual(callback.on_start.call_count, 1)
self.assertTrue(callback.on_sample.call_count == 100)
self.assertTrue(callback.on_sample_validation.call_count == 0)

def test_replay_one_batch_true(self):
t = Trial(MagicMock())
callback = MagicMock()
Expand Down
4 changes: 2 additions & 2 deletions torchbearer/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ def _replay_pass(self, state, callback_list):
all_metrics = state[torchbearer.METRICS]

# Training pass
state[torchbearer.STEPS] = state[torchbearer.TRAIN_STEPS]
state[torchbearer.STEPS] = state[torchbearer.TRAIN_STEPS] if state[torchbearer.TRAIN_STEPS] is not None else 0
state[torchbearer.METRICS] = {key: all_metrics[key] for key in all_metrics.keys() if "val_" not in key}
callback_list.on_start_training(state)
for state[torchbearer.BATCH] in range(state[torchbearer.STEPS]):
Expand All @@ -947,7 +947,7 @@ def _replay_pass(self, state, callback_list):

# Validation pass
if not state[torchbearer.STOP_TRAINING]:
state[torchbearer.STEPS] = state[torchbearer.VALIDATION_STEPS]
state[torchbearer.STEPS] = state[torchbearer.VALIDATION_STEPS] if state[torchbearer.VALIDATION_STEPS] is not None else 0
state[torchbearer.METRICS] = {key: all_metrics[key] for key in all_metrics.keys() if "val_" in key}
callback_list.on_start_validation(state)
for state[torchbearer.BATCH] in range(state[torchbearer.STEPS]):
Expand Down

0 comments on commit 8c7ad16

Please sign in to comment.