Skip to content

Commit

Permalink
Revert "Fix lr_scheduler's last_epoch value at the time of initializa…
Browse files Browse the repository at this point in the history
…tion (BC BREAKING!) (#7889)"

This reverts commit 3608490.
  • Loading branch information
ezyang committed May 6, 2019
1 parent 23ba056 commit 8f51cbb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 11 deletions.
7 changes: 1 addition & 6 deletions test/test_optim.py
Expand Up @@ -13,8 +13,7 @@
from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, \
ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, _LRScheduler, \
CyclicLR, CosineAnnealingWarmRestarts
from common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \
skipIfRocm
from common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests

# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
Expand Down Expand Up @@ -283,7 +282,6 @@ def test_sgd_sparse(self):
[lambda opt: StepLR(opt, gamma=0.99999, step_size=300)]
)

@skipIfRocm
def test_adam(self):
self._test_basic_cases(
lambda weight, bias: optim.Adam([weight, bias], lr=1e-3)
Expand Down Expand Up @@ -389,7 +387,6 @@ def test_adagrad_sparse(self):
lambda opt: ReduceLROnPlateau(opt, threshold=1e-4)]
)

@skipIfRocm
def test_adamax(self):
self._test_basic_cases(
lambda weight, bias: optim.Adamax([weight, bias], lr=1e-1)
Expand All @@ -414,7 +411,6 @@ def test_rmsprop(self):
with self.assertRaisesRegex(ValueError, "Invalid momentum value: -1.0"):
optim.RMSprop(None, lr=1e-2, momentum=-1.0)

@skipIfRocm
def test_asgd(self):
self._test_basic_cases(
lambda weight, bias: optim.ASGD([weight, bias], lr=1e-3, t0=100)
Expand All @@ -439,7 +435,6 @@ def test_rprop(self):
with self.assertRaisesRegex(ValueError, "Invalid eta values: 1.0, 0.5"):
optim.Rprop(None, lr=1e-2, etas=(1.0, 0.5))

@skipIfRocm
def test_lbfgs(self):
self._test_basic_cases(
lambda weight, bias: optim.LBFGS([weight, bias]),
Expand Down
9 changes: 4 additions & 5 deletions torch/optim/lr_scheduler.py
Expand Up @@ -16,15 +16,14 @@ def __init__(self, optimizer, last_epoch=-1):
if last_epoch == -1:
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
last_epoch = 0
else:
for i, group in enumerate(optimizer.param_groups):
if 'initial_lr' not in group:
raise KeyError("param 'initial_lr' is not specified "
"in param_groups[{}] when resuming an optimizer".format(i))
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
self.step(last_epoch + 1)
self.last_epoch = last_epoch
self.step(last_epoch)

def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
Expand Down Expand Up @@ -71,9 +70,9 @@ class LambdaLR(_LRScheduler):
>>> lambda2 = lambda epoch: 0.95 ** epoch
>>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
>>> for epoch in range(100):
>>> scheduler.step()
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""

def __init__(self, optimizer, lr_lambda, last_epoch=-1):
Expand Down Expand Up @@ -145,9 +144,9 @@ class StepLR(_LRScheduler):
>>> # ...
>>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
>>> for epoch in range(100):
>>> scheduler.step()
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""

def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1):
Expand Down Expand Up @@ -182,9 +181,9 @@ class MultiStepLR(_LRScheduler):
>>> # lr = 0.0005 if epoch >= 80
>>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
>>> for epoch in range(100):
>>> scheduler.step()
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""

def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1):
Expand Down

0 comments on commit 8f51cbb

Please sign in to comment.