Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ void cpu_flash_attention(
bool is_causal,
const optional<Tensor>& attn_mask,
const optional<double>& scale,
bool is_with_kv_cache = false,
bool is_seq_at_dim_1 = false,
const int64_t start_pos = 0) {
(void)dropout_p;
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
Expand Down Expand Up @@ -265,7 +265,7 @@ void cpu_flash_attention(
int64_t kvSize = value.size(2);
int64_t num_heads_kv = key.size(1);

if (is_with_kv_cache) {
if (is_seq_at_dim_1) {
num_head = query.size(2);
num_heads_kv = key.size(2);
qSize = query.size(1);
Expand Down Expand Up @@ -311,7 +311,7 @@ void cpu_flash_attention(
int64_t qStrideH = strides[1];
int64_t qStrideM = strides[2];

if (is_with_kv_cache) {
if (is_seq_at_dim_1) {
qStrideH = strides[2];
qStrideM = strides[1];
}
Expand All @@ -321,7 +321,7 @@ void cpu_flash_attention(
int64_t kStrideH = strides[1];
int64_t kStrideN = strides[2];

if (is_with_kv_cache) {
if (is_seq_at_dim_1) {
kStrideH = strides[2];
kStrideN = strides[1];
}
Expand All @@ -331,7 +331,7 @@ void cpu_flash_attention(
int64_t vStrideH = strides[1];
int64_t vStrideN = strides[2];

if (is_with_kv_cache) {
if (is_seq_at_dim_1) {
vStrideH = strides[2];
vStrideN = strides[1];
}
Expand All @@ -341,7 +341,7 @@ void cpu_flash_attention(
int64_t oStrideH = strides[1];
int64_t oStrideM = strides[2];

if (is_with_kv_cache) {
if (is_seq_at_dim_1) {
oStrideH = strides[2];
oStrideM = strides[1];
}
Expand Down Expand Up @@ -776,7 +776,6 @@ Tensor& custom_sdpa_out(
const Tensor& k,
const Tensor& v,
const int64_t start_pos,
const int64_t seq_len,
const optional<Tensor>& attn_mask,
const double dropout_p,
const bool is_causal,
Expand All @@ -792,6 +791,7 @@ Tensor& custom_sdpa_out(

ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor");

const int64_t seq_len = q.size(1);
auto q_seq_len = q.size(1);

// Refactor the following into create_view util perhaps using
Expand Down Expand Up @@ -870,7 +870,7 @@ Tensor& custom_sdpa_out(
is_causal,
attn_mask,
scale,
true,
true, /* is_seq_at_dim_1 */
start_pos);
} else if (q_seq_len >= 192) {
cpu_flash_attention<CTYPE, 64, 512>(
Expand All @@ -882,7 +882,7 @@ Tensor& custom_sdpa_out(
is_causal,
attn_mask,
scale,
true,
true, /* is_seq_at_dim_1 */
start_pos);
} else {
cpu_flash_attention<CTYPE, 32, 512>(
Expand All @@ -894,7 +894,7 @@ Tensor& custom_sdpa_out(
is_causal,
attn_mask,
scale,
true,
true, /* is_seq_at_dim_1 */
start_pos);
}
});
Expand Down Expand Up @@ -1017,7 +1017,6 @@ Tensor& sdpa_with_kv_cache_out(
key_cache,
value_cache,
start_pos,
seq_len,
attn_mask,
dropout_p,
is_causal,
Expand Down
Loading