Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed parameter scheduler bug with CosineAnnealingWarmRestarts #2938

Merged
merged 29 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bb5e244
remove codecov
AlexanderChaptykov Apr 13, 2023
f56b362
RankProcessFirst
AlexanderChaptykov Apr 18, 2023
5801dd5
annotations
AlexanderChaptykov Apr 18, 2023
34e77e7
Merge remote-tracking branch 'upstream/master'
AlexanderChaptykov Apr 18, 2023
86e564f
from class to contextlib
AlexanderChaptykov Apr 19, 2023
2f75b92
from class to contextlib and test
AlexanderChaptykov Apr 24, 2023
fcb555c
del test file
AlexanderChaptykov Apr 25, 2023
ccef9c2
uniq folder for test
AlexanderChaptykov Apr 25, 2023
37a9102
refactor tests + new assert_test
AlexanderChaptykov Apr 25, 2023
178c420
add to __all__, remove idist import
AlexanderChaptykov Apr 26, 2023
1cef268
Apply suggestions from code review
vfdev-5 Apr 26, 2023
b2897a8
Merge branch 'master' into master
vfdev-5 Apr 26, 2023
3f7dd99
Apply suggestions from code review
vfdev-5 Apr 26, 2023
aea674d
Update tests/ignite/distributed/utils/test_native.py
vfdev-5 Apr 26, 2023
8c7cebc
Added local arg and renamed function
vfdev-5 Apr 26, 2023
01636e2
Merge remote-tracking branch 'upstream/master'
AlexanderChaptykov May 8, 2023
bdcbad4
add proxy class
AlexanderChaptykov May 8, 2023
8ca28fd
annotation
AlexanderChaptykov May 8, 2023
92af29b
test, proxy class
AlexanderChaptykov May 22, 2023
45da45b
add optim
AlexanderChaptykov May 22, 2023
96cb1dc
name change
AlexanderChaptykov May 22, 2023
4a72e07
Merge branch 'pytorch:master' into bug_cosine_sched
AlexanderChaptykov May 22, 2023
ea8b803
test upd/ setter
AlexanderChaptykov May 23, 2023
83d22d9
Merge remote-tracking branch 'origin/bug_cosine_sched' into bug_cosin…
AlexanderChaptykov May 23, 2023
6e02a0f
class fix
AlexanderChaptykov May 23, 2023
165ed36
Fixed mypy issues
vfdev-5 May 23, 2023
5c9b99d
test upd
AlexanderChaptykov May 23, 2023
4445b02
Fixed failing test_lr_scheduler
vfdev-5 May 23, 2023
971fc64
Merge remote-tracking branch 'origin/bug_cosine_sched' into bug_cosin…
AlexanderChaptykov May 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
68 changes: 63 additions & 5 deletions ignite/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any, cast, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union

import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ReduceLROnPlateau
from torch.optim.optimizer import Optimizer

# https://github.com/pytorch/ignite/issues/2773
Expand Down Expand Up @@ -792,6 +792,61 @@
return output


class _CosineAnnealingWarmRestarts:
def __init__(self, lr_scheduler: CosineAnnealingWarmRestarts):
self._lr_scheduler = lr_scheduler

@property
def last_epoch(self) -> int:
return self._lr_scheduler.last_epoch

@last_epoch.setter
def last_epoch(self, value: int) -> None:
self._lr_scheduler.last_epoch = value

@property
def optimizer(self) -> torch.optim.Optimizer:
return self._lr_scheduler.optimizer

def get_lr(self, epoch: Optional[int] = None) -> List[float]:
# TODO: Remove this workaround when pytorch has fixed wrong type hints:
# https://github.com/pytorch/pytorch/pull/102067
# Replace below T_mult -> self._lr_scheduler.T_mult
# Replace below eta_min -> self._lr_scheduler.eta_min
T_mult = cast(int, self._lr_scheduler.T_mult)
eta_min = cast(float, self._lr_scheduler.eta_min)

if epoch is None and self.last_epoch < 0:
epoch = 0

Check warning on line 820 in ignite/handlers/param_scheduler.py

View check run for this annotation

Codecov / codecov/patch

ignite/handlers/param_scheduler.py#L820

Added line #L820 was not covered by tests
if epoch is None:
epoch = self.last_epoch + 1
self._lr_scheduler.T_cur = self._lr_scheduler.T_cur + 1
if self._lr_scheduler.T_cur >= self._lr_scheduler.T_i:
self._lr_scheduler.T_cur = self._lr_scheduler.T_cur - self._lr_scheduler.T_i
self._lr_scheduler.T_i = self._lr_scheduler.T_i * T_mult
else:
if epoch < 0:
raise ValueError("Expected non-negative epoch, but got {}".format(epoch))
if epoch >= self._lr_scheduler.T_0:
if T_mult == 1:
self._lr_scheduler.T_cur = epoch % self._lr_scheduler.T_0

Check warning on line 832 in ignite/handlers/param_scheduler.py

View check run for this annotation

Codecov / codecov/patch

ignite/handlers/param_scheduler.py#L828-L832

Added lines #L828 - L832 were not covered by tests
else:
n = int(math.log((epoch / self._lr_scheduler.T_0 * (T_mult - 1) + 1), T_mult))
self._lr_scheduler.T_cur = epoch - self._lr_scheduler.T_0 * (T_mult**n - 1) / (T_mult - 1)
self._lr_scheduler.T_i = self._lr_scheduler.T_0 * T_mult**n

Check warning on line 836 in ignite/handlers/param_scheduler.py

View check run for this annotation

Codecov / codecov/patch

ignite/handlers/param_scheduler.py#L834-L836

Added lines #L834 - L836 were not covered by tests
else:
self._lr_scheduler.T_i = self._lr_scheduler.T_0
self._lr_scheduler.T_cur = epoch

Check warning on line 839 in ignite/handlers/param_scheduler.py

View check run for this annotation

Codecov / codecov/patch

ignite/handlers/param_scheduler.py#L838-L839

Added lines #L838 - L839 were not covered by tests

self.last_epoch = math.floor(epoch)

return [
eta_min
+ (base_lr - eta_min) * (1 + math.cos(math.pi * self._lr_scheduler.T_cur / self._lr_scheduler.T_i)) / 2
for base_lr in self._lr_scheduler.base_lrs
]


class LRScheduler(ParamScheduler):
"""A wrapper class to call `torch.optim.lr_scheduler` objects as `ignite` handlers.

Expand Down Expand Up @@ -853,7 +908,10 @@
f"but given {type(lr_scheduler)}"
)

self.lr_scheduler = lr_scheduler
self.lr_scheduler: Union[PyTorchLRScheduler, _CosineAnnealingWarmRestarts] = lr_scheduler
if isinstance(lr_scheduler, CosineAnnealingWarmRestarts):
self.lr_scheduler = _CosineAnnealingWarmRestarts(lr_scheduler)

super(LRScheduler, self).__init__(
optimizer=self.lr_scheduler.optimizer,
param_name="lr",
Expand All @@ -863,7 +921,7 @@
warnings.warn(
"Please make sure to attach scheduler to Events.ITERATION_COMPLETED "
"instead of Events.ITERATION_STARTED to make sure to use "
"the first lr value from the optimizer, otherwise it is will be skipped"
"the first lr value from the optimizer, otherwise it will be skipped"
)
self.lr_scheduler.last_epoch += 1

Expand All @@ -876,9 +934,9 @@
def get_param(self) -> Union[float, List[float]]:
"""Method to get current optimizer's parameter value"""
# Emulate context manager for pytorch>=1.4
self.lr_scheduler._get_lr_called_within_step = True # type: ignore[attr-defined]
self.lr_scheduler._get_lr_called_within_step = True # type: ignore[union-attr]
lr_list = cast(List[float], self.lr_scheduler.get_lr())
self.lr_scheduler._get_lr_called_within_step = False # type: ignore[attr-defined]
self.lr_scheduler._get_lr_called_within_step = False # type: ignore[union-attr]
if len(lr_list) == 1:
return lr_list[0]
else:
Expand Down
46 changes: 44 additions & 2 deletions tests/ignite/handlers/test_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest
import torch
from torch.optim.lr_scheduler import ExponentialLR, StepLR
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ExponentialLR, StepLR

from ignite.engine import Engine, Events
from ignite.handlers.param_scheduler import (
Expand Down Expand Up @@ -650,7 +650,7 @@ def test_lr_scheduler(torch_lr_scheduler_cls, kwargs):
state_dict1 = scheduler1.state_dict()

torch_lr_scheduler2 = torch_lr_scheduler_cls(optimizer=optimizer2, **kwargs)
with pytest.warns(UserWarning, match=r"the first lr value from the optimizer, otherwise it is will be skipped"):
with pytest.warns(UserWarning, match=r"the first lr value from the optimizer, otherwise it will be skipped"):
scheduler2 = LRScheduler(torch_lr_scheduler2, use_legacy=True)
state_dict2 = scheduler2.state_dict()

Expand Down Expand Up @@ -1362,3 +1362,45 @@ def test_reduce_lr_on_plateau_scheduler_asserts():
with pytest.raises(ValueError, match=r"Length of argument metric_values should be equal to num_events."):
metric_values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
ReduceLROnPlateauScheduler.simulate_values(5, metric_values, 0.01)


@pytest.mark.parametrize("warmup_end_value", [0.23, None])
@pytest.mark.parametrize("T_0", [1, 12])
@pytest.mark.parametrize("T_mult", [1, 3])
def test_create_lr_scheduler_with_warmup_cosine(warmup_end_value, T_0, T_mult):
lr = 0.2
steps = 200
warm_steps = 50
warm_start = 0.023

def get_optim():
t1 = torch.zeros([1], requires_grad=True)
return torch.optim.SGD([t1], lr=lr)

def get_cos_shed():
return CosineAnnealingWarmRestarts(optimizer, T_0=T_0, T_mult=T_mult, verbose=False)

optimizer = get_optim()
scheduler = get_cos_shed()
cosine_lrs = []
for i in range(steps):
cosine_lrs.append(optimizer.param_groups[0]["lr"])
scheduler.step()

optimizer = get_optim()
scheduler = create_lr_scheduler_with_warmup(
get_cos_shed(), warmup_start_value=warm_start, warmup_end_value=warmup_end_value, warmup_duration=warm_steps
)

warm_lrs = []
real_warm_steps = warm_steps if warmup_end_value is not None else (warm_steps - 1)
for epoch in range(real_warm_steps + steps):
scheduler(None)
warm_lrs.append(optimizer.param_groups[0]["lr"])

if warmup_end_value is not None:
np.testing.assert_allclose(np.linspace(warm_start, warmup_end_value, warm_steps), warm_lrs[:warm_steps])
assert warm_lrs[real_warm_steps:] == cosine_lrs
else:
np.testing.assert_allclose(np.linspace(warm_start, lr, warm_steps), warm_lrs[:warm_steps])
assert warm_lrs[real_warm_steps:] == cosine_lrs