From 7b097a37d9ab36be434567aabd6d3889ccbe734f Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 4 Nov 2025 11:52:44 -0800 Subject: [PATCH] [ET-VK][ez] Align SDPA attention weights S dim to the next multiple of 4 Title says it all! Why? * Technically, this is should not be needed but SDPA op was producing incorrect output on Samsung S24 with buffer input tensors. The exact root cause is unclear, but it appears to be an issue specific to the Adreno 750 since it does not reproduce on any other GPU. The best guess at the moment is that we need to ensure that there is no possibility of multiple threads writing to the same memory location. Differential Revision: [D86226134](https://our.internmc.facebook.com/intern/diff/D86226134/) [ghstack-poisoned] --- .../graph/ops/glsl/sdpa_attn_weights_softmax.glsl | 13 +++++++------ .../ops/glsl/sdpa_compute_attn_weights_coop.glsl | 3 ++- .../ops/glsl/sdpa_compute_attn_weights_tiled.glsl | 3 ++- .../graph/ops/glsl/sdpa_compute_out_coop.glsl | 5 +++-- .../graph/ops/glsl/sdpa_compute_out_tiled.glsl | 5 +++-- backends/vulkan/runtime/graph/ops/impl/SDPA.cpp | 2 +- 6 files changed, 18 insertions(+), 13 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl index 67d9c100f68..652453bbec7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl @@ -76,6 +76,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // manually determine size of the context_len dim of the attention weight. // The "actual" tensor sizes may have been aligned to a multiple of 4 to allow // memory loads to be aligned to texel boundaries. @@ -96,7 +97,7 @@ void main() { // number of threads in the work group. for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S, Q_H); + c4, s, q_h, context_texel_len, S_aligned, Q_H); for (int comp = 0; comp < 4; comp++) { local_exp_sum += exp(in_texel[comp]); @@ -108,7 +109,7 @@ void main() { for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { const int c_base = mul_4(c4); VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S, Q_H); + c4, s, q_h, context_texel_len, S_aligned, Q_H); [[unroll]] for (int comp = 0; comp < 4; comp++) { if (c_base + comp < context_len) { @@ -138,11 +139,11 @@ void main() { // Now go back through each element in the row and normalize for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S, Q_H); + c4, s, q_h, context_texel_len, S_aligned, Q_H); VEC4_T out_texel = exp(in_texel) / local_exp_sum; store_attn_weights_softmax_c4( - out_texel, c4, s, q_h, context_texel_len, S, Q_H); + out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H); } // First thread in the work group responsible for handling last texel if it // contains any padded elements @@ -150,7 +151,7 @@ void main() { for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { const int c_base = mul_4(c4); VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S, Q_H); + c4, s, q_h, context_texel_len, S_aligned, Q_H); // Ensure that padding elements are set to 0. VEC4_T out_texel = VEC4_T(0); @@ -160,7 +161,7 @@ void main() { } } store_attn_weights_softmax_c4( - out_texel, c4, s, q_h, context_texel_len, S, Q_H); + out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl index 2900d63666b..a4bf588949b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl @@ -81,6 +81,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // number of K/V heads const int KV_H = k_cache_sizes.y; @@ -205,7 +206,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl index 95c22d91b80..ef0c3c571c9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl @@ -93,6 +93,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // number of K/V heads const int KV_H = k_cache_sizes.y; @@ -196,6 +197,6 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl index 5f408b7581d..cc60193cf18 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl @@ -81,6 +81,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // number of K/V heads const int KV_H = v_cache_sizes.y; @@ -120,7 +121,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); load_v_cache_tile_no_checks( @@ -146,7 +147,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); load_v_cache_tile_with_checks( diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl index 0063ebf9d38..385ad7a921e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl @@ -75,6 +75,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // number of K/V heads const int KV_H = v_cache_sizes.y; @@ -113,7 +114,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); load_v_cache_tile_no_checks( @@ -136,7 +137,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); load_v_cache_tile_with_checks( diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 6b4da5d95f1..f514530f175 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -50,7 +50,7 @@ void resize_compute_attn_weights_node( std::vector out_sizes = { 1, // batch num_q_heads, - seq_len, + utils::align_up_4(seq_len), utils::align_up_4(context_len)}; graph->virtual_resize(attn_weights, out_sizes);