diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 371fcf38a24..89634574027 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -382,7 +382,7 @@ void cpu_flash_attention( /* qk_sum */ qSplitSize + /* dst */ qSplitSize * headSize; - int64_t size_bytes = size_per_thread * num_thread * query.element_size(); + int64_t size_bytes = size_per_thread * num_thread * query.element_size() * 4; std::vector buf_vec(size_bytes); void* buf = reinterpret_cast(buf_vec.data()); // Need to double check the following @@ -452,6 +452,7 @@ void cpu_flash_attention( // However, lets just fix that as well. int64_t num_keys = is_causal ? std::min(m + start_pos + qBlockSize, kvSize) : kvSize; + int64_t m_start_pos = m + start_pos; auto j_kv = j / num_reps; for (int64_t n = 0; n < num_keys; n += kvSplitSize) { int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); @@ -471,29 +472,62 @@ void cpu_flash_attention( static_cast(0), qk_data, kvBlockSize); - // Apply causal mask, fill unused, i.e. future values, with -inf - // Say you have q @ k.T size = [16, 32] - // With qblock size = 4, say you are processing - // q seq len dim = 8:11. - // Say kvSplitSize = 4 - // Then for causal mask, the entries that needs to be - // ignored are - // [8, 9:31], [9, 10:31], [10, 10:31], [11, 11:31] - // Following condition says that num_keys = 8 + 4 =12 - // (num_keys - n) <= kvSplitSize - // num_keys <= n + kvSplitSize - // If n + kvSplitSize is larger than 12, then some - // entries need masked out. In our example n = 4 - // will qualify for that - if (is_causal && num_keys - n <= kvSplitSize) { + // There are 4 cases that is_causal has to cover to fill + // not-attendable-position with -inf + /* 1. Everything is attended to. This happens when m_start_pos > n + + kvSplitSize e.g m_pos [8:15] and n_pos [0:7]. Since you must attend to + all previous tokens matrix is full + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 2. Everything is not attended to. However only some tokens at the + beginning dont attend to everything. This happens when m_start_pos <= n + + kvSplitSize but m_start_pos + qBlockSize > n + kvSplitSize m_start_pos + = 8 qBlockSize = 8 n = 4 kvSplitSize = 8 For example m_pos [8:15] but + n_pos is [4:11] + + + + + + - - - + + + + + + + - - + + + + + + + + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 3. In this case only last few tokens have something to attend to. + This happens when m_start_pos < n and m_start_pos + qBlockSize >= n and + m_start_pos + qBlockSize <= n + kvSplitSize m_start_pos = 8 qBlockSize = + 8 n = 13 kvSplitSize = 8 For example m_pos [8:15] but n_pos is [13:20] + - - - - - - - - + - - - - - - - - + - - - - - - - - + - - - - - - - - + - - - - - - - - + + - - - - - - - + + + - - - - - - + + + + - - - - - + 4. In this no tokens attend to anything, but we dont really have to + take care of this case because the loop for (int64_t n = 0; n < + num_keys; n += kvSplitSize) will exit before that. + */ + if (is_causal && m_start_pos <= n + kvSplitSize) { // For this fn to work k_split_size > q_split_size - for (int32_t row = 0; row < qBlockSize; ++row) { - int64_t last_col = m + (row + start_pos) - n; + for (int32_t row = 0; + row < qBlockSize && (m_start_pos + row < n + (kvSplitSize - 1)); + ++row) { + // When last_col is 0, it means that the entire row is not attended + // to because m_pos is smaller than n_pos. So everything in n is for + // future. + int64_t last_col = + n > (m_start_pos + row) ? 0 : row + m_start_pos + 1 - n; accum_t* row_ptr = qk_data + row * kvBlockSize; fill_stub( - row_ptr + last_col + 1, + row_ptr + last_col, -std::numeric_limits::infinity(), - kvBlockSize - last_col - 1); + kvBlockSize - last_col); } } // Update attention weights with attention mask diff --git a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py index 9c8029c7b70..a1f054a153e 100644 --- a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py +++ b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py @@ -590,3 +590,14 @@ def test_sdpa_with_cache_seq_len_llava_example_gqa(self): self._test_sdpa_common( n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len ) + + def test_sdpa_to_repro_long_seq_failure(self): + n_heads_kv = 16 + n_heads_q = 32 + head_dim = 128 + max_seq_len = 2048 + seq_len = 508 + next_iter_seq_len = 127 + self._test_sdpa_common( + n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len + )