Skip to content

Conversation

@vfdev-5
Copy link
Collaborator

@vfdev-5 vfdev-5 commented Dec 23, 2019

Fixes #667 and #545

Description:

  • Added state_dict/load_state_dict
    scheduler = LinearCyclicalScheduler(optimizer, 'lr', 1, 0, 10)
    state_dict = scheduler.state_dict()
    # ....
    # restore previous state
    scheduler.load_state_dict(state_dict)
  • Breaking change BC with multiple optimizer's param groups

We need to keep a reference to optmizer and to fetch its group at runtime. For example, if we keep only param_group, then once user calls optimizer.load_state_dict(...), ignite's parameter scheduler looses its param_group to act on.

API change is :
Previously,

scheduler1 = LinearCyclicalScheduler(optimizer.param_groups[0], 'lr', 1e-7, 1e-5, len(train_loader))
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler1, "lr (base)")
scheduler2 = CosineAnnealingScheduler(optimizer.param_groups[1], 'lr', 1e-5, 1e-3, len(train_loader))
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler2, "lr (fc)")

now becomes

scheduler1 = LinearCyclicalScheduler(optimizer, 'lr', 1e-7, 1e-5, len(train_loader), param_group_index=0)
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler1, "lr (base)")
scheduler2 = CosineAnnealingScheduler(optimizer, 'lr', 1e-5, 1e-3, len(train_loader), param_group_index=1)
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler2, "lr (fc)")

This change also fixes the issue with Nvidia/APEX #545 for opt_level="O2". The following exemple, now works

import torch
from ignite.engine import Engine, Events
from ignite.contrib.handlers import PiecewiseLinear

t = torch.tensor([0])
optimizer = torch.optim.SGD([t], lr=0.0)

milestones_values = [(10, 1.0), (20, 0.0)]
lr_scheduler = PiecewiseLinear(optimizer, param_name='lr', milestones_values=milestones_values)

model = torch.nn.Sequential(torch.nn.Linear(10, 5)).to('cuda')

from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", num_losses=1)

def save_lr(engine):
    lrs.append(optimizer.param_groups[0]['lr'])

trainer = Engine(lambda engine, batch: None)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)

lrs = []
trainer.run([0] * 25, max_epochs=2)

import matplotlib.pylab as plt
%matplotlib inline
plt.plot(lrs)
  • updated docs

Check list:

  • New tests are added (if a new feature is added)
  • New doc strings: description and/or example code are in RST format
  • Documentation is updated (if required)

- Added state_dict/load_state_dict
- Breaking change BC with multiple optimizer's param groups
- updated docs
@vfdev-5 vfdev-5 merged commit aaf420b into pytorch:master Dec 26, 2019
@vfdev-5 vfdev-5 deleted the issue_667 branch December 26, 2019 00:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add state_dict/load_state_dict to ParamScheduler

1 participant