Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions ignite/contrib/handlers/base_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn
from torch.optim import Optimizer

from ignite.engine import Engine, Events, State
from ignite.engine import Engine, Events, EventsList, State
from ignite.engine.events import CallableEventWithFilter, RemovableEventHandle


Expand Down Expand Up @@ -147,26 +147,34 @@ class BaseLogger(metaclass=ABCMeta):
"""

def attach(
self, engine: Engine, log_handler: Callable, event_name: Union[str, Events, CallableEventWithFilter]
self, engine: Engine, log_handler: Callable, event_name: Union[str, Events, CallableEventWithFilter, EventsList]
) -> RemovableEventHandle:
"""Attach the logger to the engine and execute `log_handler` function at `event_name` events.

Args:
engine (Engine): engine object.
log_handler (callable): a logging handler to execute
event_name: event to attach the logging handler to. Valid events are from
:class:`~ignite.engine.events.Events` or any `event_name` added by
:meth:`~ignite.engine.engine.Engine.register_events`.
:class:`~ignite.engine.events.Events` or class:`~ignite.engine.events.EventsList` or any `event_name`
added by :meth:`~ignite.engine.engine.Engine.register_events`.

Returns:
:class:`~ignite.engine.RemovableEventHandle`, which can be used to remove the handler.
"""
name = event_name
if isinstance(event_name, EventsList):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to use return engine.add_event_handler(event_name, log_handler, self, event_name) for any type of events. Below for loop if considered without check if name not in State.event_to_attr is exactly what is done by Engine.add_event_handler.
I think we can simply generalize event checking like

event_to_check = event_name if isinstance(event_name, EventsList) else [event_name, ]
for name in event_to_check:
    if name not in State.event_to_attr:
        raise RuntimeError(f"Unknown event name '{name}'")

return engine.add_event_handler(event_name, log_handler, self, event_name)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That does not work.

The last argument of the method add_event_handler is a list of events. Internally in engine, there is a loop where the handler is attached at every event in the list. But the list is forwarded to each handler while just each event should be.

If there was no argument to the method, it would work.

Thoughts ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right. In this case, my suggestion is not valid. OK, let's do it as you proposed.

for name in event_name:
if name not in State.event_to_attr:
raise RuntimeError(f"Unknown event name '{name}'")
engine.add_event_handler(name, log_handler, self, name)

return RemovableEventHandle(event_name, log_handler, engine)

else:

if name not in State.event_to_attr:
raise RuntimeError(f"Unknown event name '{name}'")
if event_name not in State.event_to_attr:
raise RuntimeError(f"Unknown event name '{event_name}'")

return engine.add_event_handler(event_name, log_handler, self, name)
return engine.add_event_handler(event_name, log_handler, self, event_name)

def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any) -> RemovableEventHandle:
"""Shortcut method to attach `OutputHandler` to the logger.
Expand Down
4 changes: 3 additions & 1 deletion ignite/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ignite.distributed as idist
from ignite.engine.deterministic import DeterministicEngine
from ignite.engine.engine import Engine
from ignite.engine.events import CallableEventWithFilter, EventEnum, Events, State
from ignite.engine.events import CallableEventWithFilter, EventEnum, Events, EventsList, RemovableEventHandle, State
from ignite.metrics import Metric
from ignite.utils import convert_tensor

Expand All @@ -21,8 +21,10 @@
"Engine",
"DeterministicEngine",
"Events",
"EventsList",
"EventEnum",
"CallableEventWithFilter",
"RemovableEventHandle",
]


Expand Down
2 changes: 1 addition & 1 deletion ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
if TYPE_CHECKING:
from ignite.engine.engine import Engine

__all__ = ["CallableEventWithFilter", "EventEnum", "Events", "State"]
__all__ = ["CallableEventWithFilter", "EventEnum", "Events", "State", "EventsList", "RemovableEventHandle"]


class CallableEventWithFilter:
Expand Down
19 changes: 16 additions & 3 deletions tests/ignite/contrib/handlers/test_base_logger.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import math
from unittest.mock import MagicMock
from unittest.mock import MagicMock, call

import pytest
import torch

from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler
from ignite.engine import Engine, Events, State
from ignite.engine import Engine, Events, EventsList, State
from tests.ignite.contrib.handlers import MockFP16DeepSpeedZeroOptimizer


Expand Down Expand Up @@ -122,7 +122,12 @@ def update_fn(engine, batch):

trainer.run(data, max_epochs=n_epochs)

mock_log_handler.assert_called_with(trainer, logger, event)
if isinstance(event, EventsList):
events = [e for e in event]
else:
events = [event]
calls = [call(trainer, logger, e) for e in events]
mock_log_handler.assert_has_calls(calls)
assert mock_log_handler.call_count == n_calls

_test(Events.ITERATION_STARTED, len(data) * n_epochs)
Expand All @@ -134,6 +139,8 @@ def update_fn(engine, batch):

_test(Events.ITERATION_STARTED(every=10), len(data) // 10 * n_epochs)

_test(Events.STARTED | Events.COMPLETED, 2)


def test_attach_wrong_event_name():

Expand All @@ -144,6 +151,12 @@ def test_attach_wrong_event_name():
with pytest.raises(RuntimeError, match="Unknown event name"):
logger.attach(trainer, log_handler=mock_log_handler, event_name="unknown")

events_list = EventsList()
events_list._events = ["unknown"]

with pytest.raises(RuntimeError, match="Unknown event name"):
logger.attach(trainer, log_handler=mock_log_handler, event_name=events_list)


def test_attach_on_custom_event():
n_epochs = 10
Expand Down