From 9bfb6c912ad55534ad7b3aaa14a76e4d83562a63 Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 24 Apr 2026 09:46:33 -0700 Subject: [PATCH] [ET-VK] Update fused SDPA operator to support ViT attention This diff extends the ET-VK fused SDPA operator so it can be used for the ViT attention blocks in the EdgeTAM ViT-S encoder. The main correctness problem is that Q@K^T dot products in ViT attention can exceed the fp16 max (65504), so fp32 accumulation is required. **fp16 overflow fix**: The intermediate `attn_weights` buffer is now always fp32 regardless of input dtype. Previously the QK shader accumulated in fp32 but stored to an fp16 buffer, causing overflow. The softmax shader reads fp32 attention weights and writes fp16 softmax output (safe since values are in [0, 1]). **Texture support**: The QK and AV shaders support both buffer and texture3d storage for Q/K/V/output. The intermediate `attn_weights` and `attn_weights_softmax` tensors now inherit the storage type of the input/output (q_projected for the LLM path, out for the fused path), so the entire fused SDPA pipeline runs in a uniform storage type and no SDPA-internal layout transitions are needed. Differential Revision: [D102360200](https://our.internmc.facebook.com/intern/diff/D102360200/) [ghstack-poisoned] --- backends/vulkan/custom_ops_lib.py | 28 + backends/vulkan/op_registry.py | 14 + .../graph/ops/glsl/linear_fp_input_tile.glslh | 9 +- .../ops/glsl/linear_fp_input_tile_load.glslh | 8 +- .../ops/glsl/linear_fp_output_tile.glslh | 14 +- .../linear_fp_output_tile_fp_compute.glslh | 22 +- .../glsl/linear_fp_output_tile_store.glslh | 6 +- .../linear_fp_packed_weight_tile_load.glslh | 8 +- .../ops/glsl/linear_fp_weight_tile.glslh | 10 +- .../ops/glsl/sdpa_attn_weights_softmax.glsl | 133 ++-- .../ops/glsl/sdpa_attn_weights_softmax.yaml | 24 +- .../glsl/sdpa_compute_attn_weights_coop.glsl | 30 +- .../glsl/sdpa_compute_attn_weights_tiled.glsl | 162 +++-- .../glsl/sdpa_compute_attn_weights_tiled.yaml | 42 +- .../graph/ops/glsl/sdpa_compute_out_coop.glsl | 19 +- .../ops/glsl/sdpa_compute_out_tiled.glsl | 148 +++-- .../ops/glsl/sdpa_compute_out_tiled.yaml | 13 + .../glsl/sdpa_fp_attn_weight_tile_load.glslh | 24 +- .../glsl/sdpa_fp_attn_weight_tile_store.glslh | 90 ++- .../ops/glsl/sdpa_fp_k_cache_tile_load.glslh | 70 +- .../ops/glsl/sdpa_fp_out_tile_store.glslh | 41 +- .../glsl/sdpa_fp_q_projected_tile_load.glslh | 71 +- .../ops/glsl/sdpa_fp_v_cache_tile_load.glslh | 57 +- .../vulkan/runtime/graph/ops/impl/SDPA.cpp | 615 +++++++++++++----- backends/vulkan/test/op_tests/sdpa_test.cpp | 236 +++++++ 25 files changed, 1425 insertions(+), 469 deletions(-) diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index a59b150e7ae..64c6d3e46d9 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -960,6 +960,34 @@ def select_as_symint_impl(x: torch.Tensor, dim: int, index: int): lib.impl(name, select_as_symint_impl, "Meta") select_as_symint_op = getattr(getattr(torch.ops, namespace), name) +########## +## sdpa ## +########## + + +def sdpa_impl( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + scale: Optional[float] = None, +): + if scale is None: + scale = 1.0 / (q.size(-1) ** 0.5) + attn = torch.matmul(q, k.transpose(-2, -1)) * scale + if attn_mask is not None: + attn = attn + attn_mask + attn = torch.softmax(attn, dim=-1) + return torch.matmul(attn, v) + + +name = "sdpa" +lib.define( + f"{name}(Tensor q, Tensor k, Tensor v, Tensor? attn_mask = None, float? scale = None) -> Tensor" +) +lib.impl(name, sdpa_impl, "CompositeExplicitAutograd") +sdpa_op = getattr(getattr(torch.ops, namespace), name) + ################ ## rms_norm ## ################ diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index ff056d76c3a..2e313f0f91b 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -1071,6 +1071,20 @@ def register_sdpa_cpp_ops(): ) +# ============================================================================= +# SDPA.cpp (fused SDPA entry point) +# ============================================================================= + + +@update_features("et_vk::sdpa") +def register_general_sdpa(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_ANY, + inputs_dtypes=utils.FP_T, + supports_resize=True, + ) + + # ============================================================================= # RotaryEmbedding.cpp # ============================================================================= diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh index 72b5bdb812e..581b05072df 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh @@ -13,12 +13,19 @@ * Macro Settings: * - TILE_M * - TILE_K4 + * + * Optional: + * - LINEAR_FP_INPUT_TILE_VEC4_T — input tile vec4 type (default: VEC4_T). */ #extension GL_EXT_control_flow_attributes : require +#ifndef LINEAR_FP_INPUT_TILE_VEC4_T +#define LINEAR_FP_INPUT_TILE_VEC4_T VEC4_T +#endif + struct FPInputTile { - VEC4_T data[TILE_M][TILE_K4]; + LINEAR_FP_INPUT_TILE_VEC4_T data[TILE_M][TILE_K4]; }; #ifdef DEBUG_MODE diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh index 358379b3efd..84bf2e07bea 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh @@ -21,11 +21,11 @@ #include "linear_fp_input_tile.glslh" -VEC4_T load_input_x4(const int k4, const int m, const int ntexels_k) { +LINEAR_FP_INPUT_TILE_VEC4_T load_input_x4(const int k4, const int m, const int ntexels_k) { #ifdef INPUT_BUFFER - return t_input[(m * ntexels_k) + k4]; + return LINEAR_FP_INPUT_TILE_VEC4_T(t_input[(m * ntexels_k) + k4]); #else - return texelFetch(t_input, ivec3(k4, m, 0), 0); + return LINEAR_FP_INPUT_TILE_VEC4_T(texelFetch(t_input, ivec3(k4, m, 0), 0)); #endif } @@ -53,7 +53,7 @@ void load_input_tile_with_checks( if (m_start + m < M && k4_start + k4 < K4) { in_tile.data[m][k4] = load_input_x4(k4_start + k4, m_start + m, K4); } else { - in_tile.data[m][k4] = VEC4_T(0.0); + in_tile.data[m][k4] = LINEAR_FP_INPUT_TILE_VEC4_T(0.0); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh index ca466447084..b6fc31951f5 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh @@ -10,6 +10,11 @@ * Macro Settings: * - TILE_M * - TILE_N4 + * + * Optional: + * - LINEAR_FP_OUTPUT_TILE_VEC4_T — accumulator vec4 type (default: VEC4_T). + * Set this to `vec4` to force fp32 accumulation regardless of DTYPE; used + * by fused SDPA QK to avoid fp16 overflow in Q@K^T. */ #ifndef LINEAR_FP_OUTPUT_TILE_GLSLH @@ -17,14 +22,19 @@ #extension GL_EXT_control_flow_attributes : require +#ifndef LINEAR_FP_OUTPUT_TILE_VEC4_T +#define LINEAR_FP_OUTPUT_TILE_VEC4_T VEC4_T +#define LINEAR_FP_OUTPUT_TILE_VEC4_T_IS_DEFAULT +#endif + struct FPOutTile { - VEC4_T data[TILE_M][TILE_N4]; + LINEAR_FP_OUTPUT_TILE_VEC4_T data[TILE_M][TILE_N4]; }; void initialize(out FPOutTile out_tile) { [[unroll]] for (int m = 0; m < TILE_M; ++m) { [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { - out_tile.data[m][n4] = VEC4_T(0); + out_tile.data[m][n4] = LINEAR_FP_OUTPUT_TILE_VEC4_T(0); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh index 60a19ca9fc9..73faf9074ac 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh @@ -21,6 +21,12 @@ #include "linear_fp_per_out_channel_params.glslh" #include "linear_fp_weight_tile.glslh" +#if defined(LINEAR_FP_WEIGHT_TILE_VEC4_T_IS_DEFAULT) == defined(LINEAR_FP_OUTPUT_TILE_VEC4_T_IS_DEFAULT) +#define MAYBE_CAST_WVEC4(x) (x) +#else +#define MAYBE_CAST_WVEC4(x) LINEAR_FP_OUTPUT_TILE_VEC4_T(x) +#endif + void fp_accumulate_with_fp_weight( inout FPOutTile accum, FPInputTile in_tile, @@ -29,23 +35,23 @@ void fp_accumulate_with_fp_weight( [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { accum.data[m][n4] = - fma(VEC4_T(in_tile.data[m][k4][0]), - w_tile.data[mul_4(k4)][n4], + fma(LINEAR_FP_OUTPUT_TILE_VEC4_T(in_tile.data[m][k4][0]), + MAYBE_CAST_WVEC4(w_tile.data[mul_4(k4)][n4]), accum.data[m][n4]); accum.data[m][n4] = - fma(VEC4_T(in_tile.data[m][k4][1]), - w_tile.data[mul_4(k4) + 1][n4], + fma(LINEAR_FP_OUTPUT_TILE_VEC4_T(in_tile.data[m][k4][1]), + MAYBE_CAST_WVEC4(w_tile.data[mul_4(k4) + 1][n4]), accum.data[m][n4]); accum.data[m][n4] = - fma(VEC4_T(in_tile.data[m][k4][2]), - w_tile.data[mul_4(k4) + 2][n4], + fma(LINEAR_FP_OUTPUT_TILE_VEC4_T(in_tile.data[m][k4][2]), + MAYBE_CAST_WVEC4(w_tile.data[mul_4(k4) + 2][n4]), accum.data[m][n4]); accum.data[m][n4] = - fma(VEC4_T(in_tile.data[m][k4][3]), - w_tile.data[mul_4(k4) + 3][n4], + fma(LINEAR_FP_OUTPUT_TILE_VEC4_T(in_tile.data[m][k4][3]), + MAYBE_CAST_WVEC4(w_tile.data[mul_4(k4) + 3][n4]), accum.data[m][n4]); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh index 6fb399ff99b..9ee5f004cf5 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh @@ -25,14 +25,14 @@ #include "linear_fp_output_tile.glslh" void write_output_x4( - const VEC4_T out_texel, + const LINEAR_FP_OUTPUT_TILE_VEC4_T out_texel, const int n4, const int m, const int N4) { #ifdef OUTPUT_BUFFER - t_output[m * N4 + n4] = out_texel; + t_output[m * N4 + n4] = VEC4_T(out_texel); #else - imageStore(t_output, ivec3(n4, m, 0), out_texel); + imageStore(t_output, ivec3(n4, m, 0), VEC4_T(out_texel)); #endif } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_packed_weight_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_packed_weight_tile_load.glslh index 36b2a7296ef..5592042f6f7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_packed_weight_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_packed_weight_tile_load.glslh @@ -23,12 +23,12 @@ #include "linear_fp_weight_tile.glslh" -VEC4_T load_packed_weight_x4( +LINEAR_FP_WEIGHT_TILE_VEC4_T load_packed_weight_x4( const int n4, const int dk, const int k4, const int b, const int K4, const int N4) { #ifdef WEIGHT_BUFFER - return t_weight_packed[((b * K4 + k4) * N4 + n4) * 4 + dk]; + return LINEAR_FP_WEIGHT_TILE_VEC4_T(t_weight_packed[((b * K4 + k4) * N4 + n4) * 4 + dk]); #else - return VEC4_T(texelFetch(t_weight_packed, ivec2(n4 * 4 + dk, b * K4 + k4), 0)); + return LINEAR_FP_WEIGHT_TILE_VEC4_T(texelFetch(t_weight_packed, ivec2(n4 * 4 + dk, b * K4 + k4), 0)); #endif } @@ -65,7 +65,7 @@ void load_packed_weight_tile_with_checks( if (k4 < K4 && n4_start + n4 < N4) { tile.data[k][n4] = load_packed_weight_x4(n4_start + n4, dk, k4, b, K4, N4); } else { - tile.data[k][n4] = VEC4_T(0); + tile.data[k][n4] = LINEAR_FP_WEIGHT_TILE_VEC4_T(0); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh index 5e010442540..c57c5e72f0d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh @@ -10,6 +10,9 @@ * Macro Settings: * - TILE_K * - TILE_N4 + * + * Optional: + * - LINEAR_FP_WEIGHT_TILE_VEC4_T — weight tile vec4 type (default: VEC4_T). */ #ifndef LINEAR_FP_WEIGHT_TILE_GLSLH @@ -19,8 +22,13 @@ #include "common.glslh" +#ifndef LINEAR_FP_WEIGHT_TILE_VEC4_T +#define LINEAR_FP_WEIGHT_TILE_VEC4_T VEC4_T +#define LINEAR_FP_WEIGHT_TILE_VEC4_T_IS_DEFAULT +#endif + struct FPWeightTile { - VEC4_T data[TILE_K][TILE_N4]; + LINEAR_FP_WEIGHT_TILE_VEC4_T data[TILE_K][TILE_N4]; }; #ifdef DEBUG_MODE diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl index e6c118b6ab2..6c095e66255 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl @@ -9,14 +9,22 @@ #version 450 core #define PRECISION ${PRECISION} -#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} -#define T ${texel_load_component_type(DTYPE, STORAGE)} #define NUM_WORKERS_PER_WG 64 +$if MODE == "llm": + #define HAS_INPUT_POS + +#define IN_DTYPE ${IN_DTYPE} +#define OUT_DTYPE ${OUT_DTYPE} +#define SOFTMAX_IN_VEC4_T ${texel_load_type(IN_DTYPE, STORAGE)} +#define SOFTMAX_ACC_T ${texel_load_component_type(IN_DTYPE, STORAGE)} +#define VEC4_T ${texel_load_type(OUT_DTYPE, STORAGE)} +#define T ${texel_load_component_type(OUT_DTYPE, STORAGE)} + ${define_active_storage_type(STORAGE)} -${define_required_extensions(STORAGE, DTYPE)} +${define_required_extensions(STORAGE, [IN_DTYPE, OUT_DTYPE])} #extension GL_EXT_control_flow_attributes : require @@ -24,19 +32,22 @@ layout(std430) buffer; #include "common.glslh" -${layout_declare_tensor(B, "w", "t_attn_weights_softmax", DTYPE, STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_attn_weights", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "w", "t_attn_weights_softmax", OUT_DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_attn_weights", IN_DTYPE, STORAGE, is_scalar_array=False)} -${layout_declare_ubo(B, "ivec4", "q_projected_sizes")} -${layout_declare_ubo(B, "int", "input_pos")} +${layout_declare_ubo(B, "ivec4", "q_sizes")} +${layout_declare_ubo(B, "ivec4", "k_sizes")} +$if MODE == "llm": + ${layout_declare_ubo(B, "int", "input_pos")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -// Shared memory for cooperative max finding and exp sum reduction -shared T shared_max[NUM_WORKERS_PER_WG]; -shared T shared_exp_sum[NUM_WORKERS_PER_WG]; +// Shared memory for cooperative max finding and exp sum reduction. +// For fused SDPA, reductions happen in fp32 for numerical stability. +shared SOFTMAX_ACC_T shared_max[NUM_WORKERS_PER_WG]; +shared SOFTMAX_ACC_T shared_exp_sum[NUM_WORKERS_PER_WG]; -VEC4_T load_attn_weights_c4( +SOFTMAX_IN_VEC4_T load_attn_weights_c4( const int c4, const int s, const int q_h, @@ -46,7 +57,7 @@ VEC4_T load_attn_weights_c4( #ifdef USING_BUFFER return t_attn_weights[(q_h * S * C4) + (s * C4) + c4]; #else - return texelFetch(t_attn_weights, ivec3(c4, s, q_h), 0); + return SOFTMAX_IN_VEC4_T(texelFetch(t_attn_weights, ivec3(c4, s, q_h), 0)); #endif } @@ -65,26 +76,61 @@ void store_attn_weights_softmax_c4( #endif } +/* + * 3-pass numerically stable softmax over the context_len dimension of + * attention weights. + * + * LLM SDPA (HAS_INPUT_POS): + * reads VEC4_T (input dtype), reduces in T, writes VEC4_T. + * attn_weights S dim is padded to S_aligned. + * current context_len = input_pos + S. + * + * Fused SDPA (!HAS_INPUT_POS): + * reads vec4 (fp32 from QK), reduces in fp32, writes VEC4_T (input dtype). + * attn_weights S dim is not padded. + * context_len = k_sizes.y. + * + * Dispatch: (1, S, H * B) — for LLM (batch=1), H * B == H_q. + */ void main() { const int worker_id = int(gl_LocalInvocationID.x); // Index along attention weight's sequence_len dim const int s = int(gl_GlobalInvocationID.y); - // idx along attention weight's num_q_heads dim + // For LLM: q_head index. For fused: combined batch*H + head index. const int q_h = int(gl_GlobalInvocationID.z); - // number of Q heads - const int Q_H = q_projected_sizes.y; - // sequence length - const int S = q_projected_sizes.z; +#ifdef HAS_INPUT_POS + // LLM: q_sizes is WHCN {D, H_q, S, B} + const int Q_H = q_sizes.y; + const int S = q_sizes.z; +#else + // Fused: q_sizes is WHCN {D, S, H, B} + const int Q_H = q_sizes.z; + const int S = q_sizes.y; +#endif const int S_aligned = align_up_4(S); + +#ifdef HAS_INPUT_POS // manually determine size of the context_len dim of the attention weight. // The "actual" tensor sizes may have been aligned to a multiple of 4 to allow // memory loads to be aligned to texel boundaries. const int context_len = input_pos + S; +#else + const int context_len = k_sizes.y; +#endif const int context_texel_len = div_up_4(context_len); - if (s >= S || q_h >= Q_H) { + // LLM: attn_weights S dim is padded to S_aligned; fused: not padded. +#ifdef HAS_INPUT_POS + const int attn_S = S_aligned; +#else + const int attn_S = S; +#endif + + // bounds check — q_h bound is Q_H * batch_size; for LLM (batch=1) this + // equals Q_H, for fused this equals H * B. + if (s >= S || q_h >= Q_H * q_sizes.w) { return; } @@ -96,25 +142,25 @@ void main() { // Without this, exp(x) can overflow float32 when x > ~88.7. // ========================================================================= - T local_max = T(-1.0 / 0.0); // -infinity + SOFTMAX_ACC_T local_max = SOFTMAX_ACC_T(-1.0 / 0.0); // -infinity for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { - VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S_aligned, Q_H); + SOFTMAX_IN_VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, attn_S, Q_H); for (int comp = 0; comp < 4; comp++) { - local_max = max(local_max, in_texel[comp]); + local_max = max(local_max, SOFTMAX_ACC_T(in_texel[comp])); } } if (worker_id == 0) { for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { const int c_base = mul_4(c4); - VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S_aligned, Q_H); + SOFTMAX_IN_VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, attn_S, Q_H); [[unroll]] for (int comp = 0; comp < 4; comp++) { if (c_base + comp < context_len) { - local_max = max(local_max, in_texel[comp]); + local_max = max(local_max, SOFTMAX_ACC_T(in_texel[comp])); } } } @@ -135,31 +181,31 @@ void main() { barrier(); } - const T global_max = shared_max[0]; + const SOFTMAX_ACC_T global_max = shared_max[0]; // ========================================================================= // Pass 2: Compute sum(exp(x - max)) using the global max for stability // ========================================================================= - T local_exp_sum = T(0); + SOFTMAX_ACC_T local_exp_sum = SOFTMAX_ACC_T(0); for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { - VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S_aligned, Q_H); + SOFTMAX_IN_VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, attn_S, Q_H); for (int comp = 0; comp < 4; comp++) { - local_exp_sum += exp(in_texel[comp] - global_max); + local_exp_sum += exp(SOFTMAX_ACC_T(in_texel[comp]) - global_max); } } if (worker_id == 0) { for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { const int c_base = mul_4(c4); - VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S_aligned, Q_H); + SOFTMAX_IN_VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, attn_S, Q_H); [[unroll]] for (int comp = 0; comp < 4; comp++) { if (c_base + comp < context_len) { - local_exp_sum += exp(in_texel[comp] - global_max); + local_exp_sum += exp(SOFTMAX_ACC_T(in_texel[comp]) - global_max); } } } @@ -187,27 +233,32 @@ void main() { // ========================================================================= for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { - VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S_aligned, Q_H); + SOFTMAX_IN_VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, attn_S, Q_H); - VEC4_T out_texel = exp(in_texel - global_max) / local_exp_sum; + VEC4_T out_texel; + [[unroll]] for (int comp = 0; comp < 4; comp++) { + out_texel[comp] = T( + exp(SOFTMAX_ACC_T(in_texel[comp]) - global_max) / local_exp_sum); + } store_attn_weights_softmax_c4( - out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H); + out_texel, c4, s, q_h, context_texel_len, attn_S, Q_H); } if (worker_id == 0) { for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { const int c_base = mul_4(c4); - VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S_aligned, Q_H); + SOFTMAX_IN_VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, attn_S, Q_H); VEC4_T out_texel = VEC4_T(0); [[unroll]] for (int comp = 0; comp < 4; comp++) { if (c_base + comp < context_len) { - out_texel[comp] = exp(in_texel[comp] - global_max) / local_exp_sum; + out_texel[comp] = T( + exp(SOFTMAX_ACC_T(in_texel[comp]) - global_max) / local_exp_sum); } } store_attn_weights_softmax_c4( - out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H); + out_texel, c4, s, q_h, context_texel_len, attn_S, Q_H); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml index 66ec030680e..d46e301e203 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml @@ -6,14 +6,30 @@ sdpa_attn_weights_softmax: parameter_names_with_default_values: - DTYPE: float + IN_DTYPE: float + OUT_DTYPE: float STORAGE: texture3d + MODE: llm generate_variant_forall: STORAGE: - VALUE: texture3d - VALUE: buffer - DTYPE: - - VALUE: float - - VALUE: half + combination: + parameter_names: [IN_DTYPE, OUT_DTYPE] + combos: + - parameter_values: [float, float] + suffix: float + - parameter_values: [half, half] + suffix: half shader_variants: - NAME: sdpa_attn_weights_softmax + - NAME: fused_sdpa_softmax + MODE: fused + IN_DTYPE: float + generate_variant_forall: + STORAGE: + - VALUE: texture3d + - VALUE: buffer + OUT_DTYPE: + - VALUE: float + - VALUE: half diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl index e50ca0612fd..b7f14f435fa 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl @@ -15,9 +15,13 @@ $if IO_STORAGE == "buffer": #define OUTPUT_BUFFER #define INPUT_BUFFER + #define ATTN_WEIGHTS_BUFFER $if K_CACHE_STORAGE == "buffer": #define K_CACHE_BUFFER +#define Q_LAYOUT DHSB +#define K_LAYOUT DHSB + #define TILE_K4 ${TILE_K4} #define TILE_N4 ${TILE_N4} @@ -34,11 +38,11 @@ layout(std430) buffer; #include "common.glslh" ${layout_declare_tensor(B, "w", "t_attn_weights", DTYPE, IO_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_q_projected", DTYPE, IO_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_k_cache", DTYPE, K_CACHE_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_q", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_k", DTYPE, K_CACHE_STORAGE, is_scalar_array=False)} -${layout_declare_ubo(B, "ivec4", "q_projected_sizes")} -${layout_declare_ubo(B, "ivec4", "k_cache_sizes")} +${layout_declare_ubo(B, "ivec4", "q_sizes")} +${layout_declare_ubo(B, "ivec4", "k_sizes")} ${layout_declare_ubo(B, "int", "input_pos")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -75,18 +79,20 @@ void main() { // 1. const int s = 0; + // head dimension + const int D = q_sizes.x; // texel size of head_dim, over which the dot product is accumulated - const int D4 = div_up_4(q_projected_sizes.x); + const int D4 = div_up_4(D); // number of Q heads - const int Q_H = q_projected_sizes.y; + const int Q_H = q_sizes.y; // sequence length - const int S = q_projected_sizes.z; + const int S = q_sizes.z; const int S_aligned = align_up_4(S); // number of K/V heads - const int KV_H = k_cache_sizes.y; + const int KV_H = k_sizes.y; // Max context length - const int C = k_cache_sizes.z; + const int C = k_sizes.z; const int C4 = div_up_4(C); int kv_h = q_h; @@ -126,8 +132,9 @@ void main() { s, q_h, D4, - Q_H, - S); + D, + S, + Q_H); load_k_cache_tile_with_checks( w_tile, @@ -135,6 +142,7 @@ void main() { c, kv_h, D4, + D, context_len, C, KV_H); diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl index a6703437c41..66d2ebb24c1 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl @@ -9,8 +9,14 @@ #version 450 core #define PRECISION ${PRECISION} -#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)} -#define T ${texel_load_component_type(DTYPE, IO_STORAGE)} + +#define IN_DTYPE ${IN_DTYPE} +#define OUT_DTYPE ${OUT_DTYPE} + +#define VEC4_T ${texel_load_type(IN_DTYPE, IO_STORAGE)} +#define T ${texel_load_component_type(IN_DTYPE, IO_STORAGE)} + +#define LINEAR_FP_OUTPUT_TILE_VEC4_T ${texel_load_type(OUT_DTYPE, IO_STORAGE)} $if IO_STORAGE == "buffer": #define OUTPUT_BUFFER @@ -18,6 +24,19 @@ $if IO_STORAGE == "buffer": $if K_CACHE_STORAGE == "buffer": #define K_CACHE_BUFFER +$if MODE == "llm": + #define HAS_INPUT_POS + #define HAS_GQA + #define Q_LAYOUT DHSB + #define K_LAYOUT DHSB +$else: + #define SDPA_PAD_D + #define Q_LAYOUT DSHB + #define K_LAYOUT DSHB + +$if HAS_BIAS: + #define HAS_BIAS + #define TILE_M4 ${TILE_M4} #define TILE_K4 ${TILE_K4} #define TILE_N4 ${TILE_N4} @@ -26,19 +45,24 @@ $if K_CACHE_STORAGE == "buffer": #define TILE_K ${TILE_K4 * 4} #define TILE_N ${TILE_N4 * 4} -${define_required_extensions(IO_STORAGE, DTYPE)} +${define_required_extensions(IO_STORAGE, [IN_DTYPE, OUT_DTYPE])} layout(std430) buffer; #include "common.glslh" -${layout_declare_tensor(B, "w", "t_attn_weights", DTYPE, IO_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_q_projected", DTYPE, IO_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_k_cache", DTYPE, K_CACHE_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "w", "t_attn_weights", OUT_DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_q", IN_DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_k", IN_DTYPE, K_CACHE_STORAGE, is_scalar_array=False)} +$if HAS_BIAS: + ${layout_declare_tensor(B, "r", "t_bias", IN_DTYPE, IO_STORAGE, is_scalar_array=False)} -${layout_declare_ubo(B, "ivec4", "q_projected_sizes")} -${layout_declare_ubo(B, "ivec4", "k_cache_sizes")} -${layout_declare_ubo(B, "int", "input_pos")} +${layout_declare_ubo(B, "ivec4", "q_sizes")} +${layout_declare_ubo(B, "ivec4", "k_sizes")} +$if MODE == "llm": + ${layout_declare_ubo(B, "int", "input_pos")} +$if HAS_BIAS: + ${layout_declare_ubo(B, "ivec4", "bias_sizes")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -50,33 +74,28 @@ ${layout_declare_spec_const(C, "float", "inv_scale", "1.0")} #include "sdpa_fp_attn_weight_tile_store.glslh" /* - * Compute attention weights given the q_projected and k_cache tensors. - * q_projected has shape (batches, seq_len, num_q_heads, head_dim) - * k_cache has shape (batches, max_context_len, num_kv_heads, head_dim) - * output has shape (batches, num_q_heads, seq_len, context_len) - * - * This shader also applies scales and masking to the computed attention - * weights. + * Compute attention weights (Q @ K^T) given the Q and K tensors. * - * The scale applied is 1.0 / sqrt(head_dim_length). + * LLM SDPA (HAS_INPUT_POS, HAS_GQA): + * q: [B, S, H_q, D] (DHSB layout) + * k (k_cache): [B, C_max, H_kv, D] (DHSB layout) + * attn_weights: [B, H_q, S, context_len] in input dtype + * current context_len = input_pos + S + * Applies combined scale + causal mask. * - * The mask applied is a bit more complicated. Imagine you create a square - * matrix of size (input_pos + seq_len, input_pos + seq_len), and then set the - * lower triangular section of the matrix to -inf. Then, slice the matrix along - * the row dimension starting from input_pos to input_pos + seq_len. You end up - * with a partial mask with size (seq_len, input_pos + seq_len). This is the - * mask that is applied to the attention weight. - * - * In the shader, instead of generating the mask, the index of the elment is - * inspected to determine if it would have been masked. Given an element at - * tensor index (n, c, h, w), it would be masked if w < h + input_pos. + * Fused SDPA: + * q: [B, H, S, D] (DSHB layout) + * k: [B, H, L, D] (DSHB layout) + * attn_weights: [B, H, S, L] in fp32 to prevent fp16 overflow in Q@K^T + * Applies scalar scale, optionally adds bias. * + * Dispatch: (context_tiles, S_tiles, H * B) — for LLM (batch=1), H * B == H_q. */ void main() { const int tile_idx_x = int(gl_GlobalInvocationID.x); const int tile_idx_y = int(gl_GlobalInvocationID.y); - // idx along output num_q_heads dim + // For LLM: q_head index. For fused: combined batch*H + head index. const int q_h = int(gl_GlobalInvocationID.z); // idx along the output context_len dim @@ -85,32 +104,48 @@ void main() { // idx along the output seq_len dim const int s = tile_idx_y * TILE_M; - const int s4 = div_4(s); - - // texel size of head_dim, over which the dot product is accumulated - const int D4 = div_up_4(q_projected_sizes.x); - // number of Q heads - const int Q_H = q_projected_sizes.y; - // sequence length - const int S = q_projected_sizes.z; + +#ifdef HAS_INPUT_POS + // LLM: q_sizes is WHCN {D, H_q, S, B} + const int D = q_sizes.x; + const int Q_H = q_sizes.y; + const int S = q_sizes.z; + // k_sizes is WHCN {D, H_kv, C_max, B} + const int KV_H = k_sizes.y; + const int C = k_sizes.z; +#else + // Fused: q_sizes is WHCN {D, S, H, B} + const int D = q_sizes.x; + const int S = q_sizes.y; + const int Q_H = q_sizes.z; + // k_sizes is WHCN {D, L, H, B} + const int KV_H = k_sizes.z; + const int C = k_sizes.y; +#endif + const int D4 = div_up_4(D); const int S_aligned = align_up_4(S); - // number of K/V heads - const int KV_H = k_cache_sizes.y; - // Max context length - const int C = k_cache_sizes.z; - const int C4 = div_up_4(C); +#ifdef HAS_INPUT_POS + // current context length for LLM decode/prefill + const int context_len = input_pos + S; +#else + // fused: full key sequence length from k_sizes + const int context_len = k_sizes.y; +#endif + const int context_texel_len = div_up_4(context_len); +#ifdef HAS_GQA int kv_h = q_h; if (KV_H < Q_H) { kv_h = q_h / (Q_H / KV_H); } +#else + const int kv_h = q_h; +#endif - const int context_len = input_pos + S; - const int context_texel_len = div_up_4(context_len); - - // bounds check - if (c >= context_len || s >= S || q_h >= Q_H) { + // bounds check — q_h bound is Q_H * batch_size; for LLM (batch=1) this + // equals Q_H, for fused this equals H * B. + if (c >= context_len || s >= S || q_h >= Q_H * q_sizes.w) { return; } @@ -120,6 +155,16 @@ void main() { FPInputTile q_tile; FPWeightTile w_tile; + // The LLM attn_weights tensor is padded to S_aligned in its S dim, while + // fused attn_weights is not padded. The store/bias helpers bound-check + // against this. +#ifdef HAS_INPUT_POS + const int attn_S = S_aligned; +#else + const int attn_S = S; +#endif + +#ifdef HAS_INPUT_POS // If the tile is completely inside the mask region, then there is no need to // compute the output tile. All the elements in the output tile can be set to // negative infinity. @@ -127,9 +172,9 @@ void main() { if (tile_in_mask_region) { const VEC4_T negative_infinity_vec = VEC4_T(negative_infinity_val); set_out_tile_to_vec(out_tile, negative_infinity_vec); - } - // Otherwise, need to actually compute output tile - else { + } else +#endif + { for (int d4 = 0; d4 < D4; d4++) { load_q_projected_tile_with_checks( q_tile, @@ -137,8 +182,9 @@ void main() { s, q_h, D4, - Q_H, - S); + D, + S, + Q_H); load_k_cache_tile_with_checks( w_tile, @@ -146,15 +192,16 @@ void main() { c, kv_h, D4, + D, context_len, C, KV_H); - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); } - // Apply scale and mask +#ifdef HAS_INPUT_POS + // LLM: combined scale + causal mask VEC4_T inv_scale_vec = VEC4_T(inv_scale); apply_scale_and_mask( out_tile, @@ -162,6 +209,13 @@ void main() { input_pos, c, s); +#else + // Fused: scalar scale, optional bias + apply_scale(out_tile, inv_scale); + #ifdef HAS_BIAS + apply_bias(out_tile, c4, s, q_h, context_texel_len, attn_S); + #endif +#endif } store_attn_weight_tile_with_checks( @@ -170,6 +224,6 @@ void main() { s, q_h, context_texel_len, - S_aligned, + attn_S, Q_H); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml index 7fc016cf3c3..24494b408fa 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml @@ -6,9 +6,12 @@ sdpa_compute_attn_weights_tiled: parameter_names_with_default_values: - DTYPE: float + IN_DTYPE: float + OUT_DTYPE: float IO_STORAGE: texture3d K_CACHE_STORAGE: texture3d + MODE: llm + HAS_BIAS: false TILE_M4: 1 TILE_K4: 1 TILE_N4: 1 @@ -19,8 +22,39 @@ sdpa_compute_attn_weights_tiled: - parameter_values: [texture3d, texture3d] - parameter_values: [buffer, texture3d] - parameter_values: [buffer, buffer] - DTYPE: - - VALUE: float - - VALUE: half + combination1: + parameter_names: [IN_DTYPE, OUT_DTYPE] + combos: + - parameter_values: [float, float] + suffix: float + - parameter_values: [half, half] + suffix: half shader_variants: - NAME: sdpa_compute_attn_weights_tiled + - NAME: fused_sdpa_qk_tiled + MODE: fused + OUT_DTYPE: float + generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, K_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] + IN_DTYPE: + - VALUE: float + - VALUE: half + - NAME: fused_sdpa_qk_tiled_bias + MODE: fused + OUT_DTYPE: float + HAS_BIAS: true + generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, K_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] + IN_DTYPE: + - VALUE: float + - VALUE: half diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl index 2e5fda18e14..cd2c689ebc8 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl @@ -15,9 +15,14 @@ $if IO_STORAGE == "buffer": #define OUTPUT_BUFFER #define INPUT_BUFFER + #define ATTN_WEIGHTS_BUFFER $if V_CACHE_STORAGE == "buffer": #define V_CACHE_BUFFER +#define V_LAYOUT DHSB +#define OUT_LAYOUT DHSB +#define SDPA_V_BUF t_v_cache + #define TILE_K4 ${TILE_K4} #define TILE_N4 ${TILE_N4} @@ -37,8 +42,8 @@ ${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=F ${layout_declare_tensor(B, "r", "t_attn_weights", DTYPE, IO_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_v_cache", DTYPE, V_CACHE_STORAGE, is_scalar_array=False)} -${layout_declare_ubo(B, "ivec4", "q_projected_sizes")} -${layout_declare_ubo(B, "ivec4", "v_cache_sizes")} +${layout_declare_ubo(B, "ivec4", "q_sizes")} +${layout_declare_ubo(B, "ivec4", "v_sizes")} ${layout_declare_ubo(B, "int", "input_pos")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -76,17 +81,17 @@ void main() { const int s = 0; // texel size of head_dim - const int D4 = div_up_4(q_projected_sizes.x); + const int D4 = div_up_4(q_sizes.x); // number of Q heads - const int Q_H = q_projected_sizes.y; + const int Q_H = q_sizes.y; // sequence length - const int S = q_projected_sizes.z; + const int S = q_sizes.z; const int S_aligned = align_up_4(S); // number of K/V heads - const int KV_H = v_cache_sizes.y; + const int KV_H = v_sizes.y; // Max context length - const int C = v_cache_sizes.z; + const int C = v_sizes.z; const int C4 = div_up_4(C); int kv_h = q_h; diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl index 2027a9908a9..9f8f2dbc231 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl @@ -18,6 +18,15 @@ $if IO_STORAGE == "buffer": $if V_CACHE_STORAGE == "buffer": #define V_CACHE_BUFFER +$if MODE == "llm": + #define HAS_INPUT_POS + #define HAS_GQA + #define V_LAYOUT DHSB + #define OUT_LAYOUT DHSB +$else: + #define V_LAYOUT DSHB + #define OUT_LAYOUT DSHB + #define TILE_M4 ${TILE_M4} // Equvalent to K4 in matrix multiplication #define TILE_K4 ${TILE_K4} @@ -36,11 +45,12 @@ layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_attn_weights", DTYPE, IO_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_v_cache", DTYPE, V_CACHE_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_v", DTYPE, V_CACHE_STORAGE, is_scalar_array=False)} -${layout_declare_ubo(B, "ivec4", "q_projected_sizes")} -${layout_declare_ubo(B, "ivec4", "v_cache_sizes")} -${layout_declare_ubo(B, "int", "input_pos")} +${layout_declare_ubo(B, "ivec4", "q_sizes")} +${layout_declare_ubo(B, "ivec4", "v_sizes")} +$if MODE == "llm": + ${layout_declare_ubo(B, "int", "input_pos")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -50,16 +60,27 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; #include "sdpa_fp_out_tile_store.glslh" /* - * Compute SDPA output given the attention weights and v_cache tensors. - * attention weights has shape (batches, num_q_heads, seq_len, context_len) - * v_cache has shape (batches, max_context_len, num_kv_heads, head_dim) - * output has shape (batches, seq_len, num_q_heads, head_dim) + * Compute SDPA output given the attention weights and V tensors. + * + * LLM SDPA (HAS_INPUT_POS, HAS_GQA): + * attn_weights: [B, H_q, S, context_len] + * v (v_cache): [B, C_max, H_kv, D] (DHSB layout) + * output: [B, S, H_q, D] (DHSB layout) + * current context_len = input_pos + S + * GQA: Q heads may be > KV heads; kv_h = q_h / (H_q / H_kv) + * + * Fused SDPA: + * attn_weights: [B, H, S, context_len] + * v: [B, H, context_len, D] (DSHB layout) + * output: [B, H, S, D] (DSHB layout) + * + * Dispatch: (D_tiles, S_tiles, H * B) — for LLM (batch=1), H * B == H_q. */ void main() { const int tile_idx_x = int(gl_GlobalInvocationID.x); const int tile_idx_y = int(gl_GlobalInvocationID.y); - // idx along output num_q_heads dim + // For LLM: q_head index. For fused: combined batch*H + head index. const int q_h = int(gl_GlobalInvocationID.z); // idx along the output head_dim dim @@ -69,31 +90,47 @@ void main() { // idx along the output seq_len dim const int s = tile_idx_y * TILE_M; - // texel size of head_dim - const int D4 = div_up_4(q_projected_sizes.x); - // number of Q heads - const int Q_H = q_projected_sizes.y; - // sequence length - const int S = q_projected_sizes.z; +#ifdef HAS_INPUT_POS + // LLM: q_sizes is WHCN {D, H_q, S, B} + const int D = q_sizes.x; + const int Q_H = q_sizes.y; + const int S = q_sizes.z; + // v_sizes is WHCN {D, H_kv, C_max, B} + const int KV_H = v_sizes.y; + const int C = v_sizes.z; +#else + // Fused: q_sizes is WHCN {D, S, H, B} + const int D = q_sizes.x; + const int S = q_sizes.y; + const int Q_H = q_sizes.z; + // v_sizes is WHCN {D, context_len, H, B} + const int KV_H = v_sizes.z; + const int C = v_sizes.y; +#endif + const int D4 = div_up_4(D); const int S_aligned = align_up_4(S); - // number of K/V heads - const int KV_H = v_cache_sizes.y; - // Max context length - const int C = v_cache_sizes.z; - const int C4 = div_up_4(C); +#ifdef HAS_INPUT_POS + // current context length for LLM decode/prefill + const int context_len = input_pos + S; +#else + // fused: full key sequence length from v_sizes (DSHB: {D, L, H, B}) + const int context_len = v_sizes.y; +#endif + const int context_texel_len = div_up_4(context_len); +#ifdef HAS_GQA int kv_h = q_h; if (KV_H < Q_H) { kv_h = q_h / (Q_H / KV_H); } +#else + const int kv_h = q_h; +#endif - // current context length - const int context_len = input_pos + S; - const int context_texel_len = div_up_4(context_len); - - // bounds check - if (d4 >= D4 || s >= S || q_h >= Q_H) { + // bounds check — q_h bound is Q_H * batch_size; for LLM (batch=1) this + // equals Q_H, for fused this equals H * B. + if (d4 >= D4 || s >= S || q_h >= Q_H * q_sizes.w) { return; } @@ -103,62 +140,33 @@ void main() { FPInputTile attn_weight_tile; FPWeightTile w_tile; + // For LLM, the attn_weights tensor has seq_len padded up to a multiple of 4 + // (S_aligned). The loader accesses (head * attn_S * C4 + s * C4 + c4), so + // pass S_aligned in LLM mode and S in fused mode. +#ifdef HAS_INPUT_POS + const int attn_S = S_aligned; +#else + const int attn_S = S; +#endif + + // Split loop into aligned + tail for efficiency const int context_len_aligned_down = context_len - mod_4(context_len); const int C4_limit = div_4(context_len_aligned_down); for (int c4 = 0; c4 < C4_limit; c4++) { const int c = mul_4(c4); load_attn_weight_tile_no_checks( - attn_weight_tile, - c4, - s, - q_h, - context_texel_len, - S_aligned, - Q_H); - - load_v_cache_tile_no_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - + attn_weight_tile, c4, s, q_h, context_texel_len, attn_S, Q_H); + load_v_cache_tile_no_checks(w_tile, d4, c, kv_h, D4, context_len, C, KV_H); fp_accumulate_with_fp_weight(out_tile, attn_weight_tile, w_tile); } for (int c4 = C4_limit; c4 < context_texel_len; c4++) { const int c = mul_4(c4); load_attn_weight_tile_with_checks( - attn_weight_tile, - c4, - s, - q_h, - context_texel_len, - S_aligned, - Q_H); - - load_v_cache_tile_with_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - + attn_weight_tile, c4, s, q_h, context_texel_len, attn_S, Q_H); + load_v_cache_tile_with_checks(w_tile, d4, c, kv_h, D4, context_len, C, KV_H); fp_accumulate_with_fp_weight(out_tile, attn_weight_tile, w_tile); } - store_sdpa_out_tile_with_checks( - out_tile, - d4, - s, - q_h, - D4, - S, - Q_H); + store_sdpa_out_tile_with_checks(out_tile, d4, s, q_h, D4, S, Q_H); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml index eac2c6f37dd..ba91114ae92 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml @@ -9,6 +9,7 @@ sdpa_compute_out_tiled: DTYPE: float IO_STORAGE: texture3d V_CACHE_STORAGE: texture3d + MODE: llm TILE_M4: 1 TILE_K4: 1 TILE_N4: 1 @@ -24,3 +25,15 @@ sdpa_compute_out_tiled: - VALUE: half shader_variants: - NAME: sdpa_compute_out_tiled + - NAME: fused_sdpa_av_tiled + MODE: fused + generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, V_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] + DTYPE: + - VALUE: float + - VALUE: half diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_load.glslh index 12b2292fa45..829f03beb60 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_load.glslh @@ -7,11 +7,19 @@ */ /* - * Assume the following variables are defined in the shader layout: - * - t_attn_weights + * Shared attention weight tile load for both LLM SDPA and fused SDPA + * (used in the AV shader to read softmax output). * - * Macro Settings: - * - INPUT_BUFFER + * The attn_weights tensor layout is [head, S, L] in both cases: + * index = (head * S * L4) + (s * L4) + l4 + * + * No layout switch needed — both variants use the same index formula. + * + * Optional macros: + * INPUT_BUFFER — use buffer path; otherwise texture. Set at the shader + * level when IO_STORAGE == "buffer" and applies to all + * IO tensors uniformly (attn_weights now follows the + * output's storage type). */ #ifndef SDPA_FP_ATTN_WEIGHT_TILE_LOAD_GLSLH @@ -21,7 +29,7 @@ #include "linear_fp_input_tile.glslh" -VEC4_T load_attn_weight_c4( +LINEAR_FP_INPUT_TILE_VEC4_T load_attn_weight_c4( const int c4, const int s, const int q_h, @@ -44,7 +52,7 @@ void load_attn_weight_tile_no_checks( const int S, const int Q_H) { [[unroll]] for (int s = 0; s < TILE_M; ++s) { - [[unroll]] for (int c4 = 0; c4 < TILE_N4; ++c4) { + [[unroll]] for (int c4 = 0; c4 < TILE_K4; ++c4) { tile.data[s][c4] = load_attn_weight_c4(c4_start + c4, s_start + s, q_h, C4, S, Q_H); } @@ -60,12 +68,12 @@ void load_attn_weight_tile_with_checks( const int S, const int Q_H) { [[unroll]] for (int s = 0; s < TILE_M; ++s) { - [[unroll]] for (int c4 = 0; c4 < TILE_N4; ++c4) { + [[unroll]] for (int c4 = 0; c4 < TILE_K4; ++c4) { if (c4_start + c4 < C4 && s_start + s < S) { tile.data[s][c4] = load_attn_weight_c4(c4_start + c4, s_start + s, q_h, C4, S, Q_H); } else { - tile.data[s][c4] = VEC4_T(0.0); + tile.data[s][c4] = LINEAR_FP_INPUT_TILE_VEC4_T(0.0); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_store.glslh index c64d9af8cfb..c28bc2f3ac2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_store.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_store.glslh @@ -7,24 +7,39 @@ */ /* - * Assume the following variables are defined in the shader layout: - * - t_attn_weights + * Shared attention weight tile store for both LLM SDPA and fused SDPA. * - * Macro Settings: - * - OUTPUT_BUFFER + * The attn_weights tensor layout is [head, S, L] in both cases: + * index = (head * S * L4) + (s * L4) + l4 + * + * Tile precision is controlled by the caller via LINEAR_FP_OUTPUT_TILE_VEC4_T + * (fused SDPA sets it to vec4 for fp32 accumulation; LLM SDPA leaves it as + * VEC4_T for input-dtype accumulation). All helper functions below are + * available at all times and compile correctly in both modes. The only + * gated helper is apply_bias, which requires t_bias/bias_sizes and is + * therefore guarded by HAS_BIAS. + * + * Required macros/variables: + * t_attn_weights — output buffer/texture + * OUTPUT_BUFFER — buffer mode (otherwise texture). Set at the shader + * level when IO_STORAGE == "buffer" and applies to all + * IO tensors uniformly (attn_weights now follows the + * output's storage type). */ -#ifndef SDPA_FP_ATTN_WEIGHT_TILE_LOAD_GLSLH -#define SDPA_FP_ATTN_WEIGHT_TILE_LOAD_GLSLH +#ifndef SDPA_FP_ATTN_WEIGHT_TILE_STORE_GLSLH +#define SDPA_FP_ATTN_WEIGHT_TILE_STORE_GLSLH #extension GL_EXT_control_flow_attributes : require #include "linear_fp_output_tile.glslh" -T negative_infinity_val = T(-1.0 / 0.0); +// ============================================================ +// Shared store helpers (buffer or texture; fp32 or input-dtype) +// ============================================================ void store_attn_weight_c4( - const VEC4_T out_texel, + const LINEAR_FP_OUTPUT_TILE_VEC4_T out_texel, const int c4, const int s, const int q_h, @@ -72,7 +87,60 @@ void store_attn_weight_tile_with_checks( } } -void set_out_tile_to_vec(out FPOutTile tile, const VEC4_T vec) { +// ============================================================ +// Tile transform helpers (scale, bias, mask, set) +// ============================================================ + +T negative_infinity_val = T(-1.0 / 0.0); + +void apply_scale(inout FPOutTile tile, const float scale) { + const LINEAR_FP_OUTPUT_TILE_VEC4_T scale_vec = + LINEAR_FP_OUTPUT_TILE_VEC4_T(scale); + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + tile.data[m][n4] = tile.data[m][n4] * scale_vec; + } + } +} + +#ifdef HAS_BIAS +void apply_bias( + inout FPOutTile tile, + const int c4_start, + const int s_start, + const int bh, + const int C4, + const int S) { + const int bias_C4 = div_up_4(bias_sizes.x); + const int bias_S = bias_sizes.y; + const int bias_BH = bias_sizes.z * bias_sizes.w; + + [[unroll]] for (int s = 0; s < TILE_M; ++s) { + [[unroll]] for (int c4 = 0; c4 < TILE_N4; ++c4) { + if (c4_start + c4 < C4 && s_start + s < S) { + const int bias_bh = bh < bias_BH ? bh : 0; + const int bias_s = (s_start + s) < bias_S ? (s_start + s) : 0; + const int bias_c4 = c4_start + c4; + if (bias_c4 < bias_C4) { +#ifdef INPUT_BUFFER + const LINEAR_FP_OUTPUT_TILE_VEC4_T bias_val = + LINEAR_FP_OUTPUT_TILE_VEC4_T(VEC4_T( + t_bias[(bias_bh * bias_S * bias_C4) + (bias_s * bias_C4) + + bias_c4])); +#else + const LINEAR_FP_OUTPUT_TILE_VEC4_T bias_val = + LINEAR_FP_OUTPUT_TILE_VEC4_T(VEC4_T( + texelFetch(t_bias, ivec3(bias_c4, bias_s, bias_bh), 0))); +#endif + tile.data[s][c4] = tile.data[s][c4] + bias_val; + } + } + } + } +} +#endif // HAS_BIAS + +void set_out_tile_to_vec(out FPOutTile tile, const LINEAR_FP_OUTPUT_TILE_VEC4_T vec) { [[unroll]] for (int s = 0; s < TILE_M; ++s) { [[unroll]] for (int c4 = 0; c4 < TILE_N4; ++c4) { tile.data[s][c4] = vec; } } @@ -80,7 +148,7 @@ void set_out_tile_to_vec(out FPOutTile tile, const VEC4_T vec) { void apply_scale_and_mask( inout FPOutTile tile, - const VEC4_T inv_scale_vec, + const LINEAR_FP_OUTPUT_TILE_VEC4_T inv_scale_vec, const int input_pos, const int c_idx_start, const int s_idx_start) { @@ -102,4 +170,4 @@ void apply_scale_and_mask( } } -#endif // SDPA_FP_ATTN_WEIGHT_TILE_LOAD_GLSLH +#endif // SDPA_FP_ATTN_WEIGHT_TILE_STORE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh index 1880397181d..65d08755528 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh @@ -7,32 +7,74 @@ */ /* - * Assume the following variables are defined in the shader layout: - * - t_k_cache + * Shared K transposed tile load for both LLM SDPA and fused SDPA. + * Loads K[head, l, d4] and transposes in-place so the tile represents K^T. * - * Macro Settings: - * - K_CACHE_BUFFER + * Layout selection (caller must #define K_LAYOUT before including): + * DSHB (0) — fused SDPA: [B, H, L, D] → WHCN {D, L, H, B} + * index = (head * L + l) * D4 + d4 + * DHSB (1) — LLM SDPA: [B, L, H, D] → WHCN {D, H, L, B} + * index = (l * H + head) * D4 + d4 + * + * Optional macros: + * K_CACHE_BUFFER / K_BUFFER — use buffer path; otherwise texture + * SDPA_PAD_D — zero-pad last d4 texel when D % 4 != 0 */ #ifndef SDPA_FP_K_CACHE_TILE_LOAD_GLSLH #define SDPA_FP_K_CACHE_TILE_LOAD_GLSLH +#ifndef DSHB +#define DSHB 0 +#define DHSB 1 +#endif + #extension GL_EXT_control_flow_attributes : require #include "linear_fp_weight_tile.glslh" -VEC4_T load_k_cache_d4( +// Determine whether buffer mode is active. Both K_CACHE_BUFFER (LLM) and +// K_BUFFER (fused) activate the buffer path. +#if defined(K_CACHE_BUFFER) || defined(K_BUFFER) +#define _SDPA_K_USE_BUFFER +#endif + +LINEAR_FP_WEIGHT_TILE_VEC4_T load_k_cache_d4( const int d4, const int c, const int kv_h, const int D4, + const int D, const int C, const int KV_H) { -#ifdef K_CACHE_BUFFER - return VEC4_T(t_k_cache[(c * KV_H * D4) + (kv_h * D4) + d4]); -#else - return VEC4_T(texelFetch(t_k_cache, ivec3(d4, kv_h, c), 0)); + LINEAR_FP_WEIGHT_TILE_VEC4_T val; + +#ifdef _SDPA_K_USE_BUFFER + #if K_LAYOUT == DSHB + val = LINEAR_FP_WEIGHT_TILE_VEC4_T(t_k[(kv_h * C * D4) + (c * D4) + d4]); + #elif K_LAYOUT == DHSB + val = LINEAR_FP_WEIGHT_TILE_VEC4_T(t_k[(c * KV_H * D4) + (kv_h * D4) + d4]); + #endif +#else // texture + #if K_LAYOUT == DSHB + val = LINEAR_FP_WEIGHT_TILE_VEC4_T(texelFetch(t_k, ivec3(d4, c, kv_h), 0)); + #elif K_LAYOUT == DHSB + val = LINEAR_FP_WEIGHT_TILE_VEC4_T(texelFetch(t_k, ivec3(d4, kv_h, c), 0)); + #endif #endif + +#ifdef SDPA_PAD_D + if (d4 == D4 - 1) { + const int valid = D - mul_4(d4); + [[unroll]] for (int i = 0; i < 4; ++i) { + if (i >= valid) { + val[i] = LINEAR_FP_WEIGHT_TILE_VEC4_T(0)[i]; + } + } + } +#endif + + return val; } void load_k_cache_tile_no_checks( @@ -41,6 +83,7 @@ void load_k_cache_tile_no_checks( const int c_start, const int kv_h, const int D4, + const int D, const int context_len, const int C, const int KV_H) { @@ -48,8 +91,8 @@ void load_k_cache_tile_no_checks( const int c4 = div_4(c); const int c4i = mod_4(c); [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { - VEC4_T d4_row = - load_k_cache_d4(d4_start + d4, c_start + c, kv_h, D4, C, KV_H); + LINEAR_FP_WEIGHT_TILE_VEC4_T d4_row = + load_k_cache_d4(d4_start + d4, c_start + c, kv_h, D4, D, C, KV_H); // Transpose in-place const int d_base = mul_4(d4); @@ -67,6 +110,7 @@ void load_k_cache_tile_with_checks( const int c_start, const int kv_h, const int D4, + const int D, const int context_len, const int C, const int KV_H) { @@ -74,9 +118,9 @@ void load_k_cache_tile_with_checks( const int c4 = div_4(c); const int c4i = mod_4(c); [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { - VEC4_T d4_row = VEC4_T(0.0); + LINEAR_FP_WEIGHT_TILE_VEC4_T d4_row = LINEAR_FP_WEIGHT_TILE_VEC4_T(0.0); if (d4_start + d4 < D4 && c_start + c < context_len) { - d4_row = load_k_cache_d4(d4_start + d4, c_start + c, kv_h, D4, C, KV_H); + d4_row = load_k_cache_d4(d4_start + d4, c_start + c, kv_h, D4, D, C, KV_H); } // Transpose in-place diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_out_tile_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_out_tile_store.glslh index 17e0988a6a4..382747bf9a4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_out_tile_store.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_out_tile_store.glslh @@ -7,22 +7,33 @@ */ /* - * Assume the following variables are defined in the shader layout: - * - t_attn_weights + * Shared output tile store for both LLM SDPA and fused SDPA. * - * Macro Settings: - * - OUTPUT_BUFFER + * Layout selection (caller must #define OUT_LAYOUT before including): + * DSHB (0) — fused SDPA: [B, H, S, D] → WHCN {D, S, H, B} + * index = (head * S + s) * D4 + d4 + * DHSB (1) — LLM SDPA: [B, S, H, D] → WHCN {D, H, S, B} + * index = (s * H + head) * D4 + d4 + * + * Required macros/variables: + * t_output — output tensor binding + * OUTPUT_BUFFER — buffer mode (otherwise texture) */ -#ifndef SDPA_FP_OUT_TILE_LOAD_GLSLH -#define SDPA_FP_OUT_TILE_LOAD_GLSLH +#ifndef SDPA_FP_OUT_TILE_STORE_GLSLH +#define SDPA_FP_OUT_TILE_STORE_GLSLH + +#ifndef DSHB +#define DSHB 0 +#define DHSB 1 +#endif #extension GL_EXT_control_flow_attributes : require #include "linear_fp_output_tile.glslh" void store_out_d4( - const VEC4_T out_texel, + const LINEAR_FP_OUTPUT_TILE_VEC4_T out_texel, const int d4, const int q_h, const int s, @@ -30,9 +41,17 @@ void store_out_d4( const int Q_H, const int S) { #ifdef OUTPUT_BUFFER - t_output[(s * Q_H * D4) + (q_h * D4) + d4] = out_texel; -#else - imageStore(t_output, ivec3(d4, q_h, s), out_texel); + #if OUT_LAYOUT == DSHB + t_output[(q_h * S * D4) + (s * D4) + d4] = VEC4_T(out_texel); + #elif OUT_LAYOUT == DHSB + t_output[(s * Q_H * D4) + (q_h * D4) + d4] = VEC4_T(out_texel); + #endif +#else // texture + #if OUT_LAYOUT == DSHB + imageStore(t_output, ivec3(d4, s, q_h), VEC4_T(out_texel)); + #elif OUT_LAYOUT == DHSB + imageStore(t_output, ivec3(d4, q_h, s), VEC4_T(out_texel)); + #endif #endif } @@ -54,4 +73,4 @@ void store_sdpa_out_tile_with_checks( } } -#endif // SDPA_FP_OUT_TILE_LOAD_GLSLH +#endif // SDPA_FP_OUT_TILE_STORE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_q_projected_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_q_projected_tile_load.glslh index a304e5019e9..752762b623d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_q_projected_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_q_projected_tile_load.glslh @@ -7,32 +7,65 @@ */ /* - * Assume the following variables are defined in the shader layout: - * - t_input + * Shared Q tile load for both LLM SDPA and fused SDPA. * - * Macro Settings: - * - INPUT_BUFFER + * Layout selection (caller must #define Q_LAYOUT before including): + * DSHB (0) — fused SDPA: [B, H, S, D] → WHCN {D, S, H, B} + * index = (head * S + s) * D4 + d4 + * DHSB (1) — LLM SDPA: [B, S, H, D] → WHCN {D, H, S, B} + * index = (s * H + head) * D4 + d4 + * + * Optional macros: + * INPUT_BUFFER — use buffer path; otherwise texture + * SDPA_PAD_D — zero-pad last d4 texel when D % 4 != 0 */ #ifndef SDPA_FP_Q_PROJECTED_TILE_LOAD_GLSLH #define SDPA_FP_Q_PROJECTED_TILE_LOAD_GLSLH +#define DSHB 0 +#define DHSB 1 + #extension GL_EXT_control_flow_attributes : require #include "linear_fp_input_tile.glslh" -VEC4_T load_q_projected_d4( +LINEAR_FP_INPUT_TILE_VEC4_T load_q_projected_d4( const int d4, - const int q_h, const int s, + const int q_h, const int D4, - const int Q_H, - const int S) { + const int D, + const int S, + const int Q_H) { + LINEAR_FP_INPUT_TILE_VEC4_T val; + #ifdef INPUT_BUFFER - return t_q_projected[(s * Q_H * D4) + (q_h * D4) + d4]; -#else - return texelFetch(t_q_projected, ivec3(d4, q_h, s), 0); + #if Q_LAYOUT == DSHB + val = LINEAR_FP_INPUT_TILE_VEC4_T(t_q[(q_h * S * D4) + (s * D4) + d4]); + #elif Q_LAYOUT == DHSB + val = LINEAR_FP_INPUT_TILE_VEC4_T(t_q[(s * Q_H * D4) + (q_h * D4) + d4]); + #endif +#else // texture + #if Q_LAYOUT == DSHB + val = LINEAR_FP_INPUT_TILE_VEC4_T(texelFetch(t_q, ivec3(d4, s, q_h), 0)); + #elif Q_LAYOUT == DHSB + val = LINEAR_FP_INPUT_TILE_VEC4_T(texelFetch(t_q, ivec3(d4, q_h, s), 0)); + #endif #endif + +#ifdef SDPA_PAD_D + if (d4 == D4 - 1) { + const int valid = D - mul_4(d4); + [[unroll]] for (int i = 0; i < 4; ++i) { + if (i >= valid) { + val[i] = T(0); + } + } + } +#endif + + return val; } void load_q_projected_tile_no_checks( @@ -41,12 +74,13 @@ void load_q_projected_tile_no_checks( const int s_start, const int q_h, const int D4, - const int Q_H, - const int S) { + const int D, + const int S, + const int Q_H) { [[unroll]] for (int s = 0; s < TILE_M; ++s) { [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { tile.data[s][d4] = - load_q_projected_d4(d4_start + d4, q_h, s_start + s, D4, Q_H, S); + load_q_projected_d4(d4_start + d4, s_start + s, q_h, D4, D, S, Q_H); } } } @@ -57,15 +91,16 @@ void load_q_projected_tile_with_checks( const int s_start, const int q_h, const int D4, - const int Q_H, - const int S) { + const int D, + const int S, + const int Q_H) { [[unroll]] for (int s = 0; s < TILE_M; ++s) { [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { if (d4_start + d4 < D4 && s_start + s < S) { tile.data[s][d4] = - load_q_projected_d4(d4_start + d4, q_h, s_start + s, D4, Q_H, S); + load_q_projected_d4(d4_start + d4, s_start + s, q_h, D4, D, S, Q_H); } else { - tile.data[s][d4] = VEC4_T(0.0); + tile.data[s][d4] = LINEAR_FP_INPUT_TILE_VEC4_T(0.0); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_v_cache_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_v_cache_tile_load.glslh index bf94b251c43..98516744b44 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_v_cache_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_v_cache_tile_load.glslh @@ -7,31 +7,60 @@ */ /* - * Assume the following variables are defined in the shader layout: - * - t_v_cache + * Shared V tile load for both LLM SDPA and fused SDPA (no transpose). * - * Macro Settings: - * - V_CACHE_BUFFER + * Layout selection (caller must #define V_LAYOUT before including): + * DSHB (0) — fused SDPA: [B, H, L, D] → WHCN {D, L, H, B} + * index = (head * L + l) * D4 + d4 + * DHSB (1) — LLM SDPA: [B, L, H, D] → WHCN {D, H, L, B} + * index = (l * H + head) * D4 + d4 + * + * Optional macros: + * SDPA_V_BUF — tensor name (default: t_v) + * V_CACHE_BUFFER / V_BUFFER — use buffer path; otherwise texture */ #ifndef SDPA_FP_V_CACHE_TILE_LOAD_GLSLH #define SDPA_FP_V_CACHE_TILE_LOAD_GLSLH +#ifndef DSHB +#define DSHB 0 +#define DHSB 1 +#endif + #extension GL_EXT_control_flow_attributes : require #include "linear_fp_weight_tile.glslh" -VEC4_T load_v_cache_d4( +#ifndef SDPA_V_BUF +#define SDPA_V_BUF t_v +#endif + +// Determine whether buffer mode is active. Both V_CACHE_BUFFER (LLM) and +// V_BUFFER (fused) activate the buffer path. +#if defined(V_CACHE_BUFFER) || defined(V_BUFFER) +#define _SDPA_V_USE_BUFFER +#endif + +LINEAR_FP_WEIGHT_TILE_VEC4_T load_v_cache_d4( const int d4, const int c, const int kv_h, const int D4, const int C, const int KV_H) { -#ifdef V_CACHE_BUFFER - return VEC4_T(t_v_cache[(c * KV_H * D4) + (kv_h * D4) + d4]); -#else - return VEC4_T(texelFetch(t_v_cache, ivec3(d4, kv_h, c), 0)); +#ifdef _SDPA_V_USE_BUFFER + #if V_LAYOUT == DSHB + return LINEAR_FP_WEIGHT_TILE_VEC4_T(SDPA_V_BUF[(kv_h * C * D4) + (c * D4) + d4]); + #elif V_LAYOUT == DHSB + return LINEAR_FP_WEIGHT_TILE_VEC4_T(SDPA_V_BUF[(c * KV_H * D4) + (kv_h * D4) + d4]); + #endif +#else // texture + #if V_LAYOUT == DSHB + return LINEAR_FP_WEIGHT_TILE_VEC4_T(texelFetch(SDPA_V_BUF, ivec3(d4, c, kv_h), 0)); + #elif V_LAYOUT == DHSB + return LINEAR_FP_WEIGHT_TILE_VEC4_T(texelFetch(SDPA_V_BUF, ivec3(d4, kv_h, c), 0)); + #endif #endif } @@ -44,8 +73,8 @@ void load_v_cache_tile_no_checks( const int context_len, const int C, const int KV_H) { - [[unroll]] for (int c = 0; c < TILE_N; ++c) { - [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { + [[unroll]] for (int c = 0; c < TILE_K; ++c) { + [[unroll]] for (int d4 = 0; d4 < TILE_N4; ++d4) { tile.data[c][d4] = load_v_cache_d4(d4_start + d4, c_start + c, kv_h, D4, C, KV_H); } @@ -61,13 +90,13 @@ void load_v_cache_tile_with_checks( const int context_len, const int C, const int KV_H) { - [[unroll]] for (int c = 0; c < TILE_N; ++c) { - [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { + [[unroll]] for (int c = 0; c < TILE_K; ++c) { + [[unroll]] for (int d4 = 0; d4 < TILE_N4; ++d4) { if (d4_start + d4 < D4 && c_start + c < context_len) { tile.data[c][d4] = load_v_cache_d4(d4_start + d4, c_start + c, kv_h, D4, C, KV_H); } else { - tile.data[c][d4] = VEC4_T(0.0); + tile.data[c][d4] = LINEAR_FP_WEIGHT_TILE_VEC4_T(0.0); } } } diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 65dae5e25cb..3eae85cd8aa 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -21,8 +21,67 @@ #include #include +#include + namespace vkcompute { +// +// SDPA mode: distinguishes the two dispatch families sharing this file. +// LLM — Llama-style KV-cache SDPA. Q layout [B=1, S, H, D] (DHSB). +// Separate k_cache/v_cache inputs + input_pos_symint for dynamic +// context_len. attn_weights are padded to multiples of 4 in the +// S/context_len dims and carry the input dtype. A coop (GEMV) +// shader variant is selected for single-token decode. +// FUSED — General SDPA fused op. Q layout [B, H, S, D] (DSHB). No cache, +// optional additive attn_mask, optional scale arg. attn_weights +// are unpadded and always fp32. Tiled shader variant only. +// +enum class SDPAMode { LLM, FUSED }; + +// +// Common dimension helper: folds the axis-swap for LLM vs fused Q layouts. +// `input_pos_symint` is used only for LLM (context_len = S + input_pos); +// pass kDummyValueRef for FUSED. +// +struct SDPADims { + int64_t B = 1; + int64_t H = 0; + int64_t S = 0; + int64_t D = 0; + int64_t context_len = 0; // LLM: S + input_pos_val; FUSED: size_at(-2, k) + int64_t max_context_len = 0; // LLM: size_at(-3, k); FUSED: size_at(-2, k) +}; + +SDPADims compute_sdpa_dims( + ComputeGraph& graph, + const ValueRef q, + const ValueRef k, + const ValueRef input_pos_symint, + const SDPAMode mode) { + SDPADims d; + d.D = graph.size_at(-1, q); + if (mode == SDPAMode::LLM) { + // Q: [B=1, S, H, D] (DHSB), K: [B=1, C_max, H_kv, D] + // `k` may be kDummyValueRef in dispatch pickers that don't need it; + // max_context_len is only read when k is valid. + d.B = 1; + d.H = graph.size_at(-2, q); + d.S = graph.size_at(-3, q); + d.max_context_len = is_valid(k) ? graph.size_at(-3, k) : 0; + const int32_t input_pos_val = + is_valid(input_pos_symint) ? graph.read_symint(input_pos_symint) : 0; + d.context_len = d.S + input_pos_val; + } else { + // Q: [B, H, S, D] (DSHB), K: [B, H_kv, L, D] + d.B = graph.size_at(-4, q); + d.H = graph.size_at(-3, q); + d.S = graph.size_at(-2, q); + d.context_len = graph.size_at(-2, k); + d.max_context_len = d.context_len; + } + return d; +} + bool is_single_token(ComputeGraph* graph, const ValueRef& q_projected) { return graph->size_at(-3, q_projected) == 1; } @@ -31,30 +90,42 @@ bool is_single_token(ComputeGraph* graph, const ValueRef& q_projected) { // Resize functions // -void resize_compute_attn_weights_node( +// Unified attn_weights resize. In LLM mode the shape is padded to multiples of +// 4 in the S/context_len dims (to match the tiled shader's iteration space); +// in fused mode it's the unpadded [B, H, S, L]. +// resize_args layout: [q, k, input_pos_symint_or_dummy, mode_as_int] +void resize_sdpa_attn_weights_node( ComputeGraph* graph, const std::vector& args, const std::vector& resize_args) { const ValueRef attn_weights = args.at(0).refs.at(0); - const ValueRef q_projected = args.at(1).refs.at(0); - const ValueRef input_pos_symint = resize_args.at(0); - - const uint32_t num_q_heads = graph->size_at(-2, q_projected); - const uint32_t seq_len = graph->size_at(-3, q_projected); - - const int32_t input_pos_val = graph->read_symint(input_pos_symint); - - const uint32_t context_len = seq_len + input_pos_val; - - std::vector out_sizes = { - 1, // batch - num_q_heads, - utils::align_up_4(seq_len), - utils::align_up_4(context_len)}; - + const ValueRef q = resize_args.at(0); + const ValueRef k = resize_args.at(1); + const ValueRef input_pos_symint = resize_args.at(2); + const SDPAMode mode = static_cast(resize_args.at(3)); + + std::vector out_sizes; + if (mode == SDPAMode::LLM) { + const int64_t num_q_heads = graph->size_at(-2, q); + const int64_t seq_len = graph->size_at(-3, q); + const int32_t input_pos_val = graph->read_symint(input_pos_symint); + const int64_t context_len = seq_len + input_pos_val; + out_sizes = { + 1, + num_q_heads, + static_cast(utils::align_up_4(seq_len)), + static_cast(utils::align_up_4(context_len))}; + } else { + const int64_t B = graph->size_at(-4, q); + const int64_t H = graph->size_at(-3, q); + const int64_t S = graph->size_at(-2, q); + const int64_t L = graph->size_at(-2, k); + out_sizes = {B, H, S, L}; + } graph->virtual_resize(attn_weights, out_sizes); } +// Softmax preserves attn_weights shape exactly; identical across modes. void resize_sdpa_attn_weights_softmax_node( ComputeGraph* graph, const std::vector& args, @@ -65,26 +136,15 @@ void resize_sdpa_attn_weights_softmax_node( graph->virtual_resize(attn_weights_softmax, graph->sizes_of(attn_weights)); } -void resize_sdpa_compute_out_node( +// Out matches Q's shape in both modes. resize_args[0] = q. +void resize_sdpa_out_node( ComputeGraph* graph, const std::vector& args, const std::vector& resize_args) { const ValueRef out = args.at(0).refs.at(0); - const ValueRef q_projected = resize_args.at(0); + const ValueRef q = resize_args.at(0); - graph->virtual_resize(out, graph->sizes_of(q_projected)); -} - -void resize_sdpa_out( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)args; - - int arg_idx = 0; - const ValueRef q_projected = extra_args[arg_idx++]; - const ValueRef out = extra_args[arg_idx++]; - graph->virtual_resize(out, graph->sizes_of(q_projected)); + graph->virtual_resize(out, graph->sizes_of(q)); } // @@ -108,167 +168,182 @@ utils::uvec3 kv_cache_update_global_wg_size( return {utils::div_up_4(head_dim_size), seq_len, num_heads}; } -utils::uvec3 attn_weight_scale_and_mask_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - - const ValueRef attn_weight = args.at(0).refs.at(0); - - if (graph->is_buffer_storage(attn_weight)) { - return { - graph->size_at(-1, attn_weight), - graph->size_at(-2, attn_weight), - graph->size_at(-3, attn_weight), - }; - } else { - return graph->logical_limits_of(attn_weight); - } +// resize_args layout for SDPA dispatch pickers mirrors the node creation +// helper: [q, k, input_pos_symint_or_dummy, mode_as_int]. +static inline SDPAMode mode_of(const std::vector& resize_args) { + return static_cast(resize_args.at(3)); } -vkapi::ShaderInfo pick_sdpa_compute_attn_weights_shader( +vkapi::ShaderInfo pick_sdpa_qk_shader( ComputeGraph* graph, const std::vector& args, const std::vector& resize_args) { - const ValueRef q_projected = args.at(1).refs.at(0); - const ValueRef k_cache = args.at(1).refs.at(1); - - const bool is_gemv = is_single_token(graph, q_projected); - - std::string shader_name = "sdpa_compute_attn_weights"; - if (is_gemv) { - shader_name += "_coop"; + const SDPAMode mode = mode_of(resize_args); + if (mode == SDPAMode::LLM) { + const ValueRef q_projected = args.at(1).refs.at(0); + const ValueRef k_cache = args.at(1).refs.at(1); + const bool is_gemv = is_single_token(graph, q_projected); + + std::string shader_name = "sdpa_compute_attn_weights"; + shader_name += is_gemv ? "_coop" : "_tiled"; + add_storage_type_suffix(shader_name, graph->storage_type_of(q_projected)); + add_storage_type_suffix(shader_name, graph->storage_type_of(k_cache)); + add_dtype_suffix(shader_name, graph->dtype_of(q_projected)); + return VK_KERNEL_FROM_STR(shader_name); } else { - shader_name += "_tiled"; + const ValueRef q = args.at(1).refs.at(0); + const ValueRef k = args.at(1).refs.at(1); + // Fused path uses bias variant iff attn_mask was provided (signalled via + // 3 inputs in the read group: q, k, attn_mask). + const bool has_bias = args.at(1).refs.size() >= 3; + std::string shader_name = + has_bias ? "fused_sdpa_qk_tiled_bias" : "fused_sdpa_qk_tiled"; + add_storage_type_suffix(shader_name, graph->storage_type_of(q)); + add_storage_type_suffix(shader_name, graph->storage_type_of(k)); + add_dtype_suffix(shader_name, graph->dtype_of(q)); + return VK_KERNEL_FROM_STR(shader_name); } - - add_storage_type_suffix(shader_name, graph->storage_type_of(q_projected)); - add_storage_type_suffix(shader_name, graph->storage_type_of(k_cache)); - add_dtype_suffix(shader_name, graph->dtype_of(q_projected)); - - return VK_KERNEL_FROM_STR(shader_name); } -utils::uvec3 pick_sdpa_compute_attn_weights_global_wg_size( +utils::uvec3 pick_sdpa_qk_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const std::vector& args, const std::vector& resize_args) { - const ValueRef q_projected = args.at(1).refs.at(0); - const ValueRef input_pos_symint = resize_args.at(0); - - const uint32_t num_q_heads = graph->size_at(-2, q_projected); - const uint32_t seq_len = graph->size_at(-3, q_projected); - - const int32_t input_pos_val = graph->read_symint(input_pos_symint); - - const uint32_t context_len = seq_len + input_pos_val; - - const uint32_t N4 = utils::div_up_4(context_len); - const uint32_t M4 = utils::div_up_4(seq_len); - - return {N4, M4, num_q_heads}; + (void)shader; + const SDPAMode mode = mode_of(resize_args); + const ValueRef q = resize_args.at(0); + const ValueRef k = resize_args.at(1); + const ValueRef input_pos_symint = resize_args.at(2); + const SDPADims d = compute_sdpa_dims(*graph, q, k, input_pos_symint, mode); + + // Dispatch grid: (context_len tiles, S tiles, H * B). + const uint32_t N4 = utils::div_up_4(static_cast(d.context_len)); + const uint32_t M4 = utils::div_up_4(static_cast(d.S)); + return {N4, M4, static_cast(d.H * d.B)}; } -utils::uvec3 pick_sdpa_compute_attn_weights_local_wg_size( +utils::uvec3 pick_sdpa_qk_local_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const utils::uvec3& global_workgroup_size, const std::vector& args, const std::vector& resize_args) { - const bool use_coop_algorithm = - shader.kernel_name.find("_coop") != std::string::npos; - - if (use_coop_algorithm) { - return {1, 64, 1}; - } else { + const SDPAMode mode = mode_of(resize_args); + if (mode == SDPAMode::LLM) { + const bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; + if (use_coop_algorithm) { + return {1, 64, 1}; + } return pick_hw_square_wg_size( graph, shader, global_workgroup_size, args, resize_args); } + return default_pick_local_wg_size( + graph, shader, global_workgroup_size, args, resize_args); } -utils::uvec3 pick_sdpa_attn_weights_softmax_global_wg_size( +utils::uvec3 pick_sdpa_softmax_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const std::vector& args, const std::vector& resize_args) { - const ValueRef q_projected = resize_args.at(0); - - const uint32_t num_q_heads = graph->size_at(-2, q_projected); - const uint32_t seq_len = graph->size_at(-3, q_projected); - - return {1, seq_len, num_q_heads}; + (void)shader; + const SDPAMode mode = mode_of(resize_args); + const ValueRef q = resize_args.at(0); + // LLM reads H from axis -2, fused from axis -3 (handled by + // compute_sdpa_dims). + const int64_t num_q_heads = (mode == SDPAMode::LLM) + ? graph->size_at(-2, q) + : graph->size_at(-3, q); + const int64_t seq_len = (mode == SDPAMode::LLM) + ? graph->size_at(-3, q) + : graph->size_at(-2, q); + const int64_t B = + (mode == SDPAMode::LLM) ? 1 : graph->size_at(-4, q); + return { + 1, + static_cast(seq_len), + static_cast(num_q_heads * B)}; } -utils::uvec3 pick_sdpa_attn_weights_softmax_local_wg_size( +utils::uvec3 pick_sdpa_softmax_local_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const utils::uvec3& global_workgroup_size, const std::vector& args, const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)global_workgroup_size; + (void)args; + (void)resize_args; return {64, 1, 1}; } -vkapi::ShaderInfo pick_sdpa_compute_out_shader( +vkapi::ShaderInfo pick_sdpa_av_shader( ComputeGraph* graph, const std::vector& args, const std::vector& resize_args) { - const ValueRef out = args.at(0).refs.at(0); - const ValueRef v_cache = args.at(1).refs.at(1); - - const ValueRef q_projected = resize_args.at(0); - - const bool is_gemv = is_single_token(graph, q_projected); - - std::string shader_name = "sdpa_compute_out"; - if (is_gemv) { - shader_name += "_coop"; + const SDPAMode mode = mode_of(resize_args); + if (mode == SDPAMode::LLM) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef v_cache = args.at(1).refs.at(1); + const ValueRef q_projected = resize_args.at(0); + const bool is_gemv = is_single_token(graph, q_projected); + + std::string shader_name = "sdpa_compute_out"; + shader_name += is_gemv ? "_coop" : "_tiled"; + add_storage_type_suffix(shader_name, graph->storage_type_of(out)); + add_storage_type_suffix(shader_name, graph->storage_type_of(v_cache)); + add_dtype_suffix(shader_name, graph->dtype_of(out)); + return VK_KERNEL_FROM_STR(shader_name); } else { - shader_name += "_tiled"; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef v = args.at(1).refs.at(1); + std::string shader_name = "fused_sdpa_av_tiled"; + add_storage_type_suffix(shader_name, graph->storage_type_of(out)); + add_storage_type_suffix(shader_name, graph->storage_type_of(v)); + add_dtype_suffix(shader_name, graph->dtype_of(out)); + return VK_KERNEL_FROM_STR(shader_name); } - - add_storage_type_suffix(shader_name, graph->storage_type_of(out)); - add_storage_type_suffix(shader_name, graph->storage_type_of(v_cache)); - add_dtype_suffix(shader_name, graph->dtype_of(out)); - - return VK_KERNEL_FROM_STR(shader_name); } -utils::uvec3 pick_sdpa_compute_out_global_wg_size( +utils::uvec3 pick_sdpa_av_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const std::vector& args, const std::vector& resize_args) { - const ValueRef q_projected = resize_args.at(0); - - const uint32_t head_dim = graph->size_at(-1, q_projected); - const uint32_t num_q_heads = graph->size_at(-2, q_projected); - const uint32_t seq_len = graph->size_at(-3, q_projected); - - const uint32_t N4 = utils::div_up_4(head_dim); - const uint32_t M4 = utils::div_up_4(seq_len); - - return {N4, M4, num_q_heads}; + (void)shader; + const SDPAMode mode = mode_of(resize_args); + const ValueRef q = resize_args.at(0); + const ValueRef k = resize_args.at(1); + const ValueRef input_pos_symint = resize_args.at(2); + const SDPADims d = compute_sdpa_dims(*graph, q, k, input_pos_symint, mode); + + const uint32_t N4 = utils::div_up_4(static_cast(d.D)); + const uint32_t M4 = utils::div_up_4(static_cast(d.S)); + return {N4, M4, static_cast(d.H * d.B)}; } -utils::uvec3 pick_sdpa_compute_out_local_wg_size( +utils::uvec3 pick_sdpa_av_local_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const utils::uvec3& global_workgroup_size, const std::vector& args, const std::vector& resize_args) { - const bool use_coop_algorithm = - shader.kernel_name.find("_coop") != std::string::npos; - - if (use_coop_algorithm) { - return {1, 64, 1}; - } else { + const SDPAMode mode = mode_of(resize_args); + if (mode == SDPAMode::LLM) { + const bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; + if (use_coop_algorithm) { + return {1, 64, 1}; + } return pick_hw_square_wg_size( graph, shader, global_workgroup_size, args, resize_args); } + return default_pick_local_wg_size( + graph, shader, global_workgroup_size, args, resize_args); } // @@ -309,59 +384,93 @@ void add_sdpa_kv_cache_update_node( nullptr)); } +// Unified QK node (attn_weights = scale * Q @ K^T [+ bias]). +// LLM: pass input_pos_symint (real symint), attn_mask = kDummyValueRef. +// FUSED: pass input_pos_symint = kDummyValueRef, attn_mask = valid ref or +// kDummyValueRef to indicate no bias. scale_val is always passed as +// a spec const; the LLM path computes it per head_dim and FUSED may +// inherit from the caller-supplied scale. void add_sdpa_compute_attn_weights_node( ComputeGraph& graph, - const ValueRef q_projected, - const ValueRef k_cache, + const ValueRef q, + const ValueRef k, const ValueRef input_pos_symint, - const ValueRef attn_weights) { - const int32_t head_dim_size = graph.size_at(-1, q_projected); - const float scale_val = 1.0f / std::sqrt(static_cast(head_dim_size)); - + const ValueRef attn_mask, + const float scale_val, + const ValueRef attn_weights, + const SDPAMode mode) { vkapi::ParamsBindList param_ubos = { - graph.sizes_ubo(q_projected), - graph.sizes_ubo(k_cache), - graph.get_or_create_int_param_buffer(input_pos_symint)}; + graph.sizes_ubo(q), + graph.sizes_ubo(k), + }; + std::vector read_inputs = {q, k}; + + if (mode == SDPAMode::LLM) { + param_ubos.append(graph.get_or_create_int_param_buffer(input_pos_symint)); + } else if (is_valid(attn_mask)) { + param_ubos.append(graph.sizes_ubo(attn_mask)); + read_inputs.push_back(attn_mask); + } + + const ValueRef mode_ref = static_cast(mode); graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - pick_sdpa_compute_attn_weights_shader, - pick_sdpa_compute_attn_weights_global_wg_size, - pick_sdpa_compute_attn_weights_local_wg_size, + pick_sdpa_qk_shader, + pick_sdpa_qk_global_wg_size, + pick_sdpa_qk_local_wg_size, // Inputs and Outputs - {{attn_weights, vkapi::kWrite}, {{q_projected, k_cache}, vkapi::kRead}}, + {{attn_weights, vkapi::kWrite}, {read_inputs, vkapi::kRead}}, // Shader param buffers param_ubos, // Push Constants {}, // Specialization Constants {scale_val}, - // Resize Args - {input_pos_symint}, + // Resize Args: [q, k, input_pos_symint_or_dummy, mode] + {q, k, input_pos_symint, mode_ref}, // Resizing Logic - resize_compute_attn_weights_node)); + resize_sdpa_attn_weights_node)); } void add_sdpa_attn_weights_softmax_node( ComputeGraph& graph, const ValueRef attn_weights, - const ValueRef q_projected, + const ValueRef q, + const ValueRef k, const ValueRef input_pos_symint, - const ValueRef attn_weights_softmax) { - std::string shader_name = "sdpa_attn_weights_softmax"; - add_storage_type_suffix( - shader_name, graph.storage_type_of(attn_weights_softmax)); - add_dtype_suffix(shader_name, graph.dtype_of(attn_weights_softmax)); + const ValueRef attn_weights_softmax, + const SDPAMode mode) { + std::string shader_name; + if (mode == SDPAMode::LLM) { + shader_name = "sdpa_attn_weights_softmax"; + add_storage_type_suffix( + shader_name, graph.storage_type_of(attn_weights_softmax)); + add_dtype_suffix(shader_name, graph.dtype_of(attn_weights_softmax)); + } else { + shader_name = "fused_sdpa_softmax"; + add_storage_type_suffix( + shader_name, graph.storage_type_of(attn_weights_softmax)); + add_dtype_suffix(shader_name, graph.dtype_of(attn_weights_softmax)); + } - vkapi::ParamsBindList param_ubos = { - graph.sizes_ubo(q_projected), - graph.get_or_create_int_param_buffer(input_pos_symint)}; + vkapi::ParamsBindList param_ubos; + if (mode == SDPAMode::LLM) { + param_ubos = { + graph.sizes_ubo(q), + graph.sizes_ubo(k), + graph.get_or_create_int_param_buffer(input_pos_symint)}; + } else { + param_ubos = {graph.sizes_ubo(q), graph.sizes_ubo(k)}; + } + + const ValueRef mode_ref = static_cast(mode); graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(shader_name), - pick_sdpa_attn_weights_softmax_global_wg_size, - pick_sdpa_attn_weights_softmax_local_wg_size, + pick_sdpa_softmax_global_wg_size, + pick_sdpa_softmax_local_wg_size, // Inputs and Outputs {{attn_weights_softmax, vkapi::kWrite}, {attn_weights, vkapi::kRead}}, // Shader param buffers @@ -370,8 +479,8 @@ void add_sdpa_attn_weights_softmax_node( {}, // Specialization Constants {}, - // Resize Args - {q_projected, input_pos_symint}, + // Resize Args: [q, k, input_pos_symint_or_dummy, mode] + {q, k, input_pos_symint, mode_ref}, // Resizing Logic resize_sdpa_attn_weights_softmax_node)); } @@ -379,32 +488,41 @@ void add_sdpa_attn_weights_softmax_node( void add_sdpa_compute_out_node( ComputeGraph& graph, const ValueRef attn_weights_softmax, - const ValueRef v_cache, - const ValueRef q_projected, + const ValueRef v, + const ValueRef q, + const ValueRef k, const ValueRef input_pos_symint, - const ValueRef out) { - vkapi::ParamsBindList param_ubos = { - graph.sizes_ubo(q_projected), - graph.sizes_ubo(v_cache), - graph.get_or_create_int_param_buffer(input_pos_symint)}; + const ValueRef out, + const SDPAMode mode) { + vkapi::ParamsBindList param_ubos; + if (mode == SDPAMode::LLM) { + param_ubos = { + graph.sizes_ubo(q), + graph.sizes_ubo(v), + graph.get_or_create_int_param_buffer(input_pos_symint)}; + } else { + param_ubos = {graph.sizes_ubo(q), graph.sizes_ubo(k)}; + } + + const ValueRef mode_ref = static_cast(mode); graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - pick_sdpa_compute_out_shader, - pick_sdpa_compute_out_global_wg_size, - pick_sdpa_compute_out_local_wg_size, + pick_sdpa_av_shader, + pick_sdpa_av_global_wg_size, + pick_sdpa_av_local_wg_size, // Inputs and Outputs - {{out, vkapi::kWrite}, {{attn_weights_softmax, v_cache}, vkapi::kRead}}, + {{out, vkapi::kWrite}, {{attn_weights_softmax, v}, vkapi::kRead}}, // Shader param buffers param_ubos, // Push Constants {}, // Specialization Constants {}, - // Resize Args - {q_projected, input_pos_symint}, + // Resize Args: [q, k, input_pos_symint_or_dummy, mode] + {q, k, input_pos_symint, mode_ref}, // Resizing Logic - resize_sdpa_compute_out_node)); + resize_sdpa_out_node)); } // @@ -515,14 +633,37 @@ void sdpa_impl(ComputeGraph& graph, const std::vector& args) { attn_weights_storage, utils::kWidthPacked); + const int32_t head_dim_size = graph.size_at(-1, q_projected); + const float scale_val = 1.0f / std::sqrt(static_cast(head_dim_size)); + add_sdpa_compute_attn_weights_node( - graph, q_projected, k_cache, input_pos_symint, attn_weights); + graph, + q_projected, + k_cache, + input_pos_symint, + /*attn_mask=*/kDummyValueRef, + scale_val, + attn_weights, + SDPAMode::LLM); add_sdpa_attn_weights_softmax_node( - graph, attn_weights, q_projected, input_pos_symint, attn_weights_softmax); + graph, + attn_weights, + q_projected, + k_cache, + input_pos_symint, + attn_weights_softmax, + SDPAMode::LLM); add_sdpa_compute_out_node( - graph, attn_weights_softmax, v_cache, q_projected, input_pos_symint, out); + graph, + attn_weights_softmax, + v_cache, + q_projected, + /*k=*/kDummyValueRef, + input_pos_symint, + out, + SDPAMode::LLM); } void sdpa_with_kv_cache_impl( @@ -602,8 +743,121 @@ void compute_attn_weight_with_kv_cache_impl( update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); + const int32_t head_dim_size = graph.size_at(-1, q_projected); + const float scale_val = 1.0f / std::sqrt(static_cast(head_dim_size)); + add_sdpa_compute_attn_weights_node( - graph, q_projected, k_cache, input_pos_symint, out); + graph, + q_projected, + k_cache, + input_pos_symint, + /*attn_mask=*/kDummyValueRef, + scale_val, + out, + SDPAMode::LLM); +} + +// +// Fused SDPA entry point (et_vk.sdpa.default). +// +// Accepts pre-reshaped [B, H, S, D] tensors (DSHB) plus optional additive +// attn_mask and optional scale scalar. No KV cache; this is the general SDPA +// fused op used by non-LLM models. +// +void fused_sdpa_impl(ComputeGraph& graph, const std::vector& args) { + int arg_idx = 0; + const ValueRef q = args[arg_idx++]; + const ValueRef k = args[arg_idx++]; + const ValueRef v = args[arg_idx++]; + const ValueRef attn_mask = args[arg_idx++]; + const ValueRef scale_ref = args[arg_idx++]; + const ValueRef out = args[arg_idx++]; + + // Validate inputs + VK_CHECK_COND(graph.dim_of(q) == 4); + VK_CHECK_COND(graph.dim_of(k) == 4); + VK_CHECK_COND(graph.dim_of(v) == 4); + // Head dim must match between Q and K + VK_CHECK_COND(graph.size_at(-1, q) == graph.size_at(-1, k)); + // K and V must have same sequence length + VK_CHECK_COND(graph.size_at(-2, k) == graph.size_at(-2, v)); + // All tensors must be width-packed + VK_CHECK_COND(graph.packed_dim_of(q) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(k) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(v) == WHCN::kWidthDim); + + // Compute scale + const int32_t head_dim = graph.size_at(-1, q); + float scale_val; + if (graph.val_is_none(scale_ref)) { + scale_val = 1.0f / std::sqrt(static_cast(head_dim)); + } else { + scale_val = graph.extract_scalar(scale_ref); + } + + // Resolve attn_mask: a None value is normalized to kDummyValueRef so the + // unified helpers can branch with a single `is_valid()` check. + const ValueRef attn_mask_ref = + graph.val_is_none(attn_mask) ? kDummyValueRef : attn_mask; + + // Get dimensions for intermediate allocation + const int64_t B = graph.size_at(-4, q); + const int64_t H = graph.size_at(-3, q); + const int64_t S = graph.size_at(-2, q); + const int64_t L = graph.size_at(-2, k); + + std::vector attn_weight_sizes = {B, H, S, L}; + + // attn_weights and attn_weights_softmax follow the output's storage so the + // entire fused SDPA pipeline uses a uniform storage type. attn_weights stays + // in fp32 for numerical stability of the Q@K^T accumulation. + const utils::StorageType attn_storage = graph.storage_type_of(out); + + TmpTensor attn_weights( + &graph, + attn_weight_sizes, + vkapi::ScalarType::Float, + attn_storage, + utils::kWidthPacked); + + TmpTensor attn_weights_softmax( + &graph, + attn_weight_sizes, + graph.dtype_of(q), + attn_storage, + utils::kWidthPacked); + + // Phase 1: Q @ K^T with fp32 accumulation, apply scale and optional bias + add_sdpa_compute_attn_weights_node( + graph, + q, + k, + /*input_pos_symint=*/kDummyValueRef, + attn_mask_ref, + scale_val, + attn_weights, + SDPAMode::FUSED); + + // Phase 2: Softmax in fp32, output in input dtype + add_sdpa_attn_weights_softmax_node( + graph, + attn_weights, + q, + k, + /*input_pos_symint=*/kDummyValueRef, + attn_weights_softmax, + SDPAMode::FUSED); + + // Phase 3: attn_weights_softmax @ V + add_sdpa_compute_out_node( + graph, + attn_weights_softmax, + v, + q, + k, + /*input_pos_symint=*/kDummyValueRef, + out, + SDPAMode::FUSED); } REGISTER_OPERATORS { @@ -613,6 +867,7 @@ REGISTER_OPERATORS { VK_REGISTER_OP( testing.compute_attn_weight_with_kv_cache.default, compute_attn_weight_with_kv_cache_impl); + VK_REGISTER_OP(et_vk.sdpa.default, fused_sdpa_impl); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/sdpa_test.cpp b/backends/vulkan/test/op_tests/sdpa_test.cpp index d881ce7a7f4..fd2afc7408e 100644 --- a/backends/vulkan/test/op_tests/sdpa_test.cpp +++ b/backends/vulkan/test/op_tests/sdpa_test.cpp @@ -616,6 +616,242 @@ void test_vulkan_sdpa( } } +// +// General-purpose fused SDPA tests (et_vk.sdpa) +// + +/* + * Reference implementation of general SDPA: softmax(Q @ K^T * scale + bias) @ V + * Q: [B, H, S, D], K: [B, H, L, D], V: [B, H, L, D] + * Returns: [B, H, S, D] + */ +at::Tensor general_sdpa_reference_impl( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const std::optional& attn_mask = std::nullopt, + const std::optional scale = std::nullopt) { + float scale_val = + scale.has_value() ? scale.value() : (1.0 / sqrt(q.size(-1))); + at::Tensor attn = at::matmul(q, k.transpose(-2, -1)) * scale_val; + if (attn_mask.has_value()) { + attn = attn + attn_mask.value(); + } + attn = at::softmax(attn, -1); + return at::matmul(attn, v); +} + +void test_vulkan_general_sdpa( + const int batch_size, + const int num_heads, + const int q_seq_len, + const int kv_seq_len, + const int head_dim, + const bool has_bias, + at::ScalarType dtype = at::kFloat) { + torch::manual_seed(42); + + // Generate random inputs in [B, H, S, D] layout + at::Tensor q = at::rand( + {batch_size, num_heads, q_seq_len, head_dim}, + at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor k = at::rand( + {batch_size, num_heads, kv_seq_len, head_dim}, + at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor v = at::rand( + {batch_size, num_heads, kv_seq_len, head_dim}, + at::device(at::kCPU).dtype(at::kFloat)); + + std::optional bias = std::nullopt; + if (has_bias) { + // Broadcastable bias: [B, 1, 1, kv_seq_len] + bias = at::rand( + {batch_size, 1, 1, kv_seq_len}, + at::device(at::kCPU).dtype(at::kFloat)) * + 2.0 - + 1.0; + } + + // Compute reference output in fp32 + at::Tensor reference_out = general_sdpa_reference_impl(q, k, v, bias); + + // Cast to test dtype for Vulkan + q = q.to(dtype); + k = k.to(dtype); + v = v.to(dtype); + if (bias.has_value()) { + bias = bias.value().to(dtype); + } + + // Build Vulkan compute graph + using namespace vkcompute; + + GraphConfig config; + ComputeGraph graph(config); + + IOValueRef r_q = graph.add_input_tensor( + q.sizes().vec(), from_at_scalartype(dtype), utils::kBuffer); + IOValueRef r_k = graph.add_input_tensor( + k.sizes().vec(), from_at_scalartype(dtype), utils::kBuffer); + IOValueRef r_v = graph.add_input_tensor( + v.sizes().vec(), from_at_scalartype(dtype), utils::kBuffer); + + ValueRef r_bias = kDummyValueRef; + IOValueRef r_bias_io = {}; + if (has_bias) { + r_bias_io = graph.add_input_tensor( + bias.value().sizes().vec(), from_at_scalartype(dtype), utils::kBuffer); + r_bias = r_bias_io.value; + } + + const ValueRef r_out = graph.add_tensor( + {batch_size, num_heads, q_seq_len, head_dim}, + from_at_scalartype(dtype), + utils::kBuffer); + + VK_GET_OP_FN("et_vk.sdpa.default") + (graph, + { + r_q.value, + r_k.value, + r_v.value, + r_bias, + kDummyValueRef, // scale (None -> 1/sqrt(head_dim)) + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.prepack(); + + // Copy inputs + graph.maybe_cast_and_copy_into_staging( + r_q.staging, q.const_data_ptr(), q.numel(), from_at_scalartype(dtype)); + graph.maybe_cast_and_copy_into_staging( + r_k.staging, k.const_data_ptr(), k.numel(), from_at_scalartype(dtype)); + graph.maybe_cast_and_copy_into_staging( + r_v.staging, v.const_data_ptr(), v.numel(), from_at_scalartype(dtype)); + if (has_bias) { + graph.maybe_cast_and_copy_into_staging( + r_bias_io.staging, + bias.value().const_data_ptr(), + bias.value().numel(), + from_at_scalartype(dtype)); + } + + graph.execute(); + + // Extract output + at::Tensor vk_out = at::zeros( + {batch_size, num_heads, q_seq_len, head_dim}, + at::device(at::kCPU).dtype(dtype)) + .contiguous(); + graph.maybe_cast_and_copy_from_staging( + staging_out, + vk_out.mutable_data_ptr(), + vk_out.numel(), + from_at_scalartype(dtype)); + + // Compare in fp32 + vk_out = vk_out.to(at::kFloat); + + // Use appropriate tolerance based on dtype + double atol = dtype == at::kHalf ? 1e-2 : 1e-4; + double rtol = dtype == at::kHalf ? 1e-2 : 1e-5; + + const bool output_correct = at::allclose(reference_out, vk_out, rtol, atol); + if (!output_correct) { + at::Tensor diffs = at::abs(reference_out - vk_out); + std::cout << "General SDPA test failed:" << " B=" << batch_size + << " H=" << num_heads << " S=" << q_seq_len << " L=" << kv_seq_len + << " D=" << head_dim << " bias=" << has_bias << " dtype=" << dtype + << std::endl; + std::cout << "Max diff: " << at::max(diffs).item() << std::endl; + std::cout << "Max value: " + << at::max(at::abs(at::cat({reference_out, vk_out}, -1))).item() + << std::endl; + + // Print all elements for small tensors + if (reference_out.numel() <= 64) { + auto ref_flat = reference_out.flatten(); + auto vk_flat = vk_out.flatten(); + std::cout << "Reference vs Vulkan:" << std::endl; + for (int i = 0; i < ref_flat.numel(); ++i) { + std::cout << " [" << i << "] ref=" << ref_flat[i].item() + << " vk=" << vk_flat[i].item() << " diff=" + << std::abs( + ref_flat[i].item() - vk_flat[i].item()) + << std::endl; + } + } + } + ASSERT_TRUE(output_correct); +} + +// Basic correctness: small sizes, no bias, fp32 +TEST(VulkanGeneralSDPATest, test_general_sdpa_small_no_bias) { + test_vulkan_general_sdpa(1, 2, 4, 4, 8, false); +} + +// With additive bias mask +TEST(VulkanGeneralSDPATest, test_general_sdpa_small_with_bias) { + test_vulkan_general_sdpa(1, 2, 4, 8, 8, true); +} + +// Cross-attention: Q and K have different sequence lengths +TEST(VulkanGeneralSDPATest, test_general_sdpa_cross_attention) { + test_vulkan_general_sdpa(1, 4, 4, 16, 16, false); +} + +// Batch size > 1 +TEST(VulkanGeneralSDPATest, test_general_sdpa_batched) { + test_vulkan_general_sdpa(2, 4, 8, 8, 16, false); +} + +// Larger head_dim with bias (EdgeTAM-like) +TEST(VulkanGeneralSDPATest, test_general_sdpa_large_head_dim) { + test_vulkan_general_sdpa(1, 8, 4, 4, 32, true); +} + +// Non-aligned S (S is height dim, not width — no padding issue) +TEST(VulkanGeneralSDPATest, test_general_sdpa_non_aligned_s) { + test_vulkan_general_sdpa(1, 2, 5, 4, 32, false); +} + +// Large number of heads +TEST(VulkanGeneralSDPATest, test_general_sdpa_many_heads) { + test_vulkan_general_sdpa(1, 8, 4, 8, 32, false); +} + +// fp16 — validates fp32 internal accumulation +TEST(VulkanGeneralSDPATest, test_general_sdpa_fp16) { + test_vulkan_general_sdpa( + /*batch_size=*/1, + /*num_heads=*/4, + /*q_seq_len=*/8, + /*kv_seq_len=*/8, + /*head_dim=*/16, + /*has_bias=*/false, + /*dtype=*/at::kHalf); +} + +// fp16 with bias +TEST(VulkanGeneralSDPATest, test_general_sdpa_fp16_with_bias) { + test_vulkan_general_sdpa( + /*batch_size=*/1, + /*num_heads=*/4, + /*q_seq_len=*/8, + /*kv_seq_len=*/16, + /*head_dim=*/16, + /*has_bias=*/true, + /*dtype=*/at::kHalf); +} + +// +// Existing KV-cache SDPA tests +// + TEST(VulkanSDPATest, test_sdpa_op_small_params) { const int base_sequence_len = 3; const int num_heads = 8;