Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BC-breaking] Feature/#2465 lr scheduler attach events #2496

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 0 additions & 3 deletions ignite/contrib/engines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ignite.contrib.handlers import (
ClearMLLogger,
global_step_from_engine,
LRScheduler,
MLflowLogger,
NeptuneLogger,
PolyaxonLogger,
Expand Down Expand Up @@ -165,8 +164,6 @@ def _setup_common_training_handlers(
trainer.add_event_handler(
Events.ITERATION_COMPLETED, lambda engine: cast(_LRScheduler, lr_scheduler).step()
)
elif isinstance(lr_scheduler, LRScheduler):
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
else:
trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

Expand Down
28 changes: 16 additions & 12 deletions ignite/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,7 @@ class LRScheduler(ParamScheduler):
lr_scheduler: lr_scheduler object to wrap.
save_history: whether to log the parameter values to
`engine.state.param_history`, (default=False).
use_legacy: whether for LRScheduler to behave previously, (default=False).
yuta0821 marked this conversation as resolved.
Show resolved Hide resolved

Examples:

Expand All @@ -808,20 +809,16 @@ class LRScheduler(ParamScheduler):
from torch.optim.lr_scheduler import StepLR

torch_lr_scheduler = StepLR(default_optimizer, step_size=3, gamma=0.1)

scheduler = LRScheduler(torch_lr_scheduler)

# In this example, we assume to have installed PyTorch>=1.1.0
# (with new `torch.optim.lr_scheduler` behaviour)
yuta0821 marked this conversation as resolved.
Show resolved Hide resolved
default_trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

@default_trainer.on(Events.ITERATION_COMPLETED)
def print_lr():
print(default_optimizer.param_groups[0]["lr"])

# In this example, we assume to have installed PyTorch>=1.1.0
# (with new `torch.optim.lr_scheduler` behaviour) and
# we attach scheduler to Events.ITERATION_COMPLETED
# instead of Events.ITERATION_STARTED to make sure to use
# the first lr value from the optimizer, otherwise it is will be skipped:
default_trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)

default_trainer.run([0] * 8, max_epochs=1)

.. testoutput::
Expand All @@ -838,7 +835,7 @@ def print_lr():
.. versionadded:: 0.4.5
"""

def __init__(self, lr_scheduler: _LRScheduler, save_history: bool = False):
def __init__(self, lr_scheduler: _LRScheduler, save_history: bool = False, use_legacy: bool = False):

if not isinstance(lr_scheduler, _LRScheduler):
raise TypeError(
Expand All @@ -852,11 +849,18 @@ def __init__(self, lr_scheduler: _LRScheduler, save_history: bool = False):
param_name="lr",
save_history=save_history,
)
if use_legacy:
warnings.warn(
"we attach scheduler to Events.ITERATION_COMPLETED"
yuta0821 marked this conversation as resolved.
Show resolved Hide resolved
"instead of Events.ITERATION_STARTED to make sure to use"
"the first lr value from the optimizer, otherwise it is will be skipped"
)
self.lr_scheduler.last_epoch += 1 # type: ignore[attr-defined]
self._state_attrs += ["lr_scheduler"]

def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None:
self.lr_scheduler.last_epoch += 1 # type: ignore[attr-defined]
super(LRScheduler, self).__call__(engine, name)
self.lr_scheduler.last_epoch += 1 # type: ignore[attr-defined]

def get_param(self) -> Union[float, List[float]]:
"""Method to get current optimizer's parameter value"""
Expand Down Expand Up @@ -904,9 +908,9 @@ def simulate_values( # type: ignore[override]
values = []
scheduler = cls(save_history=False, lr_scheduler=lr_scheduler, **kwargs)
for i in range(num_events):
scheduler(engine=None)
params = [p[scheduler.param_name] for p in scheduler.optimizer_param_groups]
values.append([i] + params)
scheduler(engine=None)

obj = torch.load(cache_filepath.as_posix())
lr_scheduler.load_state_dict(obj["lr_scheduler"])
Expand Down Expand Up @@ -1015,7 +1019,7 @@ def print_lr():
if init_lr != param_group_warmup_end_value:
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
milestones_values.append((warmup_duration, init_lr))

lr_scheduler = LRScheduler(lr_scheduler, save_history=save_history)
lr_scheduler = LRScheduler(lr_scheduler, save_history=save_history, use_legacy=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check if we could use use_legacy=False ? Or what is blocking here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if use_legacy=False, the result look like the example below

  • example1
from torch.optim.lr_scheduler import ExponentialLR

tensor = torch.zeros([1], requires_grad=True)
optimizer = torch.optim.SGD([tensor], lr=0.01)
torch_lr_scheduler = StepLR(optimizer, step_size=3, gamma=0.1)

trainer = Engine(dummy_update)
scheduler = create_lr_scheduler_with_warmup(torch_lr_scheduler,
                                            warmup_start_value=0.0,
                                            warmup_end_value=0.1,
                                            warmup_duration=3)

trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

@trainer.on(Events.ITERATION_COMPLETED)
def print_lr():
    print(optimizer.param_groups[0]["lr"])

_ = trainer.run([0] * 8, max_epochs=1)
  • output1
0.0
0.05
0.1
0.01
0.01
0.01
0.01
0.001
  • example2
from torch.optim.lr_scheduler import ExponentialLR

tensor = torch.zeros([1], requires_grad=True)
optimizer = torch.optim.SGD([tensor], lr=0.1)
torch_lr_scheduler = StepLR(optimizer, step_size=3, gamma=0.1)

trainer = Engine(dummy_update)
scheduler = create_lr_scheduler_with_warmup(torch_lr_scheduler,
                                            warmup_start_value=0.0,
                                            warmup_end_value=0.1,
                                            warmup_duration=3)

trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

@trainer.on(Events.ITERATION_COMPLETED)
def print_lr():
    print(optimizer.param_groups[0]["lr"])

_ = trainer.run([0] * 8, max_epochs=1)
  • output2
0.0
0.05
0.1
0.1
0.1
0.1
0.010000000000000002
0.010000000000000002

The initial value of stepLR (= 0.1 or 0.01) is counted one extra time than step_size (=3).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the details @yuta0821 , can you please debug this a bit more and explicitly say which scheduler is responsible for adding LR value (0.1 or 0.01). I'm not quite sure to understand why exactly this happens. Thanks !

Copy link
Contributor Author

@yuta0821 yuta0821 Feb 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vfdev-5 Thanks a lot for your comment !
Consider the case where warmup_start_value=0.0, warmup_end_value=0.1, warmup_duration=3.
In this case, milestones_values = [(0, 0.0), (2, 0.1)]
If the initial value of lr in the optimizer is different from the warmup_end_value, it is necessary to add the initial value to the end of milestones. Therefore, milestones_values = [(0, 0.0), (2, 0.1), (3, initial value of lr)]
This is because the LRScheduler updates the lr starting from the last value of the milestones_values.
After that the following code is executed, resulting in repeating the initial value of lr.

super(LRScheduler, self).__call__(engine, name)
self.lr_scheduler.last_epoch += 1  # type: ignore[attr-defined]

Even if the initial value of lr in the optimizer is equal to the warmup_end_value, then the initial value of lr will be called extra once.

In the end, since the first __call__ method of LRScheduler runs with reference to the last value of milestones_values, the last value of milestones_values plus the initial value of LRScheduler are duplicated.

If we adjust this bug without use_legacy=False, we may have to change a lot of code such like one related to the PeacewiseLinear , which is beyond the scope of this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuta0821 sorry for delay and thanks for the explanation! I haven't checked it in details but will do as soon as it could be possible from my side (~4-5 days).
@sdesrozis can you help with that if you have some bandwidth ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It performs the same operation as use_legacy, but wouldn't it be preferable to add the argument skip_initial_value as a variable to be used for the internal function ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, expected rather than excepted 😅

As far I understand, using use_legacy=False, the first lr comes from the optimizer. Whatever schedulers used, the schedulers concatenation will produce a repetition at each joint.

Having an internal option as you suggested sounds good to me. I mean rename use_legacy to skip_initial_value is fine. Although, we have to keep use_legacy for the users.

What do you think ?

Copy link
Contributor Author

@yuta0821 yuta0821 Mar 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sdesrozis
I am sorry for having consulted with you.
It seems that this problem can be solved by setting the internal variable keep_first_value=True only when create_lr_scheduler_with_warmup is called, as shown below, to store the initial value in LRScheduler.

class LRScheduler(ParamScheduler):
    def __init__(self, lr_scheduler: _LRScheduler, save_history: bool = False, use_legacy: bool = False, keep_first_lr: bool = False):
        if keep_first_lr:
            self.lr_scheduler._get_lr_called_within_step = True  # type: ignore[attr-defined]
            self.first_lr = self.lr_scheduler.get_lr()
            self.lr_scheduler._get_lr_called_within_step = False  # type: ignore[attr-defined]

    def get_param(self) -> Union[float, List[float]]:
        """Method to get current optimizer's parameter value"""
        # Emulate context manager for pytorch>=1.4
        if hasattr(self, "first_lr"):
            lr_list = self.first_lr
            del self.first_lr
        else:
       
def create_lr_scheduler_with_warmup( ):
    if isinstance(lr_scheduler, _LRScheduler):
        init_lr = param_group["lr"]
        lr_scheduler = LRScheduler(lr_scheduler, save_history=save_history, keep_first_lr=True)
    else:
        init_lr = lr_scheduler.get_param()

I am running the existing test now. I will commit once all tests pass ! -> Done !

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the tests are ko...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I added type: ignore[attr-defined] !

else:
init_lr = lr_scheduler.get_param()
if init_lr == param_group_warmup_end_value:
Expand Down
69 changes: 39 additions & 30 deletions tests/ignite/handlers/test_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,55 +643,73 @@ def _test(torch_lr_scheduler_cls, **kwargs):
tensor = torch.zeros([1], requires_grad=True)
optimizer1 = torch.optim.SGD([tensor], lr=0.01)
optimizer2 = torch.optim.SGD([tensor], lr=0.01)
optimizer3 = torch.optim.SGD([tensor], lr=0.01)
opt_state_dict1 = optimizer1.state_dict()
opt_state_dict2 = optimizer2.state_dict()
opt_state_dict3 = optimizer3.state_dict()

torch_lr_scheduler1 = torch_lr_scheduler_cls(optimizer=optimizer1, **kwargs)
scheduler = LRScheduler(torch_lr_scheduler1)
state_dict1 = scheduler.state_dict()
scheduler1 = LRScheduler(torch_lr_scheduler1)
state_dict1 = scheduler1.state_dict()

torch_lr_scheduler2 = torch_lr_scheduler_cls(optimizer=optimizer2, **kwargs)
state_dict2 = torch_lr_scheduler2.state_dict()
scheduler2 = LRScheduler(torch_lr_scheduler2, use_legacy=True)
state_dict2 = scheduler2.state_dict()

torch_lr_scheduler3 = torch_lr_scheduler_cls(optimizer=optimizer3, **kwargs)
state_dict3 = torch_lr_scheduler3.state_dict()

def dummy_update(engine, batch):
optimizer1.step()
optimizer2.step()
optimizer3.step()

trainer = Engine(dummy_update)

trainer.add_event_handler(Events.ITERATION_STARTED, scheduler1)

@trainer.on(Events.ITERATION_STARTED)
def save_lr(engine):
lrs.append(optimizer1.param_groups[0]["lr"])
def save_lr1(engine):
lrs1.append(optimizer1.param_groups[0]["lr"])

@trainer.on(Events.ITERATION_STARTED)
def save_lr2(engine):
lrs2.append(optimizer2.param_groups[0]["lr"])

@trainer.on(Events.ITERATION_STARTED)
def save_true_lr(engine):
lrs_true.append(optimizer2.param_groups[0]["lr"])
lrs_true.append(optimizer3.param_groups[0]["lr"])

@trainer.on(Events.ITERATION_COMPLETED)
def torch_lr_scheduler_step(engine):
torch_lr_scheduler2.step()
torch_lr_scheduler3.step()

trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)
trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler2)

for _ in range(2):
lrs = []
lrs1 = []
lrs2 = []
lrs_true = []
data = [0] * 10
max_epochs = 2
trainer.run(data, max_epochs=max_epochs)
assert lrs_true == pytest.approx(lrs), f"{_}: {lrs_true} ({len(lrs_true)}) vs {lrs} ({len(lrs)})"
assert lrs_true == pytest.approx(lrs1), f"{_}: {lrs_true} ({len(lrs_true)}) vs {lrs1} ({len(lrs1)})"
assert lrs_true == pytest.approx(lrs2), f"{_}: {lrs_true} ({len(lrs_true)}) vs {lrs2} ({len(lrs2)})"
optimizer1.load_state_dict(opt_state_dict1)
scheduler.load_state_dict(state_dict1)
scheduler1.load_state_dict(state_dict1)
optimizer2.load_state_dict(opt_state_dict2)
torch_lr_scheduler2.load_state_dict(state_dict2)
scheduler2.load_state_dict(state_dict2)
optimizer3.load_state_dict(opt_state_dict3)
torch_lr_scheduler3.load_state_dict(state_dict3)

optimizer3 = torch.optim.SGD([tensor], lr=0.01)
torch_lr_scheduler3 = torch_lr_scheduler_cls(optimizer=optimizer3, **kwargs)
optimizer4 = torch.optim.SGD([tensor], lr=0.01)
torch_lr_scheduler4 = torch_lr_scheduler_cls(optimizer=optimizer4, **kwargs)

simulated_values = LRScheduler.simulate_values(
num_events=len(data) * max_epochs, lr_scheduler=torch_lr_scheduler3
num_events=len(data) * max_epochs, lr_scheduler=torch_lr_scheduler4
)
assert lrs == pytest.approx([v for i, v in simulated_values])
assert lrs1 == pytest.approx([v for i, v in simulated_values])
assert lrs2 == pytest.approx([v for i, v in simulated_values])

_test(StepLR, step_size=5, gamma=0.5)
_test(ExponentialLR, gamma=0.78)
Expand Down Expand Up @@ -817,11 +835,8 @@ def test_simulate_and_plot_values():

def _test(scheduler_cls, **scheduler_kwargs):

optimizer = None
event = Events.ITERATION_STARTED
if scheduler_cls == LRScheduler:
optimizer = scheduler_kwargs["lr_scheduler"].optimizer
event = Events.ITERATION_COMPLETED
elif scheduler_cls == ConcatScheduler:
optimizer = scheduler_kwargs["optimizer"]
del scheduler_kwargs["optimizer"]
Expand All @@ -832,7 +847,7 @@ def _test(scheduler_cls, **scheduler_kwargs):

max_epochs = 2
data = [0] * 10
# simulated_values = scheduler_cls.simulate_values(num_events=len(data) * max_epochs, **scheduler_kwargs)
simulated_values = scheduler_cls.simulate_values(num_events=len(data) * max_epochs, **scheduler_kwargs)

scheduler = scheduler_cls(**scheduler_kwargs)

Expand All @@ -842,21 +857,15 @@ def save_lr(engine):
lrs.append(optimizer.param_groups[0]["lr"])

trainer = Engine(lambda engine, batch: None)
trainer.add_event_handler(event, scheduler)
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
trainer.add_event_handler(Events.ITERATION_STARTED, save_lr)
trainer.run(data, max_epochs=max_epochs)

# assert lrs == pytest.approx([v for i, v in simulated_values])

if scheduler_cls == LRScheduler or scheduler_cls == ConcatScheduler:
# As internal state of torch lr scheduler has been changed the following checks will fail
return
assert lrs == pytest.approx([v for i, v in simulated_values])

# reexecute to check if no internal changes
# simulated_values = scheduler_cls.simulate_values(num_events=len(data) * max_epochs,
# save_history=True, # this will be removed
# **scheduler_kwargs)
# assert lrs == pytest.approx([v for i, v in simulated_values])
simulated_values = scheduler_cls.simulate_values(num_events=len(data) * max_epochs, **scheduler_kwargs)
assert lrs == pytest.approx([v for i, v in simulated_values])

# launch plot values
scheduler_cls.plot_values(num_events=len(data) * max_epochs, **scheduler_kwargs)
Expand Down