diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 2012d44ca0b4..88125bd40796 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -392,13 +392,14 @@ def wrapper(*args, **kwargs): return wrapper - def add_event_handler(self, event_name, handler, *args, **kwargs): + def add_event_handler(self, event_name, handler, priority=0, *args, **kwargs): """Add an event handler to be executed when the specified event is fired. Args: event_name: An event to attach the handler to. Valid events are from :class:`~ignite.engine.Events` or any `event_name` added by :meth:`~ignite.engine.Engine.register_events`. handler (callable): the callable event handler that should be invoked + priority(int, optional): change this to determine execution group, Default is 0 (last group) *args: optional args to be passed to `handler`. **kwargs: optional keyword args to be passed to `handler`. @@ -440,7 +441,7 @@ def print_epoch(engine): event_args = (Exception(), ) if event_name == Events.EXCEPTION_RAISED else () Engine._check_signature(self, handler, 'handler', *(event_args + args), **kwargs) - self._event_handlers[event_name].append((handler, args, kwargs)) + self._event_handlers[event_name].append((priority, handler, args, kwargs)) self._logger.debug("added handler for event %s.", event_name) return RemovableEventHandle(event_name, handler, self) @@ -460,7 +461,7 @@ def has_event_handler(self, handler, event_name=None): else: events = self._event_handlers for e in events: - for h, _, _ in self._event_handlers[e]: + for _, h, _, _ in self._event_handlers[e]: if h == handler: return True return False @@ -509,21 +510,27 @@ def _check_signature(self, fn, fn_description, *args, **kwargs): "({}).".format( fn, fn_description, fn_params, passed_params, exception_msg)) - def on(self, event_name, *args, **kwargs): + def on(self, event_name, priority=0, *args, **kwargs): """Decorator shortcut for add_event_handler. Args: event_name: An event to attach the handler to. Valid events are from :class:`~ignite.engine.Events` or any `event_name` added by :meth:`~ignite.engine.Engine.register_events`. + priority(int, optional): change this to determine execution group, Default is 0 (last group) *args: optional args to be passed to `handler`. **kwargs: optional keyword args to be passed to `handler`. """ def decorator(f): - self.add_event_handler(event_name, f, *args, **kwargs) + self.add_event_handler(event_name, f, priority, *args, **kwargs) return f return decorator + def _sort_handlers(self): + """sort all handlers list according to their priority, ascending""" + for key, handlers_list in self._event_handlers.items(): + self._event_handlers[key] = sorted(handlers_list, reverse=True) + def _fire_event(self, event_name, *event_args, **event_kwargs): """Execute all the handlers associated with given event. @@ -543,7 +550,7 @@ def _fire_event(self, event_name, *event_args, **event_kwargs): if event_name in self._allowed_events: self._logger.debug("firing handlers for event %s ", event_name) self.last_event_name = event_name - for func, args, kwargs in self._event_handlers[event_name]: + for _, func, args, kwargs in self._event_handlers[event_name]: kwargs.update(event_kwargs) func(self, *(event_args + args), **kwargs) @@ -580,7 +587,7 @@ def terminate_epoch(self): """Sends terminate signal to the engine, so that it terminates the current epoch after the current iteration. """ self._logger.info("Terminate current epoch is signaled. " - "Current epoch iteration will stop after current iteration is finished.") + "Current epoch iteration will stop after current iteration is finished.") self.should_terminate_single_epoch = True def _run_once_on_dataset(self): @@ -639,6 +646,7 @@ def switch_batch(engine): self.state = State(dataloader=data, max_epochs=max_epochs, metrics={}) self.should_terminate = self.should_terminate_single_epoch = False + self._sort_handlers() # sort all handlers at the beginning of a run try: self._logger.info("Engine run starting with max_epochs={}.".format(max_epochs)) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 8bf8178c49be..6c5cdd4eddc2 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -124,12 +124,13 @@ def completed(self, engine, name): result = result.item() engine.state.metrics[name] = result - def attach(self, engine, name): - engine.add_event_handler(Events.EPOCH_COMPLETED, self.completed, name) + def attach(self, engine, name, priority=1): + # metric is prioritized higher by default + engine.add_event_handler(Events.EPOCH_COMPLETED, self.completed, priority, name) if not engine.has_event_handler(self.started, Events.EPOCH_STARTED): - engine.add_event_handler(Events.EPOCH_STARTED, self.started) + engine.add_event_handler(Events.EPOCH_STARTED, self.started, priority) if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): - engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) + engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed, priority) def __add__(self, other): from ignite.metrics import MetricsLambda diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index 542fc379d58a..30fd5df6fe8b 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -1175,3 +1175,22 @@ def test_state_get_event_attrib_value(): assert state.get_event_attrib_value(e) == state.epoch e = Events.EPOCH_COMPLETED(once=5) assert state.get_event_attrib_value(e) == state.epoch + + +def test_handlers_sorted(): + engine = DummyEngine() + + @engine.on(Events.ITERATION_COMPLETED, 2) + def prio2(_): pass + + @engine.on(Events.ITERATION_COMPLETED, 3) + def prio3(_): pass + + @engine.on(Events.ITERATION_COMPLETED) + def prio0(_): pass + + metric = MeanSquaredError() + metric.attach(engine, 'mock') + engine._sort_handlers() + prio_list = list(zip(*engine._event_handlers[Events.ITERATION_COMPLETED]))[0] + assert np.all(np.diff(prio_list) <= 0)