Skip to content

Commit

Permalink
Fix/val crit state (#539)
Browse files Browse the repository at this point in the history
* Validation criterion now tries state first

* Update changelog

* Add tests
  • Loading branch information
MattPainter01 authored and ethanwharris committed Apr 12, 2019
1 parent a640c8f commit 098d5a1
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug where the notebook check raised ModuleNotFoundError when IPython not installed
- Fixed a memory leak with metrics that causes issues with very long epochs
- Fixed a bug with the once and once_per_epoch decorators
- Fixed a bug where the test criterion wouldn't accept a function of state

## [0.3.0] - 2019-02-28
### Added
Expand Down
86 changes: 84 additions & 2 deletions tests/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,6 @@ def test_fit_criterion(self):
loss = torch.tensor([2.0], requires_grad=True)
def crit_sig(y_pred, y_true):
return loss
# criterion = Mock(return_value=loss)
criterion = create_autospec(crit_sig)

metric_list = MagicMock()
Expand All @@ -1088,6 +1087,45 @@ def crit_sig(y_pred, y_true):
self.assertTrue(criterion.call_args_list[0][0][0] == 5)
self.assertTrue(criterion.call_args_list[0][0][1].item() == 1.0)

def test_fit_criterion_passed_state(self):
data = [(torch.Tensor([1]), torch.Tensor([1])), (torch.Tensor([2]), torch.Tensor([2])),
(torch.Tensor([3]), torch.Tensor([3]))]
generator = DataLoader(data)
train_steps = len(data)
epochs = 1
torchmodel = MagicMock()
torchmodel.return_value = 5
optimizer = MagicMock()
optimizer.step = lambda closure: closure()

loss = torch.tensor([2.0], requires_grad=True)
def crit_sig(state):
return loss
criterion = create_autospec(crit_sig)

metric_list = MagicMock()
callback_list = MagicMock()
tb.CallbackListInjection = Mock(return_value=callback_list)

state = make_state[
tb.MAX_EPOCHS: epochs, tb.STOP_TRAINING: False, tb.MODEL: torchmodel, tb.CRITERION: criterion,
tb.OPTIMIZER: optimizer, tb.INF_TRAIN_LOADING: False,
tb.METRIC_LIST: metric_list, tb.CALLBACK_LIST: callback_list, tb.DEVICE: 'cpu',
tb.DATA_TYPE: torch.float,
tb.HISTORY: [], tb.TRAIN_GENERATOR: generator, tb.TRAIN_STEPS: train_steps, tb.EPOCH: 0,
tb.BACKWARD_ARGS: {}
]

torchbearertrial = Trial(torchmodel, optimizer, criterion, [], callbacks=[])
torchbearertrial.train = Mock()
torchbearertrial.pass_state = True
torchbearertrial.state = {tb.TRAIN_GENERATOR: generator, tb.CALLBACK_LIST: callback_list,
tb.TRAIN_DATA: (generator, train_steps), tb.INF_TRAIN_LOADING: False,}

torchbearertrial._fit_pass(state)
self.assertTrue(criterion.call_count == 3)
self.assertTrue(criterion.call_args_list[0][0][0] == state)

def test_fit_backward(self):
data = [(torch.Tensor([1]), torch.Tensor([1])), (torch.Tensor([2]), torch.Tensor([2])),
(torch.Tensor([3]), torch.Tensor([3]))]
Expand Down Expand Up @@ -1466,8 +1504,12 @@ def test_criterion(self):
torchmodel.return_value = 5
optimizer = MagicMock()

def spec_crit(y_pred, y_true):
pass

loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)
criterion = create_autospec(spec_crit)
criterion.return_value = loss

metric_list = MagicMock()
metric_list.process.return_value = {'test': 0}
Expand All @@ -1493,6 +1535,46 @@ def test_criterion(self):
self.assertTrue(criterion.call_args_list[0][0][0] == 5)
self.assertTrue(criterion.call_args_list[0][0][1].item() == 1.0)

def test_criterion_passed_state(self):
data = [(torch.Tensor([1]), torch.Tensor([1])), (torch.Tensor([2]), torch.Tensor([2])),
(torch.Tensor([3]), torch.Tensor([3]))]
generator = DataLoader(data)
steps = len(data)
epochs = 1
torchmodel = MagicMock()
torchmodel.return_value = 5
optimizer = MagicMock()

def spec_crit(state):
pass

loss = torch.tensor([2.0], requires_grad=True)
criterion = create_autospec(spec_crit)
criterion.return_value = loss

metric_list = MagicMock()
metric_list.process.return_value = {'test': 0}
metric_list.process_final.return_value = {'test': 2}
callback_list = MagicMock()
tb.CallbackListInjection = Mock(return_value=callback_list)

state = make_state[
tb.MAX_EPOCHS: epochs, tb.STOP_TRAINING: False, tb.MODEL: torchmodel, tb.CRITERION: criterion,
tb.OPTIMIZER: optimizer,
tb.METRIC_LIST: metric_list, tb.CALLBACK_LIST: callback_list, tb.DEVICE: 'cpu',
tb.DATA_TYPE: torch.float, tb.HISTORY: [], tb.GENERATOR: generator, tb.STEPS: steps, tb.EPOCH: 0,
tb.X: data[0][0], tb.Y_TRUE: data[0][1], tb.SAMPLER: tb.trial.Sampler(load_batch_standard)
]

torchbearertrial = Trial(torchmodel, optimizer, criterion, [], callbacks=[])
torchbearertrial.train = Mock()
torchbearertrial.pass_state = False
torchbearertrial.state = {tb.GENERATOR: generator, tb.CALLBACK_LIST: callback_list}

torchbearertrial._test_pass(state)
self.assertTrue(criterion.call_count == 3)
self.assertTrue(criterion.call_args_list[0][0][0] == state)

def test_metric_process(self):
data = [(torch.Tensor([1]), torch.Tensor([1])), (torch.Tensor([2]), torch.Tensor([2])),
(torch.Tensor([3]), torch.Tensor([3]))]
Expand Down
6 changes: 5 additions & 1 deletion torchbearer/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,11 @@ def _test_pass(self, state):

# Loss and metrics
if torchbearer.Y_TRUE in state:
state[torchbearer.LOSS] = state[torchbearer.CRITERION](state[torchbearer.Y_PRED],
# Loss Calculation
try:
state[torchbearer.LOSS] = state[torchbearer.CRITERION](state)
except TypeError:
state[torchbearer.LOSS] = state[torchbearer.CRITERION](state[torchbearer.Y_PRED],
state[torchbearer.Y_TRUE])
state[torchbearer.CALLBACK_LIST].on_criterion_validation(state)
state[torchbearer.METRICS] = state[torchbearer.METRIC_LIST].process(state.data)
Expand Down

0 comments on commit 098d5a1

Please sign in to comment.