Skip to content

Commit

Permalink
[ONNX] Fix circular padding to support dynamic axes (#95647)
Browse files Browse the repository at this point in the history
This commit fixes a bug where the ONNX exporter for circular padding queried the input tensor shape in order to get the correct 'end' index for a slice node. This doesn't work when the axis in question is has dynamic size. The commit fixes this by setting the 'end' index to INT_MAX, which is the recommended way of slicing to the end of a dimension with unknown size per ONNX spec.

See https://onnx.ai/onnx/operators/onnx__Slice.html

Also adds a regression test.

Pull Request resolved: #95647
Approved by: https://github.com/BowenBao
  • Loading branch information
ilyasher authored and pytorchmergebot committed Mar 10, 2023
1 parent faa4cb2 commit 6154be1
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
15 changes: 15 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7585,6 +7585,21 @@ def forward(self, x):
x = torch.randn(2, 3, 6)
self.run_test(PadModel(), (x))

@skipIfUnsupportedMinOpsetVersion(11)
def test_pad_circular_dynamic_axes(self):
class PadModel(torch.nn.Module):
def forward(self, x):
out = torch.nn.functional.pad(x, (2, 1, 2, 1), mode="circular")
return out

x = torch.randn(4, 3, 5, 6)
self.run_test(
PadModel(),
x,
input_names=["input_1"],
dynamic_axes={"input_1": [0, 1, 2, 3]},
)

@skipIfUnsupportedMaxOpsetVersion(10)
@skipScriptTest() # TODO: the logic in symbolic_opset9 doesn't handle script
def test_unsupported_pad(self):
Expand Down
4 changes: 1 addition & 3 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -1915,12 +1915,10 @@ def _pad_circular(g: jit_utils.GraphContext, input: _C.Value, pad: _C.Value):
for idx in range(ndim):
pad_r = padding[-(2 * idx + 1)]
pad_l = padding[-(2 * idx + 2)]
# get size for targeting the last idx, as Slice don't take start=[-1], end=[-1]
size = symbolic_helper._get_tensor_sizes(input)
tensors = []
if pad_l > 0:
left = symbolic_helper._slice_helper(
g, cur, axes=[2 + idx], starts=[-(pad_l)], ends=[size[2 + idx]]
g, cur, axes=[2 + idx], starts=[-(pad_l)], ends=[_constants.INT64_MAX]
)
tensors.append(left)

Expand Down

0 comments on commit 6154be1

Please sign in to comment.