diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 0cbab97aaffa..19357e49876b 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -82,7 +82,9 @@ def compute(self) -> Any: This is called at the end of each epoch. Returns: - Any: the actual quantity of interest. + Any: the actual quantity of interest. However, if a :class:`~collections.abc.Mapping` is returned, + it will be (shallow) flattened into `engine.state.metrics` when + :func:`~ignite.metrics.Metric.completed` is called. Raises: NotComputableError: raised when the metric cannot be computed. @@ -137,9 +139,14 @@ def iteration_completed(self, engine: Engine) -> None: def completed(self, engine: Engine, name: str) -> None: result = self.compute() - if torch.is_tensor(result) and len(result.shape) == 0: - result = result.item() - engine.state.metrics[name] = result + if isinstance(result, Mapping): + for key, value in result.items(): + engine.state.metrics[key] = value + else: + if isinstance(result, torch.Tensor) and len(result.size()) == 0: + result = result.item() + + engine.state.metrics[name] = result def attach(self, engine: Engine, name: str) -> None: """ diff --git a/tests/ignite/metrics/test_metric.py b/tests/ignite/metrics/test_metric.py index 7d90ad9e27c9..198e4c969720 100644 --- a/tests/ignite/metrics/test_metric.py +++ b/tests/ignite/metrics/test_metric.py @@ -1,3 +1,4 @@ +import numbers import os import sys @@ -632,3 +633,37 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) _test_distrib__sync_all_reduce(device) _test_distrib_sync_all_reduce_decorator(device) + + +def test_completed(): + class DummyMetric(Metric): + def reset(self): + pass + + def compute(self): + pass + + def update(self, output): + pass + + m = DummyMetric() + + # tensor + engine = MagicMock(state=State(metrics={})) + m.compute = MagicMock(return_value=torch.tensor(1.0)) + m.completed(engine, "metric") + assert engine.state.metrics == {"metric": 1.0} + assert isinstance(engine.state.metrics["metric"], numbers.Number) + + # mapping + engine = MagicMock(state=State(metrics={})) + metrics = {"foo": 1, "bar": torch.tensor(2.0), "baz": {"qux": "quux"}} + m.compute = MagicMock(return_value=metrics) + m.completed(engine, "metric") + assert engine.state.metrics == metrics + + # other + engine = MagicMock(state=State(metrics={})) + m.compute = MagicMock(return_value="foo") + m.completed(engine, "metric") + assert engine.state.metrics == {"metric": "foo"}