Skip to content

Commit

Permalink
Add ONNXProgram.__call__ API to run model with ONNX Runtime (#113495)
Browse files Browse the repository at this point in the history
Currently the user can use torch.onnx.dynamo_export to export the model.
to ONNX.

```python
import torch

class Model(torch.nn.Module):
    def forward(self, x):
        return x + 1.0

onnx_program = torch.onnx.dynamo_export(
    Model(),
    torch.randn(1, 1, 2, dtype=torch.float),
)
```

The next step would be instantiating a ONNX runtime to execute it.

```python
import onnxruntime  # type: ignore[import]

onnx_input = self.adapt_torch_inputs_to_onnx(*args, **kwargs)
options = options or {}
providers = options.get("providers", onnxruntime.get_available_providers())
onnx_model = self.model_proto.SerializeToString()
ort_session = onnxruntime.InferenceSession(onnx_model, providers=providers)

def to_numpy(tensor):
    return (
        tensor.detach().cpu().numpy()
        if tensor.requires_grad
        else tensor.cpu().numpy()
    )

onnxruntime_input = {
    k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)
}

return ort_session.run(None, onnxruntime_input)
```

This PR provides the `ONNXProgram.__call__` method as facilitator to use ONNX Runtime under the hood, similar to how `torch.export.ExportedProgram.__call__` which allows the underlying `torch.fx.GraphModule` to be executed.
Pull Request resolved: #113495
Approved by: https://github.com/titaiwangms
  • Loading branch information
Thiago Crepaldi authored and pytorchmergebot committed Nov 22, 2023
1 parent 044cd56 commit 3f736c2
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 8 deletions.
3 changes: 3 additions & 0 deletions docs/source/onnx_dynamo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ API Reference
.. autoclass:: torch.onnx.ONNXProgramSerializer
:members:

.. autoclass:: torch.onnx.ONNXRuntimeOptions
:members:

.. autoclass:: torch.onnx.InvalidExportOptionsError
:members:

Expand Down
8 changes: 2 additions & 6 deletions test/onnx/onnx_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,15 +439,11 @@ def _compare_pytorch_onnx_with_ort(
ref_input_args = input_args
ref_input_kwargs = input_kwargs

# Format original model inputs into the format expected by exported ONNX model.
onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx(
*input_args, **input_kwargs
)

ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(
ref_model(*ref_input_args, **ref_input_kwargs)
)
ort_outputs = run_ort(onnx_program, onnx_format_args)

ort_outputs = onnx_program(*input_args, **input_kwargs)

if len(ref_outputs) != len(ort_outputs):
raise AssertionError(
Expand Down
15 changes: 15 additions & 0 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,21 @@ def create_pytorch_only_extra_kwargs():
create_pytorch_only_extra_kwargs,
)

def test_execute_model_with___call__(self):
class Model(torch.nn.Module):
def forward(self, x):
return x + 1.0

input_x = torch.randn(1, 1, 2, dtype=torch.float)
onnx_program = torch.onnx.dynamo_export(
Model(),
input_x,
)

# The other tests use ONNXProgram.__call__ indirectly and check for output equality
# This test aims to ensure ONNXProgram.__call__ API runs successfully despite internal test infra code
_ = onnx_program(input_x)

def test_exported_program_as_input(self):
class Model(torch.nn.Module):
def forward(self, x):
Expand Down
3 changes: 3 additions & 0 deletions torch/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
ExportOptions,
ONNXProgram,
ONNXProgramSerializer,
ONNXRuntimeOptions,
InvalidExportOptionsError,
OnnxExporterError,
OnnxRegistry,
Expand Down Expand Up @@ -103,6 +104,7 @@
"ExportOptions",
"ONNXProgram",
"ONNXProgramSerializer",
"ONNXRuntimeOptions",
"InvalidExportOptionsError",
"OnnxExporterError",
"OnnxRegistry",
Expand All @@ -118,6 +120,7 @@
ExportOptions.__module__ = "torch.onnx"
ONNXProgram.__module__ = "torch.onnx"
ONNXProgramSerializer.__module__ = "torch.onnx"
ONNXRuntimeOptions.__module__ = "torch.onnx"
dynamo_export.__module__ = "torch.onnx"
InvalidExportOptionsError.__module__ = "torch.onnx"
OnnxExporterError.__module__ = "torch.onnx"
Expand Down
70 changes: 68 additions & 2 deletions torch/onnx/_internal/exporter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# necessary to surface onnx.ModelProto through ONNXProgram:
from __future__ import annotations
from __future__ import ( # for onnx.ModelProto (ONNXProgram) and onnxruntime (ONNXRuntimeOptions)
annotations,
)

import abc

Expand Down Expand Up @@ -52,6 +53,7 @@
# 'import onnx' inside of dynamo_export (by way of _assert_dependencies).
if TYPE_CHECKING:
import onnx
import onnxruntime # type: ignore[import]
import onnxscript # type: ignore[import]
from onnxscript.function_libs.torch_lib import ( # type: ignore[import]
registration as torchlib_registry,
Expand Down Expand Up @@ -602,6 +604,41 @@ def serialize(
)


class ONNXRuntimeOptions:
"""Options to influence the execution of the ONNX model through ONNX Runtime.
Attributes:
session_options: ONNX Runtime session options.
execution_providers: ONNX Runtime execution providers to use during model execution.
execution_provider_options: ONNX Runtime execution provider options.
"""

session_options: Optional[Sequence["onnxruntime.SessionOptions"]] = None
"""ONNX Runtime session options."""

execution_providers: Optional[
Sequence[Union[str, Tuple[str, Dict[Any, Any]]]]
] = None
"""ONNX Runtime execution providers to use during model execution."""

execution_provider_options: Optional[Sequence[Dict[Any, Any]]] = None
"""ONNX Runtime execution provider options."""

@_beartype.beartype
def __init__(
self,
*,
session_options: Optional[Sequence["onnxruntime.SessionOptions"]] = None,
execution_providers: Optional[
Sequence[Union[str, Tuple[str, Dict[Any, Any]]]]
] = None,
execution_provider_options: Optional[Sequence[Dict[Any, Any]]] = None,
):
self.session_options = session_options
self.execution_providers = execution_providers
self.execution_provider_options = execution_provider_options


class ONNXProgram:
"""An in-memory representation of a PyTorch model that has been exported to ONNX.
Expand Down Expand Up @@ -643,6 +680,34 @@ def __init__(
self._fake_context = fake_context
self._export_exception = export_exception

def __call__(
self, *args: Any, options: Optional[ONNXRuntimeOptions] = None, **kwargs: Any
) -> Any:
"""Runs the ONNX model using ONNX Runtime
Args:
args: The positional inputs to the model.
kwargs: The keyword inputs to the model.
options: The options to use for running the model with ONNX Runtime.
Returns:
The model output as computed by ONNX Runtime
"""
import onnxruntime # type: ignore[import]

onnx_input = self.adapt_torch_inputs_to_onnx(*args, **kwargs)
options = options or ONNXRuntimeOptions()
providers = options.execution_providers or onnxruntime.get_available_providers()
onnx_model = self.model_proto.SerializeToString()
ort_session = onnxruntime.InferenceSession(onnx_model, providers=providers)

onnxruntime_input = {
k.name: v.numpy(force=True)
for k, v in zip(ort_session.get_inputs(), onnx_input)
}

return ort_session.run(None, onnxruntime_input)

@property
def model_proto(self) -> onnx.ModelProto: # type: ignore[name-defined]
"""The exported ONNX model as an :py:obj:`onnx.ModelProto`."""
Expand Down Expand Up @@ -1416,6 +1481,7 @@ def common_pre_export_passes(
"ExportOptions",
"ONNXProgram",
"ONNXProgramSerializer",
"ONNXRuntimeOptions",
"InvalidExportOptionsError",
"OnnxExporterError",
"OnnxRegistry",
Expand Down

0 comments on commit 3f736c2

Please sign in to comment.