Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions docs/source/pte-file-format.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ Optional extended header:
| byte offset zero above. I.e., it includes these headers.
| [24..31] uint64_t offset (from byte offset zero above) to the start of the
| first segment, or zero if there are no segments.
| [31..?] Any zero-padding necessary to preserve the alignment of the data
| [32..39] uint64_t size of the segment data, ie. the size from the segment_base_offset
| to the end of the segments. Note, the last segment should not have any
| trailing padding.
| [40..?] Any zero-padding necessary to preserve the alignment of the data
| that follows.
End of optional extended header.
```
Expand All @@ -81,13 +84,16 @@ Example:
Offset to flatbuffer root (0x38)
| File magic ("ET??")
| | Extended header magic ("eh??")
| | | Extended header size (0x18)
| | | Extended header size (0x20)
vvvvvvvvvvv vvvvvvvvvvv vvvvvvvvvvv vvvvvvvvvvv
0x0000 38 00 00 00 45 54 3F 3F 65 68 3F 3F 18 00 00 00
0x0000 38 00 00 00 45 54 3F 3F 65 68 3F 3F 20 00 00 00
0x0010 F0 02 00 00 00 00 00 00 00 10 00 00 00 00 00 00
0x0020 20
^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^
| Offset to segments (0x1000)
Size of program flatbuffer data (0x2f0)
|
Segment data size (0x20)
```

## Program data
Expand Down
13 changes: 12 additions & 1 deletion exir/_serialize/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ class _ExtendedHeader:
+ 8
# Segment base offset
+ 8
# Segment data size
+ 8
)

# Instance attributes. @dataclass will turn these into ctor args.
Expand All @@ -155,6 +157,9 @@ class _ExtendedHeader:
# Offset to the start of the first segment, or zero if there
# are no segments.
segment_base_offset: int
# Size of the segment data, in bytes, or zero if there are no segments, or
# if the this field isn't populated in the PTE file.
segment_data_size: int

# The magic bytes read from or to be written to the binary header.
magic: bytes = EXPECTED_MAGIC
Expand Down Expand Up @@ -189,6 +194,7 @@ def from_bytes(data: bytes) -> "_ExtendedHeader":
segment_base_offset=int.from_bytes(
data[16:24], byteorder=_HEADER_BYTEORDER
),
segment_data_size=int.from_bytes(data[24:32], byteorder=_HEADER_BYTEORDER),
)

def is_valid(self) -> bool:
Expand Down Expand Up @@ -220,6 +226,9 @@ def to_bytes(self) -> bytes:
# uint64_t: Offset to the start of the first segment, or zero if
# there are no segments.
+ self.segment_base_offset.to_bytes(8, byteorder=_HEADER_BYTEORDER)
# uint64_t: size of the segment data, or zero if there are no
# segments.
+ self.segment_data_size.to_bytes(8, byteorder=_HEADER_BYTEORDER)
)
return data

Expand Down Expand Up @@ -512,7 +521,9 @@ def serialize_pte_binary(

# Construct and pad the extended header.
header_data: bytes = _ExtendedHeader(
program_size=program_size, segment_base_offset=segment_base_offset
program_size=program_size,
segment_base_offset=segment_base_offset,
segment_data_size=len(segments_data),
).to_bytes()
header_data = pad_to(header_data, padded_header_length)

Expand Down
41 changes: 37 additions & 4 deletions exir/_serialize/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ def constant_segment_with_tensor_alignment(
# the end of the file.
self.assertGreaterEqual(eh.segment_base_offset, eh.program_size)
self.assertLess(eh.segment_base_offset, len(pte_data))
# Segment data_size should be non-zero since there are segments.
self.assertGreater(eh.segment_data_size, 0)

# Peek inside the actual flatbuffer data to see the segments.
program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data))
Expand Down Expand Up @@ -232,6 +234,8 @@ def constant_segment_with_tensor_alignment(
# Check segment data.
offsets = subsegment_offsets.offsets
segment_data: bytes = pte_data[eh.segment_base_offset :]
# Check segment data size.
self.assertEqual(len(segment_data), eh.segment_data_size)

# tensor[1]: padding.
self.assertEqual(
Expand Down Expand Up @@ -514,6 +518,8 @@ def test_round_trip_with_segments(self) -> None:
# the end of the file.
self.assertGreaterEqual(eh.segment_base_offset, eh.program_size)
self.assertLess(eh.segment_base_offset, len(pte_data))
# Segment data size should be non-zero since there are segments.
self.assertGreater(eh.segment_data_size, 0)

# Peek inside the actual flatbuffer data to see the segments. Note that
# this also implicity tests the case where we try parsing the entire
Expand Down Expand Up @@ -566,6 +572,8 @@ def test_round_trip_with_segments(self) -> None:
# Now that we've shown that the base offset is correct, slice off the
# front so that all segment offsets are relative to zero.
segment_data: bytes = pte_data[segment_base_offset:]
# Check segment data size.
self.assertEqual(len(segment_data), eh.segment_data_size)

# End of the first segment. It's much smaller than the alignment,
# so we know that it's followed by zeros.
Expand Down Expand Up @@ -729,6 +737,8 @@ def test_constant_delegate_and_named_data_segments(self) -> None:
# the end of the file.
self.assertGreaterEqual(eh.segment_base_offset, eh.program_size)
self.assertLess(eh.segment_base_offset, len(pte_data))
# Segment data size should be non-zero since there are segments.
self.assertGreater(eh.segment_data_size, 0)

# Peek inside the actual flatbuffer data to see the segments.
program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data))
Expand Down Expand Up @@ -811,6 +821,8 @@ def test_constant_delegate_and_named_data_segments(self) -> None:
# Now that we've shown that the base offset is correct, slice off the
# front so that all segment offsets are relative to zero.
segment_data: bytes = pte_data[segment_base_offset:]
# Check segment data size.
self.assertEqual(len(segment_data), eh.segment_data_size)

# Check segment[0] for constants.
offsets = subsegment_offsets.offsets
Expand Down Expand Up @@ -925,6 +937,8 @@ def test_named_data_segments(self) -> None:
# the end of the file.
self.assertGreaterEqual(eh.segment_base_offset, eh.program_size)
self.assertLess(eh.segment_base_offset, len(pte_data))
# Segment data size should be non-zero since there are segments.
self.assertGreater(eh.segment_data_size, 0)

# Peek inside the actual flatbuffer data to see the named data segments.
program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data))
Expand Down Expand Up @@ -958,6 +972,9 @@ def test_named_data_segments(self) -> None:

# Check the pte data for buffer values.
segment_data: bytes = pte_data[eh.segment_base_offset :]
# Check segment data size.
self.assertEqual(len(segment_data), eh.segment_data_size)

self.assertEqual(
segment_data[
segment_table[0].offset : segment_table[0].offset
Expand Down Expand Up @@ -985,18 +1002,21 @@ def test_named_data_segments(self) -> None:
# the example data.
EXAMPLE_PROGRAM_SIZE: int = 0x1122112233443344
EXAMPLE_SEGMENT_BASE_OFFSET: int = 0x5566556677887788
EXAMPLE_SEGMENT_DATA_SIZE: int = 0x5544554433223322
# This data is intentionally fragile. If the header layout or magic changes,
# this test must change too. The layout of the header is a contract, not an
# implementation detail.
EXAMPLE_HEADER_DATA: bytes = (
# Magic bytes
b"eh00"
# uint32_t header size (little endian)
+ b"\x18\x00\x00\x00"
+ b"\x20\x00\x00\x00"
# uint64_t program size
+ b"\x44\x33\x44\x33\x22\x11\x22\x11"
# uint64_t segment base offset
+ b"\x88\x77\x88\x77\x66\x55\x66\x55"
# uint64_t segment data size
+ b"\x22\x33\x22\x33\x44\x55\x44\x55"
)


Expand All @@ -1005,6 +1025,7 @@ def test_to_bytes(self) -> None:
eh = _ExtendedHeader(
program_size=EXAMPLE_PROGRAM_SIZE,
segment_base_offset=EXAMPLE_SEGMENT_BASE_OFFSET,
segment_data_size=EXAMPLE_SEGMENT_DATA_SIZE,
)
self.assertTrue(eh.is_valid())
self.assertEqual(eh.to_bytes(), EXAMPLE_HEADER_DATA)
Expand All @@ -1013,6 +1034,7 @@ def test_to_bytes_with_non_defaults(self) -> None:
eh = _ExtendedHeader(
program_size=EXAMPLE_PROGRAM_SIZE,
segment_base_offset=EXAMPLE_SEGMENT_BASE_OFFSET,
segment_data_size=EXAMPLE_SEGMENT_DATA_SIZE,
# Override the default magic and length, to demonstrate that this
# does not affect the serialized header.
magic=b"ABCD",
Expand All @@ -1036,6 +1058,7 @@ def test_from_bytes_valid(self) -> None:
self.assertEqual(eh.length, _ExtendedHeader.EXPECTED_LENGTH)
self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE)
self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET)
self.assertEqual(eh.segment_data_size, EXAMPLE_SEGMENT_DATA_SIZE)

def test_from_bytes_with_more_data_than_necessary(self) -> None:
# Pass in more data than necessary to parse the header.
Expand All @@ -1049,6 +1072,7 @@ def test_from_bytes_with_more_data_than_necessary(self) -> None:
self.assertEqual(eh.length, _ExtendedHeader.EXPECTED_LENGTH)
self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE)
self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET)
self.assertEqual(eh.segment_data_size, EXAMPLE_SEGMENT_DATA_SIZE)

def test_from_bytes_larger_than_needed_header_size_field(self) -> None:
# Simulate a backwards-compatibility situation. Parse a header
Expand All @@ -1059,11 +1083,13 @@ def test_from_bytes_larger_than_needed_header_size_field(self) -> None:
# Magic bytes
b"eh00"
# uint32_t header size (little endian)
+ b"\x1c\x00\x00\x00" # Longer than expected
+ b"\x21\x00\x00\x00" # Longer than expected
# uint64_t program size
+ b"\x44\x33\x44\x33\x22\x11\x22\x11"
# uint64_t segment base offset
+ b"\x88\x77\x88\x77\x66\x55\x66\x55"
# uint64_t segment data size
+ b"\x22\x33\x22\x33\x44\x55\x44\x55"
# uint32_t new field (ignored)
+ b"\xff\xee\xff\xee"
)
Expand All @@ -1075,9 +1101,10 @@ def test_from_bytes_larger_than_needed_header_size_field(self) -> None:
self.assertTrue(eh.is_valid())

self.assertEqual(eh.magic, _ExtendedHeader.EXPECTED_MAGIC)
self.assertEqual(eh.length, 28)
self.assertEqual(eh.length, 33)
self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE)
self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET)
self.assertEqual(eh.segment_data_size, EXAMPLE_SEGMENT_DATA_SIZE)

def test_from_bytes_not_enough_data_fails(self) -> None:
# Parsing a truncated prefix should fail.
Expand All @@ -1090,11 +1117,13 @@ def test_from_bytes_invalid_magic(self) -> None:
# Magic bytes
b"ABCD" # Invalid
# uint32_t header size (little endian)
+ b"\x18\x00\x00\x00"
+ b"\x20\x00\x00\x00"
# uint64_t program size
+ b"\x44\x33\x44\x33\x22\x11\x22\x11"
# uint64_t segment base offset
+ b"\x88\x77\x88\x77\x66\x55\x66\x55"
# uint64_t segment data size
+ b"\x22\x33\x22\x33\x44\x55\x44\x55"
)

# Parse the serialized extended header.
Expand All @@ -1109,6 +1138,7 @@ def test_from_bytes_invalid_magic(self) -> None:
self.assertEqual(eh.length, _ExtendedHeader.EXPECTED_LENGTH)
self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE)
self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET)
self.assertEqual(eh.segment_data_size, EXAMPLE_SEGMENT_DATA_SIZE)

def test_from_bytes_invalid_length(self) -> None:
# An invalid serialized header
Expand All @@ -1121,6 +1151,8 @@ def test_from_bytes_invalid_length(self) -> None:
+ b"\x44\x33\x44\x33\x22\x11\x22\x11"
# uint64_t segment base offset
+ b"\x88\x77\x88\x77\x66\x55\x66\x55"
# uint64_t segment data size
+ b"\x22\x33\x22\x33\x44\x55\x44\x55"
)

# Parse the serialized extended header.
Expand All @@ -1135,3 +1167,4 @@ def test_from_bytes_invalid_length(self) -> None:
self.assertEqual(eh.length, 16)
self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE)
self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET)
self.assertEqual(eh.segment_data_size, EXAMPLE_SEGMENT_DATA_SIZE)
15 changes: 15 additions & 0 deletions schema/extended_header.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ static constexpr size_t kHeaderSegmentBaseOffsetOffset =
static constexpr size_t kMinimumHeaderLength =
kHeaderSegmentBaseOffsetOffset + sizeof(uint64_t);

/// The expected location of the segment_data_size field relative to the
/// beginning of the header.
static constexpr size_t kHeaderSegmentDataSizeOffset =
kHeaderSegmentBaseOffsetOffset + sizeof(uint64_t);

/// The expected length of the header, including the segment_data_size field.
static constexpr size_t kHeaderLengthWithSegmentDataSize =
kHeaderSegmentDataSizeOffset + sizeof(uint64_t);

/// Interprets the 4 bytes at `data` as a little-endian uint32_t.
uint32_t GetUInt32LE(const uint8_t* data) {
return (uint32_t)data[0] | ((uint32_t)data[1] << 8) |
Expand Down Expand Up @@ -83,11 +92,17 @@ uint64_t GetUInt64LE(const uint8_t* data) {
return Error::InvalidProgram;
}

uint64_t segment_data_size = 0;
if (header_length >= kHeaderLengthWithSegmentDataSize) {
segment_data_size = GetUInt64LE(header + kHeaderSegmentDataSizeOffset);
}

// The header is present and apparently valid.
return ExtendedHeader{
/*program_size=*/GetUInt64LE(header + kHeaderProgramSizeOffset),
/*segment_base_offset=*/
GetUInt64LE(header + kHeaderSegmentBaseOffsetOffset),
/*segment_data_size=*/segment_data_size,
};
}

Expand Down
8 changes: 8 additions & 0 deletions schema/extended_header.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ struct ExtendedHeader {
* is present.
*/
uint64_t segment_base_offset;

/**
* The size of all the segment data, in bytes. Zero if:
* - no segment is present
* - the segment_data_size field doesn't exist in the header - the case for
* older PTE files.
*/
uint64_t segment_data_size;
};

} // namespace runtime
Expand Down
Loading
Loading