Skip to content

Commit

Permalink
Feature/timer callback (#252)
Browse files Browse the repository at this point in the history
* Add timer callback

* Add timings to state after each stage

* Remove unnecessary stages

* Update changelog

* Update timer test
  • Loading branch information
MattPainter01 authored and ethanwharris committed Jul 25, 2018
1 parent 13c755f commit 7c88f1a
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added a on_validation_criterion callback hook
- Added a DatasetValidationSplitter which can be used to create a validation split if required for datasets like Cifar10 or MNIST
- Added simple timer callback
### Changed
### Deprecated
### Removed
Expand Down
37 changes: 37 additions & 0 deletions tests/callbacks/test_timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from unittest import TestCase
from unittest.mock import patch, Mock

from torchbearer.callbacks import TimerCallback


class TestTimer(TestCase):
@patch('time.time')
def test_update_time(self, time):
time.return_value = 0
timer = TimerCallback()
time.return_value = 1
timer.update_time('test', {})
self.assertTrue(timer.get_timings()['test'] == 1)

time.return_value = 3
timer.update_time('test_2', {})
self.assertTrue(timer.get_timings()['test_2'] == 2)

def test_calls(self):
timer = TimerCallback()
timer.update_time = Mock()

timer.on_start({})
timer.on_start_training({})
timer.on_start_epoch({})
timer.on_sample({})
timer.on_forward({})
timer.on_criterion({})
timer.on_backward({})
timer.on_step_training({})
timer.on_start_validation({})
timer.on_sample_validation({})
timer.on_forward_validation({})
timer.on_criterion_validation({})
timer.on_step_validation({})
self.assertTrue(timer.update_time.call_count == 13)
5 changes: 5 additions & 0 deletions torchbearer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
:members:
:undoc-members:
.. automodule:: torchbearer.callbacks.timer
:members:
:undoc-members:
Tensorboard
------------------------------------
Expand Down Expand Up @@ -80,3 +84,4 @@
from .weight_decay import *
from .aggregate_predictions import *
from .decorators import *
from .timer import *
74 changes: 74 additions & 0 deletions torchbearer/callbacks/timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import time
from torchbearer.callbacks import Callback
import torchbearer


class TimerCallback(Callback):
def __init__(self):
""" Timer callback that aggregates timings for each stage of model execution
"""
super().__init__()
self.t0 = time.time()
self.time_dict = {}

def update_time(self, text, state):
self.time_dict[text] = time.time() - self.t0
state[torchbearer.TIMINGS] = self.time_dict
self.t0 = time.time()

def on_start(self, state):
self.t0 = time.time()
self.update_time('OnStart', state)

def on_start_training(self, state):
super().on_start_training(state)
self.update_time('OnStartTraining', state)

def on_start_epoch(self, state):
super().on_start_epoch(state)
self.update_time('OnStartEpoch', state)

def on_sample(self, state):
super().on_sample(state)
self.update_time('OnSample', state)

def on_forward(self, state):
super().on_forward(state)
self.update_time('OnForward', state)

def on_criterion(self, state):
super().on_criterion(state)
self.update_time('OnCriterion', state)

def on_backward(self, state):
super().on_backward(state)
self.update_time('OnBackward', state)

def on_step_training(self, state):
super().on_step_training(state)
self.update_time('OnStep', state)

def on_start_validation(self, state):
super().on_start_validation(state)
self.update_time('OnStartValidation', state)

def on_sample_validation(self, state):
super().on_sample_validation(state)
self.update_time('OnSampleValidation', state)

def on_forward_validation(self, state):
super().on_forward_validation(state)
self.update_time('OnForwardValidation', state)

def on_criterion_validation(self, state):
super().on_criterion_validation(state)
self.update_time('OnCriterionValidation', state)

def on_step_validation(self, state):
super().on_step_validation(state)
self.update_time('OnStepValidation', state)

def get_timings(self):
return self.time_dict


1 change: 1 addition & 0 deletions torchbearer/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ def state_key(key):
LOSS = state_key('loss')
FINAL_PREDICTIONS = state_key('final_predictions')
BATCH = state_key('t')
TIMINGS = state_key('timings')

0 comments on commit 7c88f1a

Please sign in to comment.