Skip to content

Commit

Permalink
Update torch.onnx.OnnxRegistry usage in DORT tests (microsoft#17009)
Browse files Browse the repository at this point in the history
Update the usage of torch.onnx.OnnxRegistry, as it's officially
published in PyTorch: pytorch/pytorch#106140.

---------

Co-authored-by: Wei-Sheng Chin <wechi@microsoft.com>
  • Loading branch information
titaiwangms and wschin committed Aug 7, 2023
1 parent 4e6ea73 commit 8a335b8
Showing 1 changed file with 23 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from functorch.compile import min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd
from torch.library import Library
from torch.onnx._internal.exporter import ExportOptions

import onnxruntime
from onnxruntime.training.torchdynamo.ort_backend import OrtBackend
Expand Down Expand Up @@ -99,25 +98,25 @@ def test_export_aten_mul_as_onnx_custom_op_and_run_ort(self):
"""
torch._dynamo.reset()

# Create executor of ONNX model.
# We will register a custom exporter for aten.mul.Tensor
# in the following step.
ort_backend = OrtBackend(
ep="CPUExecutionProvider",
session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(),
onnx_exporter_options=ExportOptions(dynamic_shapes=True),
)
# Register custom_exporter_for_aten_add_Tensor as "aten::mul.Tensor"'s
# exporter.
# Use custom_exporter_for_aten_add_Tensor.to_function_proto() to see
# the sub-graph representing "aten::mul.Tensor".
ort_backend.resolved_onnx_exporter_options.onnxfunction_dispatcher.onnx_registry.register_custom_op(
onnx_registry = torch.onnx.OnnxRegistry()
onnx_registry.register_op(
function=custom_exporter_for_aten_add_Tensor,
namespace="aten",
op_name="mul",
overload="Tensor",
)

# In order to use custom exporting function inside PyTorch-to-ONNX exporter used in DORT, create executor of ONNX model with custom `onnx_registry`.
ort_backend = OrtBackend(
ep="CPUExecutionProvider",
session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(),
onnx_exporter_options=torch.onnx.ExportOptions(dynamic_shapes=True, onnx_registry=onnx_registry),
)

# Wrap ORT executor as a Dynamo backend.
aot_ort = aot_autograd(
fw_compiler=ort_backend,
Expand Down Expand Up @@ -159,21 +158,26 @@ def bar_impl(self: torch.Tensor) -> torch.Tensor:

foo_lib.impl(bar_name, bar_impl, "CompositeExplicitAutograd")

# Ask exporter to map "torch.ops.foo.bar" to
# custom_exporter_for_foo_bar_default.
onnx_registry = torch.onnx.OnnxRegistry()
onnx_registry.register_op(
function=custom_exporter_for_aten_add_Tensor,
namespace="aten",
op_name="mul",
overload="Tensor",
)

# Create executor of ONNX model.
ort_backend = OrtBackend(
ep="CPUExecutionProvider", session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options()
ep="CPUExecutionProvider",
session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(),
onnx_exporter_options=torch.onnx.ExportOptions(onnx_registry=onnx_registry),
)
# Allow torch.ops.foo.bar.default to be sent to DORT.
# _support_dict tells Dynamo which ops to sent to DORT.
ort_backend._supported_ops._support_dict.add(torch.ops.foo.bar.default)
# Ask exporter to map "torch.ops.foo.bar" to
# custom_exporter_for_foo_bar_default.
# TODO(wechi): Redesign API to expose this better.
ort_backend.resolved_onnx_exporter_options.onnxfunction_dispatcher.onnx_registry.register_custom_op(
function=custom_exporter_for_foo_bar_default,
namespace="foo",
op_name="bar",
)

# Wrap ORT executor as a Dynamo backend.
aot_ort = aot_autograd(
fw_compiler=ort_backend,
Expand Down

0 comments on commit 8a335b8

Please sign in to comment.