From f16910d1680a594672a241c9dc70562401632fa1 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 6 Nov 2025 12:51:21 -0800 Subject: [PATCH] [Executorch] Use temp allocator for allocating scratch memory This allows us to leverage temp memory allocator and if that allocator is caching allocator it reduces the allocaiton overhead. Differential Revision: [D85532076](https://our.internmc.facebook.com/intern/diff/D85532076/) [ghstack-poisoned] --- extension/llm/custom_ops/op_sdpa.cpp | 6 ++++ extension/llm/custom_ops/op_sdpa_impl.h | 39 ++++++++++++++++--------- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index c98fa1729fa..72bddce7b5b 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -273,6 +273,7 @@ Tensor& flash_attention_kernel_out( // we might consider another appraoch if (seq_len >= 768) { sdpa::impl::cpu_flash_attention( + ctx, output, query, key, @@ -289,6 +290,7 @@ Tensor& flash_attention_kernel_out( nullopt); } else if (seq_len >= 192) { sdpa::impl::cpu_flash_attention( + ctx, output, query, key, @@ -305,6 +307,7 @@ Tensor& flash_attention_kernel_out( nullopt); } else { sdpa::impl::cpu_flash_attention( + ctx, output, query, key, @@ -418,6 +421,7 @@ Tensor& custom_sdpa_out_impl( // we might consider another appraoch if (seq_len >= 768) { sdpa::impl::cpu_flash_attention( + ctx, output, q, k, @@ -437,6 +441,7 @@ Tensor& custom_sdpa_out_impl( num_keys_for_causal_attention); } else if (seq_len >= 192) { sdpa::impl::cpu_flash_attention( + ctx, output, q, k, @@ -456,6 +461,7 @@ Tensor& custom_sdpa_out_impl( num_keys_for_causal_attention); } else { sdpa::impl::cpu_flash_attention( + ctx, output, q, k, diff --git a/extension/llm/custom_ops/op_sdpa_impl.h b/extension/llm/custom_ops/op_sdpa_impl.h index 21acd6130eb..a418992da3f 100644 --- a/extension/llm/custom_ops/op_sdpa_impl.h +++ b/extension/llm/custom_ops/op_sdpa_impl.h @@ -35,6 +35,7 @@ enum class SeqDim { ONE = 1, TWO }; namespace sdpa::impl { +static std::vector scratch_for_quant_dequant_vec; struct MaybeQuantizedMatrixData { const void* data{nullptr}; const int8_t* zero_points{nullptr}; @@ -543,6 +544,7 @@ TODO: Just handle conversion of bool mask to float */ template void cpu_flash_attention( + RuntimeContext& ctx, Tensor& output, const Tensor& query, const Tensor& key, @@ -766,26 +768,37 @@ void cpu_flash_attention( int64_t size_of_intermediate_precision = sizeof(accum_t); int64_t size_bytes = size_per_thread * num_thread * query.element_size() * size_of_intermediate_precision; - std::vector buf_vec(size_bytes); - void* buf = reinterpret_cast(buf_vec.data()); - // Need to double check the following - size_bytes = num_thread * qSplitSize * kvSplitSize * query.element_size(); - std::vector buf_reduced_vec(size_bytes); - void* buf_reduced = reinterpret_cast(buf_reduced_vec.data()); - // at::Tensor buf_reduced = at::empty( - // {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0}, - // query.options()); + Result buff_res = ctx.allocate_temp(size_bytes); + std::unique_ptr allocated_buf; + void* buf; + if (!buff_res.ok()) { + allocated_buf = std::make_unique(size_bytes); + buf = reinterpret_cast(allocated_buf.get()); + } else { + buf = buff_res.get(); + } + void* buf_reduced = nullptr; int64_t size_per_thread_qdq_vec = qSplitSize * kvSplitSize * headSize; // Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads, // by padding with right number of per thread elements constexpr int64_t kAlignment = 32; size_per_thread_qdq_vec = (size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1)); - int64_t size_per_thread_qdq_bytes = size_per_thread_qdq_vec * sizeof(accum_t); + int64_t size_per_thread_qdq_bytes = + size_per_thread_qdq_vec * size_of_intermediate_precision; int64_t size_qdq_bytes = size_per_thread_qdq_bytes * num_thread; - std::vector scratch_for_quant_dequant_vec(size_qdq_bytes); - accum_t* scratch_for_quant_dequant = - reinterpret_cast(scratch_for_quant_dequant_vec.data()); + std::unique_ptr allocated_buf_for_qdq; + Result scratch_for_quant_dequant_res = + ctx.allocate_temp(size_qdq_bytes); + accum_t* scratch_for_quant_dequant; + if (!scratch_for_quant_dequant_res.ok()) { + allocated_buf_for_qdq = std::make_unique(size_qdq_bytes); + scratch_for_quant_dequant = + reinterpret_cast(allocated_buf_for_qdq.get()); + } else { + scratch_for_quant_dequant = + reinterpret_cast(scratch_for_quant_dequant_res.get()); + } // Data ptrs const scalar_t* q_data = query.const_data_ptr();