Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions devtools/bundled_program/test/test_bundle_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,16 @@ def test_bundled_program(self) -> None:
method_test_case.expected_outputs,
)

emitter_output = executorch_program._emitter_output
self.assertEqual(
bundled_program.serialize_to_schema().program,
bytes(
_serialize_pte_binary(
pte_file=_PTEFile(program=executorch_program.executorch_program)
pte_file=_PTEFile(
program=executorch_program.executorch_program,
constant_data=emitter_output.constant_data,
mutable_data=emitter_output.mutable_data,
)
)
),
)
Expand Down Expand Up @@ -116,10 +121,18 @@ def test_bundled_program_from_pte(self) -> None:
bundled_program_ioset.expected_outputs,
method_test_case.expected_outputs,
)

emitter_output = executorch_program._emitter_output
self.assertEqual(
bundled_program.serialize_to_schema().program,
executorch_program.buffer,
bytes(
_serialize_pte_binary(
pte_file=_PTEFile(
program=executorch_program.executorch_program,
constant_data=emitter_output.constant_data,
mutable_data=emitter_output.mutable_data,
)
)
),
)

def test_bundled_miss_methods(self) -> None:
Expand Down
73 changes: 37 additions & 36 deletions exir/_serialize/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import math
import re

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import ClassVar, Dict, List, Literal, Optional, Sequence, Tuple

from executorch.exir._serialize._cord import Cord
Expand All @@ -33,7 +33,6 @@
from executorch.exir.schema import (
BackendDelegateDataReference,
BackendDelegateInlineData,
Buffer,
DataLocation,
DataSegment,
NamedData,
Expand All @@ -56,9 +55,10 @@ class PTEFile:
"""

program: Program
# TODO(lfq): add constant data (currently restored in the program)
# TODO(lfq): update this to List[bytes]
mutable_data: Optional[List[Buffer]] = None
# Placeholder for non-const tensors.
constant_data: List[bytes] = field(default_factory=lambda: [b""])
# Placeholder for non-const tensors.
mutable_data: List[bytes] = field(default_factory=lambda: [b""])
named_data: Optional[NamedDataStoreOutput] = None


Expand Down Expand Up @@ -346,14 +346,14 @@ def _extract_delegate_segments(


def _extract_constant_segment(
constant_buffer: List[Buffer],
constant_buffer: List[bytes],
tensor_alignment: Optional[int] = None,
) -> Tuple[Cord, List[int]]:
"""Copies the tensors from the provided list into a Cord and tracks the offsets
of each tensor.

Args:
constant_buffer: list of Buffers from which to extract constants from. Not modified.
constant_buffer: list of bytes from which to extract constants from. Not modified.
tensor_alignment: Alignment in bytes. Each tensor in the cord will be padded to align
with this value. Defaults to ALIGNMENT.

Expand All @@ -365,8 +365,8 @@ def _extract_constant_segment(
current_offset: int = 0
for i in range(len(constant_buffer)):
buffer = constant_buffer[i]
constant_segment_data.append(buffer.storage)
buffer_length = len(buffer.storage)
constant_segment_data.append(buffer)
buffer_length = len(buffer)
pad_length = (
padding_required(buffer_length, tensor_alignment)
if tensor_alignment is not None
Expand Down Expand Up @@ -460,25 +460,24 @@ def serialize_pte_binary(
# This may be constant data, delegate data or named data.
segments: List[AlignedData] = []

constant_segment_data, constant_segment_offsets = _extract_constant_segment(
program.constant_buffer, tensor_alignment=constant_tensor_alignment
)

# If there are no constants, len(constant_segment_data) = 0. However, there may
# be non-constants, in which case len(constant_segment_offsets) = 1, containing
# the placeholder value 0. Ensure the placeholder value is put into
# program.constant_segment.offsets.
if len(constant_segment_offsets) > 0:
# Update program.constant_segment with constant subsegment offset information.
program.constant_segment = SubsegmentOffsets(
segment_index=len(segments), offsets=constant_segment_offsets
if len(pte_file.constant_data) > 1:
constant_segment_data, constant_segment_offsets = _extract_constant_segment(
pte_file.constant_data, tensor_alignment=constant_tensor_alignment
)
# Clear the constant buffer, as constant data will be stored in segments.
program.constant_buffer = []
# Add to the aggregate segments cord.
segments.append(AlignedData(constant_segment_data))

if pte_file.mutable_data is not None:
# If there are no constants, len(constant_segment_data) = 0. However, there may
# be non-constants, in which case len(constant_segment_offsets) = 1, containing
# the placeholder value 0. Ensure the placeholder value is put into
# program.constant_segment.offsets.
if len(constant_segment_offsets) > 0:
# Update program.constant_segment with constant subsegment offset information.
program.constant_segment = SubsegmentOffsets(
segment_index=len(segments), offsets=constant_segment_offsets
)
# Add to the aggregate segments cord.
segments.append(AlignedData(constant_segment_data))

if len(pte_file.mutable_data) > 1:
mutable_segment_data, mutable_segment_offsets = _extract_constant_segment(
pte_file.mutable_data,
tensor_alignment=None, # data is copied at Method load so no need to align.
Expand Down Expand Up @@ -637,8 +636,9 @@ def _restore_segments(program: Program, segment_data: bytes) -> PTEFile:
)

# Replace constants from constant_segment into constant_buffer.
constant_data = None
if program.constant_segment and len(program.constant_segment.offsets) > 0:
constant_buffers: List[Buffer] = []
constant_buffers: List[bytes] = []
constant_segment = segments[program.constant_segment.segment_index]
for i in range(len(program.constant_segment.offsets)):
start_offset = program.constant_segment.offsets[i]
Expand All @@ -649,17 +649,15 @@ def _restore_segments(program: Program, segment_data: bytes) -> PTEFile:
if i < len(program.constant_segment.offsets) - 1
else len(constant_segment)
)
constant_buffers.append(
Buffer(storage=constant_segment[start_offset:end_offset])
)
program.constant_buffer = constant_buffers
constant_buffers.append(constant_segment[start_offset:end_offset])
constant_data = constant_buffers
program.constant_segment.segment_index = 0
program.constant_segment.offsets = []

# Extract mutable segments.
mutable_data = None
if program.mutable_data_segments and len(program.mutable_data_segments.offsets) > 0:
mutable_buffers: List[Buffer] = []
mutable_buffers: List[bytes] = []
mutable_segment = segments[program.mutable_segment.segment_index]
for i in range(len(program.mutable_segments.offsets)):
start_offset = program.mutable_segment.offsets[i]
Expand All @@ -670,9 +668,7 @@ def _restore_segments(program: Program, segment_data: bytes) -> PTEFile:
if i < len(program.mutable_segment.offsets) - 1
else len(mutable_segment)
)
mutable_buffers.append(
Buffer(storage=mutable_segment[start_offset:end_offset])
)
mutable_buffers.append(mutable_segment[start_offset:end_offset])
mutable_data = mutable_buffers
program.mutable_segment.segment_index = 0
program.mutable_segment.offsets = []
Expand All @@ -699,7 +695,12 @@ def _restore_segments(program: Program, segment_data: bytes) -> PTEFile:
named_data = named_data_store.get_named_data_store_output()
program.named_data = []
program.segments = []
return PTEFile(program=program, mutable_data=mutable_data, named_data=named_data)
return PTEFile(
program=program,
constant_data=constant_data,
mutable_data=mutable_data,
named_data=named_data,
)


def deserialize_pte_binary(program_data: bytes) -> PTEFile:
Expand Down
1 change: 1 addition & 0 deletions exir/_serialize/_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def serialize_for_executorch(
pte: Cord = serialize_pte_binary(
pte_file=PTEFile(
program=emitter_output.program,
constant_data=emitter_output.constant_data,
mutable_data=emitter_output.mutable_data,
named_data=pte_named_data,
),
Expand Down
37 changes: 15 additions & 22 deletions exir/_serialize/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,6 @@
CONSTANT_TENSOR_ALIGNMENT: int = 16


def add_constant_data(program: Program, blobs: Sequence[bytes]) -> None:
"""Adds the provided constant data blobs to the program."""
for blob in blobs:
program.constant_buffer.append(Buffer(storage=blob))


def add_delegate_data(
program: Program, plan: ExecutionPlan, blobs: Sequence[bytes]
) -> None:
Expand Down Expand Up @@ -169,12 +163,14 @@ def constant_segment_with_tensor_alignment(
self.gen_blob_data(constant_tensor_alignment, b"\x30\x33\x03"),
self.gen_blob_data(constant_tensor_alignment + 1, b"\x40\x44\x04"),
)
add_constant_data(program, blobs)

# Extract blobs into constant segment during serialization.
pte_data = bytes(
serialize_pte_binary(
PTEFile(program=program),
PTEFile(
program=program,
constant_data=blobs,
),
segment_alignment=SEGMENT_ALIGNMENT,
constant_tensor_alignment=constant_tensor_alignment,
)
Expand Down Expand Up @@ -289,9 +285,7 @@ def constant_segment_with_tensor_alignment(
# during serialization.
self.assertEqual(deserialized.program.execution_plan, program.execution_plan)
# Number of constant tensors should be the same.
self.assertEqual(
len(deserialized.program.constant_buffer), len(program.constant_buffer)
)
self.assertEqual(len(deserialized.constant_data), len(blobs))
self.assertEqual(deserialized.mutable_data, None)
self.assertEqual(deserialized.named_data, None)

Expand Down Expand Up @@ -647,12 +641,13 @@ def test_round_trip_with_segments(self) -> None:

def test_no_constants(self) -> None:
program = get_test_program()
# Insert placeholder for non-const tensors.
add_constant_data(program, [b""])

pte_data = bytes(
serialize_pte_binary(
PTEFile(program=program),
PTEFile(
program=program,
constant_data=[b""], # placeholder for non-const tensors.
),
extract_delegate_segments=True,
segment_alignment=SEGMENT_ALIGNMENT,
constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT,
Expand All @@ -667,10 +662,9 @@ def test_no_constants(self) -> None:
# Constant buffer should be empty.
self.assertEqual(len(flatbuffer_program.constant_buffer), 0)

# Constant segment should contain the placeholder.
# Constant segment also empty
self.assertEqual(flatbuffer_program.constant_segment.segment_index, 0)
self.assertEqual(len(flatbuffer_program.constant_segment.offsets), 1)
self.assertEqual(flatbuffer_program.constant_segment.offsets[0], 0)
self.assertEqual(len(flatbuffer_program.constant_segment.offsets), 0)

def test_unused_inline_delegate_blobs_with_segments(self) -> None:
# Create a program with some delegate data blobs.
Expand Down Expand Up @@ -736,7 +730,6 @@ def test_constant_delegate_and_named_data_segments(self) -> None:
self.gen_blob_data(SEGMENT_ALIGNMENT // 2, b"\x30\x33\x03"),
self.gen_blob_data(SEGMENT_ALIGNMENT + 1, b"\x40\x44\x04"),
)
add_constant_data(program, constant_blobs)
add_delegate_data(program, program.execution_plan[0], delegate_blobs)

# Create named data segment.
Expand All @@ -755,7 +748,9 @@ def test_constant_delegate_and_named_data_segments(self) -> None:
# Extract the blobs into segments during serialization.
pte_data = bytes(
serialize_pte_binary(
PTEFile(program=program, named_data=named_data),
PTEFile(
program=program, constant_data=constant_blobs, named_data=named_data
),
extract_delegate_segments=True,
segment_alignment=SEGMENT_ALIGNMENT,
constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT,
Expand Down Expand Up @@ -933,9 +928,7 @@ def test_constant_delegate_and_named_data_segments(self) -> None:
# during serialization.
self.assertEqual(deserialized.program.execution_plan, program.execution_plan)
# Number of constant tensors should be the same.
self.assertEqual(
len(deserialized.program.constant_buffer), len(program.constant_buffer)
)
self.assertEqual(len(deserialized.constant_data), len(constant_blobs))
self.assertEqual(deserialized.mutable_data, None)
self._check_named_data_store_output(deserialized.named_data, named_data)

Expand Down
20 changes: 14 additions & 6 deletions exir/emit/_emit_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

import torch
Expand Down Expand Up @@ -50,14 +50,17 @@ class EmitterOutput:
# generated by each instruction.
instruction_id_to_num_outs_map: Dict[str, Dict[int, int]]

mutable_data: Optional[List[Buffer]]
# Constant data stored in the PTE file.
constant_data: List[bytes] = field(default_factory=list)
# Mutable data stored in the PTE file.
mutable_data: List[bytes] = field(default_factory=list)

# Constants are optionally stored in external files.
# Aggregate unique external constants into one buffer.
external_constant_buffer: List[bytes]
external_constant_buffer: Optional[List[bytes]] = None
# Each constant_tag groups a set of constants together.
# {constant_tag: {fqn: index into external_constant_buffer}}
external_constant_map: Optional[Dict[str, Dict[str, int]]]
external_constant_map: Optional[Dict[str, Dict[str, int]]] = None


def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.GraphModule:
Expand Down Expand Up @@ -198,18 +201,23 @@ def emit_program(
program=Program(
version=EXECUTORCH_SCHEMA_VERSION,
execution_plan=plans,
constant_buffer=program_state.constant_buffer,
constant_buffer=[], # Do not add constants here anymore.
backend_delegate_data=program_state.backend_delegate_data,
# Segments may be added at serialization time.
segments=[],
# Subsegment offsets may be added at serialization time.
constant_segment=SubsegmentOffsets(segment_index=0, offsets=[]),
mutable_data_segments=None, # Will be filled in during serialization
),
constant_data=(
program_state.constant_buffer
if len(program_state.constant_buffer) > 1
else []
),
mutable_data=(
program_state.mutable_buffer
if len(program_state.mutable_buffer) > 1
else None
else []
),
external_constant_buffer=program_state.external_constant_buffer,
external_constant_map=program_state.external_constant_map,
Expand Down
Loading
Loading