Skip to content

Commit

Permalink
Propose to fix #650 (#651)
Browse files Browse the repository at this point in the history
* Propose to fix #650

- we can check if `self._monitor` is present inside the given metrics dict

* added warnings for #650

* updated fix for #650

* Fix indent and add tests

* Fix python3 tests

Co-authored-by: Matt Painter <mp2u16@ecs.soton.ac.uk>
  • Loading branch information
Francesco Saverio Zuppichini and MattPainter01 committed Jan 16, 2020
1 parent 922397d commit c36e899
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 5 deletions.
62 changes: 61 additions & 1 deletion tests/callbacks/test_torch_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest import TestCase
from mock import patch, Mock
import warnings

import torchbearer
from torchbearer.callbacks import TorchScheduler, LambdaLR, StepLR, MultiStepLR, ExponentialLR, CosineAnnealingLR,\
Expand Down Expand Up @@ -89,7 +90,7 @@ def test_torch_scheduler_on_batch_no_monitor(self):
mock_scheduler.reset_mock()

def test_torch_scheduler_on_epoch_no_monitor(self):
state = {torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer'}
state = {torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer', torchbearer.METRICS: {}}
mock_scheduler = Mock()
mock_scheduler.return_value = mock_scheduler

Expand All @@ -115,6 +116,65 @@ def test_torch_scheduler_on_epoch_no_monitor(self):
mock_scheduler.assert_not_called()
mock_scheduler.reset_mock()

def test_monitor_not_found(self):
state = {torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer', torchbearer.METRICS: {'not_test': 1.}}
mock_scheduler = Mock()
mock_scheduler.return_value = mock_scheduler

torch_scheduler = TorchScheduler(lambda opt: mock_scheduler(opt), monitor='test', step_on_batch=False)
torch_scheduler.on_start(state)

with warnings.catch_warnings(record=True) as w:
torch_scheduler.on_start_validation(state)
self.assertTrue(len(w) == 0)

with warnings.catch_warnings(record=True) as w:
torch_scheduler.on_end_epoch(state)
self.assertTrue('[test] was not found' in str(w[0].message))

def test_monitor_found(self):
state = {torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer', torchbearer.METRICS: {'test': 1.}}
mock_scheduler = Mock()
mock_scheduler.return_value = mock_scheduler

torch_scheduler = TorchScheduler(lambda opt: mock_scheduler(opt), monitor='test', step_on_batch=False)
torch_scheduler.on_start(state)
with warnings.catch_warnings(record=True) as w:
torch_scheduler.on_start_training(state)
self.assertTrue(len(w) == 0)

with warnings.catch_warnings(record=True) as w:
torch_scheduler.on_start_validation(state)
self.assertTrue(len(w) == 0)

with warnings.catch_warnings(record=True) as w:
torch_scheduler.on_end_epoch(state)
self.assertTrue(len(w) == 0)

def test_batch_monitor_not_found(self):
state = {torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer', torchbearer.METRICS: {'not_test': 1.}}
mock_scheduler = Mock()
mock_scheduler.return_value = mock_scheduler

torch_scheduler = TorchScheduler(lambda opt: mock_scheduler(opt), monitor='test', step_on_batch=True)
torch_scheduler.on_start(state)

with warnings.catch_warnings(record=True) as w:
torch_scheduler.on_step_training(state)
self.assertTrue('[test] was not found' in str(w[0].message))

def test_batch_monitor_found(self):
state = {torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer', torchbearer.METRICS: {'test': 1.}}
mock_scheduler = Mock()
mock_scheduler.return_value = mock_scheduler

torch_scheduler = TorchScheduler(lambda opt: mock_scheduler(opt), monitor='test', step_on_batch=True)
torch_scheduler.on_start(state)

with warnings.catch_warnings(record=True) as w:
torch_scheduler.on_step_training(state)
self.assertTrue(len(w) == 0)


class TestLambdaLR(TestCase):
@patch('torch.optim.lr_scheduler.LambdaLR')
Expand Down
14 changes: 10 additions & 4 deletions torchbearer/callbacks/torch_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torchbearer

import warnings
from torchbearer.callbacks import Callback

import torch
Expand All @@ -21,15 +21,21 @@ def on_sample(self, state):

def on_step_training(self, state):
if self._step_on_batch and self._monitor is not None:
self._scheduler.step(state[torchbearer.METRICS][self._monitor])

if self._monitor in state[torchbearer.METRICS]:
self._scheduler.step(state[torchbearer.METRICS][self._monitor])
else:
warnings.warn("Monitor key [{}] was not found by the scheduler.".format(self._monitor), Warning)

def on_start_training(self, state):
if not self._step_on_batch and self._monitor is None:
self._scheduler.step(epoch=state[torchbearer.EPOCH])

def on_end_epoch(self, state):
if not self._step_on_batch and self._monitor is not None:
self._scheduler.step(state[torchbearer.METRICS][self._monitor], epoch=state[torchbearer.EPOCH])
if self._monitor in state[torchbearer.METRICS]:
self._scheduler.step(state[torchbearer.METRICS][self._monitor], epoch=state[torchbearer.EPOCH])
else:
warnings.warn("Monitor key [{}] was not found by the scheduler.".format(self._monitor), Warning)


class LambdaLR(TorchScheduler):
Expand Down

0 comments on commit c36e899

Please sign in to comment.