Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ python:
- "3.6"

env:
- PYTORCH_PACKAGE=pytorch-cpu
- PYTORCH_PACKAGE=pytorch-nightly-cpu
- PYTORCH_CHANNEL=pytorch
- PYTORCH_CHANNEL=pytorch-nightly

stages:
- Lint check
Expand All @@ -25,7 +25,7 @@ before_install: &before_install
- conda update -q conda
# Useful for debugging any issues with conda
- conda info -a
- conda create -q -n test-environment -c pytorch python=$TRAVIS_PYTHON_VERSION $PYTORCH_PACKAGE
- conda create -q -n test-environment pytorch cpuonly python=$TRAVIS_PYTHON_VERSION -c $PYTORCH_CHANNEL
- source activate test-environment
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install enum34; fi
# Test contrib dependencies
Expand All @@ -39,7 +39,7 @@ install:
- pip install numpy mock pytest codecov pytest-cov
# Examples dependencies
- pip install matplotlib pandas
- conda install torchvision-cpu -c pytorch
- conda install torchvision -c $PYTORCH_CHANNEL
- pip install gym==0.10.11

script:
Expand Down
3 changes: 2 additions & 1 deletion ignite/contrib/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,8 @@ def _replicate_lr_scheduler(lr_scheduler, new_optimizer_param_groups=None):
if new_optimizer_param_groups is not None:
dummy_optimizer.param_groups = new_optimizer_param_groups
kwargs = lr_scheduler.state_dict()
del kwargs['base_lrs']
for k in ['base_lrs', '_step_count']:
del kwargs[k]
copy_lr_scheduler = lr_scheduler_cls(optimizer=dummy_optimizer, **kwargs)
copy_lr_scheduler.load_state_dict(lr_scheduler.state_dict())
return copy_lr_scheduler
Expand Down
9 changes: 7 additions & 2 deletions tests/ignite/contrib/handlers/test_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,11 @@ def _test(torch_lr_scheduler_cls, **kwargs):
lrs = []
lrs_true = []

trainer = Engine(lambda engine, batch: None)
def dummy_update(engine, batch):
optimizer1.step()
optimizer2.step()

trainer = Engine(dummy_update)

@trainer.on(Events.ITERATION_COMPLETED)
def torch_lr_scheduler_step(engine):
Expand Down Expand Up @@ -396,6 +400,7 @@ def save_true_lr(engine):
init_lr_scheduler_state = dict(lr_scheduler.state_dict())
copy_lr_scheduler = LRScheduler._replicate_lr_scheduler(lr_scheduler)
for _ in range(10):
optimizer.step()
lr_scheduler.step()

assert copy_lr_scheduler.state_dict() == init_lr_scheduler_state
Expand Down Expand Up @@ -444,7 +449,7 @@ def save_lr(engine):
lrs.append(optimizer.param_groups[0]['lr'])

trainer = Engine(lambda engine, batch: None)
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)
trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)
trainer.run([0] * 25, max_epochs=2)

Expand Down