Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed parameter scheduler bug with CosineAnnealingWarmRestarts #2938

Merged
merged 29 commits into from
May 23, 2023
Merged
Changes from 18 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bb5e244
remove codecov
AlexanderChaptykov Apr 13, 2023
f56b362
RankProcessFirst
AlexanderChaptykov Apr 18, 2023
5801dd5
annotations
AlexanderChaptykov Apr 18, 2023
34e77e7
Merge remote-tracking branch 'upstream/master'
AlexanderChaptykov Apr 18, 2023
86e564f
from class to contextlib
AlexanderChaptykov Apr 19, 2023
2f75b92
from class to contextlib and test
AlexanderChaptykov Apr 24, 2023
fcb555c
del test file
AlexanderChaptykov Apr 25, 2023
ccef9c2
uniq folder for test
AlexanderChaptykov Apr 25, 2023
37a9102
refactor tests + new assert_test
AlexanderChaptykov Apr 25, 2023
178c420
add to __all__, remove idist import
AlexanderChaptykov Apr 26, 2023
1cef268
Apply suggestions from code review
vfdev-5 Apr 26, 2023
b2897a8
Merge branch 'master' into master
vfdev-5 Apr 26, 2023
3f7dd99
Apply suggestions from code review
vfdev-5 Apr 26, 2023
aea674d
Update tests/ignite/distributed/utils/test_native.py
vfdev-5 Apr 26, 2023
8c7cebc
Added local arg and renamed function
vfdev-5 Apr 26, 2023
01636e2
Merge remote-tracking branch 'upstream/master'
AlexanderChaptykov May 8, 2023
bdcbad4
add proxy class
AlexanderChaptykov May 8, 2023
8ca28fd
annotation
AlexanderChaptykov May 8, 2023
92af29b
test, proxy class
AlexanderChaptykov May 22, 2023
45da45b
add optim
AlexanderChaptykov May 22, 2023
96cb1dc
name change
AlexanderChaptykov May 22, 2023
4a72e07
Merge branch 'pytorch:master' into bug_cosine_sched
AlexanderChaptykov May 22, 2023
ea8b803
test upd/ setter
AlexanderChaptykov May 23, 2023
83d22d9
Merge remote-tracking branch 'origin/bug_cosine_sched' into bug_cosin…
AlexanderChaptykov May 23, 2023
6e02a0f
class fix
AlexanderChaptykov May 23, 2023
165ed36
Fixed mypy issues
vfdev-5 May 23, 2023
5c9b99d
test upd
AlexanderChaptykov May 23, 2023
4445b02
Fixed failing test_lr_scheduler
vfdev-5 May 23, 2023
971fc64
Merge remote-tracking branch 'origin/bug_cosine_sched' into bug_cosin…
AlexanderChaptykov May 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
40 changes: 39 additions & 1 deletion ignite/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any, cast, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union

import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ReduceLROnPlateau
from torch.optim.optimizer import Optimizer

# https://github.com/pytorch/ignite/issues/2773
Expand Down Expand Up @@ -852,6 +852,44 @@
f"torch.optim.lr_scheduler.{PyTorchLRScheduler.__name__}, "
f"but given {type(lr_scheduler)}"
)
if isinstance(lr_scheduler, torch.optim.lr_scheduler.CosineAnnealingWarmRestarts):

class _CosineAnnealingWarmRestarts(CosineAnnealingWarmRestarts):
def get_lr(self, epoch: Union[int, None] = None) -> Union[List[float], float]:
if epoch is None and self.last_epoch < 0:
epoch = 0

Check warning on line 860 in ignite/handlers/param_scheduler.py

View check run for this annotation

Codecov / codecov/patch

ignite/handlers/param_scheduler.py#L857-L860

Added lines #L857 - L860 were not covered by tests

if epoch is None:
epoch = self.last_epoch + 1
self.T_cur = self.T_cur + 1
if self.T_cur >= self.T_i:
self.T_cur = self.T_cur - self.T_i
self.T_i = self.T_i * self.T_mult

Check warning on line 867 in ignite/handlers/param_scheduler.py

View check run for this annotation

Codecov / codecov/patch

ignite/handlers/param_scheduler.py#L862-L867

Added lines #L862 - L867 were not covered by tests
else:
if epoch < 0:
raise ValueError("Expected non-negative epoch, but got {}".format(epoch))
if epoch >= self.T_0:
if self.T_mult == 1:
self.T_cur = epoch % self.T_0

Check warning on line 873 in ignite/handlers/param_scheduler.py

View check run for this annotation

Codecov / codecov/patch

ignite/handlers/param_scheduler.py#L869-L873

Added lines #L869 - L873 were not covered by tests
else:
n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / (self.T_mult - 1)
self.T_i = self.T_0 * self.T_mult ** (n)

Check warning on line 877 in ignite/handlers/param_scheduler.py

View check run for this annotation

Codecov / codecov/patch

ignite/handlers/param_scheduler.py#L875-L877

Added lines #L875 - L877 were not covered by tests
else:
self.T_i = self.T_0
self.T_cur = epoch
self.last_epoch = math.floor(epoch)

Check warning on line 881 in ignite/handlers/param_scheduler.py

View check run for this annotation

Codecov / codecov/patch

ignite/handlers/param_scheduler.py#L879-L881

Added lines #L879 - L881 were not covered by tests

return super(_CosineAnnealingWarmRestarts, self).get_lr()

Check warning on line 883 in ignite/handlers/param_scheduler.py

View check run for this annotation

Codecov / codecov/patch

ignite/handlers/param_scheduler.py#L883

Added line #L883 was not covered by tests
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

lr_scheduler = _CosineAnnealingWarmRestarts(

Check warning on line 885 in ignite/handlers/param_scheduler.py

View check run for this annotation

Codecov / codecov/patch

ignite/handlers/param_scheduler.py#L885

Added line #L885 was not covered by tests
lr_scheduler.optimizer,
lr_scheduler.T_0,
lr_scheduler.T_mult,
lr_scheduler.eta_min,
lr_scheduler.last_epoch,
lr_scheduler.verbose,
)

self.lr_scheduler = lr_scheduler
super(LRScheduler, self).__init__(
Expand Down