Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions exir/_serialize/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,24 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
location=DataLocation.INLINE, index=data_index
)

# Replace constants from constant_segment into constant_buffer.
if program.constant_segment and len(program.constant_segment.offsets) > 0:
buffers: List[Buffer] = []
constant_segment = segments[program.constant_segment.segment_index]
for i in range(len(program.constant_segment.offsets)):
start_offset = program.constant_segment.offsets[i]
# Note: this is the original end offset plus any padding between
# it and the next start offset.
end_offset = (
program.constant_segment.offsets[i + 1]
if i < len(program.constant_segment.offsets) - 1
else len(constant_segment)
)
buffers.append(Buffer(storage=constant_segment[start_offset:end_offset]))
program.constant_buffer = buffers
program.constant_segment.segment_index = 0
program.constant_segment.offsets = []

# Clear out the segments list since the original Program didn't have one.
program.segments = []
return program
Expand Down
19 changes: 18 additions & 1 deletion exir/_serialize/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,15 @@ def constant_segment_with_tensor_alignment(
f"{segment_table}",
)

# Convert back.
program2 = deserialize_pte_binary(pte_data)
# Programs are the same besides constant_buffer, as deserialization
# does not preserve constant segment; padding may be added
# during serialization.
self.assertEqual(program2.execution_plan, program.execution_plan)
# Number of constant tensors should be the same.
self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer))

def test_canonicalize_delegate_indices(self) -> None:
def make_execution_plan(
name: str, delegates: List[BackendDelegate]
Expand Down Expand Up @@ -462,7 +471,6 @@ def gen_blob_data(size: int, pattern: bytes) -> bytes:
assert len(ret) == size
return ret

@unittest.skip("TODO(T181362263): Update restore segments to restore cords")
def test_round_trip_with_segments(self) -> None:
# Create a program with some delegate data blobs.
program = get_test_program()
Expand Down Expand Up @@ -803,6 +811,15 @@ def test_constant_segment_and_delegate_segment(self) -> None:
+ b"\x40\x44\x44",
)

# Convert back.
program2 = deserialize_pte_binary(pte_data)
# Programs are the same besides constant_buffer, as deserialization
# does not preserve constant segment; padding may be added
# during serialization.
self.assertEqual(program2.execution_plan, program.execution_plan)
# Number of constant tensors should be the same.
self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer))


# Common data for extended header tests. The two example values should produce
# the example data.
Expand Down