Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNXProgram.__call__ API to run model with ONNX Runtime #113495

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6790d4f
Add ONNXProgram.__call__ API to run model with ONNX Runtime
Nov 10, 2023
3e6a4c4
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 13, 2023
9ff23aa
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 14, 2023
dd29573
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 14, 2023
941365e
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 14, 2023
6efecf2
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 14, 2023
15ef2d3
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 15, 2023
32346e0
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 15, 2023
04cab4e
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 15, 2023
e5814a1
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 15, 2023
b69c023
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 16, 2023
21472bb
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 16, 2023
608d53d
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 17, 2023
01629cc
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 17, 2023
9dcfbee
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 17, 2023
d6e7801
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 17, 2023
b27f472
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 17, 2023
25aa9cf
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 17, 2023
f5b9e17
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 20, 2023
d5a49b1
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 21, 2023
25989b8
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/onnx_dynamo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ API Reference
.. autoclass:: torch.onnx.ONNXProgramSerializer
:members:

.. autoclass:: torch.onnx.ONNXRuntimeOptions
:members:

.. autoclass:: torch.onnx.InvalidExportOptionsError
:members:

Expand Down
8 changes: 2 additions & 6 deletions test/onnx/onnx_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,15 +439,11 @@ def _compare_pytorch_onnx_with_ort(
ref_input_args = input_args
ref_input_kwargs = input_kwargs

# Format original model inputs into the format expected by exported ONNX model.
onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx(
*input_args, **input_kwargs
)

ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(
ref_model(*ref_input_args, **ref_input_kwargs)
)
ort_outputs = run_ort(onnx_program, onnx_format_args)

ort_outputs = onnx_program(*input_args, **input_kwargs)

if len(ref_outputs) != len(ort_outputs):
raise AssertionError(
Expand Down
15 changes: 15 additions & 0 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,21 @@ def create_pytorch_only_extra_kwargs():
create_pytorch_only_extra_kwargs,
)

def test_execute_model_with___call__(self):
class Model(torch.nn.Module):
def forward(self, x):
return x + 1.0

input_x = torch.randn(1, 1, 2, dtype=torch.float)
onnx_program = torch.onnx.dynamo_export(
Model(),
input_x,
)

# The other tests use ONNXProgram.__call__ indirectly and check for output equality
# This test aims to ensure ONNXProgram.__call__ API runs successfully despite internal test infra code
_ = onnx_program(input_x)

def test_exported_program_as_input(self):
class Model(torch.nn.Module):
def forward(self, x):
Expand Down
3 changes: 3 additions & 0 deletions torch/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
ExportOptions,
ONNXProgram,
ONNXProgramSerializer,
ONNXRuntimeOptions,
InvalidExportOptionsError,
OnnxExporterError,
OnnxRegistry,
Expand Down Expand Up @@ -103,6 +104,7 @@
"ExportOptions",
"ONNXProgram",
"ONNXProgramSerializer",
"ONNXRuntimeOptions",
"InvalidExportOptionsError",
"OnnxExporterError",
"OnnxRegistry",
Expand All @@ -118,6 +120,7 @@
ExportOptions.__module__ = "torch.onnx"
ONNXProgram.__module__ = "torch.onnx"
ONNXProgramSerializer.__module__ = "torch.onnx"
ONNXRuntimeOptions.__module__ = "torch.onnx"
dynamo_export.__module__ = "torch.onnx"
InvalidExportOptionsError.__module__ = "torch.onnx"
OnnxExporterError.__module__ = "torch.onnx"
Expand Down
70 changes: 68 additions & 2 deletions torch/onnx/_internal/exporter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# necessary to surface onnx.ModelProto through ONNXProgram:
from __future__ import annotations
from __future__ import ( # for onnx.ModelProto (ONNXProgram) and onnxruntime (ONNXRuntimeOptions)
annotations,
)

import abc

Expand Down Expand Up @@ -52,6 +53,7 @@
# 'import onnx' inside of dynamo_export (by way of _assert_dependencies).
if TYPE_CHECKING:
import onnx
import onnxruntime # type: ignore[import]
import onnxscript # type: ignore[import]
from onnxscript.function_libs.torch_lib import ( # type: ignore[import]
registration as torchlib_registry,
Expand Down Expand Up @@ -602,6 +604,41 @@ def serialize(
)


class ONNXRuntimeOptions:
"""Options to influence the execution of the ONNX model through ONNX Runtime.

Attributes:
session_options: ONNX Runtime session options.
execution_providers: ONNX Runtime execution providers to use during model execution.
execution_provider_options: ONNX Runtime execution provider options.
"""

session_options: Optional[Sequence["onnxruntime.SessionOptions"]] = None
"""ONNX Runtime session options."""

execution_providers: Optional[
Sequence[Union[str, Tuple[str, Dict[Any, Any]]]]
] = None
"""ONNX Runtime execution providers to use during model execution."""

execution_provider_options: Optional[Sequence[Dict[Any, Any]]] = None
"""ONNX Runtime execution provider options."""

@_beartype.beartype
def __init__(
self,
*,
session_options: Optional[Sequence["onnxruntime.SessionOptions"]] = None,
execution_providers: Optional[
Sequence[Union[str, Tuple[str, Dict[Any, Any]]]]
] = None,
execution_provider_options: Optional[Sequence[Dict[Any, Any]]] = None,
):
self.session_options = session_options
self.execution_providers = execution_providers
self.execution_provider_options = execution_provider_options


class ONNXProgram:
"""An in-memory representation of a PyTorch model that has been exported to ONNX.

Expand Down Expand Up @@ -643,6 +680,34 @@ def __init__(
self._fake_context = fake_context
self._export_exception = export_exception

def __call__(
self, *args: Any, options: Optional[ONNXRuntimeOptions] = None, **kwargs: Any
) -> Any:
"""Runs the ONNX model using ONNX Runtime

Args:
args: The positional inputs to the model.
kwargs: The keyword inputs to the model.
options: The options to use for running the model with ONNX Runtime.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is public api I think we should be careful with introducing arguments. Should we make options a dataclass?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We didn't use dataclass for torch.onnx.ExportOptions, but I do agree that having a defined type instead of Anywould be a more robust solution. I will change options beclass ONNXRuntimeOption`.

In the near future we probably will add at least some of the following members to it:

sess_options: Sequence[onnxruntime.SessionOptions] | None = None,
providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None,
provider_options: Sequence[dict[Any, Any]] | None = None,

so that onnxruntime.InferenceSession can be instantiated with any customization we need

Returns:
The model output as computed by ONNX Runtime
"""
import onnxruntime # type: ignore[import]

onnx_input = self.adapt_torch_inputs_to_onnx(*args, **kwargs)
options = options or ONNXRuntimeOptions()
providers = options.execution_providers or onnxruntime.get_available_providers()
onnx_model = self.model_proto.SerializeToString()
ort_session = onnxruntime.InferenceSession(onnx_model, providers=providers)

onnxruntime_input = {
k.name: v.numpy(force=True)
for k, v in zip(ort_session.get_inputs(), onnx_input)
}

return ort_session.run(None, onnxruntime_input)

@property
def model_proto(self) -> onnx.ModelProto: # type: ignore[name-defined]
"""The exported ONNX model as an :py:obj:`onnx.ModelProto`."""
Expand Down Expand Up @@ -1416,6 +1481,7 @@ def common_pre_export_passes(
"ExportOptions",
"ONNXProgram",
"ONNXProgramSerializer",
"ONNXRuntimeOptions",
"InvalidExportOptionsError",
"OnnxExporterError",
"OnnxRegistry",
Expand Down
Loading