From 8080cab171664f8413d2374f02474700dae05580 Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 4 Nov 2025 11:52:39 -0800 Subject: [PATCH 1/4] [ET-VK][ez] Refactor yaml configs for SDPA shaders Title says it all! Use the new combos codegen API which makes it easier to express generating storage type combinations. Differential Revision: [D86226138](https://our.internmc.facebook.com/intern/diff/D86226138/) ghstack-source-id: 320850476 Pull Request resolved: https://github.com/pytorch/executorch/pull/15576 --- .../graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml | 10 +++++++--- .../ops/glsl/sdpa_compute_attn_weights_tiled.yaml | 10 +++++++--- .../runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml | 10 +++++++--- .../runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml | 10 +++++++--- .../runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl | 2 ++ .../runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml | 10 +++++++--- backends/vulkan/runtime/graph/ops/impl/SDPA.cpp | 1 + 7 files changed, 38 insertions(+), 15 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml index 6a4cffcc913..d5cadc36060 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml @@ -12,10 +12,14 @@ sdpa_compute_attn_weights_coop: TILE_K4: 1 TILE_N4: 1 generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, K_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: sdpa_compute_attn_weights_coop_texture3d_texture3d - - NAME: sdpa_compute_attn_weights_coop_buffer_texture3d - IO_STORAGE: buffer + - NAME: sdpa_compute_attn_weights_coop diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml index 6aadbbc379e..7fc016cf3c3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml @@ -13,10 +13,14 @@ sdpa_compute_attn_weights_tiled: TILE_K4: 1 TILE_N4: 1 generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, K_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: sdpa_compute_attn_weights_tiled_texture3d_texture3d - - NAME: sdpa_compute_attn_weights_tiled_buffer_texture3d - IO_STORAGE: buffer + - NAME: sdpa_compute_attn_weights_tiled diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml index ccebf8f7c1c..33ec2f8b322 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml @@ -12,10 +12,14 @@ sdpa_compute_out_coop: TILE_K4: 1 TILE_N4: 1 generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, V_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: sdpa_compute_out_coop_texture3d_texture3d - - NAME: sdpa_compute_out_coop_buffer_texture3d - IO_STORAGE: buffer + - NAME: sdpa_compute_out_coop diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml index 7fbce29e908..eac2c6f37dd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml @@ -13,10 +13,14 @@ sdpa_compute_out_tiled: TILE_K4: 1 TILE_N4: 1 generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, V_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: sdpa_compute_out_tiled_texture3d_texture3d - - NAME: sdpa_compute_out_tiled_buffer_texture3d - IO_STORAGE: buffer + - NAME: sdpa_compute_out_tiled diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl index 932696fff02..028e02d1a20 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl @@ -5,6 +5,8 @@ #define IN_VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} #define T ${buffer_scalar_type(DTYPE)} +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER $if INPUT_STORAGE == "buffer": #define INPUT_BUFFER diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml index 85f4ce090f8..5ec2f3e190c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml @@ -10,10 +10,14 @@ sdpa_kv_cache_update: INPUT_STORAGE: texture3d OUTPUT_STORAGE: texture3d generate_variant_forall: + combination: + parameter_names: [OUTPUT_STORAGE, INPUT_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [texture3d, buffer] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: half - VALUE: float shader_variants: - - NAME: sdpa_kv_cache_update_texture3d - - NAME: sdpa_kv_cache_update_buffer - INPUT_STORAGE: buffer + - NAME: sdpa_kv_cache_update diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 8edaebd11ff..92b14c3b724 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -282,6 +282,7 @@ void add_sdpa_kv_cache_update_node( const ValueRef projected, const ValueRef cache) { std::string kernel_name("sdpa_kv_cache_update"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(cache)); add_storage_type_suffix(kernel_name, graph.storage_type_of(projected)); add_dtype_suffix(kernel_name, graph.dtype_of(projected)); From cf49a75b3bb0aadc056299a38bfde8ec41225f47 Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 4 Nov 2025 11:52:43 -0800 Subject: [PATCH 2/4] [ET-VK][ez] Update SDPA test to be able to test different SDPA modes Title says it all! The purpose of this diff is twofold: 1. Test SDPA as both a fused operator (sdpa_with_kv_cache) and decomposed update_cache and custom_sdpa ops in order to detect possible regressions with being able to support older models 2. Make it easier to debug issues with SDPA by exposing a mode that tests only the attention weight computation. Title says it all! Update SDPA op to use buffer storage for cache tensors if projected tensors are buffer. Also included is a small change to ensure that cache tensors use the same storage type as input tensors. Differential Revision: [D86226135](https://our.internmc.facebook.com/intern/diff/D86226135/) ghstack-source-id: 320850473 Pull Request resolved: https://github.com/pytorch/executorch/pull/15577 --- backends/vulkan/op_registry.py | 2 +- .../vulkan/runtime/graph/ops/impl/SDPA.cpp | 50 ++++- backends/vulkan/test/op_tests/sdpa_test.cpp | 206 ++++++++++++++---- 3 files changed, 210 insertions(+), 48 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index b47a8f383a0..7672a2d891c 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -630,7 +630,7 @@ def register_dequantize_for_conv2d_op(): @update_features("llama::sdpa_with_kv_cache") def register_sdpa_with_kv_cache_op(): return OpFeatures( - inputs_storage=utils.WIDTH_PACKED_TEXTURE, + inputs_storage=utils.CONTIGUOUS_ANY, supports_resize=True, supports_prepacking=True, ) diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 92b14c3b724..6b4da5d95f1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -526,10 +526,11 @@ void sdpa_with_kv_cache_impl( (void)sequence_len; - const ValueRef k_cache = prepack_standard( - graph, k_cache_data, utils::kTexture3D, utils::kWidthPacked); - const ValueRef v_cache = prepack_standard( - graph, v_cache_data, utils::kTexture3D, utils::kWidthPacked); + utils::StorageType cache_storage = graph.storage_type_of(q_projected); + const ValueRef k_cache = + prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked); + const ValueRef v_cache = + prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked); update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); @@ -547,10 +548,51 @@ void sdpa_with_kv_cache_impl( out}); } +void compute_attn_weight_with_kv_cache_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef q_projected = args[arg_idx++]; + const ValueRef k_projected = args[arg_idx++]; + const ValueRef v_projected = args[arg_idx++]; + const ValueRef k_cache_data = args[arg_idx++]; + const ValueRef v_cache_data = args[arg_idx++]; + const ValueRef input_pos_symint = args[arg_idx++]; + const ValueRef sequence_len = args[arg_idx++]; + const ValueRef attn_mask = args[arg_idx++]; + (void)attn_mask; + const ValueRef dropout_p = args[arg_idx++]; + (void)dropout_p; + const ValueRef is_causal = args[arg_idx++]; + (void)is_causal; + const ValueRef scale = args[arg_idx++]; + (void)scale; + + // Output tensors + const ValueRef out = args[arg_idx++]; + + (void)sequence_len; + + utils::StorageType cache_storage = graph.storage_type_of(q_projected); + const ValueRef k_cache = + prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked); + const ValueRef v_cache = + prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked); + + update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); + update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); + + add_sdpa_compute_attn_weights_node( + graph, q_projected, k_cache, input_pos_symint, out); +} + REGISTER_OPERATORS { VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl); VK_REGISTER_OP(update_cache.default, update_cache_impl); VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl); + VK_REGISTER_OP( + testing.compute_attn_weight_with_kv_cache.default, + compute_attn_weight_with_kv_cache_impl); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/sdpa_test.cpp b/backends/vulkan/test/op_tests/sdpa_test.cpp index a94e68a53af..c3347b339a7 100644 --- a/backends/vulkan/test/op_tests/sdpa_test.cpp +++ b/backends/vulkan/test/op_tests/sdpa_test.cpp @@ -23,6 +23,24 @@ #include #include +// +// SDPA Mode Enum +// + +enum class SDPAMode { DECOMPOSED, FUSED, ATTN_WEIGHT_ONLY }; + +std::ostream& operator<<(std::ostream& os, const SDPAMode& mode) { + switch (mode) { + case SDPAMode::DECOMPOSED: + return os << "DECOMPOSED"; + case SDPAMode::FUSED: + return os << "FUSED"; + case SDPAMode::ATTN_WEIGHT_ONLY: + return os << "ATTN_WEIGHT_ONLY"; + } + return os; +} + namespace torch { namespace executor { namespace native { @@ -74,7 +92,7 @@ at::Tensor sdpa_with_kv_cache_aten( const int64_t seq_len, // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy - const std::optional attn_mask, + const std::optional& attn_mask, const double dropout_p, const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy @@ -161,10 +179,11 @@ at::Tensor sdpa_reference_impl( at::Tensor& value_cache, const int64_t start_pos, const int64_t seq_len, - const std::optional __attn_mask_ignored, + const std::optional& __attn_mask_ignored, const double dropout_p, const bool is_causal, - const std::optional scale) { + const std::optional scale, + SDPAMode mode = SDPAMode::DECOMPOSED) { at::Tensor attn_mask = construct_attention_mask(q_projected, key_cache, start_pos); @@ -202,6 +221,10 @@ at::Tensor sdpa_reference_impl( float scale_factor = 1.0 / sqrt(q_transposed.size(-1)); at::Tensor attn_weight = attn_weight_prescale * scale_factor + attn_mask; + if (mode == SDPAMode::ATTN_WEIGHT_ONLY) { + return attn_weight; + } + at::Tensor attn_weight_softmax = at::softmax(attn_weight, -1); at::Tensor out = at::matmul(attn_weight_softmax, v_transposed); @@ -268,7 +291,8 @@ void test_vulkan_sdpa( const int num_kv_heads, const int batch_size, vkcompute::utils::StorageType storage_type, - at::ScalarType dtype = at::kFloat) { + at::ScalarType dtype = at::kFloat, + SDPAMode mode = SDPAMode::DECOMPOSED) { // compute the max sequence length int max_seq_len = start_input_pos; for (int i = 0; i < sequence_lens.size(); ++i) { @@ -296,6 +320,9 @@ void test_vulkan_sdpa( // Get reference output at::Tensor out = at::empty_like(q); + if (mode == SDPAMode::ATTN_WEIGHT_ONLY) { + out = at::empty({batch_size, num_heads, init_seq_len, init_seq_len}); + } // Build Vulkan SDPA graph using namespace vkcompute; @@ -330,22 +357,87 @@ void test_vulkan_sdpa( const ValueRef r_out = graph.add_tensor( out.sizes().vec(), from_at_scalartype(out.scalar_type()), storage_type); - VK_GET_OP_FN("sdpa_with_kv_cache.default") - (graph, - { - r_q.value, - r_k.value, - r_v.value, - r_k_cache_data, - r_v_cache_data, - r_input_pos_symint, - kDummyValueRef, // sequence_len - kDummyValueRef, // attn_mask - kDummyValueRef, // dropout_p - kDummyValueRef, // is_causal - kDummyValueRef, // scale - r_out, - }); + switch (mode) { + case SDPAMode::DECOMPOSED: { + const ValueRef r_k_cache = graph.add_tensor( + k_cache_data.sizes().vec(), + from_at_scalartype(k_cache_data.scalar_type()), + storage_type); + const ValueRef r_v_cache = graph.add_tensor( + v_cache_data.sizes().vec(), + from_at_scalartype(v_cache_data.scalar_type()), + storage_type); + const ValueRef r_dummy_out = graph.add_tensor( + {1}, from_at_scalartype(out.scalar_type()), utils::kBuffer); + VK_GET_OP_FN("update_cache.default") + (graph, + { + r_k.value, + r_k_cache, + r_input_pos_symint, + r_dummy_out, + }); + VK_GET_OP_FN("update_cache.default") + (graph, + { + r_v.value, + r_v_cache, + r_input_pos_symint, + r_dummy_out, + }); + VK_GET_OP_FN("llama.custom_sdpa.default") + (graph, + { + r_q.value, + r_k_cache, + r_v_cache, + r_input_pos_symint, + kDummyValueRef, // attn_mask + kDummyValueRef, // dropout_p + kDummyValueRef, // is_causal + kDummyValueRef, // scale + r_out, + }); + } break; + case SDPAMode::FUSED: + VK_GET_OP_FN("sdpa_with_kv_cache.default") + (graph, + { + r_q.value, + r_k.value, + r_v.value, + r_k_cache_data, + r_v_cache_data, + r_input_pos_symint, + kDummyValueRef, // sequence_len + kDummyValueRef, // attn_mask + kDummyValueRef, // dropout_p + kDummyValueRef, // is_causal + kDummyValueRef, // scale + r_out, + }); + break; + case SDPAMode::ATTN_WEIGHT_ONLY: + VK_GET_OP_FN("testing.compute_attn_weight_with_kv_cache.default") + (graph, + { + r_q.value, + r_k.value, + r_v.value, + r_k_cache_data, + r_v_cache_data, + r_input_pos_symint, + kDummyValueRef, // sequence_len + kDummyValueRef, // attn_mask + kDummyValueRef, // dropout_p + kDummyValueRef, // is_causal + kDummyValueRef, // scale + r_out, + }); + break; + default: + VK_THROW("Unsupported SDPA mode"); + } ValueRef staging_out = graph.set_output_tensor(r_out); @@ -378,7 +470,7 @@ void test_vulkan_sdpa( v = at::rand_like(k); at::Tensor reference_out = sdpa_reference_impl( - q, k, v, k_cache, v_cache, input_pos, seq_len, {}, 0.0, true, {}); + q, k, v, k_cache, v_cache, input_pos, seq_len, {}, 0.0, true, {}, mode); graph.set_symint(r_input_pos_symint, input_pos); graph.resize_input(0, q.sizes().vec()); @@ -393,15 +485,38 @@ void test_vulkan_sdpa( graph.execute(); - out = at::empty_like(q); + if (mode == SDPAMode::ATTN_WEIGHT_ONLY) { + const int context_len = input_pos + seq_len; + const int context_len_align_up4 = (context_len + 3) & ~3; + const int seq_len_align_up4 = (seq_len + 3) & ~3; + + out = at::empty( + {batch_size, num_heads, seq_len_align_up4, context_len_align_up4}, + q.options()); + } else { + out = at::empty_like(q); + } EXTRACT_TENSOR(out); + if (mode == SDPAMode::ATTN_WEIGHT_ONLY) { + // Index vk_out to only include the relevant seq_len and context_len + // dimensions + int context_len = input_pos + seq_len; + vk_out = vk_out.index( + {at::indexing::Slice(), + at::indexing::Slice(), + at::indexing::Slice(0, seq_len), + at::indexing::Slice(0, context_len)}); + } + const bool output_correct = at::allclose(reference_out, vk_out); if (!output_correct) { // Print only differing tensor elements side by side for easier comparison auto ref_flat = reference_out.flatten(); auto vk_flat = vk_out.flatten(); auto numel = ref_flat.numel(); + std::cout << "While testing " << mode << " mode with " << storage_type + << " storage" << std::endl; std::cout << "reference_out\tvk_out\tindex" << std::endl; int first_diff_idx = -1; auto sizes = reference_out.sizes(); @@ -466,27 +581,32 @@ void test_vulkan_sdpa( const int num_kv_heads, const int batch_size, at::ScalarType dtype = at::kFloat) { - // Test texture - test_vulkan_sdpa( - start_input_pos, - sequence_lens, - head_dim, - num_heads, - num_kv_heads, - batch_size, - vkcompute::utils::kTexture3D, - dtype); - - // Test buffer - test_vulkan_sdpa( - start_input_pos, - sequence_lens, - head_dim, - num_heads, - num_kv_heads, - batch_size, - vkcompute::utils::kBuffer, - dtype); + for (SDPAMode mode : + {SDPAMode::ATTN_WEIGHT_ONLY, SDPAMode::DECOMPOSED, SDPAMode::FUSED}) { + // Test texture + test_vulkan_sdpa( + start_input_pos, + sequence_lens, + head_dim, + num_heads, + num_kv_heads, + batch_size, + vkcompute::utils::kTexture3D, + dtype, + mode); + + // Test buffer + test_vulkan_sdpa( + start_input_pos, + sequence_lens, + head_dim, + num_heads, + num_kv_heads, + batch_size, + vkcompute::utils::kBuffer, + dtype, + mode); + } } TEST(VulkanSDPATest, test_sdpa_op_small_params) { From accce77cb0e9dd52c5bdc14546ddee581544cef2 Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 4 Nov 2025 11:52:46 -0800 Subject: [PATCH 3/4] [ET-VK][ez] Align SDPA attention weights S dim to the next multiple of 4 Title says it all! Why? * Technically, this is should not be needed but SDPA op was producing incorrect output on Samsung S24 with buffer input tensors. The exact root cause is unclear, but it appears to be an issue specific to the Adreno 750 since it does not reproduce on any other GPU. The best guess at the moment is that we need to ensure that there is no possibility of multiple threads writing to the same memory location. Differential Revision: [D86226134](https://our.internmc.facebook.com/intern/diff/D86226134/) ghstack-source-id: 320850475 Pull Request resolved: https://github.com/pytorch/executorch/pull/15578 --- .../graph/ops/glsl/sdpa_attn_weights_softmax.glsl | 13 +++++++------ .../ops/glsl/sdpa_compute_attn_weights_coop.glsl | 3 ++- .../ops/glsl/sdpa_compute_attn_weights_tiled.glsl | 3 ++- .../graph/ops/glsl/sdpa_compute_out_coop.glsl | 5 +++-- .../graph/ops/glsl/sdpa_compute_out_tiled.glsl | 5 +++-- backends/vulkan/runtime/graph/ops/impl/SDPA.cpp | 2 +- 6 files changed, 18 insertions(+), 13 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl index 67d9c100f68..652453bbec7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl @@ -76,6 +76,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // manually determine size of the context_len dim of the attention weight. // The "actual" tensor sizes may have been aligned to a multiple of 4 to allow // memory loads to be aligned to texel boundaries. @@ -96,7 +97,7 @@ void main() { // number of threads in the work group. for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S, Q_H); + c4, s, q_h, context_texel_len, S_aligned, Q_H); for (int comp = 0; comp < 4; comp++) { local_exp_sum += exp(in_texel[comp]); @@ -108,7 +109,7 @@ void main() { for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { const int c_base = mul_4(c4); VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S, Q_H); + c4, s, q_h, context_texel_len, S_aligned, Q_H); [[unroll]] for (int comp = 0; comp < 4; comp++) { if (c_base + comp < context_len) { @@ -138,11 +139,11 @@ void main() { // Now go back through each element in the row and normalize for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S, Q_H); + c4, s, q_h, context_texel_len, S_aligned, Q_H); VEC4_T out_texel = exp(in_texel) / local_exp_sum; store_attn_weights_softmax_c4( - out_texel, c4, s, q_h, context_texel_len, S, Q_H); + out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H); } // First thread in the work group responsible for handling last texel if it // contains any padded elements @@ -150,7 +151,7 @@ void main() { for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { const int c_base = mul_4(c4); VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S, Q_H); + c4, s, q_h, context_texel_len, S_aligned, Q_H); // Ensure that padding elements are set to 0. VEC4_T out_texel = VEC4_T(0); @@ -160,7 +161,7 @@ void main() { } } store_attn_weights_softmax_c4( - out_texel, c4, s, q_h, context_texel_len, S, Q_H); + out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl index 2900d63666b..a4bf588949b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl @@ -81,6 +81,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // number of K/V heads const int KV_H = k_cache_sizes.y; @@ -205,7 +206,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl index 95c22d91b80..ef0c3c571c9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl @@ -93,6 +93,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // number of K/V heads const int KV_H = k_cache_sizes.y; @@ -196,6 +197,6 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl index 5f408b7581d..cc60193cf18 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl @@ -81,6 +81,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // number of K/V heads const int KV_H = v_cache_sizes.y; @@ -120,7 +121,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); load_v_cache_tile_no_checks( @@ -146,7 +147,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); load_v_cache_tile_with_checks( diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl index 0063ebf9d38..385ad7a921e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl @@ -75,6 +75,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // number of K/V heads const int KV_H = v_cache_sizes.y; @@ -113,7 +114,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); load_v_cache_tile_no_checks( @@ -136,7 +137,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); load_v_cache_tile_with_checks( diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 6b4da5d95f1..f514530f175 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -50,7 +50,7 @@ void resize_compute_attn_weights_node( std::vector out_sizes = { 1, // batch num_q_heads, - seq_len, + utils::align_up_4(seq_len), utils::align_up_4(context_len)}; graph->virtual_resize(attn_weights, out_sizes); From 037e9af79c7218d67017af6058e4d47640698294 Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 4 Nov 2025 11:52:50 -0800 Subject: [PATCH 4/4] [ET-VK][ez] SDPA don't branch based on whether bounds check needed Title says it all! Why? * The branching path is causing incorrect output on Samsung S24. It's unclear what the exact underlying issue is but the problem is not reproducible on other GPUs and appears to be an issue specific to Adreno 750 architecture. To be safe, always use bounds checking. Differential Revision: [D86226136](https://our.internmc.facebook.com/intern/diff/D86226136/) ghstack-source-id: 320850474 Pull Request resolved: https://github.com/pytorch/executorch/pull/15579 --- .../glsl/sdpa_compute_attn_weights_coop.glsl | 70 ++++++------------ .../glsl/sdpa_compute_attn_weights_tiled.glsl | 71 ++++++------------- 2 files changed, 43 insertions(+), 98 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl index a4bf588949b..7dec6c1697f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl @@ -119,55 +119,27 @@ void main() { } // Otherwise, need to actually compute output tile else { - const bool dont_check_bounds = (S - s) >= TILE_M && - (context_len - c) >= TILE_N; - - if (dont_check_bounds) { - for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) { - load_q_projected_tile_no_checks( - q_tile, - d4, - s, - q_h, - D4, - Q_H, - S); - - load_k_cache_tile_no_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); - } - } else { - for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) { - load_q_projected_tile_with_checks( - q_tile, - d4, - s, - q_h, - D4, - Q_H, - S); - - load_k_cache_tile_with_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); - } + for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) { + load_q_projected_tile_with_checks( + q_tile, + d4, + s, + q_h, + D4, + Q_H, + S); + + load_k_cache_tile_with_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl index ef0c3c571c9..2892f74e05f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl @@ -130,55 +130,28 @@ void main() { } // Otherwise, need to actually compute output tile else { - const bool dont_check_bounds = (S - s) >= TILE_M && - (context_len - c) >= TILE_N; - - if (dont_check_bounds) { - for (int d4 = 0; d4 < D4; d4++) { - load_q_projected_tile_no_checks( - q_tile, - d4, - s, - q_h, - D4, - Q_H, - S); - - load_k_cache_tile_no_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); - } - } else { - for (int d4 = 0; d4 < D4; d4++) { - load_q_projected_tile_with_checks( - q_tile, - d4, - s, - q_h, - D4, - Q_H, - S); - - load_k_cache_tile_with_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); - } + for (int d4 = 0; d4 < D4; d4++) { + load_q_projected_tile_with_checks( + q_tile, + d4, + s, + q_h, + D4, + Q_H, + S); + + load_k_cache_tile_with_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + + fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); } // Apply scale and mask