Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,13 +392,14 @@ def wrapper(*args, **kwargs):

return wrapper

def add_event_handler(self, event_name, handler, *args, **kwargs):
def add_event_handler(self, event_name, handler, priority=0, *args, **kwargs):
"""Add an event handler to be executed when the specified event is fired.

Args:
event_name: An event to attach the handler to. Valid events are from :class:`~ignite.engine.Events`
or any `event_name` added by :meth:`~ignite.engine.Engine.register_events`.
handler (callable): the callable event handler that should be invoked
priority(int, optional): change this to determine execution group, Default is 0 (last group)
*args: optional args to be passed to `handler`.
**kwargs: optional keyword args to be passed to `handler`.

Expand Down Expand Up @@ -440,7 +441,7 @@ def print_epoch(engine):
event_args = (Exception(), ) if event_name == Events.EXCEPTION_RAISED else ()
Engine._check_signature(self, handler, 'handler', *(event_args + args), **kwargs)

self._event_handlers[event_name].append((handler, args, kwargs))
self._event_handlers[event_name].append((priority, handler, args, kwargs))
self._logger.debug("added handler for event %s.", event_name)

return RemovableEventHandle(event_name, handler, self)
Expand All @@ -460,7 +461,7 @@ def has_event_handler(self, handler, event_name=None):
else:
events = self._event_handlers
for e in events:
for h, _, _ in self._event_handlers[e]:
for _, h, _, _ in self._event_handlers[e]:
if h == handler:
return True
return False
Expand Down Expand Up @@ -509,21 +510,27 @@ def _check_signature(self, fn, fn_description, *args, **kwargs):
"({}).".format(
fn, fn_description, fn_params, passed_params, exception_msg))

def on(self, event_name, *args, **kwargs):
def on(self, event_name, priority=0, *args, **kwargs):
"""Decorator shortcut for add_event_handler.

Args:
event_name: An event to attach the handler to. Valid events are from :class:`~ignite.engine.Events` or
any `event_name` added by :meth:`~ignite.engine.Engine.register_events`.
priority(int, optional): change this to determine execution group, Default is 0 (last group)
*args: optional args to be passed to `handler`.
**kwargs: optional keyword args to be passed to `handler`.

"""
def decorator(f):
self.add_event_handler(event_name, f, *args, **kwargs)
self.add_event_handler(event_name, f, priority, *args, **kwargs)
return f
return decorator

def _sort_handlers(self):
"""sort all handlers list according to their priority, ascending"""
for key, handlers_list in self._event_handlers.items():
self._event_handlers[key] = sorted(handlers_list, reverse=True)

def _fire_event(self, event_name, *event_args, **event_kwargs):
"""Execute all the handlers associated with given event.

Expand All @@ -543,7 +550,7 @@ def _fire_event(self, event_name, *event_args, **event_kwargs):
if event_name in self._allowed_events:
self._logger.debug("firing handlers for event %s ", event_name)
self.last_event_name = event_name
for func, args, kwargs in self._event_handlers[event_name]:
for _, func, args, kwargs in self._event_handlers[event_name]:
kwargs.update(event_kwargs)
func(self, *(event_args + args), **kwargs)

Expand Down Expand Up @@ -580,7 +587,7 @@ def terminate_epoch(self):
"""Sends terminate signal to the engine, so that it terminates the current epoch after the current iteration.
"""
self._logger.info("Terminate current epoch is signaled. "
"Current epoch iteration will stop after current iteration is finished.")
"Current epoch iteration will stop after current iteration is finished.")
self.should_terminate_single_epoch = True

def _run_once_on_dataset(self):
Expand Down Expand Up @@ -639,6 +646,7 @@ def switch_batch(engine):

self.state = State(dataloader=data, max_epochs=max_epochs, metrics={})
self.should_terminate = self.should_terminate_single_epoch = False
self._sort_handlers() # sort all handlers at the beginning of a run

try:
self._logger.info("Engine run starting with max_epochs={}.".format(max_epochs))
Expand Down
9 changes: 5 additions & 4 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,13 @@ def completed(self, engine, name):
result = result.item()
engine.state.metrics[name] = result

def attach(self, engine, name):
engine.add_event_handler(Events.EPOCH_COMPLETED, self.completed, name)
def attach(self, engine, name, priority=1):
# metric is prioritized higher by default
engine.add_event_handler(Events.EPOCH_COMPLETED, self.completed, priority, name)
if not engine.has_event_handler(self.started, Events.EPOCH_STARTED):
engine.add_event_handler(Events.EPOCH_STARTED, self.started)
engine.add_event_handler(Events.EPOCH_STARTED, self.started, priority)
if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed, priority)

def __add__(self, other):
from ignite.metrics import MetricsLambda
Expand Down
19 changes: 19 additions & 0 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,3 +1175,22 @@ def test_state_get_event_attrib_value():
assert state.get_event_attrib_value(e) == state.epoch
e = Events.EPOCH_COMPLETED(once=5)
assert state.get_event_attrib_value(e) == state.epoch


def test_handlers_sorted():
engine = DummyEngine()

@engine.on(Events.ITERATION_COMPLETED, 2)
def prio2(_): pass

@engine.on(Events.ITERATION_COMPLETED, 3)
def prio3(_): pass

@engine.on(Events.ITERATION_COMPLETED)
def prio0(_): pass

metric = MeanSquaredError()
metric.attach(engine, 'mock')
engine._sort_handlers()
prio_list = list(zip(*engine._event_handlers[Events.ITERATION_COMPLETED]))[0]
assert np.all(np.diff(prio_list) <= 0)