-
Notifications
You must be signed in to change notification settings - Fork 66
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Update timer callback * Update tests * Make timer callback also a metric * Add callback list to state and update timer metric * Update code style * Moved timer callback to metrics * Fix tests * Update CHANGELOG.md
- Loading branch information
1 parent
cf24e5e
commit 6fdd87b
Showing
10 changed files
with
269 additions
and
137 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from unittest import TestCase | ||
from unittest.mock import Mock, MagicMock, patch | ||
|
||
from torchbearer.metrics.timer import TimerMetric, _TimerMetric | ||
import torchbearer | ||
|
||
|
||
class TestTimer(TestCase): | ||
def test_update_time(self): | ||
timer = TimerMetric('test') | ||
timerMetric = TimerMetric('test2') | ||
timerMetric.process = Mock(return_value=1) | ||
timer.update_time('test', timerMetric, {}) | ||
self.assertTrue(timer.get_timings()['test'] == 1) | ||
|
||
timerMetric.process = Mock(return_value=2) | ||
timer.update_time('test_2', timerMetric, {}) | ||
self.assertTrue(timer.get_timings()['test_2'] == 2) | ||
|
||
def test_calls(self): | ||
timer = TimerMetric('test') | ||
timer.batch_timer = MagicMock() | ||
timer.epoch_timer = MagicMock() | ||
timer.train_timer = MagicMock() | ||
timer.total_timer = MagicMock() | ||
timer.valid_timer = MagicMock() | ||
|
||
timer.on_start({}) | ||
timer.on_start_training({}) | ||
timer.on_start_epoch({}) | ||
timer.on_sample({}) | ||
timer.on_forward({}) | ||
timer.on_criterion({}) | ||
timer.on_backward({}) | ||
timer.on_step_training({}) | ||
timer.on_start_validation({}) | ||
timer.on_sample_validation({}) | ||
timer.on_forward_validation({}) | ||
timer.on_criterion_validation({}) | ||
timer.on_step_validation({}) | ||
timer.on_end_training({}) | ||
timer.on_end_validation({}) | ||
timer.on_end_epoch({}) | ||
timer.on_end({}) | ||
|
||
self.assertTrue(timer.batch_timer.process.call_count == 11) | ||
self.assertTrue(timer.total_timer.process.call_count == 1) | ||
self.assertTrue(timer.epoch_timer.process.call_count == 1) | ||
self.assertTrue(timer.train_timer.process.call_count == 2) | ||
self.assertTrue(timer.valid_timer.process.call_count == 1) | ||
|
||
def test_process(self): | ||
timer = TimerMetric((torchbearer.metrics.timer.ON_FORWARD, )) | ||
timer.time_dict = {torchbearer.metrics.timer.ON_FORWARD: 1, 'test': 2} | ||
self.assertTrue(timer.process({})[torchbearer.metrics.timer.ON_FORWARD] == 1) | ||
|
||
def test_reset(self): | ||
state = {torchbearer.CALLBACK_LIST: torchbearer.callbacks.CallbackList([])} | ||
|
||
timer = TimerMetric() | ||
self.assertTrue(state[torchbearer.CALLBACK_LIST].callback_list == []) | ||
timer.reset(state) | ||
self.assertIsInstance(state[torchbearer.CALLBACK_LIST].callback_list[0], TimerMetric) | ||
|
||
timer.reset(state) | ||
self.assertTrue(len(state[torchbearer.CALLBACK_LIST].callback_list) == 1) | ||
|
||
|
||
class TestTimerMetric(TestCase): | ||
@patch('time.time') | ||
def test_process(self, time): | ||
time.return_value = 1 | ||
timer_metric = _TimerMetric('test') | ||
time.return_value = 2 | ||
dt = timer_metric.process({}) | ||
|
||
self.assertTrue(dt == 1) | ||
|
||
@patch('time.time') | ||
def test_reset(self, time): | ||
time.return_value = 1 | ||
timer_metric = _TimerMetric('test') | ||
|
||
time.return_value = 3 | ||
timer_metric.reset({}) | ||
self.assertTrue(timer_metric.t == 3) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,4 +84,3 @@ | |
from .weight_decay import * | ||
from .aggregate_predictions import * | ||
from .decorators import * | ||
from .timer import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,3 +43,4 @@ | |
from .decorators import * | ||
from .roc_auc_score import * | ||
from .primitives import * | ||
from .timer import TimerMetric |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import time | ||
from torchbearer.callbacks import Callback | ||
import torchbearer | ||
from torchbearer.metrics import Metric | ||
|
||
ON_START_TRAINING = 'on_start_training' | ||
ON_START_EPOCH = 'on_start_epoch' | ||
ON_SAMPLE = 'on_sample' | ||
ON_FORWARD = 'on_forward' | ||
ON_CRITERION = 'on_criterion' | ||
ON_BACKWARD = 'on_backward' | ||
ON_STEP_TRAINING = 'on_step_training' | ||
ON_START_VALIDATION = 'on_start_validation' | ||
ON_SAMPLE_VALIDATION = 'on_sample_validation' | ||
ON_FORWARD_VALIDATION = 'on_forward_validaiton' | ||
ON_CRITERION_VALIDATION = 'on_criterion_validation' | ||
ON_STEP_VALIDATION = 'on_step_validation' | ||
TRAIN_TIME = 'train_time' | ||
TOTAL_TIME = 'total_time' | ||
VALIDATION_TIME = 'validation_time' | ||
|
||
|
||
class TimerMetric(Callback, Metric): | ||
def __init__(self, time_keys=()): | ||
""" Timer callback that aggregates timings for each stage of model execution | ||
""" | ||
super(TimerMetric, self).__init__(name='timer') | ||
self.t0 = time.time() | ||
self.time_dict = {} | ||
# self.init_keys() | ||
self.batch_timer = _TimerMetric('t_batch') | ||
self.epoch_timer = _TimerMetric('t_epoch') | ||
self.train_timer = _TimerMetric('t_train') | ||
self.valid_timer = _TimerMetric('t_valid') | ||
self.total_timer = _TimerMetric('t_total') | ||
self.time_keys = time_keys | ||
self.added_callback = False | ||
|
||
def update_time(self, text, metric, state): | ||
self.time_dict[text] = metric.process(state) | ||
state[torchbearer.TIMINGS] = self.time_dict | ||
|
||
def process(self, *args): | ||
super().process(*args) | ||
d_out = {key: self.time_dict[key] for key in self.time_keys if key in self.time_dict} | ||
return d_out | ||
|
||
def reset(self, state): | ||
super().reset(state) | ||
if not self.added_callback: | ||
state[torchbearer.CALLBACK_LIST].append([self]) | ||
self.added_callback = True | ||
|
||
def on_start(self, state): | ||
self.t0 = time.time() | ||
self.batch_timer.reset(state) | ||
self.epoch_timer.reset(state) | ||
self.train_timer.reset(state) | ||
self.valid_timer.reset(state) | ||
self.total_timer.reset(state) | ||
|
||
def on_start_training(self, state): | ||
super().on_start_training(state) | ||
self.update_time(ON_START_TRAINING, self.batch_timer, state) | ||
self.update_time(ON_START_TRAINING, self.train_timer, state) | ||
|
||
def on_start_epoch(self, state): | ||
super().on_start_epoch(state) | ||
self.update_time(ON_START_EPOCH, self.epoch_timer, state) | ||
|
||
def on_sample(self, state): | ||
super().on_sample(state) | ||
self.update_time(ON_SAMPLE, self.batch_timer, state) | ||
|
||
def on_forward(self, state): | ||
super().on_forward(state) | ||
self.update_time(ON_FORWARD, self.batch_timer, state) | ||
|
||
def on_criterion(self, state): | ||
super().on_criterion(state) | ||
self.update_time(ON_CRITERION, self.batch_timer, state) | ||
|
||
def on_backward(self, state): | ||
super().on_backward(state) | ||
self.update_time(ON_BACKWARD, self.batch_timer, state) | ||
|
||
def on_step_training(self, state): | ||
super().on_step_training(state) | ||
self.update_time(ON_STEP_TRAINING, self.batch_timer, state) | ||
|
||
def on_start_validation(self, state): | ||
super().on_start_validation(state) | ||
self.update_time(ON_START_VALIDATION, self.batch_timer, state) | ||
|
||
def on_sample_validation(self, state): | ||
super().on_sample_validation(state) | ||
self.update_time(ON_SAMPLE_VALIDATION, self.batch_timer, state) | ||
|
||
def on_forward_validation(self, state): | ||
super().on_forward_validation(state) | ||
self.update_time(ON_FORWARD_VALIDATION, self.batch_timer, state) | ||
|
||
def on_criterion_validation(self, state): | ||
super().on_criterion_validation(state) | ||
self.update_time(ON_CRITERION_VALIDATION, self.batch_timer, state) | ||
|
||
def on_step_validation(self, state): | ||
super().on_step_validation(state) | ||
self.update_time(ON_STEP_VALIDATION, self.batch_timer, state) | ||
|
||
def on_end_training(self, state): | ||
super().on_end_training(state) | ||
self.valid_timer.reset(state) | ||
self.batch_timer.reset(state) | ||
self.update_time(TRAIN_TIME, self.train_timer, state) | ||
|
||
def on_end_epoch(self, state): | ||
super().on_end_epoch(state) | ||
self.batch_timer.reset(state) | ||
self.train_timer.reset(state) | ||
|
||
def on_end(self, state): | ||
super().on_end(state) | ||
self.update_time(TOTAL_TIME, self.total_timer, state) | ||
print(self.time_dict) | ||
|
||
def on_end_validation(self, state): | ||
super().on_end_validation(state) | ||
self.update_time(VALIDATION_TIME, self.valid_timer, state) | ||
|
||
def get_timings(self): | ||
return self.time_dict | ||
|
||
|
||
class _TimerMetric(Metric): | ||
def __init__(self, name): | ||
super().__init__(name) | ||
self.t = time.time() | ||
|
||
def process(self, *args): | ||
super().process(*args) | ||
dt = time.time() - self.t | ||
self.t = time.time() | ||
return dt | ||
|
||
def reset(self, state): | ||
super().reset(state) | ||
self.t = time.time() |
Oops, something went wrong.