From 3f3a0af451c11c32345c26d5a9c039ac41e22bc3 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Mon, 1 Feb 2021 17:22:11 -0800 Subject: [PATCH] [ONNX] Update unsafe_chunk() method to support new version 13 of Split operator. (#51415) * def unsafe_chunk() support and test in ops13. * Use _unsqueeze_helper insteadof Unsqueeze operator * Cast the splits into long. * Change the test to a fixed dimension. * Update test_pytorch_onnx_onnxruntime.py * Disable test_loop_with_list for opset 13. Co-authored-by: Negin Raoof [ghstack-poisoned] --- test/onnx/test_pytorch_onnx_onnxruntime.py | 12 ++++++++++ torch/onnx/symbolic_opset13.py | 27 ++++++++++++++++++++-- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 68bff7138193..60a9a463a913 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -6521,6 +6521,18 @@ def forward(self, input_data, prev_state): empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0) self.run_test(model, (random_data, empty_tensor)) + @skipIfUnsupportedMinOpsetVersion(11) + @disableScriptTest() + def test_unsafe_chunk(self): + class ChunkModel(torch.nn.Module): + def forward(self, x): + return torch.unsafe_chunk(x, 3, dim=1) + + model = ChunkModel() + model.eval() + x = torch.randn(1, 18) + self.run_test(model, x, input_names=['x']) + def make_test(name, base, layer, bidirectional, initial_state, variable_length, dropout, diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py index b60b6927c87c..facd7b6de39f 100644 --- a/torch/onnx/symbolic_opset13.py +++ b/torch/onnx/symbolic_opset13.py @@ -52,8 +52,8 @@ def split(g, self, split_size_or_sizes, dim, _outputs=None): # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. if sym_help._is_packed_list(split_size_or_sizes) and \ len(sym_help._unpack_list(split_size_or_sizes)) == _outputs: - split_sizes = [g.op("Unsqueeze", v, g.op("Constant", value_t=torch.tensor([0]))) - for v in sym_help._unpack_list(split_size_or_sizes)] + split_sizes = [sym_help._unsqueeze_helper(g, v, [0]) for v in sym_help._unpack_list(split_size_or_sizes)] + start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) res = [] @@ -153,3 +153,26 @@ def reduce_dim(g, self, dim, keepdim, dtype): return reduce sum = _reduce_with_dtype('ReduceSum', 'sum') + +@parse_args('v', 'i', 'i', 'i') +def unsafe_chunk(g, self, chunks, dim, _outputs=None): + if _outputs is None: + return g.op("SplitToSequence", + self, + g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), + axis_i=dim, keepdims_i=0) + + size = sym_help._get_tensor_dim_size(self, dim) + if size is None: + return _unimplemented('unsafe_chunk', 'unknown dimension size') + split_size = (size + chunks - 1) // chunks + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + + # TODO: So far we don't have a module using this method. We'll keep + # this as a constant unless we see a request of dynamics in any + # user's modules. + splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long)) + return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)