diff --git a/.travis.yml b/.travis.yml index 029c66b3c285..9674bd18c806 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,8 +5,8 @@ python: - "3.6" env: - - PYTORCH_PACKAGE=pytorch-cpu - - PYTORCH_PACKAGE=pytorch-nightly-cpu + - PYTORCH_CHANNEL=pytorch + - PYTORCH_CHANNEL=pytorch-nightly stages: - Lint check @@ -25,7 +25,7 @@ before_install: &before_install - conda update -q conda # Useful for debugging any issues with conda - conda info -a - - conda create -q -n test-environment -c pytorch python=$TRAVIS_PYTHON_VERSION $PYTORCH_PACKAGE + - conda create -q -n test-environment pytorch cpuonly python=$TRAVIS_PYTHON_VERSION -c $PYTORCH_CHANNEL - source activate test-environment - if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install enum34; fi # Test contrib dependencies @@ -39,7 +39,7 @@ install: - pip install numpy mock pytest codecov pytest-cov # Examples dependencies - pip install matplotlib pandas - - conda install torchvision-cpu -c pytorch + - conda install torchvision -c $PYTORCH_CHANNEL - pip install gym==0.10.11 script: diff --git a/ignite/contrib/handlers/param_scheduler.py b/ignite/contrib/handlers/param_scheduler.py index faa55c281da5..98d74d353109 100644 --- a/ignite/contrib/handlers/param_scheduler.py +++ b/ignite/contrib/handlers/param_scheduler.py @@ -492,7 +492,8 @@ def _replicate_lr_scheduler(lr_scheduler, new_optimizer_param_groups=None): if new_optimizer_param_groups is not None: dummy_optimizer.param_groups = new_optimizer_param_groups kwargs = lr_scheduler.state_dict() - del kwargs['base_lrs'] + for k in ['base_lrs', '_step_count']: + del kwargs[k] copy_lr_scheduler = lr_scheduler_cls(optimizer=dummy_optimizer, **kwargs) copy_lr_scheduler.load_state_dict(lr_scheduler.state_dict()) return copy_lr_scheduler diff --git a/tests/ignite/contrib/handlers/test_param_scheduler.py b/tests/ignite/contrib/handlers/test_param_scheduler.py index 8b26da61187a..390268f7a444 100644 --- a/tests/ignite/contrib/handlers/test_param_scheduler.py +++ b/tests/ignite/contrib/handlers/test_param_scheduler.py @@ -357,7 +357,11 @@ def _test(torch_lr_scheduler_cls, **kwargs): lrs = [] lrs_true = [] - trainer = Engine(lambda engine, batch: None) + def dummy_update(engine, batch): + optimizer1.step() + optimizer2.step() + + trainer = Engine(dummy_update) @trainer.on(Events.ITERATION_COMPLETED) def torch_lr_scheduler_step(engine): @@ -396,6 +400,7 @@ def save_true_lr(engine): init_lr_scheduler_state = dict(lr_scheduler.state_dict()) copy_lr_scheduler = LRScheduler._replicate_lr_scheduler(lr_scheduler) for _ in range(10): + optimizer.step() lr_scheduler.step() assert copy_lr_scheduler.state_dict() == init_lr_scheduler_state @@ -444,7 +449,7 @@ def save_lr(engine): lrs.append(optimizer.param_groups[0]['lr']) trainer = Engine(lambda engine, batch: None) - trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) + trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler) trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr) trainer.run([0] * 25, max_epochs=2)