From 95e3f1648e0ae2cf04dfeb0fb57323dd0738b70f Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 10 Mar 2025 14:39:37 -0700 Subject: [PATCH] [Executorch][SDPA] Fix bug in sdpa This diff fixes two bugs 1. When doing flash attention, the partical q @ k block may contain some entries that needs to be masked out. This logic had a bug. Maybe this bug also exist in PT core. I will look into that to add test and see if I can prove it. 2. Due to special handling via start_pos in SDPA it also exposed the bug in 1 when doing really long sequence prefill in chunked manner. It is probably better to just use mask though. Code has detail comments on the issue and fix. Differential Revision: [D70922039](https://our.internmc.facebook.com/intern/diff/D70922039/) [ghstack-poisoned] --- extension/llm/custom_ops/op_sdpa.cpp | 70 +++++++++++++------ .../llm/custom_ops/test_sdpa_with_kv_cache.py | 11 +++ 2 files changed, 61 insertions(+), 20 deletions(-) diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index f0a7775e803..3eeee0c0e3c 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -382,7 +382,7 @@ void cpu_flash_attention( /* qk_sum */ qSplitSize + /* dst */ qSplitSize * headSize; - int64_t size_bytes = size_per_thread * num_thread * query.element_size(); + int64_t size_bytes = size_per_thread * num_thread * query.element_size() * 4; std::vector buf_vec(size_bytes); void* buf = reinterpret_cast(buf_vec.data()); // Need to double check the following @@ -452,6 +452,7 @@ void cpu_flash_attention( // However, lets just fix that as well. int64_t num_keys = is_causal ? std::min(m + start_pos + qBlockSize, kvSize) : kvSize; + int64_t m_start_pos = m + start_pos; auto j_kv = j / num_reps; for (int64_t n = 0; n < num_keys; n += kvSplitSize) { int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); @@ -471,29 +472,58 @@ void cpu_flash_attention( static_cast(0), qk_data, kvBlockSize); - // Apply causal mask, fill unused, i.e. future values, with -inf - // Say you have q @ k.T size = [16, 32] - // With qblock size = 4, say you are processing - // q seq len dim = 8:11. - // Say kvSplitSize = 4 - // Then for causal mask, the entries that needs to be - // ignored are - // [8, 9:31], [9, 10:31], [10, 10:31], [11, 11:31] - // Following condition says that num_keys = 8 + 4 =12 - // (num_keys - n) <= kvSplitSize - // num_keys <= n + kvSplitSize - // If n + kvSplitSize is larger than 12, then some - // entries need masked out. In our example n = 4 - // will qualify for that - if (is_causal && num_keys - n <= kvSplitSize) { + // There are 4 cases that is_causal has to cover to fill not-attendable-position with -inf + /* 1. Everything is attended to. This happens when m_start_pos > n + kvSplitSize + e.g m_pos [8:15] and n_pos [0:7]. Since you must attend to all previous tokens + matrix is full + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 2. Everything is not attended to. However only some tokens at the beginning dont attend + to everything. This happens when m_start_pos <= n + kvSplitSize but m_start_pos + qBlockSize > n + kvSplitSize + m_start_pos = 8 qBlockSize = 8 + n = 4 kvSplitSize = 8 + For example m_pos [8:15] but n_pos is [4:11] + + + + + + - - - + + + + + + + - - + + + + + + + + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 3. In this case only last few tokens have something to attend to. This happens when m_start_pos < n + and m_start_pos + qBlockSize >= n + and m_start_pos + qBlockSize <= n + kvSplitSize + m_start_pos = 8 qBlockSize = 8 + n = 13 kvSplitSize = 8 + For example m_pos [8:15] but n_pos is [13:20] + - - - - - - - - + - - - - - - - - + - - - - - - - - + - - - - - - - - + - - - - - - - - + + - - - - - - - + + + - - - - - - + + + + - - - - - + 4. In this no tokens attend to anything, but we dont really have to take care of this case because + the loop for (int64_t n = 0; n < num_keys; n += kvSplitSize) will exit before that. + */ + if (is_causal && m_start_pos <= n + kvSplitSize) { // For this fn to work k_split_size > q_split_size - for (int32_t row = 0; row < qBlockSize; ++row) { - int64_t last_col = m + (row + start_pos) - n; + for (int32_t row = 0; row < qBlockSize && (m_start_pos + row < n + (kvSplitSize - 1)); ++row) { + // When last_col is 0, it means that the entire row is not attended to + // because m_pos is smaller than n_pos. So everything in n is for future. + int64_t last_col = n > (m_start_pos + row) ? 0 : row + m_start_pos + 1 - n; accum_t* row_ptr = qk_data + row * kvBlockSize; fill_stub( - row_ptr + last_col + 1, + row_ptr + last_col, -std::numeric_limits::infinity(), - kvBlockSize - last_col - 1); + kvBlockSize - last_col); } } // Update attention weights with attention mask diff --git a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py index 9c8029c7b70..a1f054a153e 100644 --- a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py +++ b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py @@ -590,3 +590,14 @@ def test_sdpa_with_cache_seq_len_llava_example_gqa(self): self._test_sdpa_common( n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len ) + + def test_sdpa_to_repro_long_seq_failure(self): + n_heads_kv = 16 + n_heads_q = 32 + head_dim = 128 + max_seq_len = 2048 + seq_len = 508 + next_iter_seq_len = 127 + self._test_sdpa_common( + n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len + )