From dc0dc91d57ef6d4a1adc7fab119ad02ffa5684fc Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Wed, 1 Nov 2023 23:54:07 +0000 Subject: [PATCH] Rename torch.onnx.ExportOutput* to ONNXProgram* (#112263) Since PyTorch 2.1, torch.export API was introduced and the term "export" got overloaded due to the already existing torch.onnx.export API. The torch.onnx.dynamo_export API was introduced on pyTorch 2.0 and it exposed a torch.onnx.ExportOutput which now can be confused with torch.export.export output To prevent such ambiguity and standardize names around the new torch.export.ExportedProgram, this PR renames torch.onnx.ExportOutput to torch.onnx.ONNXProgram Pull Request resolved: https://github.com/pytorch/pytorch/pull/112263 Approved by: https://github.com/BowenBao ghstack dependencies: #112444 --- benchmarks/dynamo/common.py | 30 +++--- docs/source/onnx_dynamo.rst | 18 ++-- test/onnx/dynamo/test_exporter_api.py | 58 +++++----- test/onnx/onnx_test_common.py | 32 +++--- test/onnx/test_fx_passes.py | 12 +-- test/onnx/test_fx_to_onnx.py | 76 +++++++------ test/onnx/test_fx_to_onnx_with_onnxruntime.py | 42 ++++---- torch/onnx/__init__.py | 12 +-- torch/onnx/_internal/exporter.py | 102 +++++++++--------- 9 files changed, 189 insertions(+), 193 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 7e4ab0ae9449..4919a9dc1efa 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -1527,31 +1527,31 @@ class OnnxModelFromDynamo(OnnxModel): def __init__(self, output_directory, model, example_inputs, dynamic_shapes: bool): super().__init__(output_directory, model, example_inputs, dynamic_shapes) self._dynamic_shapes = dynamic_shapes - self._export_output = self._export(model, example_inputs, self.model_path) + self._onnx_program = self._export(model, example_inputs, self.model_path) # Clear the model proto to save memory. - # The model proto is saved to disk and no longer needed from `export_output`. - # `export_output` is kept for i/o adapter usage. - self._export_output.model_proto.Clear() + # The model proto is saved to disk and no longer needed from `onnx_program`. + # `onnx_program` is kept for i/o adapter usage. + self._onnx_program.model_proto.Clear() self.onnx_session = self._init_ort_session(self.model_path) def _export( self, model, example_inputs, output_path: str - ) -> torch.onnx.ExportOutput: + ) -> torch.onnx.ONNXProgram: example_args, example_kwargs = _normalize_bench_inputs(example_inputs) options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes) - export_output = torch.onnx.dynamo_export( + onnx_program = torch.onnx.dynamo_export( model, *example_args, **example_kwargs, export_options=options ) - export_output.save(output_path) - return export_output + onnx_program.save(output_path) + return onnx_program def format_pt_inputs(self, pt_inputs): pt_args, pt_kwargs = _normalize_bench_inputs(pt_inputs) - return self._export_output.adapt_torch_inputs_to_onnx(*pt_args, **pt_kwargs) + return self._onnx_program.adapt_torch_inputs_to_onnx(*pt_args, **pt_kwargs) def format_pt_outputs(self, pt_outputs): - return self._export_output.adapt_torch_outputs_to_onnx(pt_outputs) + return self._onnx_program.adapt_torch_outputs_to_onnx(pt_outputs) class OnnxModelFromDynamoAotInline(OnnxModelFromDynamo): @@ -1561,10 +1561,10 @@ class OnnxModelFromDynamoAotInline(OnnxModelFromDynamo): def _export( self, model, example_inputs, output_path: str - ) -> torch.onnx.ExportOutput: + ) -> torch.onnx.ONNXProgram: example_args, example_kwargs = _normalize_bench_inputs(example_inputs) options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes) - export_output = torch.onnx.dynamo_export( + onnx_program = torch.onnx.dynamo_export( model, *example_args, **example_kwargs, export_options=options ) # Apply AOT inline post export. @@ -1575,12 +1575,12 @@ def _export( # Workaround for inliner not supporting with models larger than 2GB. # Save model to disk first separating out external data, # and load back without external data for inliner to work on. - model_proto = export_output.model_proto + model_proto = onnx_program.model_proto onnx.save_model(model_proto, output_path, save_as_external_data=True) model_proto = onnx.load(output_path, load_external_data=False) model_proto = onnx.inliner.inline_local_functions(model_proto) onnx.save_model(model_proto, output_path) - return export_output + return onnx_program class _OnnxPatch: @@ -1786,7 +1786,7 @@ def run_n_iterations_onnx(model, inputs, n=2): return outputs except exporter.OnnxExporterError as e: # `torch.onnx.dynamo_export` raises error that encloses diagnostics. - diagnostic_context = e.export_output.diagnostic_context + diagnostic_context = e.onnx_program.diagnostic_context for parsed_error in parser.parse_diagnostic_context(diagnostic_context): output_csv( output_error_filename, parsed_error.headers, parsed_error.row diff --git a/docs/source/onnx_dynamo.rst b/docs/source/onnx_dynamo.rst index afb06de674be..a156c51310c3 100644 --- a/docs/source/onnx_dynamo.rst +++ b/docs/source/onnx_dynamo.rst @@ -27,8 +27,8 @@ The exporter is designed to be modular and extensible. It is composed of the fol - **ONNX Registry**: :class:`OnnxRegistry` is the registry of ONNX operators and functions. - **FX Graph Extractor**: :class:`FXGraphExtractor` extracts the FX graph from the PyTorch model. - **Fake Mode**: :class:`ONNXFakeContext` is a context manager that enables fake mode for large scale models. - - **ONNX Export Output**: :class:`ExportOutput` is the output of the exporter that contains the exported ONNX graph and diagnostics. - - **ONNX Export Output Serializer**: :class:`ExportOutputSerializer` serializes the exported model to a file. + - **ONNX Program**: :class:`ONNXProgram` is the output of the exporter that contains the exported ONNX graph and diagnostics. + - **ONNX Program Serializer**: :class:`ONNXProgramSerializer` serializes the exported model to a file. - **ONNX Diagnostic Options**: :class:`DiagnosticOptions` has a set of options that control the diagnostics emitted by the exporter. Dependencies @@ -74,17 +74,17 @@ See below a demonstration of exporter API in action with a simple Multilayer Per model = MLPModel() tensor_x = torch.rand((97, 8), dtype=torch.float32) - export_output = torch.onnx.dynamo_export(model, tensor_x) + onnx_program = torch.onnx.dynamo_export(model, tensor_x) As the code above shows, all you need is to provide :func:`torch.onnx.dynamo_export` with an instance of the model and its input. -The exporter will then return an instance of :class:`torch.onnx.ExportOutput` that contains the exported ONNX graph along with extra information. +The exporter will then return an instance of :class:`torch.onnx.ONNXProgram` that contains the exported ONNX graph along with extra information. -The in-memory model available through ``export_output.model_proto`` is an ``onnx.ModelProto`` object in compliance with the `ONNX IR spec `_. -The ONNX model may then be serialized into a `Protobuf file `_ using the :meth:`torch.onnx.ExportOutput.save` API. +The in-memory model available through ``onnx_program.model_proto`` is an ``onnx.ModelProto`` object in compliance with the `ONNX IR spec `_. +The ONNX model may then be serialized into a `Protobuf file `_ using the :meth:`torch.onnx.ONNXProgram.save` API. .. code-block:: python - export_output.save("mlp.onnx") + onnx_program.save("mlp.onnx") Inspecting the ONNX model using GUI ----------------------------------- @@ -140,10 +140,10 @@ API Reference .. autofunction:: torch.onnx.enable_fake_mode -.. autoclass:: torch.onnx.ExportOutput +.. autoclass:: torch.onnx.ONNXProgram :members: -.. autoclass:: torch.onnx.ExportOutputSerializer +.. autoclass:: torch.onnx.ONNXProgramSerializer :members: .. autoclass:: torch.onnx.InvalidExportOptionsError diff --git a/test/onnx/dynamo/test_exporter_api.py b/test/onnx/dynamo/test_exporter_api.py index 3f10bb9d6076..8d32101e0c1b 100644 --- a/test/onnx/dynamo/test_exporter_api.py +++ b/test/onnx/dynamo/test_exporter_api.py @@ -5,12 +5,12 @@ import onnx import torch from beartype import roar -from torch.onnx import dynamo_export, ExportOptions, ExportOutput +from torch.onnx import dynamo_export, ExportOptions, ONNXProgram from torch.onnx._internal import exporter, io_adapter from torch.onnx._internal.exporter import ( - ExportOutputSerializer, - LargeProtobufExportOutputSerializer, - ProtobufExportOutputSerializer, + LargeProtobufONNXProgramSerializer, + ONNXProgramSerializer, + ProtobufONNXProgramSerializer, ResolvedExportOptions, ) from torch.onnx._internal.fx import diagnostics @@ -61,7 +61,7 @@ def test_dynamic_shapes_explicit(self): class TestDynamoExportAPI(common_utils.TestCase): def test_default_export(self): output = dynamo_export(SampleModel(), torch.randn(1, 1, 2)) - self.assertIsInstance(output, ExportOutput) + self.assertIsInstance(output, ONNXProgram) self.assertIsInstance(output.model_proto, onnx.ModelProto) def test_export_with_options(self): @@ -73,7 +73,7 @@ def test_export_with_options(self): dynamic_shapes=True, ), ), - ExportOutput, + ONNXProgram, ) def test_save_to_file_default_serializer(self): @@ -89,9 +89,9 @@ def test_save_to_existing_buffer_default_serializer(self): def test_save_to_file_using_specified_serializer(self): expected_buffer = "I am not actually ONNX" - class CustomSerializer(ExportOutputSerializer): + class CustomSerializer(ONNXProgramSerializer): def serialize( - self, export_output: ExportOutput, destination: io.BufferedIOBase + self, onnx_program: ONNXProgram, destination: io.BufferedIOBase ) -> None: destination.write(expected_buffer.encode()) @@ -105,12 +105,12 @@ def serialize( def test_save_to_file_using_specified_serializer_without_inheritance(self): expected_buffer = "I am not actually ONNX" - # NOTE: Inheritance from `ExportOutputSerializer` is not required. - # Because `ExportOutputSerializer` is a Protocol class. + # NOTE: Inheritance from `ONNXProgramSerializer` is not required. + # Because `ONNXProgramSerializer` is a Protocol class. # `beartype` will not complain. class CustomSerializer: def serialize( - self, export_output: ExportOutput, destination: io.BufferedIOBase + self, onnx_program: ONNXProgram, destination: io.BufferedIOBase ) -> None: destination.write(expected_buffer.encode()) @@ -146,7 +146,7 @@ def forward(self, x): dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2)) self.assertTrue(os.path.exists(exporter._DEFAULT_FAILED_EXPORT_SARIF_LOG_PATH)) - def test_export_output_accessible_from_exception_when_export_failed(self): + def test_onnx_program_accessible_from_exception_when_export_failed(self): class ModelWithExportError(torch.nn.Module): def forward(self, x): raise RuntimeError("Export error") @@ -154,9 +154,9 @@ def forward(self, x): with self.assertRaises(torch.onnx.OnnxExporterError) as cm: dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2)) self.assertIsInstance(cm.exception, torch.onnx.OnnxExporterError) - self.assertIsInstance(cm.exception.export_output, ExportOutput) + self.assertIsInstance(cm.exception.onnx_program, ONNXProgram) - def test_access_export_output_model_proto_raises_when_export_output_is_emitted_from_failed_export( + def test_access_onnx_program_model_proto_raises_when_onnx_program_is_emitted_from_failed_export( self, ): class ModelWithExportError(torch.nn.Module): @@ -165,9 +165,9 @@ def forward(self, x): with self.assertRaises(torch.onnx.OnnxExporterError) as cm: dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2)) - export_output = cm.exception.export_output + onnx_program = cm.exception.onnx_program with self.assertRaises(RuntimeError): - export_output.model_proto + onnx_program.model_proto def test_raise_from_diagnostic_warning_when_diagnostic_option_warning_as_error_is_true( self, @@ -185,8 +185,8 @@ def test_raise_from_diagnostic_warning_when_diagnostic_option_warning_as_error_i def test_raise_on_invalid_save_argument_type(self): with self.assertRaises(roar.BeartypeException): - ExportOutput(torch.nn.Linear(2, 3)) # type: ignore[arg-type] - export_output = ExportOutput( + ONNXProgram(torch.nn.Linear(2, 3)) # type: ignore[arg-type] + onnx_program = ONNXProgram( onnx.ModelProto(), io_adapter.InputAdapter(), io_adapter.OutputAdapter(), @@ -194,30 +194,30 @@ def test_raise_on_invalid_save_argument_type(self): fake_context=None, ) with self.assertRaises(roar.BeartypeException): - export_output.save(None) # type: ignore[arg-type] - export_output.model_proto + onnx_program.save(None) # type: ignore[arg-type] + onnx_program.model_proto -class TestProtobufExportOutputSerializerAPI(common_utils.TestCase): +class TestProtobufONNXProgramSerializerAPI(common_utils.TestCase): def test_raise_on_invalid_argument_type(self): with self.assertRaises(roar.BeartypeException): - serializer = ProtobufExportOutputSerializer() + serializer = ProtobufONNXProgramSerializer() serializer.serialize(None, None) # type: ignore[arg-type] def test_serialize_raises_when_model_greater_than_2gb(self): - export_output = torch.onnx.dynamo_export(_LargeModel(), torch.randn(1)) - serializer = ProtobufExportOutputSerializer() + onnx_program = torch.onnx.dynamo_export(_LargeModel(), torch.randn(1)) + serializer = ProtobufONNXProgramSerializer() with self.assertRaisesRegex(ValueError, "exceeds maximum protobuf size of 2GB"): - serializer.serialize(export_output, io.BytesIO()) + serializer.serialize(onnx_program, io.BytesIO()) -class TestLargeProtobufExportOutputSerializerAPI(common_utils.TestCase): +class TestLargeProtobufONNXProgramSerializerAPI(common_utils.TestCase): def test_serialize_succeeds_when_model_greater_than_2gb(self): - export_output = torch.onnx.dynamo_export(_LargeModel(), torch.randn(1)) + onnx_program = torch.onnx.dynamo_export(_LargeModel(), torch.randn(1)) with common_utils.TemporaryFileName() as path: - serializer = LargeProtobufExportOutputSerializer(path) + serializer = LargeProtobufONNXProgramSerializer(path) # `io.BytesIO()` is unused, but required by the Protocol interface. - serializer.serialize(export_output, io.BytesIO()) + serializer.serialize(onnx_program, io.BytesIO()) if __name__ == "__main__": diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py index 6409edf7b225..e6af3a56870f 100644 --- a/test/onnx/onnx_test_common.py +++ b/test/onnx/onnx_test_common.py @@ -85,11 +85,11 @@ def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs): return verification.verify(*args, options=options, **kwargs) -def assert_dynamic_shapes(export_output: torch.onnx.ExportOutput, dynamic_shapes: bool): +def assert_dynamic_shapes(onnx_program: torch.onnx.ONNXProgram, dynamic_shapes: bool): """Assert whether the exported model has dynamic shapes or not. Args: - export_output (torch.onnx.ExportOutput): The output of torch.onnx.dynamo_export. + onnx_program (torch.onnx.ONNXProgram): The output of torch.onnx.dynamo_export. dynamic_shapes (bool): Whether the exported model has dynamic shapes or not. When True, raises if graph inputs don't have at least one dynamic dimension When False, raises if graph inputs have at least one dynamic dimension. @@ -101,7 +101,7 @@ def assert_dynamic_shapes(export_output: torch.onnx.ExportOutput, dynamic_shapes if dynamic_shapes is None: return - model_proto = export_output.model_proto + model_proto = onnx_program.model_proto # Process graph inputs dynamic_inputs = [] for inp in model_proto.graph.input: @@ -270,7 +270,7 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime( # since ONNX doesn't represent kwargs. export_error: Optional[torch.onnx.OnnxExporterError] = None try: - export_output = torch.onnx.dynamo_export( + onnx_program = torch.onnx.dynamo_export( ref_model, *ref_input_args, **ref_input_kwargs, @@ -284,10 +284,10 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime( ) except torch.onnx.OnnxExporterError as e: export_error = e - export_output = e.export_output + onnx_program = e.onnx_program if verbose and diagnostics.is_onnx_diagnostics_log_artifact_enabled(): - export_output.save_diagnostics( + onnx_program.save_diagnostics( f"test_report_{self._testMethodName}" f"_op_level_debug_{self.op_level_debug}" f"_dynamic_axes_{self.dynamic_shapes}" @@ -298,10 +298,10 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime( raise export_error if not skip_dynamic_shapes_check: - assert_dynamic_shapes(export_output, self.dynamic_shapes) + assert_dynamic_shapes(onnx_program, self.dynamic_shapes) _compare_pytorch_onnx_with_ort( - export_output, + onnx_program, model, input_args, input_kwargs, @@ -324,7 +324,7 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime( else {} ) _compare_pytorch_onnx_with_ort( - export_output, + onnx_program, model, additional_input_args, additional_input_kwargs, @@ -336,7 +336,7 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime( @_beartype.beartype def run_ort( - onnx_model: Union[str, torch.onnx.ExportOutput], + onnx_model: Union[str, torch.onnx.ONNXProgram], pytorch_inputs: Sequence[_InputArgsType], ) -> _OutputsType: """Run ORT on the given ONNX model and inputs @@ -344,7 +344,7 @@ def run_ort( Used in test_fx_to_onnx_with_onnxruntime.py Args: - onnx_model (Union[str, torch.onnx.ExportOutput]): Converter ONNX model + onnx_model (Union[str, torch.onnx.ONNXProgram]): Converter ONNX model pytorch_inputs (Sequence[_InputArgsType]): The given torch inputs Raises: @@ -353,7 +353,7 @@ def run_ort( Returns: _OutputsType: ONNX model predictions """ - if isinstance(onnx_model, torch.onnx.ExportOutput): + if isinstance(onnx_model, torch.onnx.ONNXProgram): buffer = io.BytesIO() onnx_model.save(buffer) ort_model = buffer.getvalue() @@ -398,7 +398,7 @@ def _try_clone_inputs(input_args, input_kwargs): @_beartype.beartype def _compare_pytorch_onnx_with_ort( - export_output: torch.onnx.ExportOutput, + onnx_program: torch.onnx.ONNXProgram, model: _ModelType, input_args: Sequence[_InputArgsType], input_kwargs: Mapping[str, _InputArgsType], @@ -415,14 +415,14 @@ def _compare_pytorch_onnx_with_ort( ref_input_kwargs = input_kwargs # Format original model inputs into the format expected by exported ONNX model. - onnx_format_args = export_output.adapt_torch_inputs_to_onnx( + onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx( *input_args, **input_kwargs ) - ref_outputs = export_output.adapt_torch_outputs_to_onnx( + ref_outputs = onnx_program.adapt_torch_outputs_to_onnx( ref_model(*ref_input_args, **ref_input_kwargs) ) - ort_outputs = run_ort(export_output, onnx_format_args) + ort_outputs = run_ort(onnx_program, onnx_format_args) if len(ref_outputs) != len(ort_outputs): raise AssertionError( diff --git a/test/onnx/test_fx_passes.py b/test/onnx/test_fx_passes.py index 76823fd6f306..4b175c3fcced 100644 --- a/test/onnx/test_fx_passes.py +++ b/test/onnx/test_fx_passes.py @@ -123,10 +123,10 @@ def forward(self, x, y): unused_relu_result = self.unused_relu(x) return result - export_output = torch.onnx.dynamo_export( + onnx_program = torch.onnx.dynamo_export( TestModule(), torch.randn(3), torch.randn(3) ) - model_proto = export_output.model_proto + model_proto = onnx_program.model_proto function_proto_names = [function.name for function in model_proto.functions] self.assertIn( "torch_nn_modules_activation_GELU_used_gelu_1", function_proto_names @@ -146,10 +146,10 @@ def forward(self, x, y): out = self.relu(out) return out - export_output = torch.onnx.dynamo_export( + onnx_program = torch.onnx.dynamo_export( TestModule(), torch.randn(3), torch.randn(3) ) - model_proto = export_output.model_proto + model_proto = onnx_program.model_proto function_proto_names = [function.name for function in model_proto.functions] self.assertIn("torch_nn_modules_activation_ReLU_relu_1", function_proto_names) self.assertIn("torch_nn_modules_activation_ReLU_relu_2", function_proto_names) @@ -178,10 +178,10 @@ def forward(self, x, y): out = self.inner_module.relu(out) return out - export_output = torch.onnx.dynamo_export( + onnx_program = torch.onnx.dynamo_export( TestModule(), torch.randn(3), torch.randn(3) ) - model_proto = export_output.model_proto + model_proto = onnx_program.model_proto function_proto_names = [function.name for function in model_proto.functions] self.assertIn( "torch_nn_modules_activation_ReLU_inner_module_relu_1", function_proto_names diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index 8effa50e0bfa..5f93ca0785a1 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -120,12 +120,12 @@ def forward(self, tensor_x: torch.Tensor): return output tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32) - export_output = dynamo_export( + onnx_program = dynamo_export( MNISTModel(), tensor_x, export_options=ExportOptions(op_level_debug=True) ) assert_has_diagnostics( - export_output.diagnostic_context, + onnx_program.diagnostic_context, diagnostic_rule, diagnostics.levels.NONE, expected_node="aten.convolution.default", @@ -156,9 +156,7 @@ def forward(self, x): return torch.sum(values) x = torch.arange(1.0, 6.0, requires_grad=True) - export_output = dynamo_export( - TopKModel(), x, export_options=self.export_options - ) + onnx_program = dynamo_export(TopKModel(), x, export_options=self.export_options) def test_unsupported_indices_fake_tensor_generated_with_op_level_debug(self): class EmbedModelWithoutPaddingIdx(torch.nn.Module): @@ -169,14 +167,14 @@ def forward(self, input, emb): x = torch.randint(4, (4, 3, 2)) embedding_matrix = torch.rand(10, 3) - export_output = dynamo_export( + onnx_program = dynamo_export( model, x, embedding_matrix, export_options=ExportOptions(op_level_debug=True), ) assert_has_diagnostics( - export_output.diagnostic_context, + onnx_program.diagnostic_context, diagnostics.rules.op_level_debugging, diagnostics.levels.WARNING, expected_node="aten.embedding.default", @@ -190,10 +188,10 @@ def forward(self, input): return input.new_zeros(()) x = torch.randn((2, 3), dtype=torch.float32) - export_output = dynamo_export(TraceModel(), x) + onnx_program = dynamo_export(TraceModel(), x) assert_has_diagnostics( - export_output.diagnostic_context, + onnx_program.diagnostic_context, diagnostics.rules.find_opschema_matched_symbolic_function, diagnostics.levels.WARNING, expected_node="aten.new_zeros.default", @@ -213,11 +211,11 @@ def forward(self, input): return self.conv2(input) x = torch.randn(20, 16, 50, 50) - export_output = dynamo_export( + onnx_program = dynamo_export( TraceModel(), x, export_options=ExportOptions(op_level_debug=False) ) assert_has_diagnostics( - export_output.diagnostic_context, + onnx_program.diagnostic_context, diagnostics.rules.find_opschema_matched_symbolic_function, diagnostics.levels.NONE, expected_node="aten.convolution.default", @@ -242,11 +240,11 @@ def forward(self, input): onnx_registry._registry.pop(aten_add_Tensor) x = torch.tensor(3) - export_output = dynamo_export( + onnx_program = dynamo_export( TraceModel(), x, export_options=ExportOptions(onnx_registry=onnx_registry) ) assert_has_diagnostics( - export_output.diagnostic_context, + onnx_program.diagnostic_context, diagnostics.rules.find_operator_overloads_in_onnx_registry, diagnostics.levels.WARNING, expected_node="aten.add.Tensor", @@ -258,9 +256,9 @@ def forward(self, input): return torch.ops.aten.clone(input, memory_format=torch.preserve_format) x = torch.tensor(3) - export_output = dynamo_export(CustomModule(), x) + onnx_program = dynamo_export(CustomModule(), x) assert_has_diagnostics( - export_output.diagnostic_context, + onnx_program.diagnostic_context, diagnostics.rules.find_opschema_matched_symbolic_function, diagnostics.levels.NONE, expected_node="aten.clone.default", @@ -302,8 +300,8 @@ def forward(self, tensor_x: torch.Tensor): tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32) model = MNISTModel() - export_output = torch.onnx.dynamo_export(model, tensor_x) - model_proto = export_output.model_proto + onnx_program = torch.onnx.dynamo_export(model, tensor_x) + model_proto = onnx_program.model_proto self.assertEqual( {initializer.name for initializer in model_proto.graph.initializer}, {*model.state_dict().keys()}, @@ -323,25 +321,25 @@ def forward(self, x): x = torch.rand(5, 2, 2) model = Model() export_options = ExportOptions(fake_context=fake_context) - export_output = torch.onnx.dynamo_export( + onnx_program = torch.onnx.dynamo_export( model, x, export_options=export_options ) assert ( - export_output is not None - ), "ExportOutput must be created on successful export" + onnx_program is not None + ), "ONNXProgram must be created on successful export" assert ( - export_output.model_proto is not None + onnx_program.model_proto is not None ), "A model protobuf must be created on a successful export" - onnx.checker.check_model(export_output.model_proto, full_check=True) + onnx.checker.check_model(onnx_program.model_proto, full_check=True) assert ( - len(export_output.model_proto.graph.initializer) == 0 + len(onnx_program.model_proto.graph.initializer) == 0 ), "Initializers cannot exist when fake mode is enabled" # Variant 1: Save ONNX proto using Model's state_dict() with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file: model_state_dict = Model().state_dict() # Create a state_dict for testing - export_output.save(tmp_onnx_file.name, model_state_dict=model_state_dict) + onnx_program.save(tmp_onnx_file.name, model_state_dict=model_state_dict) assert ( len(onnx.load(tmp_onnx_file.name).graph.initializer) == 2 ), "Initializers must be present after loading it from model_state_dict" @@ -355,7 +353,7 @@ def forward(self, x): torch.save( Model().state_dict(), tmp_checkpoint_file.name ) # Create checkpoint file for testing - export_output.save( + onnx_program.save( tmp_onnx_file.name, model_state_dict=tmp_checkpoint_file.name ) assert ( @@ -424,15 +422,15 @@ def test_fake_tensor_mode_huggingface_gpt2(self): position_ids = position_ids.unsqueeze(0).view(-1, seq) export_options = torch.onnx.ExportOptions(fake_context=fake_context) - export_output = torch.onnx.dynamo_export( + onnx_program = torch.onnx.dynamo_export( model, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, export_options=export_options, ) - onnx.checker.check_model(export_output.model_proto) - onnx.shape_inference.infer_shapes(export_output.model_proto) + onnx.checker.check_model(onnx_program.model_proto) + onnx.shape_inference.infer_shapes(onnx_program.model_proto) def test_fake_tensor_mode_huggingface_open_llama(self): config = transformers.OpenLlamaConfig( @@ -448,15 +446,15 @@ def test_fake_tensor_mode_huggingface_open_llama(self): position_ids = position_ids.unsqueeze(0).view(-1, seq) export_options = torch.onnx.ExportOptions(fake_context=fake_context) - export_output = torch.onnx.dynamo_export( + onnx_program = torch.onnx.dynamo_export( model, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, export_options=export_options, ) - onnx.checker.check_model(export_output.model_proto) - onnx.shape_inference.infer_shapes(export_output.model_proto) + onnx.checker.check_model(onnx_program.model_proto) + onnx.shape_inference.infer_shapes(onnx_program.model_proto) @pytorch_test_common.xfail( "This is addressed in main branch of transformers." @@ -476,15 +474,15 @@ def test_fake_tensor_mode_huggingface_databricks_dolly_v2_3b(self): position_ids = position_ids.unsqueeze(0).view(-1, seq) export_options = torch.onnx.ExportOptions(fake_context=fake_context) - export_output = torch.onnx.dynamo_export( + onnx_program = torch.onnx.dynamo_export( model, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, export_options=export_options, ) - onnx.checker.check_model(export_output.model_proto) - onnx.shape_inference.infer_shapes(export_output.model_proto) + onnx.checker.check_model(onnx_program.model_proto) + onnx.shape_inference.infer_shapes(onnx_program.model_proto) @pytorch_test_common.xfail( "Not decorated with xfail because CI doesn't have enough memory to run and then fail." @@ -501,14 +499,14 @@ def test_fake_tensor_mode_huggingface_tiiuae_falcon(self): attention_mask = torch.ones(batch, seq, dtype=torch.bool) export_options = torch.onnx.ExportOptions(fake_context=fake_context) - export_output = torch.onnx.dynamo_export( + onnx_program = torch.onnx.dynamo_export( model, input_ids=input_ids, attention_mask=attention_mask, export_options=export_options, ) - onnx.checker.check_model(export_output.model_proto) - onnx.shape_inference.infer_shapes(export_output.model_proto) + onnx.checker.check_model(onnx_program.model_proto) + onnx.shape_inference.infer_shapes(onnx_program.model_proto) def test_exported_program_input_with_custom_fx_tracer(self): from torch.onnx._internal import exporter @@ -529,12 +527,12 @@ def forward(self, x): dynamo_graph_extractor.DynamoExport() ) # Override fx_tracer to an unsupported tracer with self.assertRaises(torch.onnx.OnnxExporterError): - export_output = torch.onnx.dynamo_export( + onnx_program = torch.onnx.dynamo_export( exported_program, x, export_options=export_options, ) - self.assertTrue(export_output._export_exception is not None) + self.assertTrue(onnx_program._export_exception is not None) with self.assertRaises(torch.onnx.InvalidExportOptionsError): raise self._export_exception diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index f0ee29ba34a1..1618833d8859 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -141,7 +141,7 @@ def func(x, b=1.0): tensor_x = torch.randn(1, 1, 2, dtype=torch.float32) - export_output = torch.onnx.dynamo_export( + onnx_program = torch.onnx.dynamo_export( func, tensor_x, 8.0, @@ -150,17 +150,17 @@ def func(x, b=1.0): dynamic_shapes=self.dynamic_shapes, ), ) - onnx_test_common.assert_dynamic_shapes(export_output, self.dynamic_shapes) - onnx_format_args = export_output.adapt_torch_inputs_to_onnx(tensor_x, 8.0) - ref_outputs = export_output.adapt_torch_outputs_to_onnx(func(tensor_x, 8.0)) - ort_outputs = onnx_test_common.run_ort(export_output, onnx_format_args) + onnx_test_common.assert_dynamic_shapes(onnx_program, self.dynamic_shapes) + onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx(tensor_x, 8.0) + ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(func(tensor_x, 8.0)) + ort_outputs = onnx_test_common.run_ort(onnx_program, onnx_format_args) for ref_output, ort_output in zip(ref_outputs, ort_outputs): torch.testing.assert_close(ref_output, torch.tensor(ort_output)) # test on different non-tensor input - xfail - onnx_format_args = export_output.adapt_torch_inputs_to_onnx(tensor_x, 9.0) - ref_outputs = export_output.adapt_torch_outputs_to_onnx(func(tensor_x, 9.0)) - _ = onnx_test_common.run_ort(export_output, onnx_format_args) + onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx(tensor_x, 9.0) + ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(func(tensor_x, 9.0)) + _ = onnx_test_common.run_ort(onnx_program, onnx_format_args) for ref_output, ort_output in zip(ref_outputs, ort_outputs): torch.testing.assert_close(ref_output, torch.tensor(ort_output)) @@ -688,14 +688,14 @@ def _test_fx_symbolic_tracer_large_scale_exporter( export_options.fx_tracer = ( fx_symbolic_graph_extractor.FXSymbolicTracer() ) - export_output = torch.onnx.dynamo_export( + onnx_program = torch.onnx.dynamo_export( fake_model, *fake_args, export_options=export_options, ) - onnx_model = export_output.model_proto + onnx_model = onnx_program.model_proto - onnx_test_common.assert_dynamic_shapes(export_output, self.dynamic_shapes) + onnx_test_common.assert_dynamic_shapes(onnx_program, self.dynamic_shapes) # Tasks done by the following block. # 1. Iterate through all tensors stored in ctx.paths (the file content is loaded torch.load) @@ -709,7 +709,7 @@ def _test_fx_symbolic_tracer_large_scale_exporter( onnx_model_location = model_name + "_external_data.onnx" onnx_initializer_location = model_name + "_initializers" # TODO: We are using the internal `save_model_with_external_data` instead of public - # `ExportOutput.save` because we need to rename ONNX initializers before saving. + # `ONNXProgram.save` because we need to rename ONNX initializers before saving. # This is only needed/allowed because we are using `fx_tracer=FXSymbolicTracer`, # which is not an official FX tracer. fx_serialization.save_model_with_external_data( @@ -724,11 +724,11 @@ def _test_fx_symbolic_tracer_large_scale_exporter( args = create_args() kwargs = create_pytorch_only_kwargs() # Original outputs. - ref_outputs = export_output.adapt_torch_outputs_to_onnx( + ref_outputs = onnx_program.adapt_torch_outputs_to_onnx( model(*args, **kwargs) ) # ORT outputs. - args_not_none = export_output.adapt_torch_inputs_to_onnx(*args) + args_not_none = onnx_program.adapt_torch_inputs_to_onnx(*args) # Drop Parameters and buffers added by fx_serialization.save_model_with_external_data args_not_none = args_not_none[: len(args) - len(kwargs)] @@ -921,7 +921,7 @@ def _test_fake_tensor_mode_exporter( ) if export_within_fake_mode: - export_output = torch.onnx.dynamo_export( + onnx_program = torch.onnx.dynamo_export( fake_model, *fake_args, **fake_kwargs, @@ -929,14 +929,14 @@ def _test_fake_tensor_mode_exporter( ) if not export_within_fake_mode: - export_output = torch.onnx.dynamo_export( + onnx_program = torch.onnx.dynamo_export( fake_model, *fake_args, **fake_kwargs, export_options=export_options ) - onnx_test_common.assert_dynamic_shapes(export_output, self.dynamic_shapes) + onnx_test_common.assert_dynamic_shapes(onnx_program, self.dynamic_shapes) with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file: - export_output.save( + onnx_program.save( tmp_onnx_file.name, model_state_dict=tmp_checkpoint_file.name ) @@ -944,13 +944,11 @@ def _test_fake_tensor_mode_exporter( args = create_args() kwargs = create_kwargs() # Original outputs. - ref_outputs = export_output.adapt_torch_outputs_to_onnx( + ref_outputs = onnx_program.adapt_torch_outputs_to_onnx( real_model(*args, **kwargs) ) # ORT outputs. - args_not_none = export_output.adapt_torch_inputs_to_onnx( - *args, **kwargs - ) + args_not_none = onnx_program.adapt_torch_inputs_to_onnx(*args, **kwargs) ort_outputs = onnx_test_common.run_ort( tmp_onnx_file.name, diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 86f3887c02f0..e50dfb33004c 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -46,8 +46,8 @@ from ._internal.exporter import ( # usort:skip. needs to be last to avoid circular import DiagnosticOptions, ExportOptions, - ExportOutput, - ExportOutputSerializer, + ONNXProgram, + ONNXProgramSerializer, InvalidExportOptionsError, OnnxExporterError, OnnxRegistry, @@ -101,8 +101,8 @@ # Dynamo Exporter "DiagnosticOptions", "ExportOptions", - "ExportOutput", - "ExportOutputSerializer", + "ONNXProgram", + "ONNXProgramSerializer", "InvalidExportOptionsError", "OnnxExporterError", "OnnxRegistry", @@ -116,8 +116,8 @@ ExportTypes.__module__ = "torch.onnx" JitScalarType.__module__ = "torch.onnx" ExportOptions.__module__ = "torch.onnx" -ExportOutput.__module__ = "torch.onnx" -ExportOutputSerializer.__module__ = "torch.onnx" +ONNXProgram.__module__ = "torch.onnx" +ONNXProgramSerializer.__module__ = "torch.onnx" dynamo_export.__module__ = "torch.onnx" InvalidExportOptionsError.__module__ = "torch.onnx" OnnxExporterError.__module__ = "torch.onnx" diff --git a/torch/onnx/_internal/exporter.py b/torch/onnx/_internal/exporter.py index b3ab705c0873..86742589dfe9 100644 --- a/torch/onnx/_internal/exporter.py +++ b/torch/onnx/_internal/exporter.py @@ -1,4 +1,4 @@ -# necessary to surface onnx.ModelProto through ExportOutput: +# necessary to surface onnx.ModelProto through ONNXProgram: from __future__ import annotations import abc @@ -377,7 +377,7 @@ def __init__( message = "'model' of type 'ExportedProgram' is only supported with 'TorchExport' FX Tracer" e = InvalidExportOptionsError(message) raise InvalidExportOptionsError( - ExportOutput._from_failure(e, options.diagnostic_context), message + ONNXProgram._from_failure(e, options.diagnostic_context), message ) self.fx_tracer = options.fx_tracer self.onnx_registry = options.onnx_registry @@ -471,15 +471,15 @@ def enable_fake_mode(): ... my_nn_module = MyModel() ... arg1 = torch.randn(2, 2, 2) # positional input 1 >>> export_options = torch.onnx.ExportOptions(fake_context=fake_context) - >>> export_output = torch.onnx.dynamo_export( + >>> onnx_program = torch.onnx.dynamo_export( ... my_nn_module, ... arg1, ... export_options=export_options ... ) >>> # Saving model WITHOUT initializers - >>> export_output.save("my_model_without_initializers.onnx") + >>> onnx_program.save("my_model_without_initializers.onnx") >>> # Saving model WITH initializers - >>> export_output.save("my_model_with_initializers.onnx", model_state_dict=MyModel().state_dict()) + >>> onnx_program.save("my_model_with_initializers.onnx", model_state_dict=MyModel().state_dict()) .. warning:: This API is experimental and is *NOT* backward-compatible. @@ -510,17 +510,17 @@ def enable_fake_mode(): @runtime_checkable -class ExportOutputSerializer(Protocol): +class ONNXProgramSerializer(Protocol): """Protocol for serializing an ONNX graph into a specific format (e.g. Protobuf). Note that this is an advanced usage scenario.""" def serialize( - self, export_output: ExportOutput, destination: io.BufferedIOBase + self, onnx_program: ONNXProgram, destination: io.BufferedIOBase ) -> None: """Protocol method that must be implemented for serialization. Args: - export_output: Represents the in-memory exported ONNX model + onnx_program: Represents the in-memory exported ONNX model destination: A binary IO stream or pre-allocated buffer into which the serialized model should be written. @@ -542,36 +542,36 @@ def serialize( ... def forward(self, x): ... out = self.linear(x) ... return out - >>> class ProtobufExportOutputSerializer: + >>> class ProtobufONNXProgramSerializer: ... def serialize( - ... self, export_output: torch.onnx.ExportOutput, destination: io.BufferedIOBase + ... self, onnx_program: torch.onnx.ONNXProgram, destination: io.BufferedIOBase ... ) -> None: - ... destination.write(export_output.model_proto.SerializeToString()) + ... destination.write(onnx_program.model_proto.SerializeToString()) >>> model = MyModel() >>> arg1 = torch.randn(2, 2, 2) # positional input 1 >>> torch.onnx.dynamo_export(model, arg1).save( ... destination="exported_model.onnx", - ... serializer=ProtobufExportOutputSerializer(), + ... serializer=ProtobufONNXProgramSerializer(), ... ) """ ... -class ProtobufExportOutputSerializer: +class ProtobufONNXProgramSerializer: """Serializes ONNX graph as Protobuf.""" @_beartype.beartype def serialize( - self, export_output: ExportOutput, destination: io.BufferedIOBase + self, onnx_program: ONNXProgram, destination: io.BufferedIOBase ) -> None: import onnx - if not isinstance(export_output.model_proto, onnx.ModelProto): # type: ignore[attr-defined] - raise ValueError("export_output.ModelProto is not an onnx.ModelProto") - destination.write(export_output.model_proto.SerializeToString()) + if not isinstance(onnx_program.model_proto, onnx.ModelProto): # type: ignore[attr-defined] + raise ValueError("onnx_program.ModelProto is not an onnx.ModelProto") + destination.write(onnx_program.model_proto.SerializeToString()) -class LargeProtobufExportOutputSerializer: +class LargeProtobufONNXProgramSerializer: """Serializes ONNX graph as Protobuf. Fallback to serializing as Protobuf with external data for models larger than 2GB. @@ -584,25 +584,25 @@ def __init__(self, destination_path: str): @_beartype.beartype def serialize( - self, export_output: ExportOutput, destination: io.BufferedIOBase + self, onnx_program: ONNXProgram, destination: io.BufferedIOBase ) -> None: """`destination` is ignored. The model is saved to `self._destination_path` instead.""" import onnx - if export_output.model_proto.ByteSize() < _PROTOBUF_SIZE_MAX_LIMIT: - onnx.save_model(export_output.model_proto, self._destination_path) # type: ignore[attr-defined] + if onnx_program.model_proto.ByteSize() < _PROTOBUF_SIZE_MAX_LIMIT: + onnx.save_model(onnx_program.model_proto, self._destination_path) # type: ignore[attr-defined] else: # ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB # Fallback to serializing the model with external data. onnx.save_model( # type: ignore[attr-defined] - export_output.model_proto, + onnx_program.model_proto, self._destination_path, save_as_external_data=True, all_tensors_to_one_file=True, ) -class ExportOutput: +class ONNXProgram: """An in-memory representation of a PyTorch model that has been exported to ONNX.""" _model_proto: Final[onnx.ModelProto] # type: ignore[name-defined] @@ -698,10 +698,10 @@ def adapt_torch_inputs_to_onnx( ... return x + y1 + y2 + y3 >>> x_dict = {"a": torch.tensor(1.)} >>> y_tuple = (torch.tensor(2.), (torch.tensor(3.), torch.tensor(4.))) - >>> export_output = torch.onnx.dynamo_export(func_with_nested_input_structure, x_dict, y_tuple) + >>> onnx_program = torch.onnx.dynamo_export(func_with_nested_input_structure, x_dict, y_tuple) >>> print(x_dict, y_tuple) {'a': tensor(1.)} (tensor(2.), (tensor(3.), tensor(4.))) - >>> print(export_output.adapt_torch_inputs_to_onnx(x_dict, y_tuple)) + >>> print(onnx_program.adapt_torch_inputs_to_onnx(x_dict, y_tuple)) (tensor(1.), tensor(2.), tensor(3.), tensor(4.)) .. warning:: @@ -746,11 +746,11 @@ def adapt_torch_outputs_to_onnx( >>> x = torch.tensor(1.) >>> y = torch.tensor(2.) >>> z = torch.tensor(3.) - >>> export_output = torch.onnx.dynamo_export(func_returning_tuples, x, y, z) + >>> onnx_program = torch.onnx.dynamo_export(func_returning_tuples, x, y, z) >>> pt_output = func_returning_tuples(x, y, z) >>> print(pt_output) (tensor(3.), (tensor(5.), tensor(8.))) - >>> print(export_output.adapt_torch_outputs_to_onnx(pt_output)) + >>> print(onnx_program.adapt_torch_outputs_to_onnx(pt_output)) [tensor(3.), tensor(5.), tensor(8.)] .. warning:: @@ -765,7 +765,7 @@ def save( destination: Union[str, io.BufferedIOBase], *, model_state_dict: Optional[Union[Dict[str, Any], str]] = None, - serializer: Optional[ExportOutputSerializer] = None, + serializer: Optional[ONNXProgramSerializer] = None, ) -> None: """Saves the in-memory ONNX model to ``destination`` using specified ``serializer``. @@ -785,9 +785,9 @@ def save( if serializer is None: if isinstance(destination, str): - serializer = LargeProtobufExportOutputSerializer(destination) + serializer = LargeProtobufONNXProgramSerializer(destination) else: - serializer = ProtobufExportOutputSerializer() + serializer = ProtobufONNXProgramSerializer() # Add initializers when symbolic tracing is enabled _model_state_dict_files: List[Union[str, io.BytesIO]] = [] @@ -874,10 +874,10 @@ def _from_failure( diagnostic_context: diagnostics.DiagnosticContext, ) -> Self: """ - Creates an instance of :class:`ExportOutput` when the export process encounters a failure. + Creates an instance of :class:`ONNXProgram` when the export process encounters a failure. In case of a failed export, this method is used to encapsulate the exception - and associated diagnostic context within an :class:`ExportOutput` instance for + and associated diagnostic context within an :class:`ONNXProgram` instance for easier handling and debugging. Args: @@ -885,13 +885,13 @@ def _from_failure( diagnostic_context: The context associated with diagnostics during export. Returns: - An instance of :class:`ExportOutput` representing the failed export output. + An instance of :class:`ONNXProgram` representing the failed ONNX program. """ # Defer `import onnx` out of `import torch` path # https://github.com/pytorch/pytorch/issues/103764 import onnx - return ExportOutput( + return ONNXProgram( onnx.ModelProto(), # type: ignore[attr-defined] io_adapter.InputAdapter(), io_adapter.OutputAdapter(), @@ -973,7 +973,7 @@ def __init__( ): self._assert_fake_tensor_mode() - def export(self) -> ExportOutput: + def export(self) -> ONNXProgram: with self.options.diagnostic_context: graph_module = self.options.fx_tracer.generate_fx( self.options, self.model, self.model_args, self.model_kwargs @@ -995,7 +995,7 @@ def export(self) -> ExportOutput: # NOTE: Filter out the initializers with fake tensors when it's fake_mode exporting. # Otherwise, the ONNX exporter will fail: RuntimeError: basic_string::_M_construct null # not valid. - # Concrete data is expected to be filled for those initializers later during `ExportOutput.save`. + # Concrete data is expected to be filled for those initializers later during `ONNXProgram.save`. if self.options.fake_context is not None: initializers_with_real_tensors: Dict[str, torch.Tensor] = {} for ( @@ -1011,7 +1011,7 @@ def export(self) -> ExportOutput: self.options.onnx_registry.opset_version, ) - return torch.onnx.ExportOutput( + return torch.onnx.ONNXProgram( onnx_model, self.options.fx_tracer.input_adapter, self.options.fx_tracer.output_adapter, @@ -1072,22 +1072,22 @@ class OnnxExporterError(RuntimeError): """Raised when an ONNX exporter error occurs. This exception is thrown when there's an error during the ONNX export process. - It encapsulates the :class:`ExportOutput` object generated until the failure, allowing + It encapsulates the :class:`ONNXProgram` object generated until the failure, allowing access to the partial export results and associated metadata. """ - export_output: Final[ExportOutput] + onnx_program: Final[ONNXProgram] - def __init__(self, export_output: ExportOutput, message: str): + def __init__(self, onnx_program: ONNXProgram, message: str): """ - Initializes the OnnxExporterError with the given export output and message. + Initializes the OnnxExporterError with the given ONNX program and message. Args: - export_output (ExportOutput): The partial results of the ONNX export. + onnx_program (ONNXProgram): The partial results of the ONNX export. message (str): The error message to be displayed. """ super().__init__(message) - self.export_output = export_output + self.onnx_program = onnx_program class InvalidExportOptionsError(RuntimeError): @@ -1145,7 +1145,7 @@ def dynamo_export( *model_args, export_options: Optional[ExportOptions] = None, **model_kwargs, -) -> ExportOutput: +) -> ONNXProgram: """Export a torch.nn.Module to an ONNX graph. Args: @@ -1171,7 +1171,7 @@ def forward(self, x, bias=None): model = MyModel() kwargs = {"bias": 3.} args = (torch.randn(2, 2, 2),) - export_output = torch.onnx.dynamo_export( + onnx_program = torch.onnx.dynamo_export( model, *args, **kwargs).save("my_simple_model.onnx") @@ -1181,18 +1181,18 @@ def forward(self, x, bias=None): # The previous model can be exported with dynamic shapes export_options = torch.onnx.ExportOptions(dynamic_shapes=True) - export_output = torch.onnx.dynamo_export( + onnx_program = torch.onnx.dynamo_export( model, *args, **kwargs, export_options=export_options) - export_output.save("my_dynamic_model.onnx") + onnx_program.save("my_dynamic_model.onnx") By printing input dynamic dimensions we can see the input shape is no longer (2,2,2) :: - >>> print(export_output.model_proto.graph.input[0]) + >>> print(onnx_program.model_proto.graph.input[0]) name: "arg0" type { tensor_type { @@ -1241,7 +1241,7 @@ def forward(self, x, bias=None): f"Please report a bug on PyTorch Github: {_PYTORCH_GITHUB_ISSUES_URL}" ) raise OnnxExporterError( - ExportOutput._from_failure(e, resolved_export_options.diagnostic_context), + ONNXProgram._from_failure(e, resolved_export_options.diagnostic_context), message, ) from e @@ -1326,8 +1326,8 @@ def common_pre_export_passes( __all__ = [ "DiagnosticOptions", "ExportOptions", - "ExportOutput", - "ExportOutputSerializer", + "ONNXProgram", + "ONNXProgramSerializer", "InvalidExportOptionsError", "OnnxExporterError", "OnnxRegistry",