diff --git a/docs/source/pte-file-format.md b/docs/source/pte-file-format.md index 9f5757ed250..66b27fe8a4c 100644 --- a/docs/source/pte-file-format.md +++ b/docs/source/pte-file-format.md @@ -71,7 +71,10 @@ Optional extended header: | byte offset zero above. I.e., it includes these headers. | [24..31] uint64_t offset (from byte offset zero above) to the start of the | first segment, or zero if there are no segments. -| [31..?] Any zero-padding necessary to preserve the alignment of the data +| [32..39] uint64_t size of the segment data, ie. the size from the segment_base_offset +| to the end of the segments. Note, the last segment should not have any +| trailing padding. +| [40..?] Any zero-padding necessary to preserve the alignment of the data | that follows. End of optional extended header. ``` @@ -81,13 +84,16 @@ Example: Offset to flatbuffer root (0x38) | File magic ("ET??") | | Extended header magic ("eh??") - | | | Extended header size (0x18) + | | | Extended header size (0x20) vvvvvvvvvvv vvvvvvvvvvv vvvvvvvvvvv vvvvvvvvvvv -0x0000 38 00 00 00 45 54 3F 3F 65 68 3F 3F 18 00 00 00 +0x0000 38 00 00 00 45 54 3F 3F 65 68 3F 3F 20 00 00 00 0x0010 F0 02 00 00 00 00 00 00 00 10 00 00 00 00 00 00 +0x0020 20 ^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^ | Offset to segments (0x1000) Size of program flatbuffer data (0x2f0) + | + Segment data size (0x20) ``` ## Program data diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index 0994156ae50..448a3afb90c 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -146,6 +146,8 @@ class _ExtendedHeader: + 8 # Segment base offset + 8 + # Segment data size + + 8 ) # Instance attributes. @dataclass will turn these into ctor args. @@ -155,6 +157,9 @@ class _ExtendedHeader: # Offset to the start of the first segment, or zero if there # are no segments. segment_base_offset: int + # Size of the segment data, in bytes, or zero if there are no segments, or + # if the this field isn't populated in the PTE file. + segment_data_size: int # The magic bytes read from or to be written to the binary header. magic: bytes = EXPECTED_MAGIC @@ -189,6 +194,7 @@ def from_bytes(data: bytes) -> "_ExtendedHeader": segment_base_offset=int.from_bytes( data[16:24], byteorder=_HEADER_BYTEORDER ), + segment_data_size=int.from_bytes(data[24:32], byteorder=_HEADER_BYTEORDER), ) def is_valid(self) -> bool: @@ -220,6 +226,9 @@ def to_bytes(self) -> bytes: # uint64_t: Offset to the start of the first segment, or zero if # there are no segments. + self.segment_base_offset.to_bytes(8, byteorder=_HEADER_BYTEORDER) + # uint64_t: size of the segment data, or zero if there are no + # segments. + + self.segment_data_size.to_bytes(8, byteorder=_HEADER_BYTEORDER) ) return data @@ -512,7 +521,9 @@ def serialize_pte_binary( # Construct and pad the extended header. header_data: bytes = _ExtendedHeader( - program_size=program_size, segment_base_offset=segment_base_offset + program_size=program_size, + segment_base_offset=segment_base_offset, + segment_data_size=len(segments_data), ).to_bytes() header_data = pad_to(header_data, padded_header_length) diff --git a/exir/_serialize/test/test_program.py b/exir/_serialize/test/test_program.py index c67849dd28d..18803da05b6 100644 --- a/exir/_serialize/test/test_program.py +++ b/exir/_serialize/test/test_program.py @@ -191,6 +191,8 @@ def constant_segment_with_tensor_alignment( # the end of the file. self.assertGreaterEqual(eh.segment_base_offset, eh.program_size) self.assertLess(eh.segment_base_offset, len(pte_data)) + # Segment data_size should be non-zero since there are segments. + self.assertGreater(eh.segment_data_size, 0) # Peek inside the actual flatbuffer data to see the segments. program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data)) @@ -232,6 +234,8 @@ def constant_segment_with_tensor_alignment( # Check segment data. offsets = subsegment_offsets.offsets segment_data: bytes = pte_data[eh.segment_base_offset :] + # Check segment data size. + self.assertEqual(len(segment_data), eh.segment_data_size) # tensor[1]: padding. self.assertEqual( @@ -514,6 +518,8 @@ def test_round_trip_with_segments(self) -> None: # the end of the file. self.assertGreaterEqual(eh.segment_base_offset, eh.program_size) self.assertLess(eh.segment_base_offset, len(pte_data)) + # Segment data size should be non-zero since there are segments. + self.assertGreater(eh.segment_data_size, 0) # Peek inside the actual flatbuffer data to see the segments. Note that # this also implicity tests the case where we try parsing the entire @@ -566,6 +572,8 @@ def test_round_trip_with_segments(self) -> None: # Now that we've shown that the base offset is correct, slice off the # front so that all segment offsets are relative to zero. segment_data: bytes = pte_data[segment_base_offset:] + # Check segment data size. + self.assertEqual(len(segment_data), eh.segment_data_size) # End of the first segment. It's much smaller than the alignment, # so we know that it's followed by zeros. @@ -729,6 +737,8 @@ def test_constant_delegate_and_named_data_segments(self) -> None: # the end of the file. self.assertGreaterEqual(eh.segment_base_offset, eh.program_size) self.assertLess(eh.segment_base_offset, len(pte_data)) + # Segment data size should be non-zero since there are segments. + self.assertGreater(eh.segment_data_size, 0) # Peek inside the actual flatbuffer data to see the segments. program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data)) @@ -811,6 +821,8 @@ def test_constant_delegate_and_named_data_segments(self) -> None: # Now that we've shown that the base offset is correct, slice off the # front so that all segment offsets are relative to zero. segment_data: bytes = pte_data[segment_base_offset:] + # Check segment data size. + self.assertEqual(len(segment_data), eh.segment_data_size) # Check segment[0] for constants. offsets = subsegment_offsets.offsets @@ -925,6 +937,8 @@ def test_named_data_segments(self) -> None: # the end of the file. self.assertGreaterEqual(eh.segment_base_offset, eh.program_size) self.assertLess(eh.segment_base_offset, len(pte_data)) + # Segment data size should be non-zero since there are segments. + self.assertGreater(eh.segment_data_size, 0) # Peek inside the actual flatbuffer data to see the named data segments. program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data)) @@ -958,6 +972,9 @@ def test_named_data_segments(self) -> None: # Check the pte data for buffer values. segment_data: bytes = pte_data[eh.segment_base_offset :] + # Check segment data size. + self.assertEqual(len(segment_data), eh.segment_data_size) + self.assertEqual( segment_data[ segment_table[0].offset : segment_table[0].offset @@ -985,6 +1002,7 @@ def test_named_data_segments(self) -> None: # the example data. EXAMPLE_PROGRAM_SIZE: int = 0x1122112233443344 EXAMPLE_SEGMENT_BASE_OFFSET: int = 0x5566556677887788 +EXAMPLE_SEGMENT_DATA_SIZE: int = 0x5544554433223322 # This data is intentionally fragile. If the header layout or magic changes, # this test must change too. The layout of the header is a contract, not an # implementation detail. @@ -992,11 +1010,13 @@ def test_named_data_segments(self) -> None: # Magic bytes b"eh00" # uint32_t header size (little endian) - + b"\x18\x00\x00\x00" + + b"\x20\x00\x00\x00" # uint64_t program size + b"\x44\x33\x44\x33\x22\x11\x22\x11" # uint64_t segment base offset + b"\x88\x77\x88\x77\x66\x55\x66\x55" + # uint64_t segment data size + + b"\x22\x33\x22\x33\x44\x55\x44\x55" ) @@ -1005,6 +1025,7 @@ def test_to_bytes(self) -> None: eh = _ExtendedHeader( program_size=EXAMPLE_PROGRAM_SIZE, segment_base_offset=EXAMPLE_SEGMENT_BASE_OFFSET, + segment_data_size=EXAMPLE_SEGMENT_DATA_SIZE, ) self.assertTrue(eh.is_valid()) self.assertEqual(eh.to_bytes(), EXAMPLE_HEADER_DATA) @@ -1013,6 +1034,7 @@ def test_to_bytes_with_non_defaults(self) -> None: eh = _ExtendedHeader( program_size=EXAMPLE_PROGRAM_SIZE, segment_base_offset=EXAMPLE_SEGMENT_BASE_OFFSET, + segment_data_size=EXAMPLE_SEGMENT_DATA_SIZE, # Override the default magic and length, to demonstrate that this # does not affect the serialized header. magic=b"ABCD", @@ -1036,6 +1058,7 @@ def test_from_bytes_valid(self) -> None: self.assertEqual(eh.length, _ExtendedHeader.EXPECTED_LENGTH) self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE) self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET) + self.assertEqual(eh.segment_data_size, EXAMPLE_SEGMENT_DATA_SIZE) def test_from_bytes_with_more_data_than_necessary(self) -> None: # Pass in more data than necessary to parse the header. @@ -1049,6 +1072,7 @@ def test_from_bytes_with_more_data_than_necessary(self) -> None: self.assertEqual(eh.length, _ExtendedHeader.EXPECTED_LENGTH) self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE) self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET) + self.assertEqual(eh.segment_data_size, EXAMPLE_SEGMENT_DATA_SIZE) def test_from_bytes_larger_than_needed_header_size_field(self) -> None: # Simulate a backwards-compatibility situation. Parse a header @@ -1059,11 +1083,13 @@ def test_from_bytes_larger_than_needed_header_size_field(self) -> None: # Magic bytes b"eh00" # uint32_t header size (little endian) - + b"\x1c\x00\x00\x00" # Longer than expected + + b"\x21\x00\x00\x00" # Longer than expected # uint64_t program size + b"\x44\x33\x44\x33\x22\x11\x22\x11" # uint64_t segment base offset + b"\x88\x77\x88\x77\x66\x55\x66\x55" + # uint64_t segment data size + + b"\x22\x33\x22\x33\x44\x55\x44\x55" # uint32_t new field (ignored) + b"\xff\xee\xff\xee" ) @@ -1075,9 +1101,10 @@ def test_from_bytes_larger_than_needed_header_size_field(self) -> None: self.assertTrue(eh.is_valid()) self.assertEqual(eh.magic, _ExtendedHeader.EXPECTED_MAGIC) - self.assertEqual(eh.length, 28) + self.assertEqual(eh.length, 33) self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE) self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET) + self.assertEqual(eh.segment_data_size, EXAMPLE_SEGMENT_DATA_SIZE) def test_from_bytes_not_enough_data_fails(self) -> None: # Parsing a truncated prefix should fail. @@ -1090,11 +1117,13 @@ def test_from_bytes_invalid_magic(self) -> None: # Magic bytes b"ABCD" # Invalid # uint32_t header size (little endian) - + b"\x18\x00\x00\x00" + + b"\x20\x00\x00\x00" # uint64_t program size + b"\x44\x33\x44\x33\x22\x11\x22\x11" # uint64_t segment base offset + b"\x88\x77\x88\x77\x66\x55\x66\x55" + # uint64_t segment data size + + b"\x22\x33\x22\x33\x44\x55\x44\x55" ) # Parse the serialized extended header. @@ -1109,6 +1138,7 @@ def test_from_bytes_invalid_magic(self) -> None: self.assertEqual(eh.length, _ExtendedHeader.EXPECTED_LENGTH) self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE) self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET) + self.assertEqual(eh.segment_data_size, EXAMPLE_SEGMENT_DATA_SIZE) def test_from_bytes_invalid_length(self) -> None: # An invalid serialized header @@ -1121,6 +1151,8 @@ def test_from_bytes_invalid_length(self) -> None: + b"\x44\x33\x44\x33\x22\x11\x22\x11" # uint64_t segment base offset + b"\x88\x77\x88\x77\x66\x55\x66\x55" + # uint64_t segment data size + + b"\x22\x33\x22\x33\x44\x55\x44\x55" ) # Parse the serialized extended header. @@ -1135,3 +1167,4 @@ def test_from_bytes_invalid_length(self) -> None: self.assertEqual(eh.length, 16) self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE) self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET) + self.assertEqual(eh.segment_data_size, EXAMPLE_SEGMENT_DATA_SIZE) diff --git a/runtime/executor/program.cpp b/runtime/executor/program.cpp index e58c8a96aa7..344e3c7177a 100644 --- a/runtime/executor/program.cpp +++ b/runtime/executor/program.cpp @@ -67,6 +67,7 @@ Result get_execution_plan( // See if the program size is in the header. size_t program_size = 0; size_t segment_base_offset = 0; + size_t segment_data_size = 0; { EXECUTORCH_SCOPE_PROF("Program::check_header"); Result header = loader->load( @@ -82,6 +83,24 @@ Result get_execution_plan( // The header has the program size. program_size = eh->program_size; segment_base_offset = eh->segment_base_offset; + segment_data_size = eh->segment_data_size; + + // segment_data_size was added in ET 1.0 release. For BC, only check the + // expected file size when there are no segments or when segment_data_size + // is positive (0-value may indicate no segments) + if ((segment_data_size == 0 && segment_base_offset == 0) || + segment_data_size > 0) { + size_t expected = segment_base_offset == 0 + ? program_size + : segment_base_offset + segment_data_size; + size_t actual = loader->size().get(); + ET_CHECK_OR_RETURN_ERROR( + expected <= actual, + InvalidProgram, + "File size is too small. Expected file size from extended header is %zu, actual file size from data loader is %zu", + expected, + actual); + } } else if (eh.error() == Error::NotFound) { // No header; the program consumes the whole file, and there are no // segments. diff --git a/runtime/executor/test/program_test.cpp b/runtime/executor/test/program_test.cpp index 962bf8f548a..3afb71b3565 100644 --- a/runtime/executor/test/program_test.cpp +++ b/runtime/executor/test/program_test.cpp @@ -574,3 +574,22 @@ TEST_F(ProgramTest, LoadFromMutableSegment) { &program.get(), 500, 1, 1, buffer); EXPECT_NE(err, Error::Ok); } + +TEST_F(ProgramTest, LoadAndCheckPTESize) { + // Load the serialized ModuleAddMul data, with constants in the segment. + const char* linear_path = std::getenv("ET_MODULE_ADD_MUL_PATH"); + Result linear_loader = FileDataLoader::from(linear_path); + ASSERT_EQ(linear_loader.error(), Error::Ok); + Result program = Program::load(&linear_loader.get()); + ASSERT_EQ(program.error(), Error::Ok); + + // Create a truncated file. + Result truncated_file = linear_loader->load( + 0, 200, DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(truncated_file.error(), Error::Ok); + + BufferDataLoader truncated_loader = + BufferDataLoader(truncated_file->data(), 200); + Result truncated_program = Program::load(&truncated_loader); + ASSERT_EQ(truncated_program.error(), Error::InvalidProgram); +} diff --git a/schema/extended_header.cpp b/schema/extended_header.cpp index 3ccf122524a..bd923bbbab7 100644 --- a/schema/extended_header.cpp +++ b/schema/extended_header.cpp @@ -41,6 +41,15 @@ static constexpr size_t kHeaderSegmentBaseOffsetOffset = static constexpr size_t kMinimumHeaderLength = kHeaderSegmentBaseOffsetOffset + sizeof(uint64_t); +/// The expected location of the segment_data_size field relative to the +/// beginning of the header. +static constexpr size_t kHeaderSegmentDataSizeOffset = + kHeaderSegmentBaseOffsetOffset + sizeof(uint64_t); + +/// The expected length of the header, including the segment_data_size field. +static constexpr size_t kHeaderLengthWithSegmentDataSize = + kHeaderSegmentDataSizeOffset + sizeof(uint64_t); + /// Interprets the 4 bytes at `data` as a little-endian uint32_t. uint32_t GetUInt32LE(const uint8_t* data) { return (uint32_t)data[0] | ((uint32_t)data[1] << 8) | @@ -83,11 +92,17 @@ uint64_t GetUInt64LE(const uint8_t* data) { return Error::InvalidProgram; } + uint64_t segment_data_size = 0; + if (header_length >= kHeaderLengthWithSegmentDataSize) { + segment_data_size = GetUInt64LE(header + kHeaderSegmentDataSizeOffset); + } + // The header is present and apparently valid. return ExtendedHeader{ /*program_size=*/GetUInt64LE(header + kHeaderProgramSizeOffset), /*segment_base_offset=*/ GetUInt64LE(header + kHeaderSegmentBaseOffsetOffset), + /*segment_data_size=*/segment_data_size, }; } diff --git a/schema/extended_header.h b/schema/extended_header.h index e6fe578368e..7b37dc3df49 100644 --- a/schema/extended_header.h +++ b/schema/extended_header.h @@ -70,6 +70,14 @@ struct ExtendedHeader { * is present. */ uint64_t segment_base_offset; + + /** + * The size of all the segment data, in bytes. Zero if: + * - no segment is present + * - the segment_data_size field doesn't exist in the header - the case for + * older PTE files. + */ + uint64_t segment_data_size; }; } // namespace runtime diff --git a/schema/test/extended_header_test.cpp b/schema/test/extended_header_test.cpp index 6099d6762a6..98f2ca2966e 100644 --- a/schema/test/extended_header_test.cpp +++ b/schema/test/extended_header_test.cpp @@ -35,6 +35,8 @@ class ExtendedHeaderTest : public ::testing::Test { * an implementation detail. */ // clang-format off + +// The minimum header. constexpr char kExampleHeaderData[] = { // Magic bytes 'e', 'h', '0', '0', @@ -45,6 +47,20 @@ constexpr char kExampleHeaderData[] = { // uint64_t segment base offset 0x72, 0x62, 0x52, 0x42, 0x32, 0x22, 0x12, 0x02, }; + +// Contains segment data size. +constexpr char kExampleHeaderDataExtended[] = { + // Magic bytes + 'e', 'h', '0', '0', + // uint32_t header size (little endian) + 0x20, 0x00, 0x00, 0x00, + // uint64_t program size + 0x71, 0x61, 0x51, 0x41, 0x31, 0x21, 0x11, 0x01, + // uint64_t segment base offset + 0x72, 0x62, 0x52, 0x42, 0x32, 0x22, 0x12, 0x02, + // uint64_t segment data size + 0x73, 0x63, 0x53, 0x43, 0x33, 0x23, 0x13, 0x03, +}; // clang-format on /// The program_size field encoded in kExampleHeaderData. Each byte is unique @@ -55,6 +71,9 @@ constexpr uint64_t kExampleProgramSize = 0x0111213141516171; /// unique within the header data. constexpr uint64_t kExampleSegmentBaseOffset = 0x0212223242526272; +/// The segment_data_size field encoded in kExampleHeaderData. Each byte is +/// unique within the header data. +constexpr uint64_t kExampleSegmentDataSize = 0x0313233343536373; /// The offset to the header's length field, which is in the 4 bytes after the /// magic. constexpr size_t kHeaderLengthOffset = @@ -64,22 +83,43 @@ constexpr size_t kHeaderLengthOffset = * Returns fake serialized Program head data that contains kExampleHeaderData at * the expected offset. */ -std::vector CreateExampleProgramHead() { +std::vector CreateExampleProgramHead( + const char* example, + size_t size) { // Allocate memory representing the head of the serialized Program. std::vector ret(ExtendedHeader::kNumHeadBytes); // Write non-zeros into it to make it more obvious if we read outside the // header. memset(ret.data(), 0x55, ret.size()); // Copy the example header into the right offset. - memcpy( - ret.data() + ExtendedHeader::kHeaderOffset, - kExampleHeaderData, - sizeof(kExampleHeaderData)); + memcpy(ret.data() + ExtendedHeader::kHeaderOffset, example, size); return ret; } TEST_F(ExtendedHeaderTest, ValidHeaderParsesCorrectly) { - std::vector program = CreateExampleProgramHead(); + std::vector program = + CreateExampleProgramHead(kExampleHeaderData, sizeof(kExampleHeaderData)); + + Result header = + ExtendedHeader::Parse(program.data(), program.size()); + + // The header should be present. + ASSERT_EQ(header.error(), Error::Ok); + + // Expect this header has size 24. + EXPECT_EQ(program[kHeaderLengthOffset], 0x18); + + // Since each byte of these fields is unique, success demonstrates that the + // endian-to-int conversion is correct and looks at the expected bytes of the + // header. + EXPECT_EQ(header->program_size, kExampleProgramSize); + EXPECT_EQ(header->segment_base_offset, kExampleSegmentBaseOffset); + EXPECT_EQ(header->segment_data_size, 0); +} + +TEST_F(ExtendedHeaderTest, ValidHeaderParsesCorrectly_ExtendedExample) { + std::vector program = CreateExampleProgramHead( + kExampleHeaderDataExtended, sizeof(kExampleHeaderDataExtended)); Result header = ExtendedHeader::Parse(program.data(), program.size()); @@ -87,15 +127,20 @@ TEST_F(ExtendedHeaderTest, ValidHeaderParsesCorrectly) { // The header should be present. ASSERT_EQ(header.error(), Error::Ok); + // Expect this header has size 32. + EXPECT_EQ(program[kHeaderLengthOffset], 0x20); + // Since each byte of these fields is unique, success demonstrates that the // endian-to-int conversion is correct and looks at the expected bytes of the // header. EXPECT_EQ(header->program_size, kExampleProgramSize); EXPECT_EQ(header->segment_base_offset, kExampleSegmentBaseOffset); + EXPECT_EQ(header->segment_data_size, kExampleSegmentDataSize); } TEST_F(ExtendedHeaderTest, ShortDataFails) { - std::vector program = CreateExampleProgramHead(); + std::vector program = + CreateExampleProgramHead(kExampleHeaderData, sizeof(kExampleHeaderData)); // Try parsing a smaller-than-required part of the data. ASSERT_GE(program.size(), ExtendedHeader::kNumHeadBytes); @@ -119,7 +164,8 @@ TEST_F(ExtendedHeaderTest, MissingHeaderNotFound) { TEST_F(ExtendedHeaderTest, BadMagicTreatedAsMissing) { // Get a valid header. - std::vector program = CreateExampleProgramHead(); + std::vector program = + CreateExampleProgramHead(kExampleHeaderData, sizeof(kExampleHeaderData)); // Should be present. { @@ -141,7 +187,8 @@ TEST_F(ExtendedHeaderTest, BadMagicTreatedAsMissing) { TEST_F(ExtendedHeaderTest, ShorterHeaderLengthFails) { // Get a valid header. - std::vector program = CreateExampleProgramHead(); + std::vector program = + CreateExampleProgramHead(kExampleHeaderData, sizeof(kExampleHeaderData)); // Should be present. { @@ -165,7 +212,8 @@ TEST_F(ExtendedHeaderTest, ShorterHeaderLengthFails) { TEST_F(ExtendedHeaderTest, LongerHeaderLengthSucceeds) { // Get a valid header. - std::vector program = CreateExampleProgramHead(); + std::vector program = + CreateExampleProgramHead(kExampleHeaderData, sizeof(kExampleHeaderData)); // Make the header length larger. // First demonstrate that we're looking in the right place.