diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl index 67d9c100f68..652453bbec7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl @@ -76,6 +76,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // manually determine size of the context_len dim of the attention weight. // The "actual" tensor sizes may have been aligned to a multiple of 4 to allow // memory loads to be aligned to texel boundaries. @@ -96,7 +97,7 @@ void main() { // number of threads in the work group. for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S, Q_H); + c4, s, q_h, context_texel_len, S_aligned, Q_H); for (int comp = 0; comp < 4; comp++) { local_exp_sum += exp(in_texel[comp]); @@ -108,7 +109,7 @@ void main() { for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { const int c_base = mul_4(c4); VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S, Q_H); + c4, s, q_h, context_texel_len, S_aligned, Q_H); [[unroll]] for (int comp = 0; comp < 4; comp++) { if (c_base + comp < context_len) { @@ -138,11 +139,11 @@ void main() { // Now go back through each element in the row and normalize for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S, Q_H); + c4, s, q_h, context_texel_len, S_aligned, Q_H); VEC4_T out_texel = exp(in_texel) / local_exp_sum; store_attn_weights_softmax_c4( - out_texel, c4, s, q_h, context_texel_len, S, Q_H); + out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H); } // First thread in the work group responsible for handling last texel if it // contains any padded elements @@ -150,7 +151,7 @@ void main() { for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { const int c_base = mul_4(c4); VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S, Q_H); + c4, s, q_h, context_texel_len, S_aligned, Q_H); // Ensure that padding elements are set to 0. VEC4_T out_texel = VEC4_T(0); @@ -160,7 +161,7 @@ void main() { } } store_attn_weights_softmax_c4( - out_texel, c4, s, q_h, context_texel_len, S, Q_H); + out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl index 2900d63666b..a4bf588949b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl @@ -81,6 +81,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // number of K/V heads const int KV_H = k_cache_sizes.y; @@ -205,7 +206,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl index 95c22d91b80..ef0c3c571c9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl @@ -93,6 +93,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // number of K/V heads const int KV_H = k_cache_sizes.y; @@ -196,6 +197,6 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl index 5f408b7581d..cc60193cf18 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl @@ -81,6 +81,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // number of K/V heads const int KV_H = v_cache_sizes.y; @@ -120,7 +121,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); load_v_cache_tile_no_checks( @@ -146,7 +147,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); load_v_cache_tile_with_checks( diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl index 0063ebf9d38..385ad7a921e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl @@ -75,6 +75,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // number of K/V heads const int KV_H = v_cache_sizes.y; @@ -113,7 +114,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); load_v_cache_tile_no_checks( @@ -136,7 +137,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); load_v_cache_tile_with_checks( diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 6b4da5d95f1..f514530f175 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -50,7 +50,7 @@ void resize_compute_attn_weights_node( std::vector out_sizes = { 1, // batch num_q_heads, - seq_len, + utils::align_up_4(seq_len), utils::align_up_4(context_len)}; graph->virtual_resize(attn_weights, out_sizes);