Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FR] Warn if scheduler.step() is called but optim.step has not been called #20124

Closed
ssnl opened this issue May 4, 2019 · 12 comments
Closed
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ssnl
Copy link
Collaborator

ssnl commented May 4, 2019

In 1.1 we made a major BC breaking change, where the order of calling lr schedulers should be changed from

for e in range(nepochs):
  scheduler.step()
  train()

to

for e in range(nepochs):
  train()
  scheduler.step()

This silently breaks many code, and makes it impossible to write consistent code for 1.0.1 and 1.1. So I propose to add a warning in scheduler.step where it looks at the corresponding optimizer, and checks if its .step has been called.

If it has not been called, this is a sign that the user is using scheduler.step() with the old pattern. I can't think of a reasonable case where this would detect a false positive.

@ssnl
Copy link
Collaborator Author

ssnl commented May 4, 2019

cc @vfdev-5

@vishwakftw vishwakftw added enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: optimizer Related to torch.optim labels May 4, 2019
vfdev-5 added a commit to vfdev-5/pytorch that referenced this issue May 6, 2019
Detect old pattern of using lr scheduler
@izdeby izdeby added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 6, 2019
@vfdev-5 vfdev-5 mentioned this issue May 7, 2019
@apaszke
Copy link
Contributor

apaszke commented May 7, 2019

Why did that change??

facebook-github-bot pushed a commit that referenced this issue Jun 6, 2019
Summary:
This PR addresses the problem described in the comment: #20203 (comment)
and previously coded bad behaviour:
- a warning was raised all the times when lr schedulling is initialized

Now the code checks that:
- on the second call of `lr_scheduler.step`, ensure that `optimizer.step` has been already called, otherwise raise a warning (as it was done in #20203 )
- if optimizer's step is overridden -> raise once another warning to aware user about the new pattern:
`opt.step()` -> `lrs.step()` as we can not check this .

Now tests check that
- at initialization (`lrs = StepLR(...)`)there is no warnings
- if we replace `optimizer.step` by something else (similarly to the [code of nvidia/apex](https://github.com/NVIDIA/apex/blob/master/apex/amp/_process_optimizer.py#L287)) there is another warning raised.

cc ezyang

PS. honestly I would say that there is a lot of overhead introduced for simple warnings. I hope all these checks will be removed in future `1.2.0` or other versions...
Pull Request resolved: #21460

Differential Revision: D15701776

Pulled By: ezyang

fbshipit-source-id: eac5712b9146d9d3392a30f6339cd33d90c497c7
@wwoods
Copy link

wwoods commented Aug 22, 2019

Like @apaszke, I'm wondering why this change happened? It is a rather busy-body change that does not improve user-facing pytorch code in any way, as far as I can tell. Furthermore, this complains about perfectly fine code which called step(epoch) explicitly, rather than relying on last_epoch in the _LRScheduler code. It now issues a warning incorrectly: step(0) being called prior to the optimizer step results in a warning being issued, despite implementing the correct behavior.

@vincentqb
Copy link
Contributor

The change was introduced in #7889, and documented in #20203.

@ruppeshnalwaya1993
Copy link

ruppeshnalwaya1993 commented Nov 2, 2019

@ssnl The warning says in pytorch 1.1.0 and later optimizer step should be before lr scheduler step. But the warning does not appear in pytorch 1.1.0. It only appears in 1.2.0 and later. Is this intended or a mistake?

1.1.0 file
https://github.com/pytorch/pytorch/blob/v1.1.0/torch/optim/lr_scheduler.py

@ssnl
Copy link
Collaborator Author

ssnl commented Nov 2, 2019 via email

@antimora
Copy link

@ssnl, @vincentqb the implementation assumes the scheduler (fully initialized) exists prior calling optimizer.step(). The scheduler monkey patches optimizer object! I get the warning despite the scheduler.step is called after optimizer.step() because my scheduler is initialized after optimizer.step() call (lazy eval pattern). I wish your implementation did not monkey patch the optimizer object. This behavior is definitely not documented and leads to chasing the issue in wrong places.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jun 29, 2020

As it was discussed #21460 (comment) maybe it is time to remove it ...
cc @ezyang

@ssnl
Copy link
Collaborator Author

ssnl commented Jun 29, 2020

@antimora Yes I fully agree that it is bad practice. It was a decision we had to make because of the silent correctness issues. I vote for removing it. Let's see what @ezyang thinks!

@vincentqb
Copy link
Contributor

Since the warning was added in 1.2 and we are now about to release 1.6, I also agree we can remove the warning to check this. I would make sure the documentation does specify the order though :) Thoughts?

@antimora
Copy link

antimora commented Jul 5, 2020

I am in favor to remove the warning. Should I file a ticket to get it done?

(CCing @ezyang since everyone is CCing him here)

@ezyang
Copy link
Contributor

ezyang commented Jul 6, 2020

Yeah let's can it. cc @soumith

facebook-github-bot pushed a commit that referenced this issue Nov 10, 2021
Summary:
Fixes #67601.

As simple a fix as I could make it. I even managed to delete some testing code!

I checked calling `super()` and, as I had feared, it doesn't work out the box, so perhaps that ought to be revisited later.

As it stands,  #20124, still applies to the chained scheduler, but I think this change is still an improvement.

Pull Request resolved: #68010

Reviewed By: zou3519

Differential Revision: D32278139

Pulled By: albanD

fbshipit-source-id: 4c6f9f1b2822affdf63a6d22ddfdbcb1c6afd579
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

10 participants