diff --git a/test/test_ops.py b/test/test_ops.py index 8f961e37117..a1cb5aa33d6 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -69,6 +69,15 @@ def forward(self, a): self.layer(a) +class PoolWrapper(nn.Module): + def __init__(self, pool: nn.Module): + super().__init__() + self.pool = pool + + def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor: + return self.pool(imgs, boxes) + + class RoIOpTester(ABC): dtype = torch.float64 @@ -150,6 +159,14 @@ def _helper_boxes_shape(self, func): boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype) ops.roi_pool(a, [boxes], output_size=(2, 2)) + def _helper_jit_boxes_list(self, model): + x = torch.rand(2, 1, 10, 10) + roi = torch.tensor([[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], dtype=torch.float).t() + rois = [roi, roi] + scriped = torch.jit.script(model) + y = scriped(x, rois) + assert y.shape == (10, 1, 3, 3) + @abstractmethod def fn(*args, **kwargs): pass @@ -210,6 +227,10 @@ def get_slice(k, block): def test_boxes_shape(self): self._helper_boxes_shape(ops.roi_pool) + def test_jit_boxes_list(self): + model = PoolWrapper(ops.RoIPool(output_size=[3, 3], spatial_scale=1.0)) + self._helper_jit_boxes_list(model) + class TestPSRoIPool(RoIOpTester): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): @@ -450,6 +471,10 @@ def test_qroi_align_multiple_images(self): with pytest.raises(RuntimeError, match="Only one image per batch is allowed"): ops.roi_align(qx, qrois, output_size=5) + def test_jit_boxes_list(self): + model = PoolWrapper(ops.RoIAlign(output_size=[3, 3], spatial_scale=1.0, sampling_ratio=-1)) + self._helper_jit_boxes_list(model) + class TestPSRoIAlign(RoIOpTester): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index afe9e42af16..f331a37da4b 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -82,7 +82,7 @@ def __init__( self.sampling_ratio = sampling_ratio self.aligned = aligned - def forward(self, input: Tensor, rois: Tensor) -> Tensor: + def forward(self, input: Tensor, rois: Union[Tensor, List[Tensor]]) -> Tensor: return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned) def __repr__(self) -> str: diff --git a/torchvision/ops/roi_pool.py b/torchvision/ops/roi_pool.py index 50dc2f64421..9fd7bd84ee2 100644 --- a/torchvision/ops/roi_pool.py +++ b/torchvision/ops/roi_pool.py @@ -62,7 +62,7 @@ def __init__(self, output_size: BroadcastingList2[int], spatial_scale: float): self.output_size = output_size self.spatial_scale = spatial_scale - def forward(self, input: Tensor, rois: Tensor) -> Tensor: + def forward(self, input: Tensor, rois: Union[Tensor, List[Tensor]]) -> Tensor: return roi_pool(input, rois, self.output_size, self.spatial_scale) def __repr__(self) -> str: