From 249c13caa54a8c89e084232a5781a2cdd52961bb Mon Sep 17 00:00:00 2001 From: David Date: Fri, 11 Dec 2020 11:29:07 -0800 Subject: [PATCH 1/2] Adapt to new torch export API for dictionary --- .circleci/config.yml | 2 +- .circleci/config.yml.in | 2 +- test/test_onnx.py | 6 +++++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 664afba9323..79f1353d8da 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -157,7 +157,7 @@ jobs: # need to install torchvision dependencies due to transitive imports pip install --user --progress-bar off --editable . pip install --user onnx - pip install --user -i https://test.pypi.org/simple/ ort-nightly==1.5.2.dev202012031 + pip install --user onnxruntime python test/test_onnx.py binary_linux_wheel: diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index d84272fec61..809bf9175a4 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -157,7 +157,7 @@ jobs: # need to install torchvision dependencies due to transitive imports pip install --user --progress-bar off --editable . pip install --user onnx - pip install --user -i https://test.pypi.org/simple/ ort-nightly==1.5.2.dev202012031 + pip install --user onnxruntime python test/test_onnx.py binary_linux_wheel: diff --git a/test/test_onnx.py b/test/test_onnx.py index 54a4e385a8d..20897ef22bb 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -33,8 +33,12 @@ def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_consta model.eval() onnx_io = io.BytesIO() + if isinstance(inputs_list[0][-1], dict): + torch_onnx_input = inputs_list[0] + ({},) + else: + torch_onnx_input = inputs_list[0] # export to onnx with the first input - torch.onnx.export(model, inputs_list[0], onnx_io, + torch.onnx.export(model, torch_onnx_input, onnx_io, do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version, dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names) # validate the exported model with onnx runtime From 91484cdddac07077c3128903c97446b1cb0661b5 Mon Sep 17 00:00:00 2001 From: David Date: Fri, 11 Dec 2020 11:29:07 -0800 Subject: [PATCH 2/2] Adapt to new torch export API for dictionary --- .circleci/config.yml | 2 +- .circleci/config.yml.in | 2 +- test/test_onnx.py | 6 +++++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 664afba9323..79f1353d8da 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -157,7 +157,7 @@ jobs: # need to install torchvision dependencies due to transitive imports pip install --user --progress-bar off --editable . pip install --user onnx - pip install --user -i https://test.pypi.org/simple/ ort-nightly==1.5.2.dev202012031 + pip install --user onnxruntime python test/test_onnx.py binary_linux_wheel: diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index d84272fec61..809bf9175a4 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -157,7 +157,7 @@ jobs: # need to install torchvision dependencies due to transitive imports pip install --user --progress-bar off --editable . pip install --user onnx - pip install --user -i https://test.pypi.org/simple/ ort-nightly==1.5.2.dev202012031 + pip install --user onnxruntime python test/test_onnx.py binary_linux_wheel: diff --git a/test/test_onnx.py b/test/test_onnx.py index 54a4e385a8d..5bb8eba2530 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -33,8 +33,12 @@ def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_consta model.eval() onnx_io = io.BytesIO() + if isinstance(inputs_list[0][-1], dict): + torch_onnx_input = inputs_list[0] + ({},) + else: + torch_onnx_input = inputs_list[0] # export to onnx with the first input - torch.onnx.export(model, inputs_list[0], onnx_io, + torch.onnx.export(model, torch_onnx_input, onnx_io, do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version, dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names) # validate the exported model with onnx runtime