Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lr_scheduler's last_epoch value at the time of initialization (BC BREAKING!) #7889

Closed
wants to merge 4 commits into from

Conversation

@bado-lee
Copy link
Contributor

commented May 28, 2018

Hello everyone :) !!

I've found that lr_scheduler was initialized with last_epoch as -1.
This causes that even after the first step (not the one in init but explicit step of scheduler),
learning rate of scheduler's optimizer remains as the previous.

>>> import torch
>>> cc = torch.nn.Conv2d(10,10,3)
>>> myinitial_lr = 0.1
>>> myoptimizer = torch.optim.Adam(cc.parameters(), lr=myinitial_lr)
>>> mylrdecay = 0.5
>>> myscheduler = torch.optim.lr_scheduler.ExponentialLR(myoptimizer,mylrdecay)

# get_lr value and optimizer's lr value is not consistent, last_epoch should be 0 instead of -1
>>> myscheduler.get_lr()
[0.2]    # this is because of  get_lr calculates lr by 0.1 * 0.5^-1
>>> myscheduler.optimizer.param_groups[0]["lr"]
0.1    # this is not consistent with get_lr value
>>> myscheduler.last_epoch
-1

# values seem to be in sync but should have been decayed value of 0.05 after first step(decay)
>>> myscheduler.step()
>>> myscheduler.get_lr()
[0.1]    # this should be the value right after the init, not after first step
>>> myscheduler.optimizer.param_groups[0]["lr"]
0.1    # since this is after first step, it should have been decayed as 0.05
>>> myscheduler.last_epoch
0

>>> myscheduler.step()
>>> myscheduler.last_epoch
1
>>> myscheduler.get_lr()
[0.05]
>>> myscheduler.optimizer.param_groups[0]["lr"]
0.05
>>> myscheduler.last_epoch
1

First problem is, even after the init of lr_scheduler, you get the inconsistent parameter values.

The second problem is, you are stuck with same learning rate in the first 2 epochs if the step function of lr_scheduler is not called in the beginning of the epoch loop.
Of course, you can avoid this by calling lr_scheduler's step in the beginning,
but I don't think this is proper use since, incase of optimizer, step is called in the end of the iteration loop.

I've simply avoided all above issues by setting last_epoch as 0 after the initialization.

This also makes sense when you init with some value of last_epoch which is not -1.
For example, if you want to init with last epoch 10,
lr should not be set with decayed 1 step further. Which is
last_epoch gets +1 in the previous code.
base_lr * self.gamma ** self.last_epoch

Instead, it should be set with step 10 exact value.

I hope this fix find it's way with all your help :)
I'm really looking forward & excited to become a contributor for pytorch!
Pytorch Rocks!!

@ezyang

This comment has been minimized.

Copy link
Contributor

commented Jun 1, 2018

@pytorchbot retest this please

1 similar comment
@bado-lee

This comment has been minimized.

Copy link
Contributor Author

commented Jun 11, 2018

@pytorchbot retest this please

@bado-lee

This comment has been minimized.

Copy link
Contributor Author

commented Jun 12, 2018

@ezyang Hi can you please review this?

@ezyang

This comment has been minimized.

Copy link
Contributor

commented Jun 12, 2018

@bado-lee I don't think the change you suggested here is right. In the example code in the docs, we clearly step the scheduler before training:

>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.05     if epoch < 30
>>> # lr = 0.005    if 30 <= epoch < 80
>>> # lr = 0.0005   if epoch >= 80
>>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
>>> for epoch in range(100):
>>>     scheduler.step()
>>>     train(...)
>>>     validate(...)

Perhaps a more logical design would have been to step the scheduler at the same time you step the optimizer, but that's not how the API works today, and you'd be changing the behavior of anyone's code who had been abiding by the API previously.

What do you think about this reasoning?

@bado-lee bado-lee force-pushed the bado-lee:master branch from f8e02a9 to 411af2c Jun 12, 2018

@bado-lee

This comment has been minimized.

Copy link
Contributor Author

commented Jun 12, 2018

@ezyang
Thank you for your comment.
Yes you are are right on the docs. So I fixed them also.

There are several points about this.

  1. Meaning of "Step" should be stepping forward from the initial status.
    Previous implementation requires step to be part of the initialization. Which is ambiguous.
    Which means, the very first step in the loop is actually an initialization rather than actual "Step".
    It will be clear if step can be called only when needed.

  2. This code fixes inconsistency of results between the following before calling the very first step. (stated above)

>>> myscheduler.get_lr()
>>> myscheduler.optimizer.param_groups[0]["lr"]
  1. There is already a case where scheduler step should be called in the end of the loop
class ReduceLROnPlateau(object):
...
Example:
    >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
    >>> scheduler = ReduceLROnPlateau(optimizer, 'min')
    >>> for epoch in range(10):
    >>>     train(...)
    >>>     val_loss = validate(...)
    >>>     # Note that step should be called after validate()
    >>>     scheduler.step(val_loss)
"""

I'll update if further fixes are required.

@ezyang

This comment has been minimized.

Copy link
Contributor

commented Jun 12, 2018

Paging the backwards compatibility police @soumith @colesbury

@bado-lee

This comment has been minimized.

Copy link
Contributor Author

commented Jun 12, 2018

About backward compatibility and the documentation example,
currently there are 6 types of lr scheduler among which

1 has example with step in the end of the loop.

3 has example with scheduler.step() in the beginning of the loop (I guess this was necessary evil given that the class which it inherited from was not implemented correctly)

2 doesn't have example mentioning where to put the step.

So I think it's better off with everything in order with example giving as step in the end of the loop.
Please consider that there can be future confusion for the new comers.
If this is not corrected now.

@apaszke

This comment has been minimized.

Copy link
Member

commented Jun 12, 2018

Didn't have time to read through all the comments here, but IIRC the current impl matches that of Keras, and so it might not only be a question of BC, but also of avoiding adding confusion to API ported from a different framework.

@bado-lee

This comment has been minimized.

Copy link
Contributor Author

commented Jun 12, 2018

@apaszke Thanks for the comment.
I've looked for what you have mentioned in keras

class LearningRateScheduler(Callback):
    """Learning rate scheduler.
    # Arguments
        schedule: a function that takes an epoch index as input
            (integer, indexed from 0) and current learning rate
            and returns a new learning rate as output (float).
        verbose: int. 0: quiet, 1: update messages.
    """

    def __init__(self, schedule, verbose=0):
        super(LearningRateScheduler, self).__init__()
        self.schedule = schedule
        self.verbose = verbose

    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, 'lr'):
            raise ValueError('Optimizer must have a "lr" attribute.')
        lr = float(K.get_value(self.model.optimizer.lr))
        try:  # new API
            lr = self.schedule(epoch, lr)
        except TypeError:  # old API for backward compatibility
            lr = self.schedule(epoch)
        if not isinstance(lr, (float, np.float32, np.float64)):
            raise ValueError('The output of the "schedule" function '
                             'should be float.')
        K.set_value(self.model.optimizer.lr, lr)
        if self.verbose > 0:
            print('\nEpoch %05d: LearningRateScheduler reducing learning '
                  'rate to %s.' % (epoch + 1, lr))

Well as you said, this one in Keras specifies the function name as "on_epoch_begin" specifically which is counterpart of "step" in pytorch. Yes, this one is required to be call in the beginning of the loop.
But I think it's behavior and intention is different from what is implemented in pytorch.
The naming is not "step", it's "on_epoch_begin". Which are clearly different.
Furthermore, "on_epoch_begin" is required to have epoch number as an input variable where "step" is not.
So, IMHO, it's not valid that this implementation matches that of Keras.
And, to me it's more reasonable that naming "step" is to be matched with the behavior of the optimizer step in pytorch which is called in the end in the loop.

@bado-lee

This comment has been minimized.

Copy link
Contributor Author

commented Jun 12, 2018

https://discuss.pytorch.org/t/how-to-use-torch-optim-lr-scheduler-exponentiallr/12444/5
I think above is what pretty much everyone would do. Given the name is "step".
But current code will run with same lr for the first 200 iterations instead of 100.

@bado-lee

This comment has been minimized.

Copy link
Contributor Author

commented Jun 18, 2018

@soumith @colesbury Hi, can I get a status on this?

@yf225

This comment has been minimized.

Copy link
Contributor

commented Jul 10, 2018

@soumith @colesbury What would be the next step for this PR?

@zdevito zdevito removed their request for review Feb 13, 2019

@gchanan gchanan removed their request for review Feb 28, 2019

@lironmo

This comment has been minimized.

Copy link

commented Apr 11, 2019

any update?

@ezyang

This comment has been minimized.

Copy link
Contributor

commented Apr 15, 2019

@pytorchbot rebase this please

@bado-lee

This comment has been minimized.

Copy link
Contributor Author

commented Apr 17, 2019

@pytorchbot retest this please

@bado-lee

This comment has been minimized.

Copy link
Contributor Author

commented Apr 17, 2019

@ezyang
Hi. Can you tell me how should I proceed with this request?
I'm also not sure about the current build fails. I think(hope) they are not related to my codes.

@bado-lee

This comment has been minimized.

Copy link
Contributor Author

commented Apr 17, 2019

@pytorchbot rebase this please

facebook-github-bot added a commit that referenced this pull request May 3, 2019

Initialize last_epoch in _LRScheduler.__init__() (#20059)
Summary:
Class attributes preferably be explicitly initiated within
the __init__() call. Otherwise, overriding step() is
prone to bugs.

This patch partially reverts #7889
Pull Request resolved: #20059

Differential Revision: D15195747

Pulled By: soumith

fbshipit-source-id: 3d1a51d8c725d6f14e3e91ee94c7bc7a7d6c1713
@SsnL

This comment has been minimized.

Copy link
Collaborator

commented May 3, 2019

This breaks BC so much! All the code out there that use scheduler.step() before training now breaks silently? A lot of major repos are affected :(!

I agree that this is for the better. But I would prefer that there is some deprecation mechanism that spans like 2 versions and that we emphasize it much much more than "pay attention to the order change" in release notes.

@vfdev-5

This comment has been minimized.

Copy link
Contributor

commented May 3, 2019

A lot of major repos are affected :(!

@SsnL that's true, I had the same opinion seing this the first time. However, in the worst case, IMO, the training starts from the second LR scheduling value instead of the first. Maybe this is a not big issue for iteration-wise lr scheduling (a bit more for epoch-wise).

@SsnL

This comment has been minimized.

Copy link
Collaborator

commented May 3, 2019

@vfdev-5

  1. This generally causes lr to decay faster than they used to be, and could harm performance with hyperparameter tuned for the previous behavior.
  2. People research on training with very few gd steps. This is just devastating if they use lr schedulers.
  3. 1.0.0 and 1.0.1 are not so long ago. This patch makes it impossible for a code to behave the same on both 1.0.1 and 1.1.
@vfdev-5

This comment has been minimized.

Copy link
Contributor

commented May 3, 2019

@SsnL I totally agree with the 3rd point! The 2nd point is not clear, can you provide some details ?

@SsnL

This comment has been minimized.

Copy link
Collaborator

commented May 3, 2019

@vfdev-5 Yes, of course. I am trying to highlight usecases where this would affect this most.
Some people investigate the behavior of gradient updating and would look like the optimization process for like 50 or 500 number of steps. Generally they use a fixed learning rate and doesn’t really rely on or schdulers. But if any of these project use them to update lr per iteration, this change would have a major effect.

@vfdev-5

This comment has been minimized.

Copy link
Contributor

commented May 3, 2019

@SsnL I see, thanks!

I made some plots to compare 1.0.1 vs 1.1.0 if the code is as for 1.0.1:

import numpy as np
import matplotlib.pylab as plt
%matplotlib inline

from torch.optim.lr_scheduler import MultiStepLR

t = torch.tensor([0.0], requires_grad=True)
opt = torch.optim.SGD([t], lr=0.01)

lr_scheduler = MultiStepLR(opt, milestones=[10, 20, 30])

lrs = []
for e in range(35):
    lr_scheduler.step()
    lrs.append((e, opt.param_groups[0]['lr']))

lrs = np.array(lrs)
plt.plot(lrs[:, 0], lrs[:, 1])

For 1.0.1:
image

lrs[0, :], lrs[9, :], lrs[10, :]
> (array([0.  , 0.01]), array([9.  , 0.01]), array([1.e+01, 1.e-03]))

For 1.1.0 it is just shifted one iteration:
image

lrs[0, :], lrs[9, :], lrs[10, :]
> (array([0.  , 0.01]), array([9.e+00, 1.e-03]), array([1.e+01, 1.e-03]))
@SsnL

This comment has been minimized.

Copy link
Collaborator

commented May 3, 2019

@vfdev-5 Thanks for the plots :). That seems to be most of the case. But there do seem to be something weird going on for CosineAnnealingLR #20086

@vfdev-5

This comment has been minimized.

Copy link
Contributor

commented May 3, 2019

@SsnL yeah, very weird behaviour. Seems like it is also because of recursive implementation...

@lihuanglx

This comment has been minimized.

Copy link

commented May 5, 2019

It's hard to believe this patch is merged into the stable release. From my point of view, this "fix" starts totally from personal perference rather than real need. I don't see any benefit of this change, neither in terms of efficiency nor in terms of usability. And I don't see any real problem of the old implementation and old convertion of putting the scheduler.step() at the beginning of each iteration. But it got merged and horribly breaks almost all current projects based on PyTorch, and seems lead to bugs (#20086, #20138 ). Normally such huge compatibility-breaking changes should only be considered if it greatly improves the software, otherwise it is not responsible to the users, especially considering the enormous size of PyTorch community. @ezyang @soumith

@ezyang

This comment has been minimized.

Copy link
Contributor

commented May 6, 2019

I'm going to go ahead and revert this. @soumith, it would be great if the revert could be cherry-picked to 1.1 as well. This shouldn't have been merged so close to the 1.1 branch cut; that is my mistake. @bado-lee, I can help you put together a deprecation schedule.

@ezyang

This comment has been minimized.

Copy link
Contributor

commented May 6, 2019

Actually, I am catching up on the discussion. I see that #20059 was merged while I was away. @lihuanglx, is that patch sufficient to eliminate the BC-breakage, or is a more complete revert needed?

ezyang added a commit that referenced this pull request May 6, 2019

ezyang added a commit that referenced this pull request May 6, 2019

Revert "Fix lr_scheduler's last_epoch value at the time of initializa…
…tion (BC BREAKING!) (#7889)"

This reverts commit 3608490.

gh-metadata: pytorch pytorch 20147 gh/ezyang/117/head
@SsnL

This comment has been minimized.

Copy link
Collaborator

commented May 6, 2019

@ezyang I don't think #20059 change the behavior.

Also, if we still want to make this change, please consider my suggestion at #20124 on adding a warning.

@waleedka

This comment has been minimized.

Copy link

commented May 6, 2019

Sure, this should have been planned better, but I think the fix is valid. It was confusing to have to call optimizer step() at the end and scheduler step() at the beginning, and this PR fixes that inconsistency. Since it's already out there, reverting it might cause more damage and confusion. I second @SsnL 's suggestion of showing a warning when the old pattern is detected and leave it at that.

@bado-lee

This comment has been minimized.

Copy link
Contributor Author

commented May 6, 2019

@ezyang I saw your revert commit. Is it decided to revert the change?
Honestly, it also occurs to me that it could have been smoother if the change had been warned for couple of minor versions earlier. But actually, I do not suspect it's validity. I understand that this change can influence many people and be even shocking to some. But it doesn't mean that this fix should be neglected.
Also, since it's already released as version 1.1 so reverting it would only cause more damage.
I think @SsnL 's suggestion would be sufficient as of now.

@ezyang

This comment has been minimized.

Copy link
Contributor

commented May 6, 2019

Hey @bado-lee, I've closed the revert commit, so we're going to keep it :) Is there an outstanding PR to add a warning for the old pattern? I'd be happy to help review.

@bado-lee

This comment has been minimized.

Copy link
Contributor Author

commented May 6, 2019

@ezyang I think I can work on #20124 if there isn't a PR already. I'll let you know when done. Thank you :)

@vfdev-5

This comment has been minimized.

Copy link
Contributor

commented May 6, 2019

@bado-lee I was also thinking to send a PR too.

@bado-lee

This comment has been minimized.

Copy link
Contributor Author

commented May 6, 2019

@vfdev-5 Oh in that case please be my guest. I would be more than happy to review it :) btw thanks for the analysis!

zhangguanheng66 pushed a commit to zhangguanheng66/pytorch that referenced this pull request May 6, 2019

Fix lr_scheduler's last_epoch value at the time of initialization (BC…
… BREAKING!) (pytorch#7889)

Summary:
Hello everyone :) !!

I've found that lr_scheduler was initialized with last_epoch as -1.
This causes that even after the first step (not the one in init but explicit step of scheduler),
learning rate of scheduler's optimizer remains as the previous.
```python
>>> import torch
>>> cc = torch.nn.Conv2d(10,10,3)
>>> myinitial_lr = 0.1
>>> myoptimizer = torch.optim.Adam(cc.parameters(), lr=myinitial_lr)
>>> mylrdecay = 0.5
>>> myscheduler = torch.optim.lr_scheduler.ExponentialLR(myoptimizer,mylrdecay)

>>> myscheduler.get_lr()
[0.2]    # this is because of  get_lr calculates lr by 0.1 * 0.5^-1
>>> myscheduler.optimizer.param_groups[0]["lr"]
0.1    # this is not consistent with get_lr value
>>> myscheduler.last_epoch
-1

>>> myscheduler.step()
>>> myscheduler.get_lr()
[0.1]    # this should be the value right after the init, not after first step
>>> myscheduler.optimizer.param_groups[0]["lr"]
0.1    # since this is after first step, it should have been decayed as 0.05
>>> myscheduler.last_epoch
0

>>> myscheduler.step()
>>> myscheduler.last_epoch
1
>>> myscheduler.get_lr()
[0.05]
>>> myscheduler.optimizer.param_groups[0]["lr"]
0.05
>>> myscheduler.last_epoch
1
```

First problem is, even after the init of lr_scheduler, you get the inconsistent parameter values.

The second problem is, you are stuck with same learning rate in the first 2 epochs if the step function of lr_scheduler is not called in the beginning of the epoch loop.
Of course, you can avoid this by calling lr_scheduler's step in the beginning,
but I don't think this is proper use since, incase of optimizer, step is called in the end of the iteration loop.

I've simply avoided all above issues by setting last_epoch as 0 after the initialization.

This also makes sense when you init with some value of last_epoch which is not -1.
For example, if you want to init with last epoch 10,
lr should not be set with decayed 1 step further. Which is
last_epoch gets +1 in the previous code.
base_lr * self.gamma ** self.last_epoch

Instead, it should be set with step 10 exact value.

I hope this fix find it's way with all your help :)
I'm really looking forward & excited to become a contributor for pytorch!
Pytorch Rocks!!
Pull Request resolved: pytorch#7889

Differential Revision: D15012769

Pulled By: ezyang

fbshipit-source-id: 258fc3009ea7b7390a3cf2e8a3682eafb506b08b
@gchanan

This comment has been minimized.

Copy link
Contributor

commented May 8, 2019

One thing I don't quite understand:

There's a discussion about the keras API (on_epoch_begin) and some comments indicating that the name makes what's going on clearer, but no discussion about introducing something like that as a future fix that doesn't break BC.

Should we do something like that?

@ezyang

This comment has been minimized.

Copy link
Contributor

commented Jun 11, 2019

@ezyang ezyang added the open source label Jun 24, 2019

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.