Skip to content

Commit

Permalink
[ONNX] Enable 'ExportOutput.save' for models larger than 2GB
Browse files Browse the repository at this point in the history
ghstack-source-id: e92a457a3648d78f268649c6a256f31a70b87fa5
Pull Request resolved: #107904
  • Loading branch information
BowenBao committed Aug 24, 2023
1 parent 1ef4bd1 commit a95e654
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 2 deletions.
30 changes: 30 additions & 0 deletions test/onnx/dynamo/test_exporter_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch.onnx._internal import exporter, io_adapter
from torch.onnx._internal.exporter import (
ExportOutputSerializer,
LargeProtobufExportOutputSerializer,
ProtobufExportOutputSerializer,
ResolvedExportOptions,
)
Expand All @@ -24,6 +25,16 @@ def forward(self, x):
return (y, z)


class _LargeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.randn(2**28)) # 1GB
self.param2 = torch.nn.Parameter(torch.randn(2**28)) # 1GB

def forward(self, x):
return self.param + self.param2 + x


class TestExportOptionsAPI(common_utils.TestCase):
def test_raise_on_invalid_argument_type(self):
expected_exception_type = roar.BeartypeException
Expand Down Expand Up @@ -110,6 +121,10 @@ def serialize(
with open(path) as fp:
self.assertEqual(fp.read(), expected_buffer)

def test_save_succeeds_when_model_greater_than_2gb(self):
with common_utils.TemporaryFileName() as path:
dynamo_export(_LargeModel(), torch.randn(1)).save(path)

def test_save_sarif_log_to_file_with_successful_export(self):
with common_utils.TemporaryFileName(suffix=".sarif") as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save_diagnostics(path)
Expand Down Expand Up @@ -182,6 +197,21 @@ def test_raise_on_invalid_argument_type(self):
serializer = ProtobufExportOutputSerializer()
serializer.serialize(None, None) # type: ignore[arg-type]

def test_serialize_raises_when_model_greater_than_2gb(self):
export_output = torch.onnx.dynamo_export(_LargeModel(), torch.randn(1))
serializer = ProtobufExportOutputSerializer()
with self.assertRaisesRegex(ValueError, "exceeds maximum protobuf size of 2GB"):
serializer.serialize(export_output, io.BytesIO())


class TestLargeProtobufExportOutputSerializerAPI(common_utils.TestCase):
def test_serialize_succeeds_when_model_greater_than_2gb(self):
export_output = torch.onnx.dynamo_export(_LargeModel(), torch.randn(1))
with common_utils.TemporaryFileName() as path:
serializer = LargeProtobufExportOutputSerializer(path)
# `io.BytesIO()` is unused, but required by the Protocol interface.
serializer.serialize(export_output, io.BytesIO())


if __name__ == "__main__":
common_utils.run_tests()
42 changes: 40 additions & 2 deletions torch/onnx/_internal/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,36 @@ def serialize(
destination.write(export_output.model_proto.SerializeToString())


class LargeProtobufExportOutputSerializer:
"""Serializes ONNX graph as Protobuf.
Fallback to serializing as Protobuf with external data for models larger than 2GB.
"""

_destination_path: Final[str]

def __init__(self, destination_path: str):
self._destination_path = destination_path

@_beartype.beartype
def serialize(
self, export_output: ExportOutput, destination: io.BufferedIOBase
) -> None:
"""`destination` is ignored. The model is saved to `self._destination_path` instead."""
import onnx

try:
onnx.save(export_output.model_proto, self._destination_path)
except ValueError:
# ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB
# Fallback to serializing the model with external data.
onnx.save(
export_output.model_proto,
self._destination_path,
save_as_external_data=True,
)


class ExportOutput:
"""An in-memory representation of a PyTorch model that has been exported to ONNX."""

Expand Down Expand Up @@ -724,7 +754,10 @@ def save(
"""

if serializer is None:
serializer = ProtobufExportOutputSerializer()
if isinstance(destination, str):
serializer = LargeProtobufExportOutputSerializer(destination)
else:
serializer = ProtobufExportOutputSerializer()

# Add initializers when symbolic tracing is enabled
_model_state_dict_files: List[Union[str, io.BytesIO]] = []
Expand Down Expand Up @@ -779,7 +812,12 @@ def save(
with open(destination, "wb") as f:
serializer.serialize(self, f)
else:
serializer.serialize(self, destination)
try:
serializer.serialize(self, destination)
except ValueError:
raise ValueError(
"'destination' must be a string when saving model larger than 2GB."
)

@_beartype.beartype
def save_diagnostics(self, destination: str) -> None:
Expand Down

0 comments on commit a95e654

Please sign in to comment.