Skip to content

Commit

Permalink
give an option to store EpochOutputStore data on engine.state (#1974
Browse files Browse the repository at this point in the history
)

* fix typo

* store epoch_output_store.data on engine.state

* update docstring

* update description
  • Loading branch information
radekosmulski committed May 1, 2021
1 parent 1305177 commit e8418eb
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
23 changes: 18 additions & 5 deletions ignite/contrib/handlers/stores.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, List, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

from ignite.engine import Engine, Events

Expand All @@ -21,15 +21,17 @@ class EpochOutputStore:
eos = EpochOutputStore()
trainer = create_supervised_trainer(model, optimizer, loss)
train_evaluator = create_supervised_evaluator(model, metrics)
eos.attach(train_evaluator)
eos.attach(train_evaluator, 'output')
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
train_evaluator.run(train_loader)
output = eos.data
output = train_evaluator.output
# do something with output, e.g., plotting
.. versionadded:: 0.4.2
.. versionchanged:: 0.5.0
`attach` now accepts an optional argument `name`
"""

def __init__(self, output_transform: Callable = lambda x: x):
Expand All @@ -45,8 +47,19 @@ def update(self, engine: Engine) -> None:
output = self.output_transform(engine.state.output)
self.data.append(output)

def attach(self, engine: Engine) -> None:
def store(self, engine: Engine) -> None:
"""Store `self.data` on `engine.state.{self.name}`"""
setattr(engine.state, self.name, self.data)

def attach(self, engine: Engine, name: Optional[str] = None) -> None:
"""Attaching `reset` method at EPOCH_STARTED and
`update` method at ITERATION_COMPLETED."""
`update` method at ITERATION_COMPLETED.
If `name` is passed, will store `self.data` on `engine.state`
under `name`.
"""
engine.add_event_handler(Events.EPOCH_STARTED, self.reset)
engine.add_event_handler(Events.ITERATION_COMPLETED, self.update)
if name:
self.name = name
engine.add_event_handler(Events.EPOCH_COMPLETED, self.store)
8 changes: 7 additions & 1 deletion tests/ignite/contrib/handlers/test_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_no_transform(dummy_evaluator, eos):
assert eos.data == [(1, 0)]


def test_tranform(dummy_evaluator):
def test_transform(dummy_evaluator):
eos = EpochOutputStore(output_transform=lambda x: x[0])
eos.attach(dummy_evaluator)

Expand Down Expand Up @@ -56,3 +56,9 @@ def test_attatch(dummy_evaluator, eos):
eos.attach(dummy_evaluator)
assert dummy_evaluator.has_event_handler(eos.reset, Events.EPOCH_STARTED)
assert dummy_evaluator.has_event_handler(eos.update, Events.ITERATION_COMPLETED)


def test_store_data(dummy_evaluator, eos):
eos.attach(dummy_evaluator, name="eval_data")
dummy_evaluator.run(range(1))
assert dummy_evaluator.state.eval_data == eos.data

0 comments on commit e8418eb

Please sign in to comment.