From 3f736c2d77b5ded0f38a44213883088113c42ff8 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Tue, 21 Nov 2023 19:54:50 +0000 Subject: [PATCH] Add ONNXProgram.__call__ API to run model with ONNX Runtime (#113495) Currently the user can use torch.onnx.dynamo_export to export the model. to ONNX. ```python import torch class Model(torch.nn.Module): def forward(self, x): return x + 1.0 onnx_program = torch.onnx.dynamo_export( Model(), torch.randn(1, 1, 2, dtype=torch.float), ) ``` The next step would be instantiating a ONNX runtime to execute it. ```python import onnxruntime # type: ignore[import] onnx_input = self.adapt_torch_inputs_to_onnx(*args, **kwargs) options = options or {} providers = options.get("providers", onnxruntime.get_available_providers()) onnx_model = self.model_proto.SerializeToString() ort_session = onnxruntime.InferenceSession(onnx_model, providers=providers) def to_numpy(tensor): return ( tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() ) onnxruntime_input = { k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input) } return ort_session.run(None, onnxruntime_input) ``` This PR provides the `ONNXProgram.__call__` method as facilitator to use ONNX Runtime under the hood, similar to how `torch.export.ExportedProgram.__call__` which allows the underlying `torch.fx.GraphModule` to be executed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/113495 Approved by: https://github.com/titaiwangms --- docs/source/onnx_dynamo.rst | 3 + test/onnx/onnx_test_common.py | 8 +-- test/onnx/test_fx_to_onnx_with_onnxruntime.py | 15 ++++ torch/onnx/__init__.py | 3 + torch/onnx/_internal/exporter.py | 70 ++++++++++++++++++- 5 files changed, 91 insertions(+), 8 deletions(-) diff --git a/docs/source/onnx_dynamo.rst b/docs/source/onnx_dynamo.rst index a156c51310c3..09a09bc3a300 100644 --- a/docs/source/onnx_dynamo.rst +++ b/docs/source/onnx_dynamo.rst @@ -146,6 +146,9 @@ API Reference .. autoclass:: torch.onnx.ONNXProgramSerializer :members: +.. autoclass:: torch.onnx.ONNXRuntimeOptions + :members: + .. autoclass:: torch.onnx.InvalidExportOptionsError :members: diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py index bcc8aa7a012d..2892a23f520a 100644 --- a/test/onnx/onnx_test_common.py +++ b/test/onnx/onnx_test_common.py @@ -439,15 +439,11 @@ def _compare_pytorch_onnx_with_ort( ref_input_args = input_args ref_input_kwargs = input_kwargs - # Format original model inputs into the format expected by exported ONNX model. - onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx( - *input_args, **input_kwargs - ) - ref_outputs = onnx_program.adapt_torch_outputs_to_onnx( ref_model(*ref_input_args, **ref_input_kwargs) ) - ort_outputs = run_ort(onnx_program, onnx_format_args) + + ort_outputs = onnx_program(*input_args, **input_kwargs) if len(ref_outputs) != len(ort_outputs): raise AssertionError( diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 728430cba994..26fa6f215bec 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -896,6 +896,21 @@ def create_pytorch_only_extra_kwargs(): create_pytorch_only_extra_kwargs, ) + def test_execute_model_with___call__(self): + class Model(torch.nn.Module): + def forward(self, x): + return x + 1.0 + + input_x = torch.randn(1, 1, 2, dtype=torch.float) + onnx_program = torch.onnx.dynamo_export( + Model(), + input_x, + ) + + # The other tests use ONNXProgram.__call__ indirectly and check for output equality + # This test aims to ensure ONNXProgram.__call__ API runs successfully despite internal test infra code + _ = onnx_program(input_x) + def test_exported_program_as_input(self): class Model(torch.nn.Module): def forward(self, x): diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index e50dfb33004c..ad3af0984d4d 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -48,6 +48,7 @@ ExportOptions, ONNXProgram, ONNXProgramSerializer, + ONNXRuntimeOptions, InvalidExportOptionsError, OnnxExporterError, OnnxRegistry, @@ -103,6 +104,7 @@ "ExportOptions", "ONNXProgram", "ONNXProgramSerializer", + "ONNXRuntimeOptions", "InvalidExportOptionsError", "OnnxExporterError", "OnnxRegistry", @@ -118,6 +120,7 @@ ExportOptions.__module__ = "torch.onnx" ONNXProgram.__module__ = "torch.onnx" ONNXProgramSerializer.__module__ = "torch.onnx" +ONNXRuntimeOptions.__module__ = "torch.onnx" dynamo_export.__module__ = "torch.onnx" InvalidExportOptionsError.__module__ = "torch.onnx" OnnxExporterError.__module__ = "torch.onnx" diff --git a/torch/onnx/_internal/exporter.py b/torch/onnx/_internal/exporter.py index fbda341ee039..807ef52a0483 100644 --- a/torch/onnx/_internal/exporter.py +++ b/torch/onnx/_internal/exporter.py @@ -1,5 +1,6 @@ -# necessary to surface onnx.ModelProto through ONNXProgram: -from __future__ import annotations +from __future__ import ( # for onnx.ModelProto (ONNXProgram) and onnxruntime (ONNXRuntimeOptions) + annotations, +) import abc @@ -52,6 +53,7 @@ # 'import onnx' inside of dynamo_export (by way of _assert_dependencies). if TYPE_CHECKING: import onnx + import onnxruntime # type: ignore[import] import onnxscript # type: ignore[import] from onnxscript.function_libs.torch_lib import ( # type: ignore[import] registration as torchlib_registry, @@ -602,6 +604,41 @@ def serialize( ) +class ONNXRuntimeOptions: + """Options to influence the execution of the ONNX model through ONNX Runtime. + + Attributes: + session_options: ONNX Runtime session options. + execution_providers: ONNX Runtime execution providers to use during model execution. + execution_provider_options: ONNX Runtime execution provider options. + """ + + session_options: Optional[Sequence["onnxruntime.SessionOptions"]] = None + """ONNX Runtime session options.""" + + execution_providers: Optional[ + Sequence[Union[str, Tuple[str, Dict[Any, Any]]]] + ] = None + """ONNX Runtime execution providers to use during model execution.""" + + execution_provider_options: Optional[Sequence[Dict[Any, Any]]] = None + """ONNX Runtime execution provider options.""" + + @_beartype.beartype + def __init__( + self, + *, + session_options: Optional[Sequence["onnxruntime.SessionOptions"]] = None, + execution_providers: Optional[ + Sequence[Union[str, Tuple[str, Dict[Any, Any]]]] + ] = None, + execution_provider_options: Optional[Sequence[Dict[Any, Any]]] = None, + ): + self.session_options = session_options + self.execution_providers = execution_providers + self.execution_provider_options = execution_provider_options + + class ONNXProgram: """An in-memory representation of a PyTorch model that has been exported to ONNX. @@ -643,6 +680,34 @@ def __init__( self._fake_context = fake_context self._export_exception = export_exception + def __call__( + self, *args: Any, options: Optional[ONNXRuntimeOptions] = None, **kwargs: Any + ) -> Any: + """Runs the ONNX model using ONNX Runtime + + Args: + args: The positional inputs to the model. + kwargs: The keyword inputs to the model. + options: The options to use for running the model with ONNX Runtime. + + Returns: + The model output as computed by ONNX Runtime + """ + import onnxruntime # type: ignore[import] + + onnx_input = self.adapt_torch_inputs_to_onnx(*args, **kwargs) + options = options or ONNXRuntimeOptions() + providers = options.execution_providers or onnxruntime.get_available_providers() + onnx_model = self.model_proto.SerializeToString() + ort_session = onnxruntime.InferenceSession(onnx_model, providers=providers) + + onnxruntime_input = { + k.name: v.numpy(force=True) + for k, v in zip(ort_session.get_inputs(), onnx_input) + } + + return ort_session.run(None, onnxruntime_input) + @property def model_proto(self) -> onnx.ModelProto: # type: ignore[name-defined] """The exported ONNX model as an :py:obj:`onnx.ModelProto`.""" @@ -1416,6 +1481,7 @@ def common_pre_export_passes( "ExportOptions", "ONNXProgram", "ONNXProgramSerializer", + "ONNXRuntimeOptions", "InvalidExportOptionsError", "OnnxExporterError", "OnnxRegistry",