Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Update unsafe_chunk() method to support new version 13 of Split operator. (#51415) #51524

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Expand Up @@ -6521,6 +6521,18 @@ def forward(self, input_data, prev_state):
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
self.run_test(model, (random_data, empty_tensor))

@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest()
def test_unsafe_chunk(self):
class ChunkModel(torch.nn.Module):
def forward(self, x):
return torch.unsafe_chunk(x, 3, dim=1)

model = ChunkModel()
model.eval()
x = torch.randn(1, 18)
self.run_test(model, x, input_names=['x'])


def make_test(name, base, layer, bidirectional, initial_state,
variable_length, dropout,
Expand Down
27 changes: 25 additions & 2 deletions torch/onnx/symbolic_opset13.py
Expand Up @@ -52,8 +52,8 @@ def split(g, self, split_size_or_sizes, dim, _outputs=None):
# Convert to multiple slice nodes iff number of splits and number of outputs are statically known.
if sym_help._is_packed_list(split_size_or_sizes) and \
len(sym_help._unpack_list(split_size_or_sizes)) == _outputs:
split_sizes = [g.op("Unsqueeze", v, g.op("Constant", value_t=torch.tensor([0])))
for v in sym_help._unpack_list(split_size_or_sizes)]
split_sizes = [sym_help._unsqueeze_helper(g, v, [0]) for v in sym_help._unpack_list(split_size_or_sizes)]

start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
res = []
Expand Down Expand Up @@ -153,3 +153,26 @@ def reduce_dim(g, self, dim, keepdim, dtype):
return reduce

sum = _reduce_with_dtype('ReduceSum', 'sum')

@parse_args('v', 'i', 'i', 'i')
def unsafe_chunk(g, self, chunks, dim, _outputs=None):
if _outputs is None:
return g.op("SplitToSequence",
self,
g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)),
axis_i=dim, keepdims_i=0)

size = sym_help._get_tensor_dim_size(self, dim)
if size is None:
return _unimplemented('unsafe_chunk', 'unknown dimension size')
split_size = (size + chunks - 1) // chunks
splits = [split_size] * (size // split_size)
leftover = size % split_size
if leftover:
splits.append(leftover)

# TODO: So far we don't have a module using this method. We'll keep
# this as a constant unless we see a request of dynamics in any
# user's modules.
splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long))
return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)