Skip to content

Commit

Permalink
Feature/callback docs (#279)
Browse files Browse the repository at this point in the history
* Fix init docstrings

* Fix tensorboard init docstrings

* Formatting some docstrings

* Fix Model init docstring
  • Loading branch information
MattPainter01 authored and ethanwharris committed Aug 3, 2018
1 parent eecff42 commit b0fea52
Show file tree
Hide file tree
Showing 11 changed files with 202 additions and 253 deletions.
80 changes: 39 additions & 41 deletions torchbearer/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Callback(object):
def on_start(self, state):
"""Perform some action with the given state as context at the start of a model fit.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -19,7 +19,7 @@ def on_start(self, state):
def on_start_epoch(self, state):
"""Perform some action with the given state as context at the start of each epoch.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -28,7 +28,7 @@ def on_start_epoch(self, state):
def on_start_training(self, state):
"""Perform some action with the given state as context at the start of the training loop.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -37,7 +37,7 @@ def on_start_training(self, state):
def on_sample(self, state):
"""Perform some action with the given state as context after data has been sampled from the generator.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -46,7 +46,7 @@ def on_sample(self, state):
def on_forward(self, state):
"""Perform some action with the given state as context after the forward pass (model output) has been completed.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -55,7 +55,7 @@ def on_forward(self, state):
def on_criterion(self, state):
"""Perform some action with the given state as context after the criterion has been evaluated.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -64,7 +64,7 @@ def on_criterion(self, state):
def on_backward(self, state):
"""Perform some action with the given state as context after backward has been called on the loss.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -73,7 +73,7 @@ def on_backward(self, state):
def on_step_training(self, state):
"""Perform some action with the given state as context after step has been called on the optimiser.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -82,7 +82,7 @@ def on_step_training(self, state):
def on_end_training(self, state):
"""Perform some action with the given state as context after the training loop has completed.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -91,7 +91,7 @@ def on_end_training(self, state):
def on_end_epoch(self, state):
"""Perform some action with the given state as context at the end of each epoch.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -100,7 +100,7 @@ def on_end_epoch(self, state):
def on_end(self, state):
"""Perform some action with the given state as context at the end of the model fitting.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -109,7 +109,7 @@ def on_end(self, state):
def on_start_validation(self, state):
"""Perform some action with the given state as context at the start of the validation loop.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -118,7 +118,7 @@ def on_start_validation(self, state):
def on_sample_validation(self, state):
"""Perform some action with the given state as context after data has been sampled from the validation generator.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -128,7 +128,7 @@ def on_forward_validation(self, state):
"""Perform some action with the given state as context after the forward pass (model output) has been completed
with the validation data.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -138,7 +138,7 @@ def on_criterion_validation(self, state):
"""Perform some action with the given state as context after the criterion evaluation has been completed
with the validation data.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -147,7 +147,7 @@ def on_criterion_validation(self, state):
def on_end_validation(self, state):
"""Perform some action with the given state as context at the end of the validation loop.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -156,25 +156,23 @@ def on_end_validation(self, state):
def on_step_validation(self, state):
"""Perform some action with the given state as context at the end of each validation step.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
pass


class CallbackList(Callback):
"""The :class:`CallbackList` class is a wrapper for a list of callbacks which acts as a single callback.
"""The :class:`CallbackList` class is a wrapper for a list of callbacks which acts as a single :class:`Callback` and internally calls each :class:`Callback` in the given list in turn.
:param callback_list:The list of callbacks to be wrapped. If the list contains a :class:`CallbackList`, this will be unwrapped.
:type callback_list:list
"""

def __init__(self, callback_list):
"""Create a new callback which wraps and internally calls each callback in the given list in turn.

:param callback_list:The list of callbacks to be wrapped. If the list contains a :class:`CallbackList`, this
will be unwrapped.
:type callback_list:list
"""
super().__init__()
self.callback_list = []
self.append(callback_list)
Expand All @@ -196,7 +194,7 @@ def append(self, callback_list):
def on_start(self, state):
"""Call on_start on each callback in turn with the given state.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -205,7 +203,7 @@ def on_start(self, state):
def on_start_epoch(self, state):
"""Call on_start_epoch on each callback in turn with the given state.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -214,7 +212,7 @@ def on_start_epoch(self, state):
def on_start_training(self, state):
"""Call on_start_training on each callback in turn with the given state.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -223,7 +221,7 @@ def on_start_training(self, state):
def on_sample(self, state):
"""Call on_sample on each callback in turn with the given state.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -232,7 +230,7 @@ def on_sample(self, state):
def on_forward(self, state):
"""Call on_forward on each callback in turn with the given state.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -241,7 +239,7 @@ def on_forward(self, state):
def on_criterion(self, state):
"""Call on_criterion on each callback in turn with the given state.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -250,7 +248,7 @@ def on_criterion(self, state):
def on_backward(self, state):
"""Call on_backward on each callback in turn with the given state.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -259,7 +257,7 @@ def on_backward(self, state):
def on_step_training(self, state):
"""Call on_step_training on each callback in turn with the given state.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -268,7 +266,7 @@ def on_step_training(self, state):
def on_end_training(self, state):
"""Call on_end_training on each callback in turn with the given state.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -277,7 +275,7 @@ def on_end_training(self, state):
def on_end_epoch(self, state):
"""Call on_end_epoch on each callback in turn with the given state.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -286,7 +284,7 @@ def on_end_epoch(self, state):
def on_end(self, state):
"""Call on_end on each callback in turn with the given state.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -295,7 +293,7 @@ def on_end(self, state):
def on_start_validation(self, state):
"""Call on_start_validation on each callback in turn with the given state.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -304,7 +302,7 @@ def on_start_validation(self, state):
def on_sample_validation(self, state):
"""Call on_sample_validation on each callback in turn with the given state.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -313,7 +311,7 @@ def on_sample_validation(self, state):
def on_forward_validation(self, state):
"""Call on_forward_validation on each callback in turn with the given state.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -322,7 +320,7 @@ def on_forward_validation(self, state):
def on_criterion_validation(self, state):
"""Call on_criterion_validation on each callback in turn with the given state.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -331,7 +329,7 @@ def on_criterion_validation(self, state):
def on_end_validation(self, state):
"""Call on_end_validation on each callback in turn with the given state.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand All @@ -340,7 +338,7 @@ def on_end_validation(self, state):
def on_step_validation(self, state):
"""Call on_step_validation on each callback in turn with the given state.
:param state: The current state dict of the :class:`Model`.
:param state: The current state dict of the :class:`.Model`.
:type state: dict[str,any]
"""
Expand Down

0 comments on commit b0fea52

Please sign in to comment.