diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index e74c111cc5f..b5cb2b55e99 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -224,7 +224,7 @@ void cpu_flash_attention( bool is_causal, const optional& attn_mask, const optional& scale, - bool is_with_kv_cache = false, + bool is_seq_at_dim_1 = false, const int64_t start_pos = 0) { (void)dropout_p; // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) @@ -265,7 +265,7 @@ void cpu_flash_attention( int64_t kvSize = value.size(2); int64_t num_heads_kv = key.size(1); - if (is_with_kv_cache) { + if (is_seq_at_dim_1) { num_head = query.size(2); num_heads_kv = key.size(2); qSize = query.size(1); @@ -311,7 +311,7 @@ void cpu_flash_attention( int64_t qStrideH = strides[1]; int64_t qStrideM = strides[2]; - if (is_with_kv_cache) { + if (is_seq_at_dim_1) { qStrideH = strides[2]; qStrideM = strides[1]; } @@ -321,7 +321,7 @@ void cpu_flash_attention( int64_t kStrideH = strides[1]; int64_t kStrideN = strides[2]; - if (is_with_kv_cache) { + if (is_seq_at_dim_1) { kStrideH = strides[2]; kStrideN = strides[1]; } @@ -331,7 +331,7 @@ void cpu_flash_attention( int64_t vStrideH = strides[1]; int64_t vStrideN = strides[2]; - if (is_with_kv_cache) { + if (is_seq_at_dim_1) { vStrideH = strides[2]; vStrideN = strides[1]; } @@ -341,7 +341,7 @@ void cpu_flash_attention( int64_t oStrideH = strides[1]; int64_t oStrideM = strides[2]; - if (is_with_kv_cache) { + if (is_seq_at_dim_1) { oStrideH = strides[2]; oStrideM = strides[1]; } @@ -776,7 +776,6 @@ Tensor& custom_sdpa_out( const Tensor& k, const Tensor& v, const int64_t start_pos, - const int64_t seq_len, const optional& attn_mask, const double dropout_p, const bool is_causal, @@ -792,6 +791,7 @@ Tensor& custom_sdpa_out( ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor"); + const int64_t seq_len = q.size(1); auto q_seq_len = q.size(1); // Refactor the following into create_view util perhaps using @@ -870,7 +870,7 @@ Tensor& custom_sdpa_out( is_causal, attn_mask, scale, - true, + true, /* is_seq_at_dim_1 */ start_pos); } else if (q_seq_len >= 192) { cpu_flash_attention( @@ -882,7 +882,7 @@ Tensor& custom_sdpa_out( is_causal, attn_mask, scale, - true, + true, /* is_seq_at_dim_1 */ start_pos); } else { cpu_flash_attention( @@ -894,7 +894,7 @@ Tensor& custom_sdpa_out( is_causal, attn_mask, scale, - true, + true, /* is_seq_at_dim_1 */ start_pos); } }); @@ -1017,7 +1017,6 @@ Tensor& sdpa_with_kv_cache_out( key_cache, value_cache, start_pos, - seq_len, attn_mask, dropout_p, is_causal,