Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ before_install: &before_install
- source activate test-environment
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install enum34; fi
# Test contrib dependencies
- pip install tqdm scikit-learn tensorboardX visdom polyaxon-client mlflow tensorboard
- pip install tqdm scikit-learn matplotlib tensorboardX visdom polyaxon-client mlflow tensorboard
- if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install pynvml; fi
# Futures should be already installed via visdom -> tornado -> futures
# Let's reinstall it anyway to be sure
Expand All @@ -39,7 +39,7 @@ install:
- python setup.py install
- pip install numpy mock pytest codecov pytest-cov pytest-xdist
# Examples dependencies
- pip install matplotlib pandas
- pip install pandas
- pip install gym==0.10.11

script:
Expand Down
42 changes: 41 additions & 1 deletion ignite/contrib/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_param(self):

@classmethod
def simulate_values(cls, num_events, **scheduler_kwargs):
"""Method to simulate scheduled values during num_events events.
"""Method to simulate scheduled values during `num_events` events.

Args:
num_events (int): number of events during the simulation.
Expand Down Expand Up @@ -107,6 +107,46 @@ def simulate_values(cls, num_events, **scheduler_kwargs):
values.append([i, scheduler.optimizer_param_groups[0][scheduler.param_name]])
return values

@classmethod
def plot_values(cls, num_events, **scheduler_kwargs):
"""Method to plot simulated scheduled values during `num_events` events.

This class requires `matplotlib package <https://matplotlib.org/>`_ to be installed:

.. code-block:: bash

pip install matplotlib

Args:
num_events (int): number of events during the simulation.
**scheduler_kwargs : parameter scheduler configuration kwargs.

Returns:
matplotlib.lines.Line2D

Examples:

.. code-block:: python

import matplotlib.pylab as plt

plt.figure(figsize=(10, 7))
LinearCyclicalScheduler.plot_values(num_events=50, param_name='lr',
start_value=1e-1, end_value=1e-3, cycle_size=10))
"""
try:
import matplotlib.pylab as plt
except ImportError:
raise RuntimeError("This method requires matplotlib to be installed. "
"Please install it with command: \n pip install matplotlib")

values = cls.simulate_values(num_events=num_events, **scheduler_kwargs)
label = scheduler_kwargs.get("param_name", "learning rate")
ax = plt.plot([e for e, _ in values], [v for _, v in values], label=label)
plt.legend()
plt.grid(which='both')
return ax


class CyclicalScheduler(ParamScheduler):
"""An abstract class for updating an optimizer's parameter value over a
Expand Down
9 changes: 8 additions & 1 deletion tests/ignite/contrib/handlers/test_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,11 @@ def save_lr(engine):
_test(milestones_as_np_int=False)


def test_simulate_values():
def test_simulate_and_plot_values():

import matplotlib
matplotlib.use('Agg')

def _test(scheduler_cls, **scheduler_kwargs):

optimizer = None
Expand Down Expand Up @@ -511,6 +515,9 @@ def save_lr(engine):
**scheduler_kwargs)
assert lrs == pytest.approx([v for i, v in simulated_values])

# launch plot values
scheduler_cls.plot_values(num_events=len(data) * max_epochs, **scheduler_kwargs)

# LinearCyclicalScheduler
_test(LinearCyclicalScheduler, param_name="lr", start_value=1.0, end_value=0.0, cycle_size=10)

Expand Down