Skip to content

Commit

Permalink
Add validation criterion callback hook (#235)
Browse files Browse the repository at this point in the history
* Add validation criterion callback hook

* Fix bug in decorator tests
  • Loading branch information
MattPainter01 authored and ethanwharris committed Jul 23, 2018
1 parent 210b04a commit 50dd297
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 2 deletions.
1 change: 1 addition & 0 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def test_empty_methods(self):
self.assertIsNone(callback.on_forward_validation({}))
self.assertIsNone(callback.on_end_validation({}))
self.assertIsNone(callback.on_step_validation({}))
self.assertIsNone(callback.on_criterion_validation({}))


class TestCallbackList(TestCase):
Expand Down
8 changes: 7 additions & 1 deletion tests/callbacks/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ def example(state):
state = 'test'
self.assertTrue(callbacks.on_forward_validation(example).on_forward_validation(state) == state)

def test_on_criterion_validation(self):
def example(state):
return state
state = 'test'
self.assertTrue(callbacks.on_criterion_validation(example).on_criterion_validation(state) == state)

def test_on_end_validation(self):
def example(state):
return state
Expand All @@ -108,7 +114,7 @@ def example(state):
state = {'test': 'test', torchbearer.LOSS: 0}
callbacks.add_to_loss(example).on_criterion(state)
self.assertTrue(state[torchbearer.LOSS] == 1)
callbacks.add_to_loss(example).on_step_validation(state)
callbacks.add_to_loss(example).on_criterion_validation(state)
self.assertTrue(state[torchbearer.LOSS] == 2)


19 changes: 19 additions & 0 deletions torchbearer/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,16 @@ def on_forward_validation(self, state):
"""
pass

def on_criterion_validation(self, state):
"""Perform some action with the given state as context after the criterion evaluation has been completed
with the validation data.
:param state: The current state dict of the :class:`Model`.
:type state: dict[str,any]
"""
pass

def on_end_validation(self, state):
"""Perform some action with the given state as context at the end of the validation loop.
Expand Down Expand Up @@ -306,6 +316,15 @@ def on_forward_validation(self, state):
"""
self._for_list(lambda callback: callback.on_forward_validation(state))

def on_criterion_validation(self, state):
"""Call on_criterion_validation on each callback in turn with the given state.
:param state: The current state dict of the :class:`Model`.
:type state: dict[str,any]
"""
self._for_list(lambda callback: callback.on_criterion_validation(state))

def on_end_validation(self, state):
"""Call on_end_validation on each callback in turn with the given state.
Expand Down
16 changes: 15 additions & 1 deletion torchbearer/callbacks/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,20 @@ def on_forward_validation(func):
return new_callback


def on_criterion_validation(func):
""" The :func:`on_criterion_validation` decorator is used to initialise a :class:`.Callback` with :meth:`~.Callback.on_criterion_validation`
calling the decorated function
:param func: The function(state) to *decorate*
:type func: function
:return: Initialised callback with :meth:`.Callback.on_criterion_validation` calling func
:rtype: :class:`.Callback`
"""
new_callback = Callback()
new_callback.on_criterion_validation = func
return new_callback


def on_end_validation(func):
""" The :func:`on_end_validation` decorator is used to initialise a :class:`.Callback` with :meth:`~.Callback.on_end_validation`
calling the decorated function
Expand Down Expand Up @@ -239,5 +253,5 @@ def add_to_loss_func(state):

new_callback = Callback()
new_callback.on_criterion = add_to_loss_func
new_callback.on_step_validation = add_to_loss_func
new_callback.on_criterion_validation = add_to_loss_func
return new_callback
1 change: 1 addition & 0 deletions torchbearer/torchbearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def _test_loop(self, state, callbacks, pass_state, batch_loader, num_steps=None)
# Loss and metrics
if torchbearer.Y_TRUE in state:
state[torchbearer.LOSS] = state[torchbearer.CRITERION](state[torchbearer.Y_PRED], state[torchbearer.Y_TRUE])
callbacks.on_criterion_validation(state)
state[torchbearer.METRICS] = state[torchbearer.METRIC_LIST].process(state)

callbacks.on_step_validation(state)
Expand Down

0 comments on commit 50dd297

Please sign in to comment.