Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 874 : add custom events for Metrics #979

Merged
merged 59 commits into from
May 18, 2020
Merged
Show file tree
Hide file tree
Changes from 58 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
17d1d65
first attempt to fix issue_874
sdesrozis Apr 25, 2020
a871d9a
autopep8 fix
Apr 25, 2020
3c41699
add usage
sdesrozis Apr 25, 2020
406a0e6
fix bug
sdesrozis Apr 25, 2020
e52b50e
autopep8 fix
Apr 25, 2020
db31275
add missing cls in module
sdesrozis Apr 25, 2020
8b119da
Merge branch 'issue_874' of github.com:sdesrozis/ignite into issue_874
sdesrozis Apr 25, 2020
c5b9f9a
Merge branch 'master' into issue_874
sdesrozis Apr 25, 2020
746473a
fix bug : completed must be registered last
sdesrozis Apr 25, 2020
1a6b5ab
autopep8 fix
Apr 25, 2020
a1305ea
add test
sdesrozis Apr 26, 2020
f2e1377
autopep8 fix
Apr 26, 2020
d5de866
attempt to move usage from __init__ to attach()
sdesrozis Apr 26, 2020
7825135
autopep8 fix
Apr 26, 2020
21ef18b
refactor metrics : split attachment from metric
sdesrozis Apr 26, 2020
c9e05cd
Merge branch 'issue_874' of github.com:sdesrozis/ignite into issue_874
sdesrozis Apr 26, 2020
f9a9500
autopep8 fix
Apr 26, 2020
24080cf
Merge branch 'master' into issue_874
sdesrozis Apr 26, 2020
7deb841
remove EngineMetric
sdesrozis Apr 26, 2020
6ed830c
add str in api
sdesrozis Apr 26, 2020
574fc23
add doc
sdesrozis Apr 26, 2020
4b937c6
autopep8 fix
Apr 26, 2020
51ace55
improve doc
sdesrozis Apr 26, 2020
5381432
improve doc
sdesrozis Apr 26, 2020
a0a7085
Merge branch 'issue_874' of github.com:sdesrozis/ignite into issue_874
sdesrozis Apr 26, 2020
1fc018c
Merge branch 'master' into issue_874
vfdev-5 Apr 26, 2020
c13a1ed
Merge branch 'master' into issue_874
sdesrozis Apr 26, 2020
9b0b998
Merge branch 'master' into issue_874
sdesrozis Apr 26, 2020
8271ab7
Merge branch 'master' into issue_874
sdesrozis Apr 30, 2020
2019ea6
Merge branch 'master' into issue_874
sdesrozis May 1, 2020
83e7dc4
Merge branch 'master' into issue_874
sdesrozis May 1, 2020
341f886
fix filtered batch usage
sdesrozis May 1, 2020
1f3dbd3
fix _check_signature for wrappers
sdesrozis May 1, 2020
ec4f472
disable _assert_non_filtered_event
sdesrozis May 1, 2020
e111bc8
add test for batch filtered usage
sdesrozis May 1, 2020
9bfe2b8
Merge branch 'issue_874' of github.com:sdesrozis/ignite into issue_874
sdesrozis May 1, 2020
d709f65
revert follow_wrapped
sdesrozis May 1, 2020
461061b
fix bug due to decoration and wrappers
sdesrozis May 1, 2020
2cf1717
fix tests
sdesrozis May 1, 2020
68854fc
autopep8 fix
May 1, 2020
febed1c
Update test_metric.py
sdesrozis May 2, 2020
779c8ac
remove _assert_non_filtered_event
sdesrozis May 2, 2020
df929a2
Merge branch 'master' into issue_874
sdesrozis May 5, 2020
f84dc55
Merge branch 'master' into issue_874
sdesrozis May 9, 2020
094974a
Merge branch 'master' into issue_874
vfdev-5 May 11, 2020
27e7f2f
autopep8 fix
May 11, 2020
ee32012
Merge branch 'master' into issue_874
vfdev-5 May 12, 2020
211e7c8
Merge branch 'master' into issue_874
vfdev-5 May 13, 2020
7c7a60d
Merge branch 'master' into issue_874
sdesrozis May 16, 2020
f7670c0
update doc
sdesrozis May 16, 2020
965765b
Merge branch 'issue_874' of github.com:sdesrozis/ignite into issue_874
sdesrozis May 16, 2020
77f0eea
autopep8 fix
May 16, 2020
fd2e6d1
Merge branch 'master' into issue_874
sdesrozis May 17, 2020
015839c
#1004 allows to use torch.no_grad decorator in class
May 18, 2020
330b614
autopep8 fix
May 18, 2020
3515d6d
improve doc
May 18, 2020
b9b71af
Merge branch 'issue_874' of github.com:sdesrozis/ignite into issue_874
May 18, 2020
1f4df05
Update metrics.rst
vfdev-5 May 18, 2020
31d4843
Fixed flake8 and improved docs
vfdev-5 May 18, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 26 additions & 2 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,30 @@ We can check this implementation in a simple case:
print(m._num_correct, m._num_examples, res)
# Out: 1 3 0.3333333333333333

Metrics and its usages
^^^^^^^^^^^^^^^^^^^^^^

By default, `Metrics` are epoch-wise, it means

- `reset()` is triggered every :attr:`~ignite.engine.Events.EPOCH_STARTED`
- `update(output)` is triggered every :attr:`~ignite.engine.Events.ITERATION_COMPLETED`
- `compute()` is triggered every :attr:`~ignite.engine.Events.EPOCH_COMPLETED`

Usages can be user defined by creating a class inheriting for :class:`~ignite.metrics.MetricUsage`. See the list below of usages.

Complete list of usages
```````````````````````

- :class:`~ignite.metrics.MetricUsage`
- :class:`~ignite.metrics.EpochWise`
- :class:`~ignite.metrics.BatchWise`
- :class:`~ignite.metrics.BatchFiltered`

Metrics and distributed computations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In the above example, `CustomAccuracy` constructor has `device` argument and `reset`, `update`, `compute` methods are decorated with `reinit__is_reduced`, `sync_all_reduce`. The purpose of these features is to adapt metrics in distributed computations on CUDA devices and assuming the backend to support `"all_reduce" operation <https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_reduce>`_. User can specify the device (by default, `cuda`) at metric's initialization. This device _can_ be used to store internal variables on and to collect all results from all participating devices. More precisely, in the above example we added `@sync_all_reduce("_num_examples", "_num_correct")` over `compute` method. This means that when `compute` method is called, metric's interal variables `self._num_examples` and `self._num_correct` are summed up over all participating devices. Therefore, once collected, these internal variables can be used to compute the final metric value.


Complete list of metrics
------------------------

Expand All @@ -210,7 +227,6 @@ Complete list of metrics
- :class:`~ignite.metrics.TopKCategoricalAccuracy`
- :class:`~ignite.metrics.VariableAccumulation`


.. currentmodule:: ignite.metrics

.. autoclass:: Accuracy
Expand Down Expand Up @@ -255,3 +271,11 @@ Complete list of metrics
.. autoclass:: TopKCategoricalAccuracy

.. autoclass:: VariableAccumulation

.. autoclass:: MetricUsage

.. autoclass:: EpochWise

.. autoclass:: BatchWise

.. autoclass:: BatchFiltered
5 changes: 1 addition & 4 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class TBPTT_Events(EventEnum):

def _handler_wrapper(self, handler: Callable, event_name: Any, event_filter: Callable) -> Callable:
# signature of the following wrapper will be inspected during registering to check if engine is necessary
# we have to build a wrapper with relevant signature : solution is functools.wrapsgit s
# we have to build a wrapper with relevant signature : solution is functools.wraps
@functools.wraps(handler)
def wrapper(*args, **kwargs) -> Any:
event = self.state.get_event_attrib_value(event_name)
Expand Down Expand Up @@ -301,8 +301,6 @@ def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None)
to ``None`` to search all events.
"""
if event_name is not None:
self._assert_non_filtered_event(event_name)

if event_name not in self._event_handlers:
return False
events = [event_name]
Expand All @@ -328,7 +326,6 @@ def remove_event_handler(self, handler: Callable, event_name: Any):
event_name: The event the handler attached to.

"""
self._assert_non_filtered_event(event_name)
if event_name not in self._event_handlers:
raise ValueError("Input event name '{}' does not exist".format(event_name))

Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ignite.metrics.mean_absolute_error import MeanAbsoluteError
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
from ignite.metrics.mean_squared_error import MeanSquaredError
from ignite.metrics.metric import Metric
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage
from ignite.metrics.metrics_lambda import MetricsLambda
from ignite.metrics.precision import Precision
from ignite.metrics.recall import Recall
Expand Down
179 changes: 157 additions & 22 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,96 @@

from ignite.engine import Engine, Events

__all__ = ["Metric"]
__all__ = ["Metric", "MetricUsage", "EpochWise", "BatchWise", "BatchFiltered"]


class MetricUsage:
"""
Base class for all usages of Metrics.

A usage of metric defines the events when a metric starts to compute, updates and completes.
Valid events are from :class:`~ignite.engine.Events`.

Args:
started: event when the metric starts to compute. This event will be associated to
:func:`~ignite.metrics.Metric.started`.
completed: event when the metric completes. This event will be associated to
:func:`~ignite.metrics.Metric.completed`.
iteration_completed: event when the metric updates. This event will be associated to
:func:`~ignite.metrics.Metric.iteration_completed`.
"""

def __init__(self, started, completed, iteration_completed):
self.__started = started
self.__completed = completed
self.__iteration_completed = iteration_completed

@property
def STARTED(self):
return self.__started

@property
def COMPLETED(self):
return self.__completed

@property
def ITERATION_COMPLETED(self):
return self.__iteration_completed


class EpochWise(MetricUsage):
"""
Epoch-wise usage of Metrics. It's the default and most common usage of metrics.

- The method :func:`~ignite.metrics.Metric.started` is triggered every :attr:`~ignite.engine.Events.EPOCH_STARTED`.
- The method :func:`~ignite.metrics.Metric.iteration_completed` is triggered every :attr:`~ignite.engine.Events.ITERATION_COMPLETED`.
- The method :func:`~ignite.metrics.Metric.completed` is triggered every :attr:`~ignite.engine.Events.EPOCH_COMPLETED`.
"""

def __init__(self):
super(EpochWise, self).__init__(
started=Events.EPOCH_STARTED,
completed=Events.EPOCH_COMPLETED,
iteration_completed=Events.ITERATION_COMPLETED,
)


class BatchWise(MetricUsage):
"""
Batch-wise usage of Metrics.

- The method :func:`~ignite.metrics.Metric.started` is triggered every :attr:`~ignite.engine.Events.ITERATION_STARTED`.
- The method :func:`~ignite.metrics.Metric.iteration_completed` is triggered every :attr:`~ignite.engine.Events.ITERATION_COMPLETED`.
- The method :func:`~ignite.metrics.Metric.completed` is triggered every :attr:`~ignite.engine.Events.ITERATION_COMPLETED`.
"""

def __init__(self):
super(BatchWise, self).__init__(
started=Events.ITERATION_STARTED,
completed=Events.ITERATION_COMPLETED,
iteration_completed=Events.ITERATION_COMPLETED,
)


class BatchFiltered(MetricUsage):
"""
Batch filtered usage of Metrics. This usage is similar to epoch-wise but update event is filtered.

- The method :func:`~ignite.metrics.Metric.started` is triggered every :attr:`~ignite.engine.Events.EPOCH_STARTED`.
- The method :func:`~ignite.metrics.Metric.iteration_completed` is triggered every filtered :attr:`~ignite.engine.Events.ITERATION_COMPLETED`.
- The method :func:`~ignite.metrics.Metric.completed` is triggered every :attr:`~ignite.engine.Events.EPOCH_COMPLETED`.

Args:
args (sequence): arguments for the setup of :attr:`~ignite.engine.Events.ITERATION_COMPLETED`.

"""

def __init__(self, *args, **kwargs):
super(BatchFiltered, self).__init__(
started=Events.EPOCH_STARTED,
completed=Events.EPOCH_COMPLETED,
iteration_completed=Events.ITERATION_COMPLETED(*args, **kwargs),
)


class Metric(metaclass=ABCMeta):
Expand All @@ -23,7 +112,7 @@ class Metric(metaclass=ABCMeta):
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
device (str of torch.device, optional): device specification in case of distributed computation usage.
device (str or torch.device, optional): device specification in case of distributed computation usage.
In most of the cases, it can be defined as "cuda:local_rank" or "cuda"
if already set `torch.cuda.set_device(local_rank)`. By default, if a distributed process group is
initialized and available, device is set to `cuda`.
Expand All @@ -32,7 +121,9 @@ class Metric(metaclass=ABCMeta):

_required_output_keys = ("y_pred", "y")

def __init__(self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None):
def __init__(
self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None,
):
self._output_transform = output_transform

# Check device if distributed is initialized:
Expand Down Expand Up @@ -147,14 +238,28 @@ def completed(self, engine: Engine, name: str) -> None:

engine.state.metrics[name] = result

def attach(self, engine: Engine, name: str) -> None:
def _check_usage(self, usage: Union[str, MetricUsage]) -> MetricUsage:
if isinstance(usage, str):
if usage == "epoch_wise":
usage = EpochWise()
elif usage == "batch_wise":
usage = BatchWise()
else:
raise ValueError("usage should be 'epoch_wise' or 'batch_wise', get {}".format(usage))
if not isinstance(usage, MetricUsage):
raise TypeError("Unhandled usage type {}".format(type(usage)))
return usage

def attach(self, engine: Engine, name: str, usage: Union[str, MetricUsage] = EpochWise()) -> None:
"""
Attaches current metric to provided engine. On the end of engine's run,
`engine.state.metrics` dictionary will contain computed metric's value under provided name.
Attaches current metric to provided engine. On the end of engine's run, `engine.state.metrics` dictionary will
contain computed metric's value under provided name.

Args:
engine (Engine): the engine to which the metric must be attached
name (str): the name of the metric to attach
usage (str or MetricUsage, optional): the usage of the metric. Valid string values should be
'epoch_wise' (default) or 'batch_wise'.

Example:

Expand All @@ -166,14 +271,26 @@ def attach(self, engine: Engine, name: str) -> None:
assert "mymetric" in engine.run(data).metrics

assert metric.is_attached(engine)
"""
engine.add_event_handler(Events.EPOCH_COMPLETED, self.completed, name)
if not engine.has_event_handler(self.started, Events.EPOCH_STARTED):
engine.add_event_handler(Events.EPOCH_STARTED, self.started)
if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)

def detach(self, engine: Engine) -> None:
Example with usage:

.. code-block:: python

metric = ...
metric.attach(engine, "mymetric", usage="batch_wise")

assert "mymetric" in engine.run(data).metrics

assert metric.is_attached(engine, usage="batch_wise")
"""
usage = self._check_usage(usage)
if not engine.has_event_handler(self.started, usage.STARTED):
engine.add_event_handler(usage.STARTED, self.started)
if not engine.has_event_handler(self.iteration_completed, usage.ITERATION_COMPLETED):
engine.add_event_handler(usage.ITERATION_COMPLETED, self.iteration_completed)
engine.add_event_handler(usage.COMPLETED, self.completed, name)

def detach(self, engine: Engine, usage: Union[str, MetricUsage] = EpochWise()) -> None:
"""
Detaches current metric from the engine and no metric's computation is done during the run.
This method in conjunction with :meth:`~ignite.metrics.Metric.attach` can be useful if several
Expand All @@ -182,6 +299,8 @@ def detach(self, engine: Engine) -> None:

Args:
engine (Engine): the engine from which the metric must be detached
usage (str or MetricUsage, optional): the usage of the metric. Valid string values should be
'epoch_wise' (default) or 'batch_wise'.

Example:

Expand All @@ -194,23 +313,39 @@ def detach(self, engine: Engine) -> None:
assert "mymetric" not in engine.run(data).metrics

assert not metric.is_attached(engine)

Example with usage:

.. code-block:: python

metric = ...
engine = ...
metric.detach(engine, usage="batch_wise")

assert "mymetric" not in engine.run(data).metrics

assert not metric.is_attached(engine, usage="batch_wise")
"""
if engine.has_event_handler(self.completed, Events.EPOCH_COMPLETED):
engine.remove_event_handler(self.completed, Events.EPOCH_COMPLETED)
if engine.has_event_handler(self.started, Events.EPOCH_STARTED):
engine.remove_event_handler(self.started, Events.EPOCH_STARTED)
if engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
engine.remove_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED)

def is_attached(self, engine: Engine) -> bool:
usage = self._check_usage(usage)
if engine.has_event_handler(self.completed, usage.COMPLETED):
engine.remove_event_handler(self.completed, usage.COMPLETED)
if engine.has_event_handler(self.started, usage.STARTED):
engine.remove_event_handler(self.started, usage.STARTED)
if engine.has_event_handler(self.iteration_completed, usage.ITERATION_COMPLETED):
engine.remove_event_handler(self.iteration_completed, usage.ITERATION_COMPLETED)

def is_attached(self, engine: Engine, usage: Union[str, MetricUsage] = EpochWise()) -> bool:
"""
Checks if current metric is attached to provided engine. If attached, metric's computed
value is written to `engine.state.metrics` dictionary.

Args:
engine (Engine): the engine checked from which the metric should be attached
usage (str or MetricUsage, optional): the usage of the metric. Valid string values should be
'epoch_wise' (default) or 'batch_wise'.
"""
return engine.has_event_handler(self.completed, Events.EPOCH_COMPLETED)
usage = self._check_usage(usage)
return engine.has_event_handler(self.completed, usage.COMPLETED)

def __add__(self, other):
from ignite.metrics.metrics_lambda import MetricsLambda
Expand Down