Skip to content
10 changes: 6 additions & 4 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ specific condition (e.g. ignore user-defined classes):
self._num_correct += torch.sum(correct).to(self._device)
self._num_examples += correct.shape[0]

@sync_all_reduce("_num_examples", "_num_correct")
@sync_all_reduce("_num_examples", "_num_correct:SUM")
def compute(self):
if self._num_examples == 0:
raise NotComputableError('CustomAccuracy must have at least one example before it can be computed.')
Expand Down Expand Up @@ -288,9 +288,11 @@ Metrics and distributed computations
In the above example, ``CustomAccuracy`` has ``reset``, ``update``, ``compute`` methods decorated
with :meth:`~ignite.metrics.metric.reinit__is_reduced`, :meth:`~ignite.metrics.metric.sync_all_reduce`. The purpose of these features is to adapt metrics in distributed
computations on supported backend and devices (see :doc:`distributed` for more details). More precisely, in the above
example we added ``@sync_all_reduce("_num_examples", "_num_correct")`` over ``compute`` method. This means that when ``compute``
method is called, metric's interal variables ``self._num_examples`` and ``self._num_correct`` are summed up over all participating
devices. Therefore, once collected, these internal variables can be used to compute the final metric value.
example we added ``@sync_all_reduce("_num_examples", "_num_correct:SUM")`` over ``compute`` method. This means that when ``compute``
method is called, metric's interal variables ``self._num_examples`` and ``self._num_correct:SUM`` are summed up over all participating
devices. We specify the reduction operation ``self._num_correct:SUM`` or we keep the default ``self._num_examples`` as the default is ``SUM``.
We currently support four reduction operations (SUM, MAX, MIN, PRODUCT).
Therefore, once collected, these internal variables can be used to compute the final metric value.

Complete list of metrics
------------------------
Expand Down
13 changes: 11 additions & 2 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,13 +528,15 @@ def __setstate__(self, d: Dict) -> None:

def sync_all_reduce(*attrs: Any) -> Callable:
"""Helper decorator for distributed configuration to collect instance attribute value
across all participating processes.
across all participating processes and apply the specified reduction operation.

See :doc:`metrics` on how to use it.

Args:
attrs: attribute names of decorated class

.. versionchanged:: 0.5.0
- Ability to handle different reduction operations (SUM, MAX, MIN, PRODUCT).
"""

def wrapper(func: Callable) -> Callable:
Expand All @@ -548,9 +550,16 @@ def another_wrapper(self: Metric, *args: Any, **kwargs: Any) -> Callable:
if len(attrs) > 0 and not self._is_reduced:
if ws > 1:
for attr in attrs:
op_kwargs = {}
if ":" in attr:
attr, op = attr.split(":")
valid_ops = ["MIN", "MAX", "SUM", "PRODUCT"]
if op not in valid_ops:
raise ValueError(f"Reduction operation is not valid (expected : {valid_ops}, got: {op}")
op_kwargs["op"] = op
t = getattr(self, attr, None)
if t is not None:
t = idist.all_reduce(t)
t = idist.all_reduce(t, **op_kwargs)
self._is_reduced = True
setattr(self, attr, t)
else:
Expand Down
93 changes: 89 additions & 4 deletions tests/ignite/metrics/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import ignite.distributed as idist
from ignite.engine import Engine, Events, State
from ignite.metrics import ConfusionMatrix, Precision, Recall
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, reinit__is_reduced
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, reinit__is_reduced, sync_all_reduce


class DummyMetric1(Metric):
Expand Down Expand Up @@ -538,13 +538,60 @@ def update(self, output):
pass


def _test_distrib_sync_all_reduce_decorator(device):
def _test_invalid_sync_all_reduce(device):
class InvalidMetric(Metric):
@reinit__is_reduced
def reset(self):
self.a = torch.tensor([0.0, 1.0, 2.0, 3.0], requires_grad=False)
self.c = 0.0
self.n = 0
self.m = -1

from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce
def compute(self):
pass

def update(self):
pass

@sync_all_reduce("a:sum")
def invalid_reduction_op_1(self):
pass

@sync_all_reduce("c:MaX")
def invalid_reduction_op_2(self):
pass

@sync_all_reduce("n:MINN")
def invalid_reduction_op_3(self):
pass

@sync_all_reduce("m:PROduCT")
def invalid_reduction_op_4(self):
pass

metric_device = device if torch.device(device).type != "xla" else "cpu"
m = InvalidMetric(device=metric_device)
m.reset()

if idist.get_world_size() > 1:
with pytest.raises(ValueError, match=r"Reduction operation is not valid"):
m.invalid_reduction_op_1()

with pytest.raises(ValueError, match=r"Reduction operation is not valid"):
m.invalid_reduction_op_2()

with pytest.raises(ValueError, match=r"Reduction operation is not valid"):
m.invalid_reduction_op_3()

with pytest.raises(ValueError, match=r"Reduction operation is not valid"):
m.invalid_reduction_op_4()


def _test_distrib_sync_all_reduce_decorator(device):
class DummyMetric(Metric):
@reinit__is_reduced
def reset(self):
# SUM op
self.a = torch.tensor([0.0, 1.0, 2.0, 3.0], device=self._device, requires_grad=False)
self.a_nocomp = self.a.clone().to("cpu")
self.b = torch.tensor(1.0, dtype=torch.float64, device=self._device, requires_grad=False)
Expand All @@ -554,20 +601,51 @@ def reset(self):
self.n = 0
self.n_nocomp = self.n

@sync_all_reduce("a", "b", "c", "n")
# MAX op
self.m = -1

# MIN op
self.k = 10000

# initialize number of updates to test (MAX, MIN) ops
self.num_updates = 0

# PRODUCT op
self.prod = torch.tensor([2.0, 3.0], device=self._device, requires_grad=False)
self.prod_nocomp = self.prod.clone().to("cpu")

@sync_all_reduce("a", "b", "c", "n:SUM", "m:MAX", "k:MIN", "prod:PRODUCT")
def compute(self):
assert (self.a.cpu() == (self.a_nocomp + 10) * idist.get_world_size()).all()
assert (self.b.cpu() == (self.b_nocomp - 5) * idist.get_world_size()).all()
assert self.c == pytest.approx((self.c_nocomp + 1.23456) * idist.get_world_size())
assert self.n == (self.n_nocomp + 1) * idist.get_world_size()
assert self.m == self.num_updates * (idist.get_world_size() - 1) - 1
assert self.k == 10000 - self.num_updates * (idist.get_world_size() - 1)
temp_prod_nocomp = 5 * self.prod_nocomp # new variable for the recomputing
temp_prod_nocomp = temp_prod_nocomp.pow(idist.get_world_size())
assert (self.prod.cpu() == temp_prod_nocomp).all()

@reinit__is_reduced
def update(self, output):
# SUM op
self.n += 1
self.c += 1.23456
self.a += 10.0
self.b -= 5.0

# MAX op
self.m += idist.get_rank()

# MIN op
self.k -= idist.get_rank()

# numper of updates for (MAX, MIN) ops
self.num_updates += 1

# PRODUCT op
self.prod *= 5

metric_device = device if torch.device(device).type != "xla" else "cpu"
m = DummyMetric(device=metric_device)
m.update(None)
Expand All @@ -588,6 +666,7 @@ def test_distrib_gpu(distributed_context_single_node_nccl):

device = f"cuda:{distributed_context_single_node_nccl['local_rank']}"
_test_distrib_sync_all_reduce_decorator(device)
_test_invalid_sync_all_reduce(device)


@pytest.mark.distributed
Expand All @@ -596,6 +675,7 @@ def test_distrib_cpu(distributed_context_single_node_gloo):

device = "cpu"
_test_distrib_sync_all_reduce_decorator(device)
_test_invalid_sync_all_reduce(device)


@pytest.mark.distributed
Expand All @@ -607,6 +687,7 @@ def test_distrib_hvd(gloo_hvd_executor):
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()

gloo_hvd_executor(_test_distrib_sync_all_reduce_decorator, (device,), np=nproc, do_init=True)
gloo_hvd_executor(_test_invalid_sync_all_reduce, (device,), np=nproc, do_init=True)


@pytest.mark.multinode_distributed
Expand All @@ -615,6 +696,7 @@ def test_distrib_hvd(gloo_hvd_executor):
def test_multinode_distrib_cpu(distributed_context_multi_node_gloo):
device = "cpu"
_test_distrib_sync_all_reduce_decorator(device)
_test_invalid_sync_all_reduce(device)


@pytest.mark.multinode_distributed
Expand All @@ -623,6 +705,7 @@ def test_multinode_distrib_cpu(distributed_context_multi_node_gloo):
def test_multinode_distrib_gpu(distributed_context_multi_node_nccl):
device = f"cuda:{distributed_context_multi_node_nccl['local_rank']}"
_test_distrib_sync_all_reduce_decorator(device)
_test_invalid_sync_all_reduce(device)


@pytest.mark.tpu
Expand All @@ -632,12 +715,14 @@ def test_distrib_single_device_xla():
device = idist.device()
_test_distrib_sync_all_reduce_decorator(device)
_test_creating_on_xla_fails(device)
_test_invalid_sync_all_reduce(device)


def _test_distrib_xla_nprocs(index):
device = idist.device()
_test_distrib_sync_all_reduce_decorator(device)
_test_creating_on_xla_fails(device)
_test_invalid_sync_all_reduce(device)


@pytest.mark.tpu
Expand Down