Skip to content

Commit

Permalink
Introduce torch.onnx.dynamo_export API (#97920)
Browse files Browse the repository at this point in the history
This is the first phase of the new ONNX exporter API for exporting from TorchDynamo and FX, and represents the beginning of a new era for exporting ONNX from PyTorch.

The API here is a starting point upon which we will layer more capability and expressiveness in subsequent phases. This first phase introduces the following into `torch.onnx`:

```python
dynamo_export(
    model: torch.nn.Module,
    /,
    *model_args,
    export_options: Optional[ExportOptions] = None,
    **model_kwargs,
) -> ExportOutput:
    ...

class ExportOptions:
    opset_version: Optional[int] = None
    dynamic_shapes: Optional[bool] = None
    logger: Optional[logging.Logger] = None

class ExportOutputSerializer(Protocol):
    def serialize(
        self,
        export_output: ExportOutput,
        destination: io.BufferedIOBase,
    ) -> None:
        ...

class ExportOutput:
    model_proto: onnx.ModelProto

    def save(
        self,
        destination: Union[str, io.BufferedIOBase],
        *,
        serializer: Optional[ExportOutputSerializer] = None,
    ) -> None:
        ...
```

In addition to the API in the first commit on this PR, we have a few experiments for exporting Dynamo and FX to ONNX that this PR rationalizes through the new Exporter API and adjusts tests to use the new API.

- A base `FXGraphModuleExporter` exporter from which all derive:
  - `DynamoExportExporter`: uses dynamo.export to acquire FX graph
  - `DynamoOptimizeExporter`: uses dynamo.optimize to acquire FX graph
  - `FXSymbolicTraceExporter`: uses FX symbolic tracing

The `dynamo_export` API currently uses `DynamoOptimizeExporter`.

### Next Steps (subsequent PRs):

* Combine `DynamoExportExporter` and `DynamoOptimizeExporter` into a single `DynamoExporter`.
* Make it easy to test `FXSymbolicTraceExporter` through the same API; eventually `FXSymbolicTraceExporter` goes away entirely when the Dynamo approach works for large models. We want to keep `FXSymbolicTraceExporter` around for now for experimenting and internal use.
* Parameterize (on `ExportOptions`) and consolidate Dynamo exporter tests.
  - This PR intentionally leaves the existing tests unchanged as much as possible except for the necessary plumbing.
* Subsequent API phases:
  - Diagnostics
  - Registry, dispatcher, and Custom Ops
  - Passes
  - Dynamic shapes

Fixes #94774

Pull Request resolved: #97920
Approved by: https://github.com/justinchuby, https://github.com/titaiwangms, https://github.com/thiagocrepaldi, https://github.com/shubhambhokare1
  • Loading branch information
abock authored and ZainRizvi committed Apr 19, 2023
1 parent 95289bc commit 851fa7e
Show file tree
Hide file tree
Showing 16 changed files with 800 additions and 365 deletions.
18 changes: 18 additions & 0 deletions docs/source/onnx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -719,3 +719,21 @@ Classes
JitScalarType
torch.onnx.verification.GraphInfo
torch.onnx.verification.VerificationOptions

Preview: torch.onnx TorchDynamo Exporter
----------------------------------------

.. warning::
The ONNX exporter for TorchDynamo is under active development and is
subject to rapid change.

.. autofunction:: torch.onnx.dynamo_export

.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst

torch.onnx.ExportOptions
torch.onnx.ExportOutput
torch.onnx.ExportOutputSerializer
152 changes: 152 additions & 0 deletions test/onnx/dynamo/test_exporter_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Owner(s): ["module: onnx"]
import io
import logging
import unittest

import onnx

import torch
from beartype import roar
from torch.onnx import dynamo_export, ExportOptions, ExportOutput
from torch.onnx._internal.exporter import (
_DEFAULT_OPSET_VERSION,
ExportOutputSerializer,
ProtobufExportOutputSerializer,
ResolvedExportOptions,
)

from torch.testing._internal.common_utils import TemporaryFileName


class SampleModel(torch.nn.Module):
def forward(self, x):
y = x + 1
z = y.relu()
return (y, z)


class TestExportOptionsAPI(unittest.TestCase):
def test_opset_version_default(self):
options = ResolvedExportOptions(None)
self.assertEquals(options.opset_version, _DEFAULT_OPSET_VERSION)

def test_opset_version_explicit(self):
options = ResolvedExportOptions(ExportOptions(opset_version=3000))
self.assertEquals(options.opset_version, 3000)

def test_raise_on_invalid_argument_type(self):
expected_exception_type = roar.BeartypeException
with self.assertRaises(expected_exception_type):
ExportOptions(opset_version="3000") # type: ignore[arg-type]
with self.assertRaises(expected_exception_type):
ExportOptions(dynamic_shapes=2) # type: ignore[arg-type]
with self.assertRaises(expected_exception_type):
ExportOptions(logger="DEBUG") # type: ignore[arg-type]
with self.assertRaises(expected_exception_type):
ResolvedExportOptions(options=12) # type: ignore[arg-type]

def test_dynamic_shapes_default(self):
options = ResolvedExportOptions(None)
self.assertFalse(options.dynamic_shapes)

def test_dynamic_shapes_explicit(self):
options = ResolvedExportOptions(ExportOptions(dynamic_shapes=None))
self.assertFalse(options.dynamic_shapes)
options = ResolvedExportOptions(ExportOptions(dynamic_shapes=True))
self.assertTrue(options.dynamic_shapes)
options = ResolvedExportOptions(ExportOptions(dynamic_shapes=False))
self.assertFalse(options.dynamic_shapes)

def test_logger_default(self):
options = ResolvedExportOptions(None)
self.assertEquals(options.logger, logging.getLogger().getChild("torch.onnx"))

def test_logger_explicit(self):
options = ResolvedExportOptions(ExportOptions(logger=logging.getLogger()))
self.assertEquals(options.logger, logging.getLogger())
self.assertNotEquals(options.logger, logging.getLogger().getChild("torch.onnx"))


class TestDynamoExportAPI(unittest.TestCase):
def test_default_export(self):
output = dynamo_export(SampleModel(), torch.randn(1, 1, 2))
self.assertIsInstance(output, ExportOutput)
self.assertIsInstance(output.model_proto, onnx.ModelProto)

def test_export_with_options(self):
self.assertIsInstance(
dynamo_export(
SampleModel(),
torch.randn(1, 1, 2),
export_options=ExportOptions(
opset_version=17,
logger=logging.getLogger(),
dynamic_shapes=True,
),
),
ExportOutput,
)

def test_save_to_file_default_serializer(self):
with TemporaryFileName() as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(path)
onnx.load(path)

def test_save_to_existing_buffer_default_serializer(self):
buffer = io.BytesIO()
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(buffer)
onnx.load(buffer)

def test_save_to_file_using_specified_serializer(self):
expected_buffer = "I am not actually ONNX"

class CustomSerializer(ExportOutputSerializer):
def serialize(
self, export_output: ExportOutput, destination: io.BufferedIOBase
) -> None:
destination.write(expected_buffer.encode())

with TemporaryFileName() as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
path, serializer=CustomSerializer()
)
with open(path, "r") as fp:
self.assertEquals(fp.read(), expected_buffer)

def test_save_to_file_using_specified_serializer_without_inheritance(self):
expected_buffer = "I am not actually ONNX"

# NOTE: Inheritance from `ExportOutputSerializer` is not required.
# Because `ExportOutputSerializer` is a Protocol class.
# `beartype` will not complain.
class CustomSerializer:
def serialize(
self, export_output: ExportOutput, destination: io.BufferedIOBase
) -> None:
destination.write(expected_buffer.encode())

with TemporaryFileName() as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
path, serializer=CustomSerializer()
)
with open(path, "r") as fp:
self.assertEquals(fp.read(), expected_buffer)

def test_raise_on_invalid_save_argument_type(self):
with self.assertRaises(roar.BeartypeException):
ExportOutput(torch.nn.Linear(2, 3)) # type: ignore[arg-type]
export_output = ExportOutput(onnx.ModelProto())
with self.assertRaises(roar.BeartypeException):
export_output.save(None) # type: ignore[arg-type]
export_output.model_proto


class TestProtobufExportOutputSerializerAPI(unittest.TestCase):
def test_raise_on_invalid_argument_type(self):
with self.assertRaises(roar.BeartypeException):
serializer = ProtobufExportOutputSerializer()
serializer.serialize(None, None) # type: ignore[arg-type]


if __name__ == "__main__":
unittest.main()
32 changes: 20 additions & 12 deletions test/onnx/test_fx_dynamic_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,31 @@

import numpy as np

import onnx.reference
import onnx_test_common

import onnxruntime # type: ignore[import]

import torch
import torchvision
from torch.onnx._internal import _beartype, diagnostics, fx as fx_onnx
import torch.onnx
from torch.onnx._internal import _beartype, diagnostics
from torch.testing._internal import common_utils
from torch.types import Number
from torch.utils import _pytree as pytree

_NumericType = Union[Number, torch.Tensor, np.ndarray]
_ModelType = Union[torch.nn.Module, Callable]
_ONNXModelType = Union["onnx.ModelProto", bytes, str, io.BytesIO]
_InputArgsType = Union[torch.Tensor, Tuple[Any, ...]]
_OutputsType = Sequence[_NumericType]


@_beartype.beartype
def _run_ort(
onnx_model: _ONNXModelType, pytorch_inputs: _InputArgsType
export_output: torch.onnx.ExportOutput, pytorch_inputs: _InputArgsType
) -> _OutputsType:
buffer = io.BytesIO()
export_output.save(buffer)
session = onnxruntime.InferenceSession(
onnx_model, providers=["CPUExecutionProvider"]
buffer.getvalue(), providers=["CPUExecutionProvider"]
)
input_names = [ort_input.name for ort_input in session.get_inputs()]
return session.run(
Expand Down Expand Up @@ -81,7 +81,7 @@ def _try_clone_model(model: _ModelType) -> _ModelType:

@_beartype.beartype
def compare_pytorch_onnx_with_ort(
onnx_model: Union["onnx.ModelProto", bytes],
export_output: torch.onnx.ExportOutput,
model_input_args: _InputArgsType,
):
# Inspect the model's signature. It will be used
Expand All @@ -101,7 +101,7 @@ def compare_pytorch_onnx_with_ort(

pt_cloned_model = _try_clone_model(model)
ref_outputs, _ = pytree.tree_flatten(pt_cloned_model(*model_input_args))
ort_outputs = _run_ort(onnx_model, bound.args)
ort_outputs = _run_ort(export_output, bound.args)
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
torch.testing.assert_close(
ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol
Expand All @@ -110,13 +110,15 @@ def compare_pytorch_onnx_with_ort(
# Feed args and kwargs into exporter.
# Note that exporter should flatten kwargs into positional args the exported model;
# since ONNX doesn't represent kwargs.
onnx_model = fx_onnx.export_after_normalizing_args_and_kwargs(

onnx_model = torch.onnx.dynamo_export(
model,
*input_args,
opset_version=opset_version,
use_binary_format=True,
enable_dynamic_axes=True, # export models with dynamic shapes
**input_kwargs,
export_options=torch.onnx.ExportOptions(
opset_version=opset_version,
dynamic_shapes=True,
),
)

compare_pytorch_onnx_with_ort(onnx_model, input_args)
Expand Down Expand Up @@ -148,7 +150,11 @@ def tearDown(self):
"typing.Union[float, int, str, bytes, typing.Sequence[float],"
" typing.Sequence[int], torch.Tensor], as [None, None]:"
)
# When the skip reason above is addressed, annotate this test with
# @skipIfNoTorchVision
def test_shufflenet_v2_dynamic_axes(self):
import torchvision

model = torchvision.models.shufflenet_v2_x0_5(pretrained=False)
dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)
test_inputs = torch.randn(3, 3, 224, 224, requires_grad=True)
Expand Down Expand Up @@ -349,6 +355,7 @@ def forward(self, x):
additional_test_inputs=[(x2,)],
)

@unittest.skip("ORT segfault")
def test_expand_as_fill_seperate_tensor(self):
class Model(torch.nn.Module):
def forward(self, x):
Expand Down Expand Up @@ -377,6 +384,7 @@ def forward(self, input):
additional_test_inputs=[(another_x,)],
)

@unittest.skip("ORT segfault")
def test_flatten_dynamic_axes(self):
class MyModule(torch.nn.Module):
def forward(self, x):
Expand Down
18 changes: 10 additions & 8 deletions test/onnx/test_fx_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,27 @@
import torch
from torch import nn
from torch.nn import functional as F
from torch.onnx._internal import fx as fx_onnx
from torch.onnx import dynamo_export, ExportOptions
from torch.testing._internal import common_utils


class TestFxToOnnx(pytorch_test_common.ExportTestCase):
def setUp(self):
super().setUp()
self.opset_version = torch.onnx._constants.ONNX_DEFAULT_OPSET
self.export_options = ExportOptions()

def test_simple_function(self):
def func(x):
y = x + 1
z = y.relu()
return (y, z)

_ = fx_onnx.export(func, torch.randn(1, 1, 2), opset_version=self.opset_version)
_ = dynamo_export(
func, torch.randn(1, 1, 2), export_options=self.export_options
)

@unittest.skip(
"Conv Op is not supported at the time. https://github.com/microsoft/onnx-script/issues/397"
"max_pool2d is not supported in ATen Lib: https://github.com/microsoft/onnx-script/issues/585"
)
def test_mnist(self):
class MNISTModel(nn.Module):
Expand All @@ -48,7 +50,7 @@ def forward(self, tensor_x: torch.Tensor):
return output

tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32)
_ = fx_onnx.export(MNISTModel(), tensor_x, opset_version=self.opset_version)
_ = dynamo_export(MNISTModel(), tensor_x, export_options=self.export_options)

def test_trace_only_op_with_evaluator(self):
model_input = torch.tensor([[1.0, 2.0, 3.0], [1.0, 1.0, 2.0]])
Expand All @@ -64,8 +66,8 @@ def forward(self, input):
torch.argmax(input, dim=1, keepdim=True),
)

_ = fx_onnx.export(
ArgminArgmaxModel(), model_input, opset_version=self.opset_version
_ = dynamo_export(
ArgminArgmaxModel(), model_input, export_options=self.export_options
)

def test_multiple_outputs_op_with_evaluator(self):
Expand All @@ -75,7 +77,7 @@ def forward(self, x):
return torch.sum(values)

x = torch.arange(1.0, 6.0, requires_grad=True)
_ = fx_onnx.export(TopKModel(), x, opset_version=self.opset_version)
_ = dynamo_export(TopKModel(), x, export_options=self.export_options)


if __name__ == "__main__":
Expand Down

0 comments on commit 851fa7e

Please sign in to comment.