From 4ec0731a035933fec0a2e3989e5658673a3520d1 Mon Sep 17 00:00:00 2001 From: lucylq Date: Tue, 10 Dec 2024 09:26:54 -0800 Subject: [PATCH] [executorch][flat_tensor] Serialize flat tensor tests More comprehensive testing for flat tensor serialization. Differential Revision: [D67007821](https://our.internmc.facebook.com/intern/diff/D67007821/) [ghstack-poisoned] --- extension/flat_tensor/serialize/serialize.py | 29 ++++++- extension/flat_tensor/test/test_serialize.py | 84 ++++++++++++++++++-- 2 files changed, 106 insertions(+), 7 deletions(-) diff --git a/extension/flat_tensor/serialize/serialize.py b/extension/flat_tensor/serialize/serialize.py index ff7a5123961..db6c06c5e6d 100644 --- a/extension/flat_tensor/serialize/serialize.py +++ b/extension/flat_tensor/serialize/serialize.py @@ -6,9 +6,9 @@ import pkg_resources from executorch.exir._serialize._cord import Cord -from executorch.exir._serialize._dataclass import _DataclassEncoder +from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass -from executorch.exir._serialize._flatbuffer import _flatc_compile +from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile from executorch.exir._serialize.data_serializer import DataSerializer, SerializationInfo from executorch.exir._serialize.utils import ( @@ -48,6 +48,31 @@ def _convert_to_flatbuffer(flat_tensor: FlatTensor) -> Cord: return Cord(output_file.read()) +def _convert_to_flat_tensor(flatbuffer: bytes) -> FlatTensor: + with tempfile.TemporaryDirectory() as d: + schema_path = os.path.join(d, "flat_tensor.fbs") + with open(schema_path, "wb") as schema_file: + schema_file.write( + pkg_resources.resource_string(__name__, "flat_tensor.fbs") + ) + + scalar_type_path = os.path.join(d, "scalar_type.fbs") + with open(scalar_type_path, "wb") as scalar_type_file: + scalar_type_file.write( + pkg_resources.resource_string(__name__, "scalar_type.fbs") + ) + + bin_path = os.path.join(d, "flat_tensor.bin") + with open(bin_path, "wb") as bin_file: + bin_file.write(flatbuffer) + + _flatc_decompile(d, schema_path, bin_path, ["--raw-binary"]) + + json_path = os.path.join(d, "flat_tensor.json") + with open(json_path, "rb") as output_file: + return _json_to_dataclass(json.load(output_file), cls=FlatTensor) + + @dataclass class FlatTensorConfig: tensor_alignment: int = 16 diff --git a/extension/flat_tensor/test/test_serialize.py b/extension/flat_tensor/test/test_serialize.py index ef8c6456921..e5e339e0f38 100644 --- a/extension/flat_tensor/test/test_serialize.py +++ b/extension/flat_tensor/test/test_serialize.py @@ -15,17 +15,21 @@ ) from executorch.exir.schema import ScalarType +from executorch.extension.flat_tensor.serialize.flat_tensor_schema import TensorMetadata from executorch.extension.flat_tensor.serialize.serialize import ( + _convert_to_flat_tensor, + FlatTensorConfig, FlatTensorHeader, FlatTensorSerializer, ) # Test artifacts -TEST_TENSOR_BUFFER = [b"tensor"] +TEST_TENSOR_BUFFER = [b"\x11"*4, b"\x22"*32] TEST_TENSOR_MAP = { "fqn1": 0, "fqn2": 0, + "fqn3": 1, } TEST_TENSOR_LAYOUT = { @@ -39,12 +43,25 @@ dim_sizes=[1, 1, 1], dim_order=typing.cast(List[bytes], [0, 1, 2]), ), + "fqn3": TensorLayout( + scalar_type=ScalarType.INT, + dim_sizes=[2, 2, 2], + dim_order=typing.cast(List[bytes], [0, 1]), + ), } class TestSerialize(unittest.TestCase): + def check_tensor_metadata( + self, tensor_layout: TensorLayout, tensor_metadata: TensorMetadata + ) -> None: + self.assertEqual(tensor_layout.scalar_type, tensor_metadata.scalar_type) + self.assertEqual(tensor_layout.dim_sizes, tensor_metadata.dim_sizes) + self.assertEqual(tensor_layout.dim_order, tensor_metadata.dim_order) + def test_serialize(self) -> None: - serializer: DataSerializer = FlatTensorSerializer() + config = FlatTensorConfig() + serializer: DataSerializer = FlatTensorSerializer(config) data = bytes( serializer.serialize_tensors( @@ -54,14 +71,71 @@ def test_serialize(self) -> None: ) ) + # Check header. header = FlatTensorHeader.from_bytes(data[0 : FlatTensorHeader.EXPECTED_LENGTH]) self.assertTrue(header.is_valid()) self.assertEqual(header.flatbuffer_offset, 48) - self.assertEqual(header.flatbuffer_size, 200) - self.assertEqual(header.segment_base_offset, 256) - self.assertEqual(header.data_size, 16) + self.assertEqual(header.flatbuffer_size, 288) + self.assertEqual(header.segment_base_offset, 336) + self.assertEqual(header.data_size, 48) self.assertEqual( data[header.flatbuffer_offset + 4 : header.flatbuffer_offset + 8], b"FT01" ) + + # Check flat tensor data. + flat_tensor_bytes = data[ + header.flatbuffer_offset : header.flatbuffer_offset + header.flatbuffer_size + ] + + flat_tensor = _convert_to_flat_tensor(flat_tensor_bytes) + + self.assertEqual(flat_tensor.version, 0) + self.assertEqual(flat_tensor.tensor_alignment, config.tensor_alignment) + + tensors = flat_tensor.tensors + self.assertEqual(len(tensors), 3) + self.assertEqual(tensors[0].fully_qualified_name, "fqn1") + self.check_tensor_metadata(TEST_TENSOR_LAYOUT["fqn1"], tensors[0]) + self.assertEqual(tensors[0].segment_index, 0) + self.assertEqual(tensors[0].offset, 0) + + self.assertEqual(tensors[1].fully_qualified_name, "fqn2") + self.check_tensor_metadata(TEST_TENSOR_LAYOUT["fqn2"], tensors[1]) + self.assertEqual(tensors[1].segment_index, 0) + self.assertEqual(tensors[1].offset, 0) + + self.assertEqual(tensors[2].fully_qualified_name, "fqn3") + self.check_tensor_metadata(TEST_TENSOR_LAYOUT["fqn3"], tensors[2]) + self.assertEqual(tensors[2].segment_index, 0) + self.assertEqual(tensors[2].offset, config.tensor_alignment) + + segments = flat_tensor.segments + self.assertEqual(len(segments), 1) + self.assertEqual(segments[0].offset, 0) + self.assertEqual(segments[0].size, config.tensor_alignment * 3) + + # Check segment data. + segment_data = data[ + header.segment_base_offset : header.segment_base_offset + segments[0].size + ] + + t0_start = 0 + t0_len = len(TEST_TENSOR_BUFFER[0]) + t0_end = config.tensor_alignment + self.assertEqual( + segment_data[t0_start : t0_start + t0_len], TEST_TENSOR_BUFFER[0] + ) + padding = b"\x00" * (t0_end - t0_len) + self.assertEqual(segment_data[t0_start + t0_len : t0_end], padding) + + t1_start = config.tensor_alignment + t1_len = len(TEST_TENSOR_BUFFER[1]) + t1_end = config.tensor_alignment * 3 + self.assertEqual( + segment_data[t1_start : t1_start + t1_len], + TEST_TENSOR_BUFFER[1], + ) + padding = b"\x00" * (t1_end - (t1_len + t1_start)) + self.assertEqual(segment_data[t1_start + t1_len : t1_end], padding)