From 5edbffda4124904a4a1fd28f6732f7682a311a50 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Wed, 1 Oct 2025 23:56:52 -0700 Subject: [PATCH] Deserialize with named data store As titled --- exir/_serialize/_dataclass.py | 7 +++-- exir/_serialize/_program.py | 44 +++++++++++++++++++++++++++- exir/_serialize/test/test_program.py | 32 ++++++++++++++++++++ 3 files changed, 80 insertions(+), 3 deletions(-) diff --git a/exir/_serialize/_dataclass.py b/exir/_serialize/_dataclass.py index 013d733bcda..bd4c08c7746 100644 --- a/exir/_serialize/_dataclass.py +++ b/exir/_serialize/_dataclass.py @@ -57,7 +57,7 @@ def _get_class_from_union(json_dict: Dict[str, Any], key: str, cls: Any) -> Any: # pyre-ignore -def _json_to_dataclass(json_dict: Dict[str, Any], cls: Any = None) -> Any: +def _json_to_dataclass(json_dict: Dict[str, Any], cls: Any = None) -> Any: # noqa: C901 """Initializes a dataclass given a dictionary loaded from a json, `json_dict`, and the expected class, `cls`, by iterating through the fields of the class and retrieving the data for each. If there is a field that is @@ -139,7 +139,10 @@ class Example # If T is an enum then lookup the value in the enum otherwise try to # cast value to whatever type is required if isinstance(T, enum.EnumMeta): - data[key] = T[value] + if isinstance(value, str): + data[key] = T[value] + else: + data[key] = T(value) else: data[key] = T(value) return cls(**data) diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index 35a452c22ed..b9c5782ba1a 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -576,7 +576,7 @@ def serialize_pte_binary( return pte_data -def _restore_segments(program: Program, segment_data: bytes) -> Program: +def _restore_segments(program: Program, segment_data: bytes) -> Program: # noqa: C901 """Moves segments from `segment_data` into `program`. This should recreate the original Program that the segments were extracted @@ -641,6 +641,48 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program: program.constant_segment.segment_index = 0 program.constant_segment.offsets = [] + # Reconstruct named data blobs from segment data when present. + if program.named_data: + segment_to_buffer_index: Dict[int, int] = {} + named_buffers: List[BufferEntry] = [] + key_to_buffer_index: Dict[str, int] = {} + + for entry in program.named_data: + segment_index = entry.segment_index + if segment_index >= len(segments): + raise ValueError( + "Named data segment index " + f"{segment_index} >= num segments {len(segments)}" + ) + + buffer_index = segment_to_buffer_index.get(segment_index) + if buffer_index is None: + buffer_index = len(named_buffers) + segment_to_buffer_index[segment_index] = buffer_index + named_buffers.append( + BufferEntry(buffer=segments[segment_index], alignment=1) + ) + + key_to_buffer_index[entry.key] = buffer_index + + named_data_store = NamedDataStoreOutput( + buffers=named_buffers, + pte_data=key_to_buffer_index, + external_data={}, + ) + # Keep a convenient mapping from key to raw bytes for callers that only + # need to read the blobs. + setattr( # noqa: B010 + program, + "named_data_blobs", + { + key: named_data_store.buffers[idx].buffer + for key, idx in named_data_store.pte_data.items() + }, + ) + setattr(program, "named_data_store", named_data_store) # noqa: B010 + program.named_data = [] + # Clear out the segments list since the original Program didn't have one. program.segments = [] return program diff --git a/exir/_serialize/test/test_program.py b/exir/_serialize/test/test_program.py index 7ed83569169..5da9a8767b4 100644 --- a/exir/_serialize/test/test_program.py +++ b/exir/_serialize/test/test_program.py @@ -893,6 +893,23 @@ def test_constant_delegate_and_named_data_segments(self) -> None: self.assertEqual(program2.execution_plan, program.execution_plan) # Number of constant tensors should be the same. self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer)) + # Named data should be restored in a named data store and removed from the Program. + self.assertEqual(program2.named_data, []) + self.assertTrue(hasattr(program2, "named_data_store")) + named_store = program2.named_data_store + self.assertEqual(named_store.pte_data, pte_named_data) + # Buffers in the restored store should match the original serialized blobs. + restored = { + key: named_store.buffers[idx].buffer + for key, idx in named_store.pte_data.items() + } + original = { + key: named_data_buffers[buf_idx].buffer + for key, buf_idx in pte_named_data.items() + } + self.assertEqual(restored, original) + self.assertTrue(hasattr(program2, "named_data_blobs")) + self.assertEqual(program2.named_data_blobs, restored) def test_named_data_segments(self) -> None: # Set segment alignment to 12 to test the padding. @@ -997,6 +1014,21 @@ def test_named_data_segments(self) -> None: buffers[2].buffer, ) + program2 = deserialize_pte_binary(pte_data) + self.assertEqual(program2.named_data, []) + self.assertTrue(hasattr(program2, "named_data_store")) + store = program2.named_data_store + self.assertEqual(store.pte_data, pte_named_data) + restored_named_data = { + key: store.buffers[idx].buffer for key, idx in store.pte_data.items() + } + self.assertEqual( + restored_named_data, + {key: buffers[idx].buffer for key, idx in pte_named_data.items()}, + ) + self.assertTrue(hasattr(program2, "named_data_blobs")) + self.assertEqual(program2.named_data_blobs, restored_named_data) + # Common data for extended header tests. The two example values should produce # the example data.