Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 52 additions & 52 deletions exir/_serialize/_named_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,32 @@
# pyre-strict

import hashlib
import math
from dataclasses import dataclass

# from dataclasses import dataclass
from typing import Dict, List, Optional


@dataclass
class BufferEntry:
"""A class to hold the buffer entries for serialization.

Attributes:
buffer: The buffer bytes.
alignment: The alignment of the buffer.
"""

buffer: bytes
alignment: int
from executorch.exir._serialize.data_serializer import DataEntry
from executorch.exir.tensor_layout import TensorLayout


@dataclass
class NamedDataStoreOutput:
"""
Holds named data for serialization.
Holds named data for serialization. Note: a DataEntry contains the index into
`buffers`, the alignment and a tensor layout, if applicable.

Attributes:
buffers: A list of unique buffer entries.
pte_data: Contains data that is stored inside the PTE file. A mapping from
{key: buffer_index}.
{key: DataEntry}.
external_data: Contains data that is stored external to the PTE. A mapping
from {filename: {key: buffer_index}}.
from {filename: {key: DataEntry}}.
"""

buffers: List[BufferEntry]
pte_data: Dict[str, int]
external_data: Dict[str, Dict[str, int]]
buffers: List[bytes]
pte_data: Dict[str, DataEntry]
external_data: Dict[str, Dict[str, DataEntry]]


class NamedDataStore:
Expand All @@ -61,12 +51,12 @@ class NamedDataStore:
"""

# List of unique blobs.
buffers: List[BufferEntry]
# Named data stored inside the PTE file. Map of {key: buffer_index}.
pte_data: Dict[str, int]
buffers: List[bytes]
# Named data stored inside the PTE file. Map of {key: DataEntry}.
pte_data: Dict[str, DataEntry]
# Named data stored outside of the PTE file.
# Map of {filename: {key: buffer_index}}.
external_data: Dict[str, Dict[str, int]]
# Map of {filename: {key: DataEntry}}.
external_data: Dict[str, Dict[str, DataEntry]]

# Cache of the data hash for deduplication.
# Use a hash instead of the data as a key because a sha256 collision is
Expand All @@ -93,7 +83,8 @@ def _add_named_data_to_map(
key: str,
data: bytes,
alignment: int,
local_key_to_buffer_idx: Dict[str, int],
local_key_to_buffer_idx: Dict[str, DataEntry],
tensor_layout: Optional[TensorLayout] = None,
) -> None:
"""
Add data to a map and update the alignment. Ensure that the key-data
Expand All @@ -116,33 +107,31 @@ def _add_named_data_to_map(

# Check if the key exists.
buffer_idx = self.key_to_buffer_idx.get(key, -1)
if buffer_idx != -1:
# If the key exists, the corresponding data must be identical.
if self.data_hash_to_buffer_idx.get(hashed, -1) != buffer_idx:
raise ValueError(
f"Duplicate key {key} with different data. "
f"Existing data: {self.buffers[buffer_idx].buffer}. "
f"New data: {data}."
)
self.buffers[buffer_idx].alignment = math.lcm(
self.buffers[buffer_idx].alignment, alignment
# If the key exists, the corresponding data must be identical.
if (
buffer_idx != -1
and self.data_hash_to_buffer_idx.get(hashed, -1) != buffer_idx
):
raise ValueError(
f"Duplicate key {key} with different data. "
f"Existing data: {self.buffers[buffer_idx]}. "
f"New data: {data}."
)
else:
# Key doesn't exist; check if the data exists.
buffer_idx = self.data_hash_to_buffer_idx.get(hashed, -1)
if buffer_idx != -1:
# The data exists; update the alignment.
self.buffers[buffer_idx].alignment = math.lcm(
self.buffers[buffer_idx].alignment, alignment
)
else:
if buffer_idx == -1:
# The data doesn't exist; add it to the data store.
buffer_idx = len(self.buffers)
self.buffers.append(BufferEntry(data, alignment))
self.buffers.append(data)
self.data_hash_to_buffer_idx[hashed] = buffer_idx

# Add key to the map and the key cache.
local_key_to_buffer_idx[key] = buffer_idx
local_key_to_buffer_idx[key] = DataEntry(
buffer_index=buffer_idx,
alignment=alignment,
tensor_layout=tensor_layout,
)
self.key_to_buffer_idx[key] = buffer_idx

def add_named_data(
Expand All @@ -151,6 +140,7 @@ def add_named_data(
data: bytes,
alignment: Optional[int] = 1,
external_tag: Optional[str] = None,
tensor_layout: Optional[TensorLayout] = None,
) -> None:
"""
Adds a named blob to the NamedDataStore.
Expand All @@ -159,6 +149,7 @@ def add_named_data(
data (bytes): Bytes being requested to be serialized.
alignment (int): alignment for bytes to be serialized with.
external (Optional[str]): the external filename that this data is saved to.
tensor_layout (Optional[TensorLayout]): layout of the tensor, if applicable.
Raises:
ValueError: when the key exists in the store, and corresponding data
is different.
Expand All @@ -171,10 +162,16 @@ def add_named_data(
raise ValueError(f"Alignment must be greater than 0, received {alignment}.")

if external_tag is None:
self._add_named_data_to_map(key, data, alignment, self.pte_data)
self._add_named_data_to_map(
key, data, alignment, self.pte_data, tensor_layout
)
else:
self._add_named_data_to_map(
key, data, alignment, self.external_data.setdefault(external_tag, {})
key,
data,
alignment,
self.external_data.setdefault(external_tag, {}),
tensor_layout,
)

def get_named_data_store_output(self) -> NamedDataStoreOutput:
Expand All @@ -192,19 +189,22 @@ def merge_named_data_store(self, other: NamedDataStoreOutput) -> None:
data is different between them.
"""
# Merge the pte_data.
for key, buffer_idx in other.pte_data.items():
for key, data_entry in other.pte_data.items():
self.add_named_data(
key,
other.buffers[buffer_idx].buffer,
other.buffers[buffer_idx].alignment,
other.buffers[data_entry.buffer_index],
data_entry.alignment,
external_tag=None,
tensor_layout=data_entry.tensor_layout,
)

# Merge the external_data.
for filename, key_to_buffer_idx in other.external_data.items():
for key, buffer_idx in key_to_buffer_idx.items():
for filename, key_to_data_entry in other.external_data.items():
for key, data_entry in key_to_data_entry.items():
self.add_named_data(
key,
other.buffers[buffer_idx].buffer,
other.buffers[buffer_idx].alignment,
other.buffers[data_entry.buffer_index],
data_entry.alignment,
external_tag=filename,
tensor_layout=data_entry.tensor_layout,
)
23 changes: 11 additions & 12 deletions exir/_serialize/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import re

from dataclasses import dataclass
from typing import ClassVar, Dict, List, Literal, Optional, Tuple
from typing import ClassVar, Dict, List, Literal, Optional, Sequence, Tuple

from executorch.exir._serialize._cord import Cord
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
Expand All @@ -21,10 +21,9 @@
_program_flatbuffer_to_json,
_program_json_to_flatbuffer,
)
from executorch.exir._serialize._named_data_store import (
BufferEntry,
NamedDataStoreOutput,
)
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput

from executorch.exir._serialize.data_serializer import DataEntry

from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required

Expand Down Expand Up @@ -368,8 +367,8 @@ def _extract_constant_segment(
def _extract_named_data(
program: Program,
segments: List[AlignedData],
buffers: List[BufferEntry],
name_to_buffer_idx: Dict[str, int],
buffers: Sequence[bytes],
name_to_data_entry: Dict[str, DataEntry],
) -> None:
"""Modifies the program in-place to add references to the named data
segments.
Expand All @@ -379,7 +378,7 @@ def _extract_named_data(
segments: A list of buffers to append extracted segments to. Modified in-place.
buffers: A list of unique buffers and the information required to
serialize them. Not modified.
name_to_buffer_idx: A map from the name of a blob to the index in buffers.
name_to_data_entry: A map from the blob name to DataEntry.
Not modified.
"""
if program.named_data is not None and len(program.named_data) > 0:
Expand All @@ -389,14 +388,14 @@ def _extract_named_data(
segment_index_map: Dict[int, int] = {}

named_data: List[NamedData] = []
for name, buffer_idx in name_to_buffer_idx.items():
segment_index = segment_index_map.get(buffer_idx, None)
for name, data_entry in name_to_data_entry.items():
segment_index = segment_index_map.get(data_entry.buffer_index, None)
if segment_index is None:
segment_index = len(segments)
segment_index_map[buffer_idx] = segment_index
segment_index_map[data_entry.buffer_index] = segment_index
segments.append(
AlignedData(
Cord(buffers[buffer_idx].buffer), buffers[buffer_idx].alignment
Cord(buffers[data_entry.buffer_index]), data_entry.alignment
)
)
named_data.append(NamedData(key=name, segment_index=segment_index))
Expand Down
14 changes: 7 additions & 7 deletions exir/_serialize/_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,18 @@ def serialize_for_executorch(
)
buffers.append(emitter_output.external_constant_buffer[index])

# Extract external data.
# Extract external data from named_data_store.
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`.
key_to_buffer_index = named_data_store.external_data.get(tag, {})
for key, index in key_to_buffer_index.items():
blob_to_data_entry = named_data_store.external_data.get(tag, {})
for key, data_entry in blob_to_data_entry.items():
assert key not in key_to_data_entry # key must be unique
key_to_data_entry[key] = DataEntry(
buffer_index=len(buffers),
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`.
alignment=named_data_store.buffers[index].alignment,
tensor_layout=None,
alignment=data_entry.alignment,
tensor_layout=data_entry.tensor_layout,
)
buffers.append(named_data_store.buffers[index].buffer)
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`.
buffers.append(named_data_store.buffers[data_entry.buffer_index])

# Serialize into PTD file.
ptd_files[tag] = data_serializer.serialize(
Expand Down
53 changes: 36 additions & 17 deletions exir/_serialize/test/test_named_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

import unittest

from executorch.exir._serialize._named_data_store import BufferEntry, NamedDataStore
from executorch.exir._serialize._named_data_store import NamedDataStore
from executorch.exir._serialize.data_serializer import DataEntry
from executorch.exir.scalar_type import ScalarType
from executorch.exir.tensor_layout import TensorLayout


class TestNamedDataStore(unittest.TestCase):
Expand All @@ -21,17 +24,17 @@ def test_add(self) -> None:
output = store.get_named_data_store_output()

self.assertEqual(len(output.buffers), 3)
self.assertEqual(output.buffers[0], BufferEntry(b"data1", 1))
self.assertEqual(output.buffers[1], BufferEntry(b"data2", 16))
self.assertEqual(output.buffers[2], BufferEntry(b"data3", 16))
self.assertEqual(output.buffers[0], b"data1")
self.assertEqual(output.buffers[1], b"data2")
self.assertEqual(output.buffers[2], b"data3")

self.assertEqual(len(output.pte_data), 1)
self.assertEqual(output.pte_data["key1"], 0)
self.assertEqual(output.pte_data["key1"], DataEntry(0, 1, None))

self.assertEqual(len(output.external_data), 1)
self.assertEqual(len(output.external_data["file1"]), 2)
self.assertEqual(output.external_data["file1"]["key2"], 1)
self.assertEqual(output.external_data["file1"]["key3"], 2)
self.assertEqual(output.external_data["file1"]["key2"], DataEntry(1, 16, None))
self.assertEqual(output.external_data["file1"]["key3"], DataEntry(2, 16, None))

def test_add_duplicate_name_and_data(self) -> None:
store = NamedDataStore()
Expand All @@ -41,10 +44,10 @@ def test_add_duplicate_name_and_data(self) -> None:
output = store.get_named_data_store_output()

self.assertEqual(len(output.buffers), 1)
self.assertEqual(output.buffers[0], BufferEntry(b"data", 1))
self.assertEqual(output.buffers[0], b"data")

self.assertEqual(len(output.pte_data), 1)
self.assertEqual(output.pte_data["key"], 0)
self.assertEqual(output.pte_data["key"], DataEntry(0, 1, None))

self.assertEqual(len(output.external_data), 0)

Expand All @@ -56,12 +59,11 @@ def test_add_same_data_with_different_alignment(self) -> None:
output = store.get_named_data_store_output()

self.assertEqual(len(output.buffers), 1)
# Check that we take the LCM of the two alignments (3, 4) = 12
self.assertEqual(output.buffers[0], BufferEntry(b"data", 12))
self.assertEqual(output.buffers[0], b"data")

self.assertEqual(len(output.pte_data), 2)
self.assertEqual(output.pte_data["key"], 0)
self.assertEqual(output.pte_data["key1"], 0)
self.assertEqual(output.pte_data["key"], DataEntry(0, 3, None))
self.assertEqual(output.pte_data["key1"], DataEntry(0, 4, None))

self.assertEqual(len(output.external_data), 0)

Expand All @@ -78,15 +80,30 @@ def test_add_duplicate_key_fail(self) -> None:
output = store.get_named_data_store_output()

self.assertEqual(len(output.buffers), 1)
self.assertEqual(output.buffers[0], BufferEntry(b"data", 1))
self.assertEqual(output.buffers[0], b"data")

self.assertEqual(len(output.pte_data), 1)
self.assertEqual(output.pte_data["key"], 0)
self.assertEqual(output.pte_data["key"], DataEntry(0, 1, None))
self.assertEqual(len(output.external_data), 0)

def test_add_same_data_with_different_tensor_layout(self) -> None:
store = NamedDataStore()
tensor_layout1 = TensorLayout(ScalarType.FLOAT, [1, 2], [0, 1])
tensor_layout2 = TensorLayout(ScalarType.FLOAT, [2, 1], [0, 1])
store.add_named_data("key", b"data", None, None, tensor_layout1)
store.add_named_data("key1", b"data", None, None, tensor_layout2)

output = store.get_named_data_store_output()
self.assertEqual(len(output.buffers), 1)
self.assertEqual(output.buffers[0], b"data")

self.assertEqual(output.pte_data["key"], DataEntry(0, 1, tensor_layout1))
self.assertEqual(output.pte_data["key1"], DataEntry(0, 1, tensor_layout2))

def test_merge(self) -> None:
store1 = NamedDataStore()
store1.add_named_data("key1", b"data1", None, None)
tensor_layout1 = TensorLayout(ScalarType.FLOAT, [1, 2], [0, 1])
store1.add_named_data("key1", b"data1", None, None, tensor_layout1)
store1.add_named_data("key2", b"data2", 16, "file1")

# Check items in the store1.
Expand All @@ -97,7 +114,7 @@ def test_merge(self) -> None:
self.assertEqual(len(output.external_data["file1"]), 1)

store2 = NamedDataStore()
store2.add_named_data("key1", b"data1", None, None)
store2.add_named_data("key1", b"data1", None, None, tensor_layout1)
store2.add_named_data("key3", b"data3", None, None)
store2.add_named_data("key4", b"data4", 16, "file1")
store2.add_named_data("key5", b"data5", 16, "file2")
Expand All @@ -118,6 +135,8 @@ def test_merge(self) -> None:
# key1, data1 exist in both store1 and store2, so we only have one copy of it.
self.assertEqual(len(output.buffers), 5)
self.assertEqual(len(output.pte_data), 2)
# Confirm DataEntry is correct.
self.assertEqual(output.pte_data["key1"], DataEntry(0, 1, tensor_layout1))
self.assertEqual(len(output.external_data), 2)
self.assertEqual(len(output.external_data["file1"]), 2)
self.assertEqual(len(output.external_data["file2"]), 1)
Expand Down
Loading
Loading