Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions devtools/bundled_program/test/test_bundle_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from executorch.devtools.bundled_program.util.test_util import (
get_common_executorch_program,
)
from executorch.exir._serialize import _serialize_pte_binary
from executorch.exir._serialize import _PTEFile, _serialize_pte_binary
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why _PTEFile but not PTEFile

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is PTEFile supposed to be private only?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly following _serialize_pte_binary and _deserialize_pte_binary ...

I think those are fairly public though haha



class TestBundle(unittest.TestCase):
Expand Down Expand Up @@ -72,7 +72,11 @@ def test_bundled_program(self) -> None:

self.assertEqual(
bundled_program.serialize_to_schema().program,
bytes(_serialize_pte_binary(executorch_program.executorch_program)),
bytes(
_serialize_pte_binary(
pte_file=_PTEFile(program=executorch_program.executorch_program)
)
),
)

def test_bundled_program_from_pte(self) -> None:
Expand Down
20 changes: 9 additions & 11 deletions exir/_serialize/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,19 +419,17 @@ def _extract_named_data(


def serialize_pte_binary(
program: Program,
pte_file: PTEFile,
*,
mutable_data: Optional[List[Buffer]] = None,
extract_delegate_segments: bool = False,
segment_alignment: int = 128,
constant_tensor_alignment: Optional[int] = None,
delegate_alignment: Optional[int] = None,
named_data: Optional[NamedDataStoreOutput] = None,
) -> Cord:
"""Returns the runtime binary representation of the given Program.

Args:
program: The Program to serialize.
pte_file: PTEFile class containing the program and segments.
extract_delegate_segments: Whether to move delegate data blobs from the
Program into separate segments, rather than encoding those blobs
in the flatbuffer data. When true, will also:
Expand All @@ -446,8 +444,6 @@ def serialize_pte_binary(
delegate_alignment: If provided, the minimum alignment of delegate data
in the program. Must be a power of 2. If not provided, uses the
value in the schema file.
named_data: If provided, named blobs to be stored in segments
after the PTE file.
Returns:
The serialized form of the Program, ready for execution by the runtime.
"""
Expand All @@ -458,7 +454,7 @@ def serialize_pte_binary(
# Don't modify the original program.
# TODO(T144120904): Could avoid yet more huge copies with a more shallow
# copy, reusing the actual data blobs.
program = copy.deepcopy(program)
program = copy.deepcopy(pte_file.program)

# Store extracted segment data, with any buffer-specific alignment.
# This may be constant data, delegate data or named data.
Expand All @@ -482,9 +478,9 @@ def serialize_pte_binary(
# Add to the aggregate segments cord.
segments.append(AlignedData(constant_segment_data))

if mutable_data is not None:
if pte_file.mutable_data is not None:
mutable_segment_data, mutable_segment_offsets = _extract_constant_segment(
mutable_data,
pte_file.mutable_data,
tensor_alignment=None, # data is copied at Method load so no need to align.
)
if len(mutable_segment_data) > 0:
Expand All @@ -499,8 +495,10 @@ def serialize_pte_binary(

if extract_delegate_segments:
_extract_delegate_segments(program, segments)
if named_data is not None:
_extract_named_data(program, segments, named_data.buffers, named_data.pte_data)
if pte_file.named_data is not None:
_extract_named_data(
program, segments, pte_file.named_data.buffers, pte_file.named_data.pte_data
)

# Append all segments into a single Cord, adding any necessary padding to ensure that
# each segment begins at the required alignment.
Expand Down
14 changes: 8 additions & 6 deletions exir/_serialize/_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

from typing import Dict, Optional, Set, Tuple

from executorch.exir._serialize import _serialize_pte_binary

from executorch.exir._serialize._cord import Cord
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput

from executorch.exir._serialize._program import PTEFile, serialize_pte_binary
from executorch.exir._serialize.data_serializer import (
DataEntry,
DataPayload,
Expand Down Expand Up @@ -46,14 +46,16 @@ def serialize_for_executorch(
pte_data=named_data_store.pte_data,
external_data={},
)
pte: Cord = _serialize_pte_binary(
program=emitter_output.program,
mutable_data=emitter_output.mutable_data,
pte: Cord = serialize_pte_binary(
pte_file=PTEFile(
program=emitter_output.program,
mutable_data=emitter_output.mutable_data,
named_data=pte_named_data,
),
extract_delegate_segments=config.extract_delegate_segments,
segment_alignment=config.segment_alignment,
constant_tensor_alignment=config.constant_tensor_alignment,
delegate_alignment=config.delegate_alignment,
named_data=pte_named_data,
)

# Serialize PTD files.
Expand Down
32 changes: 17 additions & 15 deletions exir/_serialize/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
_json_to_program,
_program_to_json,
deserialize_pte_binary,
PTEFile,
serialize_pte_binary,
)
from executorch.exir._serialize.data_serializer import DataEntry
Expand Down Expand Up @@ -173,7 +174,7 @@ def constant_segment_with_tensor_alignment(
# Extract blobs into constant segment during serialization.
pte_data = bytes(
serialize_pte_binary(
program,
PTEFile(program=program),
segment_alignment=SEGMENT_ALIGNMENT,
constant_tensor_alignment=constant_tensor_alignment,
)
Expand Down Expand Up @@ -446,7 +447,7 @@ def test_round_trip_no_header_no_segments(self) -> None:
deserializing.
"""
program = get_test_program()
pte_data = bytes(serialize_pte_binary(program))
pte_data = bytes(serialize_pte_binary(pte_file=PTEFile(program)))
self.assertGreater(len(pte_data), 16)

# File magic should be present at the expected offset.
Expand All @@ -471,7 +472,7 @@ def test_round_trip_large_buffer_sizes(self) -> None:
"""
program = get_test_program()
program.execution_plan[0].non_const_buffer_sizes = [0, 2**48]
flatbuffer_from_py = bytes(serialize_pte_binary(program))
flatbuffer_from_py = bytes(serialize_pte_binary(pte_file=PTEFile(program)))
self.assert_programs_equal(
program, deserialize_pte_binary(flatbuffer_from_py).program
)
Expand All @@ -483,7 +484,11 @@ def test_round_trip_no_segments_and_no_header(self) -> None:
the same after serializing and deserializing.
"""
program = get_test_program()
pte_data = bytes(serialize_pte_binary(program, extract_delegate_segments=True))
pte_data = bytes(
serialize_pte_binary(
pte_file=PTEFile(program), extract_delegate_segments=True
)
)
self.assertGreater(len(pte_data), 16)

# File magic should be present at the expected offset.
Expand Down Expand Up @@ -533,7 +538,7 @@ def test_round_trip_with_segments(self) -> None:
# Extract the blobs into segments during serialization.
pte_data = bytes(
serialize_pte_binary(
program,
PTEFile(program=program),
extract_delegate_segments=True,
segment_alignment=SEGMENT_ALIGNMENT,
)
Expand Down Expand Up @@ -647,7 +652,7 @@ def test_no_constants(self) -> None:

pte_data = bytes(
serialize_pte_binary(
program,
PTEFile(program=program),
extract_delegate_segments=True,
segment_alignment=SEGMENT_ALIGNMENT,
constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT,
Expand Down Expand Up @@ -679,7 +684,7 @@ def test_unused_inline_delegate_blobs_with_segments(self) -> None:
# Extract the blobs into segments should succeeed.
pte_data = bytes(
serialize_pte_binary(
program,
PTEFile(program=program),
extract_delegate_segments=True,
segment_alignment=SEGMENT_ALIGNMENT,
)
Expand All @@ -694,7 +699,7 @@ def test_unused_inline_delegate_blobs_with_segments(self) -> None:
# Should cause serialization to fail.
with self.assertRaises(ValueError):
serialize_pte_binary(
program,
PTEFile(program=program),
extract_delegate_segments=True,
segment_alignment=SEGMENT_ALIGNMENT,
)
Expand All @@ -715,7 +720,7 @@ def test_constant_segment_tensor_alignment_non_power_of_2_fails(self) -> None:
# Expect failure as tensor alignment 14 is not a power of 2.
with self.assertRaises(ValueError):
serialize_pte_binary(
program,
PTEFile(program=program),
segment_alignment=SEGMENT_ALIGNMENT,
constant_tensor_alignment=constant_tensor_alignment,
)
Expand Down Expand Up @@ -750,11 +755,10 @@ def test_constant_delegate_and_named_data_segments(self) -> None:
# Extract the blobs into segments during serialization.
pte_data = bytes(
serialize_pte_binary(
program,
PTEFile(program=program, named_data=named_data),
extract_delegate_segments=True,
segment_alignment=SEGMENT_ALIGNMENT,
constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT,
named_data=named_data,
)
)

Expand Down Expand Up @@ -961,11 +965,10 @@ def test_named_data_segments(self) -> None:
# Serialize the program with named data segments.
pte_data = bytes(
serialize_pte_binary(
program,
PTEFile(program=program, named_data=named_data),
extract_delegate_segments=True,
segment_alignment=SEGMENT_ALIGNMENT,
constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT,
named_data=named_data,
)
)

Expand Down Expand Up @@ -1046,11 +1049,10 @@ def test_named_data_segments(self) -> None:

# Test re-serialize
pte_data2 = serialize_pte_binary(
deserialized.program,
PTEFile(program=deserialized.program, named_data=deserialized.named_data),
extract_delegate_segments=True,
segment_alignment=SEGMENT_ALIGNMENT,
constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT,
named_data=deserialized.named_data,
)
# pte_data2 is not going to be the same as pte_data due to alignment;
# directly test the deserialized one.
Expand Down
6 changes: 3 additions & 3 deletions exir/backend/test/test_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
from executorch.exir import to_edge
from executorch.exir._serialize import _serialize_pte_binary
from executorch.exir._serialize import _PTEFile, _serialize_pte_binary
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.canonical_partitioners.all_node_partitioner import (
AllNodePartitioner,
Expand Down Expand Up @@ -58,7 +58,7 @@ def forward(self, x):
# Generate the .pte file with the wrong version.
buff = bytes(
_serialize_pte_binary(
program=prog,
pte_file=_PTEFile(program=prog),
)
)

Expand Down Expand Up @@ -105,7 +105,7 @@ def forward(self, x):
# Generate the .pte file with the wrong version.
buff = bytes(
_serialize_pte_binary(
program=prog,
pte_file=_PTEFile(program=prog),
)
)

Expand Down
8 changes: 5 additions & 3 deletions exir/lowered_backend_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torch
import torch.utils._pytree as pytree
from executorch.exir._serialize import _serialize_pte_binary
from executorch.exir._serialize import _PTEFile, _serialize_pte_binary
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name
Expand Down Expand Up @@ -164,12 +164,14 @@ def buffer(
# TODO(T181463742): avoid calling bytes(..) which incurs large copies.
out = bytes(
_serialize_pte_binary(
program=self.program(memory_planning=memory_planning),
pte_file=_PTEFile(
program=self.program(memory_planning=memory_planning),
named_data=self.named_data_store_output,
),
extract_delegate_segments=extract_delegate_segments,
segment_alignment=segment_alignment,
constant_tensor_alignment=constant_tensor_alignment,
delegate_alignment=delegate_alignment,
named_data=self.named_data_store_output,
)
)
return out
Expand Down
Loading