From 991dc74e4d48df1bafcc0926678a9ed56a190852 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 11 Mar 2020 15:15:27 -0700 Subject: [PATCH 1/6] add checkout/assert in roi_pool --- torchvision/ops/roi_pool.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchvision/ops/roi_pool.py b/torchvision/ops/roi_pool.py index f94373436db..00c69c52452 100644 --- a/torchvision/ops/roi_pool.py +++ b/torchvision/ops/roi_pool.py @@ -27,6 +27,15 @@ def roi_pool(input, boxes, output_size, spatial_scale=1.0): Returns: output (Tensor[K, C, output_size[0], output_size[1]]) """ + if isinstance(boxes, list): + for _tensor in boxes: + assert _tensor.size(1) == 4, \ + 'The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]' + elif isinstance(boxes, torch.Tensor): + assert boxes.size(1) == 5, 'The boxes tensor shape is not correct as Tensor[K, 5]' + else: + assert False, 'boxes shape is not correct' + rois = boxes output_size = _pair(output_size) if not isinstance(rois, torch.Tensor): From 17a7748f00ed51f39d618625103150d44a44510f Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 11 Mar 2020 15:17:31 -0700 Subject: [PATCH 2/6] add checkout/assert in roi_align --- torchvision/ops/roi_align.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index 0e8a978aff0..9b772be9cd2 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -35,6 +35,15 @@ def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, a Returns: output (Tensor[K, C, output_size[0], output_size[1]]) """ + if isinstance(boxes, list): + for _tensor in boxes: + assert _tensor.size(1) == 4, \ + 'The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]' + elif isinstance(boxes, torch.Tensor): + assert boxes.size(1) == 5, 'The boxes tensor shape is not correct as Tensor[K, 5]' + else: + assert False, 'boxes shape is not correct' + rois = boxes output_size = _pair(output_size) if not isinstance(rois, torch.Tensor): From 2edb8546c23e3fa235c2d256388ffaef4a384f24 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 12 Mar 2020 11:09:23 -0700 Subject: [PATCH 3/6] move check_roi_boxes_shape func to ops/_utils.py --- torchvision/ops/_utils.py | 12 ++++++++++++ torchvision/ops/ps_roi_align.py | 3 ++- torchvision/ops/ps_roi_pool.py | 3 ++- torchvision/ops/roi_align.py | 12 ++---------- torchvision/ops/roi_pool.py | 12 ++---------- 5 files changed, 20 insertions(+), 22 deletions(-) diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index 269abaf7db3..714022f0421 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -24,3 +24,15 @@ def convert_boxes_to_roi_format(boxes): ids = _cat(temp, dim=0) rois = torch.cat([ids, concat_boxes], dim=1) return rois + + +def check_roi_boxes_shape(boxes): + if isinstance(boxes, list): + for _tensor in boxes: + assert _tensor.size(1) == 4, \ + 'The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]' + elif isinstance(boxes, torch.Tensor): + assert boxes.size(1) == 5, 'The boxes tensor shape is not correct as Tensor[K, 5]' + else: + assert False, 'boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]' + return diff --git a/torchvision/ops/ps_roi_align.py b/torchvision/ops/ps_roi_align.py index 4d265096a67..c0c761b72cc 100644 --- a/torchvision/ops/ps_roi_align.py +++ b/torchvision/ops/ps_roi_align.py @@ -4,7 +4,7 @@ from torch.nn.modules.utils import _pair from torch.jit.annotations import List -from ._utils import convert_boxes_to_roi_format +from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape def ps_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1): @@ -33,6 +33,7 @@ def ps_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1 Returns: output (Tensor[K, C, output_size[0], output_size[1]]) """ + check_roi_boxes_shape(boxes) rois = boxes output_size = _pair(output_size) if not isinstance(rois, torch.Tensor): diff --git a/torchvision/ops/ps_roi_pool.py b/torchvision/ops/ps_roi_pool.py index a033d15fff6..710f2cb0195 100644 --- a/torchvision/ops/ps_roi_pool.py +++ b/torchvision/ops/ps_roi_pool.py @@ -4,7 +4,7 @@ from torch.nn.modules.utils import _pair from torch.jit.annotations import List -from ._utils import convert_boxes_to_roi_format +from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape def ps_roi_pool(input, boxes, output_size, spatial_scale=1.0): @@ -28,6 +28,7 @@ def ps_roi_pool(input, boxes, output_size, spatial_scale=1.0): Returns: output (Tensor[K, C, output_size[0], output_size[1]]) """ + check_roi_boxes_shape(boxes) rois = boxes output_size = _pair(output_size) if not isinstance(rois, torch.Tensor): diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index 9b772be9cd2..14224d8a83e 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -4,7 +4,7 @@ from torch.nn.modules.utils import _pair from torch.jit.annotations import List, BroadcastingList2 -from ._utils import convert_boxes_to_roi_format +from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False): @@ -35,15 +35,7 @@ def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, a Returns: output (Tensor[K, C, output_size[0], output_size[1]]) """ - if isinstance(boxes, list): - for _tensor in boxes: - assert _tensor.size(1) == 4, \ - 'The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]' - elif isinstance(boxes, torch.Tensor): - assert boxes.size(1) == 5, 'The boxes tensor shape is not correct as Tensor[K, 5]' - else: - assert False, 'boxes shape is not correct' - + check_roi_boxes_shape(boxes) rois = boxes output_size = _pair(output_size) if not isinstance(rois, torch.Tensor): diff --git a/torchvision/ops/roi_pool.py b/torchvision/ops/roi_pool.py index 00c69c52452..10232f16b4a 100644 --- a/torchvision/ops/roi_pool.py +++ b/torchvision/ops/roi_pool.py @@ -4,7 +4,7 @@ from torch.nn.modules.utils import _pair from torch.jit.annotations import List, BroadcastingList2 -from ._utils import convert_boxes_to_roi_format +from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape def roi_pool(input, boxes, output_size, spatial_scale=1.0): @@ -27,15 +27,7 @@ def roi_pool(input, boxes, output_size, spatial_scale=1.0): Returns: output (Tensor[K, C, output_size[0], output_size[1]]) """ - if isinstance(boxes, list): - for _tensor in boxes: - assert _tensor.size(1) == 4, \ - 'The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]' - elif isinstance(boxes, torch.Tensor): - assert boxes.size(1) == 5, 'The boxes tensor shape is not correct as Tensor[K, 5]' - else: - assert False, 'boxes shape is not correct' - + check_roi_boxes_shape(boxes) rois = boxes output_size = _pair(output_size) if not isinstance(rois, torch.Tensor): From 5e8d3f44c11f1b7ce9136629cca7b0868fef998e Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 12 Mar 2020 11:41:11 -0700 Subject: [PATCH 4/6] add tests --- test/test_ops.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index 92782cf4400..c98c8aace3b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -51,6 +51,9 @@ def _test_forward(self, device, contiguous): def _test_backward(self, device, contiguous): pass + def test_boxes_shape(self): + self._test_boxes_shape(self) + class RoIOpTester(OpTester): def _test_forward(self, device, contiguous): @@ -91,6 +94,19 @@ def func(z): self.assertTrue(gradcheck(func, (x,))) self.assertTrue(gradcheck(script_func, (x,))) + def _helper_boxes_shape(self, func): + # test boxes as Tensor[N, 5] + with self.assertRaises(AssertionError): + a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8) + boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype) + func(a, boxes, output_size=(2, 2)) + + # test boxes as List[Tensor[N, 4]] + with self.assertRaises(AssertionError): + a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8) + boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype) + ops.roi_pool(a, [boxes], output_size=(2, 2)) + def fn(*args, **kwargs): pass @@ -139,6 +155,9 @@ def get_slice(k, block): y[roi_idx, :, i, j] = bin_x.reshape(n_channels, -1).max(dim=1)[0] return y + def _test_boxes_shape(self): + self._helper_boxes_shape(ops.roi_pool) + class PSRoIPoolTester(RoIOpTester, unittest.TestCase): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): @@ -183,6 +202,9 @@ def get_slice(k, block): y[roi_idx, c_out, i, j] = t / area return y + def _test_boxes_shape(self): + self._helper_boxes_shape(ops.ps_roi_pool) + def bilinear_interpolate(data, y, x, snap_border=False): height, width = data.shape @@ -266,6 +288,9 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_r out_data[r, channel, i, j] = val return out_data + def _test_boxes_shape(self): + self._helper_boxes_shape(ops.roi_align) + class PSRoIAlignTester(RoIOpTester, unittest.TestCase): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): @@ -317,6 +342,9 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, device, spatial_scale=1, out_data[r, c_out, i, j] = val return out_data + def _test_boxes_shape(self): + self._helper_boxes_shape(ops.ps_roi_align) + class NMSTester(unittest.TestCase): def reference_nms(self, boxes, scores, iou_threshold): From 4a11721c973e1d9ec1c2db36bf45f91e90d47fc4 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 13 Mar 2020 06:41:12 -0700 Subject: [PATCH 5/6] fix CI --- test/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index c98c8aace3b..c32cb1bbc46 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -52,7 +52,7 @@ def _test_backward(self, device, contiguous): pass def test_boxes_shape(self): - self._test_boxes_shape(self) + self._test_boxes_shape() class RoIOpTester(OpTester): From 2388ab454b6b50ae9f480c2ac7c7ebe683338f7b Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 13 Mar 2020 07:52:40 -0700 Subject: [PATCH 6/6] fix CI --- test/test_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index c32cb1bbc46..71ae9c4dbbc 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -51,9 +51,6 @@ def _test_forward(self, device, contiguous): def _test_backward(self, device, contiguous): pass - def test_boxes_shape(self): - self._test_boxes_shape() - class RoIOpTester(OpTester): def _test_forward(self, device, contiguous): @@ -94,6 +91,9 @@ def func(z): self.assertTrue(gradcheck(func, (x,))) self.assertTrue(gradcheck(script_func, (x,))) + def test_boxes_shape(self): + self._test_boxes_shape() + def _helper_boxes_shape(self, func): # test boxes as Tensor[N, 5] with self.assertRaises(AssertionError):