diff --git a/test/test_ops.py b/test/test_ops.py index d100714ad73..6b35b4f0091 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2,13 +2,13 @@ import os from abc import ABC, abstractmethod from functools import lru_cache -from typing import Tuple +from typing import Callable, List, Tuple import numpy as np import pytest import torch import torch.fx -from common_utils import needs_cuda, cpu_and_gpu, assert_equal +from common_utils import assert_equal, cpu_and_gpu, needs_cuda from PIL import Image from torch import nn, Tensor from torch.autograd import gradcheck @@ -1101,114 +1101,149 @@ def test_bbox_convert_jit(self): torch.testing.assert_close(scripted_cxcywh, box_cxcywh, rtol=0.0, atol=TOLERANCE) -class TestBoxArea: - def test_box_area(self): - def area_check(box, expected, tolerance=1e-4): - out = ops.box_area(box) - torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) - - # Check for int boxes - for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype) - expected = torch.tensor([10000, 0]) - area_check(box_tensor, expected) - - # Check for float32 and float64 boxes - for dtype in [torch.float32, torch.float64]: - box_tensor = torch.tensor( - [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ], - dtype=dtype, - ) - expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=torch.float64) - area_check(box_tensor, expected, tolerance=0.05) - - # Check for float16 box - box_tensor = torch.tensor( - [[285.25, 185.625, 1194.0, 851.5], [285.25, 188.75, 1192.0, 851.0], [279.25, 198.0, 1189.0, 849.0]], - dtype=torch.float16, - ) - expected = torch.tensor([605113.875, 600495.1875, 592247.25]) - area_check(box_tensor, expected) - - def test_box_area_jit(self): - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float) - TOLERANCE = 1e-3 - expected = ops.box_area(box_tensor) - scripted_fn = torch.jit.script(ops.box_area) - scripted_area = scripted_fn(box_tensor) - torch.testing.assert_close(scripted_area, expected, rtol=0.0, atol=TOLERANCE) +class BoxTestBase(ABC): + @abstractmethod + def _target_fn(self) -> Tuple[bool, Callable]: + pass + def _perform_box_operation(self, box: Tensor, run_as_script: bool = False) -> Tensor: + is_binary_fn = self._target_fn()[0] + target_fn = self._target_fn()[1] + box_operation = torch.jit.script(target_fn) if run_as_script else target_fn + return box_operation(box, box) if is_binary_fn else box_operation(box) -class TestBoxIou: - def test_iou(self): - def iou_check(box, expected, tolerance=1e-4): - out = ops.box_iou(box, box) + def _run_test(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: + def assert_close(box: Tensor, expected: Tensor, tolerance): + out = self._perform_box_operation(box) torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) - # Check for int boxes - for dtype in [torch.int16, torch.int32, torch.int64]: - box = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=dtype) - expected = torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]) - iou_check(box, expected) + for dtype in dtypes: + actual_box = torch.tensor(test_input, dtype=dtype) + expected_box = torch.tensor(expected) + assert_close(actual_box, expected_box, tolerance) - # Check for float boxes - for dtype in [torch.float16, torch.float32, torch.float64]: - box_tensor = torch.tensor( - [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ], - dtype=dtype, - ) - expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]) - iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-4) + def _run_jit_test(self, test_input: List) -> None: + box_tensor = torch.tensor(test_input, dtype=torch.float) + expected = self._perform_box_operation(box_tensor, True) + scripted_area = self._perform_box_operation(box_tensor, True) + torch.testing.assert_close(scripted_area, expected, rtol=0.0, atol=1e-3) - def test_iou_jit(self): - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float) - TOLERANCE = 1e-3 - expected = ops.box_iou(box_tensor, box_tensor) - scripted_fn = torch.jit.script(ops.box_iou) - scripted_iou = scripted_fn(box_tensor, box_tensor) - torch.testing.assert_close(scripted_iou, expected, rtol=0.0, atol=TOLERANCE) +class TestBoxArea(BoxTestBase): + def _target_fn(self) -> Tuple[bool, Callable]: + return (False, ops.box_area) -class TestGenBoxIou: - def test_gen_iou(self): - def gen_iou_check(box, expected, tolerance=1e-4): - out = ops.generalized_box_iou(box, box) - torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) + def _generate_int_input() -> List[List[int]]: + return [[0, 0, 100, 100], [0, 0, 0, 0]] - # Check for int boxes - for dtype in [torch.int16, torch.int32, torch.int64]: - box = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=dtype) - expected = torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]]) - gen_iou_check(box, expected) + def _generate_int_expected() -> List[int]: + return [10000, 0] - # Check for float boxes - for dtype in [torch.float16, torch.float32, torch.float64]: - box_tensor = torch.tensor( - [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ], - dtype=dtype, - ) - expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]) - gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3) + def _generate_float_input(index: int) -> List[List[float]]: + return [ + [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ], + [[285.25, 185.625, 1194.0, 851.5], [285.25, 188.75, 1192.0, 851.0], [279.25, 198.0, 1189.0, 849.0]], + ][index] + + def _generate_float_expected(index: int) -> List[float]: + return [[604723.0806, 600965.4666, 592761.0085], [605113.875, 600495.1875, 592247.25]][index] + + @pytest.mark.parametrize( + "test_input, dtypes, tolerance, expected", + [ + pytest.param( + _generate_int_input(), + [torch.int8, torch.int16, torch.int32, torch.int64], + 1e-4, + _generate_int_expected(), + ), + pytest.param(_generate_float_input(0), [torch.float32, torch.float64], 0.05, _generate_float_expected(0)), + pytest.param(_generate_float_input(1), [torch.float16], 1e-4, _generate_float_expected(1)), + ], + ) + def test_box_area(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: + self._run_test(test_input, dtypes, tolerance, expected) + + def test_box_area_jit(self) -> None: + self._run_jit_test([[0, 0, 100, 100], [0, 0, 0, 0]]) + + +class TestBoxIou(BoxTestBase): + def _target_fn(self) -> Tuple[bool, Callable]: + return (True, ops.box_iou) + + def _generate_int_input() -> List[List[int]]: + return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] + + def _generate_int_expected() -> List[List[float]]: + return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + + def _generate_float_input() -> List[List[float]]: + return [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ] - def test_giou_jit(self): - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float) - TOLERANCE = 1e-3 - expected = ops.generalized_box_iou(box_tensor, box_tensor) - scripted_fn = torch.jit.script(ops.generalized_box_iou) - scripted_iou = scripted_fn(box_tensor, box_tensor) - torch.testing.assert_close(scripted_iou, expected, rtol=0.0, atol=TOLERANCE) + def _generate_float_expected() -> List[List[float]]: + return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] + + @pytest.mark.parametrize( + "test_input, dtypes, tolerance, expected", + [ + pytest.param( + _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() + ), + pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), + pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-4, _generate_float_expected()), + ], + ) + def test_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: + self._run_test(test_input, dtypes, tolerance, expected) + + def test_iou_jit(self) -> None: + self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + + +class TestGenBoxIou(BoxTestBase): + def _target_fn(self) -> Tuple[bool, Callable]: + return (True, ops.generalized_box_iou) + + def _generate_int_input() -> List[List[int]]: + return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] + + def _generate_int_expected() -> List[List[float]]: + return [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]] + + def _generate_float_input() -> List[List[float]]: + return [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ] + + def _generate_float_expected() -> List[List[float]]: + return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] + + @pytest.mark.parametrize( + "test_input, dtypes, tolerance, expected", + [ + pytest.param( + _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() + ), + pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), + pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()), + ], + ) + def test_gen_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: + self._run_test(test_input, dtypes, tolerance, expected) + + def test_giou_jit(self) -> None: + self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) class TestMasksToBoxes: