Skip to content

Commit

Permalink
Introduce torch.onnx.dynamo_export API
Browse files Browse the repository at this point in the history
This is the first phase of the new ONNX exporter API for exporting from
TorchDynamo and FX, and represents the beginning of a new era for exporting
ONNX from PyTorch.

The API here is a starting point upon which we will layer more capability
and expressiveness in subsequent phases. This first phase introduces the
following into `torch.onnx`:

dynamo_export(
    model: torch.nn.Module,
    /,
    *model_args,
    export_options: Optional[ExportOptions] = None,
    **model_kwargs,
) -> ExportOutput

class ExportOptions:
    opset_version: Optional[int] = None
    dynamic_shapes: Optional[bool] = None
    logger: Optional[logging.Logger] = None

class ExportOutputSerializer(Protocol):
    def serialize(
        self,
        export_output: ExportOutput,
        destination: io.BufferedIOBase
    ) -> None:
        ...

class ExportOutput:
    export_options: ExportOptions
    model_proto: onnx.ModelProto

    def save(
        self,
        destination: Union[str, io.BufferedIOBase],
        *,
        serializer: Optional[ExportOutputSerializer] = None
    ) -> None:
        ...

Co-authored-by: Bowen Bao <bowbao@microsoft.com>
Co-authored-by: Aaron Bockover <abock@microsoft.com>
  • Loading branch information
3 people committed Mar 30, 2023
1 parent 1f85390 commit 882c198
Show file tree
Hide file tree
Showing 5 changed files with 461 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .ci/docker/common/install_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pip_install \
transformers==4.25.1

# TODO: change this when onnx-script is on testPypi
pip_install "onnx-script@git+https://github.com/microsoft/onnx-script@1e8d764a9be04323d7171e4d5f511332790cb809"
pip_install "onnx-script@git+https://github.com/microsoft/onnx-script@25b095b4cf22381d73d2e167ae1dd0873b2e5a1f"

# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
Expand Down
152 changes: 152 additions & 0 deletions test/onnx/dynamo/test_exporter_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Owner(s): ["module: onnx"]
import io
import logging
import unittest

import onnx

import torch
from beartype import roar
from torch.onnx import dynamo_export, ExportOptions, ExportOutput
from torch.onnx._internal.exporter import (
_DEFAULT_OPSET_VERSION,
_resolve_export_options,
ExportOutputSerializer,
ProtobufExportOutputSerializer,
)

from torch.testing._internal.common_utils import TemporaryFileName


class SampleModel(torch.nn.Module):
def forward(self, x):
y = x + 1
z = y.relu()
return (y, z)


class TestExportOptionsAPI(unittest.TestCase):
def test_opset_version_default(self):
options = _resolve_export_options(None)
self.assertEquals(options.opset_version, _DEFAULT_OPSET_VERSION)

def test_opset_version_explicit(self):
options = _resolve_export_options(ExportOptions(opset_version=3000))
self.assertEquals(options.opset_version, 3000)

def test_raise_on_invalid_argument_type(self):
expected_exception_type = roar.BeartypeException
with self.assertRaises(expected_exception_type):
ExportOptions(opset_version="3000") # type: ignore[arg-type]
with self.assertRaises(expected_exception_type):
ExportOptions(dynamic_shapes=2) # type: ignore[arg-type]
with self.assertRaises(expected_exception_type):
ExportOptions(logger="DEBUG") # type: ignore[arg-type]
with self.assertRaises(expected_exception_type):
_resolve_export_options(options=12) # type: ignore[arg-type]

def test_dynamic_shapes_default(self):
options = _resolve_export_options(None)
self.assertIsNone(options.dynamic_shapes)

def test_dynamic_shapes_explicit(self):
options = _resolve_export_options(ExportOptions(dynamic_shapes=None))
self.assertIsNone(options.dynamic_shapes)
options = _resolve_export_options(ExportOptions(dynamic_shapes=True))
self.assertTrue(options.dynamic_shapes)
options = _resolve_export_options(ExportOptions(dynamic_shapes=False))
self.assertFalse(options.dynamic_shapes)

def test_logger_default(self):
options = _resolve_export_options(None)
self.assertEquals(options.logger, logging.getLogger().getChild("torch.onnx"))

def test_logger_explicit(self):
options = _resolve_export_options(ExportOptions(logger=logging.getLogger()))
self.assertEquals(options.logger, logging.getLogger())
self.assertNotEquals(options.logger, logging.getLogger().getChild("torch.onnx"))


class TestDynamoExportAPI(unittest.TestCase):
def test_default_export(self):
output = dynamo_export(SampleModel(), torch.randn(1, 1, 2))
self.assertIsInstance(output, ExportOutput)
self.assertIsInstance(output.model_proto, onnx.ModelProto)

def test_export_with_options(self):
self.assertIsInstance(
dynamo_export(
SampleModel(),
torch.randn(1, 1, 2),
export_options=ExportOptions(
opset_version=17,
logger=logging.getLogger(),
dynamic_shapes=True,
),
),
ExportOutput,
)

def test_save_to_file_default_serializer(self):
with TemporaryFileName() as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(path)
onnx.load(path)

def test_save_to_existing_buffer_default_serializer(self):
buffer = io.BytesIO()
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(buffer)
onnx.load(buffer)

def test_save_to_file_using_specified_serializer(self):
expected_buffer = "I am not actually ONNX"

class CustomSerializer(ExportOutputSerializer):
def serialize(
self, export_output: ExportOutput, destination: io.BufferedIOBase
) -> None:
destination.write(expected_buffer.encode())

with TemporaryFileName() as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
path, serializer=CustomSerializer()
)
with open(path, "r") as fp:
self.assertEquals(fp.read(), expected_buffer)

def test_save_to_file_using_specified_serializer_without_inheritance(self):
expected_buffer = "I am not actually ONNX"

# NOTE: Inheritance from `ExportOutputSerializer` is not required.
# Because `ExportOutputSerializer` is a Protocol class.
# `beartype` will not complain.
class CustomSerializer:
def serialize(
self, export_output: ExportOutput, destination: io.BufferedIOBase
) -> None:
destination.write(expected_buffer.encode())

with TemporaryFileName() as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
path, serializer=CustomSerializer()
)
with open(path, "r") as fp:
self.assertEquals(fp.read(), expected_buffer)

def test_raise_on_invalid_save_argument_type(self):
with self.assertRaises(roar.BeartypeException):
ExportOutput(torch.nn.Linear(2, 3)) # type: ignore[arg-type]
export_output = ExportOutput(onnx.ModelProto())
with self.assertRaises(roar.BeartypeException):
export_output.save(None) # type: ignore[arg-type]
export_output.model_proto


class TestProtobufExportOutputSerializerAPI(unittest.TestCase):
def test_raise_on_invalid_argument_type(self):
with self.assertRaises(roar.BeartypeException):
serializer = ProtobufExportOutputSerializer()
serializer.serialize(None, None) # type: ignore[arg-type]


if __name__ == "__main__":
unittest.main()
12 changes: 12 additions & 0 deletions torch/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@
unregister_custom_op_symbolic,
)

from ._internal.exporter import ( # usort:skip. needs to be last to avoid circular import
ExportOptions,
ExportOutput,
ExportOutputSerializer,
dynamo_export,
)

__all__ = [
# Modules
"symbolic_helper",
Expand Down Expand Up @@ -81,6 +88,11 @@
"enable_log",
# Errors
"CheckerError", # Backwards compatibility
# Dynamo Exporter
"ExportOptions",
"ExportOutput",
"ExportOutputSerializer",
"dynamo_export",
]

# Set namespace for exposed private names
Expand Down

0 comments on commit 882c198

Please sign in to comment.