From 540e91e0a6ff0345fb1b3501892e013965caae53 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Tue, 24 Sep 2024 21:42:18 -0700 Subject: [PATCH] [Executorch][llama] Add custom_sdpa and use that instead of sdpa_with_kv_cache sdpa_with_kv_cache updates kv cache. In quantized kv cache, cache updates happens separately. Then the quantized cache is dequantized. After that we call sdpa_with_kv_cache which copies k and v data into dequantized cache. Although this is not needed because the actual cache is the one that is quantized. For very large context length this will add significant amount data copy. Subsequent diffs will deprecate sdpa_with_kv_cache op and deconstruct that using a) update_cache op and b) custom_sdpa op. Differential Revision: [D62623241](https://our.internmc.facebook.com/intern/diff/D62623241/) [ghstack-poisoned] --- .../llama2/source_transformation/sdpa.py | 41 ++--- extension/llm/custom_ops/op_sdpa.cpp | 141 +++++++++--------- extension/llm/custom_ops/op_sdpa.h | 13 ++ extension/llm/custom_ops/op_sdpa_aot.cpp | 62 +++++++- .../llm/custom_ops/sdpa_with_kv_cache.py | 29 ++++ 5 files changed, 195 insertions(+), 91 deletions(-) diff --git a/examples/models/llama2/source_transformation/sdpa.py b/examples/models/llama2/source_transformation/sdpa.py index 263a98a66b3..54184f89dbc 100644 --- a/examples/models/llama2/source_transformation/sdpa.py +++ b/examples/models/llama2/source_transformation/sdpa.py @@ -46,25 +46,28 @@ def forward( # returns dequantized kv cache # Not most optimal. Optimizations to follow next k_cache, v_cache = self.kv_cache.update(input_pos, k, v) - # Note that this path will still inplace mutate the k_cache, v_cache. - # WHen we are not using quantized kv cache, this will just mutate - # the original kv cache. - # When we aer using quantized kv cache, this will mutate - # k_cache, v_cache that is returned from cache update operation. - # This operation just dequantized thee cache and returns that. - # Future diffs will optimize this - output = torch.ops.llama.sdpa_with_kv_cache( - q, - k, - v, - k_cache, - v_cache, - input_pos[-1].item(), - seqlen, - None, # Attention mask - 0, # dropout probability. Ignored by the code - True, # is_causal - ) + output = torch.ops.llama.custom_sdpa( + q, + k_cache, + v_cache, + input_pos[0].item(), + None, # Attention mask + 0, # dropout probability. Ignored by the code + True, # is_causal + ) + else: + output = torch.ops.llama.sdpa_with_kv_cache( + q, + k, + v, + k_cache, + v_cache, + input_pos[0].item(), + seqlen, + None, # Attention mask + 0, # dropout probability. Ignored by the code + True, # is_causal + ) return output.view(bsz, seqlen, self.dim) diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index b5cb2b55e99..4316a68afa4 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -754,6 +754,74 @@ void update_cache( } } +} // anonymous namespace + +Tensor& flash_attention_kernel_out( + RuntimeContext& ctx, + const Tensor& query, + const Tensor& key, + const Tensor& value, + const optional& attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional scale, + Tensor& output) { + (void)ctx; + ET_KERNEL_CHECK( + ctx, + validate_flash_attention_args(query, key, value, attn_mask), + InvalidArgument, + output); + + ET_KERNEL_CHECK( + ctx, + resize_tensor(output, query.sizes()) == Error::Ok, + InvalidArgument, + output); + + auto q_seq_len = query.size(2); + + ET_SWITCH_FLOAT_TYPES( + query.scalar_type(), ctx, "flash_attention", CTYPE, [&] { + // TODO we need to re-evaluate this for ARM CPUs + // And there can be many so instead of templatizing + // we might consider another appraoch + if (q_seq_len >= 768) { + cpu_flash_attention( + output, + query, + key, + value, + dropout_p, + is_causal, + attn_mask, + scale); + } else if (q_seq_len >= 192) { + cpu_flash_attention( + output, + query, + key, + value, + dropout_p, + is_causal, + attn_mask, + scale); + } else { + cpu_flash_attention( + output, + query, + key, + value, + dropout_p, + is_causal, + attn_mask, + scale); + } + }); + return output; +} + /* Input params @param[in] q_projected Projected query with query weights. @@ -900,74 +968,6 @@ Tensor& custom_sdpa_out( }); return output; } -} // anonymous namespace - -Tensor& flash_attention_kernel_out( - KernelRuntimeContext& ctx, - const Tensor& query, - const Tensor& key, - const Tensor& value, - const optional& attn_mask, - const double dropout_p, - const bool is_causal, - // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy - const optional scale, - Tensor& output) { - (void)ctx; - ET_KERNEL_CHECK( - ctx, - validate_flash_attention_args(query, key, value, attn_mask), - InvalidArgument, - output); - - ET_KERNEL_CHECK( - ctx, - resize_tensor(output, query.sizes()) == Error::Ok, - InvalidArgument, - output); - - auto q_seq_len = query.size(2); - - ET_SWITCH_FLOAT_TYPES( - query.scalar_type(), ctx, "flash_attention", CTYPE, [&] { - // TODO we need to re-evaluate this for ARM CPUs - // And there can be many so instead of templatizing - // we might consider another appraoch - if (q_seq_len >= 768) { - cpu_flash_attention( - output, - query, - key, - value, - dropout_p, - is_causal, - attn_mask, - scale); - } else if (q_seq_len >= 192) { - cpu_flash_attention( - output, - query, - key, - value, - dropout_p, - is_causal, - attn_mask, - scale); - } else { - cpu_flash_attention( - output, - query, - key, - value, - dropout_p, - is_causal, - attn_mask, - scale); - } - }); - return output; -} - /* Input params @param[in] q_projected Projected query with query weights. @@ -1033,3 +1033,8 @@ EXECUTORCH_LIBRARY( llama, "sdpa_with_kv_cache.out", torch::executor::native::sdpa_with_kv_cache_out); + +EXECUTORCH_LIBRARY( + llama, + "custom_sdpa.out", + torch::executor::native::custom_sdpa_out); diff --git a/extension/llm/custom_ops/op_sdpa.h b/extension/llm/custom_ops/op_sdpa.h index ce969b013d2..bc2202b9bd8 100644 --- a/extension/llm/custom_ops/op_sdpa.h +++ b/extension/llm/custom_ops/op_sdpa.h @@ -31,6 +31,19 @@ Tensor& sdpa_with_kv_cache_out( const optional scale, Tensor& output); +Tensor& custom_sdpa_out( + RuntimeContext& ctx, + const Tensor& q, + const Tensor& k, + const Tensor& v, + const int64_t start_pos, + const optional& attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional scale, + Tensor& output); + Tensor& flash_attention_kernel_out( KernelRuntimeContext& ctx, const Tensor& query, diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index f3674088fd7..c182903aa54 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -82,6 +82,51 @@ at::Tensor sdpa_with_kv_cache_aten( return output; } +Tensor& custom_sdpa_out_no_context( + const Tensor& q, + const Tensor& k, + const Tensor& v, + const int64_t start_pos, + // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional scale, + Tensor& output) { + exec_aten::RuntimeContext context{}; + return torch::executor::native::custom_sdpa_out( + context, + q, + k, + v, + start_pos, + attn_mask, + dropout_p, + is_causal, + scale, + output); +} + +at::Tensor custom_sdpa_aten( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const int64_t start_pos, + // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const c10::optional attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const c10::optional scale) { + auto output = at::empty_like(q); + WRAP_TO_ATEN(custom_sdpa_out_no_context, 8) + (q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output); + return output; +} + Tensor& update_quantized_cache_out_no_context( const Tensor& value, Tensor& cache, @@ -115,6 +160,14 @@ TORCH_LIBRARY_FRAGMENT(llama, m) { "sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, " "Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, " "float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)"); + m.def( + "custom_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, " + "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " + "float? scale=None) -> Tensor"); + m.def( + "custom_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, " + "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " + "float? scale=None, *, Tensor(a!) out) -> Tensor(a!)"); m.def( "update_quantized_cache(Tensor value, Tensor(a!) cache, " "SymInt start_pos) -> Tensor"); @@ -123,6 +176,7 @@ TORCH_LIBRARY_FRAGMENT(llama, m) { "SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)"); } +// TODO: Rename this file to op_custom_ops_aot.cpp TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { m.impl( "sdpa_with_kv_cache", torch::executor::native::sdpa_with_kv_cache_aten); @@ -130,10 +184,10 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { "sdpa_with_kv_cache.out", WRAP_TO_ATEN( torch::executor::native::sdpa_with_kv_cache_out_no_context, 11)); -} - -// TODO: Rename this file to op_custom_ops_aot.cpp -TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { + m.impl("custom_sdpa", torch::executor::native::custom_sdpa_aten); + m.impl( + "custom_sdpa.out", + WRAP_TO_ATEN(torch::executor::native::custom_sdpa_out_no_context, 8)); m.impl( "update_quantized_cache", torch::executor::native::update_quantized_cache_aten); diff --git a/extension/llm/custom_ops/sdpa_with_kv_cache.py b/extension/llm/custom_ops/sdpa_with_kv_cache.py index d6c7fbab6f4..85021266b59 100644 --- a/extension/llm/custom_ops/sdpa_with_kv_cache.py +++ b/extension/llm/custom_ops/sdpa_with_kv_cache.py @@ -141,6 +141,35 @@ def fast_hadamard_transform_meta(mat): return torch.empty_like(mat) +@impl(custom_ops_lib, "custom_sdpa", "Meta") +def custom_sdpa( + query, + key_cache, + value_cache, + start_pos, + attn_mask=None, + drpout_p=0.0, + is_causal=False, + scale=None, +): + seq_len = query.size(1) + _validate_params( + query, + key_cache, + value_cache, + key_cache, + value_cache, + start_pos, + seq_len, + attn_mask, + drpout_p, + is_causal, + scale, + ) + + return torch.empty_like(query) + + def _validate_update_cache_params( value, cache,