Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistant values of lr_scheduler.get_lr and lr in optimizer.param_groups #20138

Open
lihuanglx opened this issue May 5, 2019 · 2 comments
Labels
module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@lihuanglx
Copy link

lihuanglx commented May 5, 2019

🐛 Bug

After upgrading to 1.1.0, the value returned by lr_scheduler.get_lr is confusing comparing to the lr value inside optimizer.param_groups.

To Reproduce

Here I follow the new convention putting the lr_scheduler.step() at the end of each iteration, see the new documents and #7889 (which is probably the root of this issue).

Code:

# 1.1.0
import torch

net = torch.nn.Conv2d(1, 1, 1)
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

for i in range(5):
    print(i, lr_scheduler.get_lr(), optimizer.param_groups[0]['lr'])
    lr_scheduler.step()

Output:

0 [0.1] 0.1
1 [0.08100000000000002] 0.09000000000000001
2 [0.07290000000000002] 0.08100000000000002
3 [0.06561000000000002] 0.07290000000000002
4 [0.05904900000000002] 0.06561000000000002

We got inconsistant values. The two values are the same in the first line, but different in following lines by lr decay factor gamma = 0.9.

Expected behavior

At least we should have consistant values of the two, right?

In the old version 1.0.1, if we follow the previous convention putting the lr_scheduler.step() at the beginning of each iteration, the output values are reasonable and consistant:

Code:

# 1.0.1
import torch

net = torch.nn.Conv2d(1, 1, 1)
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

for i in range(5):
    lr_scheduler.step()
    print(i, lr_scheduler.get_lr(), optimizer.param_groups[0]['lr'])

Output:

0 [0.1] 0.1
1 [0.09000000000000001] 0.09000000000000001
2 [0.08100000000000002] 0.08100000000000002
3 [0.0729] 0.0729
4 [0.06561] 0.06561

Environment

I'll skip this part since it can be easily reproduced from a fresh 1.1.0 installation.

@vishwakftw vishwakftw added the module: optimizer Related to torch.optim label May 5, 2019
@codexetreme
Copy link
Contributor

If this is a bug, can I try to work on it?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 5, 2019

@lihuanglx I would say this is also due to recursive reimplementation of LR schedulers, #14010
such that, multiple schedulers can be used simultaneously to modify the learning rates.

@izdeby izdeby added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 6, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants