From 2a3ac1f48a2d444735066a88757b222dc0b19bde Mon Sep 17 00:00:00 2001 From: Lucy Qiu Date: Mon, 9 Jun 2025 15:27:08 -0700 Subject: [PATCH] Refactor FlatTensor (#11499) Summary: This diff is the stack at D73209986 combined. Landing as one diff/PR as it contains BC-breaking changes throughout the stack. Reviewed By: JacobSzwejbka Differential Revision: D76285590 --- exir/_serialize/_serialize.py | 49 ++-- exir/_serialize/data_serializer.py | 64 ++--- .../flat_tensor/flat_tensor_data_map.cpp | 236 +++++++----------- extension/flat_tensor/flat_tensor_data_map.h | 16 +- .../flat_tensor/serialize/flat_tensor.fbs | 25 +- .../serialize/flat_tensor_schema.py | 11 +- extension/flat_tensor/serialize/serialize.cpp | 63 +++-- extension/flat_tensor/serialize/serialize.py | 96 ++----- .../test/flat_tensor_data_map_test.cpp | 12 +- extension/flat_tensor/test/test_serialize.cpp | 59 +++-- extension/flat_tensor/test/test_serialize.py | 175 +++++++------ extension/training/module/state_dict_util.cpp | 2 +- runtime/core/named_data_map.h | 26 +- runtime/executor/method.cpp | 2 +- runtime/executor/pte_data_map.cpp | 27 +- runtime/executor/pte_data_map.h | 13 +- runtime/executor/tensor_parser_exec_aten.cpp | 2 +- runtime/executor/test/pte_data_map_test.cpp | 4 +- 18 files changed, 368 insertions(+), 514 deletions(-) diff --git a/exir/_serialize/_serialize.py b/exir/_serialize/_serialize.py index 1b36dac1743..e2147458545 100644 --- a/exir/_serialize/_serialize.py +++ b/exir/_serialize/_serialize.py @@ -16,7 +16,6 @@ DataEntry, DataPayload, DataSerializer, - TensorEntry, TensorLayout, ) @@ -29,22 +28,22 @@ def serialize_for_executorch( emitter_output: EmitterOutput, config: ExecutorchBackendConfig, data_serializer: DataSerializer, - named_data: Optional[NamedDataStoreOutput] = None, + named_data_store: Optional[NamedDataStoreOutput] = None, ) -> Tuple[Cord, Dict[str, Cord]]: """Serialize the output from Emitter into ExecuTorch artifacts; PTE and PTD files.""" # Serialize PTE file. pte_named_data = None if ( - named_data is not None - and len(named_data.buffers) > 0 - and len(named_data.pte_data) > 0 + named_data_store is not None + and len(named_data_store.buffers) > 0 + and len(named_data_store.pte_data) > 0 ): # Create a separate NamedDataStoreOutput with only pte_data; exclude # external_data, which shouldn't be serialized with the PTE file. pte_named_data = NamedDataStoreOutput( - buffers=named_data.buffers, - pte_data=named_data.pte_data, + buffers=named_data_store.buffers, + pte_data=named_data_store.pte_data, external_data={}, ) pte: Cord = _serialize_pte_binary( @@ -72,22 +71,23 @@ def serialize_for_executorch( and tensor.extra_tensor_info.location is TensorDataLocation.EXTERNAL ): fqn_to_tensor_layout[ + # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name` tensor.extra_tensor_info.fully_qualified_name ] = TensorLayout(tensor.scalar_type, tensor.sizes, tensor.dim_order) if len(fqn_to_tensor_layout) == 0 and ( - named_data is None or len(named_data.external_data) == 0 + named_data_store is None or len(named_data_store.external_data) == 0 ): return pte, ptd_files # Consolidate tensors and opaque data with the same external tag so they # can be saved to the same PTD. all_external_tags: Set[str] = set() - if named_data is not None and len(named_data.external_data) > 0: + if named_data_store is not None and len(named_data_store.external_data) > 0: assert ( - len(named_data.buffers) > 0 + len(named_data_store.buffers) > 0 ), "External data exists, but there are no buffers provided." - all_external_tags = set(named_data.external_data.keys()) + all_external_tags = set(named_data_store.external_data.keys()) if len(fqn_to_tensor_layout) > 0: # emitter_output.external_constant_map contains the mapping from @@ -103,35 +103,38 @@ def serialize_for_executorch( for tag in all_external_tags: buffers = [] - fqn_to_tensor_entry: Dict[str, TensorEntry] = {} + key_to_data_entry: Dict[str, DataEntry] = {} # pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`. fqn_to_index = emitter_output.external_constant_map.get(tag, {}) - # Create a TensorEntry for each external tensor. + # Create a DataEntry for each external tensor. for fqn, index in fqn_to_index.items(): assert fqn in fqn_to_tensor_layout - fqn_to_tensor_entry[fqn] = TensorEntry( + assert fqn not in key_to_data_entry # fqn must be unique + key_to_data_entry[fqn] = DataEntry( buffer_index=len(buffers), - layout=fqn_to_tensor_layout[fqn], + alignment=config.constant_tensor_alignment, + tensor_layout=fqn_to_tensor_layout[fqn], ) buffers.append(emitter_output.external_constant_buffer[index]) # Extract external data. - key_to_data: Dict[str, DataEntry] = {} # pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`. - key_to_buffer_index = named_data.external_data.get(tag, {}) + key_to_buffer_index = named_data_store.external_data.get(tag, {}) for key, index in key_to_buffer_index.items(): - # pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`. - key_to_data[key] = DataEntry( - len(buffers), named_data.buffers[index].alignment + assert key not in key_to_data_entry # key must be unique + key_to_data_entry[key] = DataEntry( + buffer_index=len(buffers), + # pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`. + alignment=named_data_store.buffers[index].alignment, + tensor_layout=None, ) - buffers.append(named_data.buffers[index].buffer) + buffers.append(named_data_store.buffers[index].buffer) # Serialize into PTD file. ptd_files[tag] = data_serializer.serialize( DataPayload( buffers=buffers, - fqn_to_tensor=fqn_to_tensor_entry, - key_to_data=key_to_data, + named_data=key_to_data_entry, ) ) diff --git a/exir/_serialize/data_serializer.py b/exir/_serialize/data_serializer.py index e30fd2546d7..e828b4d0ae3 100644 --- a/exir/_serialize/data_serializer.py +++ b/exir/_serialize/data_serializer.py @@ -1,41 +1,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, List, Sequence +from typing import Dict, Optional, Sequence from executorch.exir._serialize._cord import Cord - -from executorch.exir.schema import ScalarType - - -@dataclass -class TensorLayout: - """Tensor layout information for externally-serialized tensors. - - Attributes: - scalar_type: type of the elements in the tensor. - sizes: size of each dim in the tensor. - dim_order: specifies the order the dimensions are laid out in memory, - from outer to inner. - """ - - scalar_type: ScalarType - sizes: List[int] - dim_order: List[int] - - -@dataclass -class TensorEntry: - """Represents a single tensor in `DataPayload`, specifying its location - and metadata. - - Attributes: - buffer_index: The index inside `DataPayload.buffers` that this - TensorEntry refers to. - layout: Metadata about the tensor. - """ - - buffer_index: int - layout: TensorLayout +from executorch.extension.flat_tensor.serialize.flat_tensor_schema import TensorLayout @dataclass @@ -47,10 +15,12 @@ class DataEntry: buffer_index: The index inside `DataPayload.buffers` that this DataEntry refers to. alignment: The alignment of the data. + tensor_layout: If this is a tensor, the tensor layout information. """ buffer_index: int alignment: int + tensor_layout: Optional[TensorLayout] @dataclass @@ -58,22 +28,20 @@ class DataPayload: """Contains the data and metadata required for serialization. Having an index-based arrangement instead of embedding the buffers in - TensorEntry allows the caller to deduplicate buffers and point multiple - fully qualified names (FQNs) to the same entry. + DataEntry allows the caller to deduplicate buffers and point multiple + keys to the same entry. Attributes: - buffers: a sequence of tensor buffers. - fqn_to_tensor: a map from fully qualified names to serializable tensors. - key_to_data: a map from unique keys to serializable opaque data. + buffers: a sequence of byte buffers. + key_to_data: a map from unique keys to serializable data. """ buffers: Sequence[bytes] - fqn_to_tensor: Dict[str, TensorEntry] - key_to_data: Dict[str, DataEntry] + named_data: Dict[str, DataEntry] class DataSerializer(ABC): - """Serializes and deserializes FQN-tagged tensor data. + """Serializes and deserializes data. Data can be referenced by a unique key. This base class enables serialization into different formats. See executorch/extension/flat_tensor/ for an example. @@ -85,11 +53,11 @@ def serialize( data: DataPayload, ) -> Cord: """ - Serializes a list of tensors emitted by ExecuTorch into a binary blob. + Serializes a list of bytes emitted by ExecuTorch into a binary blob. Args: - data: the tensor buffers and tensor layout information required for - serialization. + data: buffers and corresponding metadata used for serialization. + Returns: A binary blob that contains the serialized data. @@ -99,14 +67,14 @@ def serialize( @abstractmethod def deserialize(self, blob: Cord) -> DataPayload: """ - Deserializes a blob into a list of tensors. Reverses the effect of + Deserializes a blob into a DataPayload. Reverses the effect of serialize. Args: blob: A binary blob that contains the serialized data. Returns: - DataPayload: tensor buffers and tensor layout information - deserialized from `blob`. + DataPayload: buffers and corresponding metadata deserialized + from `blob`. """ raise NotImplementedError("deserialize_data") diff --git a/extension/flat_tensor/flat_tensor_data_map.cpp b/extension/flat_tensor/flat_tensor_data_map.cpp index c5590cb61b1..3a69dc8b92c 100644 --- a/extension/flat_tensor/flat_tensor_data_map.cpp +++ b/extension/flat_tensor/flat_tensor_data_map.cpp @@ -45,171 +45,121 @@ bool is_aligned(const void* data) { } Result get_named_data( - const char* key, + executorch::aten::string_view key, const flatbuffers::Vector< - flatbuffers::Offset>* named_data) { + flatbuffers::Offset>* named_data, + const flatbuffers::Vector< + flatbuffers::Offset>* segments, + size_t segment_end_offset) { // Linear search by name. if (named_data == nullptr) { return Error::NotFound; } for (int i = 0; i < named_data->size(); i++) { - if (std::strcmp(named_data->Get(i)->key()->c_str(), key) == 0) { - const auto* metadata = named_data->Get(i); + if (std::strncmp( + named_data->Get(i)->key()->c_str(), + key.data(), + named_data->Get(i)->key()->size()) == 0) { + const auto* found = named_data->Get(i); + // Validate the named_data. + size_t segment_index = found->segment_index(); ET_CHECK_OR_RETURN_ERROR( - metadata->segment_index() >= 0, + segment_index >= 0 && segment_index < segments->size(), InvalidExternalData, - "Invalid segment_index %d; malformed PTD file.", - metadata->segment_index()); - return metadata; - } - } - return Error::NotFound; -} - -Result get_flat_tensor_metadata( - const char* key, - const flatbuffers::Vector< - flatbuffers::Offset>* tensors) { - // Linear search by name. - for (int i = 0; i < tensors->size(); i++) { - if (std::strcmp(tensors->Get(i)->fully_qualified_name()->c_str(), key) == - 0) { - const auto* metadata = tensors->Get(i); + "Segment index %zu for key %.*s is out of bounds for segment size %d. Malformed PTD file.", + segment_index, + static_cast(key.size()), + key.data(), + segments->size()); + // Validate the segment. ET_CHECK_OR_RETURN_ERROR( - metadata->segment_index() >= 0 && metadata->offset() >= 0, + segments->Get(segment_index)->offset() < segment_end_offset, InvalidExternalData, - "Invalid segment_index %d or offset %" PRIu64 "; malformed PTD file.", - metadata->segment_index(), - metadata->offset()); - return metadata; + "Invalid segment offset %" PRIu64 + " is larger than the segment_base_offset + segment_data_size %" PRIu64 + "; malformed PTD file.", + segments->Get(segment_index)->offset(), + static_cast(segment_end_offset)); + return found; } } return Error::NotFound; } Result create_tensor_layout( - const flat_tensor_flatbuffer::TensorMetadata* tensor_metadata) { + const flat_tensor_flatbuffer::TensorLayout* tensor_layout) { ScalarType scalar_type = - static_cast(tensor_metadata->scalar_type()); - const int dim = tensor_metadata->sizes()->size(); - const auto serialized_sizes = tensor_metadata->sizes()->data(); - const auto serialized_dim_order = tensor_metadata->dim_order()->data(); + static_cast(tensor_layout->scalar_type()); + const int dim = tensor_layout->sizes()->size(); + const auto serialized_sizes = tensor_layout->sizes()->data(); + const auto serialized_dim_order = tensor_layout->dim_order()->data(); return TensorLayout::create( Span(serialized_sizes, dim), Span(serialized_dim_order, dim), scalar_type); } -Result get_and_check_segment_offset( - const flatbuffers::Vector< - flatbuffers::Offset>* segments, - const flat_tensor_flatbuffer::TensorMetadata* metadata) { - ET_CHECK_OR_RETURN_ERROR( - segments != nullptr, - InvalidExternalData, - "No segments in external data flatbuffer."); - - ET_CHECK_OR_RETURN_ERROR( - metadata->segment_index() < segments->size(), - InvalidExternalData, - "Invalid segment_index %d; malformed PTD file.", - metadata->segment_index()); - return segments->Get(metadata->segment_index())->offset(); -} - } // namespace -ET_NODISCARD Result FlatTensorDataMap::get_metadata( - const char* key) const { - Result metadata_res = - get_flat_tensor_metadata(key, flat_tensor_->tensors()); - if (!metadata_res.ok()) { - return metadata_res.error(); +ET_NODISCARD Result FlatTensorDataMap::get_tensor_layout( + executorch::aten::string_view key) const { + Result named_data = get_named_data( + key, + flat_tensor_->named_data(), + flat_tensor_->segments(), + header_.segment_base_offset + header_.segment_data_size); + if (!named_data.ok()) { + return named_data.error(); } - return create_tensor_layout(metadata_res.get()); + return create_tensor_layout(named_data.get()->tensor_layout()); } ET_NODISCARD Result FlatTensorDataMap::get_data( - const char* key) const { - // TODO(lfq): consolidate named_data and tensors. - // Check named data. - Result named_data = - get_named_data(key, flat_tensor_->named_data()); - if (named_data.ok()) { - size_t segment_index = named_data.get()->segment_index(); - ET_CHECK_OR_RETURN_ERROR( - segment_index < flat_tensor_->segments()->size(), - InvalidExternalData, - "Invalid segment_index %zu; malformed PTD file.", - segment_index); - - size_t segment_offset = - flat_tensor_->segments()->Get(segment_index)->offset(); - size_t segment_size = flat_tensor_->segments()->Get(segment_index)->size(); - ET_CHECK_OR_RETURN_ERROR( - segment_offset < - header_.segment_base_offset + header_.segment_data_size, - InvalidExternalData, - "Invalid segment offset %zu is larger than the segment_base_offset + segment_data_size %" PRIu64 - "; malformed PTD file.", - segment_offset, - header_.segment_base_offset + header_.segment_data_size); - return loader_->load( - /*offset=*/header_.segment_base_offset + segment_offset, - segment_size, - DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External)); - } - if (named_data.error() != Error::NotFound) { + executorch::aten::string_view key) const { + Result named_data = get_named_data( + key, + flat_tensor_->named_data(), + flat_tensor_->segments(), + header_.segment_base_offset + header_.segment_data_size); + if (!named_data.ok()) { return named_data.error(); } - // Check tensors, if named data is not found. - Result metadata = - get_flat_tensor_metadata(key, flat_tensor_->tensors()); - if (!metadata.ok()) { - return metadata.error(); - } - Result tensor_layout = - create_tensor_layout(metadata.get()); - if (!tensor_layout.ok()) { - return tensor_layout.error(); - } - Result segment_offset = - get_and_check_segment_offset(flat_tensor_->segments(), metadata.get()); - if (!segment_offset.ok()) { - return segment_offset.error(); - } + uint32_t segment_index = named_data.get()->segment_index(); + uint64_t segment_offset = + flat_tensor_->segments()->Get(segment_index)->offset(); + uint64_t segment_size = flat_tensor_->segments()->Get(segment_index)->size(); - // Load constant data. - ET_CHECK_OR_RETURN_ERROR( - segment_offset.get() < - header_.segment_base_offset + header_.segment_data_size, - InvalidExternalData, - "Invalid segment offset %d is larger than the segment_base_offset + segment_data_size %" PRIu64 - "; malformed PTD file.", - segment_offset.get(), - header_.segment_base_offset + header_.segment_data_size); return loader_->load( - header_.segment_base_offset + segment_offset.get() + - metadata.get()->offset(), - tensor_layout.get().nbytes(), + /*offset=*/header_.segment_base_offset + segment_offset, + segment_size, DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External)); } ET_NODISCARD Error FlatTensorDataMap::load_data_into( - ET_UNUSED const char* key, + ET_UNUSED executorch::aten::string_view key, ET_UNUSED void* buffer, ET_UNUSED size_t size) const { - Result metadata = - get_flat_tensor_metadata(key, flat_tensor_->tensors()); - if (!metadata.ok()) { - return metadata.error(); + Result named_data = get_named_data( + key, + flat_tensor_->named_data(), + flat_tensor_->segments(), + header_.segment_base_offset + header_.segment_data_size); + if (!named_data.ok()) { + return named_data.error(); } + + uint32_t segment_index = named_data.get()->segment_index(); + uint64_t segment_offset = + flat_tensor_->segments()->Get(segment_index)->offset(); + Result tensor_layout = - create_tensor_layout(metadata.get()); + create_tensor_layout(named_data.get()->tensor_layout()); + if (!tensor_layout.ok()) { return tensor_layout.error(); } + ET_CHECK_OR_RETURN_ERROR( size <= tensor_layout.get().nbytes(), InvalidArgument, @@ -217,51 +167,30 @@ ET_NODISCARD Error FlatTensorDataMap::load_data_into( size, tensor_layout.get().nbytes()); - Result segment_offset = - get_and_check_segment_offset(flat_tensor_->segments(), metadata.get()); - if (!segment_offset.ok()) { - return segment_offset.error(); - } // Load mutable data. DataLoader::SegmentInfo info = DataLoader::SegmentInfo( DataLoader::SegmentInfo::Type::Mutable, 0, nullptr); return loader_->load_into( - header_.segment_base_offset + segment_offset.get() + - metadata.get()->offset(), + header_.segment_base_offset + segment_offset, tensor_layout.get().nbytes(), info, buffer); } -ET_NODISCARD Result FlatTensorDataMap::get_num_keys() const { - // TODO(lfq): consolidate named_data and tensors. - if (flat_tensor_->named_data() == nullptr) { - return flat_tensor_->tensors()->size(); - } - return flat_tensor_->named_data()->size() + flat_tensor_->tensors()->size(); +ET_NODISCARD Result FlatTensorDataMap::get_num_keys() const { + return flat_tensor_->named_data()->size(); } ET_NODISCARD Result FlatTensorDataMap::get_key( - size_t index) const { - // TODO(lfq): consolidate named_data and tensors. - // For now, iterate over named_data and then flat_tensor. - size_t num_keys = get_num_keys().get(); + uint32_t index) const { + uint32_t num_keys = get_num_keys().get(); ET_CHECK_OR_RETURN_ERROR( index >= 0 && index < num_keys, InvalidArgument, - "Index %zu out of range of size %zu", + "Index %u out of range of size %u", index, num_keys); - - if (flat_tensor_->named_data() != nullptr && - index < flat_tensor_->named_data()->size()) { - return flat_tensor_->named_data()->Get(index)->key()->c_str(); - } else { - if (flat_tensor_->named_data() != nullptr) { - index = index - flat_tensor_->named_data()->size(); - } - return flat_tensor_->tensors()->Get(index)->fully_qualified_name()->c_str(); - } + return flat_tensor_->named_data()->Get(index)->key()->c_str(); } /* static */ Result FlatTensorDataMap::load( @@ -321,6 +250,17 @@ ET_NODISCARD Result FlatTensorDataMap::get_key( const flat_tensor_flatbuffer::FlatTensor* flat_tensor = flat_tensor_flatbuffer::GetFlatTensor(flat_tensor_data->data()); + // Validate flat_tensor. + ET_CHECK_OR_RETURN_ERROR( + flat_tensor->named_data() != nullptr, + InvalidExternalData, + "FlatTensor named_data is nullptr, malformed PTD file."); + + ET_CHECK_OR_RETURN_ERROR( + flat_tensor->segments() != nullptr, + InvalidExternalData, + "FlatTensor segments is nullptr, malformed PTD file."); + return FlatTensorDataMap( fh.get(), std::move(flat_tensor_data.get()), flat_tensor, loader); } diff --git a/extension/flat_tensor/flat_tensor_data_map.h b/extension/flat_tensor/flat_tensor_data_map.h index 0e7aee8ffc8..751e312f7ef 100644 --- a/extension/flat_tensor/flat_tensor_data_map.h +++ b/extension/flat_tensor/flat_tensor_data_map.h @@ -45,7 +45,7 @@ class FlatTensorDataMap final executorch::runtime::DataLoader* loader); /** - * Retrieve the metadata for the specified key. + * Retrieve the tensor_layout for the specified key. * * @param[in] key The name of the tensor to get metadata on. * @@ -54,7 +54,7 @@ class FlatTensorDataMap final ET_NODISCARD executorch::runtime::Result< const executorch::ET_RUNTIME_NAMESPACE::TensorLayout> - get_metadata(const char* key) const override; + get_tensor_layout(executorch::aten::string_view key) const override; /** * Retrieve read-only data for the specified key. @@ -65,7 +65,7 @@ class FlatTensorDataMap final */ ET_NODISCARD executorch::runtime::Result get_data( - const char* key) const override; + executorch::aten::string_view key) const override; /** * Loads the data of the specified tensor into the provided buffer. @@ -77,20 +77,22 @@ class FlatTensorDataMap final * * @returns an Error indicating if the load was successful. */ - ET_NODISCARD executorch::runtime::Error - load_data_into(const char* key, void* buffer, size_t size) const override; + ET_NODISCARD executorch::runtime::Error load_data_into( + executorch::aten::string_view key, + void* buffer, + size_t size) const override; /** * @returns The number of keys in the map. */ - ET_NODISCARD executorch::runtime::Result get_num_keys() + ET_NODISCARD executorch::runtime::Result get_num_keys() const override; /** * @returns The key at the specified index, error if index out of bounds. */ ET_NODISCARD executorch::runtime::Result get_key( - size_t index) const override; + uint32_t index) const override; FlatTensorDataMap(FlatTensorDataMap&&) noexcept = default; diff --git a/extension/flat_tensor/serialize/flat_tensor.fbs b/extension/flat_tensor/serialize/flat_tensor.fbs index 7d47a61d2b0..abf331697d6 100644 --- a/extension/flat_tensor/serialize/flat_tensor.fbs +++ b/extension/flat_tensor/serialize/flat_tensor.fbs @@ -7,9 +7,7 @@ namespace flat_tensor_flatbuffer; file_identifier "FT01"; file_extension "ptd"; -table TensorMetadata { - // The unique id used to connect the data and program. - fully_qualified_name: string; +table TensorLayout { scalar_type: executorch_flatbuffer.ScalarType; // Size of each dimension. @@ -27,17 +25,6 @@ table TensorMetadata { // the innermost dimension, then comes "batch", and the outermost dimension // is "row". dim_order: [uint8]; - - // FlatTensor.segments index that the tensor data is stored in. - segment_index: uint32; - - // Tensor offsets are relative to each TensorSegment. - // To retrieve a given tensor: - // 1. segment_base_offset: from the file header. - // 2. segment_offset: segments[segment_index].offset - // 3. tensor_offset: the offset within the segment. If there is only one item - // in the segment, offset=0. - offset: uint64; } // Describes a contiguous piece of data that lives outside of the flatbuffer data, @@ -62,6 +49,9 @@ table NamedData { // Index of the segment in FlatTensor.segments. segment_index: uint32; + + // Optional: if the underlying data is a tensor, store layout information. + tensor_layout: TensorLayout; } // FlatTensor is a flatbuffer-based format for storing and loading tensors. @@ -69,13 +59,6 @@ table FlatTensor { // Schema version. version: uint32; - // Alignment for each tensor in bytes. Offsets of the tensor provided - // in TensorMetadata.offset are aligned to tensor_alignment. - tensor_alignment: uint32; - - // Tensor information, including metadata and offsets to the raw tensor data. - tensors: [TensorMetadata]; - // List of data segments that follow the FlatTensor data in this file, sorted by // offset. Elements in this schema can refer to these segments by index. segments: [DataSegment]; diff --git a/extension/flat_tensor/serialize/flat_tensor_schema.py b/extension/flat_tensor/serialize/flat_tensor_schema.py index 9581442c2d8..53b0fe98ea9 100644 --- a/extension/flat_tensor/serialize/flat_tensor_schema.py +++ b/extension/flat_tensor/serialize/flat_tensor_schema.py @@ -7,7 +7,7 @@ # pyre-strict from dataclasses import dataclass -from typing import List +from typing import List, Optional from executorch.exir.scalar_type import ScalarType @@ -15,15 +15,11 @@ @dataclass -class TensorMetadata: - fully_qualified_name: str +class TensorLayout: scalar_type: ScalarType sizes: List[int] dim_order: List[int] - segment_index: int - offset: int - @dataclass class DataSegment: @@ -35,12 +31,11 @@ class DataSegment: class NamedData: key: str segment_index: int + tensor_layout: Optional[TensorLayout] = None @dataclass class FlatTensor: version: int - tensor_alignment: int - tensors: List[TensorMetadata] segments: List[DataSegment] named_data: List[NamedData] diff --git a/extension/flat_tensor/serialize/serialize.cpp b/extension/flat_tensor/serialize/serialize.cpp index 720e104ab7f..9930de6bba6 100644 --- a/extension/flat_tensor/serialize/serialize.cpp +++ b/extension/flat_tensor/serialize/serialize.cpp @@ -68,47 +68,46 @@ runtime::Error save_ptd( // Create flatbuffer flatbuffers::FlatBufferBuilder builder; - std::vector> - tensors; + std::vector> + named_data; std::vector> - buffers; + segments; // Write the tensors. size_t total_segment_size = 0; - size_t i = tensor_map.size(); + uint32_t i = 0; for (const auto& [name, tensor] : tensor_map) { - auto name_offset = builder.CreateString(name); - // Write the tensor metadata. - auto tensor_metadata = ::flat_tensor_flatbuffer::CreateTensorMetadata( - builder, - name_offset, + auto key = builder.CreateString(name); + // Write the tensor layouts. + auto tensor_layout = ::flat_tensor_flatbuffer::CreateTensorLayout( + /*_fbb*=*/builder, + /*scalar_type=*/ static_cast(tensor.scalar_type()), + /*sizes=*/ builder.CreateVector(tensor.sizes().data(), tensor.sizes().size()), + /*dim_order=*/ builder.CreateVector( - tensor.dim_order().data(), tensor.dim_order().size()), - 0, // segment index - total_segment_size); - - tensors.push_back(tensor_metadata); - // Don't pad last entry. - if (i != 1) { - // Precalculate the size of the data blob. - total_segment_size += aligned_size(tensor.nbytes(), tensor_alignment); - } else { - total_segment_size += tensor.nbytes(); - } - i--; + tensor.dim_order().data(), tensor.dim_order().size())); + + named_data.push_back(::flat_tensor_flatbuffer::CreateNamedData( + /*_fbb=*/builder, + /*key=*/key, + /*segment_index=*/i, + /*tensor_layout=*/tensor_layout)); + + segments.push_back(::flat_tensor_flatbuffer::CreateDataSegment( + /*_fbb=*/builder, + /*offset=*/total_segment_size, + /*size=*/tensor.nbytes())); + total_segment_size += aligned_size(tensor.nbytes(), tensor_alignment); + i++; } - // Only have one segment - buffers.push_back(::flat_tensor_flatbuffer::CreateDataSegment( - builder, 0, total_segment_size)); auto flat_tensor = CreateFlatTensor( - builder, - kSchemaVersion, - tensor_alignment, - builder.CreateVector(tensors), - builder.CreateVector(buffers)); + /*_fbb=*/builder, + /*version=*/kSchemaVersion, + /*segments=*/builder.CreateVector(segments), + /*named_data=*/builder.CreateVector(named_data)); builder.Finish(flat_tensor, ::flat_tensor_flatbuffer::FlatTensorIdentifier()); // Our flatbuffer is created now. @@ -133,10 +132,10 @@ runtime::Error save_ptd( *reinterpret_cast(builder.GetBufferPointer()); uint32_t new_offset = current_offset + padded_header_size; - // Write flatbuffer offset to root table + // Write flatbuffer offset to root table. out.write(reinterpret_cast(&new_offset), sizeof(new_offset)); - // Write flatbuffer magic bytes + // Write flatbuffer magic bytes. out.write( reinterpret_cast(builder.GetBufferPointer()) + sizeof(new_offset), diff --git a/extension/flat_tensor/serialize/serialize.py b/extension/flat_tensor/serialize/serialize.py index d23ffda2ae9..5b29d7ccacd 100644 --- a/extension/flat_tensor/serialize/serialize.py +++ b/extension/flat_tensor/serialize/serialize.py @@ -11,7 +11,7 @@ import os import tempfile from dataclasses import dataclass -from typing import ClassVar, Dict, List, Literal, Optional, Sequence +from typing import ClassVar, Dict, List, Literal, Optional import pkg_resources from executorch.exir._serialize._cord import Cord @@ -19,12 +19,7 @@ from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile from executorch.exir._serialize._program import _insert_flatbuffer_header -from executorch.exir._serialize.data_serializer import ( - DataEntry, - DataPayload, - DataSerializer, - TensorEntry, -) +from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required @@ -32,7 +27,6 @@ DataSegment, FlatTensor, NamedData, - TensorMetadata, ) # Byte order of numbers written to flat tensor headers. Always little-endian @@ -234,65 +228,8 @@ def _get_extended_header(flat_tensor_data: bytes) -> Optional[FlatTensorHeader]: return None -def _extract_tensors( - fqn_to_tensor: Dict[str, TensorEntry], - buffers: Sequence[bytes], - segments: List[AlignedData], - tensor_alignment: int, -) -> List[TensorMetadata]: - """Places tensors into a single segment, aligned to tensor_alignment within - the segment. - - Args: - fqn_to_tensor: A map from fully qualified names to tensor entries. - buffers: A sequence of tensor buffers. - segments: A list of segments to append the tensor data to. Modified in-place. - tensor_alignment: The alignment of the tensor data. - - Returns: - A list of TensorMetadata, which describes the tensors in the segment. - """ - tensor_data: Cord = Cord() - tensors: List[TensorMetadata] = [] - # {idx, offset} - saved_offsets: Dict[int, int] = {} - for fqn, tensor_entry in fqn_to_tensor.items(): - assert tensor_entry.layout is not None - # Check index into the tensor buffers is valid. - assert tensor_entry.buffer_index < len( - buffers - ), f"Invalid index {tensor_entry.buffer_index} is greater than tensor buffer size {len(buffers)}." - - # Check if the tensor has already been appended to the flat_tensor_data. - offset = saved_offsets.get(tensor_entry.buffer_index, -1) - if offset == -1: - if len(tensor_data) > 0: - # Add padding to round off the previous tensor offset. - pad_length = padding_required(len(tensor_data), tensor_alignment) - tensor_data.append(b"\x00" * pad_length) - # Add to saved offsets. - offset = len(tensor_data) - saved_offsets[tensor_entry.buffer_index] = offset - # Append to flat_tensor_data at the offset. - tensor_data.append(buffers[tensor_entry.buffer_index]) - - tensors.append( - TensorMetadata( - fully_qualified_name=fqn, - scalar_type=tensor_entry.layout.scalar_type, - sizes=tensor_entry.layout.sizes, - dim_order=tensor_entry.layout.dim_order, - segment_index=len(segments), - offset=offset, - ) - ) - segments.append(AlignedData(tensor_data)) - return tensors - - def _extract_named_data( - key_to_data: Dict[str, DataEntry], - buffers: Sequence[bytes], + data_payload: DataPayload, segments: List[AlignedData], ) -> List[NamedData]: """Places named data into segments and record the alignment for each. @@ -310,16 +247,25 @@ def _extract_named_data( segment_index_map: Dict[int, int] = {} named_data: List[NamedData] = [] - for key, data_entry in key_to_data.items(): + for key, data_entry in data_payload.named_data.items(): buffer_idx = data_entry.buffer_index segment_index = segment_index_map.get(buffer_idx, None) if segment_index is None: segment_index = len(segments) segment_index_map[buffer_idx] = segment_index segments.append( - AlignedData(Cord(buffers[buffer_idx]), data_entry.alignment) + AlignedData( + Cord(data_payload.buffers[buffer_idx]), data_entry.alignment + ) ) - named_data.append(NamedData(key=key, segment_index=segment_index)) + named_data.append( + NamedData( + key=key, + segment_index=segment_index, + # pyre-ignore Incompatible parameter type [6] + tensor_layout=data_entry.tensor_layout, + ) + ) return named_data @@ -344,13 +290,9 @@ def serialize( """Serializes a list of tensors and named data into a blob.""" segments: List[AlignedData] = [] - tensors = _extract_tensors( - data.fqn_to_tensor, - data.buffers, - segments, - self.config.tensor_alignment, - ) - named_data = _extract_named_data(data.key_to_data, data.buffers, segments) + + # Add a config to place tensors in a single segment. + named_data = _extract_named_data(data, segments) data_segments: List[DataSegment] = [] aggregated_segment_data = Cord() @@ -379,8 +321,6 @@ def serialize( # points to all the data segments. It will be serialized to flatbuffer. flat_tensor = FlatTensor( version=0, # Keep in sync with c++ version number in serialize.h - tensor_alignment=self.config.tensor_alignment, - tensors=tensors, segments=data_segments, named_data=named_data, ) diff --git a/extension/flat_tensor/test/flat_tensor_data_map_test.cpp b/extension/flat_tensor/test/flat_tensor_data_map_test.cpp index 0d8bf9659bb..5a94b47b954 100644 --- a/extension/flat_tensor/test/flat_tensor_data_map_test.cpp +++ b/extension/flat_tensor/test/flat_tensor_data_map_test.cpp @@ -60,7 +60,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) { // From //executorch/test/models/linear_model.py, we have the tensors // self.a = 3 * torch.ones(2, 2, dtype=torch.float) // self.b = 2 * torch.ones(2, 2, dtype=torch.float) - Result const_a_res = data_map->get_metadata("a"); + Result const_a_res = data_map->get_tensor_layout("a"); ASSERT_EQ(Error::Ok, const_a_res.error()); const TensorLayout const_a = const_a_res.get(); @@ -74,7 +74,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) { EXPECT_EQ(dim_order_a[0], 0); EXPECT_EQ(dim_order_a[1], 1); - Result const_b_res = data_map->get_metadata("b"); + Result const_b_res = data_map->get_tensor_layout("b"); ASSERT_EQ(Error::Ok, const_b_res.error()); const TensorLayout const_b = const_b_res.get(); @@ -88,8 +88,8 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) { EXPECT_EQ(dim_order_b[0], 0); EXPECT_EQ(dim_order_b[1], 1); - // Check get_metadata fails when key is not found. - Result const_c_res = data_map->get_metadata("c"); + // Check get_tensor_layout fails when key is not found. + Result const_c_res = data_map->get_tensor_layout("c"); EXPECT_EQ(const_c_res.error(), Error::NotFound); } @@ -120,7 +120,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) { EXPECT_EQ(data_map.error(), Error::Ok); // Check num tensors is 2. - Result num_tensors_res = data_map->get_num_keys(); + Result num_tensors_res = data_map->get_num_keys(); ASSERT_EQ(Error::Ok, num_tensors_res.error()); EXPECT_EQ(num_tensors_res.get(), 2); @@ -144,7 +144,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) { EXPECT_EQ(data_map.error(), Error::Ok); // get the metadata - auto meta_data_res = data_map->get_metadata("a"); + auto meta_data_res = data_map->get_tensor_layout("a"); ASSERT_EQ(meta_data_res.error(), Error::Ok); // get data blob diff --git a/extension/flat_tensor/test/test_serialize.cpp b/extension/flat_tensor/test/test_serialize.cpp index 82f47684c71..57a0253485b 100644 --- a/extension/flat_tensor/test/test_serialize.cpp +++ b/extension/flat_tensor/test/test_serialize.cpp @@ -79,17 +79,16 @@ TEST_F(FlatTensorSerializeTest, ValidFlatTensorSerialized) { EXPECT_EQ( *(uint64_t*)(header_buffer + 16), - 232); // Flatbuffer size. This is fragile, and depends on the schema, + 280); // Flatbuffer size. This is fragile, and depends on the schema, // the builder, and the padding needed. // Segment offset, depends on the padded header and flatbuffer sizes. - const uint64_t segment_offset = 48 + 232 + 8; // 8 is padding. + const uint64_t segment_offset = 48 + 280 + 8; // 8 is padding. EXPECT_EQ(*(uint64_t*)(header_buffer + 24), segment_offset); - EXPECT_EQ( - *(uint64_t*)(header_buffer + 32), - 20); // Segment total size, 8 bytes of data (2 floats), 24 bytes of - // padding. + // Segment total size, 8 bytes of data (2 floats), 24 bytes of padding. + const uint64_t segment_size = 32; + EXPECT_EQ(*(uint64_t*)(header_buffer + 32), segment_size); // Check Flatbuffer auto flat_tensor = ::flat_tensor_flatbuffer::GetFlatTensor(byte_buffer); @@ -97,30 +96,38 @@ TEST_F(FlatTensorSerializeTest, ValidFlatTensorSerialized) { EXPECT_EQ( flat_tensor->version(), executorch::extension::flat_tensor::kSchemaVersion); - EXPECT_EQ(flat_tensor->tensor_alignment(), 16); - EXPECT_EQ(flat_tensor->tensors()->size(), 2); - EXPECT_EQ(flat_tensor->segments()->size(), 1); - - auto tensor0 = flat_tensor->tensors()->Get(0); - EXPECT_EQ(strcmp(tensor0->fully_qualified_name()->c_str(), "linear.bias"), 0); - EXPECT_EQ(tensor0->scalar_type(), executorch_flatbuffer::ScalarType::FLOAT); - EXPECT_EQ(tensor0->sizes()->size(), 1); - EXPECT_EQ(tensor0->segment_index(), 0); - EXPECT_EQ(tensor0->offset(), 0); - - auto tensor1 = flat_tensor->tensors()->Get(1); + EXPECT_EQ(flat_tensor->named_data()->size(), 2); + EXPECT_EQ(flat_tensor->segments()->size(), 2); + + auto tensor0 = flat_tensor->named_data()->Get(0); + EXPECT_EQ(strcmp(tensor0->key()->c_str(), "linear.bias"), 0); EXPECT_EQ( - strcmp(tensor1->fully_qualified_name()->c_str(), "linear.weight"), 0); - EXPECT_EQ(tensor1->scalar_type(), executorch_flatbuffer::ScalarType::FLOAT); - EXPECT_EQ(tensor1->sizes()->size(), 1); - EXPECT_EQ(tensor1->segment_index(), 0); - EXPECT_EQ(tensor1->offset(), 16); + tensor0->tensor_layout()->scalar_type(), + executorch_flatbuffer::ScalarType::FLOAT); + EXPECT_EQ(tensor0->tensor_layout()->sizes()->size(), 1); + EXPECT_EQ(tensor0->tensor_layout()->sizes()->Get(0), 1); + EXPECT_EQ(tensor0->tensor_layout()->dim_order()->size(), 1); + EXPECT_EQ(tensor0->tensor_layout()->dim_order()->Get(0), 0); + + auto tensor1 = flat_tensor->named_data()->Get(1); + EXPECT_EQ(strcmp(tensor1->key()->c_str(), "linear.weight"), 0); + EXPECT_EQ( + tensor1->tensor_layout()->scalar_type(), + executorch_flatbuffer::ScalarType::FLOAT); + EXPECT_EQ(tensor1->tensor_layout()->sizes()->size(), 1); + EXPECT_EQ(tensor1->tensor_layout()->sizes()->Get(0), 1); + EXPECT_EQ(tensor1->tensor_layout()->dim_order()->size(), 1); + EXPECT_EQ(tensor1->tensor_layout()->dim_order()->Get(0), 0); // Test Segments - auto segment = flat_tensor->segments()->Get(0); + auto segment0 = flat_tensor->segments()->Get(0); + EXPECT_EQ(segment0->offset(), 0); + EXPECT_EQ(segment0->size(), 4); + + auto segment1 = flat_tensor->segments()->Get(1); + EXPECT_EQ(segment1->offset(), kTensorAlignment); + EXPECT_EQ(segment1->size(), 4); - EXPECT_EQ(segment->offset(), 0); - EXPECT_EQ(segment->size(), 20); uint8_t* data = (uint8_t*)(byte_buffer + segment_offset); EXPECT_EQ(*(float*)(data + 0), linear_bias); EXPECT_EQ(*(float*)(data + 16), linear_weight); diff --git a/extension/flat_tensor/test/test_serialize.py b/extension/flat_tensor/test/test_serialize.py index 570fe9ae97f..80ee59ae974 100644 --- a/extension/flat_tensor/test/test_serialize.py +++ b/extension/flat_tensor/test/test_serialize.py @@ -4,25 +4,23 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-strict +# pyre-unsafe import math import unittest -from typing import List +from typing import List, Optional from executorch.exir._serialize.data_serializer import ( DataEntry, DataPayload, DataSerializer, - TensorEntry, - TensorLayout, ) from executorch.exir._serialize.padding import aligned_size from executorch.exir.schema import ScalarType -from executorch.extension.flat_tensor.serialize.flat_tensor_schema import TensorMetadata +from executorch.extension.flat_tensor.serialize.flat_tensor_schema import TensorLayout from executorch.extension.flat_tensor.serialize.serialize import ( _deserialize_to_flat_tensor, @@ -31,65 +29,71 @@ FlatTensorSerializer, ) -# Test artifacts. +# The raw data stored in the serialized file segments. TEST_BUFFER: List[bytes] = [b"\x11" * 4, b"\x22" * 32, b"\x33" * 17] -TEST_TENSOR_MAP = { - "fqn1": TensorEntry( + +# Items serialized into FlatTensor.named_data. +# fqn1 and fqn2 are tensors that point to the same buffer index. +# fqn3 is a single tensor. +# key0 is a named_data entry. +TEST_NAMED_DATA = { + "fqn1": DataEntry( buffer_index=0, - layout=TensorLayout( + alignment=0, + tensor_layout=TensorLayout( scalar_type=ScalarType.FLOAT, sizes=[1, 1, 1], dim_order=[0, 1, 2], ), ), - "fqn2": TensorEntry( + "fqn2": DataEntry( buffer_index=0, - layout=TensorLayout( + alignment=0, + tensor_layout=TensorLayout( scalar_type=ScalarType.FLOAT, sizes=[1, 1, 1], dim_order=[0, 1, 2], ), ), - "fqn3": TensorEntry( + "fqn3": DataEntry( buffer_index=1, - layout=TensorLayout( + alignment=0, + tensor_layout=TensorLayout( scalar_type=ScalarType.INT, sizes=[2, 2, 2], dim_order=[0, 1], ), ), -} - -TEST_DATA_ENTRY = { "key0": DataEntry( buffer_index=2, alignment=64, - ) + tensor_layout=None, + ), } TEST_DATA_PAYLOAD = DataPayload( buffers=TEST_BUFFER, - fqn_to_tensor=TEST_TENSOR_MAP, - key_to_data=TEST_DATA_ENTRY, + named_data=TEST_NAMED_DATA, ) class TestSerialize(unittest.TestCase): # TODO(T211851359): improve test coverage. - def check_tensor_metadata( - self, tensor_layout: TensorLayout, tensor_metadata: TensorMetadata + def check_tensor_layout( + self, expected: Optional[TensorLayout], actual: Optional[TensorLayout] ) -> None: - self.assertEqual(tensor_layout.scalar_type, tensor_metadata.scalar_type) - self.assertEqual(tensor_layout.sizes, tensor_metadata.sizes) - self.assertEqual(tensor_layout.dim_order, tensor_metadata.dim_order) + self.assertIsNotNone(expected) + self.assertIsNotNone(actual) + self.assertEqual(expected.scalar_type, actual.scalar_type) + self.assertEqual(expected.sizes, actual.sizes) + self.assertEqual(expected.dim_order, actual.dim_order) def test_serialize(self) -> None: config = FlatTensorConfig() serializer: DataSerializer = FlatTensorSerializer(config) - serialized_data = bytes(serializer.serialize(TEST_DATA_PAYLOAD)) - # Check header. + # Ensure valid header. header = FlatTensorHeader.from_bytes( serialized_data[8 : FlatTensorHeader.EXPECTED_LENGTH + 8] ) @@ -110,67 +114,68 @@ def test_serialize(self) -> None: ) self.assertTrue(header.segment_base_offset, expected_segment_base_offset) - # TEST_BUFFER is aligned to config.segment_alignment. - tensor1_size = aligned_size(len(TEST_BUFFER[0]), config.tensor_alignment) - tensor2_size = aligned_size(len(TEST_BUFFER[1]), config.tensor_alignment) - tensor_segment_size = aligned_size( - tensor1_size + tensor2_size, - math.lcm(config.segment_alignment, TEST_DATA_ENTRY["key0"].alignment), - ) - data_segment_size = len(TEST_BUFFER[2]) - expected_segment_data_size = tensor_segment_size + data_segment_size - self.assertEqual(header.segment_data_size, expected_segment_data_size) - # Confirm the flatbuffer magic is present. self.assertEqual( serialized_data[4:8], b"FT01", ) - # Check flat tensor data. + # Extract the flatbuffer. flat_tensor_bytes = serialized_data[ 0 : header.flatbuffer_offset + header.flatbuffer_size ] - flat_tensor = _deserialize_to_flat_tensor(flat_tensor_bytes) + # Check FlatTensor.version. self.assertEqual(flat_tensor.version, 0) - self.assertEqual(flat_tensor.tensor_alignment, config.tensor_alignment) - tensors = flat_tensor.tensors - self.assertEqual(len(tensors), 3) - self.assertEqual(tensors[0].fully_qualified_name, "fqn1") - self.check_tensor_metadata(TEST_TENSOR_MAP["fqn1"].layout, tensors[0]) - self.assertEqual(tensors[0].segment_index, 0) - self.assertEqual(tensors[0].offset, 0) + # Check FlatTensor.named_data; key, segment_index, tensor_layout. + named_data = flat_tensor.named_data + self.assertEqual(len(named_data), 4) - self.assertEqual(tensors[1].fully_qualified_name, "fqn2") - self.check_tensor_metadata(TEST_TENSOR_MAP["fqn2"].layout, tensors[1]) - self.assertEqual(tensors[1].segment_index, 0) - self.assertEqual(tensors[1].offset, 0) + self.assertEqual(named_data[0].key, "fqn1") + self.assertEqual(named_data[0].segment_index, 0) + self.check_tensor_layout( + TEST_NAMED_DATA["fqn1"].tensor_layout, named_data[0].tensor_layout + ) - self.assertEqual(tensors[2].fully_qualified_name, "fqn3") - self.check_tensor_metadata(TEST_TENSOR_MAP["fqn3"].layout, tensors[2]) - self.assertEqual(tensors[2].segment_index, 0) - self.assertEqual(tensors[2].offset, config.tensor_alignment) + self.assertEqual(named_data[1].key, "fqn2") + self.assertEqual(named_data[1].segment_index, 0) + self.check_tensor_layout( + TEST_NAMED_DATA["fqn2"].tensor_layout, named_data[1].tensor_layout + ) - named_data = flat_tensor.named_data - self.assertEqual(len(named_data), 1) - self.assertEqual(named_data[0].key, "key0") - self.assertEqual(named_data[0].segment_index, 1) + self.assertEqual(named_data[2].key, "fqn3") + self.assertEqual(named_data[2].segment_index, 1) + self.check_tensor_layout( + TEST_NAMED_DATA["fqn3"].tensor_layout, named_data[2].tensor_layout + ) + + self.assertEqual(named_data[3].key, "key0") + self.assertEqual(named_data[3].segment_index, 2) + self.assertEqual(named_data[3].tensor_layout, None) + # Check FlatTensor.segments. segments = flat_tensor.segments - self.assertEqual(len(segments), 2) + self.assertEqual(len(segments), 3) + + # Segment 0 contains fqn1, fqn2; 4 bytes, aligned to config.tensor_alignment. self.assertEqual(segments[0].offset, 0) - self.assertEqual(segments[0].size, config.tensor_alignment * 3) + self.assertEqual(segments[0].size, len(TEST_BUFFER[0])) + + # Segment 1 contains fqn3; 32 bytes, aligned to config.tensor_alignment. + self.assertEqual(segments[1].offset, config.tensor_alignment) + self.assertEqual(segments[1].size, len(TEST_BUFFER[1])) + + # Segment 2 contains key0; 17 bytes, aligned to 64. + custom_alignment = math.lcm( + config.segment_alignment, TEST_NAMED_DATA["key0"].alignment + ) self.assertEqual( - segments[1].offset, - aligned_size( - config.tensor_alignment * 3, - math.lcm(config.segment_alignment, TEST_DATA_ENTRY["key0"].alignment), - ), + segments[2].offset, + aligned_size(config.tensor_alignment * 3, custom_alignment), ) - self.assertEqual(segments[1].size, len(TEST_BUFFER[2])) + self.assertEqual(segments[2].size, len(TEST_BUFFER[2])) # Length of serialized_data matches segment_base_offset + segment_data_size. self.assertEqual( @@ -186,31 +191,35 @@ def test_serialize(self) -> None: ] # Tensor: b"\x11" * 4 - t0_start = 0 - t0_len = len(TEST_BUFFER[0]) - t0_end = t0_start + aligned_size(t0_len, config.tensor_alignment) - self.assertEqual(segment_data[t0_start : t0_start + t0_len], TEST_BUFFER[0]) - padding = b"\x00" * (t0_end - t0_len) - self.assertEqual(segment_data[t0_start + t0_len : t0_end], padding) + self.assertEqual( + segment_data[segments[0].offset : segments[0].offset + segments[0].size], + TEST_BUFFER[0], + ) # Tensor: b"\x22" * 32 - t1_start = t0_end - t1_len = len(TEST_BUFFER[1]) - t1_end = t1_start + aligned_size(t1_len, config.tensor_alignment) + padding = b"\x00" * ( + segments[1].offset - (segments[0].offset + segments[0].size) + ) self.assertEqual( - segment_data[t1_start : t1_start + t1_len], - TEST_BUFFER[1], + segment_data[segments[0].offset + segments[0].size : segments[1].offset], + padding, ) - padding = b"\x00" * (t1_end - (t1_len + t1_start)) - self.assertEqual(segment_data[t1_start + t1_len : t1_end], padding) - - # Check length of the segment is expected. self.assertEqual( - segments[0].size, aligned_size(t1_end, config.segment_alignment) + segment_data[segments[1].offset : segments[1].offset + segments[1].size], + TEST_BUFFER[1], ) # Named data: b"\x33" * 17 + padding = b"\x00" * ( + segments[2].offset - (segments[1].offset + segments[1].size) + ) self.assertEqual( - segment_data[segments[1].offset : segments[1].offset + len(TEST_BUFFER[2])], + segment_data[segments[1].offset + segments[1].size : segments[2].offset], + padding, + ) + self.assertEqual( + segment_data[segments[2].offset : segments[2].offset + segments[2].size], TEST_BUFFER[2], ) + + self.assertEqual(segments[2].offset + segments[2].size, len(segment_data)) diff --git a/extension/training/module/state_dict_util.cpp b/extension/training/module/state_dict_util.cpp index 7c742d11c08..aee863854c5 100644 --- a/extension/training/module/state_dict_util.cpp +++ b/extension/training/module/state_dict_util.cpp @@ -27,7 +27,7 @@ load_state_dict(const runtime::NamedDataMap& data_map) { } // get the metadata - auto metadata_res = data_map.get_metadata(key_res.get()); + auto metadata_res = data_map.get_tensor_layout(key_res.get()); if (!metadata_res.ok()) { return metadata_res.error(); } diff --git a/runtime/core/named_data_map.h b/runtime/core/named_data_map.h index 14179d22795..7503f0b2979 100644 --- a/runtime/core/named_data_map.h +++ b/runtime/core/named_data_map.h @@ -31,41 +31,43 @@ class ET_EXPERIMENTAL NamedDataMap { public: virtual ~NamedDataMap() = default; /** - * Get metadata by key. + * Get tensor_layout by key. * * @param key The name of the tensor. - * @return Result containing TensorLayout with tensor metadata. + * @return Result containing TensorLayout. */ - ET_NODISCARD virtual Result get_metadata( - const char* key) const = 0; + ET_NODISCARD virtual Result get_tensor_layout( + executorch::aten::string_view key) const = 0; /** * Get data by key. * * @param key Name of the data. - * @return Result containing a FreeableBuffer with the tensor data. + * @return Result containing a FreeableBuffer. */ ET_NODISCARD virtual Result get_data( - const char* key) const = 0; + executorch::aten::string_view key) const = 0; /** * Loads data corresponding to the key into the provided buffer. * * @param key The name of the data. - * @param size The number of bytes to load. Use `get_metadata` to retrieve the - * size of the data for a given key. + * @param size The number of bytes to load. Use `get_tensor_layout` to + * retrieve the size of the data for a given key. * @param buffer The buffer to load the data into. Must point to at least * `size` bytes of memory. * @returns an Error indicating if the load was successful. */ - ET_NODISCARD virtual Error - load_data_into(const char* key, void* buffer, size_t size) const = 0; + ET_NODISCARD virtual Error load_data_into( + executorch::aten::string_view key, + void* buffer, + size_t size) const = 0; /** * Get the number of keys in the NamedDataMap. * * @return Result containing the number of keys. */ - ET_NODISCARD virtual Result get_num_keys() const = 0; + ET_NODISCARD virtual Result get_num_keys() const = 0; /** * Get the key at the given index. @@ -74,7 +76,7 @@ class ET_EXPERIMENTAL NamedDataMap { * @return Result containing the key at the given index. Note: the returned * pointer is only valid for the lifetime of the DataMap. */ - ET_NODISCARD virtual Result get_key(size_t index) const = 0; + ET_NODISCARD virtual Result get_key(uint32_t index) const = 0; }; } // namespace ET_RUNTIME_NAMESPACE diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index af40f4a7bfd..7f4836a9e76 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -372,7 +372,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) { continue; } Result tensor_layout = - named_data_map->get_metadata(key); + named_data_map->get_tensor_layout(key); if (!tensor_layout.ok()) { ET_LOG(Info, "Failed to get metadata for key %s", key); return tensor_layout.error(); diff --git a/runtime/executor/pte_data_map.cpp b/runtime/executor/pte_data_map.cpp index fd064cb8256..e9b1c3460be 100644 --- a/runtime/executor/pte_data_map.cpp +++ b/runtime/executor/pte_data_map.cpp @@ -27,15 +27,19 @@ namespace internal { ET_NODISCARD executorch::runtime::Result -PteDataMap::get_data(const char* key) const { - for (size_t i = 0; i < named_data_->size(); i++) { +PteDataMap::get_data(executorch::aten::string_view key) const { + for (uint32_t i = 0; i < named_data_->size(); i++) { ET_CHECK_OR_RETURN_ERROR( named_data_->Get(i) != nullptr && named_data_->Get(i)->key() != nullptr, InvalidArgument, - "Searching for key %s: NamedData at index %zu is null", - key, + "Searching for key %.*s: NamedData at index %d is null", + static_cast(key.size()), + key.data(), i); - if (strcmp(named_data_->Get(i)->key()->c_str(), key) == 0) { + if (strncmp( + named_data_->Get(i)->key()->c_str(), + key.data(), + named_data_->Get(i)->key()->size()) == 0) { // Get the segment index. size_t segment_index = named_data_->Get(i)->segment_index(); @@ -43,9 +47,10 @@ PteDataMap::get_data(const char* key) const { ET_CHECK_OR_RETURN_ERROR( segment_index < segments_->size(), InvalidArgument, - "Segment index %zu for key %s is out of range for segments size %u", + "Segment index %zu for key %.*s is out of range for segments size %u", segment_index, - key, + static_cast(key.size()), + key.data(), segments_->size()); size_t segment_offset = segments_->Get(segment_index)->offset(); size_t segment_size = segments_->Get(segment_index)->size(); @@ -59,17 +64,17 @@ PteDataMap::get_data(const char* key) const { return Error::NotFound; } -ET_NODISCARD executorch::runtime::Result PteDataMap::get_num_keys() +ET_NODISCARD executorch::runtime::Result PteDataMap::get_num_keys() const { return named_data_->size(); } ET_NODISCARD executorch::runtime::Result PteDataMap::get_key( - size_t index) const { + uint32_t index) const { ET_CHECK_OR_RETURN_ERROR( index < named_data_->size(), InvalidArgument, - "Index out of range: named_data size is %u, received index %zu", + "Index out of range: named_data size is %u, received index %u", named_data_->size(), index); @@ -77,7 +82,7 @@ ET_NODISCARD executorch::runtime::Result PteDataMap::get_key( named_data_->Get(index) != nullptr && named_data_->Get(index)->key() != nullptr, InvalidArgument, - "NamedData at index %zu is null", + "NamedData at index %u is null", index); return named_data_->Get(index)->key()->c_str(); } diff --git a/runtime/executor/pte_data_map.h b/runtime/executor/pte_data_map.h index b26c0ac42f9..b4b46a6b541 100644 --- a/runtime/executor/pte_data_map.h +++ b/runtime/executor/pte_data_map.h @@ -78,8 +78,8 @@ class PteDataMap final : public NamedDataMap { * tensor-specific metadata. */ ET_NODISCARD - Result get_metadata( - ET_UNUSED const char* key) const override { + Result get_tensor_layout( + ET_UNUSED executorch::aten::string_view key) const override { return Error::NotImplemented; } @@ -91,13 +91,14 @@ class PteDataMap final : public NamedDataMap { * @return error if the key is not present or data cannot be loaded. */ ET_NODISCARD - Result get_data(const char* key) const override; + Result get_data( + executorch::aten::string_view key) const override; /** * The PteDataMap currently does not implement load_into. */ ET_NODISCARD Error load_data_into( - ET_UNUSED const char* key, + ET_UNUSED executorch::aten::string_view key, ET_UNUSED void* buffer, ET_UNUSED size_t size) const override { return Error::NotImplemented; @@ -106,12 +107,12 @@ class PteDataMap final : public NamedDataMap { /** * @returns The number of keys in the map. */ - ET_NODISCARD Result get_num_keys() const override; + ET_NODISCARD Result get_num_keys() const override; /** * @returns The key at the specified index, error if index out of bounds. */ - ET_NODISCARD Result get_key(size_t index) const override; + ET_NODISCARD Result get_key(uint32_t index) const override; // Moveable, to be compatible with Result. PteDataMap(PteDataMap&&) noexcept = default; diff --git a/runtime/executor/tensor_parser_exec_aten.cpp b/runtime/executor/tensor_parser_exec_aten.cpp index aa27bbf929d..45ce16b4e6b 100644 --- a/runtime/executor/tensor_parser_exec_aten.cpp +++ b/runtime/executor/tensor_parser_exec_aten.cpp @@ -195,7 +195,7 @@ ET_NODISCARD Result getTensorDataPtr( // Mutable value. // Look up tensor in named data map. Result tensor_layout_res = - named_data_map->get_metadata(fqn); + named_data_map->get_tensor_layout(fqn); if (!tensor_layout_res.ok()) { return tensor_layout_res.error(); } diff --git a/runtime/executor/test/pte_data_map_test.cpp b/runtime/executor/test/pte_data_map_test.cpp index b5312eb4a88..5b13191aa3f 100644 --- a/runtime/executor/test/pte_data_map_test.cpp +++ b/runtime/executor/test/pte_data_map_test.cpp @@ -125,8 +125,8 @@ TEST_F(PteDataMapTest, UnimplementedMethods) { data_map_loader_.get(), 0, program_->named_data(), program_->segments()); ; - // Check get_metadata is not implemented. - auto result = data_map->get_metadata("sample_key"); + // Check get_tensor_layout is not implemented. + auto result = data_map->get_tensor_layout("sample_key"); EXPECT_EQ(result.error(), Error::NotImplemented); // Check load_data_into is not implemented.