from enum import Enum
from ignite.engine import Events, State
class CustomPeriodicEvent(object):
"""Handler to define a custom periodic events as a number of elapsed iterations/epochs for an engine.
When custom periodic event is created and attached to an engine, the following events are fired:
1) K iterations is specified:
1) K epochs is specified:
.. code-block:: python
from ignite.engine import Engine, Events
from ignite.contrib.handlers import CustomPeriodicEvent
# Let's define an event every 1000 iterations
cpe1 = CustomPeriodicEvent(n_iterations=1000)
# Let's define an event every 10 epochs
cpe2 = CustomPeriodicEvent(n_epochs=10)
def on_every_1000_iterations(engine):
# run a computation after 1000 iterations
# ...
def on_every_10_epochs(engine):
# run a computation every 10 epochs
# ...
n_iterations (int, optional): number iterations of the custom periodic event
n_epochs (int, optional): number iterations of the custom periodic event. Argument is optional, but only one,
either n_iterations or n_epochs should defined.
def __init__(self, n_iterations=None, n_epochs=None):
if n_iterations is not None and (not isinstance(n_iterations, int) or n_iterations < 1):
raise ValueError("Argument n_iterations should be positive integer number")
if n_epochs is not None and (not isinstance(n_epochs, int) or n_epochs < 1):
raise ValueError("Argument n_epochs should be positive integer number")
if (n_iterations is None and n_epochs is None) or (n_iterations and n_epochs):
raise ValueError("Either n_iterations or n_epochs should defined")
if n_iterations:
prefix = "iterations"
self.state_attr = "iteration"
self.period = n_iterations
if n_epochs:
prefix = "epochs"
self.state_attr = "epoch"
self.period = n_epochs
self.custom_state_attr = "{}_{}".format(prefix, self.period)
event_name = "{}_{}".format(prefix.upper(), self.period)
setattr(self, "Events", Enum("Events", " ".join([
# Update State.event_to_attr
for e in self.Events:
State.event_to_attr[e] = self.custom_state_attr
# Create aliases
self._periodic_event_started = getattr(self.Events, "{}_STARTED".format(event_name))
self._periodic_event_completed = getattr(self.Events, "{}_COMPLETED".format(event_name))
def _on_started(self, engine):
setattr(engine.state, self.custom_state_attr, 0)
def _on_periodic_event_started(self, engine):
if getattr(engine.state, self.state_attr) % self.period == 1:
setattr(engine.state, self.custom_state_attr, getattr(engine.state, self.custom_state_attr) + 1)
def _on_periodic_event_completed(self, engine):
if getattr(engine.state, self.state_attr) % self.period == 0:
def attach(self, engine):
engine.add_event_handler(Events.STARTED, self._on_started)
engine.add_event_handler(getattr(Events, "{}_STARTED".format(self.state_attr.upper())),
engine.add_event_handler(getattr(Events, "{}_COMPLETED".format(self.state_attr.upper())),
