From bbf2b178f38ea7dba4b1399a13cc38b016092f0d Mon Sep 17 00:00:00 2001 From: Hakan Boyraz Date: Thu, 14 May 2026 19:32:57 -0700 Subject: [PATCH] Gate weights cache on runtime option instead of compile-time macro (#19603) Summary: Replaces the compile-time `#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE` gate in XNNCompiler.cpp with a runtime boolean plumbed from `XnnpackBackendOptions::resolve_weight_cache(context)` through `XNNPACKBackend::init` to `XNNCompiler::compileModel`. This fixes a silent-disable bug: previously, runtime opt-in via `set_option(weight_cache_option_key, true)` was silently a no-op unless the build also set `-c executorch.xnnpack_weights_cache=1`, because the cache pointer handed to `xnn_create_runtime_v4` was hardcoded to nullptr when the macro was undefined. Multimethod LoRA models re-packed the entire backbone for every method load, costing hundreds of MB of resident memory. The runtime path now keys all three cache-relevant code regions (unpacked-data load, cache pointer handoff to xnn_create_runtime_v4, and finalize_for_runtime) on `bool use_weight_cache` resolved per-init from the BackendInitContext. The `Result>` declaration in compileModel was reshaped to plain `vector` since `Result<>` is non-assignable, which is required for the new runtime branch. Reviewed By: GregoryComer Differential Revision: D105123995 --- backends/xnnpack/runtime/XNNCompiler.cpp | 120 ++++++++++---------- backends/xnnpack/runtime/XNNCompiler.h | 3 +- backends/xnnpack/runtime/XNNPACKBackend.cpp | 3 +- 3 files changed, 67 insertions(+), 59 deletions(-) diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 103bdeb6b82..df24dc4ba1f 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -181,7 +181,8 @@ Result getConstantDataPtr( const uint8_t* constant_data_ptr, const NamedDataMap* named_data_map, std::vector& freeable_buffers, - XNNWeightsCache* weights_cache) { + XNNWeightsCache* weights_cache, + bool use_weight_cache) { if (buffer_idx) { if (!constant_data_ptr) { // TODO(T172265611): Remove constant_buffer in flatbuffer path after BC @@ -230,30 +231,30 @@ Result getConstantDataPtr( InvalidProgram, "Named key is null"); const std::string& data_name = constant_data_offset->named_key()->str(); -#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE - Result data_ptr = - weights_cache->load_unpacked_data(data_name); - if (!data_ptr.ok()) { - ET_LOG(Error, "Failed to load weights from cache"); - return data_ptr.error(); - } - return data_ptr.get(); -#else - Result buffer = - named_data_map->get_data(data_name.c_str()); - if (!buffer.ok()) { - ET_LOG( - Error, - "Failed to get constant data for key %s from named_data_map. Error code: %u", - data_name.c_str(), - static_cast(buffer.error())); - return buffer.error(); + if (use_weight_cache) { + Result data_ptr = + weights_cache->load_unpacked_data(data_name); + if (!data_ptr.ok()) { + ET_LOG(Error, "Failed to load weights from cache"); + return data_ptr.error(); + } + return data_ptr.get(); + } else { + Result buffer = + named_data_map->get_data(data_name.c_str()); + if (!buffer.ok()) { + ET_LOG( + Error, + "Failed to get constant data for key %s from named_data_map. Error code: %u", + data_name.c_str(), + static_cast(buffer.error())); + return buffer.error(); + } + const uint8_t* data_ptr = + static_cast(buffer.get().data()); + freeable_buffers.push_back(std::move(buffer.get())); + return data_ptr; } - const uint8_t* data_ptr = - static_cast(buffer.get().data()); - freeable_buffers.push_back(std::move(buffer.get())); - return data_ptr; -#endif } } } @@ -267,14 +268,16 @@ Result getConstantDataPtr( const uint8_t* constant_data_ptr, const NamedDataMap* named_data_map, std::vector& freeable_buffers, - XNNWeightsCache* weights_cache) { + XNNWeightsCache* weights_cache, + bool use_weight_cache) { return getConstantDataPtr( tensor_value->constant_buffer_idx(), flatbuffer_graph, constant_data_ptr, named_data_map, freeable_buffers, - weights_cache); + weights_cache, + use_weight_cache); } /** @@ -293,7 +296,8 @@ Error defineTensor( CompileAllocator& allocator, const NamedDataMap* named_data_map, std::vector& freeable_buffers, - XNNWeightsCache* weights_cache) { + XNNWeightsCache* weights_cache, + bool use_weight_cache) { const fb_xnnpack::XNNTensorValue* tensor_value = nullptr; const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr; @@ -347,7 +351,8 @@ Error defineTensor( constant_data_ptr, named_data_map, freeable_buffers, - weights_cache); + weights_cache, + use_weight_cache); if (!buffer_result.ok()) { return buffer_result.error(); } @@ -502,7 +507,8 @@ Error defineTensor( constant_data_ptr, named_data_map, freeable_buffers, - weights_cache); + weights_cache, + use_weight_cache); if (!scale_result.ok()) { return scale_result.error(); } @@ -548,7 +554,8 @@ Error defineTensor( constant_data_ptr, named_data_map, freeable_buffers, - weights_cache); + weights_cache, + use_weight_cache); if (!scale_data_result.ok()) { return scale_data_result.error(); } @@ -1976,7 +1983,8 @@ ET_NODISCARD Error XNNCompiler::compileModel( XNNExecutor* executor, XNNWeightsCache* weights_cache, xnn_workspace_t workspace, - const NamedDataMap* named_data_map) { + const NamedDataMap* named_data_map, + bool use_weight_cache) { Result header = XNNHeader::Parse(buffer_pointer, num_bytes); const uint8_t* flatbuffer_data = nullptr; const uint8_t* constant_data = nullptr; @@ -2086,7 +2094,8 @@ ET_NODISCARD Error XNNCompiler::compileModel( compile_allocator, named_data_map, unpacked_buffers, - weights_cache); + weights_cache, + use_weight_cache); if (err != Error::Ok) { return err; @@ -2108,19 +2117,16 @@ ET_NODISCARD Error XNNCompiler::compileModel( xnn_runtime_t runtime_ptr = nullptr; - // XNNWeightsCache if weights cache is not enabled, then XNNWeightsCache - // just manages the unpacked weights until the runtime is created. -#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE - ET_CHECK_OR_RETURN_ERROR( - unpacked_buffers.size() == 0, - Internal, - "Weight Cache is enabled, which means unpacked buffers should be owned by the cache"); - xnn_weights_cache_t weights_cache_ptr = - weights_cache->get_num_unpacked_data() > 0 ? weights_cache->get() - : nullptr; -#else xnn_weights_cache_t weights_cache_ptr = nullptr; -#endif + if (use_weight_cache) { + ET_CHECK_OR_RETURN_ERROR( + unpacked_buffers.size() == 0, + Internal, + "Weight Cache is enabled, which means unpacked buffers should be owned by the cache"); + weights_cache_ptr = weights_cache->get_num_unpacked_data() > 0 + ? weights_cache->get() + : nullptr; + } // NOLINTBEGIN(facebook-hte-NullableDereference) - weights cache is allowed to // be null @@ -2139,25 +2145,25 @@ ET_NODISCARD Error XNNCompiler::compileModel( "XNN Runtime creation failed with code: %s", xnn_status_to_string(status)); -#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE - auto packed_weights_names = weights_cache->finalize_for_runtime(); - ET_CHECK_OR_RETURN_ERROR( - packed_weights_names.ok(), - Internal, - "Failed to finalize weights cache after creating the xnn runtime") -#else - for (auto& buffer : unpacked_buffers) { - buffer.Free(); + std::vector packed_weights_names; + if (use_weight_cache) { + auto packed_weights_names_result = weights_cache->finalize_for_runtime(); + ET_CHECK_OR_RETURN_ERROR( + packed_weights_names_result.ok(), + Internal, + "Failed to finalize weights cache after creating the xnn runtime"); + packed_weights_names = std::move(packed_weights_names_result.get()); + } else { + for (auto& buffer : unpacked_buffers) { + buffer.Free(); + } } - Result> packed_weights_names = - std::vector(); -#endif err = executor->initialize( // NOLINT: runtime_ptr is non-null runtime_ptr, std::move(input_ids), std::move(output_ids), - std::move(packed_weights_names.get())); + std::move(packed_weights_names)); return err; }; diff --git a/backends/xnnpack/runtime/XNNCompiler.h b/backends/xnnpack/runtime/XNNCompiler.h index bcc87351d7d..639df0438cb 100644 --- a/backends/xnnpack/runtime/XNNCompiler.h +++ b/backends/xnnpack/runtime/XNNCompiler.h @@ -29,7 +29,8 @@ class XNNCompiler { XNNExecutor* executor, XNNWeightsCache* weights_cache, xnn_workspace_t workspace, - const NamedDataMap* named_data_map); + const NamedDataMap* named_data_map, + bool use_weight_cache); }; } // namespace delegate diff --git a/backends/xnnpack/runtime/XNNPACKBackend.cpp b/backends/xnnpack/runtime/XNNPACKBackend.cpp index 23a3f4c4b1f..1076e5f2e25 100644 --- a/backends/xnnpack/runtime/XNNPACKBackend.cpp +++ b/backends/xnnpack/runtime/XNNPACKBackend.cpp @@ -110,7 +110,8 @@ class XnnpackBackend final executor, weights_cache_.get(), workspace_ptr, - named_data_map); + named_data_map, + use_weight_cache); // This backend does not need its processed data after compiling the model. processed->Free();