Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,22 @@ def func(z):
self.assertTrue(gradcheck(func, (x,)))
self.assertTrue(gradcheck(script_func, (x,)))

def test_boxes_shape(self):
self._test_boxes_shape()

def _helper_boxes_shape(self, func):
# test boxes as Tensor[N, 5]
with self.assertRaises(AssertionError):
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
func(a, boxes, output_size=(2, 2))

# test boxes as List[Tensor[N, 4]]
with self.assertRaises(AssertionError):
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
ops.roi_pool(a, [boxes], output_size=(2, 2))

def fn(*args, **kwargs):
pass

Expand Down Expand Up @@ -139,6 +155,9 @@ def get_slice(k, block):
y[roi_idx, :, i, j] = bin_x.reshape(n_channels, -1).max(dim=1)[0]
return y

def _test_boxes_shape(self):
self._helper_boxes_shape(ops.roi_pool)


class PSRoIPoolTester(RoIOpTester, unittest.TestCase):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
Expand Down Expand Up @@ -183,6 +202,9 @@ def get_slice(k, block):
y[roi_idx, c_out, i, j] = t / area
return y

def _test_boxes_shape(self):
self._helper_boxes_shape(ops.ps_roi_pool)


def bilinear_interpolate(data, y, x, snap_border=False):
height, width = data.shape
Expand Down Expand Up @@ -266,6 +288,9 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_r
out_data[r, channel, i, j] = val
return out_data

def _test_boxes_shape(self):
self._helper_boxes_shape(ops.roi_align)


class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
Expand Down Expand Up @@ -317,6 +342,9 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, device, spatial_scale=1,
out_data[r, c_out, i, j] = val
return out_data

def _test_boxes_shape(self):
self._helper_boxes_shape(ops.ps_roi_align)


class NMSTester(unittest.TestCase):
def reference_nms(self, boxes, scores, iou_threshold):
Expand Down
12 changes: 12 additions & 0 deletions torchvision/ops/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,15 @@ def convert_boxes_to_roi_format(boxes):
ids = _cat(temp, dim=0)
rois = torch.cat([ids, concat_boxes], dim=1)
return rois


def check_roi_boxes_shape(boxes):
if isinstance(boxes, list):
for _tensor in boxes:
assert _tensor.size(1) == 4, \
'The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]'
elif isinstance(boxes, torch.Tensor):
assert boxes.size(1) == 5, 'The boxes tensor shape is not correct as Tensor[K, 5]'
else:
assert False, 'boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]'
return
3 changes: 2 additions & 1 deletion torchvision/ops/ps_roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List

from ._utils import convert_boxes_to_roi_format
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape


def ps_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
Expand Down Expand Up @@ -33,6 +33,7 @@ def ps_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor):
Expand Down
3 changes: 2 additions & 1 deletion torchvision/ops/ps_roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List

from ._utils import convert_boxes_to_roi_format
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape


def ps_roi_pool(input, boxes, output_size, spatial_scale=1.0):
Expand All @@ -28,6 +28,7 @@ def ps_roi_pool(input, boxes, output_size, spatial_scale=1.0):
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor):
Expand Down
3 changes: 2 additions & 1 deletion torchvision/ops/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List, BroadcastingList2

from ._utils import convert_boxes_to_roi_format
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape


def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False):
Expand Down Expand Up @@ -35,6 +35,7 @@ def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, a
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor):
Expand Down
3 changes: 2 additions & 1 deletion torchvision/ops/roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List, BroadcastingList2

from ._utils import convert_boxes_to_roi_format
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape


def roi_pool(input, boxes, output_size, spatial_scale=1.0):
Expand All @@ -27,6 +27,7 @@ def roi_pool(input, boxes, output_size, spatial_scale=1.0):
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor):
Expand Down