-
-
Notifications
You must be signed in to change notification settings - Fork 655
Closed
Labels
Description
🚀 Feature
As discussed with @vfdev-5 in #309, it could be sometimes useful to provide a handler to store all output prediction history for visualization purposes. Following is my first try to implement it.
import torch
from ignite.engine import Events
class EpochOutputStore(object):
"""EpochOutputStore handler to save output prediction and target history
after every epoch, could be useful for e.g., visualization purposes.
Note:
This can potentially lead to a memory error if the output data is
larger than available RAM.
Args:
output_transform (callable, optional): a callable that is used to
transform the :class:`~ignite.engine.engine.Engine`'s
``process_function``'s output into the form `y_pred, y`, e.g.,
lambda x, y, y_pred: y_pred, y
Examples:
.. code-block:: python
import ...
eos = EpochOutputStore()
trainer = create_supervised_trainer(model, optimizer, loss)
train_evaluator = create_supervised_evaluator(model, metrics={"acc": Accuracy()})
eos.attach(train_evaluator)
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
train_evaluator.run(train_loader)
y_pred, y = eos.get_output()
# plottings
"""
def __init__(self, output_transform=lambda x: x):
self.predictions = None
self.targets = None
self.output_transform = output_transform
def reset(self):
self.predictions = []
self.targets = []
def update(self, engine):
y_pred, y = self.output_transform(engine.state.output)
self.predictions.append(y_pred)
self.targets.append(y)
def attach(self, engine):
engine.add_event_handler(Events.EPOCH_STARTED, self.reset)
engine.add_event_handler(Events.ITERATION_COMPLETED, self.update)
def get_output(self, to_numpy=False):
prediction_tensor = torch.cat(self.predictions, dim=0)
target_tensor = torch.cat(self.targets, dim=0)
if to_numpy:
prediction_tensor = prediction_tensor.cpu().detach().numpy()
target_tensor = target_tensor.cpu().detach().numpy()
return prediction_tensor, target_tensor