-
Notifications
You must be signed in to change notification settings - Fork 66
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add timer callback * Add timings to state after each stage * Remove unnecessary stages * Update changelog * Update timer test
- Loading branch information
1 parent
13c755f
commit 7c88f1a
Showing
5 changed files
with
118 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from unittest import TestCase | ||
from unittest.mock import patch, Mock | ||
|
||
from torchbearer.callbacks import TimerCallback | ||
|
||
|
||
class TestTimer(TestCase): | ||
@patch('time.time') | ||
def test_update_time(self, time): | ||
time.return_value = 0 | ||
timer = TimerCallback() | ||
time.return_value = 1 | ||
timer.update_time('test', {}) | ||
self.assertTrue(timer.get_timings()['test'] == 1) | ||
|
||
time.return_value = 3 | ||
timer.update_time('test_2', {}) | ||
self.assertTrue(timer.get_timings()['test_2'] == 2) | ||
|
||
def test_calls(self): | ||
timer = TimerCallback() | ||
timer.update_time = Mock() | ||
|
||
timer.on_start({}) | ||
timer.on_start_training({}) | ||
timer.on_start_epoch({}) | ||
timer.on_sample({}) | ||
timer.on_forward({}) | ||
timer.on_criterion({}) | ||
timer.on_backward({}) | ||
timer.on_step_training({}) | ||
timer.on_start_validation({}) | ||
timer.on_sample_validation({}) | ||
timer.on_forward_validation({}) | ||
timer.on_criterion_validation({}) | ||
timer.on_step_validation({}) | ||
self.assertTrue(timer.update_time.call_count == 13) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import time | ||
from torchbearer.callbacks import Callback | ||
import torchbearer | ||
|
||
|
||
class TimerCallback(Callback): | ||
def __init__(self): | ||
""" Timer callback that aggregates timings for each stage of model execution | ||
""" | ||
super().__init__() | ||
self.t0 = time.time() | ||
self.time_dict = {} | ||
|
||
def update_time(self, text, state): | ||
self.time_dict[text] = time.time() - self.t0 | ||
state[torchbearer.TIMINGS] = self.time_dict | ||
self.t0 = time.time() | ||
|
||
def on_start(self, state): | ||
self.t0 = time.time() | ||
self.update_time('OnStart', state) | ||
|
||
def on_start_training(self, state): | ||
super().on_start_training(state) | ||
self.update_time('OnStartTraining', state) | ||
|
||
def on_start_epoch(self, state): | ||
super().on_start_epoch(state) | ||
self.update_time('OnStartEpoch', state) | ||
|
||
def on_sample(self, state): | ||
super().on_sample(state) | ||
self.update_time('OnSample', state) | ||
|
||
def on_forward(self, state): | ||
super().on_forward(state) | ||
self.update_time('OnForward', state) | ||
|
||
def on_criterion(self, state): | ||
super().on_criterion(state) | ||
self.update_time('OnCriterion', state) | ||
|
||
def on_backward(self, state): | ||
super().on_backward(state) | ||
self.update_time('OnBackward', state) | ||
|
||
def on_step_training(self, state): | ||
super().on_step_training(state) | ||
self.update_time('OnStep', state) | ||
|
||
def on_start_validation(self, state): | ||
super().on_start_validation(state) | ||
self.update_time('OnStartValidation', state) | ||
|
||
def on_sample_validation(self, state): | ||
super().on_sample_validation(state) | ||
self.update_time('OnSampleValidation', state) | ||
|
||
def on_forward_validation(self, state): | ||
super().on_forward_validation(state) | ||
self.update_time('OnForwardValidation', state) | ||
|
||
def on_criterion_validation(self, state): | ||
super().on_criterion_validation(state) | ||
self.update_time('OnCriterionValidation', state) | ||
|
||
def on_step_validation(self, state): | ||
super().on_step_validation(state) | ||
self.update_time('OnStepValidation', state) | ||
|
||
def get_timings(self): | ||
return self.time_dict | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters