Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optim.lr_scheduler.CyclicLR (master only: not released) is buggy when not using momentum #19003

Closed
Youssefares opened this issue Apr 7, 2019 · 22 comments
Labels
module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@Youssefares
Copy link

Issue description

If I use an optimizer like adam with no momentum and follow the message of line 578 below here, passing cycle_momentum=False, line 584 throws a KeyError because key 'momentum' is not set.

if cycle_momentum:
if 'momentum' not in optimizer.defaults:
raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
if last_epoch == -1:
for momentum, group in zip(base_momentums, optimizer.param_groups):
group['momentum'] = momentum
self.base_momentums = list(map(lambda group: group['momentum'], optimizer.param_groups))
self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)

  1. I believe the last two lines should be indented under the if-condition as well.
  2. I also think the expected behavior is that cycle_momentum is set to False by default when the passed in optimizer doesn't support momentum, instead of throwing an error.

Code example

optimizer = optim.Adam(params_to_update, lr=5e-4)
scheduler = CyclicLR(optimizer, base_lr=5e-6, max_lr=5e-2, cycle_momentum=False, step_size_up=2500)

System Info

In my setup, I am using the stable version of pytorch, but I copied over CyclicLR and relative imports from master and I am using it in my project.

Collecting environment information...
PyTorch version: 1.0.1.post2
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Debian GNU/Linux 9.8 (stretch)
GCC version: (Debian 6.3.0-18+deb9u1) 6.3.0 20170516
CMake version: Could not collect

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration: GPU 0: Tesla P4
Nvidia driver version: 410.72
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] intel-numpy==1.15.1
[pip3] numpy==1.15.1
[pip3] torch==1.0.1.post2
[pip3] torchvision==0.2.2.post3
[conda] blas                      1.0                         mkl
[conda] mkl                       2019.1                      144
[conda] mkl-service               1.1.2            py37he904b0f_5
[conda] mkl_fft                   1.0.6            py37hd81dba3_0
[conda] mkl_random                1.0.2            py37hd81dba3_0
[conda] pytorch                   1.0.1           py3.7_cuda10.0.130_cudnn7.4.2_2    pytorch
[conda] torchvision               0.2.2                      py_3    pytorch
@umanwizard umanwizard added module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 8, 2019
@soumith soumith added this to the 1.1 milestone Apr 9, 2019
@Youssefares
Copy link
Author

Let me know if you'd like me to work on this 👨‍💻

@soumith
Copy link
Member

soumith commented Apr 9, 2019

@Youssefares sure thing, go for it. Thank you!

@trias702
Copy link

trias702 commented May 1, 2019

This bug is still present in the official PyTorch 1.1.0 which was released Apr 30, 2019. It's a really bad bug since CyclicLR is a big feature of the new pytorch and this bug causes CyclicLR to fail on any optimizer which doesn't support momentum. It's a fairly easy fix too, just indent two lines of code.

@BorisMarjanovic
Copy link

When will this bug be fixed?

@memahesh
Copy link

This bug is still present. Is anyone working to fix this ? If not, maybe I can work on it.

@memahesh
Copy link

memahesh commented Jun 1, 2019

@soumith and @Youssefares
Can any of you tell me if its fixed or not ?

@Youssefares
Copy link
Author

@memahesh
I haven't had time to fix this I am afraid.
Sorry for any delay I might have caused. I think a pull request is already open for it though. So you can follow the progress there.

@jeroenvuurens
Copy link

I think you are correct, if you indent the second and third to last lines in the init it works.

@gchanan
Copy link
Contributor

gchanan commented Jul 16, 2019

Fixed in #20401.

@gchanan gchanan closed this as completed Jul 16, 2019
@meghbhalerao
Copy link

Hmm. Is anyone getting this error in pytorch 1.6.0?

@ghost
Copy link

ghost commented Sep 12, 2020

I use PyTorch 1.4, but I still get the following error when using Adam without momentum:
ValueError: optimizer must support momentum with cycle_momentum option enabled

What I do now instead is to use SGD with momentum ...

@meghbhalerao
Copy link

Oh, okay. So basically there are some compatibility issues with the optimizer and lr scheduler for certain combinations of the two?

@ghost
Copy link

ghost commented Sep 12, 2020

Exactly. At least it seems so.

I hope this will be fixed at some point. :-)

@kaushal-py
Copy link

The bug exists in pytorch 1.6 as well.

@snk4tr
Copy link

snk4tr commented Nov 6, 2020

The bug exists in PyTorch 1.7.0 😞

@Jungjee
Copy link

Jungjee commented Nov 20, 2020

Is there anyone who knows why this issue is closed although this bug exists in the latest stable release(1.7.0)??

Or is there a solution for this phenomenon except for setting cycle_momentum=False

@Moldoteck
Copy link

Any updates for this?

@thomashirtz
Copy link

thomashirtz commented Jun 12, 2021

Is someone currently working on this ? (This issue is not solved)

@patricio-astudillo
Copy link

The bug exists in PyTorch 1.9.0

@cowwoc
Copy link

cowwoc commented Aug 19, 2021

@gchanan Please reopen this issue. It is not fixed.

@thomashirtz
Copy link

Does someone know/can explain me how to fix it ? I can try to make a PR

@gonzrubio
Copy link

The bug exists in PyTorch 2.0.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests