-
Notifications
You must be signed in to change notification settings - Fork 21.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce torch.onnx.dynamo_export API
This is the first phase of the new ONNX exporter API for exporting from TorchDynamo and FX, and represents the beginning of a new era for exporting ONNX from PyTorch. The API here is a starting point upon which we will layer more capability and expressiveness in subsequent phases. This first phase introduces the following into `torch.onnx`: dynamo_export( model: torch.nn.Module, /, *model_args, export_options: Optional[ExportOptions] = None, **model_kwargs, ) -> ExportOutput class ExportOptions: opset_version: Optional[int] = None dynamic_shapes: Optional[bool] = None logger: Optional[logging.Logger] = None class ExportOutputSerializer(Protocol): def serialize( self, export_output: ExportOutput, destination: io.BufferedIOBase ) -> None: ... class ExportOutput: export_options: ExportOptions model_proto: onnx.ModelProto def save( self, destination: Union[str, io.BufferedIOBase], *, serializer: Optional[ExportOutputSerializer] = None ) -> None: ... Co-authored-by: Bowen Bao <bowbao@microsoft.com> Co-authored-by: Aaron Bockover <abock@microsoft.com>
- Loading branch information
1 parent
1f85390
commit 388b918
Showing
5 changed files
with
508 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# Owner(s): ["module: onnx"] | ||
import io | ||
import logging | ||
import unittest | ||
|
||
import onnx | ||
|
||
import torch | ||
from beartype import roar | ||
from torch.onnx import dynamo_export, ExportOptions, ExportOutput | ||
from torch.onnx._internal.exporter import ( | ||
_DEFAULT_OPSET_VERSION, | ||
_resolve_export_options, | ||
ExportOutputSerializer, | ||
ProtobufExportOutputSerializer, | ||
) | ||
|
||
from torch.testing._internal.common_utils import TemporaryFileName | ||
|
||
|
||
class SampleModel(torch.nn.Module): | ||
def forward(self, x): | ||
y = x + 1 | ||
z = y.relu() | ||
return (y, z) | ||
|
||
|
||
class TestExportOptionsAPI(unittest.TestCase): | ||
def test_opset_version_default(self): | ||
options = _resolve_export_options(None) | ||
self.assertEquals(options.opset_version, _DEFAULT_OPSET_VERSION) | ||
|
||
def test_opset_version_explicit(self): | ||
options = _resolve_export_options(ExportOptions(opset_version=3000)) | ||
self.assertEquals(options.opset_version, 3000) | ||
|
||
def test_raise_on_invalid_argument_type(self): | ||
expected_exception_type = roar.BeartypeException | ||
with self.assertRaises(expected_exception_type): | ||
ExportOptions(opset_version="3000") # type: ignore[arg-type] | ||
with self.assertRaises(expected_exception_type): | ||
ExportOptions(dynamic_shapes=2) # type: ignore[arg-type] | ||
with self.assertRaises(expected_exception_type): | ||
ExportOptions(logger="DEBUG") # type: ignore[arg-type] | ||
with self.assertRaises(expected_exception_type): | ||
_resolve_export_options(options=12) # type: ignore[arg-type] | ||
|
||
def test_dynamic_shapes_default(self): | ||
options = _resolve_export_options(None) | ||
self.assertIsNone(options.dynamic_shapes) | ||
|
||
def test_dynamic_shapes_explicit(self): | ||
options = _resolve_export_options(ExportOptions(dynamic_shapes=None)) | ||
self.assertIsNone(options.dynamic_shapes) | ||
options = _resolve_export_options(ExportOptions(dynamic_shapes=True)) | ||
self.assertTrue(options.dynamic_shapes) | ||
options = _resolve_export_options(ExportOptions(dynamic_shapes=False)) | ||
self.assertFalse(options.dynamic_shapes) | ||
|
||
def test_logger_default(self): | ||
options = _resolve_export_options(None) | ||
self.assertEquals(options.logger, logging.getLogger().getChild("torch.onnx")) | ||
|
||
def test_logger_explicit(self): | ||
options = _resolve_export_options(ExportOptions(logger=logging.getLogger())) | ||
self.assertEquals(options.logger, logging.getLogger()) | ||
self.assertNotEquals(options.logger, logging.getLogger().getChild("torch.onnx")) | ||
|
||
|
||
class TestDynamoExportAPI(unittest.TestCase): | ||
def test_default_export(self): | ||
output = dynamo_export(SampleModel(), torch.randn(1, 1, 2)) | ||
self.assertIsInstance(output, ExportOutput) | ||
self.assertIsInstance(output.model_proto, onnx.ModelProto) | ||
|
||
def test_export_with_options(self): | ||
self.assertIsInstance( | ||
dynamo_export( | ||
SampleModel(), | ||
torch.randn(1, 1, 2), | ||
export_options=ExportOptions( | ||
opset_version=17, | ||
logger=logging.getLogger(), | ||
dynamic_shapes=True, | ||
), | ||
), | ||
ExportOutput, | ||
) | ||
|
||
def test_save_to_file_default_serializer(self): | ||
with TemporaryFileName() as path: | ||
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(path) | ||
onnx.load(path) | ||
|
||
def test_save_to_existing_buffer_default_serializer(self): | ||
buffer = io.BytesIO() | ||
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(buffer) | ||
onnx.load(buffer) | ||
|
||
def test_save_to_file_using_specified_serializer(self): | ||
expected_buffer = "I am not actually ONNX" | ||
|
||
class CustomSerializer(ExportOutputSerializer): | ||
def serialize( | ||
self, export_output: ExportOutput, destination: io.BufferedIOBase | ||
) -> None: | ||
destination.write(expected_buffer.encode()) | ||
|
||
with TemporaryFileName() as path: | ||
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save( | ||
path, serializer=CustomSerializer() | ||
) | ||
with open(path, "r") as fp: | ||
self.assertEquals(fp.read(), expected_buffer) | ||
|
||
def test_save_to_file_using_specified_serializer_without_inheritance(self): | ||
expected_buffer = "I am not actually ONNX" | ||
|
||
# NOTE: Inheritance from `ExportOutputSerializer` is not required. | ||
# Because `ExportOutputSerializer` is a Protocol class. | ||
# `beartype` will not complain. | ||
class CustomSerializer: | ||
def serialize( | ||
self, export_output: ExportOutput, destination: io.BufferedIOBase | ||
) -> None: | ||
destination.write(expected_buffer.encode()) | ||
|
||
with TemporaryFileName() as path: | ||
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save( | ||
path, serializer=CustomSerializer() | ||
) | ||
with open(path, "r") as fp: | ||
self.assertEquals(fp.read(), expected_buffer) | ||
|
||
def test_raise_on_invalid_save_argument_type(self): | ||
with self.assertRaises(roar.BeartypeException): | ||
ExportOutput(torch.nn.Linear(2, 3)) # type: ignore[arg-type] | ||
export_output = ExportOutput(onnx.ModelProto()) | ||
with self.assertRaises(roar.BeartypeException): | ||
export_output.save(None) # type: ignore[arg-type] | ||
export_output.model_proto | ||
|
||
|
||
class TestProtobufExportOutputSerializerAPI(unittest.TestCase): | ||
def test_raise_on_invalid_argument_type(self): | ||
with self.assertRaises(roar.BeartypeException): | ||
serializer = ProtobufExportOutputSerializer() | ||
serializer.serialize(None, None) # type: ignore[arg-type] | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.