Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNXProgram.__call__ API to run model with ONNX Runtime #113495

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6790d4f
Add ONNXProgram.__call__ API to run model with ONNX Runtime
Nov 10, 2023
3e6a4c4
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 13, 2023
9ff23aa
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 14, 2023
dd29573
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 14, 2023
941365e
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 14, 2023
6efecf2
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 14, 2023
15ef2d3
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 15, 2023
32346e0
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 15, 2023
04cab4e
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 15, 2023
e5814a1
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 15, 2023
b69c023
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 16, 2023
21472bb
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 16, 2023
608d53d
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 17, 2023
01629cc
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 17, 2023
9dcfbee
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 17, 2023
d6e7801
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 17, 2023
b27f472
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 17, 2023
25aa9cf
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 17, 2023
f5b9e17
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 20, 2023
d5a49b1
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 21, 2023
25989b8
Update on "Add ONNXProgram.__call__ API to run model with ONNX Runtime"
Nov 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 8 additions & 1 deletion test/onnx/onnx_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
atol,
rtol,
has_mutation=has_mutation,
model_type=model_type,
)
# This confirms the exported mode accepts different input shapes
# when dynamic shape is enabled.
Expand All @@ -347,6 +348,7 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
atol,
rtol,
has_mutation=has_mutation,
model_type=model_type,
)


Expand Down Expand Up @@ -423,6 +425,7 @@ def _compare_pytorch_onnx_with_ort(
atol: Optional[float] = None,
rtol: Optional[float] = None,
has_mutation: bool = False,
model_type: str = None,
):
if has_mutation:
ref_model = _try_clone_model(model)
Expand All @@ -440,7 +443,11 @@ def _compare_pytorch_onnx_with_ort(
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(
ref_model(*ref_input_args, **ref_input_kwargs)
)
ort_outputs = run_ort(onnx_program, onnx_format_args)

if model_type == "torch.export.ExportedProgram":
ort_outputs = onnx_program(*input_args, **input_kwargs)
else:
ort_outputs = run_ort(onnx_program, onnx_format_args)

# When model is a torch.export.ExportedProgram, the number of outputs in the ONNX model can be greater
# than the number of outputs in the original model. This is because the ONNX model may contain
Expand Down
14 changes: 14 additions & 0 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,20 @@ def create_pytorch_only_extra_kwargs():
create_pytorch_only_extra_kwargs,
)

def test_execute_model_with___call__(self):
class Model(torch.nn.Module):
def forward(self, x):
return x + 1.0

onnx_program = torch.onnx.dynamo_export(
Model(),
torch.randn(1, 1, 2, dtype=torch.float),
)

# The other tests use ONNXProgram.__call__ indirectly and check for output equality
# This test aims to ensure ONNXProgram.__call__ API runs successfully despite internal test infra code
_ = onnx_program(x)

def test_exported_program_as_input(self):
class Model(torch.nn.Module):
def forward(self, x):
Expand Down
32 changes: 32 additions & 0 deletions torch/onnx/_internal/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,38 @@ def __init__(
self._fake_context = fake_context
self._export_exception = export_exception

def __call__(self, *args: Any, options=None, **kwargs: Any) -> Any:
"""Runs the ONNX model using ONNX Runtime

Args:
args: The positional inputs to the model.
kwargs: The keyword inputs to the model.
options: The options to use for running the model with ONNX Runtime.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is public api I think we should be careful with introducing arguments. Should we make options a dataclass?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We didn't use dataclass for torch.onnx.ExportOptions, but I do agree that having a defined type instead of Anywould be a more robust solution. I will change options beclass ONNXRuntimeOption`.

In the near future we probably will add at least some of the following members to it:

sess_options: Sequence[onnxruntime.SessionOptions] | None = None,
providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None,
provider_options: Sequence[dict[Any, Any]] | None = None,

so that onnxruntime.InferenceSession can be instantiated with any customization we need


Returns:
The model output as computed by ONNX Runtime
"""
import onnxruntime # type: ignore[import]

onnx_input = self.adapt_torch_inputs_to_onnx(*args, **kwargs)
options = options or {}
providers = options.get("providers", onnxruntime.get_available_providers())
onnx_model = self.model_proto.SerializeToString()
ort_session = onnxruntime.InferenceSession(onnx_model, providers=providers)

def to_numpy(tensor):
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
return (
tensor.detach().cpu().numpy()
if tensor.requires_grad
else tensor.cpu().numpy()
)

onnxruntime_input = {
k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)
}

return ort_session.run(None, onnxruntime_input)

@property
def model_proto(self) -> onnx.ModelProto: # type: ignore[name-defined]
"""The exported ONNX model as an :py:obj:`onnx.ModelProto`."""
Expand Down