diff --git a/extension/flat_tensor/serialize/serialize.cpp b/extension/flat_tensor/serialize/serialize.cpp index 06b9f7f0d24..2b91de271bb 100644 --- a/extension/flat_tensor/serialize/serialize.cpp +++ b/extension/flat_tensor/serialize/serialize.cpp @@ -109,7 +109,8 @@ runtime::Error save_ptd( tensor_alignment, builder.CreateVector(tensors), builder.CreateVector(buffers)); - builder.Finish(flat_tensor); // Our flatbuffer is created now. + builder.Finish(flat_tensor, ::flat_tensor_flatbuffer::FlatTensorIdentifier()); + // Our flatbuffer is created now. // Calculate flatbuffer padding. auto padded_flatbufer_size = @@ -117,6 +118,30 @@ runtime::Error save_ptd( auto padded_header_size = aligned_size(FlatTensorHeader::kHeaderExpectedLength, tensor_alignment); + // The general structure of the file is: + // [flatbuffer offset to root table][flatbuffer file indentifier] + // [FlatTensorHeader][padding][flatbuffer contents][padding] + // [segment data]. + // This means we first serialize the first 8 bytes of the flatbuffer, + // updating the offset to the root table, then the header, then the + // flatbuffer. We are embedding the header inside the flatbuffer doing + // this which allows us to continue using flatbuffer tools directly on the + // .ptd file. + + // Calculate new offset to root table. + uint32_t current_offset = + *reinterpret_cast(builder.GetBufferPointer()); + uint32_t new_offset = current_offset + padded_header_size; + + // Write flatbuffer offset to root table + out.write(reinterpret_cast(&new_offset), sizeof(new_offset)); + + // Write flatbuffer magic bytes + out.write( + reinterpret_cast(builder.GetBufferPointer()) + + sizeof(new_offset), + 4); // This is the file identifier from flat_tensor.fbs. + // Write header out.write(FlatTensorHeader::kMagic, sizeof(FlatTensorHeader::kMagic)); out.write( @@ -149,10 +174,11 @@ runtime::Error save_ptd( padding_required( FlatTensorHeader::kHeaderExpectedLength, tensor_alignment)); - // Write flatbuffer + // Write flatbuffer, offset by 8 bytes (4-byte root table offset + 4-byte + // file identifier) since we wrote those before the FlatTensorHeader. out.write( - reinterpret_cast(builder.GetBufferPointer()), - builder.GetSize()); + reinterpret_cast(builder.GetBufferPointer()) + 8, + builder.GetSize() - 8); // Write flatbuffer padding write_nulls(out, padding_required(builder.GetSize(), tensor_alignment)); diff --git a/extension/flat_tensor/serialize/serialize.py b/extension/flat_tensor/serialize/serialize.py index 22374345cf7..abdccc17dec 100644 --- a/extension/flat_tensor/serialize/serialize.py +++ b/extension/flat_tensor/serialize/serialize.py @@ -17,6 +17,7 @@ from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile +from executorch.exir._serialize._program import _insert_flatbuffer_header from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required @@ -197,6 +198,17 @@ def to_bytes(self) -> bytes: return data +def _get_extended_header(flat_tensor_data: bytes) -> Optional[FlatTensorHeader]: + """Returns the extended header of the flat_tensor data, if present and valid.""" + try: + eh = FlatTensorHeader.from_bytes(flat_tensor_data[8:]) + if eh.is_valid(): + return eh + except ValueError: + pass + return None + + class FlatTensorSerializer(DataSerializer): """A concrete implementation of the DataSerializer interface that serializes and deserializes data to/from the FlatTensor format. @@ -299,14 +311,29 @@ def serialize( # Pad header and payload to segment alignment. header_data = pad_to(header_data, padded_header_length) + original_flatbuffer_payload_size = len(flatbuffer_payload) flatbuffer_payload.append( b"\x00" * (padded_flatbuffer_length - len(flatbuffer_payload)) ) + injected_flatbuffer_data: bytes = _insert_flatbuffer_header( + flatbuffer_data=flatbuffer_payload.__bytes__(), + magic_regex=r"FT[0-9a-zA-Z][0-9a-zA-Z]", + header_data=header_data, + ) + + eh = _get_extended_header(injected_flatbuffer_data) + assert eh is not None + assert eh.flatbuffer_size == original_flatbuffer_payload_size + assert eh.segment_base_offset == segment_base_offset + assert eh.flatbuffer_offset == padded_header_length + assert eh.segment_data_size == len(flat_tensor_data) + + del header_data + del flatbuffer_payload # Place everything into one segment. payload = Cord() - payload.append(header_data) - payload.append(flatbuffer_payload) + payload.append(injected_flatbuffer_data) payload.append(flat_tensor_data) return payload diff --git a/extension/flat_tensor/test/test_serialize.cpp b/extension/flat_tensor/test/test_serialize.cpp index ddb25857d59..a4091a5b9d3 100644 --- a/extension/flat_tensor/test/test_serialize.cpp +++ b/extension/flat_tensor/test/test_serialize.cpp @@ -53,35 +53,46 @@ TEST_F(FlatTensorSerializeTest, ValidFlatTensorSerialized) { auto x = buf.str(); const char* byte_buffer = x.c_str(); - // Check Magic - EXPECT_EQ(byte_buffer[0], 'F'); - EXPECT_EQ(byte_buffer[1], 'H'); - EXPECT_EQ(byte_buffer[2], '0'); - EXPECT_EQ(byte_buffer[3], '1'); + // First 4 bytes are an offset to the flatbuffer root table. + + // Check magic ids. + EXPECT_EQ(byte_buffer[4], 'F'); + EXPECT_EQ(byte_buffer[5], 'T'); + ASSERT_EQ(byte_buffer[6], '0'); + ASSERT_EQ(byte_buffer[7], '1'); + + ASSERT_EQ(byte_buffer[8], 'F'); + ASSERT_EQ(byte_buffer[9], 'H'); + EXPECT_EQ(byte_buffer[10], '0'); + EXPECT_EQ(byte_buffer[11], '1'); // Check Header - EXPECT_EQ( // Header length - *(uint32_t*)(byte_buffer + 4), + auto header_buffer = byte_buffer + 8; + EXPECT_EQ( // Check expected length + *(uint32_t*)(header_buffer + 4), executorch::extension::FlatTensorHeader::kHeaderExpectedLength); + EXPECT_EQ( - *(uint64_t*)(byte_buffer + 8), - 48); // Flatbuffer offset, header is 40 bytes + 8 bytes of padding today, - // and then the flatbuffer starts. + *(uint64_t*)(header_buffer + 8), + 48); // Flatbuffer offset, header is 40 bytes + 8 bytes of padding + // today, and then the flatbuffer starts. + EXPECT_EQ( - *(uint64_t*)(byte_buffer + 16), - 224); // Flatbuffer size, This is fragile, and depends on the schema, the - // builder, and the padding needed. - const uint64_t segment_offset = 48 + - 224; // Segment offset, depends on the padded header and flatbuffer sizes. - EXPECT_EQ(*(uint64_t*)(byte_buffer + 24), segment_offset); + *(uint64_t*)(header_buffer + 16), + 232); // Flatbuffer size. This is fragile, and depends on the schema, + // the builder, and the padding needed. + + // Segment offset, depends on the padded header and flatbuffer sizes. + const uint64_t segment_offset = 48 + 232 + 8; // 8 is padding. + EXPECT_EQ(*(uint64_t*)(header_buffer + 24), segment_offset); EXPECT_EQ( - *(uint64_t*)(byte_buffer + 32), + *(uint64_t*)(header_buffer + 32), 20); // Segment total size, 8 bytes of data (2 floats), 24 bytes of // padding. // Check Flatbuffer - auto flat_tensor = ::flat_tensor_flatbuffer::GetFlatTensor(byte_buffer + 48); + auto flat_tensor = ::flat_tensor_flatbuffer::GetFlatTensor(byte_buffer); EXPECT_EQ( flat_tensor->version(), diff --git a/extension/flat_tensor/test/test_serialize.py b/extension/flat_tensor/test/test_serialize.py index 57dbdb8c192..d32eac1a72c 100644 --- a/extension/flat_tensor/test/test_serialize.py +++ b/extension/flat_tensor/test/test_serialize.py @@ -80,7 +80,7 @@ def test_serialize(self) -> None: # Check header. header = FlatTensorHeader.from_bytes( - serialized_data[0 : FlatTensorHeader.EXPECTED_LENGTH] + serialized_data[8 : FlatTensorHeader.EXPECTED_LENGTH + 8] ) self.assertTrue(header.is_valid()) @@ -107,15 +107,13 @@ def test_serialize(self) -> None: # Confirm the flatbuffer magic is present. self.assertEqual( - serialized_data[ - header.flatbuffer_offset + 4 : header.flatbuffer_offset + 8 - ], + serialized_data[4:8], b"FT01", ) # Check flat tensor data. flat_tensor_bytes = serialized_data[ - header.flatbuffer_offset : header.flatbuffer_offset + header.flatbuffer_size + 0 : header.flatbuffer_offset + header.flatbuffer_size ] flat_tensor = _deserialize_to_flat_tensor(flat_tensor_bytes)