diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index eddd6b713fde..c47a68cf9ca4 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -223,7 +223,7 @@ specific condition (e.g. ignore user-defined classes): self._num_correct += torch.sum(correct).to(self._device) self._num_examples += correct.shape[0] - @sync_all_reduce("_num_examples", "_num_correct") + @sync_all_reduce("_num_examples", "_num_correct:SUM") def compute(self): if self._num_examples == 0: raise NotComputableError('CustomAccuracy must have at least one example before it can be computed.') @@ -288,9 +288,11 @@ Metrics and distributed computations In the above example, ``CustomAccuracy`` has ``reset``, ``update``, ``compute`` methods decorated with :meth:`~ignite.metrics.metric.reinit__is_reduced`, :meth:`~ignite.metrics.metric.sync_all_reduce`. The purpose of these features is to adapt metrics in distributed computations on supported backend and devices (see :doc:`distributed` for more details). More precisely, in the above -example we added ``@sync_all_reduce("_num_examples", "_num_correct")`` over ``compute`` method. This means that when ``compute`` -method is called, metric's interal variables ``self._num_examples`` and ``self._num_correct`` are summed up over all participating -devices. Therefore, once collected, these internal variables can be used to compute the final metric value. +example we added ``@sync_all_reduce("_num_examples", "_num_correct:SUM")`` over ``compute`` method. This means that when ``compute`` +method is called, metric's interal variables ``self._num_examples`` and ``self._num_correct:SUM`` are summed up over all participating +devices. We specify the reduction operation ``self._num_correct:SUM`` or we keep the default ``self._num_examples`` as the default is ``SUM``. +We currently support four reduction operations (SUM, MAX, MIN, PRODUCT). +Therefore, once collected, these internal variables can be used to compute the final metric value. Complete list of metrics ------------------------ diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 1fcdc2606b30..ba46bb4ab3c3 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -528,13 +528,15 @@ def __setstate__(self, d: Dict) -> None: def sync_all_reduce(*attrs: Any) -> Callable: """Helper decorator for distributed configuration to collect instance attribute value - across all participating processes. + across all participating processes and apply the specified reduction operation. See :doc:`metrics` on how to use it. Args: attrs: attribute names of decorated class + .. versionchanged:: 0.5.0 + - Ability to handle different reduction operations (SUM, MAX, MIN, PRODUCT). """ def wrapper(func: Callable) -> Callable: @@ -548,9 +550,16 @@ def another_wrapper(self: Metric, *args: Any, **kwargs: Any) -> Callable: if len(attrs) > 0 and not self._is_reduced: if ws > 1: for attr in attrs: + op_kwargs = {} + if ":" in attr: + attr, op = attr.split(":") + valid_ops = ["MIN", "MAX", "SUM", "PRODUCT"] + if op not in valid_ops: + raise ValueError(f"Reduction operation is not valid (expected : {valid_ops}, got: {op}") + op_kwargs["op"] = op t = getattr(self, attr, None) if t is not None: - t = idist.all_reduce(t) + t = idist.all_reduce(t, **op_kwargs) self._is_reduced = True setattr(self, attr, t) else: diff --git a/tests/ignite/metrics/test_metric.py b/tests/ignite/metrics/test_metric.py index e30c67e0668d..e8a5f6f39721 100644 --- a/tests/ignite/metrics/test_metric.py +++ b/tests/ignite/metrics/test_metric.py @@ -12,7 +12,7 @@ import ignite.distributed as idist from ignite.engine import Engine, Events, State from ignite.metrics import ConfusionMatrix, Precision, Recall -from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, reinit__is_reduced +from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, reinit__is_reduced, sync_all_reduce class DummyMetric1(Metric): @@ -538,13 +538,60 @@ def update(self, output): pass -def _test_distrib_sync_all_reduce_decorator(device): +def _test_invalid_sync_all_reduce(device): + class InvalidMetric(Metric): + @reinit__is_reduced + def reset(self): + self.a = torch.tensor([0.0, 1.0, 2.0, 3.0], requires_grad=False) + self.c = 0.0 + self.n = 0 + self.m = -1 - from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce + def compute(self): + pass + + def update(self): + pass + + @sync_all_reduce("a:sum") + def invalid_reduction_op_1(self): + pass + + @sync_all_reduce("c:MaX") + def invalid_reduction_op_2(self): + pass + + @sync_all_reduce("n:MINN") + def invalid_reduction_op_3(self): + pass + + @sync_all_reduce("m:PROduCT") + def invalid_reduction_op_4(self): + pass + metric_device = device if torch.device(device).type != "xla" else "cpu" + m = InvalidMetric(device=metric_device) + m.reset() + + if idist.get_world_size() > 1: + with pytest.raises(ValueError, match=r"Reduction operation is not valid"): + m.invalid_reduction_op_1() + + with pytest.raises(ValueError, match=r"Reduction operation is not valid"): + m.invalid_reduction_op_2() + + with pytest.raises(ValueError, match=r"Reduction operation is not valid"): + m.invalid_reduction_op_3() + + with pytest.raises(ValueError, match=r"Reduction operation is not valid"): + m.invalid_reduction_op_4() + + +def _test_distrib_sync_all_reduce_decorator(device): class DummyMetric(Metric): @reinit__is_reduced def reset(self): + # SUM op self.a = torch.tensor([0.0, 1.0, 2.0, 3.0], device=self._device, requires_grad=False) self.a_nocomp = self.a.clone().to("cpu") self.b = torch.tensor(1.0, dtype=torch.float64, device=self._device, requires_grad=False) @@ -554,20 +601,51 @@ def reset(self): self.n = 0 self.n_nocomp = self.n - @sync_all_reduce("a", "b", "c", "n") + # MAX op + self.m = -1 + + # MIN op + self.k = 10000 + + # initialize number of updates to test (MAX, MIN) ops + self.num_updates = 0 + + # PRODUCT op + self.prod = torch.tensor([2.0, 3.0], device=self._device, requires_grad=False) + self.prod_nocomp = self.prod.clone().to("cpu") + + @sync_all_reduce("a", "b", "c", "n:SUM", "m:MAX", "k:MIN", "prod:PRODUCT") def compute(self): assert (self.a.cpu() == (self.a_nocomp + 10) * idist.get_world_size()).all() assert (self.b.cpu() == (self.b_nocomp - 5) * idist.get_world_size()).all() assert self.c == pytest.approx((self.c_nocomp + 1.23456) * idist.get_world_size()) assert self.n == (self.n_nocomp + 1) * idist.get_world_size() + assert self.m == self.num_updates * (idist.get_world_size() - 1) - 1 + assert self.k == 10000 - self.num_updates * (idist.get_world_size() - 1) + temp_prod_nocomp = 5 * self.prod_nocomp # new variable for the recomputing + temp_prod_nocomp = temp_prod_nocomp.pow(idist.get_world_size()) + assert (self.prod.cpu() == temp_prod_nocomp).all() @reinit__is_reduced def update(self, output): + # SUM op self.n += 1 self.c += 1.23456 self.a += 10.0 self.b -= 5.0 + # MAX op + self.m += idist.get_rank() + + # MIN op + self.k -= idist.get_rank() + + # numper of updates for (MAX, MIN) ops + self.num_updates += 1 + + # PRODUCT op + self.prod *= 5 + metric_device = device if torch.device(device).type != "xla" else "cpu" m = DummyMetric(device=metric_device) m.update(None) @@ -588,6 +666,7 @@ def test_distrib_gpu(distributed_context_single_node_nccl): device = f"cuda:{distributed_context_single_node_nccl['local_rank']}" _test_distrib_sync_all_reduce_decorator(device) + _test_invalid_sync_all_reduce(device) @pytest.mark.distributed @@ -596,6 +675,7 @@ def test_distrib_cpu(distributed_context_single_node_gloo): device = "cpu" _test_distrib_sync_all_reduce_decorator(device) + _test_invalid_sync_all_reduce(device) @pytest.mark.distributed @@ -607,6 +687,7 @@ def test_distrib_hvd(gloo_hvd_executor): nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_sync_all_reduce_decorator, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_invalid_sync_all_reduce, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @@ -615,6 +696,7 @@ def test_distrib_hvd(gloo_hvd_executor): def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): device = "cpu" _test_distrib_sync_all_reduce_decorator(device) + _test_invalid_sync_all_reduce(device) @pytest.mark.multinode_distributed @@ -623,6 +705,7 @@ def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): device = f"cuda:{distributed_context_multi_node_nccl['local_rank']}" _test_distrib_sync_all_reduce_decorator(device) + _test_invalid_sync_all_reduce(device) @pytest.mark.tpu @@ -632,12 +715,14 @@ def test_distrib_single_device_xla(): device = idist.device() _test_distrib_sync_all_reduce_decorator(device) _test_creating_on_xla_fails(device) + _test_invalid_sync_all_reduce(device) def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_sync_all_reduce_decorator(device) _test_creating_on_xla_fails(device) + _test_invalid_sync_all_reduce(device) @pytest.mark.tpu