Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in CosineAnnealingWarmRestarts #49630

Open
heilaw opened this issue Dec 19, 2020 · 5 comments
Open

Bug in CosineAnnealingWarmRestarts #49630

heilaw opened this issue Dec 19, 2020 · 5 comments
Labels
module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@heilaw
Copy link

heilaw commented Dec 19, 2020

馃悰 Bug

CosineAnnealingWarmRestarts gives AttributeError if we initialize it with last_epoch not equal to -1, which usually happens when one wants to resume training a network.

To Reproduce

Steps to reproduce the behavior:

  1. Initialize the scheduler and set the last_epoch to where we left off
scheduler = CosineAnnealingWarmRestarts(
      optimizer, restart_iter,
      T_mult=restart_mult,
      last_epoch=_last_iter # where we left off, _last_epoch should be greater than -1
)
  1. The code gives this error:
  File "/foo/torch/optim/lr_scheduler.py", line 965, in __init__
    super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch, verbose)
  File "/foo/torch/optim/lr_scheduler.py", line 79, in __init__
    self.step()
  File "/foo/torch/optim/lr_scheduler.py", line 1008, in step
    self.T_cur = self.T_cur + 1
AttributeError: 'CosineAnnealingWarmRestarts' object has no attribute 'T_cur'

Expected behavior

No attribute error.

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
  • PyTorch Version (e.g., 1.0): 1.7.0
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source):
  • Python version: 3.7
  • CUDA/cuDNN version: 11/6.0.21
  • GPU models and configuration: RTX 3090
  • Any other relevant information:

Additional context

After some investigation, I believe I found the code that causes this bug. When we initialize CosineAnnealingWarmRestarts, it will call its parent (_LRScheduler) constructor (line 965 in torch/optim/lr_scheduler.py). The constructor of _LRScheduler will call the step function (line 79 in torch/optim/lr_scheduler.py) of CosineAnnealingWarmRestarts before finishing the initialization. In the step function, because epoch is none and last_epoch is greater than 0, line 1008 will be executed which involves self.T_cur. But at this point self.T_cur has not been set yet as it is only set after the initialization is done (line 967).

cc @vincentqb

@VitalyFedyunin
Copy link
Contributor

Triage review because we probably need new 'scheduler' label

@VitalyFedyunin
Copy link
Contributor

Optimizer label just fine for it

@mruberry mruberry added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 29, 2020
@mruberry
Copy link
Collaborator

Thank you for reporting this issue, @heilaw. I think we would accept a PR fixing this.

@ziyuang
Copy link

ziyuang commented Nov 26, 2021

Not merged yet?

@hugefrog
Copy link

bug is still exsit in pytorch 1.10

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants