Skip to content

Commit

Permalink
Feature/#2465 lr scheduler attach events (#2496)
Browse files Browse the repository at this point in the history
* adjuste to fit the changes such like event

* #2465 make LRScheduler attachable to Events.ITERATION_STARTED

* #2465 add use_legacy test and adjuste to fit the changes such like event

* bug fix

* #2465 modify docstring and warning message

* #2465 bug fix

* add keep_first_lr

* bug fix

* #2465 bug fix about keep_first_lr

* #2465 modify test to accommodate this change

* #2465 bug fix by addding type: ignore[attr-defined]

* remove keep_first_lr and adjust last_epoch in create_lr_scheduler_with_warmup

* #2465 add keep_first_lr

* WIP

* Added versionchanged

Co-authored-by: Sylvain Desroziers <sylvain.desroziers@gmail.com>
Co-authored-by: vfdev-5 <vfdev.5@gmail.com>
  • Loading branch information
3 people committed Apr 19, 2022
1 parent 727150e commit 545d125
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 59 deletions.
3 changes: 0 additions & 3 deletions ignite/contrib/engines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ignite.contrib.handlers import (
ClearMLLogger,
global_step_from_engine,
LRScheduler,
MLflowLogger,
NeptuneLogger,
PolyaxonLogger,
Expand Down Expand Up @@ -165,8 +164,6 @@ def _setup_common_training_handlers(
trainer.add_event_handler(
Events.ITERATION_COMPLETED, lambda engine: cast(_LRScheduler, lr_scheduler).step()
)
elif isinstance(lr_scheduler, LRScheduler):
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
else:
trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

Expand Down
43 changes: 29 additions & 14 deletions ignite/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,7 @@ class LRScheduler(ParamScheduler):
lr_scheduler: lr_scheduler object to wrap.
save_history: whether to log the parameter values to
`engine.state.param_history`, (default=False).
use_legacy: if True, scheduler should be attached to ``Events.ITERATION_COMPLETED``, (default=False).
Examples:
Expand All @@ -808,20 +809,14 @@ class LRScheduler(ParamScheduler):
from torch.optim.lr_scheduler import StepLR
torch_lr_scheduler = StepLR(default_optimizer, step_size=3, gamma=0.1)
scheduler = LRScheduler(torch_lr_scheduler)
default_trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
@default_trainer.on(Events.ITERATION_COMPLETED)
def print_lr():
print(default_optimizer.param_groups[0]["lr"])
# In this example, we assume to have installed PyTorch>=1.1.0
# (with new `torch.optim.lr_scheduler` behaviour) and
# we attach scheduler to Events.ITERATION_COMPLETED
# instead of Events.ITERATION_STARTED to make sure to use
# the first lr value from the optimizer, otherwise it is will be skipped:
default_trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)
default_trainer.run([0] * 8, max_epochs=1)
.. testoutput::
Expand All @@ -836,9 +831,17 @@ def print_lr():
0.001...
.. versionadded:: 0.4.5
.. versionchanged:: 0.5.0
added `use_legacy` argument
"""

def __init__(self, lr_scheduler: _LRScheduler, save_history: bool = False):
def __init__(
self,
lr_scheduler: _LRScheduler,
save_history: bool = False,
use_legacy: bool = False,
):

if not isinstance(lr_scheduler, _LRScheduler):
raise TypeError(
Expand All @@ -852,11 +855,19 @@ def __init__(self, lr_scheduler: _LRScheduler, save_history: bool = False):
param_name="lr",
save_history=save_history,
)
if use_legacy:
warnings.warn(
"Please make sure to attach scheduler to Events.ITERATION_COMPLETED "
"instead of Events.ITERATION_STARTED to make sure to use "
"the first lr value from the optimizer, otherwise it is will be skipped"
)
self.lr_scheduler.last_epoch += 1 # type: ignore[attr-defined]

self._state_attrs += ["lr_scheduler"]

def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None:
self.lr_scheduler.last_epoch += 1 # type: ignore[attr-defined]
super(LRScheduler, self).__call__(engine, name)
self.lr_scheduler.last_epoch += 1 # type: ignore[attr-defined]

def get_param(self) -> Union[float, List[float]]:
"""Method to get current optimizer's parameter value"""
Expand Down Expand Up @@ -904,9 +915,9 @@ def simulate_values( # type: ignore[override]
values = []
scheduler = cls(save_history=False, lr_scheduler=lr_scheduler, **kwargs)
for i in range(num_events):
scheduler(engine=None)
params = [p[scheduler.param_name] for p in scheduler.optimizer_param_groups]
values.append([i] + params)
scheduler(engine=None)

obj = torch.load(cache_filepath.as_posix())
lr_scheduler.load_state_dict(obj["lr_scheduler"])
Expand All @@ -927,8 +938,7 @@ def create_lr_scheduler_with_warmup(
Helper method to create a learning rate scheduler with a linear warm-up.
Args:
lr_scheduler: learning rate scheduler
after the warm-up.
lr_scheduler: learning rate scheduler after the warm-up.
warmup_start_value: learning rate start value of the warm-up phase.
warmup_duration: warm-up phase duration, number of events.
warmup_end_value: learning rate end value of the warm-up phase, (default=None). If None,
Expand Down Expand Up @@ -1011,10 +1021,15 @@ def print_lr():

if isinstance(lr_scheduler, _LRScheduler):
init_lr = param_group["lr"]

if init_lr != param_group_warmup_end_value:
milestones_values.append((warmup_duration, init_lr))

# We need to advance torch lr_scheduler to avoid duplicated lr value
# given by PiecewiseLinear and LRScheduler.
# We suggest to attach output scheduler on ITERATION_STARTED but
# torch lr_scheduler works with ITERATION_COMPLETED
# See also https://github.com/pytorch/ignite/pull/2496#issuecomment-1065984440
lr_scheduler.last_epoch += 1
lr_scheduler = LRScheduler(lr_scheduler, save_history=save_history)
else:
init_lr = lr_scheduler.get_param()
Expand Down
4 changes: 2 additions & 2 deletions tests/ignite/contrib/engines/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def update_fn(engine, batch):

# Check LR scheduling
assert optimizer.param_groups[0]["lr"] <= lr * gamma ** (
num_iters * num_epochs / step_size
), f"{optimizer.param_groups[0]['lr']} vs {lr * gamma ** (num_iters * num_epochs / step_size)}"
(num_iters * num_epochs - 1) // step_size
), f"{optimizer.param_groups[0]['lr']} vs {lr * gamma ** ((num_iters * num_epochs - 1) // step_size)}"


def test_asserts_setup_common_training_handlers():
Expand Down
88 changes: 48 additions & 40 deletions tests/ignite/handlers/test_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,53 +648,71 @@ def test_lr_scheduler(torch_lr_scheduler_cls, kwargs):
tensor = torch.zeros([1], requires_grad=True)
optimizer1 = torch.optim.SGD([tensor], lr=0.01)
optimizer2 = torch.optim.SGD([tensor], lr=0.01)
optimizer3 = torch.optim.SGD([tensor], lr=0.01)
opt_state_dict1 = optimizer1.state_dict()
opt_state_dict2 = optimizer2.state_dict()
opt_state_dict3 = optimizer3.state_dict()

torch_lr_scheduler1 = torch_lr_scheduler_cls(optimizer=optimizer1, **kwargs)
scheduler = LRScheduler(torch_lr_scheduler1)
state_dict1 = scheduler.state_dict()
scheduler1 = LRScheduler(torch_lr_scheduler1)
state_dict1 = scheduler1.state_dict()

torch_lr_scheduler2 = torch_lr_scheduler_cls(optimizer=optimizer2, **kwargs)
state_dict2 = torch_lr_scheduler2.state_dict()
with pytest.warns(UserWarning, match=r"the first lr value from the optimizer, otherwise it is will be skipped"):
scheduler2 = LRScheduler(torch_lr_scheduler2, use_legacy=True)
state_dict2 = scheduler2.state_dict()

torch_lr_scheduler3 = torch_lr_scheduler_cls(optimizer=optimizer3, **kwargs)
state_dict3 = torch_lr_scheduler3.state_dict()

def dummy_update(engine, batch):
optimizer1.step()
optimizer2.step()
optimizer3.step()

trainer = Engine(dummy_update)
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler1)

@trainer.on(Events.ITERATION_STARTED)
def save_lr(engine):
lrs.append(optimizer1.param_groups[0]["lr"])
def save_lr1(engine):
lrs1.append(optimizer1.param_groups[0]["lr"])

@trainer.on(Events.ITERATION_STARTED)
def save_lr2(engine):
lrs2.append(optimizer2.param_groups[0]["lr"])

@trainer.on(Events.ITERATION_STARTED)
def save_true_lr(engine):
lrs_true.append(optimizer2.param_groups[0]["lr"])
lrs_true.append(optimizer3.param_groups[0]["lr"])

@trainer.on(Events.ITERATION_COMPLETED)
def torch_lr_scheduler_step(engine):
torch_lr_scheduler2.step()
torch_lr_scheduler3.step()

trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)
trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler2)

for _ in range(2):
lrs = []
lrs1 = []
lrs2 = []
lrs_true = []
data = [0] * 10
max_epochs = 2
trainer.run(data, max_epochs=max_epochs)
assert lrs_true == pytest.approx(lrs), f"{_}: {lrs_true} ({len(lrs_true)}) vs {lrs} ({len(lrs)})"
assert lrs_true == pytest.approx(lrs1), f"{_}: {lrs_true} ({len(lrs_true)}) vs {lrs1} ({len(lrs1)})"
assert lrs_true == pytest.approx(lrs2), f"{_}: {lrs_true} ({len(lrs_true)}) vs {lrs2} ({len(lrs2)})"
optimizer1.load_state_dict(opt_state_dict1)
scheduler.load_state_dict(state_dict1)
scheduler1.load_state_dict(state_dict1)
optimizer2.load_state_dict(opt_state_dict2)
torch_lr_scheduler2.load_state_dict(state_dict2)
scheduler2.load_state_dict(state_dict2)
optimizer3.load_state_dict(opt_state_dict3)
torch_lr_scheduler3.load_state_dict(state_dict3)

optimizer3 = torch.optim.SGD([tensor], lr=0.01)
torch_lr_scheduler3 = torch_lr_scheduler_cls(optimizer=optimizer3, **kwargs)
optimizer4 = torch.optim.SGD([tensor], lr=0.01)
torch_lr_scheduler4 = torch_lr_scheduler_cls(optimizer=optimizer4, **kwargs)

simulated_values = LRScheduler.simulate_values(num_events=len(data) * max_epochs, lr_scheduler=torch_lr_scheduler3)
assert lrs == pytest.approx([v for i, v in simulated_values])
simulated_values = LRScheduler.simulate_values(num_events=len(data) * max_epochs, lr_scheduler=torch_lr_scheduler4)
assert lrs1 == pytest.approx([v for i, v in simulated_values])
assert lrs2 == pytest.approx([v for i, v in simulated_values])


def test_piecewiselinear_asserts():
Expand Down Expand Up @@ -813,11 +831,8 @@ def test_simulate_and_plot_values():

def _test(scheduler_cls, **scheduler_kwargs):

optimizer = None
event = Events.ITERATION_STARTED
if scheduler_cls == LRScheduler:
optimizer = scheduler_kwargs["lr_scheduler"].optimizer
event = Events.ITERATION_COMPLETED
elif scheduler_cls == ConcatScheduler:
optimizer = scheduler_kwargs["optimizer"]
del scheduler_kwargs["optimizer"]
Expand All @@ -828,7 +843,7 @@ def _test(scheduler_cls, **scheduler_kwargs):

max_epochs = 2
data = [0] * 10
# simulated_values = scheduler_cls.simulate_values(num_events=len(data) * max_epochs, **scheduler_kwargs)
simulated_values = scheduler_cls.simulate_values(num_events=len(data) * max_epochs, **scheduler_kwargs)

scheduler = scheduler_cls(**scheduler_kwargs)

Expand All @@ -838,15 +853,11 @@ def save_lr(engine):
lrs.append(optimizer.param_groups[0]["lr"])

trainer = Engine(lambda engine, batch: None)
trainer.add_event_handler(event, scheduler)
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
trainer.add_event_handler(Events.ITERATION_STARTED, save_lr)
trainer.run(data, max_epochs=max_epochs)

# assert lrs == pytest.approx([v for i, v in simulated_values])

if scheduler_cls == LRScheduler or scheduler_cls == ConcatScheduler:
# As internal state of torch lr scheduler has been changed the following checks will fail
return
assert lrs == pytest.approx([v for i, v in simulated_values])

# reexecute to check if no internal changes
# simulated_values = scheduler_cls.simulate_values(num_events=len(data) * max_epochs,
Expand Down Expand Up @@ -937,7 +948,7 @@ def test_create_lr_scheduler_with_warmup_asserts():


@pytest.mark.parametrize(
"lrsched_warmup_config",
"lr_scheduler_name, warmup_start_value, warmup_end_value, warmup_duration, warmup_end_next_value",
[
# A) opt lr != warmup_end_value
("ExponentialLR", 0.01, 0.05, 10, 0.2),
Expand All @@ -955,15 +966,9 @@ def test_create_lr_scheduler_with_warmup_asserts():
("ExponentialLR", 0.01, None, 10, 0.2 * 0.98),
],
)
def test_create_lr_scheduler_with_warmup(lrsched_warmup_config):

(
lr_scheduler_name,
warmup_start_value,
warmup_end_value,
warmup_duration,
warmup_end_next_value,
) = lrsched_warmup_config
def test_create_lr_scheduler_with_warmup(
lr_scheduler_name, warmup_start_value, warmup_end_value, warmup_duration, warmup_end_next_value
):

t1 = torch.zeros([1], requires_grad=True)

Expand All @@ -981,6 +986,11 @@ def test_create_lr_scheduler_with_warmup(lrsched_warmup_config):
num_iterations = 10
max_epochs = 20

if warmup_end_value is None:
expected_warmup_end_value = optimizer.param_groups[0]["lr"]
else:
expected_warmup_end_value = warmup_end_value

simulated_values = [None] * (num_iterations * max_epochs)
scheduler = create_lr_scheduler_with_warmup(
lr_scheduler,
Expand All @@ -989,8 +999,6 @@ def test_create_lr_scheduler_with_warmup(lrsched_warmup_config):
warmup_duration=warmup_duration,
output_simulated_values=simulated_values,
)
if warmup_end_value is None:
warmup_end_value = optimizer.param_groups[0]["lr"]

state_dict = scheduler.state_dict()
trainer = Engine(lambda engine, batch: None)
Expand All @@ -1007,11 +1015,11 @@ def save_lr(engine):
lrs = []
trainer.run(data, max_epochs=max_epochs)

assert lrs == pytest.approx([v for i, v in simulated_values])
assert lrs == pytest.approx([v for _, v in simulated_values])

assert lrs[0] == pytest.approx(warmup_start_value), f"lrs={lrs[: warmup_duration + num_iterations]}"
assert lrs[warmup_duration - 1] == pytest.approx(
warmup_end_value
expected_warmup_end_value
), f"lrs={lrs[: warmup_duration + num_iterations]}"
assert lrs[warmup_duration] == pytest.approx(
warmup_end_next_value
Expand Down

0 comments on commit 545d125

Please sign in to comment.