Skip to content

Commit

Permalink
Feature/timer update (#268)
Browse files Browse the repository at this point in the history
* Update timer callback

* Update tests

* Make timer callback also a metric

* Add callback list to state and update timer metric

* Update code style

* Moved timer callback to metrics

* Fix tests

* Update CHANGELOG.md
  • Loading branch information
MattPainter01 authored and ethanwharris committed Aug 1, 2018
1 parent cf24e5e commit 6fdd87b
Show file tree
Hide file tree
Showing 10 changed files with 269 additions and 137 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [Unreleased]
### Added
### Changed
- Timer callback can now also be used as a metric which allows display of specified timings to printers and has been moved to metrics.
### Deprecated
### Removed
### Fixed
Expand Down
37 changes: 0 additions & 37 deletions tests/callbacks/test_timer.py

This file was deleted.

87 changes: 87 additions & 0 deletions tests/metrics/test_timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from unittest import TestCase
from unittest.mock import Mock, MagicMock, patch

from torchbearer.metrics.timer import TimerMetric, _TimerMetric
import torchbearer


class TestTimer(TestCase):
def test_update_time(self):
timer = TimerMetric('test')
timerMetric = TimerMetric('test2')
timerMetric.process = Mock(return_value=1)
timer.update_time('test', timerMetric, {})
self.assertTrue(timer.get_timings()['test'] == 1)

timerMetric.process = Mock(return_value=2)
timer.update_time('test_2', timerMetric, {})
self.assertTrue(timer.get_timings()['test_2'] == 2)

def test_calls(self):
timer = TimerMetric('test')
timer.batch_timer = MagicMock()
timer.epoch_timer = MagicMock()
timer.train_timer = MagicMock()
timer.total_timer = MagicMock()
timer.valid_timer = MagicMock()

timer.on_start({})
timer.on_start_training({})
timer.on_start_epoch({})
timer.on_sample({})
timer.on_forward({})
timer.on_criterion({})
timer.on_backward({})
timer.on_step_training({})
timer.on_start_validation({})
timer.on_sample_validation({})
timer.on_forward_validation({})
timer.on_criterion_validation({})
timer.on_step_validation({})
timer.on_end_training({})
timer.on_end_validation({})
timer.on_end_epoch({})
timer.on_end({})

self.assertTrue(timer.batch_timer.process.call_count == 11)
self.assertTrue(timer.total_timer.process.call_count == 1)
self.assertTrue(timer.epoch_timer.process.call_count == 1)
self.assertTrue(timer.train_timer.process.call_count == 2)
self.assertTrue(timer.valid_timer.process.call_count == 1)

def test_process(self):
timer = TimerMetric((torchbearer.metrics.timer.ON_FORWARD, ))
timer.time_dict = {torchbearer.metrics.timer.ON_FORWARD: 1, 'test': 2}
self.assertTrue(timer.process({})[torchbearer.metrics.timer.ON_FORWARD] == 1)

def test_reset(self):
state = {torchbearer.CALLBACK_LIST: torchbearer.callbacks.CallbackList([])}

timer = TimerMetric()
self.assertTrue(state[torchbearer.CALLBACK_LIST].callback_list == [])
timer.reset(state)
self.assertIsInstance(state[torchbearer.CALLBACK_LIST].callback_list[0], TimerMetric)

timer.reset(state)
self.assertTrue(len(state[torchbearer.CALLBACK_LIST].callback_list) == 1)


class TestTimerMetric(TestCase):
@patch('time.time')
def test_process(self, time):
time.return_value = 1
timer_metric = _TimerMetric('test')
time.return_value = 2
dt = timer_metric.process({})

self.assertTrue(dt == 1)

@patch('time.time')
def test_reset(self, time):
time.return_value = 1
timer_metric = _TimerMetric('test')

time.return_value = 3
timer_metric.reset({})
self.assertTrue(timer_metric.t == 3)

1 change: 0 additions & 1 deletion torchbearer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,3 @@
from .weight_decay import *
from .aggregate_predictions import *
from .decorators import *
from .timer import *
15 changes: 9 additions & 6 deletions torchbearer/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,18 +177,21 @@ def __init__(self, callback_list):
"""
super().__init__()
self.callback_list = []
self.append(callback_list)

for i in range(len(callback_list)):
callback = callback_list[i]
def _for_list(self, function):
for callback in self.callback_list:
function(callback)

def __iter__(self):
return self.callback_list.__iter__()

def append(self, callback_list):
for callback in callback_list:
if isinstance(callback, CallbackList):
self.callback_list = self.callback_list + callback.callback_list
else:
self.callback_list.append(callback)

def _for_list(self, function):
for callback in self.callback_list:
function(callback)

def on_start(self, state):
"""Call on_start on each callback in turn with the given state.
Expand Down
74 changes: 0 additions & 74 deletions torchbearer/callbacks/timer.py

This file was deleted.

1 change: 1 addition & 0 deletions torchbearer/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@
from .decorators import *
from .roc_auc_score import *
from .primitives import *
from .timer import TimerMetric
148 changes: 148 additions & 0 deletions torchbearer/metrics/timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import time
from torchbearer.callbacks import Callback
import torchbearer
from torchbearer.metrics import Metric

ON_START_TRAINING = 'on_start_training'
ON_START_EPOCH = 'on_start_epoch'
ON_SAMPLE = 'on_sample'
ON_FORWARD = 'on_forward'
ON_CRITERION = 'on_criterion'
ON_BACKWARD = 'on_backward'
ON_STEP_TRAINING = 'on_step_training'
ON_START_VALIDATION = 'on_start_validation'
ON_SAMPLE_VALIDATION = 'on_sample_validation'
ON_FORWARD_VALIDATION = 'on_forward_validaiton'
ON_CRITERION_VALIDATION = 'on_criterion_validation'
ON_STEP_VALIDATION = 'on_step_validation'
TRAIN_TIME = 'train_time'
TOTAL_TIME = 'total_time'
VALIDATION_TIME = 'validation_time'


class TimerMetric(Callback, Metric):
def __init__(self, time_keys=()):
""" Timer callback that aggregates timings for each stage of model execution
"""
super(TimerMetric, self).__init__(name='timer')
self.t0 = time.time()
self.time_dict = {}
# self.init_keys()
self.batch_timer = _TimerMetric('t_batch')
self.epoch_timer = _TimerMetric('t_epoch')
self.train_timer = _TimerMetric('t_train')
self.valid_timer = _TimerMetric('t_valid')
self.total_timer = _TimerMetric('t_total')
self.time_keys = time_keys
self.added_callback = False

def update_time(self, text, metric, state):
self.time_dict[text] = metric.process(state)
state[torchbearer.TIMINGS] = self.time_dict

def process(self, *args):
super().process(*args)
d_out = {key: self.time_dict[key] for key in self.time_keys if key in self.time_dict}
return d_out

def reset(self, state):
super().reset(state)
if not self.added_callback:
state[torchbearer.CALLBACK_LIST].append([self])
self.added_callback = True

def on_start(self, state):
self.t0 = time.time()
self.batch_timer.reset(state)
self.epoch_timer.reset(state)
self.train_timer.reset(state)
self.valid_timer.reset(state)
self.total_timer.reset(state)

def on_start_training(self, state):
super().on_start_training(state)
self.update_time(ON_START_TRAINING, self.batch_timer, state)
self.update_time(ON_START_TRAINING, self.train_timer, state)

def on_start_epoch(self, state):
super().on_start_epoch(state)
self.update_time(ON_START_EPOCH, self.epoch_timer, state)

def on_sample(self, state):
super().on_sample(state)
self.update_time(ON_SAMPLE, self.batch_timer, state)

def on_forward(self, state):
super().on_forward(state)
self.update_time(ON_FORWARD, self.batch_timer, state)

def on_criterion(self, state):
super().on_criterion(state)
self.update_time(ON_CRITERION, self.batch_timer, state)

def on_backward(self, state):
super().on_backward(state)
self.update_time(ON_BACKWARD, self.batch_timer, state)

def on_step_training(self, state):
super().on_step_training(state)
self.update_time(ON_STEP_TRAINING, self.batch_timer, state)

def on_start_validation(self, state):
super().on_start_validation(state)
self.update_time(ON_START_VALIDATION, self.batch_timer, state)

def on_sample_validation(self, state):
super().on_sample_validation(state)
self.update_time(ON_SAMPLE_VALIDATION, self.batch_timer, state)

def on_forward_validation(self, state):
super().on_forward_validation(state)
self.update_time(ON_FORWARD_VALIDATION, self.batch_timer, state)

def on_criterion_validation(self, state):
super().on_criterion_validation(state)
self.update_time(ON_CRITERION_VALIDATION, self.batch_timer, state)

def on_step_validation(self, state):
super().on_step_validation(state)
self.update_time(ON_STEP_VALIDATION, self.batch_timer, state)

def on_end_training(self, state):
super().on_end_training(state)
self.valid_timer.reset(state)
self.batch_timer.reset(state)
self.update_time(TRAIN_TIME, self.train_timer, state)

def on_end_epoch(self, state):
super().on_end_epoch(state)
self.batch_timer.reset(state)
self.train_timer.reset(state)

def on_end(self, state):
super().on_end(state)
self.update_time(TOTAL_TIME, self.total_timer, state)
print(self.time_dict)

def on_end_validation(self, state):
super().on_end_validation(state)
self.update_time(VALIDATION_TIME, self.valid_timer, state)

def get_timings(self):
return self.time_dict


class _TimerMetric(Metric):
def __init__(self, name):
super().__init__(name)
self.t = time.time()

def process(self, *args):
super().process(*args)
dt = time.time() - self.t
self.t = time.time()
return dt

def reset(self, state):
super().reset(state)
self.t = time.time()

0 comments on commit 6fdd87b

Please sign in to comment.