From f53ba044b730b51b58ef25046a9e22b4a59ba84e Mon Sep 17 00:00:00 2001 From: lucylq Date: Wed, 12 Nov 2025 18:23:51 -0800 Subject: [PATCH] Use PTEFile class in serialize_pte_binary Take in a PTE file for serialization instead of program, mutable_segments, named_data segments separately. Differential Revision: [D86908241](https://our.internmc.facebook.com/intern/diff/D86908241/) [ghstack-poisoned] --- .../bundled_program/test/test_bundle_data.py | 8 +++-- exir/_serialize/_program.py | 20 ++++++------ exir/_serialize/_serialize.py | 14 ++++---- exir/_serialize/test/test_program.py | 32 ++++++++++--------- exir/backend/test/test_compatibility.py | 4 +-- exir/lowered_backend_module.py | 8 +++-- 6 files changed, 47 insertions(+), 39 deletions(-) diff --git a/devtools/bundled_program/test/test_bundle_data.py b/devtools/bundled_program/test/test_bundle_data.py index a587a8672e9..9fdeb4a776d 100644 --- a/devtools/bundled_program/test/test_bundle_data.py +++ b/devtools/bundled_program/test/test_bundle_data.py @@ -18,7 +18,7 @@ from executorch.devtools.bundled_program.util.test_util import ( get_common_executorch_program, ) -from executorch.exir._serialize import _serialize_pte_binary +from executorch.exir._serialize import _PTEFile, _serialize_pte_binary class TestBundle(unittest.TestCase): @@ -72,7 +72,11 @@ def test_bundled_program(self) -> None: self.assertEqual( bundled_program.serialize_to_schema().program, - bytes(_serialize_pte_binary(executorch_program.executorch_program)), + bytes( + _serialize_pte_binary( + pte_file=_PTEFile(program=executorch_program.executorch_program) + ) + ), ) def test_bundled_program_from_pte(self) -> None: diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index e4af45c08ce..80374409a71 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -419,19 +419,17 @@ def _extract_named_data( def serialize_pte_binary( - program: Program, + pte_file: PTEFile, *, - mutable_data: Optional[List[Buffer]] = None, extract_delegate_segments: bool = False, segment_alignment: int = 128, constant_tensor_alignment: Optional[int] = None, delegate_alignment: Optional[int] = None, - named_data: Optional[NamedDataStoreOutput] = None, ) -> Cord: """Returns the runtime binary representation of the given Program. Args: - program: The Program to serialize. + pte_file: PTEFile class containing the program and segments. extract_delegate_segments: Whether to move delegate data blobs from the Program into separate segments, rather than encoding those blobs in the flatbuffer data. When true, will also: @@ -446,8 +444,6 @@ def serialize_pte_binary( delegate_alignment: If provided, the minimum alignment of delegate data in the program. Must be a power of 2. If not provided, uses the value in the schema file. - named_data: If provided, named blobs to be stored in segments - after the PTE file. Returns: The serialized form of the Program, ready for execution by the runtime. """ @@ -458,7 +454,7 @@ def serialize_pte_binary( # Don't modify the original program. # TODO(T144120904): Could avoid yet more huge copies with a more shallow # copy, reusing the actual data blobs. - program = copy.deepcopy(program) + program = copy.deepcopy(pte_file.program) # Store extracted segment data, with any buffer-specific alignment. # This may be constant data, delegate data or named data. @@ -482,9 +478,9 @@ def serialize_pte_binary( # Add to the aggregate segments cord. segments.append(AlignedData(constant_segment_data)) - if mutable_data is not None: + if pte_file.mutable_data is not None: mutable_segment_data, mutable_segment_offsets = _extract_constant_segment( - mutable_data, + pte_file.mutable_data, tensor_alignment=None, # data is copied at Method load so no need to align. ) if len(mutable_segment_data) > 0: @@ -499,8 +495,10 @@ def serialize_pte_binary( if extract_delegate_segments: _extract_delegate_segments(program, segments) - if named_data is not None: - _extract_named_data(program, segments, named_data.buffers, named_data.pte_data) + if pte_file.named_data is not None: + _extract_named_data( + program, segments, pte_file.named_data.buffers, pte_file.named_data.pte_data + ) # Append all segments into a single Cord, adding any necessary padding to ensure that # each segment begins at the required alignment. diff --git a/exir/_serialize/_serialize.py b/exir/_serialize/_serialize.py index 789ae89b190..60b6079f4a8 100644 --- a/exir/_serialize/_serialize.py +++ b/exir/_serialize/_serialize.py @@ -8,10 +8,10 @@ from typing import Dict, Optional, Set, Tuple -from executorch.exir._serialize import _serialize_pte_binary - from executorch.exir._serialize._cord import Cord from executorch.exir._serialize._named_data_store import NamedDataStoreOutput + +from executorch.exir._serialize._program import PTEFile, serialize_pte_binary from executorch.exir._serialize.data_serializer import ( DataEntry, DataPayload, @@ -46,14 +46,16 @@ def serialize_for_executorch( pte_data=named_data_store.pte_data, external_data={}, ) - pte: Cord = _serialize_pte_binary( - program=emitter_output.program, - mutable_data=emitter_output.mutable_data, + pte: Cord = serialize_pte_binary( + pte_file=PTEFile( + program=emitter_output.program, + mutable_data=emitter_output.mutable_data, + named_data=pte_named_data, + ), extract_delegate_segments=config.extract_delegate_segments, segment_alignment=config.segment_alignment, constant_tensor_alignment=config.constant_tensor_alignment, delegate_alignment=config.delegate_alignment, - named_data=pte_named_data, ) # Serialize PTD files. diff --git a/exir/_serialize/test/test_program.py b/exir/_serialize/test/test_program.py index b2a22694245..46e8f020a0b 100644 --- a/exir/_serialize/test/test_program.py +++ b/exir/_serialize/test/test_program.py @@ -24,6 +24,7 @@ _json_to_program, _program_to_json, deserialize_pte_binary, + PTEFile, serialize_pte_binary, ) from executorch.exir._serialize.data_serializer import DataEntry @@ -173,7 +174,7 @@ def constant_segment_with_tensor_alignment( # Extract blobs into constant segment during serialization. pte_data = bytes( serialize_pte_binary( - program, + PTEFile(program=program), segment_alignment=SEGMENT_ALIGNMENT, constant_tensor_alignment=constant_tensor_alignment, ) @@ -446,7 +447,7 @@ def test_round_trip_no_header_no_segments(self) -> None: deserializing. """ program = get_test_program() - pte_data = bytes(serialize_pte_binary(program)) + pte_data = bytes(serialize_pte_binary(pte_file=PTEFile(program))) self.assertGreater(len(pte_data), 16) # File magic should be present at the expected offset. @@ -471,7 +472,7 @@ def test_round_trip_large_buffer_sizes(self) -> None: """ program = get_test_program() program.execution_plan[0].non_const_buffer_sizes = [0, 2**48] - flatbuffer_from_py = bytes(serialize_pte_binary(program)) + flatbuffer_from_py = bytes(serialize_pte_binary(pte_file=PTEFile(program))) self.assert_programs_equal( program, deserialize_pte_binary(flatbuffer_from_py).program ) @@ -483,7 +484,11 @@ def test_round_trip_no_segments_and_no_header(self) -> None: the same after serializing and deserializing. """ program = get_test_program() - pte_data = bytes(serialize_pte_binary(program, extract_delegate_segments=True)) + pte_data = bytes( + serialize_pte_binary( + pte_file=PTEFile(program), extract_delegate_segments=True + ) + ) self.assertGreater(len(pte_data), 16) # File magic should be present at the expected offset. @@ -533,7 +538,7 @@ def test_round_trip_with_segments(self) -> None: # Extract the blobs into segments during serialization. pte_data = bytes( serialize_pte_binary( - program, + PTEFile(program=program), extract_delegate_segments=True, segment_alignment=SEGMENT_ALIGNMENT, ) @@ -647,7 +652,7 @@ def test_no_constants(self) -> None: pte_data = bytes( serialize_pte_binary( - program, + PTEFile(program=program), extract_delegate_segments=True, segment_alignment=SEGMENT_ALIGNMENT, constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT, @@ -679,7 +684,7 @@ def test_unused_inline_delegate_blobs_with_segments(self) -> None: # Extract the blobs into segments should succeeed. pte_data = bytes( serialize_pte_binary( - program, + PTEFile(program=program), extract_delegate_segments=True, segment_alignment=SEGMENT_ALIGNMENT, ) @@ -694,7 +699,7 @@ def test_unused_inline_delegate_blobs_with_segments(self) -> None: # Should cause serialization to fail. with self.assertRaises(ValueError): serialize_pte_binary( - program, + PTEFile(program=program), extract_delegate_segments=True, segment_alignment=SEGMENT_ALIGNMENT, ) @@ -715,7 +720,7 @@ def test_constant_segment_tensor_alignment_non_power_of_2_fails(self) -> None: # Expect failure as tensor alignment 14 is not a power of 2. with self.assertRaises(ValueError): serialize_pte_binary( - program, + PTEFile(program=program), segment_alignment=SEGMENT_ALIGNMENT, constant_tensor_alignment=constant_tensor_alignment, ) @@ -750,11 +755,10 @@ def test_constant_delegate_and_named_data_segments(self) -> None: # Extract the blobs into segments during serialization. pte_data = bytes( serialize_pte_binary( - program, + PTEFile(program=program, named_data=named_data), extract_delegate_segments=True, segment_alignment=SEGMENT_ALIGNMENT, constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT, - named_data=named_data, ) ) @@ -961,11 +965,10 @@ def test_named_data_segments(self) -> None: # Serialize the program with named data segments. pte_data = bytes( serialize_pte_binary( - program, + PTEFile(program=program, named_data=named_data), extract_delegate_segments=True, segment_alignment=SEGMENT_ALIGNMENT, constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT, - named_data=named_data, ) ) @@ -1046,11 +1049,10 @@ def test_named_data_segments(self) -> None: # Test re-serialize pte_data2 = serialize_pte_binary( - deserialized.program, + PTEFile(program=deserialized.program, named_data=deserialized.named_data), extract_delegate_segments=True, segment_alignment=SEGMENT_ALIGNMENT, constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT, - named_data=deserialized.named_data, ) # pte_data2 is not going to be the same as pte_data due to alignment; # directly test the deserialized one. diff --git a/exir/backend/test/test_compatibility.py b/exir/backend/test/test_compatibility.py index 4bde3d40b2c..437982820c8 100644 --- a/exir/backend/test/test_compatibility.py +++ b/exir/backend/test/test_compatibility.py @@ -8,7 +8,7 @@ import torch from executorch.exir import to_edge -from executorch.exir._serialize import _serialize_pte_binary +from executorch.exir._serialize import _PTEFile, _serialize_pte_binary from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.canonical_partitioners.all_node_partitioner import ( AllNodePartitioner, @@ -58,7 +58,7 @@ def forward(self, x): # Generate the .pte file with the wrong version. buff = bytes( _serialize_pte_binary( - program=prog, + pte_file=_PTEFile(program=prog), ) ) diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index 61414990703..c0ff61242df 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -13,7 +13,7 @@ import torch import torch.utils._pytree as pytree -from executorch.exir._serialize import _serialize_pte_binary +from executorch.exir._serialize import _PTEFile, _serialize_pte_binary from executorch.exir._serialize._named_data_store import NamedDataStoreOutput from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name @@ -164,12 +164,14 @@ def buffer( # TODO(T181463742): avoid calling bytes(..) which incurs large copies. out = bytes( _serialize_pte_binary( - program=self.program(memory_planning=memory_planning), + pte_file=_PTEFile( + program=self.program(memory_planning=memory_planning), + named_data=self.named_data_store_output, + ), extract_delegate_segments=extract_delegate_segments, segment_alignment=segment_alignment, constant_tensor_alignment=constant_tensor_alignment, delegate_alignment=delegate_alignment, - named_data=self.named_data_store_output, ) ) return out