Skip to content

Commit

Permalink
[ONNX] Update unsafe_chunk() method to support new version 13 of Spli…
Browse files Browse the repository at this point in the history
…t operator. (#51415) (#51524)

Summary:
Pull Request resolved: #51524

* def unsafe_chunk() support and test in ops13.

* Use _unsqueeze_helper insteadof Unsqueeze operator

* Cast the splits into long.

* Change the test to a fixed dimension.

* Update test_pytorch_onnx_onnxruntime.py

* Disable test_loop_with_list for opset 13.

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D26203123

Pulled By: SplitInfinity

fbshipit-source-id: b273aeff8339faa0e8e9f1fcfbf877d1b703209f

Co-authored-by: Negin Raoof <neginmr@utexas.edu>
  • Loading branch information
2 people authored and facebook-github-bot committed Feb 4, 2021
1 parent 8ae6b0c commit ba824eb
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
12 changes: 12 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Expand Up @@ -6521,6 +6521,18 @@ def forward(self, input_data, prev_state):
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
self.run_test(model, (random_data, empty_tensor))

@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest()
def test_unsafe_chunk(self):
class ChunkModel(torch.nn.Module):
def forward(self, x):
return torch.unsafe_chunk(x, 3, dim=1)

model = ChunkModel()
model.eval()
x = torch.randn(1, 18)
self.run_test(model, x, input_names=['x'])


def make_test(name, base, layer, bidirectional, initial_state,
variable_length, dropout,
Expand Down
27 changes: 25 additions & 2 deletions torch/onnx/symbolic_opset13.py
Expand Up @@ -52,8 +52,8 @@ def split(g, self, split_size_or_sizes, dim, _outputs=None):
# Convert to multiple slice nodes iff number of splits and number of outputs are statically known.
if sym_help._is_packed_list(split_size_or_sizes) and \
len(sym_help._unpack_list(split_size_or_sizes)) == _outputs:
split_sizes = [g.op("Unsqueeze", v, g.op("Constant", value_t=torch.tensor([0])))
for v in sym_help._unpack_list(split_size_or_sizes)]
split_sizes = [sym_help._unsqueeze_helper(g, v, [0]) for v in sym_help._unpack_list(split_size_or_sizes)]

start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
res = []
Expand Down Expand Up @@ -153,3 +153,26 @@ def reduce_dim(g, self, dim, keepdim, dtype):
return reduce

sum = _reduce_with_dtype('ReduceSum', 'sum')

@parse_args('v', 'i', 'i', 'i')
def unsafe_chunk(g, self, chunks, dim, _outputs=None):
if _outputs is None:
return g.op("SplitToSequence",
self,
g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)),
axis_i=dim, keepdims_i=0)

size = sym_help._get_tensor_dim_size(self, dim)
if size is None:
return _unimplemented('unsafe_chunk', 'unknown dimension size')
split_size = (size + chunks - 1) // chunks
splits = [split_size] * (size // split_size)
leftover = size % split_size
if leftover:
splits.append(leftover)

# TODO: So far we don't have a module using this method. We'll keep
# this as a constant unless we see a request of dynamics in any
# user's modules.
splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long))
return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)

0 comments on commit ba824eb

Please sign in to comment.