Skip to content

Commit

Permalink
Rename torch.onnx.ExportOutput* to ONNXProgram* (pytorch#112263)
Browse files Browse the repository at this point in the history
Since PyTorch 2.1, torch.export API was introduced and the term "export"
got overloaded due to the already existing torch.onnx.export API.

The torch.onnx.dynamo_export API was introduced on pyTorch 2.0 and it
exposed a torch.onnx.ExportOutput which now can be confused with
torch.export.export output

To prevent such ambiguity and standardize names around the new
torch.export.ExportedProgram, this PR renames torch.onnx.ExportOutput to
torch.onnx.ONNXProgram

Pull Request resolved: pytorch#112263
Approved by: https://github.com/BowenBao
ghstack dependencies: pytorch#112444
  • Loading branch information
Thiago Crepaldi authored and xuhancn committed Nov 8, 2023
1 parent d9dcb46 commit dc0dc91
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 193 deletions.
30 changes: 15 additions & 15 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,31 +1527,31 @@ class OnnxModelFromDynamo(OnnxModel):
def __init__(self, output_directory, model, example_inputs, dynamic_shapes: bool):
super().__init__(output_directory, model, example_inputs, dynamic_shapes)
self._dynamic_shapes = dynamic_shapes
self._export_output = self._export(model, example_inputs, self.model_path)
self._onnx_program = self._export(model, example_inputs, self.model_path)
# Clear the model proto to save memory.
# The model proto is saved to disk and no longer needed from `export_output`.
# `export_output` is kept for i/o adapter usage.
self._export_output.model_proto.Clear()
# The model proto is saved to disk and no longer needed from `onnx_program`.
# `onnx_program` is kept for i/o adapter usage.
self._onnx_program.model_proto.Clear()
self.onnx_session = self._init_ort_session(self.model_path)

def _export(
self, model, example_inputs, output_path: str
) -> torch.onnx.ExportOutput:
) -> torch.onnx.ONNXProgram:
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
export_output = torch.onnx.dynamo_export(
onnx_program = torch.onnx.dynamo_export(
model, *example_args, **example_kwargs, export_options=options
)

export_output.save(output_path)
return export_output
onnx_program.save(output_path)
return onnx_program

def format_pt_inputs(self, pt_inputs):
pt_args, pt_kwargs = _normalize_bench_inputs(pt_inputs)
return self._export_output.adapt_torch_inputs_to_onnx(*pt_args, **pt_kwargs)
return self._onnx_program.adapt_torch_inputs_to_onnx(*pt_args, **pt_kwargs)

def format_pt_outputs(self, pt_outputs):
return self._export_output.adapt_torch_outputs_to_onnx(pt_outputs)
return self._onnx_program.adapt_torch_outputs_to_onnx(pt_outputs)


class OnnxModelFromDynamoAotInline(OnnxModelFromDynamo):
Expand All @@ -1561,10 +1561,10 @@ class OnnxModelFromDynamoAotInline(OnnxModelFromDynamo):

def _export(
self, model, example_inputs, output_path: str
) -> torch.onnx.ExportOutput:
) -> torch.onnx.ONNXProgram:
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
export_output = torch.onnx.dynamo_export(
onnx_program = torch.onnx.dynamo_export(
model, *example_args, **example_kwargs, export_options=options
)
# Apply AOT inline post export.
Expand All @@ -1575,12 +1575,12 @@ def _export(
# Workaround for inliner not supporting with models larger than 2GB.
# Save model to disk first separating out external data,
# and load back without external data for inliner to work on.
model_proto = export_output.model_proto
model_proto = onnx_program.model_proto
onnx.save_model(model_proto, output_path, save_as_external_data=True)
model_proto = onnx.load(output_path, load_external_data=False)
model_proto = onnx.inliner.inline_local_functions(model_proto)
onnx.save_model(model_proto, output_path)
return export_output
return onnx_program


class _OnnxPatch:
Expand Down Expand Up @@ -1786,7 +1786,7 @@ def run_n_iterations_onnx(model, inputs, n=2):
return outputs
except exporter.OnnxExporterError as e:
# `torch.onnx.dynamo_export` raises error that encloses diagnostics.
diagnostic_context = e.export_output.diagnostic_context
diagnostic_context = e.onnx_program.diagnostic_context
for parsed_error in parser.parse_diagnostic_context(diagnostic_context):
output_csv(
output_error_filename, parsed_error.headers, parsed_error.row
Expand Down
18 changes: 9 additions & 9 deletions docs/source/onnx_dynamo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ The exporter is designed to be modular and extensible. It is composed of the fol
- **ONNX Registry**: :class:`OnnxRegistry` is the registry of ONNX operators and functions.
- **FX Graph Extractor**: :class:`FXGraphExtractor` extracts the FX graph from the PyTorch model.
- **Fake Mode**: :class:`ONNXFakeContext` is a context manager that enables fake mode for large scale models.
- **ONNX Export Output**: :class:`ExportOutput` is the output of the exporter that contains the exported ONNX graph and diagnostics.
- **ONNX Export Output Serializer**: :class:`ExportOutputSerializer` serializes the exported model to a file.
- **ONNX Program**: :class:`ONNXProgram` is the output of the exporter that contains the exported ONNX graph and diagnostics.
- **ONNX Program Serializer**: :class:`ONNXProgramSerializer` serializes the exported model to a file.
- **ONNX Diagnostic Options**: :class:`DiagnosticOptions` has a set of options that control the diagnostics emitted by the exporter.

Dependencies
Expand Down Expand Up @@ -74,17 +74,17 @@ See below a demonstration of exporter API in action with a simple Multilayer Per
model = MLPModel()
tensor_x = torch.rand((97, 8), dtype=torch.float32)
export_output = torch.onnx.dynamo_export(model, tensor_x)
onnx_program = torch.onnx.dynamo_export(model, tensor_x)
As the code above shows, all you need is to provide :func:`torch.onnx.dynamo_export` with an instance of the model and its input.
The exporter will then return an instance of :class:`torch.onnx.ExportOutput` that contains the exported ONNX graph along with extra information.
The exporter will then return an instance of :class:`torch.onnx.ONNXProgram` that contains the exported ONNX graph along with extra information.

The in-memory model available through ``export_output.model_proto`` is an ``onnx.ModelProto`` object in compliance with the `ONNX IR spec <https://github.com/onnx/onnx/blob/main/docs/IR.md>`_.
The ONNX model may then be serialized into a `Protobuf file <https://protobuf.dev/>`_ using the :meth:`torch.onnx.ExportOutput.save` API.
The in-memory model available through ``onnx_program.model_proto`` is an ``onnx.ModelProto`` object in compliance with the `ONNX IR spec <https://github.com/onnx/onnx/blob/main/docs/IR.md>`_.
The ONNX model may then be serialized into a `Protobuf file <https://protobuf.dev/>`_ using the :meth:`torch.onnx.ONNXProgram.save` API.

.. code-block:: python
export_output.save("mlp.onnx")
onnx_program.save("mlp.onnx")
Inspecting the ONNX model using GUI
-----------------------------------
Expand Down Expand Up @@ -140,10 +140,10 @@ API Reference

.. autofunction:: torch.onnx.enable_fake_mode

.. autoclass:: torch.onnx.ExportOutput
.. autoclass:: torch.onnx.ONNXProgram
:members:

.. autoclass:: torch.onnx.ExportOutputSerializer
.. autoclass:: torch.onnx.ONNXProgramSerializer
:members:

.. autoclass:: torch.onnx.InvalidExportOptionsError
Expand Down
58 changes: 29 additions & 29 deletions test/onnx/dynamo/test_exporter_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import onnx
import torch
from beartype import roar
from torch.onnx import dynamo_export, ExportOptions, ExportOutput
from torch.onnx import dynamo_export, ExportOptions, ONNXProgram
from torch.onnx._internal import exporter, io_adapter
from torch.onnx._internal.exporter import (
ExportOutputSerializer,
LargeProtobufExportOutputSerializer,
ProtobufExportOutputSerializer,
LargeProtobufONNXProgramSerializer,
ONNXProgramSerializer,
ProtobufONNXProgramSerializer,
ResolvedExportOptions,
)
from torch.onnx._internal.fx import diagnostics
Expand Down Expand Up @@ -61,7 +61,7 @@ def test_dynamic_shapes_explicit(self):
class TestDynamoExportAPI(common_utils.TestCase):
def test_default_export(self):
output = dynamo_export(SampleModel(), torch.randn(1, 1, 2))
self.assertIsInstance(output, ExportOutput)
self.assertIsInstance(output, ONNXProgram)
self.assertIsInstance(output.model_proto, onnx.ModelProto)

def test_export_with_options(self):
Expand All @@ -73,7 +73,7 @@ def test_export_with_options(self):
dynamic_shapes=True,
),
),
ExportOutput,
ONNXProgram,
)

def test_save_to_file_default_serializer(self):
Expand All @@ -89,9 +89,9 @@ def test_save_to_existing_buffer_default_serializer(self):
def test_save_to_file_using_specified_serializer(self):
expected_buffer = "I am not actually ONNX"

class CustomSerializer(ExportOutputSerializer):
class CustomSerializer(ONNXProgramSerializer):
def serialize(
self, export_output: ExportOutput, destination: io.BufferedIOBase
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
) -> None:
destination.write(expected_buffer.encode())

Expand All @@ -105,12 +105,12 @@ def serialize(
def test_save_to_file_using_specified_serializer_without_inheritance(self):
expected_buffer = "I am not actually ONNX"

# NOTE: Inheritance from `ExportOutputSerializer` is not required.
# Because `ExportOutputSerializer` is a Protocol class.
# NOTE: Inheritance from `ONNXProgramSerializer` is not required.
# Because `ONNXProgramSerializer` is a Protocol class.
# `beartype` will not complain.
class CustomSerializer:
def serialize(
self, export_output: ExportOutput, destination: io.BufferedIOBase
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
) -> None:
destination.write(expected_buffer.encode())

Expand Down Expand Up @@ -146,17 +146,17 @@ def forward(self, x):
dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2))
self.assertTrue(os.path.exists(exporter._DEFAULT_FAILED_EXPORT_SARIF_LOG_PATH))

def test_export_output_accessible_from_exception_when_export_failed(self):
def test_onnx_program_accessible_from_exception_when_export_failed(self):
class ModelWithExportError(torch.nn.Module):
def forward(self, x):
raise RuntimeError("Export error")

with self.assertRaises(torch.onnx.OnnxExporterError) as cm:
dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2))
self.assertIsInstance(cm.exception, torch.onnx.OnnxExporterError)
self.assertIsInstance(cm.exception.export_output, ExportOutput)
self.assertIsInstance(cm.exception.onnx_program, ONNXProgram)

def test_access_export_output_model_proto_raises_when_export_output_is_emitted_from_failed_export(
def test_access_onnx_program_model_proto_raises_when_onnx_program_is_emitted_from_failed_export(
self,
):
class ModelWithExportError(torch.nn.Module):
Expand All @@ -165,9 +165,9 @@ def forward(self, x):

with self.assertRaises(torch.onnx.OnnxExporterError) as cm:
dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2))
export_output = cm.exception.export_output
onnx_program = cm.exception.onnx_program
with self.assertRaises(RuntimeError):
export_output.model_proto
onnx_program.model_proto

def test_raise_from_diagnostic_warning_when_diagnostic_option_warning_as_error_is_true(
self,
Expand All @@ -185,39 +185,39 @@ def test_raise_from_diagnostic_warning_when_diagnostic_option_warning_as_error_i

def test_raise_on_invalid_save_argument_type(self):
with self.assertRaises(roar.BeartypeException):
ExportOutput(torch.nn.Linear(2, 3)) # type: ignore[arg-type]
export_output = ExportOutput(
ONNXProgram(torch.nn.Linear(2, 3)) # type: ignore[arg-type]
onnx_program = ONNXProgram(
onnx.ModelProto(),
io_adapter.InputAdapter(),
io_adapter.OutputAdapter(),
diagnostics.DiagnosticContext("test", "1.0"),
fake_context=None,
)
with self.assertRaises(roar.BeartypeException):
export_output.save(None) # type: ignore[arg-type]
export_output.model_proto
onnx_program.save(None) # type: ignore[arg-type]
onnx_program.model_proto


class TestProtobufExportOutputSerializerAPI(common_utils.TestCase):
class TestProtobufONNXProgramSerializerAPI(common_utils.TestCase):
def test_raise_on_invalid_argument_type(self):
with self.assertRaises(roar.BeartypeException):
serializer = ProtobufExportOutputSerializer()
serializer = ProtobufONNXProgramSerializer()
serializer.serialize(None, None) # type: ignore[arg-type]

def test_serialize_raises_when_model_greater_than_2gb(self):
export_output = torch.onnx.dynamo_export(_LargeModel(), torch.randn(1))
serializer = ProtobufExportOutputSerializer()
onnx_program = torch.onnx.dynamo_export(_LargeModel(), torch.randn(1))
serializer = ProtobufONNXProgramSerializer()
with self.assertRaisesRegex(ValueError, "exceeds maximum protobuf size of 2GB"):
serializer.serialize(export_output, io.BytesIO())
serializer.serialize(onnx_program, io.BytesIO())


class TestLargeProtobufExportOutputSerializerAPI(common_utils.TestCase):
class TestLargeProtobufONNXProgramSerializerAPI(common_utils.TestCase):
def test_serialize_succeeds_when_model_greater_than_2gb(self):
export_output = torch.onnx.dynamo_export(_LargeModel(), torch.randn(1))
onnx_program = torch.onnx.dynamo_export(_LargeModel(), torch.randn(1))
with common_utils.TemporaryFileName() as path:
serializer = LargeProtobufExportOutputSerializer(path)
serializer = LargeProtobufONNXProgramSerializer(path)
# `io.BytesIO()` is unused, but required by the Protocol interface.
serializer.serialize(export_output, io.BytesIO())
serializer.serialize(onnx_program, io.BytesIO())


if __name__ == "__main__":
Expand Down
32 changes: 16 additions & 16 deletions test/onnx/onnx_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs):
return verification.verify(*args, options=options, **kwargs)


def assert_dynamic_shapes(export_output: torch.onnx.ExportOutput, dynamic_shapes: bool):
def assert_dynamic_shapes(onnx_program: torch.onnx.ONNXProgram, dynamic_shapes: bool):
"""Assert whether the exported model has dynamic shapes or not.
Args:
export_output (torch.onnx.ExportOutput): The output of torch.onnx.dynamo_export.
onnx_program (torch.onnx.ONNXProgram): The output of torch.onnx.dynamo_export.
dynamic_shapes (bool): Whether the exported model has dynamic shapes or not.
When True, raises if graph inputs don't have at least one dynamic dimension
When False, raises if graph inputs have at least one dynamic dimension.
Expand All @@ -101,7 +101,7 @@ def assert_dynamic_shapes(export_output: torch.onnx.ExportOutput, dynamic_shapes
if dynamic_shapes is None:
return

model_proto = export_output.model_proto
model_proto = onnx_program.model_proto
# Process graph inputs
dynamic_inputs = []
for inp in model_proto.graph.input:
Expand Down Expand Up @@ -270,7 +270,7 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
# since ONNX doesn't represent kwargs.
export_error: Optional[torch.onnx.OnnxExporterError] = None
try:
export_output = torch.onnx.dynamo_export(
onnx_program = torch.onnx.dynamo_export(
ref_model,
*ref_input_args,
**ref_input_kwargs,
Expand All @@ -284,10 +284,10 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
)
except torch.onnx.OnnxExporterError as e:
export_error = e
export_output = e.export_output
onnx_program = e.onnx_program

if verbose and diagnostics.is_onnx_diagnostics_log_artifact_enabled():
export_output.save_diagnostics(
onnx_program.save_diagnostics(
f"test_report_{self._testMethodName}"
f"_op_level_debug_{self.op_level_debug}"
f"_dynamic_axes_{self.dynamic_shapes}"
Expand All @@ -298,10 +298,10 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
raise export_error

if not skip_dynamic_shapes_check:
assert_dynamic_shapes(export_output, self.dynamic_shapes)
assert_dynamic_shapes(onnx_program, self.dynamic_shapes)

_compare_pytorch_onnx_with_ort(
export_output,
onnx_program,
model,
input_args,
input_kwargs,
Expand All @@ -324,7 +324,7 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
else {}
)
_compare_pytorch_onnx_with_ort(
export_output,
onnx_program,
model,
additional_input_args,
additional_input_kwargs,
Expand All @@ -336,15 +336,15 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(

@_beartype.beartype
def run_ort(
onnx_model: Union[str, torch.onnx.ExportOutput],
onnx_model: Union[str, torch.onnx.ONNXProgram],
pytorch_inputs: Sequence[_InputArgsType],
) -> _OutputsType:
"""Run ORT on the given ONNX model and inputs
Used in test_fx_to_onnx_with_onnxruntime.py
Args:
onnx_model (Union[str, torch.onnx.ExportOutput]): Converter ONNX model
onnx_model (Union[str, torch.onnx.ONNXProgram]): Converter ONNX model
pytorch_inputs (Sequence[_InputArgsType]): The given torch inputs
Raises:
Expand All @@ -353,7 +353,7 @@ def run_ort(
Returns:
_OutputsType: ONNX model predictions
"""
if isinstance(onnx_model, torch.onnx.ExportOutput):
if isinstance(onnx_model, torch.onnx.ONNXProgram):
buffer = io.BytesIO()
onnx_model.save(buffer)
ort_model = buffer.getvalue()
Expand Down Expand Up @@ -398,7 +398,7 @@ def _try_clone_inputs(input_args, input_kwargs):

@_beartype.beartype
def _compare_pytorch_onnx_with_ort(
export_output: torch.onnx.ExportOutput,
onnx_program: torch.onnx.ONNXProgram,
model: _ModelType,
input_args: Sequence[_InputArgsType],
input_kwargs: Mapping[str, _InputArgsType],
Expand All @@ -415,14 +415,14 @@ def _compare_pytorch_onnx_with_ort(
ref_input_kwargs = input_kwargs

# Format original model inputs into the format expected by exported ONNX model.
onnx_format_args = export_output.adapt_torch_inputs_to_onnx(
onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx(
*input_args, **input_kwargs
)

ref_outputs = export_output.adapt_torch_outputs_to_onnx(
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(
ref_model(*ref_input_args, **ref_input_kwargs)
)
ort_outputs = run_ort(export_output, onnx_format_args)
ort_outputs = run_ort(onnx_program, onnx_format_args)

if len(ref_outputs) != len(ort_outputs):
raise AssertionError(
Expand Down

0 comments on commit dc0dc91

Please sign in to comment.