Skip to content

Commit

Permalink
fix issue 918 (#919)
Browse files Browse the repository at this point in the history
* fix issue 918

* add missing test

* fix flake

Co-authored-by: Desroziers <sylvain.desroziers@ifpen.fr>
  • Loading branch information
sdesrozis and Desroziers committed Apr 14, 2020
1 parent d62c4e9 commit 10fd602
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
17 changes: 10 additions & 7 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import weakref
import random
import warnings
import functools
from typing import Union, Optional, Callable, Iterable, Iterator, Any, Tuple, List

import torch
Expand Down Expand Up @@ -197,12 +198,14 @@ class TBPTT_Events(EventEnum):
if event_to_attr and e in event_to_attr:
State.event_to_attr[e] = event_to_attr[e]

@staticmethod
def _handler_wrapper(handler: Callable, event_name: Any, event_filter: Callable) -> Callable:
def wrapper(engine: Engine, *args, **kwargs) -> Any:
event = engine.state.get_event_attrib_value(event_name)
if event_filter(engine, event):
return handler(engine, *args, **kwargs)
def _handler_wrapper(self, handler: Callable, event_name: Any, event_filter: Callable) -> Callable:
# signature of the following wrapper will be inspected during registering to check if engine is necessary
# we have to build a wrapper with relevant signature : solution is functools.wrapsgit s
@functools.wraps(handler)
def wrapper(*args, **kwargs) -> Any:
event = self.state.get_event_attrib_value(event_name)
if event_filter(self, event):
return handler(*args, **kwargs)

# setup input handler as parent to make has_event_handler work
wrapper._parent = weakref.ref(handler)
Expand Down Expand Up @@ -261,7 +264,7 @@ def execute_something():
and event_name.filter != CallableEventWithFilter.default_event_filter
):
event_filter = event_name.filter
handler = Engine._handler_wrapper(handler, event_name, event_filter)
handler = self._handler_wrapper(handler, event_name, event_filter)

if event_name not in self._allowed_events:
self.logger.error("attempt to add event handler to an invalid event %s.", event_name)
Expand Down
9 changes: 9 additions & 0 deletions tests/ignite/engine/test_custom_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,20 @@ def assert_every(engine):
assert counter_every[0] == getattr(engine.state, event_attr)
num_calls[0] += 1

@engine.on(event_name(every=every))
def assert_every_no_engine():
assert getattr(engine.state, event_attr) % every == 0
assert counter_every[0] == getattr(engine.state, event_attr)

@engine.on(event_name)
def assert_(engine):
counter[0] += 1
assert getattr(engine.state, event_attr) == counter[0]

@engine.on(event_name)
def assert_no_engine():
assert getattr(engine.state, event_attr) == counter[0]

engine.run(data, max_epochs=5)

assert num_calls[0] == true_num_calls
Expand Down

0 comments on commit 10fd602

Please sign in to comment.