From 492300d368a4fbfd12f3a4d96b716eaa70523c02 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 11 Aug 2022 09:34:33 +0100 Subject: [PATCH 1/3] Fix typing jit issue on RoIPool and RoIAlign --- test/test_ops.py | 17 +++++++++++++++++ torchvision/ops/roi_align.py | 2 +- torchvision/ops/roi_pool.py | 2 +- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 8f961e37117..c14bdb9b4d8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -69,6 +69,15 @@ def forward(self, a): self.layer(a) +class PoolWrapper(torch.nn.Module): + def __init__(self, pool: nn.Module): + super().__init__() + self.pool = pool + + def forward(self, imgs: Tensor, boxes: List[Tensor]): + return self.pool(imgs, boxes) + + class RoIOpTester(ABC): dtype = torch.float64 @@ -210,6 +219,10 @@ def get_slice(k, block): def test_boxes_shape(self): self._helper_boxes_shape(ops.roi_pool) + def test_jit_boxes_list(self): + model = PoolWrapper(ops.RoIPool(output_size=[3, 3], spatial_scale=1.0)) + torch.jit.script(model) + class TestPSRoIPool(RoIOpTester): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): @@ -450,6 +463,10 @@ def test_qroi_align_multiple_images(self): with pytest.raises(RuntimeError, match="Only one image per batch is allowed"): ops.roi_align(qx, qrois, output_size=5) + def test_jit_boxes_list(self): + model = PoolWrapper(ops.RoIAlign(output_size=[3, 3], spatial_scale=1.0, sampling_ratio=-1)) + torch.jit.script(model) + class TestPSRoIAlign(RoIOpTester): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index afe9e42af16..f331a37da4b 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -82,7 +82,7 @@ def __init__( self.sampling_ratio = sampling_ratio self.aligned = aligned - def forward(self, input: Tensor, rois: Tensor) -> Tensor: + def forward(self, input: Tensor, rois: Union[Tensor, List[Tensor]]) -> Tensor: return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned) def __repr__(self) -> str: diff --git a/torchvision/ops/roi_pool.py b/torchvision/ops/roi_pool.py index 50dc2f64421..9fd7bd84ee2 100644 --- a/torchvision/ops/roi_pool.py +++ b/torchvision/ops/roi_pool.py @@ -62,7 +62,7 @@ def __init__(self, output_size: BroadcastingList2[int], spatial_scale: float): self.output_size = output_size self.spatial_scale = spatial_scale - def forward(self, input: Tensor, rois: Tensor) -> Tensor: + def forward(self, input: Tensor, rois: Union[Tensor, List[Tensor]]) -> Tensor: return roi_pool(input, rois, self.output_size, self.spatial_scale) def __repr__(self) -> str: From fdf79c20f62d1e84c66cc0f4e69e0e7473fcef95 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 11 Aug 2022 11:49:13 +0100 Subject: [PATCH 2/3] Fix nit. --- test/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index c14bdb9b4d8..cf0d18eed47 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -69,7 +69,7 @@ def forward(self, a): self.layer(a) -class PoolWrapper(torch.nn.Module): +class PoolWrapper(nn.Module): def __init__(self, pool: nn.Module): super().__init__() self.pool = pool From 7213479a6836703e7c4630761d6c24b9a17d9831 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 11 Aug 2022 12:14:03 +0100 Subject: [PATCH 3/3] Address code review comments. --- test/test_ops.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index cf0d18eed47..a1cb5aa33d6 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -74,7 +74,7 @@ def __init__(self, pool: nn.Module): super().__init__() self.pool = pool - def forward(self, imgs: Tensor, boxes: List[Tensor]): + def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor: return self.pool(imgs, boxes) @@ -159,6 +159,14 @@ def _helper_boxes_shape(self, func): boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype) ops.roi_pool(a, [boxes], output_size=(2, 2)) + def _helper_jit_boxes_list(self, model): + x = torch.rand(2, 1, 10, 10) + roi = torch.tensor([[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], dtype=torch.float).t() + rois = [roi, roi] + scriped = torch.jit.script(model) + y = scriped(x, rois) + assert y.shape == (10, 1, 3, 3) + @abstractmethod def fn(*args, **kwargs): pass @@ -221,7 +229,7 @@ def test_boxes_shape(self): def test_jit_boxes_list(self): model = PoolWrapper(ops.RoIPool(output_size=[3, 3], spatial_scale=1.0)) - torch.jit.script(model) + self._helper_jit_boxes_list(model) class TestPSRoIPool(RoIOpTester): @@ -465,7 +473,7 @@ def test_qroi_align_multiple_images(self): def test_jit_boxes_list(self): model = PoolWrapper(ops.RoIAlign(output_size=[3, 3], spatial_scale=1.0, sampling_ratio=-1)) - torch.jit.script(model) + self._helper_jit_boxes_list(model) class TestPSRoIAlign(RoIOpTester):