Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BC-breaking] Feature/#2465 lr scheduler attach events #2496

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 0 additions & 3 deletions ignite/contrib/engines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ignite.contrib.handlers import (
ClearMLLogger,
global_step_from_engine,
LRScheduler,
MLflowLogger,
NeptuneLogger,
PolyaxonLogger,
Expand Down Expand Up @@ -165,8 +164,6 @@ def _setup_common_training_handlers(
trainer.add_event_handler(
Events.ITERATION_COMPLETED, lambda engine: cast(_LRScheduler, lr_scheduler).step()
)
elif isinstance(lr_scheduler, LRScheduler):
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
else:
trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

Expand Down
43 changes: 29 additions & 14 deletions ignite/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,7 @@ class LRScheduler(ParamScheduler):
lr_scheduler: lr_scheduler object to wrap.
save_history: whether to log the parameter values to
`engine.state.param_history`, (default=False).
use_legacy: if True, scheduler should be attached to ``Events.ITERATION_COMPLETED``, (default=False).

Examples:

Expand All @@ -808,20 +809,14 @@ class LRScheduler(ParamScheduler):
from torch.optim.lr_scheduler import StepLR

torch_lr_scheduler = StepLR(default_optimizer, step_size=3, gamma=0.1)

scheduler = LRScheduler(torch_lr_scheduler)

default_trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

@default_trainer.on(Events.ITERATION_COMPLETED)
def print_lr():
print(default_optimizer.param_groups[0]["lr"])

# In this example, we assume to have installed PyTorch>=1.1.0
# (with new `torch.optim.lr_scheduler` behaviour) and
# we attach scheduler to Events.ITERATION_COMPLETED
# instead of Events.ITERATION_STARTED to make sure to use
# the first lr value from the optimizer, otherwise it is will be skipped:
default_trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)

default_trainer.run([0] * 8, max_epochs=1)

.. testoutput::
Expand All @@ -836,9 +831,17 @@ def print_lr():
0.001...

.. versionadded:: 0.4.5

.. versionchanged:: 0.5.0
added `use_legacy` argument
"""

def __init__(self, lr_scheduler: _LRScheduler, save_history: bool = False):
def __init__(
self,
lr_scheduler: _LRScheduler,
save_history: bool = False,
use_legacy: bool = False,
):

if not isinstance(lr_scheduler, _LRScheduler):
raise TypeError(
Expand All @@ -852,11 +855,19 @@ def __init__(self, lr_scheduler: _LRScheduler, save_history: bool = False):
param_name="lr",
save_history=save_history,
)
if use_legacy:
warnings.warn(
"Please make sure to attach scheduler to Events.ITERATION_COMPLETED "
"instead of Events.ITERATION_STARTED to make sure to use "
"the first lr value from the optimizer, otherwise it is will be skipped"
)
self.lr_scheduler.last_epoch += 1 # type: ignore[attr-defined]

self._state_attrs += ["lr_scheduler"]

def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None:
self.lr_scheduler.last_epoch += 1 # type: ignore[attr-defined]
super(LRScheduler, self).__call__(engine, name)
self.lr_scheduler.last_epoch += 1 # type: ignore[attr-defined]

def get_param(self) -> Union[float, List[float]]:
"""Method to get current optimizer's parameter value"""
Expand Down Expand Up @@ -904,9 +915,9 @@ def simulate_values( # type: ignore[override]
values = []
scheduler = cls(save_history=False, lr_scheduler=lr_scheduler, **kwargs)
for i in range(num_events):
scheduler(engine=None)
params = [p[scheduler.param_name] for p in scheduler.optimizer_param_groups]
values.append([i] + params)
scheduler(engine=None)

obj = torch.load(cache_filepath.as_posix())
lr_scheduler.load_state_dict(obj["lr_scheduler"])
Expand All @@ -927,8 +938,7 @@ def create_lr_scheduler_with_warmup(
Helper method to create a learning rate scheduler with a linear warm-up.

Args:
lr_scheduler: learning rate scheduler
after the warm-up.
lr_scheduler: learning rate scheduler after the warm-up.
warmup_start_value: learning rate start value of the warm-up phase.
warmup_duration: warm-up phase duration, number of events.
warmup_end_value: learning rate end value of the warm-up phase, (default=None). If None,
Expand Down Expand Up @@ -1011,10 +1021,15 @@ def print_lr():

if isinstance(lr_scheduler, _LRScheduler):
init_lr = param_group["lr"]

if init_lr != param_group_warmup_end_value:
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
milestones_values.append((warmup_duration, init_lr))

# We need to advance torch lr_scheduler to avoid duplicated lr value
# given by PiecewiseLinear and LRScheduler.
# We suggest to attach output scheduler on ITERATION_STARTED but
# torch lr_scheduler works with ITERATION_COMPLETED
# See also https://github.com/pytorch/ignite/pull/2496#issuecomment-1065984440
lr_scheduler.last_epoch += 1
lr_scheduler = LRScheduler(lr_scheduler, save_history=save_history)
else:
init_lr = lr_scheduler.get_param()
Expand Down
4 changes: 2 additions & 2 deletions tests/ignite/contrib/engines/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def update_fn(engine, batch):

# Check LR scheduling
assert optimizer.param_groups[0]["lr"] <= lr * gamma ** (
num_iters * num_epochs / step_size
), f"{optimizer.param_groups[0]['lr']} vs {lr * gamma ** (num_iters * num_epochs / step_size)}"
(num_iters * num_epochs - 1) // step_size
), f"{optimizer.param_groups[0]['lr']} vs {lr * gamma ** ((num_iters * num_epochs - 1) // step_size)}"


def test_asserts_setup_common_training_handlers():
Expand Down
88 changes: 48 additions & 40 deletions tests/ignite/handlers/test_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,53 +648,71 @@ def test_lr_scheduler(torch_lr_scheduler_cls, kwargs):
tensor = torch.zeros([1], requires_grad=True)
optimizer1 = torch.optim.SGD([tensor], lr=0.01)
optimizer2 = torch.optim.SGD([tensor], lr=0.01)
optimizer3 = torch.optim.SGD([tensor], lr=0.01)
opt_state_dict1 = optimizer1.state_dict()
opt_state_dict2 = optimizer2.state_dict()
opt_state_dict3 = optimizer3.state_dict()

torch_lr_scheduler1 = torch_lr_scheduler_cls(optimizer=optimizer1, **kwargs)
scheduler = LRScheduler(torch_lr_scheduler1)
state_dict1 = scheduler.state_dict()
scheduler1 = LRScheduler(torch_lr_scheduler1)
state_dict1 = scheduler1.state_dict()

torch_lr_scheduler2 = torch_lr_scheduler_cls(optimizer=optimizer2, **kwargs)
state_dict2 = torch_lr_scheduler2.state_dict()
with pytest.warns(UserWarning, match=r"the first lr value from the optimizer, otherwise it is will be skipped"):
scheduler2 = LRScheduler(torch_lr_scheduler2, use_legacy=True)
state_dict2 = scheduler2.state_dict()

torch_lr_scheduler3 = torch_lr_scheduler_cls(optimizer=optimizer3, **kwargs)
state_dict3 = torch_lr_scheduler3.state_dict()

def dummy_update(engine, batch):
optimizer1.step()
optimizer2.step()
optimizer3.step()

trainer = Engine(dummy_update)
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler1)

@trainer.on(Events.ITERATION_STARTED)
def save_lr(engine):
lrs.append(optimizer1.param_groups[0]["lr"])
def save_lr1(engine):
lrs1.append(optimizer1.param_groups[0]["lr"])

@trainer.on(Events.ITERATION_STARTED)
def save_lr2(engine):
lrs2.append(optimizer2.param_groups[0]["lr"])

@trainer.on(Events.ITERATION_STARTED)
def save_true_lr(engine):
lrs_true.append(optimizer2.param_groups[0]["lr"])
lrs_true.append(optimizer3.param_groups[0]["lr"])

@trainer.on(Events.ITERATION_COMPLETED)
def torch_lr_scheduler_step(engine):
torch_lr_scheduler2.step()
torch_lr_scheduler3.step()

trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)
trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler2)

for _ in range(2):
lrs = []
lrs1 = []
lrs2 = []
lrs_true = []
data = [0] * 10
max_epochs = 2
trainer.run(data, max_epochs=max_epochs)
assert lrs_true == pytest.approx(lrs), f"{_}: {lrs_true} ({len(lrs_true)}) vs {lrs} ({len(lrs)})"
assert lrs_true == pytest.approx(lrs1), f"{_}: {lrs_true} ({len(lrs_true)}) vs {lrs1} ({len(lrs1)})"
assert lrs_true == pytest.approx(lrs2), f"{_}: {lrs_true} ({len(lrs_true)}) vs {lrs2} ({len(lrs2)})"
optimizer1.load_state_dict(opt_state_dict1)
scheduler.load_state_dict(state_dict1)
scheduler1.load_state_dict(state_dict1)
optimizer2.load_state_dict(opt_state_dict2)
torch_lr_scheduler2.load_state_dict(state_dict2)
scheduler2.load_state_dict(state_dict2)
optimizer3.load_state_dict(opt_state_dict3)
torch_lr_scheduler3.load_state_dict(state_dict3)

optimizer3 = torch.optim.SGD([tensor], lr=0.01)
torch_lr_scheduler3 = torch_lr_scheduler_cls(optimizer=optimizer3, **kwargs)
optimizer4 = torch.optim.SGD([tensor], lr=0.01)
torch_lr_scheduler4 = torch_lr_scheduler_cls(optimizer=optimizer4, **kwargs)

simulated_values = LRScheduler.simulate_values(num_events=len(data) * max_epochs, lr_scheduler=torch_lr_scheduler3)
assert lrs == pytest.approx([v for i, v in simulated_values])
simulated_values = LRScheduler.simulate_values(num_events=len(data) * max_epochs, lr_scheduler=torch_lr_scheduler4)
assert lrs1 == pytest.approx([v for i, v in simulated_values])
assert lrs2 == pytest.approx([v for i, v in simulated_values])


def test_piecewiselinear_asserts():
Expand Down Expand Up @@ -813,11 +831,8 @@ def test_simulate_and_plot_values():

def _test(scheduler_cls, **scheduler_kwargs):

optimizer = None
event = Events.ITERATION_STARTED
if scheduler_cls == LRScheduler:
optimizer = scheduler_kwargs["lr_scheduler"].optimizer
event = Events.ITERATION_COMPLETED
elif scheduler_cls == ConcatScheduler:
optimizer = scheduler_kwargs["optimizer"]
del scheduler_kwargs["optimizer"]
Expand All @@ -828,7 +843,7 @@ def _test(scheduler_cls, **scheduler_kwargs):

max_epochs = 2
data = [0] * 10
# simulated_values = scheduler_cls.simulate_values(num_events=len(data) * max_epochs, **scheduler_kwargs)
simulated_values = scheduler_cls.simulate_values(num_events=len(data) * max_epochs, **scheduler_kwargs)

scheduler = scheduler_cls(**scheduler_kwargs)

Expand All @@ -838,15 +853,11 @@ def save_lr(engine):
lrs.append(optimizer.param_groups[0]["lr"])

trainer = Engine(lambda engine, batch: None)
trainer.add_event_handler(event, scheduler)
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
trainer.add_event_handler(Events.ITERATION_STARTED, save_lr)
trainer.run(data, max_epochs=max_epochs)

# assert lrs == pytest.approx([v for i, v in simulated_values])

if scheduler_cls == LRScheduler or scheduler_cls == ConcatScheduler:
# As internal state of torch lr scheduler has been changed the following checks will fail
return
assert lrs == pytest.approx([v for i, v in simulated_values])

# reexecute to check if no internal changes
# simulated_values = scheduler_cls.simulate_values(num_events=len(data) * max_epochs,
Expand Down Expand Up @@ -937,7 +948,7 @@ def test_create_lr_scheduler_with_warmup_asserts():


@pytest.mark.parametrize(
"lrsched_warmup_config",
"lr_scheduler_name, warmup_start_value, warmup_end_value, warmup_duration, warmup_end_next_value",
[
# A) opt lr != warmup_end_value
("ExponentialLR", 0.01, 0.05, 10, 0.2),
Expand All @@ -955,15 +966,9 @@ def test_create_lr_scheduler_with_warmup_asserts():
("ExponentialLR", 0.01, None, 10, 0.2 * 0.98),
],
)
def test_create_lr_scheduler_with_warmup(lrsched_warmup_config):

(
lr_scheduler_name,
warmup_start_value,
warmup_end_value,
warmup_duration,
warmup_end_next_value,
) = lrsched_warmup_config
def test_create_lr_scheduler_with_warmup(
lr_scheduler_name, warmup_start_value, warmup_end_value, warmup_duration, warmup_end_next_value
):

t1 = torch.zeros([1], requires_grad=True)

Expand All @@ -981,6 +986,11 @@ def test_create_lr_scheduler_with_warmup(lrsched_warmup_config):
num_iterations = 10
max_epochs = 20

if warmup_end_value is None:
expected_warmup_end_value = optimizer.param_groups[0]["lr"]
else:
expected_warmup_end_value = warmup_end_value

simulated_values = [None] * (num_iterations * max_epochs)
scheduler = create_lr_scheduler_with_warmup(
lr_scheduler,
Expand All @@ -989,8 +999,6 @@ def test_create_lr_scheduler_with_warmup(lrsched_warmup_config):
warmup_duration=warmup_duration,
output_simulated_values=simulated_values,
)
if warmup_end_value is None:
warmup_end_value = optimizer.param_groups[0]["lr"]

state_dict = scheduler.state_dict()
trainer = Engine(lambda engine, batch: None)
Expand All @@ -1007,11 +1015,11 @@ def save_lr(engine):
lrs = []
trainer.run(data, max_epochs=max_epochs)

assert lrs == pytest.approx([v for i, v in simulated_values])
assert lrs == pytest.approx([v for _, v in simulated_values])

assert lrs[0] == pytest.approx(warmup_start_value), f"lrs={lrs[: warmup_duration + num_iterations]}"
assert lrs[warmup_duration - 1] == pytest.approx(
warmup_end_value
expected_warmup_end_value
), f"lrs={lrs[: warmup_duration + num_iterations]}"
assert lrs[warmup_duration] == pytest.approx(
warmup_end_next_value
Expand Down