From f05f44ee49f986c896f16d860233993ac66ab37c Mon Sep 17 00:00:00 2001 From: Quentin Khan Date: Thu, 11 Apr 2024 04:15:19 -0700 Subject: [PATCH] When no weight cache is provided to XNNPack, create one to share packed weights between operations. PiperOrigin-RevId: 623781016 --- tensorflow/lite/core/c/common.h | 2 + tensorflow/lite/core/interpreter_builder.cc | 3 +- tensorflow/lite/core/subgraph.cc | 6 +- tensorflow/lite/core/subgraph.h | 17 +- tensorflow/lite/delegates/xnnpack/BUILD | 48 +- .../lite/delegates/xnnpack/weight_cache.cc | 491 ++++++++++++ .../lite/delegates/xnnpack/weight_cache.h | 306 ++++++++ .../delegates/xnnpack/weight_cache_schema.fbs | 52 ++ .../xnnpack/weight_cache_schema_generated.h | 422 +++++++++++ .../delegates/xnnpack/weight_cache_test.cc | 708 ++++++++++++++++++ .../delegates/xnnpack/xnnpack_delegate.cc | 52 +- .../lite/delegates/xnnpack/xnnpack_delegate.h | 11 + tensorflow/lite/tflite_with_xnnpack.cc | 4 + tensorflow/opensource_only.files | 1 + 14 files changed, 2108 insertions(+), 15 deletions(-) create mode 100644 tensorflow/lite/delegates/xnnpack/weight_cache.cc create mode 100644 tensorflow/lite/delegates/xnnpack/weight_cache.h create mode 100644 tensorflow/lite/delegates/xnnpack/weight_cache_schema.fbs create mode 100755 tensorflow/lite/delegates/xnnpack/weight_cache_schema_generated.h create mode 100644 tensorflow/lite/delegates/xnnpack/weight_cache_test.cc diff --git a/tensorflow/lite/core/c/common.h b/tensorflow/lite/core/c/common.h index ea54be9490ef01..96f19f12336bc4 100644 --- a/tensorflow/lite/core/c/common.h +++ b/tensorflow/lite/core/c/common.h @@ -472,6 +472,8 @@ typedef enum TfLiteCustomAllocationFlags { kTfLiteCustomAllocationFlagsSkipAlignCheck = 1, } TfLiteCustomAllocationFlags; +enum { kTfLiteNoBufferIdentifier = SIZE_MAX }; + /// A tensor in the interpreter system which is a wrapper around a buffer of /// data including a dimensionality (or NULL if not currently defined). #ifndef TF_LITE_STATIC_MEMORY diff --git a/tensorflow/lite/core/interpreter_builder.cc b/tensorflow/lite/core/interpreter_builder.cc index 41e62cfd675340..d8c6d181ebdd1a 100644 --- a/tensorflow/lite/core/interpreter_builder.cc +++ b/tensorflow/lite/core/interpreter_builder.cc @@ -691,7 +691,8 @@ TfLiteStatus InterpreterBuilder::ParseTensors( if (subgraph->SetTensorParametersReadOnly( i, type, get_name(tensor), dims, quantization, buffer_ptr, - buffer_size, allocation_, sparsity) != kTfLiteOk) { + buffer_size, allocation_, sparsity, + /*buffer_identifier=*/tensor->buffer()) != kTfLiteOk) { TF_LITE_REPORT_ERROR(error_reporter_, "Tensor %d is invalidly specified in schema.\n", i); diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 26ba2037342405..ce3622105b1ce5 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -1856,7 +1856,8 @@ TfLiteStatus Subgraph::GetNodeAndRegistration( TfLiteStatus Subgraph::SetTensorParametersReadOnly( int tensor_index, TfLiteType type, const char* name, const size_t ndims, const int* dims, TfLiteQuantization quantization, const char* buffer, - size_t bytes, const Allocation* allocation, TfLiteSparsity* sparsity) { + size_t bytes, const Allocation* allocation, TfLiteSparsity* sparsity, + const size_t buffer_identifier) { // Ensure quantization cleanup on failure. ScopedTfLiteQuantization scoped_quantization(&quantization); ScopedTfLiteSparsity scoped_sparsity(sparsity); @@ -1904,6 +1905,9 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly( tensor.quantization = *scoped_quantization.release(); tensor.sparsity = scoped_sparsity.release(); } + if (buffer_identifier != kTfLiteNoBufferIdentifier) { + tensor_buffer_identifiers_[tensor_index] = buffer_identifier; + } return kTfLiteOk; } diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h index 5940bfbb232ca3..281ac04adc2096 100644 --- a/tensorflow/lite/core/subgraph.h +++ b/tensorflow/lite/core/subgraph.h @@ -23,6 +23,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -132,16 +133,18 @@ class Subgraph { int tensor_index, TfLiteType type, const char* name, const std::vector& dims, TfLiteQuantization quantization, const char* buffer, size_t bytes, const Allocation* allocation = nullptr, - TfLiteSparsity* sparsity = nullptr) { + TfLiteSparsity* sparsity = nullptr, + size_t buffer_identifier = kTfLiteNoBufferIdentifier) { return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(), dims.data(), quantization, buffer, bytes, - allocation, sparsity); + allocation, sparsity, buffer_identifier); } TfLiteStatus SetTensorParametersReadOnly( int tensor_index, TfLiteType type, const char* name, const size_t ndims, const int* dims, TfLiteQuantization quantization, const char* buffer, size_t bytes, const Allocation* allocation = nullptr, - TfLiteSparsity* sparsity = nullptr); + TfLiteSparsity* sparsity = nullptr, + size_t buffer_identifier = kTfLiteNoBufferIdentifier); // Set description of inputs/outputs/data/fptrs for node `node_index`. // This variant assumes an external buffer has been allocated of size @@ -589,6 +592,10 @@ class Subgraph { // Returns true if the subgraph has been fully delegated. bool IsFullyDelegated() const; + const std::unordered_map& GetTensorBufferIdentifiers() { + return tensor_buffer_identifiers_; + } + private: #ifndef DOXYGEN_SKIP friend class tflite::impl::InterpreterBuilder; @@ -1153,6 +1160,10 @@ class Subgraph { /// The allocator used for holding memory of the model. Note that this will /// be null if the client provides a tflite::Model directly. const Allocation* allocation_ = nullptr; + + // Maps tensor constant buffers used in the subgraph to a model-wide + // identifiers. + std::unordered_map tensor_buffer_identifiers_; }; } // namespace tflite diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD index c4f748280d70ec..ae8e9bc9dacdf5 100644 --- a/tensorflow/lite/delegates/xnnpack/BUILD +++ b/tensorflow/lite/delegates/xnnpack/BUILD @@ -1,3 +1,4 @@ +load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/lite:build_def.bzl", "tflite_copts") load("//tensorflow/lite:special_rules.bzl", "internal_visibility_allowlist", "tflite_portable_test_suite_combined") @@ -246,11 +247,7 @@ cc_library( linkstatic = True, deps = [ ":quantization_util", - ":tflite_with_xnnpack_dynamic_fully_connected", - ":tflite_with_xnnpack_logging", - ":tflite_with_xnnpack_qs8", - ":tflite_with_xnnpack_qu8", - ":tflite_with_xnnpack_transient_indirection_buffer", + ":weight_cache", "//tensorflow/lite:kernel_api", "//tensorflow/lite:minimal_logging", "//tensorflow/lite/c:c_api_types", @@ -267,7 +264,6 @@ cc_library( "//tensorflow/lite/tools/optimize:reduced_precision_support", "@XNNPACK", "@XNNPACK//:experiments_config", - "@XNNPACK//:logging", ], ) @@ -289,6 +285,7 @@ cc_library( linkstatic = True, deps = [ ":quantization_util", + ":weight_cache", "//tensorflow/lite:kernel_api", "//tensorflow/lite:minimal_logging", "//tensorflow/lite/c:c_api_types", @@ -305,7 +302,6 @@ cc_library( "//tensorflow/lite/tools/optimize:reduced_precision_support", "@XNNPACK//:XNNPACK_test_mode", "@XNNPACK//:experiments_config", - "@XNNPACK//:logging", ], ) @@ -323,6 +319,30 @@ cc_library( ], ) +flatbuffer_cc_library( + name = "weight_cache_schema", + srcs = ["weight_cache_schema.fbs"], + compatible_with = get_compatible_with_portable(), + flatc_args = [ + "--gen-mutable", + "--gen-object-api", + ], +) + +cc_library( + name = "weight_cache", + srcs = ["weight_cache.cc"], + hdrs = ["weight_cache.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":weight_cache_schema", + "//tensorflow/lite:minimal_logging", + "//tensorflow/lite/c:common", + "@XNNPACK", + "@flatbuffers//:runtime_cc", + ], +) + ################################ Tester classes ################################ cc_library( @@ -2828,4 +2848,18 @@ cc_test( ], ) +cc_test( + name = "weight_cache_test", + srcs = ["weight_cache_test.cc"], + deps = [ + ":test_main", + ":weight_cache", + ":weight_cache_schema", + "//tensorflow/lite/c:common", + "@XNNPACK", + "@com_google_googletest//:gtest", + "@flatbuffers//:runtime_cc", + ], +) + tflite_portable_test_suite_combined(combine_conditions = {"deps": [":test_main"]}) diff --git a/tensorflow/lite/delegates/xnnpack/weight_cache.cc b/tensorflow/lite/delegates/xnnpack/weight_cache.cc new file mode 100644 index 00000000000000..cb71086b9d0f31 --- /dev/null +++ b/tensorflow/lite/delegates/xnnpack/weight_cache.cc @@ -0,0 +1,491 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/delegates/xnnpack/weight_cache.h" + +#include +#include + +#if defined(_MSC_VER) +#include +#else +#include +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "xnnpack.h" // from @XNNPACK +#include "flatbuffers/base.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/verifier.h" // from @flatbuffers +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/delegates/xnnpack/weight_cache_schema_generated.h" +#include "tensorflow/lite/logger.h" +#include "tensorflow/lite/minimal_logging.h" + +#define XNNPACK_ABORT_CHECK(TEST, ...) \ + if (!(TEST)) { \ + TFLITE_LOG_PROD(tflite::TFLITE_LOG_ERROR, __VA_ARGS__); \ + std::abort(); \ + } + +namespace tflite::xnnpack { + +namespace { +constexpr size_t kMinAlignment = 64; + +template +class ScopeGuard { + public: + explicit ScopeGuard(F&& callback) : callback_(std::forward(callback)) {} + ~ScopeGuard() { + if (active_) { + callback_(); + } + } + + void Deactivate() { active_ = false; } + + private: + F callback_; + bool active_ = true; +}; + +template +ScopeGuard(F&&) -> ScopeGuard; + +} // namespace + +void swap(MMapHandle& a, MMapHandle& b) { + using std::swap; + swap(a.size_, b.size_); + swap(a.data_, b.data_); +} + +MMapHandle::~MMapHandle() { UnMap(); } + +MMapHandle::MMapHandle(MMapHandle&& other) { swap(*this, other); } + +MMapHandle& MMapHandle::operator=(MMapHandle&& other) { + swap(*this, other); + return *this; +} + +bool MMapHandle::Map(const char* path) { + this->UnMap(); + + const int fd = open(path, O_RDONLY); + if (fd == -1) { + TFLITE_LOG_PROD(tflite::TFLITE_LOG_ERROR, + "Could not open file to mmap: %s (%s).", strerror(errno), + path) + return false; + } + + const ScopeGuard close_fd_on_return([&fd] { + if (fd >= 0) { + close(fd); + } + }); + + struct stat file_stats; + if (fstat(fd, &file_stats)) { + TFLITE_LOG_PROD(tflite::TFLITE_LOG_ERROR, + "Could not access file stats to get size: %s (%s).", + strerror(errno), path) + return false; + } + + size_ = file_stats.st_size; +#if defined(_MSC_VER) + data_ = new uint8_t[size_]; + { + const uint8_t* data_reader = data; + size_t remaining_bytes = size_; + while (remaining_bytes > 0) { + const ssize_t bytes = read(fd, data_reader, remaining_bytes); + if (bytes == -1) { + TFLITE_LOG_PROD(tflite::TFLITE_LOG_ERROR, + "Could not read file ('%s'): %s.", path, + strerror(errno)) + UnMap(); + return false; + } + remaining_bytes -= bytes; + data_reader += bytes; + } + } +#else + data_ = static_cast( + mmap(/*addr=*/nullptr, size_, PROT_READ, MAP_SHARED, fd, /*offset=*/0)); + if (data_ == MAP_FAILED) { + TFLITE_LOG_PROD(tflite::TFLITE_LOG_ERROR, "Could not mmap file: %s (%s).", + strerror(errno), path) + data_ = nullptr; + size_ = 0; + return false; + } +#endif + + return true; +} + +void MMapHandle::UnMap() { + if (data_) { +#if defined(_MSC_VER) + delete[] data_; +#else + munmap(data_, size_); +#endif + data_ = nullptr; + size_ = 0; + } +} + +void* WeightCacheBuilder::Reserve(size_t size) { + size_t offset = buffer_data_.size(); + const size_t misalign = offset % kMinAlignment; + if (misalign) { + size += kMinAlignment - misalign; + offset += kMinAlignment - misalign; + } + buffer_data_.resize(buffer_data_.size() + size); + return buffer_data_.data() + offset; +} + +bool WeightCacheBuilder::SpanIsWithinBuffer(const void* ptr, + uint64_t size) const { + const uintptr_t buf_begin = reinterpret_cast(buffer_data_.data()); + const uintptr_t buf_end = buf_begin + buffer_data_.size(); + const uintptr_t ptr_begin = reinterpret_cast(ptr); + const uintptr_t ptr_end = ptr_begin + size; + return ptr_begin >= buf_begin && ptr_begin <= buf_end && + ptr_end >= buf_begin && ptr_end <= buf_end; +} + +BufferLocation WeightCacheBuilder::Append(PackIdentifier pack_id, + const void* data, uint64_t size) { + const void* append_data = data; + if (!SpanIsWithinBuffer(data, size)) { + void* reserved_data = Reserve(size); + std::memcpy(reserved_data, data, size); + append_data = reserved_data; + } + BufferLocation loc{.offset = reinterpret_cast(append_data) - + reinterpret_cast(buffer_data_.data()), + .size = size}; + schema_.buffers.push_back(std::make_unique( + cache::schema::BufferT{.packing_algorithm_id = pack_id.pack_algorithm_id, + .weights_id = pack_id.weights_id, + .bias_id = pack_id.bias_id, + .offset = loc.offset, + .size = loc.size})); + return loc; +} + +bool WeightCacheBuilder::ShouldWrite() const { return !buffer_data_.empty(); } + +namespace { + +bool WriteData(const int fd, const uint8_t* data, size_t size, + const char* const file_path, const char* step_description) { + for (size_t bytes = 0; bytes < size;) { + const ssize_t written_bytes = write(fd, data + bytes, size - bytes); + if (written_bytes == -1) { + TFLITE_LOG_PROD(tflite::TFLITE_LOG_ERROR, + "Cache file write incomplete (%s). %s: %s", file_path, + step_description, strerror(errno)) + } + bytes += written_bytes; + } + + return true; +} + +} // namespace + +bool WeightCacheBuilder::Write(const char* path) { + const int fd = open(path, O_WRONLY | O_CREAT | O_TRUNC, 0644); + if (fd == -1) { + TFLITE_LOG_PROD(tflite::TFLITE_LOG_ERROR, + "Could not open cache file ('%s') for writing: %s", path, + strerror(errno)) + return false; + } + + const ScopeGuard close_fd_on_return([&fd] { + if (fd >= 0) { + close(fd); + } + }); + + flatbuffers::FlatBufferBuilder builder; + // Add a fake size and the base offset to mutate them afterwards. Otherwise + // space for it won't be added to the flatbuffer. + schema_.flatbuffer_size = 1; + schema_.base_offset = 1; + FinishPackedWeightsBuffer( + builder, cache::schema::PackedWeights::Pack(builder, &schema_)); + + // Mutate the flatbuffer size and base offset fields. + auto* mutable_packed_weights = + cache::schema::GetMutablePackedWeights(builder.GetBufferPointer()); + mutable_packed_weights->mutate_flatbuffer_size(builder.GetSize()); + const size_t misalign = builder.GetSize() % kMinAlignment; + const size_t alignment_offset = misalign ? kMinAlignment - misalign : 0; + mutable_packed_weights->mutate_base_offset(builder.GetSize() + + alignment_offset); + + // Write the flatbuffer which serves as a header to index the following data. + if (!WriteData(fd, builder.GetBufferPointer(), builder.GetSize(), path, + "Header")) { + return false; + } + // Add some padding so that the cache file can be mmaped and the buffers + // stay aligned correctly. + const uint8_t fill[kMinAlignment] = {0}; + if (!WriteData(fd, fill, alignment_offset, path, "Alignment padding")) { + return false; + } + // Write the actual buffer data. + if (!WriteData(fd, buffer_data_.data(), buffer_data_.size(), path, + "Buffer data")) { + return false; + } + return true; +} + +MMapWeightCacheProvider::MMapWeightCacheProvider( + MMapWeightCacheProvider&& other) { + *this = std::move(other); +} + +MMapWeightCacheProvider& MMapWeightCacheProvider::operator=( + MMapWeightCacheProvider&& other) { + using std::swap; + swap(cache_provider_, other.cache_provider_); + // The contexts need to keep pointing to their owning object. + cache_provider_.context = this; + other.cache_provider_.context = &other; + swap(file_path_, other.file_path_); + swap(buffer_address_to_identifier_, other.buffer_address_to_identifier_); + swap(cache_key_to_offset_, other.cache_key_to_offset_); + swap(mmap_handle_, other.mmap_handle_); + swap(mmap_buffer_base_offset_, other.mmap_buffer_base_offset_); + swap(builder_, other.builder_); + return *this; +} + +void MMapWeightCacheProvider::SetFilePath(const char* path) { + XNNPACK_ABORT_CHECK( + !IsFinalized(), + "Cannot change the path of a cache that has already been loaded."); + file_path_ = path; +} + +bool MMapWeightCacheProvider::Load(const std::string& path) { + file_path_ = path; + if (mmap_handle_.Map(path.c_str())) { + return Load(std::move(mmap_handle_)); + } + return false; +} + +bool MMapWeightCacheProvider::Load(MMapHandle&& handle) { + swap(mmap_handle_, handle); + // Verifiy the flabuffer part of the file. + const size_t verifier_size = + std::min(mmap_handle_.size(), + static_cast(FLATBUFFERS_MAX_BUFFER_SIZE - 1)); + flatbuffers::Verifier verifier(mmap_handle_.data(), verifier_size); + if (!cache::schema::VerifyPackedWeightsBuffer(verifier)) { + TFLITE_LOG_PROD(tflite::TFLITE_LOG_ERROR, + "Packed weights buffer validation failed."); + return false; + } + + // Load flatbuffer. + const cache::schema::PackedWeights* packed_weights = + cache::schema::GetPackedWeights(mmap_handle_.data()); + if (!packed_weights) { + TFLITE_LOG_PROD(tflite::TFLITE_LOG_ERROR, + "Could not get packed weights from flatbuffer."); + return false; + } + mmap_buffer_base_offset_ = packed_weights->base_offset(); + if (const auto buffers = packed_weights->buffers(); buffers) { + for (auto* buffer : *buffers) { + if (!buffer) { + TFLITE_LOG_PROD(tflite::TFLITE_LOG_ERROR, + "Invalid buffer address in buffer list."); + return false; + } + cache_key_to_offset_.emplace( + PackIdentifier{.pack_algorithm_id = buffer->packing_algorithm_id(), + .weights_id = buffer->weights_id(), + .bias_id = buffer->bias_id()}, + BufferLocation{.offset = buffer->offset(), .size = buffer->size()}); + } + } + return true; +} + +void MMapWeightCacheProvider::MapTensorIdentifiers( + const TfLiteTensor* tensors, const size_t size, + const std::unordered_map& tensor_index_to_identifier) { + for (const auto [index, identifier] : tensor_index_to_identifier) { + XNNPACK_ABORT_CHECK(index < size, + "Tensor index corresponds to a non existing tensor."); + buffer_address_to_identifier_[tensors[index].data.data] = identifier; + } +} + +size_t MMapWeightCacheProvider::LookUp( + const xnn_weights_cache_look_up_key* cache_key) { + if (!cache_key) { + TFLITE_LOG_PROD(tflite::TFLITE_LOG_ERROR, "A null cache key was provided."); + return SIZE_MAX; + } + const PackIdentifier pack_id = BuildPackIdentifier(*cache_key); + if (auto offset_it = cache_key_to_offset_.find(pack_id); + offset_it != cache_key_to_offset_.end()) { + return offset_it->second.offset; + } + return SIZE_MAX; +} + +void* MMapWeightCacheProvider::ReserveSpace(size_t size) { + XNNPACK_ABORT_CHECK(!IsFinalized(), + "Cannot reserve space in a finalized cache."); + return builder_.Reserve(size); +} + +size_t MMapWeightCacheProvider::LookUpOrInsert( + const xnn_weights_cache_look_up_key* cache_key, void* ptr, size_t size) { + XNNPACK_ABORT_CHECK(cache_key, "A null cache key was provided."); + + const PackIdentifier pack_id = BuildPackIdentifier(*cache_key); + if (auto offset_it = cache_key_to_offset_.find(pack_id); + offset_it != cache_key_to_offset_.end()) { + return offset_it->second.offset; + } + + XNNPACK_ABORT_CHECK(!IsFinalized(), + "Cannot insert a buffer in a finalized cache."); + + const BufferLocation location = builder_.Append(pack_id, ptr, size); + cache_key_to_offset_.emplace(pack_id, location); + return location.offset; +} + +void* MMapWeightCacheProvider::OffsetToAddr(const size_t offset) { + // While the cache is being built, the buffer could grow and need to be + // reallocated so we cannot ensure pointer stability. + XNNPACK_ABORT_CHECK( + IsFinalized(), + "Cannot get the address of a buffer in a non finalized cache."); + return mmap_handle_.data() + mmap_buffer_base_offset_ + offset; +} + +void MMapWeightCacheProvider::Reset() { + MMapWeightCacheProvider empty; + std::swap(*this, empty); +} + +bool MMapWeightCacheProvider::Finalize() { + if (IsFinalized()) { + return true; + } + if (file_path_.empty()) { + TFLITE_LOG_PROD(tflite::TFLITE_LOG_ERROR, + "File path wasn't set. Cannot finalize the cache."); + return false; + } + std::string file_path = file_path_; + if (!builder_.Write(file_path.c_str())) { + return false; + } + // The buffer mapping needs to be kept. We save it and restore it after the + // Reset. + std::unordered_map + buffer_address_to_identifier_backup = + std::move(buffer_address_to_identifier_); + Reset(); + buffer_address_to_identifier_ = + std::move(buffer_address_to_identifier_backup); + return Load(file_path); +} + +bool MMapWeightCacheProvider::IsFinalized() const { + return mmap_handle_.IsMapped(); +} + +size_t MMapWeightCacheProvider::look_up( + void* context, const xnn_weights_cache_look_up_key* cache_key) { + return reinterpret_cast(context)->LookUp(cache_key); +} + +void* MMapWeightCacheProvider::reserve_space(void* context, size_t n) { + return reinterpret_cast(context)->ReserveSpace(n); +} + +size_t MMapWeightCacheProvider::look_up_or_insert( + void* context, const xnn_weights_cache_look_up_key* cache_key, void* ptr, + size_t size) { + return reinterpret_cast(context)->LookUpOrInsert( + cache_key, ptr, size); +} + +bool MMapWeightCacheProvider::is_finalized(void* context) { + return reinterpret_cast(context)->IsFinalized(); +} + +void* MMapWeightCacheProvider::offset_to_addr(void* context, size_t offset) { + return reinterpret_cast(context)->OffsetToAddr( + offset); +} + +enum xnn_status MMapWeightCacheProvider::delete_cache(void* context) { + reinterpret_cast(context)->Reset(); + return xnn_status_success; +} + +PackIdentifier MMapWeightCacheProvider::BuildPackIdentifier( + const xnn_weights_cache_look_up_key& key) { + const auto get_buffer_id = [&](const void* buffer) -> size_t { + if (buffer) { + const auto identifier_it = buffer_address_to_identifier_.find(buffer); + XNNPACK_ABORT_CHECK(identifier_it != buffer_address_to_identifier_.end(), + "Unknown constant buffer passed to HashCacheKey."); + return identifier_it->second; + } + return PackIdentifier::kNoId; + }; + return PackIdentifier{.pack_algorithm_id = key.seed, + .weights_id = get_buffer_id(key.kernel), + .bias_id = get_buffer_id(key.bias)}; +} + +} // namespace tflite::xnnpack diff --git a/tensorflow/lite/delegates/xnnpack/weight_cache.h b/tensorflow/lite/delegates/xnnpack/weight_cache.h new file mode 100644 index 00000000000000..8942ea3beb3e9e --- /dev/null +++ b/tensorflow/lite/delegates/xnnpack/weight_cache.h @@ -0,0 +1,306 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_WEIGHT_CACHE_H_ +#define TENSORFLOW_LITE_DELEGATES_XNNPACK_WEIGHT_CACHE_H_ + +#include +#include +#include +#include +#include +#include + +#include "xnnpack.h" // from @XNNPACK +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/delegates/xnnpack/weight_cache_schema_generated.h" + +// WARNING: the interface in this file is still under experimentation and WILL +// CHANGE. Do not rely on it. + +// TFLite doesn't use absl hashing utilities. + +namespace tflite { +namespace xnnpack { + +struct PackIdentifier { + enum { kNoId = SIZE_MAX }; + uint64_t pack_algorithm_id = kNoId; + uint64_t weights_id = kNoId; + uint64_t bias_id = kNoId; + + friend bool operator==(const PackIdentifier& a, const PackIdentifier& b) { + return a.pack_algorithm_id == b.pack_algorithm_id && + a.weights_id == b.weights_id && a.bias_id == b.bias_id; + } + + struct Hash { + size_t operator()(const PackIdentifier& p) const { + std::hash hasher; + return hasher(p.pack_algorithm_id) ^ hasher(p.weights_id) ^ + hasher(p.bias_id); + } + }; +}; + +struct BufferLocation { + uint64_t offset; + uint64_t size; +}; + +// Handles MMap allocations lifetime. +// +// When mapped, provides a view over the allocation for convenience. +// +// WARNING: the interface in this file is still under experimentation and WILL +// CHANGE. Do not rely on it. +class MMapHandle { + public: + using value_type = uint8_t; + + MMapHandle() = default; + ~MMapHandle(); + MMapHandle(const MMapHandle&) = delete; + MMapHandle& operator=(const MMapHandle&) = delete; + MMapHandle(MMapHandle&&); + MMapHandle& operator=(MMapHandle&&); + + // Maps the file at the given path. + [[nodiscard /*Mapping a file can fail.*/]] + bool Map(const char* path); + + // Unmaps an existing mapping. + void UnMap(); + + // Returns true if a mapping exists. + bool IsMapped() const { return data_ != nullptr; } + + // Returns the mapping buffer. + uint8_t* data() { return data_; } + + // Returns the mapping buffer. + const uint8_t* data() const { return data_; } + + // Returns the mapping size in bytes. + size_t size() const { return size_; } + + uint8_t* begin() { return data(); } + + const uint8_t* begin() const { return data(); } + + uint8_t* end() { return data() + size(); } + + const uint8_t* end() const { return data() + size(); } + + friend void swap(MMapHandle& a, MMapHandle& b); + + private: + size_t size_ = 0; + uint8_t* data_ = nullptr; +}; + +// Provides storage to write the packed buffers to and saves those to disk. +// +// WARNING: the interface in this file is still under experimentation and WILL +// CHANGE. Do not rely on it. +class WeightCacheBuilder { + public: + // Reserves space in the data buffer for the required size in bytes and + // returns the address of that space. + // + // Sets `last_reserve` to the offset from `buffer_data_`'s start and `n`. + // + // A call to `Reserve` should alway be followed by a call to `Append`. + [[nodiscard /*The pointer to reserved space should be used.*/]] + void* Reserve(size_t size); + + // Adds a buffer to the cache. + // + // The buffer space must have been reserved before using `Reserve`. If not, a + // new call to `Reserve` will be done and the data will be copied over. + [[nodiscard /*The location to the appended data should be saved.*/]] + BufferLocation Append(PackIdentifier pack_id, const void* data, + uint64_t size); + + // Checks whether this builder has data that needs to be written to disk. + bool ShouldWrite() const; + + // Writes the flatbuffer to disk. + [[nodiscard /*Writing the weight cache can fail.*/]] + bool Write(const char* path); + + // Helper for testing. + // + // WARNING: this exposes class implementation details for testing purposes and + // may be removed at any time. + const std::vector& BufferData() const { return buffer_data_; } + + private: + bool SpanIsWithinBuffer(const void* ptr, uint64_t size) const; + + cache::schema::PackedWeightsT schema_; + std::vector buffer_data_; +}; + +// Allows XNNPack to directly load packed weights from disk instead of having to +// repack them every time. +// +// XNNPack kernels do not have knowledge of the TFLite context. The only thing +// they can access is the buffers address. We rely on the fact that the address +// provided by TFLite is unique in order to find out the buffer identifier. +// +// To use the cache you need to: +// +// - Map the buffer addresses to their identifier with `MapTensorIdentifiers` +// - Load the cache file. +// - Finalize the cache before calling the run functions of XNNPack (setup and +// reshape are ok). +class MMapWeightCacheProvider { + public: + MMapWeightCacheProvider() = default; + MMapWeightCacheProvider(const MMapWeightCacheProvider&) = delete; + MMapWeightCacheProvider& operator=(const MMapWeightCacheProvider&) = delete; + MMapWeightCacheProvider(MMapWeightCacheProvider&&); + MMapWeightCacheProvider& operator=(MMapWeightCacheProvider&&); + + // Changes the file path to save the cache to. + // + // WARNING: Can only be called if the cache isn't finalized. + void SetFilePath(const char* file_path); + + // Loads a flatbuffer following the layout in weight_cache_schema.fbs and set + // the file path. + [[nodiscard /*Loading a cache file may fail.*/]] + bool Load(const std::string& path); + + // Loads an MMap allocation following the layout in weight_cache_schema.fbs. + [[nodiscard /*Loading cache data may fail.*/]] + bool Load(MMapHandle&& mmap_handle); + + // Creates the tensor map. + void MapTensorIdentifiers( + const TfLiteTensor* tensors, size_t size, + const std::unordered_map& tensor_index_to_identifier); + + // Returns the offset of the buffer identified by `cache_key`. + // + // If the buffer isn't found, return SIZE_MAX. + [[nodiscard]] + size_t LookUp(const xnn_weights_cache_look_up_key* cache_key); + + // Reserves space for a buffer of given size and returns a pointer to it. + // + // The buffer data should be filled and `LookUpOrInsert` should be immediately + // called. + [[nodiscard]] + void* ReserveSpace(size_t size); + + // Returns the offset of the buffer identified by `cache_key`. If the lookup + // fails, inserts the span `[ptr, ptr+size)`. + // + // This should be called after ReserveSpace and `ptr` should be the result of + // that call with the given `size`. + // + // WARNING: The cache key cannot be null. + [[nodiscard]] + size_t LookUpOrInsert(const xnn_weights_cache_look_up_key* cache_key, + void* ptr, size_t size); + + // Gets the pointer to the buffer at the given offset. + // + // WARNING: This requires the buffer to be finalized. + // WARNING: This does not check the validity of the passed offset. + void* OffsetToAddr(size_t offset); + + // Resets the weight cache provider as if it had been default constructed. + void Reset(); + + // Ensures that the cache is ready. + // + // If the cache file already exists, this is a no-op. Otherwise, this writes + // the file to disk and reloads it. + [[nodiscard /*Writing the cache file may fail.*/]] + bool Finalize(); + + // Checks whether the cache is ready to be used. + bool IsFinalized() const; + + // Returns true if any weights have been added to the underlying builder. + bool IsBuilding() const { return !IsFinalized() && !file_path_.empty(); }; + + // Returns true if a file is mapped or a file path is set. + bool IsActive() const { return IsFinalized() || !file_path_.empty(); }; + + // Returns the cache provider expected by XNNPack. + xnn_weights_cache_provider& GetCacheProvider() { return cache_provider_; } + + // C interface: `xnn_weights_cache_provider` callback. + static size_t look_up(void* context, + const xnn_weights_cache_look_up_key* cache_key); + + // C interface: `xnn_weights_cache_provider` callback. + static void* reserve_space(void* context, size_t n); + + // C interface: `xnn_weights_cache_provider` callback. + static size_t look_up_or_insert( + void* context, const xnn_weights_cache_look_up_key* cache_key, void* ptr, + size_t size); + + // C interface: `xnn_weights_cache_provider` callback. + static bool is_finalized(void* context); + + // C interface: `xnn_weights_cache_provider` callback. + static void* offset_to_addr(void* context, size_t offset); + + // C interface: `xnn_weights_cache_provider` callback. + static enum xnn_status delete_cache(void* context); + + private: + // Hashes a cache key to lookup in `cache_key_to_identifier_`. + PackIdentifier BuildPackIdentifier(const xnn_weights_cache_look_up_key& key); + + // Cache provider implementation for XNNPack. + xnn_weights_cache_provider cache_provider_{ + .context = this, + .look_up = MMapWeightCacheProvider::look_up, + .reserve_space = MMapWeightCacheProvider::reserve_space, + .look_up_or_insert = MMapWeightCacheProvider::look_up_or_insert, + .is_finalized = MMapWeightCacheProvider::is_finalized, + .offset_to_addr = MMapWeightCacheProvider::offset_to_addr, + .delete_cache = MMapWeightCacheProvider::delete_cache}; + + // Path to the cache file. + std::string file_path_; + + // Maps buffer addresses to buffer identifiers. + std::unordered_map buffer_address_to_identifier_; + + // Maps cache request hashes to the buffer identifier. + std::unordered_multimap + cache_key_to_offset_; + + // MMap allocation handler. + MMapHandle mmap_handle_; + + // The offset to the first buffer data in the MMap allocation. + size_t mmap_buffer_base_offset_; + + // Used to build the cache. + WeightCacheBuilder builder_; +}; + +} // namespace xnnpack +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_XNNPACK_WEIGHT_CACHE_H_ diff --git a/tensorflow/lite/delegates/xnnpack/weight_cache_schema.fbs b/tensorflow/lite/delegates/xnnpack/weight_cache_schema.fbs new file mode 100644 index 00000000000000..0658054f21c07e --- /dev/null +++ b/tensorflow/lite/delegates/xnnpack/weight_cache_schema.fbs @@ -0,0 +1,52 @@ +// Copyright 2024 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This is a list of buffers with identifiers, to host the CPU-specific cache on disk. +namespace tflite.xnnpack.cache.schema; + +// Schema version. +file_identifier "V001"; +// File extension of written files. +file_extension "xnn_weights"; + +table Buffer { + // To uniquely identify a packed buffer we need to keep track of the packing + // algorithm and of the buffers that were used to generate it. + packing_algorithm_id: uint64; + weights_id: uint64; + bias_id: uint64; + + /// The buffer data is appended after the flatbuffer to bypass 2GB file size + /// limitation. The offset is calculated relative to the base offset. + /// (i.e. beginning of the file + base_offset). + offset: uint64; + + /// Size of the buffer in bytes. + size: uint64; +} + +table PackedWeights { + /// A list of buffers. + buffers: [Buffer]; + + /// The serialized file is `flatbuffer_size` of bytes representing + /// `NamedBuffers` appended with a blob representing the buffer content. + flatbuffer_size: uint64; + + /// Defines the base offset for the data appended to the file. That offset + /// may be needed to guarantee data alignment. + base_offset:uint64; +} + +root_type PackedWeights; diff --git a/tensorflow/lite/delegates/xnnpack/weight_cache_schema_generated.h b/tensorflow/lite/delegates/xnnpack/weight_cache_schema_generated.h new file mode 100755 index 00000000000000..fa5d30a4cdae65 --- /dev/null +++ b/tensorflow/lite/delegates/xnnpack/weight_cache_schema_generated.h @@ -0,0 +1,422 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// automatically generated by the FlatBuffers compiler, do not modify + + +#ifndef FLATBUFFERS_GENERATED_WEIGHTCACHESCHEMA_TFLITE_XNNPACK_CACHE_SCHEMA_H_ +#define FLATBUFFERS_GENERATED_WEIGHTCACHESCHEMA_TFLITE_XNNPACK_CACHE_SCHEMA_H_ + +#include "flatbuffers/flatbuffers.h" + +// Ensure the included flatbuffers.h is the same version as when this file was +// generated, otherwise it may not be compatible. +static_assert(FLATBUFFERS_VERSION_MAJOR == 24 && + FLATBUFFERS_VERSION_MINOR == 3 && + FLATBUFFERS_VERSION_REVISION == 25, + "Non-compatible flatbuffers version included"); + +namespace tflite { +namespace xnnpack { +namespace cache { +namespace schema { + +struct Buffer; +struct BufferBuilder; +struct BufferT; + +struct PackedWeights; +struct PackedWeightsBuilder; +struct PackedWeightsT; + +struct BufferT : public ::flatbuffers::NativeTable { + typedef Buffer TableType; + uint64_t packing_algorithm_id = 0; + uint64_t weights_id = 0; + uint64_t bias_id = 0; + uint64_t offset = 0; + uint64_t size = 0; +}; + +struct Buffer FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef BufferT NativeTableType; + typedef BufferBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_PACKING_ALGORITHM_ID = 4, + VT_WEIGHTS_ID = 6, + VT_BIAS_ID = 8, + VT_OFFSET = 10, + VT_SIZE = 12 + }; + uint64_t packing_algorithm_id() const { + return GetField(VT_PACKING_ALGORITHM_ID, 0); + } + bool mutate_packing_algorithm_id(uint64_t _packing_algorithm_id = 0) { + return SetField(VT_PACKING_ALGORITHM_ID, _packing_algorithm_id, 0); + } + uint64_t weights_id() const { + return GetField(VT_WEIGHTS_ID, 0); + } + bool mutate_weights_id(uint64_t _weights_id = 0) { + return SetField(VT_WEIGHTS_ID, _weights_id, 0); + } + uint64_t bias_id() const { + return GetField(VT_BIAS_ID, 0); + } + bool mutate_bias_id(uint64_t _bias_id = 0) { + return SetField(VT_BIAS_ID, _bias_id, 0); + } + /// The buffer data is appended after the flatbuffer to bypass 2GB file size + /// limitation. The offset is calculated relative to the base offset. + /// (i.e. beginning of the file + base_offset). + uint64_t offset() const { + return GetField(VT_OFFSET, 0); + } + bool mutate_offset(uint64_t _offset = 0) { + return SetField(VT_OFFSET, _offset, 0); + } + /// Size of the buffer in bytes. + uint64_t size() const { + return GetField(VT_SIZE, 0); + } + bool mutate_size(uint64_t _size = 0) { + return SetField(VT_SIZE, _size, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_PACKING_ALGORITHM_ID, 8) && + VerifyField(verifier, VT_WEIGHTS_ID, 8) && + VerifyField(verifier, VT_BIAS_ID, 8) && + VerifyField(verifier, VT_OFFSET, 8) && + VerifyField(verifier, VT_SIZE, 8) && + verifier.EndTable(); + } + BufferT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BufferT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BufferT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BufferBuilder { + typedef Buffer Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_packing_algorithm_id(uint64_t packing_algorithm_id) { + fbb_.AddElement(Buffer::VT_PACKING_ALGORITHM_ID, packing_algorithm_id, 0); + } + void add_weights_id(uint64_t weights_id) { + fbb_.AddElement(Buffer::VT_WEIGHTS_ID, weights_id, 0); + } + void add_bias_id(uint64_t bias_id) { + fbb_.AddElement(Buffer::VT_BIAS_ID, bias_id, 0); + } + void add_offset(uint64_t offset) { + fbb_.AddElement(Buffer::VT_OFFSET, offset, 0); + } + void add_size(uint64_t size) { + fbb_.AddElement(Buffer::VT_SIZE, size, 0); + } + explicit BufferBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateBuffer( + ::flatbuffers::FlatBufferBuilder &_fbb, + uint64_t packing_algorithm_id = 0, + uint64_t weights_id = 0, + uint64_t bias_id = 0, + uint64_t offset = 0, + uint64_t size = 0) { + BufferBuilder builder_(_fbb); + builder_.add_size(size); + builder_.add_offset(offset); + builder_.add_bias_id(bias_id); + builder_.add_weights_id(weights_id); + builder_.add_packing_algorithm_id(packing_algorithm_id); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateBuffer(::flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct PackedWeightsT : public ::flatbuffers::NativeTable { + typedef PackedWeights TableType; + std::vector> buffers{}; + uint64_t flatbuffer_size = 0; + uint64_t base_offset = 0; + PackedWeightsT() = default; + PackedWeightsT(const PackedWeightsT &o); + PackedWeightsT(PackedWeightsT&&) FLATBUFFERS_NOEXCEPT = default; + PackedWeightsT &operator=(PackedWeightsT o) FLATBUFFERS_NOEXCEPT; +}; + +struct PackedWeights FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef PackedWeightsT NativeTableType; + typedef PackedWeightsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_BUFFERS = 4, + VT_FLATBUFFER_SIZE = 6, + VT_BASE_OFFSET = 8 + }; + /// A list of buffers. + const ::flatbuffers::Vector<::flatbuffers::Offset> *buffers() const { + return GetPointer> *>(VT_BUFFERS); + } + ::flatbuffers::Vector<::flatbuffers::Offset> *mutable_buffers() { + return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset> *>(VT_BUFFERS); + } + /// The serialized file is `flatbuffer_size` of bytes representing + /// `NamedBuffers` appended with a blob representing the buffer content. + uint64_t flatbuffer_size() const { + return GetField(VT_FLATBUFFER_SIZE, 0); + } + bool mutate_flatbuffer_size(uint64_t _flatbuffer_size = 0) { + return SetField(VT_FLATBUFFER_SIZE, _flatbuffer_size, 0); + } + /// Defines the base offset for the data appended to the file. That offset + /// may be needed to guarantee data alignment. + uint64_t base_offset() const { + return GetField(VT_BASE_OFFSET, 0); + } + bool mutate_base_offset(uint64_t _base_offset = 0) { + return SetField(VT_BASE_OFFSET, _base_offset, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_BUFFERS) && + verifier.VerifyVector(buffers()) && + verifier.VerifyVectorOfTables(buffers()) && + VerifyField(verifier, VT_FLATBUFFER_SIZE, 8) && + VerifyField(verifier, VT_BASE_OFFSET, 8) && + verifier.EndTable(); + } + PackedWeightsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(PackedWeightsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const PackedWeightsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct PackedWeightsBuilder { + typedef PackedWeights Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_buffers(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> buffers) { + fbb_.AddOffset(PackedWeights::VT_BUFFERS, buffers); + } + void add_flatbuffer_size(uint64_t flatbuffer_size) { + fbb_.AddElement(PackedWeights::VT_FLATBUFFER_SIZE, flatbuffer_size, 0); + } + void add_base_offset(uint64_t base_offset) { + fbb_.AddElement(PackedWeights::VT_BASE_OFFSET, base_offset, 0); + } + explicit PackedWeightsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreatePackedWeights( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> buffers = 0, + uint64_t flatbuffer_size = 0, + uint64_t base_offset = 0) { + PackedWeightsBuilder builder_(_fbb); + builder_.add_base_offset(base_offset); + builder_.add_flatbuffer_size(flatbuffer_size); + builder_.add_buffers(buffers); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreatePackedWeightsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<::flatbuffers::Offset> *buffers = nullptr, + uint64_t flatbuffer_size = 0, + uint64_t base_offset = 0) { + auto buffers__ = buffers ? _fbb.CreateVector<::flatbuffers::Offset>(*buffers) : 0; + return tflite::xnnpack::cache::schema::CreatePackedWeights( + _fbb, + buffers__, + flatbuffer_size, + base_offset); +} + +::flatbuffers::Offset CreatePackedWeights(::flatbuffers::FlatBufferBuilder &_fbb, const PackedWeightsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +inline BufferT *Buffer::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new BufferT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void Buffer::UnPackTo(BufferT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = packing_algorithm_id(); _o->packing_algorithm_id = _e; } + { auto _e = weights_id(); _o->weights_id = _e; } + { auto _e = bias_id(); _o->bias_id = _e; } + { auto _e = offset(); _o->offset = _e; } + { auto _e = size(); _o->size = _e; } +} + +inline ::flatbuffers::Offset Buffer::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BufferT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateBuffer(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateBuffer(::flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const BufferT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _packing_algorithm_id = _o->packing_algorithm_id; + auto _weights_id = _o->weights_id; + auto _bias_id = _o->bias_id; + auto _offset = _o->offset; + auto _size = _o->size; + return tflite::xnnpack::cache::schema::CreateBuffer( + _fbb, + _packing_algorithm_id, + _weights_id, + _bias_id, + _offset, + _size); +} + +inline PackedWeightsT::PackedWeightsT(const PackedWeightsT &o) + : flatbuffer_size(o.flatbuffer_size), + base_offset(o.base_offset) { + buffers.reserve(o.buffers.size()); + for (const auto &buffers_ : o.buffers) { buffers.emplace_back((buffers_) ? new tflite::xnnpack::cache::schema::BufferT(*buffers_) : nullptr); } +} + +inline PackedWeightsT &PackedWeightsT::operator=(PackedWeightsT o) FLATBUFFERS_NOEXCEPT { + std::swap(buffers, o.buffers); + std::swap(flatbuffer_size, o.flatbuffer_size); + std::swap(base_offset, o.base_offset); + return *this; +} + +inline PackedWeightsT *PackedWeights::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new PackedWeightsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void PackedWeights::UnPackTo(PackedWeightsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = buffers(); if (_e) { _o->buffers.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->buffers[_i]) { _e->Get(_i)->UnPackTo(_o->buffers[_i].get(), _resolver); } else { _o->buffers[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->buffers.resize(0); } } + { auto _e = flatbuffer_size(); _o->flatbuffer_size = _e; } + { auto _e = base_offset(); _o->base_offset = _e; } +} + +inline ::flatbuffers::Offset PackedWeights::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const PackedWeightsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreatePackedWeights(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreatePackedWeights(::flatbuffers::FlatBufferBuilder &_fbb, const PackedWeightsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const PackedWeightsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _buffers = _o->buffers.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->buffers.size(), [](size_t i, _VectorArgs *__va) { return CreateBuffer(*__va->__fbb, __va->__o->buffers[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _flatbuffer_size = _o->flatbuffer_size; + auto _base_offset = _o->base_offset; + return tflite::xnnpack::cache::schema::CreatePackedWeights( + _fbb, + _buffers, + _flatbuffer_size, + _base_offset); +} + +inline const tflite::xnnpack::cache::schema::PackedWeights *GetPackedWeights(const void *buf) { + return ::flatbuffers::GetRoot(buf); +} + +inline const tflite::xnnpack::cache::schema::PackedWeights *GetSizePrefixedPackedWeights(const void *buf) { + return ::flatbuffers::GetSizePrefixedRoot(buf); +} + +inline PackedWeights *GetMutablePackedWeights(void *buf) { + return ::flatbuffers::GetMutableRoot(buf); +} + +inline tflite::xnnpack::cache::schema::PackedWeights *GetMutableSizePrefixedPackedWeights(void *buf) { + return ::flatbuffers::GetMutableSizePrefixedRoot(buf); +} + +inline const char *PackedWeightsIdentifier() { + return "V001"; +} + +inline bool PackedWeightsBufferHasIdentifier(const void *buf) { + return ::flatbuffers::BufferHasIdentifier( + buf, PackedWeightsIdentifier()); +} + +inline bool SizePrefixedPackedWeightsBufferHasIdentifier(const void *buf) { + return ::flatbuffers::BufferHasIdentifier( + buf, PackedWeightsIdentifier(), true); +} + +inline bool VerifyPackedWeightsBuffer( + ::flatbuffers::Verifier &verifier) { + return verifier.VerifyBuffer(PackedWeightsIdentifier()); +} + +inline bool VerifySizePrefixedPackedWeightsBuffer( + ::flatbuffers::Verifier &verifier) { + return verifier.VerifySizePrefixedBuffer(PackedWeightsIdentifier()); +} + +inline const char *PackedWeightsExtension() { + return "xnn_weights"; +} + +inline void FinishPackedWeightsBuffer( + ::flatbuffers::FlatBufferBuilder &fbb, + ::flatbuffers::Offset root) { + fbb.Finish(root, PackedWeightsIdentifier()); +} + +inline void FinishSizePrefixedPackedWeightsBuffer( + ::flatbuffers::FlatBufferBuilder &fbb, + ::flatbuffers::Offset root) { + fbb.FinishSizePrefixed(root, PackedWeightsIdentifier()); +} + +inline std::unique_ptr UnPackPackedWeights( + const void *buf, + const ::flatbuffers::resolver_function_t *res = nullptr) { + return std::unique_ptr(GetPackedWeights(buf)->UnPack(res)); +} + +inline std::unique_ptr UnPackSizePrefixedPackedWeights( + const void *buf, + const ::flatbuffers::resolver_function_t *res = nullptr) { + return std::unique_ptr(GetSizePrefixedPackedWeights(buf)->UnPack(res)); +} + +} // namespace schema +} // namespace cache +} // namespace xnnpack +} // namespace tflite + +#endif // FLATBUFFERS_GENERATED_WEIGHTCACHESCHEMA_TFLITE_XNNPACK_CACHE_SCHEMA_H_ diff --git a/tensorflow/lite/delegates/xnnpack/weight_cache_test.cc b/tensorflow/lite/delegates/xnnpack/weight_cache_test.cc new file mode 100644 index 00000000000000..850c0d82537037 --- /dev/null +++ b/tensorflow/lite/delegates/xnnpack/weight_cache_test.cc @@ -0,0 +1,708 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/delegates/xnnpack/weight_cache.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include "xnnpack.h" // from @XNNPACK +#include "flatbuffers/verifier.h" // from @flatbuffers +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/delegates/xnnpack/weight_cache_schema_generated.h" + +namespace tflite::xnnpack { + +std::ostream& operator<<(std::ostream& os, const PackIdentifier& p) { + return os << "PackIdentifier{pack_algo: " << p.pack_algorithm_id + << ", weights_id: " << p.weights_id << ", bias_id: " << p.bias_id + << "}"; +} + +namespace { + +using testing::ElementsAreArray; +using testing::Ge; + +#ifndef XNN_TEST_WEIGHT_CACHE_TEMP_FILE_TEMPATE +#define XNN_TEST_WEIGHT_CACHE_TEMP_FILE_TEMPATE \ + "/tmp/weight_cache_test_file.XXXXXX" +#endif +constexpr const char kTempFileTemplate[] = + XNN_TEST_WEIGHT_CACHE_TEMP_FILE_TEMPATE; + +// Wraps a call to `mkstemp` to create temporary files. +class TempFileDesc { + public: + static constexpr struct AutoClose { + } kAutoCLose; + + TempFileDesc() : fd_(mkstemp(path_.data())) { + if (GetFd() < 0) { + perror("Could not create temporary file"); + } + } + + explicit TempFileDesc(AutoClose) : TempFileDesc() { Close(); } + + TempFileDesc(const TempFileDesc&) = delete; + TempFileDesc& operator=(const TempFileDesc&) = delete; + + friend void swap(TempFileDesc& a, TempFileDesc& b) { + std::swap(a.path_, b.path_); + std::swap(a.fd_, b.fd_); + } + + TempFileDesc(TempFileDesc&& other) { swap(*this, other); } + TempFileDesc& operator=(TempFileDesc&& other) { + swap(*this, other); + return *this; + } + + ~TempFileDesc() { Close(); } + + void Close() { + if (GetFd() >= 0) { + close(fd_); + fd_ = -1; + } + } + + const std::string& GetPath() const { return path_; } + + const char* GetCPath() const { return path_.c_str(); } + + int GetFd() const { return fd_; } + + bool IsOpen() const { return fd_ >= 0; } + + private: + std::string path_ = kTempFileTemplate; + int fd_ = -1; +}; + +TEST(MMapHandleTest, DefaultConstructs) { + MMapHandle handle; + EXPECT_FALSE(handle.IsMapped()); + EXPECT_EQ(handle.data(), nullptr); + EXPECT_EQ(handle.size(), 0); +} + +TEST(MMapHandleTest, MapNonExitxingFileFails) { + // I hope this path doesn't exist... + const char* file_path = "sdbgfd"; + MMapHandle handle; + EXPECT_FALSE(handle.Map(file_path)); +} + +TEST(MMapHandleTest, MapExistingFileWorks) { + using std::size; + + const std::string payload = "This is some data in the file."; + + TempFileDesc tmp_file; + ASSERT_TRUE(tmp_file.IsOpen()); + write(tmp_file.GetFd(), payload.c_str(), size(payload)); + tmp_file.Close(); + + MMapHandle handle; + ASSERT_TRUE(handle.Map(tmp_file.GetCPath())); + EXPECT_TRUE(handle.IsMapped()); + EXPECT_NE(handle.data(), nullptr); + EXPECT_THAT(handle.size(), Ge(size(payload))); + EXPECT_THAT(handle, ElementsAreArray(payload)); + + handle.UnMap(); + EXPECT_FALSE(handle.IsMapped()); + EXPECT_EQ(handle.data(), nullptr); + EXPECT_EQ(handle.size(), 0); +} + +TEST(MMapHandleTest, MoveConstructs) { + const std::string payload = "This is some data in the file."; + + TempFileDesc tmp_file; + ASSERT_TRUE(tmp_file.IsOpen()); + write(tmp_file.GetFd(), payload.c_str(), size(payload)); + tmp_file.Close(); + + MMapHandle handle; + ASSERT_TRUE(handle.Map(tmp_file.GetCPath())); + + MMapHandle handle2(std::move(handle)); + + // We are checking that the moved from handle has lost control over the data. + // NOLINTBEGIN(bugprone-use-after-move) + EXPECT_FALSE(handle.IsMapped()); + EXPECT_EQ(handle.data(), nullptr); + EXPECT_EQ(handle.size(), 0); + // NOLINTEND(bugprone-use-after-move) + + EXPECT_TRUE(handle2.IsMapped()); + EXPECT_NE(handle2.data(), nullptr); + EXPECT_THAT(handle2.size(), Ge(size(payload))); + EXPECT_THAT(handle2, ElementsAreArray(payload)); +} + +TEST(WeightCacheBuilderTest, ReserveAppendWriteWorks) { + using std::size; + + const std::string payload = "This is some data in the file."; + const PackIdentifier dummy_id{1, 2, 3}; + + WeightCacheBuilder builder; + + const size_t payload_size = size(payload); + void* buffer = builder.Reserve(payload_size); + std::memcpy(buffer, payload.c_str(), payload_size); + auto loc = builder.Append(dummy_id, buffer, payload_size); + + EXPECT_EQ(loc.size, payload_size); + EXPECT_EQ(builder.BufferData().size(), payload_size); + EXPECT_TRUE(builder.ShouldWrite()); + + TempFileDesc tmp_file; + ASSERT_TRUE(tmp_file.IsOpen()); + tmp_file.Close(); + + ASSERT_TRUE(builder.Write(tmp_file.GetCPath())); + + MMapHandle handle; + ASSERT_TRUE(handle.Map(tmp_file.GetCPath())); + + const cache::schema::PackedWeights* const packed_weights = + cache::schema::GetPackedWeights(handle.data()); + ASSERT_NE(packed_weights, nullptr); + EXPECT_LE(packed_weights->flatbuffer_size(), size(handle) - size(payload)); + ASSERT_NE(packed_weights->buffers(), nullptr); + ASSERT_EQ(packed_weights->buffers()->size(), 1); + ASSERT_NE(packed_weights->buffers()->Get(0), nullptr); + ASSERT_EQ(packed_weights->buffers()->Get(0)->size(), size(payload)); + EXPECT_EQ(packed_weights->buffers()->Get(0)->offset(), 0); + ASSERT_EQ(packed_weights->buffers()->Get(0)->packing_algorithm_id(), + dummy_id.pack_algorithm_id); + ASSERT_EQ(packed_weights->buffers()->Get(0)->weights_id(), + dummy_id.weights_id); + ASSERT_EQ(packed_weights->buffers()->Get(0)->bias_id(), dummy_id.bias_id); + + flatbuffers::Verifier verifier(handle.data(), handle.size()); + EXPECT_TRUE(cache::schema::VerifyPackedWeightsBuffer(verifier)) + << packed_weights->flatbuffer_size() << " " << handle.size() << " " + << packed_weights->buffers()->size() << "\n" + << tmp_file.GetPath(); +} + +TEST(WeightCacheBuilderTest, AppendWithoutReserveWriteWorks) { + using std::size; + + const std::string payload = "This is some data in the file."; + const PackIdentifier dummy_id{1, 2, 3}; + + WeightCacheBuilder builder; + + const size_t payload_size = size(payload); + auto loc = builder.Append(dummy_id, payload.c_str(), payload_size); + + EXPECT_EQ(loc.size, payload_size); + EXPECT_EQ(builder.BufferData().size(), payload_size); + EXPECT_TRUE(builder.ShouldWrite()); + + TempFileDesc tmp_file; + ASSERT_TRUE(tmp_file.IsOpen()); + tmp_file.Close(); + + ASSERT_TRUE(builder.Write(tmp_file.GetCPath())); + + MMapHandle handle; + ASSERT_TRUE(handle.Map(tmp_file.GetCPath())); + + const cache::schema::PackedWeights* const packed_weights = + cache::schema::GetPackedWeights(handle.data()); + ASSERT_NE(packed_weights, nullptr); + EXPECT_LE(packed_weights->flatbuffer_size(), size(handle) - size(payload)); + ASSERT_NE(packed_weights->buffers(), nullptr); + ASSERT_EQ(packed_weights->buffers()->size(), 1); + ASSERT_NE(packed_weights->buffers()->Get(0), nullptr); + ASSERT_EQ(packed_weights->buffers()->Get(0)->size(), size(payload)); + EXPECT_EQ(packed_weights->buffers()->Get(0)->offset(), 0); + ASSERT_EQ(packed_weights->buffers()->Get(0)->packing_algorithm_id(), + dummy_id.pack_algorithm_id); + ASSERT_EQ(packed_weights->buffers()->Get(0)->weights_id(), + dummy_id.weights_id); + ASSERT_EQ(packed_weights->buffers()->Get(0)->bias_id(), dummy_id.bias_id); + + flatbuffers::Verifier verifier(handle.data(), handle.size()); + EXPECT_TRUE(cache::schema::VerifyPackedWeightsBuffer(verifier)) + << packed_weights->flatbuffer_size() << " " << handle.size() << " " + << packed_weights->buffers()->size() << "\n" + << tmp_file.GetPath(); +} + +TEST(WeightCacheBuilderTest, NonExistingPathFails) { + using std::size; + + const std::string payload = "This is some data in the file."; + const PackIdentifier dummy_id{1, 2, 3}; + + WeightCacheBuilder builder; + + const size_t payload_size = size(payload); + auto loc = builder.Append(dummy_id, payload.c_str(), payload_size); + + EXPECT_EQ(loc.size, payload_size); + EXPECT_EQ(builder.BufferData().size(), payload_size); + EXPECT_TRUE(builder.ShouldWrite()); + + EXPECT_FALSE(builder.Write("")); + EXPECT_FALSE(builder.Write("/selktjdsljf")); +} + +struct FakeContext { + // Adds a new tensor and it's backing buffer to the context. + // + // The tensor `data` will not be set until `FinalizeTensors` is called. + void AddTensor(int buffer_identifier, size_t size) { + buffers.emplace_back(size, buffer_identifier); + tensors.push_back({}); + tensors.back().allocation_type = kTfLiteMmapRo; + tensor_buffer_identifiers[tensors.size() - 1] = buffer_identifier; + } + + // Updates the tensor data mappings. + // + // This needs to be called every time the context `tensors` list is + // reallocated (mainly because of insertions). + void FinalizeTensors() { + for (size_t i = 0; i < tensors.size(); ++i) { + tensors[i].data.data = buffers[i].data(); + tensors[i].bytes = buffers[i].size(); + } + } + + // Creates a look up key for the XNNPack weight provider C interface. + xnn_weights_cache_look_up_key LookUpKey(const uint32_t algorithm_seed, + const int weights_index) const { + return {.seed = algorithm_seed, + .kernel = buffers[weights_index].data(), + .bias = nullptr}; + } + + // Creates a look up key for the XNNPack weight provider C interface. + xnn_weights_cache_look_up_key LookUpKey(const uint32_t algorithm_seed, + const int weights_index, + const int bias_index) const { + return {.seed = algorithm_seed, + .kernel = buffers[weights_index].data(), + .bias = buffers[bias_index].data()}; + } + + // Helps creating fake packed data. + void AddTensorToPack(std::vector& pack_buffer, int index) { + const std::vector& buffer = buffers[index]; + pack_buffer.resize(std::max(size(pack_buffer), size(buffer))); + for (size_t i = 0; i < size(buffer); ++i) { + pack_buffer[i] ^= buffer[i]; + } + } + + // Packs the referenced tensors into one buffer. + // + // Returns the pack id to retrieve the packed reference data from + // `packed_buffers`. + template + PackIdentifier PackTensors(xnn_weights_cache_t weight_cache, + const uint32_t algorithm_seed, + const Ids... tensor_indices) { + // Create fake packed and save the result for later lookup tests. + + PackIdentifier pack_id{algorithm_seed, + tensor_buffer_identifiers[tensor_indices]...}; + PackedBuffer& packed = + packed_buffers.emplace(pack_id, PackedBuffer{})->second; + (AddTensorToPack(packed.buffer, tensor_indices), ...); + + // Add the packed buffer to the XNNPack cache. Normaly you would pack in + // place where the reserved space is. + xnn_weights_cache_look_up_key look_up_key = + LookUpKey(algorithm_seed, tensor_indices...); + packed.offset = weight_cache->look_up_or_insert( + weight_cache->context, &look_up_key, packed.buffer.data(), + packed.buffer.size()); + return pack_id; + } + + struct PackedBuffer { + size_t offset; + std::vector buffer; + }; + + std::vector tensors; + std::vector> buffers; + std::unordered_multimap + packed_buffers; + std::unordered_map tensor_buffer_identifiers; +}; + +struct BuildMMapWeightCacheProviderTest : testing::Test { + enum { kAlgoSeed1, kAlgoSeed2, kAlgoSeed3 }; + enum { kBufferId1, kBufferId2, kBufferId3, kBufferId4 }; + + void SetUp() override { + AddTensors(); + EndSetup(); + } + + void AddTensors() { + ctx.AddTensor(/*buffer_identifier=*/kBufferId1, /*size=*/12); + ctx.AddTensor(/*buffer_identifier=*/kBufferId2, /*size=*/43); + ctx.AddTensor(/*buffer_identifier=*/kBufferId3, /*size=*/64); + ctx.AddTensor(/*buffer_identifier=*/kBufferId4, /*size=*/8); + } + + void EndSetup() { + ctx.FinalizeTensors(); + cache_provider.MapTensorIdentifiers(ctx.tensors.data(), ctx.tensors.size(), + ctx.tensor_buffer_identifiers); + } + + FakeContext ctx; + MMapWeightCacheProvider cache_provider; +}; + +TEST_F(BuildMMapWeightCacheProviderTest, LookUpFailsIfKeyDoesntMatch) { + xnn_weights_cache_look_up_key look_up_key{}; + EXPECT_EQ(cache_provider.LookUp(&look_up_key), SIZE_MAX); +} + +TEST_F(BuildMMapWeightCacheProviderTest, LookUpSucceeds) { + enum { kWeightIndex, kBiasIndex }; + const auto pack_id = ctx.PackTensors(&cache_provider.GetCacheProvider(), + kAlgoSeed1, kWeightIndex, kBiasIndex); + const xnn_weights_cache_look_up_key look_up_key = + ctx.LookUpKey(kAlgoSeed1, kWeightIndex, kBiasIndex); + + EXPECT_EQ(cache_provider.LookUp(&look_up_key), + ctx.packed_buffers.find(pack_id)->second.offset); +} + +TEST_F(BuildMMapWeightCacheProviderTest, + DifferentAlgoSeedsSameTensorsDontConflict) { + enum { kWeightIndex, kBiasIndex }; + const auto pack_id_1 = ctx.PackTensors(&cache_provider.GetCacheProvider(), + kAlgoSeed1, kWeightIndex, kBiasIndex); + const auto pack_id_2 = ctx.PackTensors(&cache_provider.GetCacheProvider(), + kAlgoSeed2, kWeightIndex, kBiasIndex); + + const xnn_weights_cache_look_up_key look_up_key_1 = + ctx.LookUpKey(kAlgoSeed1, kWeightIndex, kBiasIndex); + const xnn_weights_cache_look_up_key look_up_key_2 = + ctx.LookUpKey(kAlgoSeed2, kWeightIndex, kBiasIndex); + + EXPECT_EQ(cache_provider.LookUp(&look_up_key_1), + ctx.packed_buffers.find(pack_id_1)->second.offset); + EXPECT_EQ(cache_provider.LookUp(&look_up_key_2), + ctx.packed_buffers.find(pack_id_2)->second.offset); + EXPECT_NE(cache_provider.LookUp(&look_up_key_1), + cache_provider.LookUp(&look_up_key_2)); +} + +TEST_F(BuildMMapWeightCacheProviderTest, + SameAlgoSeedDifferentTensorsDontConflict) { + enum { kWeightIndex1, kWeightIndex2, kBiasIndex1, kBiasIndex2 }; + const auto pack_id_1 = + ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed1, + kWeightIndex1, kBiasIndex1); + const auto pack_id_2 = + ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed1, + kWeightIndex2, kBiasIndex1); + const auto pack_id_3 = + ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed1, + kWeightIndex1, kBiasIndex2); + const auto pack_id_4 = + ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed1, + kWeightIndex2, kBiasIndex2); + + const xnn_weights_cache_look_up_key look_up_key_1 = + ctx.LookUpKey(kAlgoSeed1, kWeightIndex1, kBiasIndex1); + const xnn_weights_cache_look_up_key look_up_key_2 = + ctx.LookUpKey(kAlgoSeed1, kWeightIndex2, kBiasIndex1); + const xnn_weights_cache_look_up_key look_up_key_3 = + ctx.LookUpKey(kAlgoSeed1, kWeightIndex1, kBiasIndex2); + const xnn_weights_cache_look_up_key look_up_key_4 = + ctx.LookUpKey(kAlgoSeed1, kWeightIndex2, kBiasIndex2); + + EXPECT_EQ(cache_provider.LookUp(&look_up_key_1), + ctx.packed_buffers.find(pack_id_1)->second.offset); + EXPECT_EQ(cache_provider.LookUp(&look_up_key_2), + ctx.packed_buffers.find(pack_id_2)->second.offset); + EXPECT_EQ(cache_provider.LookUp(&look_up_key_3), + ctx.packed_buffers.find(pack_id_3)->second.offset); + EXPECT_EQ(cache_provider.LookUp(&look_up_key_4), + ctx.packed_buffers.find(pack_id_4)->second.offset); + EXPECT_NE(cache_provider.LookUp(&look_up_key_1), + cache_provider.LookUp(&look_up_key_2)); + EXPECT_NE(cache_provider.LookUp(&look_up_key_1), + cache_provider.LookUp(&look_up_key_3)); + EXPECT_NE(cache_provider.LookUp(&look_up_key_1), + cache_provider.LookUp(&look_up_key_4)) + << pack_id_1 << " " << pack_id_4; + EXPECT_NE(cache_provider.LookUp(&look_up_key_2), + cache_provider.LookUp(&look_up_key_3)); + EXPECT_NE(cache_provider.LookUp(&look_up_key_2), + cache_provider.LookUp(&look_up_key_4)); + EXPECT_NE(cache_provider.LookUp(&look_up_key_3), + cache_provider.LookUp(&look_up_key_4)); +} + +TEST_F(BuildMMapWeightCacheProviderTest, FinalizeWorks) { + enum { kWeightIndex1, kBiasIndex, kWeightIndex2 }; + TempFileDesc tmp_file; + + ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed1, kWeightIndex1, + kBiasIndex); + ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed2, + kWeightIndex2); + + EXPECT_FALSE(cache_provider.Finalize()); + cache_provider.SetFilePath(tmp_file.GetCPath()); + + EXPECT_TRUE(cache_provider.IsActive()); + EXPECT_TRUE(cache_provider.IsBuilding()); + ASSERT_TRUE(cache_provider.Finalize()); + + ASSERT_TRUE(cache_provider.IsFinalized()); +} + +struct LoadMMapWeightCacheProviderTest : BuildMMapWeightCacheProviderTest { + enum { kWeightIndex1, kBiasIndex, kWeightIndex2 }; + + void SetUp() override { + BuildMMapWeightCacheProviderTest::SetUp(); + cache_provider.SetFilePath(tmp_file.GetCPath()); + + pack_id_1 = ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed1, + kWeightIndex1, kBiasIndex); + pack_id_2 = ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed2, + kWeightIndex2); + + ASSERT_TRUE(cache_provider.Finalize()); + ASSERT_TRUE(cache_provider.IsFinalized()); + } + + xnn_weights_cache_look_up_key LookUpKey1() const { + return ctx.LookUpKey(kAlgoSeed1, kWeightIndex1, kBiasIndex); + } + + xnn_weights_cache_look_up_key LookUpKey2() const { + return ctx.LookUpKey(kAlgoSeed2, kWeightIndex2); + } + + TempFileDesc tmp_file; + PackIdentifier pack_id_1; + PackIdentifier pack_id_2; +}; + +TEST_F(LoadMMapWeightCacheProviderTest, LookUpFailsIfKeyDoesntMatch) { + xnn_weights_cache_look_up_key look_up_key{}; + EXPECT_EQ(cache_provider.LookUp(&look_up_key), SIZE_MAX); +} + +template +class LightSpan { + public: + using value_type = T; + + LightSpan(const void* data, const size_t size) + : ptr_(reinterpret_cast(data)), size_(size) {} + + const T* begin() const { return ptr_; } + const T* end() const { return ptr_ + size_; } + + private: + T* ptr_; + size_t size_; +}; + +TEST_F(LoadMMapWeightCacheProviderTest, LookUpSucceeds) { + const auto& reference_1 = ctx.packed_buffers.find(pack_id_1)->second; + const auto& reference_2 = ctx.packed_buffers.find(pack_id_2)->second; + + const xnn_weights_cache_look_up_key look_up_key_1 = LookUpKey1(); + const xnn_weights_cache_look_up_key look_up_key_2 = LookUpKey2(); + + const uint64_t offset_1 = cache_provider.LookUp(&look_up_key_1); + const uint64_t offset_2 = cache_provider.LookUp(&look_up_key_2); + + ASSERT_EQ(offset_1, reference_1.offset); + ASSERT_EQ(offset_2, reference_2.offset); + + const void* const addr_1 = cache_provider.OffsetToAddr(offset_1); + const void* const addr_2 = cache_provider.OffsetToAddr(offset_2); + + ASSERT_NE(addr_1, nullptr); + ASSERT_NE(addr_2, nullptr); + + EXPECT_THAT(LightSpan(addr_1, reference_1.buffer.size()), + ElementsAreArray(reference_1.buffer)); + EXPECT_THAT(LightSpan(addr_2, reference_2.buffer.size()), + ElementsAreArray(reference_2.buffer)); +} + +TEST(MMapWeightCacheProviderTest, XnnpackCApiJourney) { + using std::size; + TempFileDesc temp_fd(TempFileDesc::kAutoCLose); + const int32_t fake_packing_algo_seed = 0xBA0BAB; + const char packed_data_ref_1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + const char packed_data_ref_2[] = {26, 32, 43, 59, 34, 65, 80, 101}; + auto bytes = [](const auto& array) { return size(array) * sizeof(array[0]); }; + + constexpr int kBufferCount = 10; + // We are going to feed dummy packed data. We only need a valid pointer + // address to map to a buffer identifier. + char fake_buffer_pointer[kBufferCount] = {0}; + + { // Build and reload scenario. + TfLiteTensor tensors[kBufferCount]; + std::unordered_map tensor_buffer_identifiers; + for (int i = 0; i < kBufferCount; ++i) { + tensors[0].data.data = (void*)(fake_buffer_pointer + i); + tensor_buffer_identifiers[i] = i + 1; + } + + MMapWeightCacheProvider cache_provider; + cache_provider.SetFilePath(temp_fd.GetCPath()); + + xnn_weights_cache_t cache = &cache_provider.GetCacheProvider(); + cache_provider.MapTensorIdentifiers(tensors, size(tensors), + tensor_buffer_identifiers); + + const xnn_weights_cache_look_up_key look_up_key_1{ + .seed = fake_packing_algo_seed, + .kernel = tensors[0].data.data, + .bias = tensors[1].data.data}; + + // Lookup non-packed tensor. + ASSERT_EQ(cache->look_up(cache, &look_up_key_1), SIZE_MAX); + // Reserve space, write data and add packed data. + void* const reserved_ptr = + cache->reserve_space(cache, bytes(packed_data_ref_1)); + ASSERT_NE(reserved_ptr, nullptr); + std::memcpy(reserved_ptr, packed_data_ref_1, bytes(packed_data_ref_1)); + const size_t build_offset_1 = cache->look_up_or_insert( + cache, &look_up_key_1, reserved_ptr, bytes(packed_data_ref_1)); + + // Check that a second insertion with the same key returns the same offset. + const size_t build_offset_redundant = cache->look_up_or_insert( + cache, &look_up_key_1, reserved_ptr, bytes(packed_data_ref_1)); + EXPECT_EQ(build_offset_1, build_offset_redundant); + + // Lookup newly packed tensor. + ASSERT_EQ(cache->look_up(cache, &look_up_key_1), build_offset_1); + + // Add a tensor without reserving before. + const xnn_weights_cache_look_up_key look_up_key_2{ + .seed = fake_packing_algo_seed, + .kernel = tensors[2].data.data, + .bias = tensors[3].data.data}; + const size_t build_offset_2 = cache->look_up_or_insert( + cache, &look_up_key_2, (void*)packed_data_ref_2, + bytes(packed_data_ref_2)); + + // Save the cache to disk and reload. + ASSERT_TRUE(cache_provider.Finalize()); + + ASSERT_TRUE(cache->is_finalized(cache)); + + const size_t reload_offset_1 = cache->look_up(cache, &look_up_key_1); + ASSERT_EQ(reload_offset_1, build_offset_1); + + const void* const loaded_packed_data_1 = + cache->offset_to_addr(cache, reload_offset_1); + ASSERT_NE(loaded_packed_data_1, nullptr); + EXPECT_THAT( + LightSpan(loaded_packed_data_1, size(packed_data_ref_1)), + ElementsAreArray(packed_data_ref_1)); + + const size_t reload_offset_2 = cache->look_up(cache, &look_up_key_2); + ASSERT_EQ(reload_offset_2, build_offset_2); + + const void* const loaded_packed_data_2 = + cache->offset_to_addr(cache, reload_offset_2); + ASSERT_NE(loaded_packed_data_2, nullptr); + EXPECT_THAT( + LightSpan(loaded_packed_data_2, size(packed_data_ref_2)), + ElementsAreArray(packed_data_ref_2)); + } + + { // Load existing cache scenario. + TfLiteTensor tensors[kBufferCount]; + std::unordered_map tensor_buffer_identifiers; + for (int i = 0; i < kBufferCount; ++i) { + tensors[0].data.data = (void*)(fake_buffer_pointer + i); + tensor_buffer_identifiers[i] = i + 1; + } + + MMapWeightCacheProvider cache_provider; + ASSERT_TRUE(cache_provider.Load(temp_fd.GetCPath())); + + xnn_weights_cache_t cache = &cache_provider.GetCacheProvider(); + cache_provider.MapTensorIdentifiers(tensors, size(tensors), + tensor_buffer_identifiers); + + const xnn_weights_cache_look_up_key look_up_key_1{ + .seed = fake_packing_algo_seed, + .kernel = tensors[0].data.data, + .bias = tensors[1].data.data}; + + const xnn_weights_cache_look_up_key look_up_key_2{ + .seed = fake_packing_algo_seed, + .kernel = tensors[2].data.data, + .bias = tensors[3].data.data}; + + ASSERT_TRUE(cache->is_finalized(cache)); + + const size_t offset_1 = cache->look_up(cache, &look_up_key_1); + const void* const loaded_packed_data_1 = + cache->offset_to_addr(cache, offset_1); + ASSERT_NE(loaded_packed_data_1, nullptr); + EXPECT_THAT( + LightSpan(loaded_packed_data_1, size(packed_data_ref_1)), + ElementsAreArray(packed_data_ref_1)); + + const size_t offset_2 = cache->look_up(cache, &look_up_key_2); + const void* const loaded_packed_data_2 = + cache->offset_to_addr(cache, offset_2); + ASSERT_NE(loaded_packed_data_2, nullptr); + EXPECT_THAT( + LightSpan(loaded_packed_data_2, size(packed_data_ref_2)), + ElementsAreArray(packed_data_ref_2)); + } +} + +} // namespace +} // namespace tflite::xnnpack diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index 76cc6dba209ab9..ab0c5f613a7d0b 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -38,7 +38,9 @@ limitations under the License. #include "tensorflow/lite/core/api/profiler.h" #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/core/subgraph.h" #include "tensorflow/lite/delegates/xnnpack/quantization_util.h" +#include "tensorflow/lite/delegates/xnnpack/weight_cache.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" @@ -525,6 +527,29 @@ class Delegate { options != nullptr ? *options : TfLiteXNNPackDelegateOptionsDefault(); delegate_.flags = GetXNNPackDelegateFlags(); workspace_.reset(workspace); + + // If no weight cache is provided, add one when requested. + if (!options_.weights_cache) { + if (options_.experimental_weight_cache_file_path) { + if (weight_cache_provider_.Load( + options_.experimental_weight_cache_file_path)) { + TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO, + "XNNPack weight cache loaded from '%s'.", + options_.experimental_weight_cache_file_path); + } else { + TFLITE_LOG_PROD( + tflite::TFLITE_LOG_INFO, + "XNNPack weight cache not found at '%s', building it.", + options_.experimental_weight_cache_file_path); + } + options_.weights_cache = + reinterpret_cast( + weight_cache_provider_.GetCacheProvider().context); + } else { + TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO, + "XNNPack weight cache not enabled."); + } + } } TfLiteIntArray* PrepareOpsToDelegate(TfLiteContext* context); @@ -711,6 +736,10 @@ class Delegate { TfLiteXNNPackDelegateOptions options_{}; VariableHolder variable_holder_; std::mutex workspace_mutex_; + + // If no weight cache is provided and a cache is set in the delegate options, + // this will be used as a weight cache. + MMapWeightCacheProvider weight_cache_provider_; }; class Subgraph { @@ -781,6 +810,13 @@ class Subgraph { static Subgraph* Create(TfLiteContext* context, const TfLiteDelegateParams* params, Delegate& delegate) { + // Map tensors identifiers before packing anything. + if (delegate.weight_cache_provider_.IsActive()) { + delegate.weight_cache_provider_.MapTensorIdentifiers( + context->tensors, context->tensors_size, + reinterpret_cast(context->impl_) + ->GetTensorBufferIdentifiers()); + } // Convert subgraph inputs and outputs to hash sets for faster lookup. const std::unordered_set inputs( ¶ms->input_tensors->data[0], @@ -1121,6 +1157,18 @@ class Subgraph { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node, bool enable_subgraph_reshaping, Delegate* delegate) { std::lock_guard lock(delegate->workspace_mutex_); + + // The weights cache needs to be finalized only once. Prepare will be called + // for each partition after all the partitions have been created (therefore + // all the weights are known and have been packed). + if (delegate->weight_cache_provider_.IsActive()) { + if (!delegate->weight_cache_provider_.Finalize()) { + TF_LITE_KERNEL_LOG(context, + "XNNPack delegate failed to finalize cache."); + return kTfLiteError; + } + } + if (enable_subgraph_reshaping) { xnn_status status = xnn_status_invalid_state; for (int i = 0; i < inputs_.size(); ++i) { @@ -1170,10 +1218,8 @@ class Subgraph { return kTfLiteError; } } - return kTfLiteOk; - } else { - return kTfLiteOk; } + return kTfLiteOk; } TfLiteStatus Invoke(TfLiteContext* context, bool enable_subgraph_reshaping, diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h index aa11998dc0fc49..16ade69a1967a7 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h @@ -70,6 +70,10 @@ typedef struct { bool handle_variable_ops; // Enable adaptive optimization for AVX CPUs. bool experimental_adaptive_avx_optimization; + // Path to the weight cache to load if `weight_cache` is undefined. + // + // WARNING this is an experimental flag. + const char* experimental_weight_cache_file_path; } TfLiteXNNPackDelegateOptions; // Returns a structure with the default XNNPack delegate options. @@ -111,11 +115,13 @@ TFL_CAPI_EXPORT void TfLiteXNNPackDelegateDelete(TfLiteDelegate* delegate); // reduce memory bandwidth. TFL_CAPI_EXPORT struct TfLiteXNNPackDelegateWeightsCache* TfLiteXNNPackDelegateWeightsCacheCreate(); + // Creates a new weights cache with a specified initial size that can be shared // with multiple delegate instances. The weights cache can hold up to size bytes // without growing. TFL_CAPI_EXPORT struct TfLiteXNNPackDelegateWeightsCache* TfLiteXNNPackDelegateWeightsCacheCreateWithSize(size_t size); + // Soft-finalize a weights cache. Extra space will be left in the weights cache // to allow for cache "insertion" only if it is a cache hit. This has memory // overhead compared to TfLiteXNNPackDelegateWeightsCacheFinalizeHard. Use this @@ -124,6 +130,7 @@ TfLiteXNNPackDelegateWeightsCacheCreateWithSize(size_t size); // Returns true on success, false on error. TFL_CAPI_EXPORT bool TfLiteXNNPackDelegateWeightsCacheFinalizeSoft( struct TfLiteXNNPackDelegateWeightsCache* cache); + // Hard-finalize a weights cache, cache is effectively frozen and no more cache // operations are allowed. Memory is resized to smallest possible. Use this if // the number of interpreter instances using XNNPACK delegate can be fixed and @@ -132,6 +139,10 @@ TFL_CAPI_EXPORT bool TfLiteXNNPackDelegateWeightsCacheFinalizeSoft( // Returns true on success, false on error. TFL_CAPI_EXPORT bool TfLiteXNNPackDelegateWeightsCacheFinalizeHard( struct TfLiteXNNPackDelegateWeightsCache* cache); + +TFL_CAPI_EXPORT bool TfLiteXNNPackDelegateWeightsCacheIsFinalized( + TfLiteXNNPackDelegateWeightsCache* cache); + // Destroys a weights cache created with // `TfLiteXNNPackDelegateWeightsCacheCreate` call. TFL_CAPI_EXPORT void TfLiteXNNPackDelegateWeightsCacheDelete( diff --git a/tensorflow/lite/tflite_with_xnnpack.cc b/tensorflow/lite/tflite_with_xnnpack.cc index 22e8617ec74e21..d443d404c21f05 100644 --- a/tensorflow/lite/tflite_with_xnnpack.cc +++ b/tensorflow/lite/tflite_with_xnnpack.cc @@ -23,6 +23,10 @@ namespace tflite { std::unique_ptr AcquireXNNPACKDelegate() { auto opts = TfLiteXNNPackDelegateOptionsDefault(); +#ifdef TFLITE_XNNPACK_DELEGATE_EXPERIMENTAL_WEIGHT_CACHE_FILE_PATH + opts.experimental_weight_cache_file_path = + TFLITE_XNNPACK_DELEGATE_EXPERIMENTAL_WEIGHT_CACHE_FILE_PATH; +#endif return std::unique_ptr( TfLiteXNNPackDelegateCreate(&opts), TfLiteXNNPackDelegateDelete); } diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index b2645a331739e3..b83a48134d02e1 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -101,6 +101,7 @@ tf_staging/tensorflow/lite/delegates/utils/experimental/stable_delegate/BUILD: tf_staging/tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader.cc: tf_staging/tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader.h: tf_staging/tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader_test.cc: +tf_staging/tensorflow/lite/delegates/xnnpack/weight_cache_schema_generated.h: tf_staging/tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h: tf_staging/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h: tf_staging/tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg.h: