From 6790d4f635ad478cb6850cc64b22a3b73a68e7c0 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Fri, 10 Nov 2023 22:43:22 +0000 Subject: [PATCH] Add ONNXProgram.__call__ API to run model with ONNX Runtime Currently the user can use torch.onnx.dynamo_export to export the model to ONNX. ```python import torch class Model(torch.nn.Module): def forward(self, x): return x + 1.0 onnx_program = torch.onnx.dynamo_export( Model(), torch.randn(1, 1, 2, dtype=torch.float), ) ``` The next step would be instantiate a ONNX runtime to execute it ```python import onnxruntime # type: ignore[import] onnx_input = self.adapt_torch_inputs_to_onnx(*args, **kwargs) options = options or {} providers = options.get("providers", onnxruntime.get_available_providers()) onnx_model = self.model_proto.SerializeToString() ort_session = onnxruntime.InferenceSession(onnx_model, providers=providers) def to_numpy(tensor): return ( tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() ) onnxruntime_input = { k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input) } return ort_session.run(None, onnxruntime_input) ``` This PR provides the ONNXProgram.__call__ method as facilitator to use ONNX Runtime under the hood, similar to how torch.export.ExportedProgram.__call__ which allows the underlying torch.fx.GraphModule to be executed [ghstack-poisoned] --- test/onnx/onnx_test_common.py | 9 +++++- test/onnx/test_fx_to_onnx_with_onnxruntime.py | 14 ++++++++ torch/onnx/_internal/exporter.py | 32 +++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py index f89a28bd63aab..d4e6f99c1a491 100644 --- a/test/onnx/onnx_test_common.py +++ b/test/onnx/onnx_test_common.py @@ -324,6 +324,7 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime( atol, rtol, has_mutation=has_mutation, + model_type=model_type, ) # This confirms the exported mode accepts different input shapes # when dynamic shape is enabled. @@ -347,6 +348,7 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime( atol, rtol, has_mutation=has_mutation, + model_type=model_type, ) @@ -423,6 +425,7 @@ def _compare_pytorch_onnx_with_ort( atol: Optional[float] = None, rtol: Optional[float] = None, has_mutation: bool = False, + model_type: str = None, ): if has_mutation: ref_model = _try_clone_model(model) @@ -440,7 +443,11 @@ def _compare_pytorch_onnx_with_ort( ref_outputs = onnx_program.adapt_torch_outputs_to_onnx( ref_model(*ref_input_args, **ref_input_kwargs) ) - ort_outputs = run_ort(onnx_program, onnx_format_args) + + if model_type == "torch.export.ExportedProgram": + ort_outputs = onnx_program(*input_args, **input_kwargs) + else: + ort_outputs = run_ort(onnx_program, onnx_format_args) # When model is a torch.export.ExportedProgram, the number of outputs in the ONNX model can be greater # than the number of outputs in the original model. This is because the ONNX model may contain diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 6cf38591e04e0..13a2df9507072 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -889,6 +889,20 @@ def create_pytorch_only_extra_kwargs(): create_pytorch_only_extra_kwargs, ) + def test_execute_model_with___call__(self): + class Model(torch.nn.Module): + def forward(self, x): + return x + 1.0 + + onnx_program = torch.onnx.dynamo_export( + Model(), + torch.randn(1, 1, 2, dtype=torch.float), + ) + + # The other tests use ONNXProgram.__call__ indirectly and check for output equality + # This test aims to ensure ONNXProgram.__call__ API runs successfully despite internal test infra code + _ = onnx_program(x) + def test_exported_program_as_input(self): class Model(torch.nn.Module): def forward(self, x): diff --git a/torch/onnx/_internal/exporter.py b/torch/onnx/_internal/exporter.py index 58da81a1898e1..d3ce07c998783 100644 --- a/torch/onnx/_internal/exporter.py +++ b/torch/onnx/_internal/exporter.py @@ -633,6 +633,38 @@ def __init__( self._fake_context = fake_context self._export_exception = export_exception + def __call__(self, *args: Any, options=None, **kwargs: Any) -> Any: + """Runs the ONNX model using ONNX Runtime + + Args: + args: The positional inputs to the model. + kwargs: The keyword inputs to the model. + options: The options to use for running the model with ONNX Runtime. + + Returns: + The model output as computed by ONNX Runtime + """ + import onnxruntime # type: ignore[import] + + onnx_input = self.adapt_torch_inputs_to_onnx(*args, **kwargs) + options = options or {} + providers = options.get("providers", onnxruntime.get_available_providers()) + onnx_model = self.model_proto.SerializeToString() + ort_session = onnxruntime.InferenceSession(onnx_model, providers=providers) + + def to_numpy(tensor): + return ( + tensor.detach().cpu().numpy() + if tensor.requires_grad + else tensor.cpu().numpy() + ) + + onnxruntime_input = { + k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input) + } + + return ort_session.run(None, onnxruntime_input) + @property def model_proto(self) -> onnx.ModelProto: # type: ignore[name-defined] """The exported ONNX model as an :py:obj:`onnx.ModelProto`."""