Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Enable 'ExportOutput.save' for models larger than 2GB #107904

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 37 additions & 0 deletions test/onnx/dynamo/test_exporter_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch.onnx._internal import exporter, io_adapter
from torch.onnx._internal.exporter import (
ExportOutputSerializer,
LargeProtobufExportOutputSerializer,
ProtobufExportOutputSerializer,
ResolvedExportOptions,
)
Expand All @@ -24,6 +25,16 @@ def forward(self, x):
return (y, z)


class _LargeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.randn(2**28)) # 1GB
self.param2 = torch.nn.Parameter(torch.randn(2**28)) # 1GB

def forward(self, x):
return self.param + self.param2 + x


class TestExportOptionsAPI(common_utils.TestCase):
def test_raise_on_invalid_argument_type(self):
expected_exception_type = roar.BeartypeException
Expand Down Expand Up @@ -110,6 +121,17 @@ def serialize(
with open(path) as fp:
self.assertEqual(fp.read(), expected_buffer)

def test_save_succeeds_when_model_greater_than_2gb_and_destination_is_str(self):
with common_utils.TemporaryFileName() as path:
dynamo_export(_LargeModel(), torch.randn(1)).save(path)

def test_save_raises_when_model_greater_than_2gb_and_destination_is_not_str(self):
with self.assertRaisesRegex(
ValueError,
"'destination' must be a string when saving model larger than 2GB.",
):
dynamo_export(_LargeModel(), torch.randn(1)).save(io.BytesIO())

def test_save_sarif_log_to_file_with_successful_export(self):
with common_utils.TemporaryFileName(suffix=".sarif") as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save_diagnostics(path)
Expand Down Expand Up @@ -182,6 +204,21 @@ def test_raise_on_invalid_argument_type(self):
serializer = ProtobufExportOutputSerializer()
serializer.serialize(None, None) # type: ignore[arg-type]

def test_serialize_raises_when_model_greater_than_2gb(self):
export_output = torch.onnx.dynamo_export(_LargeModel(), torch.randn(1))
serializer = ProtobufExportOutputSerializer()
with self.assertRaisesRegex(ValueError, "exceeds maximum protobuf size of 2GB"):
serializer.serialize(export_output, io.BytesIO())


class TestLargeProtobufExportOutputSerializerAPI(common_utils.TestCase):
def test_serialize_succeeds_when_model_greater_than_2gb(self):
export_output = torch.onnx.dynamo_export(_LargeModel(), torch.randn(1))
with common_utils.TemporaryFileName() as path:
serializer = LargeProtobufExportOutputSerializer(path)
# `io.BytesIO()` is unused, but required by the Protocol interface.
serializer.serialize(export_output, io.BytesIO())


if __name__ == "__main__":
common_utils.run_tests()
42 changes: 40 additions & 2 deletions torch/onnx/_internal/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,36 @@
destination.write(export_output.model_proto.SerializeToString())


class LargeProtobufExportOutputSerializer:
"""Serializes ONNX graph as Protobuf.

Fallback to serializing as Protobuf with external data for models larger than 2GB.
"""

_destination_path: Final[str]

def __init__(self, destination_path: str):
self._destination_path = destination_path

@_beartype.beartype
def serialize(
self, export_output: ExportOutput, destination: io.BufferedIOBase
) -> None:
"""`destination` is ignored. The model is saved to `self._destination_path` instead."""
import onnx

try:
onnx.save_model(export_output.model_proto, self._destination_path)

Check failure on line 566 in torch/onnx/_internal/exporter.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY [attr-defined]

Module has no attribute "save_model"
except ValueError:
# ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB
# Fallback to serializing the model with external data.
onnx.save_model(

Check failure on line 570 in torch/onnx/_internal/exporter.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY [attr-defined]

Module has no attribute "save_model"
export_output.model_proto,
self._destination_path,
save_as_external_data=True,
)


class ExportOutput:
"""An in-memory representation of a PyTorch model that has been exported to ONNX."""

Expand Down Expand Up @@ -724,7 +754,10 @@
"""

if serializer is None:
serializer = ProtobufExportOutputSerializer()
if isinstance(destination, str):
serializer = LargeProtobufExportOutputSerializer(destination)
else:
serializer = ProtobufExportOutputSerializer()

# Add initializers when symbolic tracing is enabled
_model_state_dict_files: List[Union[str, io.BytesIO]] = []
Expand Down Expand Up @@ -779,7 +812,12 @@
with open(destination, "wb") as f:
serializer.serialize(self, f)
else:
serializer.serialize(self, destination)
try:
serializer.serialize(self, destination)
except ValueError:
raise ValueError(
"'destination' must be a string when saving model larger than 2GB."
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
)

@_beartype.beartype
def save_diagnostics(self, destination: str) -> None:
Expand Down