Skip to content

Commit

Permalink
[ONNX] Fix circular padding to support dynamic axes
Browse files Browse the repository at this point in the history
This commit fixes a bug where the ONNX exporter for circular padding
queried the input tensor shape in order to get the correct 'end' index
for a slice node. This doesn't work when the axis in question is
has dynamic size. The commit fixes this by setting the 'end' index to INT_MAX,
which is the recommended way of slicing to the end of a dimension
with unknown size per ONNX spec.
  • Loading branch information
ilyasher authored and pytorchmergebot committed Feb 28, 2023
1 parent 71ad100 commit d72682b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
15 changes: 15 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7557,6 +7557,21 @@ def forward(self, x):
x = torch.randn(2, 3, 6)
self.run_test(PadModel(), (x))

@skipIfUnsupportedMinOpsetVersion(11)
def test_pad_circular_dynamic_axes(self):
class PadModel(torch.nn.Module):
def forward(self, x):
out = torch.nn.functional.pad(x, (2, 1, 2, 1), mode="circular")
return out

x = torch.randn(4, 3, 5, 6)
self.run_test(
PadModel(),
x,
input_names=["input_1"],
dynamic_axes={"input_1": [0, 1, 2, 3]}
)

@skipIfUnsupportedMaxOpsetVersion(10)
@skipScriptTest() # TODO: the logic in symbolic_opset9 doesn't handle script
def test_unsupported_pad(self):
Expand Down
4 changes: 1 addition & 3 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -1912,12 +1912,10 @@ def _pad_circular(g: jit_utils.GraphContext, input: _C.Value, pad: _C.Value):
for idx in range(ndim):
pad_r = padding[-(2 * idx + 1)]
pad_l = padding[-(2 * idx + 2)]
# get size for targeting the last idx, as Slice don't take start=[-1], end=[-1]
size = symbolic_helper._get_tensor_sizes(input)
tensors = []
if pad_l > 0:
left = symbolic_helper._slice_helper(
g, cur, axes=[2 + idx], starts=[-(pad_l)], ends=[size[2 + idx]]
g, cur, axes=[2 + idx], starts=[-(pad_l)], ends=[_constants.INT64_MAX]
)
tensors.append(left)

Expand Down

0 comments on commit d72682b

Please sign in to comment.