Skip to content

Commit

Permalink
Fix kwarg bug (#604)
Browse files Browse the repository at this point in the history
* Fix kwarg bug

* Fix broken tests
  • Loading branch information
ethanwharris committed Jul 3, 2019
1 parent 5a16d7a commit 8340a1f
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug in the `ClassAppearanceModel` callback
- Fixed a bug where the state given to predict was not a State object
- Fixed a bug in `ImagingCallback` that would sometimes cause `make_grid` to throw an error
- Fixed a bug where the verbose argument would not work unless given as a keyword argument

## [0.3.2] - 2019-05-28
### Added
Expand Down
36 changes: 34 additions & 2 deletions tests/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -1930,7 +1930,7 @@ def test_predict(self):
clist = MagicMock()
t.state = {torchbearer.TEST_GENERATOR: generator, torchbearer.CALLBACK_LIST: clist, torchbearer.TEST_STEPS: steps,
torchbearer.TEST_DATA: (generator, steps), torchbearer.LOADER: None}
metrics = t.predict(state)
metrics = t.predict()

self.assertEqual(clist.on_start.call_count, 1)
self.assertEqual(clist.on_start_epoch.call_count, 1)
Expand All @@ -1954,7 +1954,7 @@ def test_predict_none(self):
test_pass_mock = t._test_pass = Mock(return_value={torchbearer.FINAL_PREDICTIONS: 1})
t.state = {torchbearer.TEST_GENERATOR: generator, torchbearer.CALLBACK_LIST: None, torchbearer.TEST_STEPS: steps,
torchbearer.TEST_DATA: (generator, steps), torchbearer.LOADER: None}
metrics = t.predict(state)
metrics = t.predict()

self.assertTrue(eval_mock.call_count == 0)

Expand Down Expand Up @@ -2442,6 +2442,22 @@ def test_func(self, verbose=0):
self.assertEqual(c_inj.call_count, 1)
get_print_mock.assert_called_once_with(validation_label_letter='v', verbose=0)

@patch('torchbearer.trial.get_printer')
@patch('torchbearer.trial.CallbackListInjection')
def test_inject_printer_no_kwargs(self, c_inj, get_print_mock):
callback_list = torchbearer.callbacks.CallbackList([])

class SomeClass:
@torchbearer.inject_printer('v')
def test_func(self, verbose=0):
pass

t = SomeClass()
t.state = {torchbearer.CALLBACK_LIST: callback_list}
t.test_func(1)
self.assertEqual(c_inj.call_count, 1)
get_print_mock.assert_called_once_with(validation_label_letter='v', verbose=1)

@patch('torchbearer.trial.get_printer')
@patch('torchbearer.trial.CallbackListInjection')
def test_inject_printer_tqdm_on_epoch(self, c_inj, get_print_mock):
Expand Down Expand Up @@ -2638,6 +2654,22 @@ def test_func(self, data_key=None):
self.assertTrue(t.state[torchbearer.GENERATOR] == test_generator)
self.assertTrue(t.state[torchbearer.STEPS] == test_steps)

def test_inject_sampler_data_key_no_kwargs(self):
generator = MagicMock()
test_generator = 'test'
test_steps = 1

class SomeClass:
@torchbearer.inject_sampler(torchbearer.GENERATOR, load_batch_predict)
def test_func(self, data_key=None):
pass

t = SomeClass()
t.state = {torchbearer.GENERATOR: (generator, None), torchbearer.TEST_GENERATOR: (test_generator, test_steps), torchbearer.LOADER: None}
t.test_func(torchbearer.TEST_GENERATOR)
self.assertTrue(t.state[torchbearer.GENERATOR] == test_generator)
self.assertTrue(t.state[torchbearer.STEPS] == test_steps)

@patch('torchbearer.trial.CallbackListInjection')
def test_inject_callback(self, c_inj):
callback_list = torchbearer.callbacks.CallbackList([])
Expand Down
9 changes: 7 additions & 2 deletions torchbearer/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,13 @@ def inject_printer(validation_label_letter='v'):
Returns:
A decorator
"""
from inspect import getcallargs

def decorator(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
verbose = kwargs['verbose'] if 'verbose' in kwargs else get_default(func, 'verbose') # Populate default value
call_args = getcallargs(func, self, *args, **kwargs)
verbose = call_args['verbose'] if 'verbose' in call_args else get_default(func, 'verbose') # Populate default value
verbose = self.verbose if verbose == -1 else verbose

printer = get_printer(verbose=verbose, validation_label_letter=validation_label_letter)
Expand Down Expand Up @@ -249,6 +251,8 @@ def inject_sampler(data_key, batch_sampler):
Returns:
The decorator
"""
from inspect import getcallargs

def decorator(func):
def infinite_wrapper(self, key, generator, steps, sampler):
if generator is not None and steps is not None:
Expand All @@ -271,7 +275,8 @@ def infinite_wrapper(self, key, generator, steps, sampler):
def wrapper(self, *args, **kwargs):
sampler = batch_sampler

key = kwargs['data_key'] if 'data_key' in kwargs else data_key # Populate default value
call_args = getcallargs(func, self, *args, **kwargs)
key = call_args['data_key'] if 'data_key' in call_args else data_key # Populate default value
generator, steps = self.state[key] if self.state[key] is not None else (None, None)

if self.state[torchbearer.LOADER] is not None:
Expand Down

0 comments on commit 8340a1f

Please sign in to comment.