Skip to content
10 changes: 10 additions & 0 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ def test_roi_align(self):
model = ops.RoIAlign((5, 5), 1, 2)
self.run_model(model, [(x, single_roi)])

x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, -1)
self.run_model(model, [(x, single_roi)])

def test_roi_align_aligned(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
Expand All @@ -150,6 +155,11 @@ def test_roi_align_aligned(self):
model = ops.RoIAlign((2, 2), 2.5, 0, aligned=True)
self.run_model(model, [(x, single_roi)])

x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((2, 2), 2.5, -1, aligned=True)
self.run_model(model, [(x, single_roi)])

@unittest.skip # Issue in exporting ROIAlign with aligned = True for malformed boxes
def test_roi_align_malformed_boxes(self):
x = torch.randn(1, 1, 10, 10, dtype=torch.float32)
Expand Down
6 changes: 6 additions & 0 deletions torchvision/ops/_register_onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ def roi_align(g, input, rois, spatial_scale, pooled_height, pooled_width, sampli
" ONNX forces ROIs to be 1x1 or larger.")
scale = torch.tensor(0.5 / spatial_scale).to(dtype=torch.float)
rois = g.op("Sub", rois, scale)

# ONNX doesn't support negative sampling_ratio
if sampling_ratio < 0:
warnings.warn("ONNX doesn't support negative sampling ratio,"
"therefore is is set to 0 in order to be exported.")
sampling_ratio = 0
return g.op('RoiAlign', input, rois, batch_indices, spatial_scale_f=spatial_scale,
output_height_i=pooled_height, output_width_i=pooled_width, sampling_ratio_i=sampling_ratio)

Expand Down