Skip to content

Commit

Permalink
Issue 874 : add custom events for Metrics (#979)
Browse files Browse the repository at this point in the history
* first attempt to fix issue_874

* autopep8 fix

* add usage

* fix bug

* autopep8 fix

* add missing cls in module

* fix bug : completed must be registered last

* autopep8 fix

* add test

* autopep8 fix

* attempt to move usage from __init__ to attach()

* autopep8 fix

* refactor metrics : split attachment from metric

* autopep8 fix

* remove EngineMetric

* add str in api

* add doc

* autopep8 fix

* improve doc

* improve doc

* fix filtered batch usage

* fix _check_signature for wrappers

* disable _assert_non_filtered_event

* add test for batch filtered usage

* revert follow_wrapped

* fix bug due to decoration and wrappers

* fix tests

* autopep8 fix

* Update test_metric.py

* remove _assert_non_filtered_event

* autopep8 fix

* update doc

* autopep8 fix

* #1004 allows to use torch.no_grad decorator in class

* autopep8 fix

* improve doc

* Update metrics.rst

* Fixed flake8 and improved docs

Co-authored-by: AutoPEP8 <>
Co-authored-by: vfdev <vfdev.5@gmail.com>
Co-authored-by: Desroziers <sylvain.desroziers@ifpen.fr>
  • Loading branch information
3 people committed May 18, 2020
1 parent 5fa1796 commit 5d26c97
Show file tree
Hide file tree
Showing 7 changed files with 308 additions and 59 deletions.
28 changes: 26 additions & 2 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,30 @@ We can check this implementation in a simple case:
print(m._num_correct, m._num_examples, res)
# Out: 1 3 0.3333333333333333
Metrics and its usages
^^^^^^^^^^^^^^^^^^^^^^

By default, `Metrics` are epoch-wise, it means

- `reset()` is triggered every :attr:`~ignite.engine.Events.EPOCH_STARTED`
- `update(output)` is triggered every :attr:`~ignite.engine.Events.ITERATION_COMPLETED`
- `compute()` is triggered every :attr:`~ignite.engine.Events.EPOCH_COMPLETED`

Usages can be user defined by creating a class inheriting for :class:`~ignite.metrics.MetricUsage`. See the list below of usages.

Complete list of usages
```````````````````````

- :class:`~ignite.metrics.MetricUsage`
- :class:`~ignite.metrics.EpochWise`
- :class:`~ignite.metrics.BatchWise`
- :class:`~ignite.metrics.BatchFiltered`

Metrics and distributed computations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In the above example, `CustomAccuracy` constructor has `device` argument and `reset`, `update`, `compute` methods are decorated with `reinit__is_reduced`, `sync_all_reduce`. The purpose of these features is to adapt metrics in distributed computations on CUDA devices and assuming the backend to support `"all_reduce" operation <https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_reduce>`_. User can specify the device (by default, `cuda`) at metric's initialization. This device _can_ be used to store internal variables on and to collect all results from all participating devices. More precisely, in the above example we added `@sync_all_reduce("_num_examples", "_num_correct")` over `compute` method. This means that when `compute` method is called, metric's interal variables `self._num_examples` and `self._num_correct` are summed up over all participating devices. Therefore, once collected, these internal variables can be used to compute the final metric value.


Complete list of metrics
------------------------

Expand All @@ -210,7 +227,6 @@ Complete list of metrics
- :class:`~ignite.metrics.TopKCategoricalAccuracy`
- :class:`~ignite.metrics.VariableAccumulation`


.. currentmodule:: ignite.metrics

.. autoclass:: Accuracy
Expand Down Expand Up @@ -255,3 +271,11 @@ Complete list of metrics
.. autoclass:: TopKCategoricalAccuracy

.. autoclass:: VariableAccumulation

.. autoclass:: MetricUsage

.. autoclass:: EpochWise

.. autoclass:: BatchWise

.. autoclass:: BatchFiltered
5 changes: 1 addition & 4 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class TBPTT_Events(EventEnum):

def _handler_wrapper(self, handler: Callable, event_name: Any, event_filter: Callable) -> Callable:
# signature of the following wrapper will be inspected during registering to check if engine is necessary
# we have to build a wrapper with relevant signature : solution is functools.wrapsgit s
# we have to build a wrapper with relevant signature : solution is functools.wraps
@functools.wraps(handler)
def wrapper(*args, **kwargs) -> Any:
event = self.state.get_event_attrib_value(event_name)
Expand Down Expand Up @@ -301,8 +301,6 @@ def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None)
to ``None`` to search all events.
"""
if event_name is not None:
self._assert_non_filtered_event(event_name)

if event_name not in self._event_handlers:
return False
events = [event_name]
Expand All @@ -328,7 +326,6 @@ def remove_event_handler(self, handler: Callable, event_name: Any):
event_name: The event the handler attached to.
"""
self._assert_non_filtered_event(event_name)
if event_name not in self._event_handlers:
raise ValueError("Input event name '{}' does not exist".format(event_name))

Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ignite.metrics.mean_absolute_error import MeanAbsoluteError
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
from ignite.metrics.mean_squared_error import MeanSquaredError
from ignite.metrics.metric import Metric
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage
from ignite.metrics.metrics_lambda import MetricsLambda
from ignite.metrics.precision import Precision
from ignite.metrics.recall import Recall
Expand Down
211 changes: 186 additions & 25 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,103 @@

from ignite.engine import Engine, Events

__all__ = ["Metric"]
__all__ = ["Metric", "MetricUsage", "EpochWise", "BatchWise", "BatchFiltered"]


class MetricUsage:
"""
Base class for all usages of metrics.
A usage of metric defines the events when a metric starts to compute, updates and completes.
Valid events are from :class:`~ignite.engine.Events`.
Args:
started: event when the metric starts to compute. This event will be associated to
:meth:`~ignite.metrics.Metric.started`.
completed: event when the metric completes. This event will be associated to
:meth:`~ignite.metrics.Metric.completed`.
iteration_completed: event when the metric updates. This event will be associated to
:meth:`~ignite.metrics.Metric.iteration_completed`.
"""

def __init__(self, started, completed, iteration_completed):
self.__started = started
self.__completed = completed
self.__iteration_completed = iteration_completed

@property
def STARTED(self):
return self.__started

@property
def COMPLETED(self):
return self.__completed

@property
def ITERATION_COMPLETED(self):
return self.__iteration_completed


class EpochWise(MetricUsage):
"""
Epoch-wise usage of Metrics. It's the default and most common usage of metrics.
Metric's methods are triggered on the following engine events:
- :meth:`~ignite.metrics.Metric.started` on every :attr:`~ignite.engine.Events.EPOCH_STARTED`.
- :meth:`~ignite.metrics.Metric.iteration_completed` on every :attr:`~ignite.engine.Events.ITERATION_COMPLETED`.
- :meth:`~ignite.metrics.Metric.completed` on every :attr:`~ignite.engine.Events.EPOCH_COMPLETED`.
"""

def __init__(self):
super(EpochWise, self).__init__(
started=Events.EPOCH_STARTED,
completed=Events.EPOCH_COMPLETED,
iteration_completed=Events.ITERATION_COMPLETED,
)


class BatchWise(MetricUsage):
"""
Batch-wise usage of Metrics.
Metric's methods are triggered on the following engine events:
- :meth:`~ignite.metrics.Metric.started` on every :attr:`~ignite.engine.Events.ITERATION_STARTED`.
- :meth:`~ignite.metrics.Metric.iteration_completed` on every :attr:`~ignite.engine.Events.ITERATION_COMPLETED`.
- :meth:`~ignite.metrics.Metric.completed` on every :attr:`~ignite.engine.Events.ITERATION_COMPLETED`.
"""

def __init__(self):
super(BatchWise, self).__init__(
started=Events.ITERATION_STARTED,
completed=Events.ITERATION_COMPLETED,
iteration_completed=Events.ITERATION_COMPLETED,
)


class BatchFiltered(MetricUsage):
"""
Batch filtered usage of Metrics. This usage is similar to epoch-wise but update event is filtered.
Metric's methods are triggered on the following engine events:
- :meth:`~ignite.metrics.Metric.started` on every :attr:`~ignite.engine.Events.EPOCH_STARTED`.
- :meth:`~ignite.metrics.Metric.iteration_completed` on filtered :attr:`~ignite.engine.Events.ITERATION_COMPLETED`.
- :meth:`~ignite.metrics.Metric.completed` on every :attr:`~ignite.engine.Events.EPOCH_COMPLETED`.
Args:
args (sequence): arguments for the setup of :attr:`~ignite.engine.Events.ITERATION_COMPLETED` handled by
:meth:`~ignite.metrics.Metric.iteration_completed`.
"""

def __init__(self, *args, **kwargs):
super(BatchFiltered, self).__init__(
started=Events.EPOCH_STARTED,
completed=Events.EPOCH_COMPLETED,
iteration_completed=Events.ITERATION_COMPLETED(*args, **kwargs),
)


class Metric(metaclass=ABCMeta):
Expand All @@ -23,7 +119,7 @@ class Metric(metaclass=ABCMeta):
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
device (str of torch.device, optional): device specification in case of distributed computation usage.
device (str or torch.device, optional): device specification in case of distributed computation usage.
In most of the cases, it can be defined as "cuda:local_rank" or "cuda"
if already set `torch.cuda.set_device(local_rank)`. By default, if a distributed process group is
initialized and available, device is set to `cuda`.
Expand All @@ -32,7 +128,9 @@ class Metric(metaclass=ABCMeta):

_required_output_keys = ("y_pred", "y")

def __init__(self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None):
def __init__(
self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None,
):
self._output_transform = output_transform

# Check device if distributed is initialized:
Expand All @@ -57,7 +155,7 @@ def reset(self) -> None:
"""
Resets the metric to it's initial state.
This is called at the start of each epoch.
By default, this is called at the start of each epoch.
"""
pass

Expand All @@ -66,7 +164,7 @@ def update(self, output) -> None:
"""
Updates the metric's state using the passed batch output.
This is called once for each batch.
By default, this is called once for each batch.
Args:
output: the is the output from the engine's process function.
Expand All @@ -78,7 +176,7 @@ def compute(self) -> Any:
"""
Computes the metric based on it's accumulated state.
This is called at the end of each epoch.
By default, this is called at the end of each epoch.
Returns:
Any: the actual quantity of interest. However, if a :class:`~collections.abc.Mapping` is returned,
Expand Down Expand Up @@ -116,10 +214,23 @@ def _sync_all_reduce(self, tensor: Union[torch.Tensor, numbers.Number]) -> Union
return tensor

def started(self, engine: Engine) -> None:
"""Helper method to start data gathering for metric's computation. It is automatically attached to the
`engine` with :meth:`~ignite.metrics.Metric.attach`.
Args:
engine (Engine): the engine to which the metric must be attached
"""
self.reset()

@torch.no_grad()
def iteration_completed(self, engine: Engine) -> None:
"""Helper method to update metric's computation. It is automatically attached to the
`engine` with :meth:`~ignite.metrics.Metric.attach`.
Args:
engine (Engine): the engine to which the metric must be attached
"""

output = self._output_transform(engine.state.output)
if isinstance(output, Mapping):
if self._required_output_keys is None:
Expand All @@ -137,6 +248,12 @@ def iteration_completed(self, engine: Engine) -> None:
self.update(output)

def completed(self, engine: Engine, name: str) -> None:
"""Helper method to compute metric's value and put into the engine. It is automatically attached to the
`engine` with :meth:`~ignite.metrics.Metric.attach`.
Args:
engine (Engine): the engine to which the metric must be attached
"""
result = self.compute()
if isinstance(result, Mapping):
for key, value in result.items():
Expand All @@ -147,14 +264,28 @@ def completed(self, engine: Engine, name: str) -> None:

engine.state.metrics[name] = result

def attach(self, engine: Engine, name: str) -> None:
def _check_usage(self, usage: Union[str, MetricUsage]) -> MetricUsage:
if isinstance(usage, str):
if usage == "epoch_wise":
usage = EpochWise()
elif usage == "batch_wise":
usage = BatchWise()
else:
raise ValueError("usage should be 'epoch_wise' or 'batch_wise', get {}".format(usage))
if not isinstance(usage, MetricUsage):
raise TypeError("Unhandled usage type {}".format(type(usage)))
return usage

def attach(self, engine: Engine, name: str, usage: Union[str, MetricUsage] = EpochWise()) -> None:
"""
Attaches current metric to provided engine. On the end of engine's run,
`engine.state.metrics` dictionary will contain computed metric's value under provided name.
Attaches current metric to provided engine. On the end of engine's run, `engine.state.metrics` dictionary will
contain computed metric's value under provided name.
Args:
engine (Engine): the engine to which the metric must be attached
name (str): the name of the metric to attach
usage (str or MetricUsage, optional): the usage of the metric. Valid string values should be
'epoch_wise' (default) or 'batch_wise'.
Example:
Expand All @@ -166,14 +297,26 @@ def attach(self, engine: Engine, name: str) -> None:
assert "mymetric" in engine.run(data).metrics
assert metric.is_attached(engine)
"""
engine.add_event_handler(Events.EPOCH_COMPLETED, self.completed, name)
if not engine.has_event_handler(self.started, Events.EPOCH_STARTED):
engine.add_event_handler(Events.EPOCH_STARTED, self.started)
if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
def detach(self, engine: Engine) -> None:
Example with usage:
.. code-block:: python
metric = ...
metric.attach(engine, "mymetric", usage="batch_wise")
assert "mymetric" in engine.run(data).metrics
assert metric.is_attached(engine, usage="batch_wise")
"""
usage = self._check_usage(usage)
if not engine.has_event_handler(self.started, usage.STARTED):
engine.add_event_handler(usage.STARTED, self.started)
if not engine.has_event_handler(self.iteration_completed, usage.ITERATION_COMPLETED):
engine.add_event_handler(usage.ITERATION_COMPLETED, self.iteration_completed)
engine.add_event_handler(usage.COMPLETED, self.completed, name)

def detach(self, engine: Engine, usage: Union[str, MetricUsage] = EpochWise()) -> None:
"""
Detaches current metric from the engine and no metric's computation is done during the run.
This method in conjunction with :meth:`~ignite.metrics.Metric.attach` can be useful if several
Expand All @@ -182,6 +325,8 @@ def detach(self, engine: Engine) -> None:
Args:
engine (Engine): the engine from which the metric must be detached
usage (str or MetricUsage, optional): the usage of the metric. Valid string values should be
'epoch_wise' (default) or 'batch_wise'.
Example:
Expand All @@ -194,23 +339,39 @@ def detach(self, engine: Engine) -> None:
assert "mymetric" not in engine.run(data).metrics
assert not metric.is_attached(engine)
Example with usage:
.. code-block:: python
metric = ...
engine = ...
metric.detach(engine, usage="batch_wise")
assert "mymetric" not in engine.run(data).metrics
assert not metric.is_attached(engine, usage="batch_wise")
"""
if engine.has_event_handler(self.completed, Events.EPOCH_COMPLETED):
engine.remove_event_handler(self.completed, Events.EPOCH_COMPLETED)
if engine.has_event_handler(self.started, Events.EPOCH_STARTED):
engine.remove_event_handler(self.started, Events.EPOCH_STARTED)
if engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
engine.remove_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED)

def is_attached(self, engine: Engine) -> bool:
usage = self._check_usage(usage)
if engine.has_event_handler(self.completed, usage.COMPLETED):
engine.remove_event_handler(self.completed, usage.COMPLETED)
if engine.has_event_handler(self.started, usage.STARTED):
engine.remove_event_handler(self.started, usage.STARTED)
if engine.has_event_handler(self.iteration_completed, usage.ITERATION_COMPLETED):
engine.remove_event_handler(self.iteration_completed, usage.ITERATION_COMPLETED)

def is_attached(self, engine: Engine, usage: Union[str, MetricUsage] = EpochWise()) -> bool:
"""
Checks if current metric is attached to provided engine. If attached, metric's computed
value is written to `engine.state.metrics` dictionary.
Args:
engine (Engine): the engine checked from which the metric should be attached
usage (str or MetricUsage, optional): the usage of the metric. Valid string values should be
'epoch_wise' (default) or 'batch_wise'.
"""
return engine.has_event_handler(self.completed, Events.EPOCH_COMPLETED)
usage = self._check_usage(usage)
return engine.has_event_handler(self.completed, usage.COMPLETED)

def __add__(self, other):
from ignite.metrics.metrics_lambda import MetricsLambda
Expand Down
Loading

0 comments on commit 5d26c97

Please sign in to comment.