From a73d60f990255e184f7f2dbf185c4b0d0b7d72ea Mon Sep 17 00:00:00 2001 From: Desroziers Date: Tue, 5 Jan 2021 10:17:45 +0100 Subject: [PATCH 1/6] use events list for loggers --- ignite/contrib/handlers/base_logger.py | 21 +++++++++++++++------ ignite/engine/__init__.py | 4 +++- ignite/engine/events.py | 2 +- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/ignite/contrib/handlers/base_logger.py b/ignite/contrib/handlers/base_logger.py index 44707d668d79..8a368345882a 100644 --- a/ignite/contrib/handlers/base_logger.py +++ b/ignite/contrib/handlers/base_logger.py @@ -7,7 +7,7 @@ import torch.nn as nn from torch.optim import Optimizer -from ignite.engine import Engine, Events, State +from ignite.engine import Engine, Events, EventsList, State from ignite.engine.events import CallableEventWithFilter, RemovableEventHandle @@ -147,7 +147,7 @@ class BaseLogger(metaclass=ABCMeta): """ def attach( - self, engine: Engine, log_handler: Callable, event_name: Union[str, Events, CallableEventWithFilter] + self, engine: Engine, log_handler: Callable, event_name: Union[str, Events, CallableEventWithFilter, EventsList] ) -> RemovableEventHandle: """Attach the logger to the engine and execute `log_handler` function at `event_name` events. @@ -161,12 +161,21 @@ def attach( Returns: :class:`~ignite.engine.RemovableEventHandle`, which can be used to remove the handler. """ - name = event_name + if isinstance(event_name, EventsList): + for name in event_name: + if name not in State.event_to_attr: + raise RuntimeError(f"Unknown event name '{name}'") + engine.add_event_handler(name, log_handler, self, name) - if name not in State.event_to_attr: - raise RuntimeError(f"Unknown event name '{name}'") + return RemovableEventHandle(event_name, log_handler, engine) - return engine.add_event_handler(event_name, log_handler, self, name) + else: + name = event_name + + if name not in State.event_to_attr: + raise RuntimeError(f"Unknown event name '{name}'") + + return engine.add_event_handler(event_name, log_handler, self, name) def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any) -> RemovableEventHandle: """Shortcut method to attach `OutputHandler` to the logger. diff --git a/ignite/engine/__init__.py b/ignite/engine/__init__.py index 3fbc83861b33..39aa79ee9fb2 100644 --- a/ignite/engine/__init__.py +++ b/ignite/engine/__init__.py @@ -6,7 +6,7 @@ import ignite.distributed as idist from ignite.engine.deterministic import DeterministicEngine from ignite.engine.engine import Engine -from ignite.engine.events import CallableEventWithFilter, EventEnum, Events, State +from ignite.engine.events import CallableEventWithFilter, EventEnum, Events, EventsList, State, RemovableEventHandle from ignite.metrics import Metric from ignite.utils import convert_tensor @@ -21,8 +21,10 @@ "Engine", "DeterministicEngine", "Events", + "EventsList", "EventEnum", "CallableEventWithFilter", + "RemovableEventHandle", ] diff --git a/ignite/engine/events.py b/ignite/engine/events.py index 4f07f5765082..aef9c97bb108 100644 --- a/ignite/engine/events.py +++ b/ignite/engine/events.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from ignite.engine.engine import Engine -__all__ = ["CallableEventWithFilter", "EventEnum", "Events", "State"] +__all__ = ["CallableEventWithFilter", "EventEnum", "Events", "State", "EventsList", "RemovableEventHandle"] class CallableEventWithFilter: From 3ede8289c1e8e5d7c9dfe43bd658f2608990b8ee Mon Sep 17 00:00:00 2001 From: sdesrozis Date: Tue, 5 Jan 2021 09:19:38 +0000 Subject: [PATCH 2/6] autopep8 fix --- ignite/engine/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/engine/__init__.py b/ignite/engine/__init__.py index 39aa79ee9fb2..5330393a2cc8 100644 --- a/ignite/engine/__init__.py +++ b/ignite/engine/__init__.py @@ -6,7 +6,7 @@ import ignite.distributed as idist from ignite.engine.deterministic import DeterministicEngine from ignite.engine.engine import Engine -from ignite.engine.events import CallableEventWithFilter, EventEnum, Events, EventsList, State, RemovableEventHandle +from ignite.engine.events import CallableEventWithFilter, EventEnum, Events, EventsList, RemovableEventHandle, State from ignite.metrics import Metric from ignite.utils import convert_tensor From ecd9e6a89d06e7b9906250cf84ed547b6c195207 Mon Sep 17 00:00:00 2001 From: Desroziers Date: Tue, 5 Jan 2021 10:36:58 +0100 Subject: [PATCH 3/6] add test --- tests/ignite/contrib/handlers/test_base_logger.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/ignite/contrib/handlers/test_base_logger.py b/tests/ignite/contrib/handlers/test_base_logger.py index 3abc162f693b..3f1890289ad7 100644 --- a/tests/ignite/contrib/handlers/test_base_logger.py +++ b/tests/ignite/contrib/handlers/test_base_logger.py @@ -1,11 +1,11 @@ import math -from unittest.mock import MagicMock +from unittest.mock import MagicMock, call import pytest import torch from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler -from ignite.engine import Engine, Events, State +from ignite.engine import Engine, Events, EventsList, State from tests.ignite.contrib.handlers import MockFP16DeepSpeedZeroOptimizer @@ -122,7 +122,12 @@ def update_fn(engine, batch): trainer.run(data, max_epochs=n_epochs) - mock_log_handler.assert_called_with(trainer, logger, event) + if isinstance(event, EventsList): + events = [e for e in event] + else: + events = [event] + calls = [call(trainer, logger, e) for e in events] + mock_log_handler.assert_has_calls(calls) assert mock_log_handler.call_count == n_calls _test(Events.ITERATION_STARTED, len(data) * n_epochs) @@ -134,6 +139,8 @@ def update_fn(engine, batch): _test(Events.ITERATION_STARTED(every=10), len(data) // 10 * n_epochs) + _test(Events.STARTED | Events.COMPLETED, 2) + def test_attach_wrong_event_name(): From 957e066059455338a2c572181924014d99d82fa1 Mon Sep 17 00:00:00 2001 From: Desroziers Date: Tue, 5 Jan 2021 14:27:58 +0100 Subject: [PATCH 4/6] fix mypy --- ignite/contrib/handlers/base_logger.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ignite/contrib/handlers/base_logger.py b/ignite/contrib/handlers/base_logger.py index 8a368345882a..5c7a57d69de8 100644 --- a/ignite/contrib/handlers/base_logger.py +++ b/ignite/contrib/handlers/base_logger.py @@ -170,12 +170,11 @@ def attach( return RemovableEventHandle(event_name, log_handler, engine) else: - name = event_name - if name not in State.event_to_attr: - raise RuntimeError(f"Unknown event name '{name}'") + if event_name not in State.event_to_attr: + raise RuntimeError(f"Unknown event name '{event_name}'") - return engine.add_event_handler(event_name, log_handler, self, name) + return engine.add_event_handler(event_name, log_handler, self, event_name) def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any) -> RemovableEventHandle: """Shortcut method to attach `OutputHandler` to the logger. From 1e8bf30972aa9c63b5268b74bb4cdab6a1159b67 Mon Sep 17 00:00:00 2001 From: Desroziers Date: Tue, 5 Jan 2021 18:31:56 +0100 Subject: [PATCH 5/6] add test to catch error --- tests/ignite/contrib/handlers/test_base_logger.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/ignite/contrib/handlers/test_base_logger.py b/tests/ignite/contrib/handlers/test_base_logger.py index 3f1890289ad7..4b6b8350dbe7 100644 --- a/tests/ignite/contrib/handlers/test_base_logger.py +++ b/tests/ignite/contrib/handlers/test_base_logger.py @@ -151,6 +151,12 @@ def test_attach_wrong_event_name(): with pytest.raises(RuntimeError, match="Unknown event name"): logger.attach(trainer, log_handler=mock_log_handler, event_name="unknown") + events_list = EventsList() + events_list._events = ["unknown"] + + with pytest.raises(RuntimeError, match="Unknown event name"): + logger.attach(trainer, log_handler=mock_log_handler, event_name=events_list) + def test_attach_on_custom_event(): n_epochs = 10 From 5d51fd5e90d7e8697b0a34f4dda3d0ab98e33766 Mon Sep 17 00:00:00 2001 From: Desroziers Date: Tue, 5 Jan 2021 18:32:16 +0100 Subject: [PATCH 6/6] improve docstring --- ignite/contrib/handlers/base_logger.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ignite/contrib/handlers/base_logger.py b/ignite/contrib/handlers/base_logger.py index 5c7a57d69de8..c3776835449d 100644 --- a/ignite/contrib/handlers/base_logger.py +++ b/ignite/contrib/handlers/base_logger.py @@ -155,8 +155,8 @@ def attach( engine (Engine): engine object. log_handler (callable): a logging handler to execute event_name: event to attach the logging handler to. Valid events are from - :class:`~ignite.engine.events.Events` or any `event_name` added by - :meth:`~ignite.engine.engine.Engine.register_events`. + :class:`~ignite.engine.events.Events` or class:`~ignite.engine.events.EventsList` or any `event_name` + added by :meth:`~ignite.engine.engine.Engine.register_events`. Returns: :class:`~ignite.engine.RemovableEventHandle`, which can be used to remove the handler.