Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions exir/_serialize/_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _get_class_from_union(json_dict: Dict[str, Any], key: str, cls: Any) -> Any:


# pyre-ignore
def _json_to_dataclass(json_dict: Dict[str, Any], cls: Any = None) -> Any:
def _json_to_dataclass(json_dict: Dict[str, Any], cls: Any = None) -> Any: # noqa: C901
"""Initializes a dataclass given a dictionary loaded from a json,
`json_dict`, and the expected class, `cls`, by iterating through the fields
of the class and retrieving the data for each. If there is a field that is
Expand Down Expand Up @@ -139,7 +139,10 @@ class Example
# If T is an enum then lookup the value in the enum otherwise try to
# cast value to whatever type is required
if isinstance(T, enum.EnumMeta):
data[key] = T[value]
if isinstance(value, str):
data[key] = T[value]
else:
data[key] = T(value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

else:
data[key] = T(value)
return cls(**data)
44 changes: 43 additions & 1 deletion exir/_serialize/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def serialize_pte_binary(
return pte_data


def _restore_segments(program: Program, segment_data: bytes) -> Program:
def _restore_segments(program: Program, segment_data: bytes) -> Program: # noqa: C901
"""Moves segments from `segment_data` into `program`.

This should recreate the original Program that the segments were extracted
Expand Down Expand Up @@ -641,6 +641,48 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
program.constant_segment.segment_index = 0
program.constant_segment.offsets = []

# Reconstruct named data blobs from segment data when present.
if program.named_data:
segment_to_buffer_index: Dict[int, int] = {}
named_buffers: List[BufferEntry] = []
key_to_buffer_index: Dict[str, int] = {}

for entry in program.named_data:
segment_index = entry.segment_index
if segment_index >= len(segments):
raise ValueError(
"Named data segment index "
f"{segment_index} >= num segments {len(segments)}"
)

buffer_index = segment_to_buffer_index.get(segment_index)
if buffer_index is None:
buffer_index = len(named_buffers)
segment_to_buffer_index[segment_index] = buffer_index
named_buffers.append(
BufferEntry(buffer=segments[segment_index], alignment=1)
)

key_to_buffer_index[entry.key] = buffer_index

named_data_store = NamedDataStoreOutput(
buffers=named_buffers,
pte_data=key_to_buffer_index,
external_data={},
)
# Keep a convenient mapping from key to raw bytes for callers that only
# need to read the blobs.
setattr( # noqa: B010
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we don't need this if we have the named_data_store in line 683.

Not sure about attaching named_data_store to program type though?

program,
"named_data_blobs",
{
key: named_data_store.buffers[idx].buffer
for key, idx in named_data_store.pte_data.items()
},
)
setattr(program, "named_data_store", named_data_store) # noqa: B010
program.named_data = []

# Clear out the segments list since the original Program didn't have one.
program.segments = []
return program
Expand Down
32 changes: 32 additions & 0 deletions exir/_serialize/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,23 @@ def test_constant_delegate_and_named_data_segments(self) -> None:
self.assertEqual(program2.execution_plan, program.execution_plan)
# Number of constant tensors should be the same.
self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer))
# Named data should be restored in a named data store and removed from the Program.
self.assertEqual(program2.named_data, [])
self.assertTrue(hasattr(program2, "named_data_store"))
named_store = program2.named_data_store
self.assertEqual(named_store.pte_data, pte_named_data)
# Buffers in the restored store should match the original serialized blobs.
restored = {
key: named_store.buffers[idx].buffer
for key, idx in named_store.pte_data.items()
}
original = {
key: named_data_buffers[buf_idx].buffer
for key, buf_idx in pte_named_data.items()
}
self.assertEqual(restored, original)
self.assertTrue(hasattr(program2, "named_data_blobs"))
self.assertEqual(program2.named_data_blobs, restored)

def test_named_data_segments(self) -> None:
# Set segment alignment to 12 to test the padding.
Expand Down Expand Up @@ -997,6 +1014,21 @@ def test_named_data_segments(self) -> None:
buffers[2].buffer,
)

program2 = deserialize_pte_binary(pte_data)
self.assertEqual(program2.named_data, [])
self.assertTrue(hasattr(program2, "named_data_store"))
store = program2.named_data_store
self.assertEqual(store.pte_data, pte_named_data)
restored_named_data = {
key: store.buffers[idx].buffer for key, idx in store.pte_data.items()
}
self.assertEqual(
restored_named_data,
{key: buffers[idx].buffer for key, idx in pte_named_data.items()},
)
self.assertTrue(hasattr(program2, "named_data_blobs"))
self.assertEqual(program2.named_data_blobs, restored_named_data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test the roundtrip --> can re-serialize the program with program2.named_data_blobs?



# Common data for extended header tests. The two example values should produce
# the example data.
Expand Down
Loading