Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ def forward(self, a):
self.layer(a)


class PoolWrapper(nn.Module):
def __init__(self, pool: nn.Module):
super().__init__()
self.pool = pool

def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor:
return self.pool(imgs, boxes)


class RoIOpTester(ABC):
dtype = torch.float64

Expand Down Expand Up @@ -150,6 +159,14 @@ def _helper_boxes_shape(self, func):
boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
ops.roi_pool(a, [boxes], output_size=(2, 2))

def _helper_jit_boxes_list(self, model):
x = torch.rand(2, 1, 10, 10)
roi = torch.tensor([[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], dtype=torch.float).t()
rois = [roi, roi]
scriped = torch.jit.script(model)
y = scriped(x, rois)
assert y.shape == (10, 1, 3, 3)

@abstractmethod
def fn(*args, **kwargs):
pass
Expand Down Expand Up @@ -210,6 +227,10 @@ def get_slice(k, block):
def test_boxes_shape(self):
self._helper_boxes_shape(ops.roi_pool)

def test_jit_boxes_list(self):
model = PoolWrapper(ops.RoIPool(output_size=[3, 3], spatial_scale=1.0))
self._helper_jit_boxes_list(model)


class TestPSRoIPool(RoIOpTester):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
Expand Down Expand Up @@ -450,6 +471,10 @@ def test_qroi_align_multiple_images(self):
with pytest.raises(RuntimeError, match="Only one image per batch is allowed"):
ops.roi_align(qx, qrois, output_size=5)

def test_jit_boxes_list(self):
model = PoolWrapper(ops.RoIAlign(output_size=[3, 3], spatial_scale=1.0, sampling_ratio=-1))
self._helper_jit_boxes_list(model)


class TestPSRoIAlign(RoIOpTester):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion torchvision/ops/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
self.sampling_ratio = sampling_ratio
self.aligned = aligned

def forward(self, input: Tensor, rois: Tensor) -> Tensor:
def forward(self, input: Tensor, rois: Union[Tensor, List[Tensor]]) -> Tensor:
return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned)

def __repr__(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion torchvision/ops/roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, output_size: BroadcastingList2[int], spatial_scale: float):
self.output_size = output_size
self.spatial_scale = spatial_scale

def forward(self, input: Tensor, rois: Tensor) -> Tensor:
def forward(self, input: Tensor, rois: Union[Tensor, List[Tensor]]) -> Tensor:
return roi_pool(input, rois, self.output_size, self.spatial_scale)

def __repr__(self) -> str:
Expand Down