From 57fcab9a222927fb947781ff81f01d3222f16c35 Mon Sep 17 00:00:00 2001 From: isolet Date: Thu, 23 Apr 2020 10:07:31 +0800 Subject: [PATCH 1/4] Add support for nested metric values. (#959) --- ignite/metrics/metric.py | 28 ++++++++-- tests/ignite/metrics/test_metric.py | 87 +++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 5 deletions(-) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 0cbab97aaffa..e691718011d8 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -4,7 +4,7 @@ from collections.abc import Mapping import warnings -from typing import Callable, Union, Optional, Any +from typing import Any, Callable, Dict, Optional, Union import torch import torch.distributed as dist @@ -82,7 +82,9 @@ def compute(self) -> Any: This is called at the end of each epoch. Returns: - Any: the actual quantity of interest. + Any: the actual quantity of interest. However, if a :class:`dict` is returned, it should contain + :class:`numbers.Number` or :class:`torch.tensor` values and will flattened into + `engine.state.metrics` when :func:`~ignite.metrics.Metric.completed` is called. Raises: NotComputableError: raised when the metric cannot be computed. @@ -114,6 +116,19 @@ def _sync_all_reduce(self, tensor: Union[torch.Tensor, numbers.Number]) -> Union return tensor.item() return tensor + def ensure_metric_value(self, value: Any, raise_error=True) -> Union[numbers.Number, torch.Tensor]: + if isinstance(value, torch.Tensor): + if len(value.size()) == 0: + value = value.item() + elif not isinstance(value, numbers.Number) and raise_error: + raise TypeError( + "Metric value should has type of `numbers.Number` or `torch.Tensor`, but given `{}`".format( + value.__class__.__name__ + ) + ) + + return value + def started(self, engine: Engine) -> None: self.reset() @@ -137,9 +152,12 @@ def iteration_completed(self, engine: Engine) -> None: def completed(self, engine: Engine, name: str) -> None: result = self.compute() - if torch.is_tensor(result) and len(result.shape) == 0: - result = result.item() - engine.state.metrics[name] = result + if isinstance(result, dict): + for key, value in result.items(): + engine.state.metrics[key] = self.ensure_metric_value(value, raise_error=True) + else: + result = self.ensure_metric_value(result, raise_error=False) + engine.state.metrics[name] = result def attach(self, engine: Engine, name: str) -> None: """ diff --git a/tests/ignite/metrics/test_metric.py b/tests/ignite/metrics/test_metric.py index 7d90ad9e27c9..fe2dd5f6e31b 100644 --- a/tests/ignite/metrics/test_metric.py +++ b/tests/ignite/metrics/test_metric.py @@ -1,3 +1,4 @@ +import numbers import os import sys @@ -632,3 +633,89 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) _test_distrib__sync_all_reduce(device) _test_distrib_sync_all_reduce_decorator(device) + + +def test_ensure_metric_value(): + class DummyMetric(Metric): + def reset(self): + pass + + def compute(self): + pass + + def update(self, output): + pass + + m = DummyMetric() + + # int + v_int = 1 + v_int_ensured = m.ensure_metric_value(v_int) + assert v_int_ensured == v_int + assert isinstance(v_int_ensured, int) + + # float + v_float = 1.0 + v_float_ensured = m.ensure_metric_value(v_float) + assert v_float_ensured == v_float + assert isinstance(v_float_ensured, float) + + # 0d tensor + v_tensor_0d = torch.tensor(1.0) + v_tensor_0d_ensured = m.ensure_metric_value(v_tensor_0d) + assert v_tensor_0d_ensured == v_tensor_0d + assert isinstance(v_tensor_0d_ensured, numbers.Number) + + # other value type + v_str = "dummy" + v_str_ensured = m.ensure_metric_value(v_str, raise_error=False) + assert v_str_ensured == v_str + assert isinstance(v_str, str) + with pytest.raises( + TypeError, match=r"Metric value should has type of `numbers.Number` or `torch.Tensor`, but given `str`", + ): + m.ensure_metric_value(v_str, raise_error=True) + + +def test_completed(): + class DummyMetric(Metric): + def reset(self): + pass + + def compute(self): + pass + + def update(self, output): + pass + + m = DummyMetric() + + # plain number values + engine = MagicMock(state=State(metrics={})) + m.compute = MagicMock(return_value=1) + m.completed(engine, "metric") + assert engine.state.metrics == {"metric": 1} + + engine = MagicMock(state=State(metrics={})) + m.compute = MagicMock(return_value=torch.tensor(1.0)) + m.completed(engine, "metric") + assert engine.state.metrics == {"metric": 1.0} + assert isinstance(engine.state.metrics["metric"], numbers.Number) + + # dict with number values + engine = MagicMock(state=State(metrics={})) + m.compute = MagicMock(return_value={"foo": 1, "bar": torch.tensor(2.0)}) + m.completed(engine, "metric") + assert engine.state.metrics == {"foo": 1, "bar": 2.0} + + # dict with non number values + engine = MagicMock(state=State(metrics={})) + m.compute = MagicMock(return_value={"foo": 1, "bar": "str"}) + with raises(TypeError): + m.completed(engine, "metric") + + # other + engine = MagicMock(state=State(metrics={})) + m.compute = MagicMock(return_value="foo") + m.completed(engine, "metric") + assert engine.state.metrics == {"metric": "foo"} From 36d01ef68645f863168a4a5a9cd4a91ea3d96785 Mon Sep 17 00:00:00 2001 From: isolet Date: Fri, 24 Apr 2020 15:29:53 +0800 Subject: [PATCH 2/4] Shallow faltten mapping metrics. --- ignite/metrics/metric.py | 27 ++++--------- tests/ignite/metrics/test_metric.py | 62 +++-------------------------- 2 files changed, 13 insertions(+), 76 deletions(-) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index e691718011d8..0a11c0960a63 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -82,9 +82,9 @@ def compute(self) -> Any: This is called at the end of each epoch. Returns: - Any: the actual quantity of interest. However, if a :class:`dict` is returned, it should contain - :class:`numbers.Number` or :class:`torch.tensor` values and will flattened into - `engine.state.metrics` when :func:`~ignite.metrics.Metric.completed` is called. + Any: the actual quantity of interest. However, if a :class:`~collections.abc.Mapping` is returned, + it will be (shallow) flattened into `engine.state.metrics` when + :func:`~ignite.metrics.Metric.completed` is called. Raises: NotComputableError: raised when the metric cannot be computed. @@ -116,19 +116,6 @@ def _sync_all_reduce(self, tensor: Union[torch.Tensor, numbers.Number]) -> Union return tensor.item() return tensor - def ensure_metric_value(self, value: Any, raise_error=True) -> Union[numbers.Number, torch.Tensor]: - if isinstance(value, torch.Tensor): - if len(value.size()) == 0: - value = value.item() - elif not isinstance(value, numbers.Number) and raise_error: - raise TypeError( - "Metric value should has type of `numbers.Number` or `torch.Tensor`, but given `{}`".format( - value.__class__.__name__ - ) - ) - - return value - def started(self, engine: Engine) -> None: self.reset() @@ -152,11 +139,13 @@ def iteration_completed(self, engine: Engine) -> None: def completed(self, engine: Engine, name: str) -> None: result = self.compute() - if isinstance(result, dict): + if isinstance(result, Mapping): for key, value in result.items(): - engine.state.metrics[key] = self.ensure_metric_value(value, raise_error=True) + engine.state.metrics[key] = value else: - result = self.ensure_metric_value(result, raise_error=False) + if isinstance(result, torch.Tensor) and len(result.size()) == 0: + result = result.item() + engine.state.metrics[name] = result def attach(self, engine: Engine, name: str) -> None: diff --git a/tests/ignite/metrics/test_metric.py b/tests/ignite/metrics/test_metric.py index fe2dd5f6e31b..198e4c969720 100644 --- a/tests/ignite/metrics/test_metric.py +++ b/tests/ignite/metrics/test_metric.py @@ -635,48 +635,6 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): _test_distrib_sync_all_reduce_decorator(device) -def test_ensure_metric_value(): - class DummyMetric(Metric): - def reset(self): - pass - - def compute(self): - pass - - def update(self, output): - pass - - m = DummyMetric() - - # int - v_int = 1 - v_int_ensured = m.ensure_metric_value(v_int) - assert v_int_ensured == v_int - assert isinstance(v_int_ensured, int) - - # float - v_float = 1.0 - v_float_ensured = m.ensure_metric_value(v_float) - assert v_float_ensured == v_float - assert isinstance(v_float_ensured, float) - - # 0d tensor - v_tensor_0d = torch.tensor(1.0) - v_tensor_0d_ensured = m.ensure_metric_value(v_tensor_0d) - assert v_tensor_0d_ensured == v_tensor_0d - assert isinstance(v_tensor_0d_ensured, numbers.Number) - - # other value type - v_str = "dummy" - v_str_ensured = m.ensure_metric_value(v_str, raise_error=False) - assert v_str_ensured == v_str - assert isinstance(v_str, str) - with pytest.raises( - TypeError, match=r"Metric value should has type of `numbers.Number` or `torch.Tensor`, but given `str`", - ): - m.ensure_metric_value(v_str, raise_error=True) - - def test_completed(): class DummyMetric(Metric): def reset(self): @@ -690,29 +648,19 @@ def update(self, output): m = DummyMetric() - # plain number values - engine = MagicMock(state=State(metrics={})) - m.compute = MagicMock(return_value=1) - m.completed(engine, "metric") - assert engine.state.metrics == {"metric": 1} - + # tensor engine = MagicMock(state=State(metrics={})) m.compute = MagicMock(return_value=torch.tensor(1.0)) m.completed(engine, "metric") assert engine.state.metrics == {"metric": 1.0} assert isinstance(engine.state.metrics["metric"], numbers.Number) - # dict with number values + # mapping engine = MagicMock(state=State(metrics={})) - m.compute = MagicMock(return_value={"foo": 1, "bar": torch.tensor(2.0)}) + metrics = {"foo": 1, "bar": torch.tensor(2.0), "baz": {"qux": "quux"}} + m.compute = MagicMock(return_value=metrics) m.completed(engine, "metric") - assert engine.state.metrics == {"foo": 1, "bar": 2.0} - - # dict with non number values - engine = MagicMock(state=State(metrics={})) - m.compute = MagicMock(return_value={"foo": 1, "bar": "str"}) - with raises(TypeError): - m.completed(engine, "metric") + assert engine.state.metrics == metrics # other engine = MagicMock(state=State(metrics={})) From 32c816ec8650c28d02cd841abf7919d2f2775756 Mon Sep 17 00:00:00 2001 From: isolet Date: Fri, 24 Apr 2020 15:43:35 +0800 Subject: [PATCH 3/4] Remove unused typings. --- ignite/metrics/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 0a11c0960a63..b06de07a28c2 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -4,7 +4,7 @@ from collections.abc import Mapping import warnings -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.distributed as dist From 6278225568c6b0e6f49f2fcd85073c3dfb1c5250 Mon Sep 17 00:00:00 2001 From: isolet Date: Fri, 24 Apr 2020 15:44:42 +0800 Subject: [PATCH 4/4] Remove unexpected changes. --- ignite/metrics/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index b06de07a28c2..19357e49876b 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -4,7 +4,7 @@ from collections.abc import Mapping import warnings -from typing import Any, Callable, Optional, Union +from typing import Callable, Union, Optional, Any import torch import torch.distributed as dist