From 2d1977f9d4a6f429f2d3c61a40007bfebebf2065 Mon Sep 17 00:00:00 2001 From: beyondyoni <48648837+beyondyoni@users.noreply.github.com> Date: Mon, 20 Jan 2020 09:28:20 +0200 Subject: [PATCH 1/5] adding ability to prioritize handlers --- ignite/engine/engine.py | 48 +++++++++++++++++------------- ignite/metrics/metric.py | 9 +++--- tests/ignite/engine/test_engine.py | 19 ++++++++++++ 3 files changed, 52 insertions(+), 24 deletions(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 2012d44ca0b4..cbc88b472e9e 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -306,8 +306,8 @@ def compute_mean_std(engine, batch): """ def __init__(self, process_function): self._event_handlers = defaultdict(list) - self._logger = logging.getLogger(__name__ + "." + self.__class__.__name__) - self._logger.addHandler(logging.NullHandler()) + self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__) + self.logger.addHandler(logging.NullHandler()) self._process_function = process_function self.last_event_name = None self.should_terminate = False @@ -392,13 +392,14 @@ def wrapper(*args, **kwargs): return wrapper - def add_event_handler(self, event_name, handler, *args, **kwargs): + def add_event_handler(self, event_name, handler, priority=0, *args, **kwargs): """Add an event handler to be executed when the specified event is fired. Args: event_name: An event to attach the handler to. Valid events are from :class:`~ignite.engine.Events` or any `event_name` added by :meth:`~ignite.engine.Engine.register_events`. handler (callable): the callable event handler that should be invoked + priority(int, optional): change this to determine execution group, Default is 0 (last group) *args: optional args to be passed to `handler`. **kwargs: optional keyword args to be passed to `handler`. @@ -434,14 +435,14 @@ def print_epoch(engine): handler = self._handler_wrapper(handler, event_name, event_filter) if event_name not in self._allowed_events: - self._logger.error("attempt to add event handler to an invalid event %s.", event_name) + self.logger.error("attempt to add event handler to an invalid event %s.", event_name) raise ValueError("Event {} is not a valid event for this Engine.".format(event_name)) event_args = (Exception(), ) if event_name == Events.EXCEPTION_RAISED else () Engine._check_signature(self, handler, 'handler', *(event_args + args), **kwargs) - self._event_handlers[event_name].append((handler, args, kwargs)) - self._logger.debug("added handler for event %s.", event_name) + self._event_handlers[event_name].append((priority, handler, args, kwargs)) + self.logger.debug("added handler for event %s.", event_name) return RemovableEventHandle(event_name, handler, self) @@ -460,7 +461,7 @@ def has_event_handler(self, handler, event_name=None): else: events = self._event_handlers for e in events: - for h, _, _ in self._event_handlers[e]: + for _, h, _, _ in self._event_handlers[e]: if h == handler: return True return False @@ -507,23 +508,29 @@ def _check_signature(self, fn, fn_description, *args, **kwargs): raise ValueError("Error adding {} '{}': " "takes parameters {} but will be called with {} " "({}).".format( - fn, fn_description, fn_params, passed_params, exception_msg)) + fn, fn_description, fn_params, passed_params, exception_msg)) - def on(self, event_name, *args, **kwargs): + def on(self, event_name, priority=0, *args, **kwargs): """Decorator shortcut for add_event_handler. Args: event_name: An event to attach the handler to. Valid events are from :class:`~ignite.engine.Events` or any `event_name` added by :meth:`~ignite.engine.Engine.register_events`. + priority(int, optional): change this to determine execution group, Default is 0 (last group) *args: optional args to be passed to `handler`. **kwargs: optional keyword args to be passed to `handler`. """ def decorator(f): - self.add_event_handler(event_name, f, *args, **kwargs) + self.add_event_handler(event_name, f, priority, *args, **kwargs) return f return decorator + def _sort_handlers(self): + """sort all handlers list according to their priority, ascending""" + for key, handlers_list in self._event_handlers.items(): + self._event_handlers[key] = sorted(handlers_list, reverse=True) + def _fire_event(self, event_name, *event_args, **event_kwargs): """Execute all the handlers associated with given event. @@ -541,9 +548,9 @@ def _fire_event(self, event_name, *event_args, **event_kwargs): """ if event_name in self._allowed_events: - self._logger.debug("firing handlers for event %s ", event_name) + self.logger.debug("firing handlers for event %s ", event_name) self.last_event_name = event_name - for func, args, kwargs in self._event_handlers[event_name]: + for _, func, args, kwargs in self._event_handlers[event_name]: kwargs.update(event_kwargs) func(self, *(event_args + args), **kwargs) @@ -573,14 +580,14 @@ def fire_event(self, event_name): def terminate(self): """Sends terminate signal to the engine, so that it terminates completely the run after the current iteration. """ - self._logger.info("Terminate signaled. Engine will stop after current iteration is finished.") + self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.") self.should_terminate = True def terminate_epoch(self): """Sends terminate signal to the engine, so that it terminates the current epoch after the current iteration. """ - self._logger.info("Terminate current epoch is signaled. " - "Current epoch iteration will stop after current iteration is finished.") + self.logger.info("Terminate current epoch is signaled. " + "Current epoch iteration will stop after current iteration is finished.") self.should_terminate_single_epoch = True def _run_once_on_dataset(self): @@ -598,7 +605,7 @@ def _run_once_on_dataset(self): break except BaseException as e: - self._logger.error("Current run is terminating due to exception: %s.", str(e)) + self.logger.error("Current run is terminating due to exception: %s.", str(e)) self._handle_exception(e) time_taken = time.time() - start_time @@ -639,16 +646,17 @@ def switch_batch(engine): self.state = State(dataloader=data, max_epochs=max_epochs, metrics={}) self.should_terminate = self.should_terminate_single_epoch = False + self._sort_handlers() # sort all handlers at the beginning of a run try: - self._logger.info("Engine run starting with max_epochs={}.".format(max_epochs)) + self.logger.info("Engine run starting with max_epochs={}.".format(max_epochs)) start_time = time.time() self._fire_event(Events.STARTED) while self.state.epoch < max_epochs and not self.should_terminate: self.state.epoch += 1 self._fire_event(Events.EPOCH_STARTED) hours, mins, secs = self._run_once_on_dataset() - self._logger.info("Epoch[%s] Complete. Time taken: %02d:%02d:%02d", self.state.epoch, hours, mins, secs) + self.logger.info("Epoch[%s] Complete. Time taken: %02d:%02d:%02d", self.state.epoch, hours, mins, secs) if self.should_terminate: break self._fire_event(Events.EPOCH_COMPLETED) @@ -656,10 +664,10 @@ def switch_batch(engine): self._fire_event(Events.COMPLETED) time_taken = time.time() - start_time hours, mins, secs = _to_hours_mins_secs(time_taken) - self._logger.info("Engine run complete. Time taken %02d:%02d:%02d" % (hours, mins, secs)) + self.logger.info("Engine run complete. Time taken %02d:%02d:%02d" % (hours, mins, secs)) except BaseException as e: - self._logger.error("Engine run is terminating due to exception: %s.", str(e)) + self.logger.error("Engine run is terminating due to exception: %s.", str(e)) self._handle_exception(e) return self.state diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 8bf8178c49be..6c5cdd4eddc2 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -124,12 +124,13 @@ def completed(self, engine, name): result = result.item() engine.state.metrics[name] = result - def attach(self, engine, name): - engine.add_event_handler(Events.EPOCH_COMPLETED, self.completed, name) + def attach(self, engine, name, priority=1): + # metric is prioritized higher by default + engine.add_event_handler(Events.EPOCH_COMPLETED, self.completed, priority, name) if not engine.has_event_handler(self.started, Events.EPOCH_STARTED): - engine.add_event_handler(Events.EPOCH_STARTED, self.started) + engine.add_event_handler(Events.EPOCH_STARTED, self.started, priority) if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): - engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) + engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed, priority) def __add__(self, other): from ignite.metrics import MetricsLambda diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index 542fc379d58a..30fd5df6fe8b 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -1175,3 +1175,22 @@ def test_state_get_event_attrib_value(): assert state.get_event_attrib_value(e) == state.epoch e = Events.EPOCH_COMPLETED(once=5) assert state.get_event_attrib_value(e) == state.epoch + + +def test_handlers_sorted(): + engine = DummyEngine() + + @engine.on(Events.ITERATION_COMPLETED, 2) + def prio2(_): pass + + @engine.on(Events.ITERATION_COMPLETED, 3) + def prio3(_): pass + + @engine.on(Events.ITERATION_COMPLETED) + def prio0(_): pass + + metric = MeanSquaredError() + metric.attach(engine, 'mock') + engine._sort_handlers() + prio_list = list(zip(*engine._event_handlers[Events.ITERATION_COMPLETED]))[0] + assert np.all(np.diff(prio_list) <= 0) From 449724154c2cc66bc8d67c9f8d1133a39ee5c21c Mon Sep 17 00:00:00 2001 From: beyondyoni <48648837+beyondyoni@users.noreply.github.com> Date: Mon, 20 Jan 2020 09:38:54 +0200 Subject: [PATCH 2/5] adding ability to prioritize handlers --- ignite/engine/engine.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index cbc88b472e9e..749e546e5678 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -306,8 +306,8 @@ def compute_mean_std(engine, batch): """ def __init__(self, process_function): self._event_handlers = defaultdict(list) - self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__) - self.logger.addHandler(logging.NullHandler()) + self._logger = logging.getLogger(__name__ + "." + self.__class__.__name__) + self._logger.addHandler(logging.NullHandler()) self._process_function = process_function self.last_event_name = None self.should_terminate = False @@ -435,14 +435,14 @@ def print_epoch(engine): handler = self._handler_wrapper(handler, event_name, event_filter) if event_name not in self._allowed_events: - self.logger.error("attempt to add event handler to an invalid event %s.", event_name) + self._logger.error("attempt to add event handler to an invalid event %s.", event_name) raise ValueError("Event {} is not a valid event for this Engine.".format(event_name)) event_args = (Exception(), ) if event_name == Events.EXCEPTION_RAISED else () Engine._check_signature(self, handler, 'handler', *(event_args + args), **kwargs) self._event_handlers[event_name].append((priority, handler, args, kwargs)) - self.logger.debug("added handler for event %s.", event_name) + self._logger.debug("added handler for event %s.", event_name) return RemovableEventHandle(event_name, handler, self) @@ -548,7 +548,7 @@ def _fire_event(self, event_name, *event_args, **event_kwargs): """ if event_name in self._allowed_events: - self.logger.debug("firing handlers for event %s ", event_name) + self._logger.debug("firing handlers for event %s ", event_name) self.last_event_name = event_name for _, func, args, kwargs in self._event_handlers[event_name]: kwargs.update(event_kwargs) @@ -580,13 +580,13 @@ def fire_event(self, event_name): def terminate(self): """Sends terminate signal to the engine, so that it terminates completely the run after the current iteration. """ - self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.") + self._logger.info("Terminate signaled. Engine will stop after current iteration is finished.") self.should_terminate = True def terminate_epoch(self): """Sends terminate signal to the engine, so that it terminates the current epoch after the current iteration. """ - self.logger.info("Terminate current epoch is signaled. " + self._logger.info("Terminate current epoch is signaled. " "Current epoch iteration will stop after current iteration is finished.") self.should_terminate_single_epoch = True @@ -605,7 +605,7 @@ def _run_once_on_dataset(self): break except BaseException as e: - self.logger.error("Current run is terminating due to exception: %s.", str(e)) + self._logger.error("Current run is terminating due to exception: %s.", str(e)) self._handle_exception(e) time_taken = time.time() - start_time @@ -649,14 +649,14 @@ def switch_batch(engine): self._sort_handlers() # sort all handlers at the beginning of a run try: - self.logger.info("Engine run starting with max_epochs={}.".format(max_epochs)) + self._logger.info("Engine run starting with max_epochs={}.".format(max_epochs)) start_time = time.time() self._fire_event(Events.STARTED) while self.state.epoch < max_epochs and not self.should_terminate: self.state.epoch += 1 self._fire_event(Events.EPOCH_STARTED) hours, mins, secs = self._run_once_on_dataset() - self.logger.info("Epoch[%s] Complete. Time taken: %02d:%02d:%02d", self.state.epoch, hours, mins, secs) + self._logger.info("Epoch[%s] Complete. Time taken: %02d:%02d:%02d", self.state.epoch, hours, mins, secs) if self.should_terminate: break self._fire_event(Events.EPOCH_COMPLETED) @@ -664,10 +664,10 @@ def switch_batch(engine): self._fire_event(Events.COMPLETED) time_taken = time.time() - start_time hours, mins, secs = _to_hours_mins_secs(time_taken) - self.logger.info("Engine run complete. Time taken %02d:%02d:%02d" % (hours, mins, secs)) + self._logger.info("Engine run complete. Time taken %02d:%02d:%02d" % (hours, mins, secs)) except BaseException as e: - self.logger.error("Engine run is terminating due to exception: %s.", str(e)) + self._logger.error("Engine run is terminating due to exception: %s.", str(e)) self._handle_exception(e) return self.state From d4cd80259d634789bd59ea6b0bacefbd192a00fa Mon Sep 17 00:00:00 2001 From: beyondyoni <48648837+beyondyoni@users.noreply.github.com> Date: Mon, 20 Jan 2020 09:40:07 +0200 Subject: [PATCH 3/5] adding ability to prioritize handlers --- ignite/engine/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 749e546e5678..ce93f7765ff9 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -508,7 +508,7 @@ def _check_signature(self, fn, fn_description, *args, **kwargs): raise ValueError("Error adding {} '{}': " "takes parameters {} but will be called with {} " "({}).".format( - fn, fn_description, fn_params, passed_params, exception_msg)) + fn, fn_description, fn_params, passed_params, exception_msg)) def on(self, event_name, priority=0, *args, **kwargs): """Decorator shortcut for add_event_handler. From ede1c48aaac952fa16c310d84ead304e461a9a63 Mon Sep 17 00:00:00 2001 From: beyondyoni <48648837+beyondyoni@users.noreply.github.com> Date: Mon, 20 Jan 2020 09:40:54 +0200 Subject: [PATCH 4/5] adding ability to prioritize handlers --- ignite/engine/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index ce93f7765ff9..9159d5492edd 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -508,7 +508,7 @@ def _check_signature(self, fn, fn_description, *args, **kwargs): raise ValueError("Error adding {} '{}': " "takes parameters {} but will be called with {} " "({}).".format( - fn, fn_description, fn_params, passed_params, exception_msg)) + fn, fn_description, fn_params, passed_params, exception_msg)) def on(self, event_name, priority=0, *args, **kwargs): """Decorator shortcut for add_event_handler. From 7b5a1551d8d11eaba897ce7ba80869527466bf12 Mon Sep 17 00:00:00 2001 From: beyondyoni <48648837+beyondyoni@users.noreply.github.com> Date: Mon, 20 Jan 2020 09:42:05 +0200 Subject: [PATCH 5/5] adding ability to prioritize handlers --- ignite/engine/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 9159d5492edd..88125bd40796 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -508,7 +508,7 @@ def _check_signature(self, fn, fn_description, *args, **kwargs): raise ValueError("Error adding {} '{}': " "takes parameters {} but will be called with {} " "({}).".format( - fn, fn_description, fn_params, passed_params, exception_msg)) + fn, fn_description, fn_params, passed_params, exception_msg)) def on(self, event_name, priority=0, *args, **kwargs): """Decorator shortcut for add_event_handler.