From ba3d35fc3ba3c34a45eb98d3b0cca9920ff94855 Mon Sep 17 00:00:00 2001 From: Ksenija Stanojevic Date: Fri, 28 Aug 2020 20:54:33 -0700 Subject: [PATCH 1/2] update len symbolic --- test/onnx/test_pytorch_onnx_onnxruntime.py | 9 +++++++++ torch/onnx/symbolic_opset11.py | 4 +++- torch/onnx/symbolic_opset9.py | 4 ++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 5959e9c7a5b5b..f5a79b3e2d394 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -2576,6 +2576,15 @@ def forward(self, input): self.run_test(LenModel(), x, input_names=['input'], dynamic_axes={'input': {0: 'seq'}}, test_with_inputs=(torch.randn(5, 5),)) + def test_len_list(self): + class LenListModel(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, input): + return torch.ones(len(input.shape)) + + x = torch.randn(4, 5) + self.run_test(LenListModel(), x) + @skipIfUnsupportedMinOpsetVersion(11) def test_unbind_dynamic(self): class UnbindModel(torch.jit.ScriptModule): diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index b9ea02b31b11d..e3a4b9b40a40f 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -273,7 +273,9 @@ def masked_scatter(g, self, mask, source): def _len(g, self): - return g.op("SequenceLength", self) + if self.type().isSubtypeOf(torch._C.ListType.ofTensors()): + return g.op("SequenceLength", self) + return g.op("Size", self) def __getitem_(g, self, i): diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 1ec83634332e8..d19c35ea20e0f 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -619,6 +619,10 @@ def floor(g, input): return g.op("Floor", input) +def _len(g, self): + return g.op("Size", self) + + @parse_args('v', 't', 't') def threshold(g, self, threshold, value): # See Note [Export inplace] From 596a73c0d0cca652e7f0648cd2968d8ff97b4780 Mon Sep 17 00:00:00 2001 From: Ksenija Stanojevic Date: Sun, 30 Aug 2020 22:40:11 -0700 Subject: [PATCH 2/2] update test for len operator --- test/onnx/test_pytorch_onnx_onnxruntime.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index f5a79b3e2d394..ebeac73430561 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -2576,6 +2576,7 @@ def forward(self, input): self.run_test(LenModel(), x, input_names=['input'], dynamic_axes={'input': {0: 'seq'}}, test_with_inputs=(torch.randn(5, 5),)) + @skipIfUnsupportedMinOpsetVersion(9) def test_len_list(self): class LenListModel(torch.jit.ScriptModule): @torch.jit.script_method