Skip to content

EpochOutputStorage #1224

@ZhiliangWu

Description

@ZhiliangWu

🚀 Feature

As discussed with @vfdev-5 in #309, it could be sometimes useful to provide a handler to store all output prediction history for visualization purposes. Following is my first try to implement it.

import torch
from ignite.engine import Events

class EpochOutputStore(object):
    """EpochOutputStore handler to save output prediction and target history
    after every epoch, could be useful for e.g., visualization purposes.

    Note:
        This can potentially lead to a memory error if the output data is
    larger than available RAM.

    Args:
         output_transform (callable, optional): a callable that is used to
         transform the :class:`~ignite.engine.engine.Engine`'s
         ``process_function``'s output into the form `y_pred, y`, e.g.,
         lambda x, y, y_pred: y_pred, y

    Examples:
    .. code-block:: python
        import ...

        eos = EpochOutputStore()
        trainer = create_supervised_trainer(model, optimizer, loss)
        train_evaluator = create_supervised_evaluator(model, metrics={"acc": Accuracy()})
        eos.attach(train_evaluator)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_training_results(engine):
            train_evaluator.run(train_loader)
            y_pred, y = eos.get_output()
            # plottings

    """
    def __init__(self, output_transform=lambda x: x):
        self.predictions = None
        self.targets = None
        self.output_transform = output_transform

    def reset(self):
        self.predictions = []
        self.targets = []

    def update(self, engine):
        y_pred, y = self.output_transform(engine.state.output)
        self.predictions.append(y_pred)
        self.targets.append(y)

    def attach(self, engine):
        engine.add_event_handler(Events.EPOCH_STARTED, self.reset)
        engine.add_event_handler(Events.ITERATION_COMPLETED, self.update)

    def get_output(self, to_numpy=False):
        prediction_tensor = torch.cat(self.predictions, dim=0)
        target_tensor = torch.cat(self.targets, dim=0)
        
        if to_numpy:
            prediction_tensor = prediction_tensor.cpu().detach().numpy()
            target_tensor = target_tensor.cpu().detach().numpy()
            
        return prediction_tensor, target_tensor

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions