Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

WarmRestarts seems not working with create_lr_scheduler_with_warmup function #2910

Closed
developer0hye opened this issue Apr 11, 2023 · 5 comments 路 Fixed by #2938
Closed

WarmRestarts seems not working with create_lr_scheduler_with_warmup function #2910

developer0hye opened this issue Apr 11, 2023 · 5 comments 路 Fixed by #2938

Comments

@developer0hye
Copy link

developer0hye commented Apr 11, 2023

馃悰 Bug description

image

Environment

  • PyTorch Version (e.g., 1.4): 2.0.0+cu118
  • Ignite Version (e.g., 0.3.0): 0.4.11
  • OS (e.g., Linux): Linux 8082c16db087 5.10.147+#1 SMP Sat Dec 10 16:00:40 UTC 2022 x86_64 x86_64 x86_64 GNU/Linux
  • How you installed Ignite (conda, pip, source): pip
  • Python version: 3.9.16
  • Any other relevant information:
@developer0hye
Copy link
Author

import torch
import torchvision
import ignite
from ignite.handlers.param_scheduler import create_lr_scheduler_with_warmup
import matplotlib.pyplot as plt
import numpy as np

model = torchvision.models.resnet18()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

scheduler = create_lr_scheduler_with_warmup(torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 
                                                                 T_0=10,
                                                                 T_mult=1, 
                                                                 verbose=False),
                                            warmup_start_value=1e-10,
                                            warmup_end_value=1e-3,
                                            warmup_duration=10)
lr = []
for epoch in range(100):
  scheduler(None)
  lr.append(optimizer.param_groups[0]["lr"])
plt.plot(np.arange(len(lr)), lr, label="lr")

@sadra-barikbin
Copy link
Collaborator

sadra-barikbin commented Apr 11, 2023

@developer0hye, thank you for reporting this bug. This occurs because we do not call step method of the torch schedulers, instead increment their last_epoch attribute to move them forward. CosineAnnealingWarmRestarts happens to be the only torch scheduler which determines its new lrs using attributes (T_cur and T_i) other than last_epoch; the ones that get updated using last_epoch in the very step method. @vfdev-5 why we don't call step to avoid meddling in torch schedulers' move-forward logic?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Apr 11, 2023

This is related to #813 (comment)

@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 23, 2023

Hey @developer0hye we just fixed this issue and after midnight this fix will be available on pytorch-ignite nightly:

pip install --upgrade --pre pytorch-ignite

if you would like to give it a try. It works as reported in the PR: #2938 (comment)

@developer0hye
Copy link
Author

@vfdev-5 Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants