Skip to content

Commit

Permalink
Introduce torch.onnx.dynamo_export API
Browse files Browse the repository at this point in the history
This is the first phase of the new ONNX exporter API for exporting from
TorchDynamo and FX, and represents the beginning of a new era for exporting
ONNX from PyTorch.

The API here is a starting point upon which we will layer more capability
and expressiveness in subsequent phases. This first phase introduces the
following into `torch.onnx`:

dynamo_export(
    model: torch.nn.Module,
    /,
    *model_args,
    export_options: Optional[ExportOptions] = None,
    **model_kwargs,
) -> ExportOutput

class ExportOptions:
    opset_version: Optional[int] = None
    dynamic_shapes: Optional[bool] = None
    logger: Optional[logging.Logger] = None

class ExportOutputSerializer(Protocol):
    def serialize(
        self,
        export_output: ExportOutput,
        destination: io.BufferedIOBase
    ) -> None:
        ...

class ExportOutput:
    export_options: ExportOptions
    model_proto: onnx.ModelProto

    def save(
        self,
        destination: Union[str, io.BufferedIOBase],
        *,
        serializer: Optional[ExportOutputSerializer] = None
    ) -> None:
        ...

Co-authored-by: Bowen Bao <bowbao@microsoft.com>
Co-authored-by: Aaron Bockover <abock@microsoft.com>
  • Loading branch information
3 people committed Mar 29, 2023
1 parent 2f86c9b commit 3a01a1a
Show file tree
Hide file tree
Showing 5 changed files with 428 additions and 1 deletion.
152 changes: 152 additions & 0 deletions test/onnx/dynamo/test_exporter_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Owner(s): ["module: onnx"]
import io
import logging
import unittest

import onnx

import torch
from beartype import roar
from torch.onnx import dynamo_export, ExportOptions, ExportOutput
from torch.onnx._internal.exporter import (
_DEFAULT_OPSET_VERSION,
_resolve_export_options,
ExportOutputSerializer,
ProtobufExportOutputSerializer,
)

from torch.testing._internal.common_utils import TemporaryFileName


class SampleModel(torch.nn.Module):
def forward(self, x):
y = x + 1
z = y.relu()
return (y, z)


class TestExportOptionsAPI(unittest.TestCase):
def test_opset_version_default(self):
options = _resolve_export_options(None)
self.assertEquals(options.opset_version, _DEFAULT_OPSET_VERSION)

def test_opset_version_explicit(self):
options = _resolve_export_options(ExportOptions(opset_version=3000))
self.assertEquals(options.opset_version, 3000)

def test_raise_on_invalid_argument_type(self):
expected_exception_type = roar.BeartypeException
with self.assertRaises(expected_exception_type):
options = ExportOptions(opset_version="3000") # type: ignore[arg-type]
with self.assertRaises(expected_exception_type):
options = ExportOptions(dynamic_shapes=2) # type: ignore[arg-type]
with self.assertRaises(expected_exception_type):
options = ExportOptions(logger="DEBUG") # type: ignore[arg-type]
with self.assertRaises(expected_exception_type):
_resolve_export_options(options=12) # type: ignore[arg-type]

def test_dynamic_shapes_default(self):
options = _resolve_export_options(None)
self.assertIsNone(options.dynamic_shapes)

def test_dynamic_shapes_explicit(self):
options = _resolve_export_options(ExportOptions(dynamic_shapes=None))
self.assertIsNone(options.dynamic_shapes)
options = _resolve_export_options(ExportOptions(dynamic_shapes=True))
self.assertTrue(options.dynamic_shapes)
options = _resolve_export_options(ExportOptions(dynamic_shapes=False))
self.assertFalse(options.dynamic_shapes)

def test_logger_default(self):
options = _resolve_export_options(None)
self.assertEquals(options.logger, logging.getLogger().getChild("torch.onnx"))

def test_logger_explicit(self):
options = _resolve_export_options(ExportOptions(logger=logging.getLogger()))
self.assertEquals(options.logger, logging.getLogger())
self.assertNotEquals(options.logger, logging.getLogger().getChild("torch.onnx"))


class TestDynamoExportAPI(unittest.TestCase):
def test_default_export(self):
output = dynamo_export(SampleModel(), torch.randn(1, 1, 2))
self.assertIsInstance(output, ExportOutput)
self.assertIsInstance(output.model_proto, onnx.ModelProto)

def test_export_with_options(self):
self.assertIsInstance(
dynamo_export(
SampleModel(),
torch.randn(1, 1, 2),
export_options=ExportOptions(
opset_version=17,
logger=logging.getLogger(),
dynamic_shapes=True,
),
),
ExportOutput,
)

def test_save_to_file_default_serializer(self):
with TemporaryFileName() as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(path)
onnx.load(path)

def test_save_to_existing_buffer_default_serializer(self):
buffer = io.BytesIO()
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(buffer)
onnx.load(buffer)

def test_save_to_file_using_specified_serializer(self):
expected_buffer = "I am not actually ONNX"

class CustomSerializer(ExportOutputSerializer):
def serialize(
self, export_output: ExportOutput, destination: io.BufferedIOBase
) -> None:
destination.write(expected_buffer.encode())

with TemporaryFileName() as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
path, serializer=CustomSerializer()
)
with open(path, "r") as fp:
self.assertEquals(fp.read(), expected_buffer)

def test_save_to_file_using_specified_serializer_without_inheritance(self):
expected_buffer = "I am not actually ONNX"

# NOTE: Inheritance from `ExportOutputSerializer` is not required.
# Because `ExportOutputSerializer` is a Protocol class.
# `beartype` will not complain.
class CustomSerializer:
def serialize(
self, export_output: ExportOutput, destination: io.BufferedIOBase
) -> None:
destination.write(expected_buffer.encode())

with TemporaryFileName() as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
path, serializer=CustomSerializer()
)
with open(path, "r") as fp:
self.assertEquals(fp.read(), expected_buffer)

def test_raise_on_invalid_save_argument_type(self):
with self.assertRaises(roar.BeartypeException):
ExportOutput(torch.nn.Linear(2, 3)) # type: ignore[arg-type]
export_output = ExportOutput(ExportOptions(), onnx.ModelProto())
with self.assertRaises(roar.BeartypeException):
export_output.save(None) # type: ignore[arg-type]
export_output.model_proto


class TestProtobufExportOutputSerializerAPI(unittest.TestCase):
def test_raise_on_invalid_argument_type(self):
with self.assertRaises(roar.BeartypeException):
serializer = ProtobufExportOutputSerializer()
serializer.serialize(None, None) # type: ignore[arg-type]


if __name__ == "__main__":
unittest.main()
12 changes: 12 additions & 0 deletions torch/onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
"""ONNX exporter."""

from ._internal.exporter import (
ExportOptions,
ExportOutput,
ExportOutputSerializer,
dynamo_export,
)

from torch import _C
from torch._C import _onnx as _C_onnx
from torch._C._onnx import (
Expand Down Expand Up @@ -81,6 +88,11 @@
"enable_log",
# Errors
"CheckerError", # Backwards compatibility
# Dynamo Exporter
"ExportOptions",
"ExportOutput",
"ExportOutputSerializer",
"dynamo_export",
]

# Set namespace for exposed private names
Expand Down

0 comments on commit 3a01a1a

Please sign in to comment.