From f54f7a203bcc95443ff80f405967b701f3b2f5bd Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 24 Feb 2021 22:11:07 +0000 Subject: [PATCH] Added eps to avoid nans in canberra error --- ignite/contrib/metrics/regression/canberra_metric.py | 2 +- .../contrib/metrics/regression/test_canberra_metric.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/ignite/contrib/metrics/regression/canberra_metric.py b/ignite/contrib/metrics/regression/canberra_metric.py index 5bb3cb8cca67..c3fc17e06cf1 100644 --- a/ignite/contrib/metrics/regression/canberra_metric.py +++ b/ignite/contrib/metrics/regression/canberra_metric.py @@ -39,7 +39,7 @@ def reset(self) -> None: def _update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: y_pred, y = output - errors = torch.abs(y - y_pred) / (torch.abs(y_pred) + torch.abs(y)) + errors = torch.abs(y - y_pred) / (torch.abs(y_pred) + torch.abs(y) + 1e-15) self._sum_of_errors += torch.sum(errors).to(self._device) @sync_all_reduce("_sum_of_errors") diff --git a/tests/ignite/contrib/metrics/regression/test_canberra_metric.py b/tests/ignite/contrib/metrics/regression/test_canberra_metric.py index 04974c21e563..43209b0b1fab 100644 --- a/tests/ignite/contrib/metrics/regression/test_canberra_metric.py +++ b/tests/ignite/contrib/metrics/regression/test_canberra_metric.py @@ -63,6 +63,12 @@ def test_compute(): assert canberra.pairwise([v1, v2])[0][1] == pytest.approx(np_sum) +def test_error_is_not_nan(): + m = CanberraMetric() + m.update((torch.zeros(4), torch.zeros(4))) + assert not (torch.isnan(m._sum_of_errors).any() or torch.isinf(m._sum_of_errors).any()), m._sum_of_errors + + def _test_distrib_compute(device): rank = idist.get_rank()