Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 54 additions & 20 deletions extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ void cpu_flash_attention(
/* qk_sum */ qSplitSize +
/* dst */ qSplitSize * headSize;

int64_t size_bytes = size_per_thread * num_thread * query.element_size();
int64_t size_bytes = size_per_thread * num_thread * query.element_size() * 4;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is *4 for fp32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh really good catch. Thats was left over from local debug. need to remove

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please finalize so that @spalatinate has the correct version

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh totally forgot about this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here is the fix pr #9492

std::vector<char> buf_vec(size_bytes);
void* buf = reinterpret_cast<void*>(buf_vec.data());
// Need to double check the following
Expand Down Expand Up @@ -452,6 +452,7 @@ void cpu_flash_attention(
// However, lets just fix that as well.
int64_t num_keys =
is_causal ? std::min(m + start_pos + qBlockSize, kvSize) : kvSize;
int64_t m_start_pos = m + start_pos;
auto j_kv = j / num_reps;
for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
Expand All @@ -471,29 +472,62 @@ void cpu_flash_attention(
static_cast<accum_t>(0),
qk_data,
kvBlockSize);
// Apply causal mask, fill unused, i.e. future values, with -inf
// Say you have q @ k.T size = [16, 32]
// With qblock size = 4, say you are processing
// q seq len dim = 8:11.
// Say kvSplitSize = 4
// Then for causal mask, the entries that needs to be
// ignored are
// [8, 9:31], [9, 10:31], [10, 10:31], [11, 11:31]
// Following condition says that num_keys = 8 + 4 =12
// (num_keys - n) <= kvSplitSize
// num_keys <= n + kvSplitSize
// If n + kvSplitSize is larger than 12, then some
// entries need masked out. In our example n = 4
// will qualify for that
if (is_causal && num_keys - n <= kvSplitSize) {
// There are 4 cases that is_causal has to cover to fill
// not-attendable-position with -inf
/* 1. Everything is attended to. This happens when m_start_pos > n +
kvSplitSize e.g m_pos [8:15] and n_pos [0:7]. Since you must attend to
all previous tokens matrix is full
+ + + + + + + +
+ + + + + + + +
+ + + + + + + +
+ + + + + + + +
+ + + + + + + +
+ + + + + + + +
+ + + + + + + +
2. Everything is not attended to. However only some tokens at the
beginning dont attend to everything. This happens when m_start_pos <= n
+ kvSplitSize but m_start_pos + qBlockSize > n + kvSplitSize m_start_pos
= 8 qBlockSize = 8 n = 4 kvSplitSize = 8 For example m_pos [8:15] but
n_pos is [4:11]
+ + + + + - - -
+ + + + + + - -
+ + + + + + + -
+ + + + + + + +
+ + + + + + + +
+ + + + + + + +
+ + + + + + + +
+ + + + + + + +
3. In this case only last few tokens have something to attend to.
This happens when m_start_pos < n and m_start_pos + qBlockSize >= n and
m_start_pos + qBlockSize <= n + kvSplitSize m_start_pos = 8 qBlockSize =
8 n = 13 kvSplitSize = 8 For example m_pos [8:15] but n_pos is [13:20]
- - - - - - - -
- - - - - - - -
- - - - - - - -
- - - - - - - -
- - - - - - - -
+ - - - - - - -
+ + - - - - - -
+ + + - - - - -
4. In this no tokens attend to anything, but we dont really have to
take care of this case because the loop for (int64_t n = 0; n <
num_keys; n += kvSplitSize) will exit before that.
*/
if (is_causal && m_start_pos <= n + kvSplitSize) {
// For this fn to work k_split_size > q_split_size
for (int32_t row = 0; row < qBlockSize; ++row) {
int64_t last_col = m + (row + start_pos) - n;
for (int32_t row = 0;
row < qBlockSize && (m_start_pos + row < n + (kvSplitSize - 1));
++row) {
// When last_col is 0, it means that the entire row is not attended
// to because m_pos is smaller than n_pos. So everything in n is for
// future.
int64_t last_col =
n > (m_start_pos + row) ? 0 : row + m_start_pos + 1 - n;
accum_t* row_ptr = qk_data + row * kvBlockSize;
fill_stub(
row_ptr + last_col + 1,
row_ptr + last_col,
-std::numeric_limits<accum_t>::infinity(),
kvBlockSize - last_col - 1);
kvBlockSize - last_col);
}
}
// Update attention weights with attention mask
Expand Down
11 changes: 11 additions & 0 deletions extension/llm/custom_ops/test_sdpa_with_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,14 @@ def test_sdpa_with_cache_seq_len_llava_example_gqa(self):
self._test_sdpa_common(
n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len
)

def test_sdpa_to_repro_long_seq_failure(self):
n_heads_kv = 16
n_heads_q = 32
head_dim = 128
max_seq_len = 2048
seq_len = 508
next_iter_seq_len = 127
self._test_sdpa_common(
n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len
)
Loading