Skip to content
Permalink
Browse files

Merge pull request #473 from thomasjpfan/lr_scheduler_1_1

[MRG] ENH Updates schedulers for pytorch 1.1
  • Loading branch information...
BenjaminBossan committed May 14, 2019
2 parents 82efd5a + af04b55 commit 8f7fa15db9ca78b21b04d6539e80803a359016a2
Showing with 42 additions and 10 deletions.
  1. +31 −7 skorch/callbacks/lr_scheduler.py
  2. +11 −3 skorch/tests/callbacks/test_lr_scheduler.py
@@ -13,6 +13,7 @@
from torch.optim.lr_scheduler import MultiStepLR
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import StepLR
from torch.optim.lr_scheduler import CyclicLR as TorchCyclicLR
from torch.optim.optimizer import Optimizer
from skorch.callbacks import Callback

@@ -88,8 +89,8 @@ def simulate(self, steps, initial_lr):
step = sch.step
lrs = []
for _ in range(steps):
lrs.append(opt.param_groups[0]['lr'])
step()
lrs.append(sch.get_lr()[0])

return np.array(lrs)

@@ -124,7 +125,7 @@ def on_train_begin(self, net, **kwargs):
net, self.policy_, **self.kwargs
)

def on_epoch_begin(self, net, **kwargs):
def on_epoch_end(self, net, **kwargs):
epoch = len(net.history) - 1
if isinstance(self.lr_scheduler_, ReduceLROnPlateau):
if callable(self.monitor):
@@ -141,15 +142,17 @@ def on_epoch_begin(self, net, **kwargs):
else:
self.lr_scheduler_.step(epoch)

def on_batch_begin(self, net, training, **kwargs):
def on_batch_end(self, net, training, **kwargs):
if (
training and
hasattr(self.lr_scheduler_, 'batch_step') and
callable(self.lr_scheduler_.batch_step)
):
self.lr_scheduler_.batch_step(self.batch_idx_)

def on_batch_end(self, net, training, **kwargs):
if isinstance(self.lr_scheduler_, TorchCyclicLR):
self.lr_scheduler_.step(self.batch_idx_)

if training:
self.batch_idx_ += 1

@@ -337,11 +340,33 @@ def __init__(self, optimizer, base_lr=1e-3, max_lr=6e-3,
gamma=1., scale_fn=None, scale_mode='cycle',
last_batch_idx=-1, step_size=None):

# TODO: Remove class in 0.7
warnings.warn(
"skorch.callbacks.CyclicLR is deprecated, please use "
"torch.optim.lr_scheduler.CyclicLR instead",
DeprecationWarning
)

if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format(
type(optimizer).__name__))
self.optimizer = optimizer
self.base_lrs = _check_lr('base_lr', optimizer, base_lr)

# copied from torch.optim._lr_scheduler._LRScheduler
base_lrs = _check_lr('base_lr', optimizer, base_lr)
if last_batch_idx == -1:
for lr, group in zip(base_lrs, optimizer.param_groups):
group['lr'] = lr
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
last_batch_idx = 0
else:
for i, group in enumerate(optimizer.param_groups):
if 'initial_lr' not in group:
raise KeyError("param 'initial_lr' is not specified "
"in param_groups[{}] when resuming an optimizer".format(i))
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))

self.max_lrs = _check_lr('max_lr', optimizer, max_lr)

# TODO: Remove warning in a future release
@@ -378,8 +403,7 @@ def __init__(self, optimizer, base_lr=1e-3, max_lr=6e-3,
self.scale_fn = scale_fn
self.scale_mode = scale_mode

self.batch_step(last_batch_idx + 1)
self.last_batch_idx = last_batch_idx
self.batch_step(last_batch_idx)

def step(self, epoch=None):
"""Not used by ``CyclicLR``, use batch_step instead."""
@@ -12,11 +12,13 @@
from torch.optim.lr_scheduler import MultiStepLR
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import StepLR
from torch.optim.lr_scheduler import CyclicLR as TorchCyclicLR

from skorch import NeuralNetClassifier
from skorch.callbacks.lr_scheduler import WarmRestartLR, LRScheduler, CyclicLR


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
class TestLRCallbacks:

@pytest.mark.parametrize('policy', [StepLR, 'StepLR'])
@@ -26,7 +28,7 @@ def test_simulate_lrs_epoch_step(self, policy):
expected = np.array([1.0, 1.0, 0.1, 0.1, 0.01, 0.01])
assert np.allclose(expected, lrs)

@pytest.mark.parametrize('policy', [CyclicLR, 'CyclicLR'])
@pytest.mark.parametrize('policy', [CyclicLR, 'CyclicLR', TorchCyclicLR])
def test_simulate_lrs_batch_step(self, policy):
lr_sch = LRScheduler(
policy, base_lr=1, max_lr=5, step_size_up=4)
@@ -96,6 +98,7 @@ def test_lr_callback_steps_correctly(

@pytest.mark.parametrize('policy, kwargs', [
('CyclicLR', {}),
(TorchCyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3}),
])
def test_lr_callback_batch_steps_correctly(
self,
@@ -125,6 +128,7 @@ def test_lr_callback_batch_steps_correctly(

@pytest.mark.parametrize('policy, kwargs', [
('CyclicLR', {}),
(TorchCyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3}),
])
def test_lr_callback_batch_steps_correctly_fallback(
self,
@@ -376,6 +380,7 @@ def test_restarts_with_multiple_groups(self, classifier_module):
)


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
class TestCyclicLR():

@pytest.fixture(params=[1, 3])
@@ -477,9 +482,12 @@ def test_exp_range_mode_step_size_up_down(self, init_optimizer, num_groups):
self._test_cycle_lr(init_optimizer, scheduler, targets)

def test_batch_idx_with_none(self, init_optimizer):
scheduler = CyclicLR(init_optimizer)
with pytest.warns(DeprecationWarning):
scheduler = CyclicLR(init_optimizer)
for p_group in init_optimizer.param_groups:
assert p_group['initial_lr']
scheduler.batch_step()
assert scheduler.last_batch_idx == 0
assert scheduler.last_batch_idx == 1

def test_scale_fn(self, init_optimizer):
def scale_fn(x):

0 comments on commit 8f7fa15

Please sign in to comment.
You can’t perform that action at this time.