From d92bbf20e34feacab3941c5e27e73fe687303d63 Mon Sep 17 00:00:00 2001 From: Taras Savchyn Date: Mon, 31 May 2021 15:18:14 +0300 Subject: [PATCH 1/5] Introduce _updated flag --- ignite/metrics/precision.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ignite/metrics/precision.py b/ignite/metrics/precision.py index 2c9449646b17..363fd1ac6fe9 100644 --- a/ignite/metrics/precision.py +++ b/ignite/metrics/precision.py @@ -22,6 +22,7 @@ def __init__( self._average = average self.eps = 1e-20 + self._updated = False super(_BasePrecisionRecall, self).__init__( output_transform=output_transform, is_multilabel=is_multilabel, device=device ) @@ -30,6 +31,7 @@ def __init__( def reset(self) -> None: self._true_positives = 0 # type: Union[int, torch.Tensor] self._positives = 0 # type: Union[int, torch.Tensor] + self._updated = False if self._is_multilabel: init_value = 0.0 if self._average else [] @@ -39,8 +41,7 @@ def reset(self) -> None: super(_BasePrecisionRecall, self).reset() def compute(self) -> Union[torch.Tensor, float]: - is_scalar = not isinstance(self._positives, torch.Tensor) or self._positives.ndim == 0 - if is_scalar and self._positives == 0: + if not self._updated: raise NotComputableError( f"{self.__class__.__name__} must have at least one example before it can be computed." ) @@ -173,3 +174,5 @@ def update(self, output: Sequence[torch.Tensor]) -> None: else: self._true_positives += true_positives self._positives += all_positives + + self._updated = True From b45d838e1a92ecad280f841389f22ed12a60fae9 Mon Sep 17 00:00:00 2001 From: Taras Savchyn Date: Mon, 31 May 2021 15:18:37 +0300 Subject: [PATCH 2/5] Adjust prec tests --- tests/ignite/metrics/test_precision.py | 67 ++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/tests/ignite/metrics/test_precision.py b/tests/ignite/metrics/test_precision.py index bff02cc65c27..4972ec1085bd 100644 --- a/tests/ignite/metrics/test_precision.py +++ b/tests/ignite/metrics/test_precision.py @@ -15,45 +15,58 @@ def test_no_update(): precision = Precision() + assert precision._updated is False with pytest.raises(NotComputableError, match=r"Precision must have at least one example before it can be computed"): precision.compute() + assert precision._updated is False precision = Precision(is_multilabel=True, average=True) + assert precision._updated is False with pytest.raises(NotComputableError, match=r"Precision must have at least one example before it can be computed"): precision.compute() + assert precision._updated is False def test_binary_wrong_inputs(): pr = Precision() + assert pr._updated is False with pytest.raises(ValueError, match=r"For binary cases, y must be comprised of 0's and 1's"): # y has not only 0 or 1 values pr.update((torch.randint(0, 2, size=(10,)).long(), torch.arange(0, 10).long())) + assert pr._updated is False with pytest.raises(ValueError, match=r"For binary cases, y_pred must be comprised of 0's and 1's"): # y_pred values are not thresholded to 0, 1 values pr.update((torch.rand(10,), torch.randint(0, 2, size=(10,)).long(),)) + assert pr._updated is False with pytest.raises(ValueError, match=r"y must have shape of"): # incompatible shapes pr.update((torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10, 5)).long())) + assert pr._updated is False with pytest.raises(ValueError, match=r"y must have shape of"): # incompatible shapes pr.update((torch.randint(0, 2, size=(10, 5, 6)).long(), torch.randint(0, 2, size=(10,)).long())) + assert pr._updated is False with pytest.raises(ValueError, match=r"y must have shape of"): # incompatible shapes pr.update((torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10, 5, 6)).long())) + assert pr._updated is False @pytest.mark.parametrize("average", [False, True]) def test_binary_input(average): pr = Precision(average=average) + assert pr._updated is False def _test(y_pred, y, batch_size): pr.reset() + assert pr._updated is False + if batch_size > 1: n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): @@ -66,6 +79,7 @@ def _test(y_pred, y, batch_size): np_y_pred = y_pred.numpy().ravel() assert pr._type == "binary" + assert pr._updated is True assert isinstance(pr.compute(), float if average else torch.Tensor) pr_compute = pr.compute() if average else pr.compute().numpy() assert precision_score(np_y, np_y_pred, average="binary") == pytest.approx(pr_compute) @@ -104,51 +118,64 @@ def get_test_cases(): def test_multiclass_wrong_inputs(): pr = Precision() + assert pr._updated is False with pytest.raises(ValueError): # incompatible shapes pr.update((torch.rand(10, 5, 4), torch.randint(0, 2, size=(10,)).long())) + assert pr._updated is False with pytest.raises(ValueError): # incompatible shapes pr.update((torch.rand(10, 5, 6), torch.randint(0, 5, size=(10, 5)).long())) + assert pr._updated is False with pytest.raises(ValueError): # incompatible shapes pr.update((torch.rand(10), torch.randint(0, 5, size=(10, 5, 6)).long())) + assert pr._updated is False pr = Precision(average=True) + assert pr._updated is False with pytest.raises(ValueError): # incompatible shapes between two updates pr.update((torch.rand(10, 5), torch.randint(0, 5, size=(10,)).long())) pr.update((torch.rand(10, 6), torch.randint(0, 5, size=(10,)).long())) + assert pr._updated is True with pytest.raises(ValueError): # incompatible shapes between two updates pr.update((torch.rand(10, 5, 12, 14), torch.randint(0, 5, size=(10, 12, 14)).long())) pr.update((torch.rand(10, 6, 12, 14), torch.randint(0, 5, size=(10, 12, 14)).long())) + assert pr._updated is True pr = Precision(average=False) + assert pr._updated is False with pytest.raises(ValueError): # incompatible shapes between two updates pr.update((torch.rand(10, 5), torch.randint(0, 5, size=(10,)).long())) pr.update((torch.rand(10, 6), torch.randint(0, 5, size=(10,)).long())) + assert pr._updated is True with pytest.raises(ValueError): # incompatible shapes between two updates pr.update((torch.rand(10, 5, 12, 14), torch.randint(0, 5, size=(10, 12, 14)).long())) pr.update((torch.rand(10, 6, 12, 14), torch.randint(0, 5, size=(10, 12, 14)).long())) + assert pr._updated is True @pytest.mark.parametrize("average", [False, True]) def test_multiclass_input(average): pr = Precision(average=average) + assert pr._updated is False def _test(y_pred, y, batch_size): pr.reset() + assert pr._updated is False + if batch_size > 1: n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): @@ -162,6 +189,7 @@ def _test(y_pred, y, batch_size): np_y = y.numpy().ravel() assert pr._type == "multiclass" + assert pr._updated is True assert isinstance(pr.compute(), float if average else torch.Tensor) pr_compute = pr.compute() if average else pr.compute().numpy() sk_average_parameter = "macro" if average else None @@ -204,23 +232,28 @@ def get_test_cases(): def test_multilabel_wrong_inputs(): pr = Precision(average=True, is_multilabel=True) + assert pr._updated is False with pytest.raises(ValueError): # incompatible shapes pr.update((torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10,)).long())) + assert pr._updated is False with pytest.raises(ValueError): # incompatible y_pred pr.update((torch.rand(10, 5), torch.randint(0, 2, size=(10, 5)).long())) + assert pr._updated is False with pytest.raises(ValueError): # incompatible y pr.update((torch.randint(0, 5, size=(10, 5, 6)), torch.rand(10))) + assert pr._updated is False with pytest.raises(ValueError): # incompatible shapes between two updates pr.update((torch.randint(0, 2, size=(20, 5)), torch.randint(0, 2, size=(20, 5)).long())) pr.update((torch.randint(0, 2, size=(20, 6)), torch.randint(0, 2, size=(20, 6)).long())) + assert pr._updated is True def to_numpy_multilabel(y): @@ -235,9 +268,12 @@ def to_numpy_multilabel(y): def test_multilabel_input(average): pr = Precision(average=average, is_multilabel=True) + assert pr._updated is False def _test(y_pred, y, batch_size): pr.reset() + assert pr._updated is False + if batch_size > 1: n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): @@ -250,6 +286,7 @@ def _test(y_pred, y, batch_size): np_y = to_numpy_multilabel(y) assert pr._type == "multilabel" + assert pr._updated is True pr_compute = pr.compute() if average else pr.compute().mean().item() with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UndefinedMetricWarning) @@ -257,9 +294,15 @@ def _test(y_pred, y, batch_size): pr1 = Precision(is_multilabel=True, average=True) pr2 = Precision(is_multilabel=True, average=False) + assert pr1._updated is False + assert pr2._updated is False pr1.update((y_pred, y)) pr2.update((y_pred, y)) + assert pr1._updated is True + assert pr2._updated is True assert pr1.compute() == pytest.approx(pr2.compute().mean().item()) + assert pr1._updated is True + assert pr2._updated is True def get_test_cases(): @@ -298,10 +341,12 @@ def test_incorrect_type(): def _test(average): pr = Precision(average=average) + assert pr._updated is False y_pred = torch.softmax(torch.rand(4, 4), dim=1) y = torch.ones(4).long() pr.update((y_pred, y)) + assert pr._updated is True y_pred = torch.randint(0, 2, size=(4,)) y = torch.ones(4).long() @@ -309,15 +354,21 @@ def _test(average): with pytest.raises(RuntimeError): pr.update((y_pred, y)) + assert pr._updated is True + _test(average=True) _test(average=False) pr1 = Precision(is_multilabel=True, average=True) pr2 = Precision(is_multilabel=True, average=False) + assert pr1._updated is False + assert pr2._updated is False y_pred = torch.randint(0, 2, size=(10, 4, 20, 23)) y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() pr1.update((y_pred, y)) pr2.update((y_pred, y)) + assert pr1._updated is True + assert pr2._updated is True assert pr1.compute() == pytest.approx(pr2.compute().mean().item()) @@ -325,12 +376,16 @@ def test_incorrect_y_classes(): def _test(average): pr = Precision(average=average) + assert pr._updated is False + y_pred = torch.randint(0, 2, size=(10, 4)).float() y = torch.randint(4, 5, size=(10,)).long() with pytest.raises(ValueError): pr.update((y_pred, y)) + assert pr._updated is False + _test(average=True) _test(average=False) @@ -360,11 +415,13 @@ def update(engine, i): pr = Precision(average=average, device=metric_device) pr.attach(engine, "pr") + assert pr._updated is False data = list(range(n_iters)) engine.run(data=data, max_epochs=n_epochs) assert "pr" in engine.state.metrics + assert pr._updated is True res = engine.state.metrics["pr"] if isinstance(res, torch.Tensor): assert res.device == metric_device @@ -413,11 +470,13 @@ def update(engine, i): pr = Precision(average=average, is_multilabel=True, device=metric_device) pr.attach(engine, "pr") + assert pr._updated is False data = list(range(n_iters)) engine.run(data=data, max_epochs=n_epochs) assert "pr" in engine.state.metrics + assert pr._updated is True res = engine.state.metrics["pr"] res2 = pr.compute() if isinstance(res, torch.Tensor): @@ -447,10 +506,14 @@ def update(engine, i): pr1 = Precision(is_multilabel=True, average=True) pr2 = Precision(is_multilabel=True, average=False) + assert pr1._updated is False + assert pr2._updated is False y_pred = torch.randint(0, 2, size=(10, 4, 20, 23)) y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() pr1.update((y_pred, y)) pr2.update((y_pred, y)) + assert pr1._updated is True + assert pr2._updated is True assert pr1.compute() == pytest.approx(pr2.compute().mean().item()) @@ -460,6 +523,7 @@ def _test_distrib_accumulator_device(device): def _test(average, metric_device): pr = Precision(average=average, device=metric_device) assert pr._device == metric_device + assert pr._updated is False # Since the shape of the accumulated amount isn't known before the first update # call, the internal variables aren't tensors on the right device yet. @@ -467,6 +531,7 @@ def _test(average, metric_device): y = torch.randint(0, 2, size=(10,)).long() pr.update((y_pred, y)) + assert pr._updated is True assert ( pr._true_positives.device == metric_device ), f"{type(pr._true_positives.device)}:{pr._true_positives.device} vs {type(metric_device)}:{metric_device}" @@ -488,6 +553,7 @@ def _test_distrib_multilabel_accumulator_device(device): def _test(average, metric_device): pr = Precision(is_multilabel=True, average=average, device=metric_device) + assert pr._updated is False assert pr._device == metric_device assert ( pr._true_positives.device == metric_device @@ -500,6 +566,7 @@ def _test(average, metric_device): y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() pr.update((y_pred, y)) + assert pr._updated is True assert ( pr._true_positives.device == metric_device ), f"{type(pr._true_positives.device)}:{pr._true_positives.device} vs {type(metric_device)}:{metric_device}" From a6fcc0fcf3c7d3af0e4cdca318ece540aaa42b6b Mon Sep 17 00:00:00 2001 From: Taras Savchyn Date: Mon, 31 May 2021 15:18:57 +0300 Subject: [PATCH 3/5] Fix recall accordingly --- ignite/metrics/recall.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ignite/metrics/recall.py b/ignite/metrics/recall.py index 9b9c05c9b663..f2c2d997a3c5 100644 --- a/ignite/metrics/recall.py +++ b/ignite/metrics/recall.py @@ -121,3 +121,5 @@ def update(self, output: Sequence[torch.Tensor]) -> None: else: self._true_positives += true_positives self._positives += actual_positives + + self._updated = True From 098a734e966a59b4e0f53872080d3cef830c8230 Mon Sep 17 00:00:00 2001 From: Taras Savchyn Date: Mon, 31 May 2021 15:19:14 +0300 Subject: [PATCH 4/5] Adjust rec tests --- tests/ignite/metrics/test_recall.py | 67 +++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/tests/ignite/metrics/test_recall.py b/tests/ignite/metrics/test_recall.py index fe9b14e93dc8..2cc94573ee0d 100644 --- a/tests/ignite/metrics/test_recall.py +++ b/tests/ignite/metrics/test_recall.py @@ -15,45 +15,58 @@ def test_no_update(): recall = Recall() + assert recall._updated is False with pytest.raises(NotComputableError, match=r"Recall must have at least one example before it can be computed"): recall.compute() + assert recall._updated is False recall = Recall(is_multilabel=True, average=True) + assert recall._updated is False with pytest.raises(NotComputableError, match=r"Recall must have at least one example before it can be computed"): recall.compute() + assert recall._updated is False def test_binary_wrong_inputs(): re = Recall() + assert re._updated is False with pytest.raises(ValueError, match=r"For binary cases, y must be comprised of 0's and 1's"): # y has not only 0 or 1 values re.update((torch.randint(0, 2, size=(10,)), torch.arange(0, 10).long())) + assert re._updated is False with pytest.raises(ValueError, match=r"For binary cases, y_pred must be comprised of 0's and 1's"): # y_pred values are not thresholded to 0, 1 values re.update((torch.rand(10, 1), torch.randint(0, 2, size=(10,)).long())) + assert re._updated is False with pytest.raises(ValueError, match=r"y must have shape of"): # incompatible shapes re.update((torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10, 5)).long())) + assert re._updated is False with pytest.raises(ValueError, match=r"y must have shape of"): # incompatible shapes re.update((torch.randint(0, 2, size=(10, 5, 6)), torch.randint(0, 2, size=(10,)).long())) + assert re._updated is False with pytest.raises(ValueError, match=r"y must have shape of"): # incompatible shapes re.update((torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10, 5, 6)).long())) + assert re._updated is False @pytest.mark.parametrize("average", [False, True]) def test_binary_input(average): re = Recall(average=average) + assert re._updated is False def _test(y_pred, y, batch_size): re.reset() + assert re._updated is False + if batch_size > 1: n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): @@ -66,6 +79,7 @@ def _test(y_pred, y, batch_size): np_y_pred = y_pred.numpy().ravel() assert re._type == "binary" + assert re._updated is True assert isinstance(re.compute(), float if average else torch.Tensor) re_compute = re.compute() if average else re.compute().numpy() assert recall_score(np_y, np_y_pred, average="binary") == pytest.approx(re_compute) @@ -104,51 +118,64 @@ def get_test_cases(): def test_multiclass_wrong_inputs(): re = Recall() + assert re._updated is False with pytest.raises(ValueError): # incompatible shapes re.update((torch.rand(10, 5, 4), torch.randint(0, 2, size=(10,)).long())) + assert re._updated is False with pytest.raises(ValueError): # incompatible shapes re.update((torch.rand(10, 5, 6), torch.randint(0, 5, size=(10, 5)).long())) + assert re._updated is False with pytest.raises(ValueError): # incompatible shapes re.update((torch.rand(10), torch.randint(0, 5, size=(10, 5, 6)).long())) + assert re._updated is False re = Recall(average=True) + assert re._updated is False with pytest.raises(ValueError): # incompatible shapes between two updates re.update((torch.rand(10, 5), torch.randint(0, 5, size=(10,)).long())) re.update((torch.rand(10, 6), torch.randint(0, 5, size=(10,)).long())) + assert re._updated is True with pytest.raises(ValueError): # incompatible shapes between two updates re.update((torch.rand(10, 5, 12, 14), torch.randint(0, 5, size=(10, 12, 14)).long())) re.update((torch.rand(10, 6, 12, 14), torch.randint(0, 5, size=(10, 12, 14)).long())) + assert re._updated is True re = Recall(average=False) + assert re._updated is False with pytest.raises(ValueError): # incompatible shapes between two updates re.update((torch.rand(10, 5), torch.randint(0, 5, size=(10,)).long())) re.update((torch.rand(10, 6), torch.randint(0, 5, size=(10,)).long())) + assert re._updated is True with pytest.raises(ValueError): # incompatible shapes between two updates re.update((torch.rand(10, 5, 12, 14), torch.randint(0, 5, size=(10, 12, 14)).long())) re.update((torch.rand(10, 6, 12, 14), torch.randint(0, 5, size=(10, 12, 14)).long())) + assert re._updated is True @pytest.mark.parametrize("average", [False, True]) def test_multiclass_input(average): re = Recall(average=average) + assert re._updated is False def _test(y_pred, y, batch_size): re.reset() + assert re._updated is False + if batch_size > 1: n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): @@ -162,6 +189,7 @@ def _test(y_pred, y, batch_size): np_y = y.numpy().ravel() assert re._type == "multiclass" + assert re._updated is True assert isinstance(re.compute(), float if average else torch.Tensor) re_compute = re.compute() if average else re.compute().numpy() sk_average_parameter = "macro" if average else None @@ -204,23 +232,28 @@ def get_test_cases(): def test_multilabel_wrong_inputs(): re = Recall(average=True, is_multilabel=True) + assert re._updated is False with pytest.raises(ValueError): # incompatible shapes re.update((torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10,)).long())) + assert re._updated is False with pytest.raises(ValueError): # incompatible y_pred re.update((torch.rand(10, 5), torch.randint(0, 2, size=(10, 5)).long())) + assert re._updated is False with pytest.raises(ValueError): # incompatible y re.update((torch.randint(0, 5, size=(10, 5, 6)), torch.rand(10))) + assert re._updated is False with pytest.raises(ValueError): # incompatible shapes between two updates re.update((torch.randint(0, 2, size=(20, 5)), torch.randint(0, 2, size=(20, 5)).long())) re.update((torch.randint(0, 2, size=(20, 6)), torch.randint(0, 2, size=(20, 6)).long())) + assert re._updated is True def to_numpy_multilabel(y): @@ -235,9 +268,12 @@ def to_numpy_multilabel(y): def test_multilabel_input(average): re = Recall(average=average, is_multilabel=True) + assert re._updated is False def _test(y_pred, y, batch_size): re.reset() + assert re._updated is False + if batch_size > 1: n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): @@ -250,6 +286,7 @@ def _test(y_pred, y, batch_size): np_y = to_numpy_multilabel(y) assert re._type == "multilabel" + assert re._updated is True re_compute = re.compute() if average else re.compute().mean().item() with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UndefinedMetricWarning) @@ -257,9 +294,15 @@ def _test(y_pred, y, batch_size): re1 = Recall(is_multilabel=True, average=True) re2 = Recall(is_multilabel=True, average=False) + assert re1._updated is False + assert re2._updated is False re1.update((y_pred, y)) re2.update((y_pred, y)) + assert re1._updated is True + assert re2._updated is True assert re1.compute() == pytest.approx(re2.compute().mean().item()) + assert re1._updated is True + assert re2._updated is True def get_test_cases(): @@ -298,10 +341,12 @@ def test_incorrect_type(): def _test(average): re = Recall(average=average) + assert re._updated is False y_pred = torch.softmax(torch.rand(4, 4), dim=1) y = torch.ones(4).long() re.update((y_pred, y)) + assert re._updated is True y_pred = torch.zeros(4,) y = torch.ones(4).long() @@ -309,15 +354,21 @@ def _test(average): with pytest.raises(RuntimeError): re.update((y_pred, y)) + assert re._updated is True + _test(average=True) _test(average=False) re1 = Recall(is_multilabel=True, average=True) re2 = Recall(is_multilabel=True, average=False) + assert re1._updated is False + assert re2._updated is False y_pred = torch.randint(0, 2, size=(10, 4, 20, 23)) y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() re1.update((y_pred, y)) re2.update((y_pred, y)) + assert re1._updated is True + assert re2._updated is True assert re1.compute() == pytest.approx(re2.compute().mean().item()) @@ -325,12 +376,16 @@ def test_incorrect_y_classes(): def _test(average): re = Recall(average=average) + assert re._updated is False + y_pred = torch.randint(0, 2, size=(10, 4)).float() y = torch.randint(4, 5, size=(10,)).long() with pytest.raises(ValueError): re.update((y_pred, y)) + assert re._updated is False + _test(average=True) _test(average=False) @@ -361,11 +416,13 @@ def update(engine, i): re = Recall(average=average, device=metric_device) re.attach(engine, "re") + assert re._updated is False data = list(range(n_iters)) engine.run(data=data, max_epochs=n_epochs) assert "re" in engine.state.metrics + assert re._updated is True res = engine.state.metrics["re"] if isinstance(res, torch.Tensor): assert res.device == metric_device @@ -414,11 +471,13 @@ def update(engine, i): re = Recall(average=average, is_multilabel=True, device=metric_device) re.attach(engine, "re") + assert re._updated is False data = list(range(n_iters)) engine.run(data=data, max_epochs=n_epochs) assert "re" in engine.state.metrics + assert re._updated is True res = engine.state.metrics["re"] res2 = re.compute() if isinstance(res, torch.Tensor): @@ -448,10 +507,14 @@ def update(engine, i): re1 = Recall(is_multilabel=True, average=True) re2 = Recall(is_multilabel=True, average=False) + assert re1._updated is False + assert re2._updated is False y_pred = torch.randint(0, 2, size=(10, 4, 20, 23)) y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() re1.update((y_pred, y)) re2.update((y_pred, y)) + assert re1._updated is True + assert re2._updated is True assert re1.compute() == pytest.approx(re2.compute().mean().item()) @@ -461,6 +524,7 @@ def _test_distrib_accumulator_device(device): def _test(average, metric_device): re = Recall(average=average, device=metric_device) assert re._device == metric_device + assert re._updated is False # Since the shape of the accumulated amount isn't known before the first update # call, the internal variables aren't tensors on the right device yet. @@ -468,6 +532,7 @@ def _test(average, metric_device): y = torch.randint(0, 2, size=(10,)).long() re.update((y_reed, y)) + assert re._updated is True assert ( re._true_positives.device == metric_device ), f"{type(re._true_positives.device)}:{re._true_positives.device} vs {type(metric_device)}:{metric_device}" @@ -489,6 +554,7 @@ def _test_distrib_multilabel_accumulator_device(device): def _test(average, metric_device): re = Recall(is_multilabel=True, average=average, device=metric_device) + assert re._updated is False assert re._device == metric_device assert ( re._true_positives.device == metric_device @@ -501,6 +567,7 @@ def _test(average, metric_device): y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() re.update((y_reed, y)) + assert re._updated is True assert ( re._true_positives.device == metric_device ), f"{type(re._true_positives.device)}:{re._true_positives.device} vs {type(metric_device)}:{metric_device}" From a47d8f7e3413e648214e817d0be42c3b2362dbe2 Mon Sep 17 00:00:00 2001 From: Taras Savchyn Date: Mon, 31 May 2021 16:14:41 +0300 Subject: [PATCH 5/5] Add all zeros y_pred test case --- tests/ignite/metrics/test_precision.py | 9 +++++++++ tests/ignite/metrics/test_recall.py | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/tests/ignite/metrics/test_precision.py b/tests/ignite/metrics/test_precision.py index 4972ec1085bd..919f4604d9ce 100644 --- a/tests/ignite/metrics/test_precision.py +++ b/tests/ignite/metrics/test_precision.py @@ -105,6 +105,9 @@ def get_test_cases(): # updated batches (torch.randint(0, 2, size=(50, 12, 10)), torch.randint(0, 2, size=(50, 12, 10)), 16), (torch.randint(0, 2, size=(50, 1, 12, 10)), torch.randint(0, 2, size=(50, 1, 12, 10)), 16), + # Corner case with all zeros predictions + (torch.zeros(size=(10,)), torch.randint(0, 2, size=(10,)), 1), + (torch.zeros(size=(10, 1)), torch.randint(0, 2, size=(10, 1)), 1), ] return test_cases @@ -219,6 +222,9 @@ def get_test_cases(): # updated batches (torch.rand(50, 5, 18, 16), torch.randint(0, 5, size=(50, 18, 16)), 16), (torch.rand(50, 7, 20, 12), torch.randint(0, 7, size=(50, 20, 12)), 16), + # Corner case with all zeros predictions + (torch.zeros(size=(10, 6)), torch.randint(0, 6, size=(10,)), 1), + (torch.zeros(size=(10, 4)), torch.randint(0, 4, size=(10,)), 1), ] return test_cases @@ -325,6 +331,9 @@ def get_test_cases(): # updated batches (torch.randint(0, 2, size=(50, 5, 18, 16)), torch.randint(0, 2, size=(50, 5, 18, 16)), 16), (torch.randint(0, 2, size=(50, 4, 20, 23)), torch.randint(0, 2, size=(50, 4, 20, 23)), 16), + # Corner case with all zeros predictions + (torch.zeros(size=(10, 5)), torch.randint(0, 2, size=(10, 5)), 1), + (torch.zeros(size=(10, 4)), torch.randint(0, 2, size=(10, 4)), 1), ] return test_cases diff --git a/tests/ignite/metrics/test_recall.py b/tests/ignite/metrics/test_recall.py index 2cc94573ee0d..333c466e86eb 100644 --- a/tests/ignite/metrics/test_recall.py +++ b/tests/ignite/metrics/test_recall.py @@ -105,6 +105,9 @@ def get_test_cases(): # updated batches (torch.randint(0, 2, size=(50, 12, 10)), torch.randint(0, 2, size=(50, 12, 10)), 16), (torch.randint(0, 2, size=(50, 1, 12, 10)), torch.randint(0, 2, size=(50, 1, 12, 10)), 16), + # Corner case with all zeros predictions + (torch.zeros(size=(10,)), torch.randint(0, 2, size=(10,)), 1), + (torch.zeros(size=(10, 1)), torch.randint(0, 2, size=(10, 1)), 1), ] return test_cases @@ -219,6 +222,9 @@ def get_test_cases(): # updated batches (torch.rand(50, 5, 18, 16), torch.randint(0, 5, size=(50, 18, 16)), 16), (torch.rand(50, 7, 20, 12), torch.randint(0, 7, size=(50, 20, 12)), 16), + # Corner case with all zeros predictions + (torch.zeros(size=(10, 6)), torch.randint(0, 6, size=(10,)), 1), + (torch.zeros(size=(10, 4)), torch.randint(0, 4, size=(10,)), 1), ] return test_cases @@ -325,6 +331,9 @@ def get_test_cases(): # updated batches (torch.randint(0, 2, size=(50, 5, 18, 16)), torch.randint(0, 2, size=(50, 5, 18, 16)), 16), (torch.randint(0, 2, size=(50, 4, 20, 23)), torch.randint(0, 2, size=(50, 4, 20, 23)), 16), + # Corner case with all zeros predictions + (torch.zeros(size=(10, 5)), torch.randint(0, 2, size=(10, 5)), 1), + (torch.zeros(size=(10, 4)), torch.randint(0, 2, size=(10, 4)), 1), ] return test_cases