From e4433edc77bdb4b8f17b5fe65c8de536d47cba62 Mon Sep 17 00:00:00 2001 From: Ksenija Stanojevic Date: Thu, 15 Oct 2020 09:24:34 -0700 Subject: [PATCH 1/8] add export for is_floating_point --- torch/onnx/symbolic_opset9.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index eed84e437b2c..ae8cca4b9933 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -2194,6 +2194,12 @@ def prim_shape(g, self): return g.op('Shape', self) +def is_floating_point(g, self): + if sym_help._is_fp(self): + return g.op("Constant", value_t=torch.BoolTensor([1])) + return g.op("Constant", value_t=torch.BoolTensor([0])) + + @parse_args('v', 'i') def one_hot(g, self, num_classes): values = g.op("Constant", value_t=torch.LongTensor([0, 1])) From 87c56997daa5ad42ccd59fa9dee2d9a10f6159a9 Mon Sep 17 00:00:00 2001 From: Ksenija Stanojevic Date: Thu, 15 Oct 2020 09:47:35 -0700 Subject: [PATCH 2/8] add tests --- test/onnx/test_pytorch_onnx_onnxruntime.py | 26 ++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 54cfe965cd4f..1988ed464996 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -849,6 +849,32 @@ def test_avgpool_3d_ceil(self): x = torch.randn(20, 16, 50, 44, 31) self.run_test(model, x) + def test_floating_point(self): + class FloatingPoint(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + if x.is_floating_point(): + return x.new_zeros(x.shape[1:]) + return x.new_zeros(x.shape) + + x = torch.randn(2, 3, 4) + self.run_test(FloatingPoint(), x) + + @skipIfONNXShapeInference(False) + def test_floating_point_infer_dtype(self): + class FloatingPoint(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + if x.size(0) > 1: + a = x + 2 + if a.is_floating_point(): + return x.new_zeros(x.shape[1:]) + return x.new_zeros(x.shape) + return x + + x = torch.randn(2, 3, 4) + self.run_test(FloatingPoint(), x) + def test_arithmetic(self): class ArithmeticModule(torch.nn.Module): def forward(self, x): From 5229a1267847da82be14c2b78579c53f5475d98d Mon Sep 17 00:00:00 2001 From: Ksenija Stanojevic Date: Thu, 15 Oct 2020 10:04:15 -0700 Subject: [PATCH 3/8] add a warning if a type cannot be inferred --- torch/onnx/symbolic_helper.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 8932955703a7..68d651ffe95b 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -229,6 +229,8 @@ def _is_fp(value): return (type == 'torch.float32') or (type == 'torch.float64') or (type == 'torch.float16') else: type = value.type().scalarType() + if type is None: + warnings.warn("Type cannot be inferred, which might cause exported graph to produce incorrect results.") return (type == 'Float') or (type == 'Double') or (type == 'Half') return False From 4d3025e651c705d6c1b40ce9a24983f1f8652e2f Mon Sep 17 00:00:00 2001 From: Ksenija Stanojevic Date: Thu, 15 Oct 2020 10:16:31 -0700 Subject: [PATCH 4/8] disable pytorch_shape_inference tests in CI --- scripts/onnx/test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/onnx/test.sh b/scripts/onnx/test.sh index 8b6fc6c4cf63..5e950c7243b1 100755 --- a/scripts/onnx/test.sh +++ b/scripts/onnx/test.sh @@ -51,6 +51,7 @@ pytest "${args[@]}" \ --ignore "$top_dir/test/onnx/test_custom_ops.py" \ --ignore "$top_dir/test/onnx/test_models_onnxruntime.py" \ --ignore "$top_dir/test/onnx/test_utility_funs.py" \ + --ignore "$top_dir/test/onnx/test_pytorch_onnx_shape_inference.py" \ "${test_paths[@]}" # onnxruntime only support py3 From e68990e5f113f67127a793457dd87f598922f6ce Mon Sep 17 00:00:00 2001 From: Ksenija Stanojevic Date: Fri, 16 Oct 2020 10:22:10 -0700 Subject: [PATCH 5/8] add more tests --- test/onnx/test_pytorch_onnx_onnxruntime.py | 19 ++++++++++++++++--- torch/onnx/symbolic_helper.py | 2 +- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 1988ed464996..c7c078d08485 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -868,11 +868,24 @@ def forward(self, x): if x.size(0) > 1: a = x + 2 if a.is_floating_point(): - return x.new_zeros(x.shape[1:]) - return x.new_zeros(x.shape) + return x.new_zeros(x.shape[1:], dtype=torch.long) + return x.new_zeros(x.shape, dtype=torch.long) return x - x = torch.randn(2, 3, 4) + x = torch.randn(2, 3, 4).to(torch.int32) + self.run_test(FloatingPoint(), x) + + class FloatingPoint(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + if x.size(0) > 1: + a = x + 2 + if a.is_floating_point(): + return x + 1 + return x + return x + + x = torch.randn(2, 3, 4).to(torch.int32) self.run_test(FloatingPoint(), x) def test_arithmetic(self): diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 68d651ffe95b..2f7fdd05148b 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -231,7 +231,7 @@ def _is_fp(value): type = value.type().scalarType() if type is None: warnings.warn("Type cannot be inferred, which might cause exported graph to produce incorrect results.") - return (type == 'Float') or (type == 'Double') or (type == 'Half') + return (type == 'Float') or (type == 'Double') or (type == 'Half') or (type is None) return False From 38e9edaf1b831077e04f7867c64f142584e0efa8 Mon Sep 17 00:00:00 2001 From: Ksenija Stanojevic Date: Fri, 16 Oct 2020 13:19:17 -0700 Subject: [PATCH 6/8] add tests --- test/onnx/test_pytorch_onnx_onnxruntime.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index c7c078d08485..92b647580239 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -854,12 +854,25 @@ class FloatingPoint(torch.jit.ScriptModule): @torch.jit.script_method def forward(self, x): if x.is_floating_point(): - return x.new_zeros(x.shape[1:]) + return x.new_zeros(x.shape) return x.new_zeros(x.shape) x = torch.randn(2, 3, 4) self.run_test(FloatingPoint(), x) + class FloatingPoint(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + if x.size(0) > 1: + a = x + 2 + if a.is_floating_point(): + return x + 1 + return x + return x + + x = torch.randn(2, 3, 4) + self.run_test(FloatingPoint(), x) + @skipIfONNXShapeInference(False) def test_floating_point_infer_dtype(self): class FloatingPoint(torch.jit.ScriptModule): From 812757aca7eb966f6977e0c1f2d2018f564acc14 Mon Sep 17 00:00:00 2001 From: Ksenija Stanojevic Date: Thu, 22 Oct 2020 17:51:47 -0700 Subject: [PATCH 7/8] fix keypoint rcnn test --- test/onnx/test_pytorch_onnx_onnxruntime.py | 8 ++++---- torch/onnx/symbolic_helper.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 4b56ad1848a0..64d585f58e8e 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -867,7 +867,7 @@ def forward(self, x): a = x + 2 if a.is_floating_point(): return x + 1 - return x + return x + 1 return x x = torch.randn(2, 3, 4) @@ -881,11 +881,11 @@ def forward(self, x): if x.size(0) > 1: a = x + 2 if a.is_floating_point(): - return x.new_zeros(x.shape[1:], dtype=torch.long) - return x.new_zeros(x.shape, dtype=torch.long) + return x.new_zeros(x.shape[1:]) + return x.new_zeros(x.shape) return x - x = torch.randn(2, 3, 4).to(torch.int32) + x = torch.randn(2, 3, 4) self.run_test(FloatingPoint(), x) class FloatingPoint(torch.jit.ScriptModule): diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 2f7fdd05148b..68d651ffe95b 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -231,7 +231,7 @@ def _is_fp(value): type = value.type().scalarType() if type is None: warnings.warn("Type cannot be inferred, which might cause exported graph to produce incorrect results.") - return (type == 'Float') or (type == 'Double') or (type == 'Half') or (type is None) + return (type == 'Float') or (type == 'Double') or (type == 'Half') return False From 2893ae3655f5d71212df3baf0a7cc346fd38abdf Mon Sep 17 00:00:00 2001 From: Ksenija Stanojevic Date: Fri, 23 Oct 2020 09:48:17 -0700 Subject: [PATCH 8/8] update test --- test/onnx/test_pytorch_onnx_onnxruntime.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 64d585f58e8e..0ada7abc3efa 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -849,6 +849,7 @@ def test_avgpool_3d_ceil(self): x = torch.randn(20, 16, 50, 44, 31) self.run_test(model, x) + @skipIfUnsupportedMinOpsetVersion(9) def test_floating_point(self): class FloatingPoint(torch.jit.ScriptModule): @torch.jit.script_method @@ -873,6 +874,7 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_test(FloatingPoint(), x) + @skipIfUnsupportedMinOpsetVersion(9) @skipIfONNXShapeInference(False) def test_floating_point_infer_dtype(self): class FloatingPoint(torch.jit.ScriptModule):