diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index b47a8f383a0..7672a2d891c 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -630,7 +630,7 @@ def register_dequantize_for_conv2d_op(): @update_features("llama::sdpa_with_kv_cache") def register_sdpa_with_kv_cache_op(): return OpFeatures( - inputs_storage=utils.WIDTH_PACKED_TEXTURE, + inputs_storage=utils.CONTIGUOUS_ANY, supports_resize=True, supports_prepacking=True, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl index 67d9c100f68..652453bbec7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl @@ -76,6 +76,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // manually determine size of the context_len dim of the attention weight. // The "actual" tensor sizes may have been aligned to a multiple of 4 to allow // memory loads to be aligned to texel boundaries. @@ -96,7 +97,7 @@ void main() { // number of threads in the work group. for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S, Q_H); + c4, s, q_h, context_texel_len, S_aligned, Q_H); for (int comp = 0; comp < 4; comp++) { local_exp_sum += exp(in_texel[comp]); @@ -108,7 +109,7 @@ void main() { for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { const int c_base = mul_4(c4); VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S, Q_H); + c4, s, q_h, context_texel_len, S_aligned, Q_H); [[unroll]] for (int comp = 0; comp < 4; comp++) { if (c_base + comp < context_len) { @@ -138,11 +139,11 @@ void main() { // Now go back through each element in the row and normalize for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S, Q_H); + c4, s, q_h, context_texel_len, S_aligned, Q_H); VEC4_T out_texel = exp(in_texel) / local_exp_sum; store_attn_weights_softmax_c4( - out_texel, c4, s, q_h, context_texel_len, S, Q_H); + out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H); } // First thread in the work group responsible for handling last texel if it // contains any padded elements @@ -150,7 +151,7 @@ void main() { for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { const int c_base = mul_4(c4); VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S, Q_H); + c4, s, q_h, context_texel_len, S_aligned, Q_H); // Ensure that padding elements are set to 0. VEC4_T out_texel = VEC4_T(0); @@ -160,7 +161,7 @@ void main() { } } store_attn_weights_softmax_c4( - out_texel, c4, s, q_h, context_texel_len, S, Q_H); + out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl index 2900d63666b..7dec6c1697f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl @@ -81,6 +81,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // number of K/V heads const int KV_H = k_cache_sizes.y; @@ -118,55 +119,27 @@ void main() { } // Otherwise, need to actually compute output tile else { - const bool dont_check_bounds = (S - s) >= TILE_M && - (context_len - c) >= TILE_N; - - if (dont_check_bounds) { - for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) { - load_q_projected_tile_no_checks( - q_tile, - d4, - s, - q_h, - D4, - Q_H, - S); - - load_k_cache_tile_no_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); - } - } else { - for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) { - load_q_projected_tile_with_checks( - q_tile, - d4, - s, - q_h, - D4, - Q_H, - S); - - load_k_cache_tile_with_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); - } + for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) { + load_q_projected_tile_with_checks( + q_tile, + d4, + s, + q_h, + D4, + Q_H, + S); + + load_k_cache_tile_with_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); } } @@ -205,7 +178,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml index 6a4cffcc913..d5cadc36060 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml @@ -12,10 +12,14 @@ sdpa_compute_attn_weights_coop: TILE_K4: 1 TILE_N4: 1 generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, K_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: sdpa_compute_attn_weights_coop_texture3d_texture3d - - NAME: sdpa_compute_attn_weights_coop_buffer_texture3d - IO_STORAGE: buffer + - NAME: sdpa_compute_attn_weights_coop diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl index 95c22d91b80..2892f74e05f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl @@ -93,6 +93,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // number of K/V heads const int KV_H = k_cache_sizes.y; @@ -129,55 +130,28 @@ void main() { } // Otherwise, need to actually compute output tile else { - const bool dont_check_bounds = (S - s) >= TILE_M && - (context_len - c) >= TILE_N; - - if (dont_check_bounds) { - for (int d4 = 0; d4 < D4; d4++) { - load_q_projected_tile_no_checks( - q_tile, - d4, - s, - q_h, - D4, - Q_H, - S); - - load_k_cache_tile_no_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); - } - } else { - for (int d4 = 0; d4 < D4; d4++) { - load_q_projected_tile_with_checks( - q_tile, - d4, - s, - q_h, - D4, - Q_H, - S); - - load_k_cache_tile_with_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); - } + for (int d4 = 0; d4 < D4; d4++) { + load_q_projected_tile_with_checks( + q_tile, + d4, + s, + q_h, + D4, + Q_H, + S); + + load_k_cache_tile_with_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + + fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); } // Apply scale and mask @@ -196,6 +170,6 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml index 6aadbbc379e..7fc016cf3c3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml @@ -13,10 +13,14 @@ sdpa_compute_attn_weights_tiled: TILE_K4: 1 TILE_N4: 1 generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, K_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: sdpa_compute_attn_weights_tiled_texture3d_texture3d - - NAME: sdpa_compute_attn_weights_tiled_buffer_texture3d - IO_STORAGE: buffer + - NAME: sdpa_compute_attn_weights_tiled diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl index 5f408b7581d..cc60193cf18 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl @@ -81,6 +81,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // number of K/V heads const int KV_H = v_cache_sizes.y; @@ -120,7 +121,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); load_v_cache_tile_no_checks( @@ -146,7 +147,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); load_v_cache_tile_with_checks( diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml index ccebf8f7c1c..33ec2f8b322 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml @@ -12,10 +12,14 @@ sdpa_compute_out_coop: TILE_K4: 1 TILE_N4: 1 generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, V_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: sdpa_compute_out_coop_texture3d_texture3d - - NAME: sdpa_compute_out_coop_buffer_texture3d - IO_STORAGE: buffer + - NAME: sdpa_compute_out_coop diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl index 0063ebf9d38..385ad7a921e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl @@ -75,6 +75,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // number of K/V heads const int KV_H = v_cache_sizes.y; @@ -113,7 +114,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); load_v_cache_tile_no_checks( @@ -136,7 +137,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); load_v_cache_tile_with_checks( diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml index 7fbce29e908..eac2c6f37dd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml @@ -13,10 +13,14 @@ sdpa_compute_out_tiled: TILE_K4: 1 TILE_N4: 1 generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, V_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: sdpa_compute_out_tiled_texture3d_texture3d - - NAME: sdpa_compute_out_tiled_buffer_texture3d - IO_STORAGE: buffer + - NAME: sdpa_compute_out_tiled diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl index 932696fff02..028e02d1a20 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl @@ -5,6 +5,8 @@ #define IN_VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} #define T ${buffer_scalar_type(DTYPE)} +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER $if INPUT_STORAGE == "buffer": #define INPUT_BUFFER diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml index 85f4ce090f8..5ec2f3e190c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml @@ -10,10 +10,14 @@ sdpa_kv_cache_update: INPUT_STORAGE: texture3d OUTPUT_STORAGE: texture3d generate_variant_forall: + combination: + parameter_names: [OUTPUT_STORAGE, INPUT_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [texture3d, buffer] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: half - VALUE: float shader_variants: - - NAME: sdpa_kv_cache_update_texture3d - - NAME: sdpa_kv_cache_update_buffer - INPUT_STORAGE: buffer + - NAME: sdpa_kv_cache_update diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 8edaebd11ff..f514530f175 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -50,7 +50,7 @@ void resize_compute_attn_weights_node( std::vector out_sizes = { 1, // batch num_q_heads, - seq_len, + utils::align_up_4(seq_len), utils::align_up_4(context_len)}; graph->virtual_resize(attn_weights, out_sizes); @@ -282,6 +282,7 @@ void add_sdpa_kv_cache_update_node( const ValueRef projected, const ValueRef cache) { std::string kernel_name("sdpa_kv_cache_update"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(cache)); add_storage_type_suffix(kernel_name, graph.storage_type_of(projected)); add_dtype_suffix(kernel_name, graph.dtype_of(projected)); @@ -525,10 +526,11 @@ void sdpa_with_kv_cache_impl( (void)sequence_len; - const ValueRef k_cache = prepack_standard( - graph, k_cache_data, utils::kTexture3D, utils::kWidthPacked); - const ValueRef v_cache = prepack_standard( - graph, v_cache_data, utils::kTexture3D, utils::kWidthPacked); + utils::StorageType cache_storage = graph.storage_type_of(q_projected); + const ValueRef k_cache = + prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked); + const ValueRef v_cache = + prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked); update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); @@ -546,10 +548,51 @@ void sdpa_with_kv_cache_impl( out}); } +void compute_attn_weight_with_kv_cache_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef q_projected = args[arg_idx++]; + const ValueRef k_projected = args[arg_idx++]; + const ValueRef v_projected = args[arg_idx++]; + const ValueRef k_cache_data = args[arg_idx++]; + const ValueRef v_cache_data = args[arg_idx++]; + const ValueRef input_pos_symint = args[arg_idx++]; + const ValueRef sequence_len = args[arg_idx++]; + const ValueRef attn_mask = args[arg_idx++]; + (void)attn_mask; + const ValueRef dropout_p = args[arg_idx++]; + (void)dropout_p; + const ValueRef is_causal = args[arg_idx++]; + (void)is_causal; + const ValueRef scale = args[arg_idx++]; + (void)scale; + + // Output tensors + const ValueRef out = args[arg_idx++]; + + (void)sequence_len; + + utils::StorageType cache_storage = graph.storage_type_of(q_projected); + const ValueRef k_cache = + prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked); + const ValueRef v_cache = + prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked); + + update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); + update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); + + add_sdpa_compute_attn_weights_node( + graph, q_projected, k_cache, input_pos_symint, out); +} + REGISTER_OPERATORS { VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl); VK_REGISTER_OP(update_cache.default, update_cache_impl); VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl); + VK_REGISTER_OP( + testing.compute_attn_weight_with_kv_cache.default, + compute_attn_weight_with_kv_cache_impl); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/sdpa_test.cpp b/backends/vulkan/test/op_tests/sdpa_test.cpp index a94e68a53af..c3347b339a7 100644 --- a/backends/vulkan/test/op_tests/sdpa_test.cpp +++ b/backends/vulkan/test/op_tests/sdpa_test.cpp @@ -23,6 +23,24 @@ #include #include +// +// SDPA Mode Enum +// + +enum class SDPAMode { DECOMPOSED, FUSED, ATTN_WEIGHT_ONLY }; + +std::ostream& operator<<(std::ostream& os, const SDPAMode& mode) { + switch (mode) { + case SDPAMode::DECOMPOSED: + return os << "DECOMPOSED"; + case SDPAMode::FUSED: + return os << "FUSED"; + case SDPAMode::ATTN_WEIGHT_ONLY: + return os << "ATTN_WEIGHT_ONLY"; + } + return os; +} + namespace torch { namespace executor { namespace native { @@ -74,7 +92,7 @@ at::Tensor sdpa_with_kv_cache_aten( const int64_t seq_len, // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy - const std::optional attn_mask, + const std::optional& attn_mask, const double dropout_p, const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy @@ -161,10 +179,11 @@ at::Tensor sdpa_reference_impl( at::Tensor& value_cache, const int64_t start_pos, const int64_t seq_len, - const std::optional __attn_mask_ignored, + const std::optional& __attn_mask_ignored, const double dropout_p, const bool is_causal, - const std::optional scale) { + const std::optional scale, + SDPAMode mode = SDPAMode::DECOMPOSED) { at::Tensor attn_mask = construct_attention_mask(q_projected, key_cache, start_pos); @@ -202,6 +221,10 @@ at::Tensor sdpa_reference_impl( float scale_factor = 1.0 / sqrt(q_transposed.size(-1)); at::Tensor attn_weight = attn_weight_prescale * scale_factor + attn_mask; + if (mode == SDPAMode::ATTN_WEIGHT_ONLY) { + return attn_weight; + } + at::Tensor attn_weight_softmax = at::softmax(attn_weight, -1); at::Tensor out = at::matmul(attn_weight_softmax, v_transposed); @@ -268,7 +291,8 @@ void test_vulkan_sdpa( const int num_kv_heads, const int batch_size, vkcompute::utils::StorageType storage_type, - at::ScalarType dtype = at::kFloat) { + at::ScalarType dtype = at::kFloat, + SDPAMode mode = SDPAMode::DECOMPOSED) { // compute the max sequence length int max_seq_len = start_input_pos; for (int i = 0; i < sequence_lens.size(); ++i) { @@ -296,6 +320,9 @@ void test_vulkan_sdpa( // Get reference output at::Tensor out = at::empty_like(q); + if (mode == SDPAMode::ATTN_WEIGHT_ONLY) { + out = at::empty({batch_size, num_heads, init_seq_len, init_seq_len}); + } // Build Vulkan SDPA graph using namespace vkcompute; @@ -330,22 +357,87 @@ void test_vulkan_sdpa( const ValueRef r_out = graph.add_tensor( out.sizes().vec(), from_at_scalartype(out.scalar_type()), storage_type); - VK_GET_OP_FN("sdpa_with_kv_cache.default") - (graph, - { - r_q.value, - r_k.value, - r_v.value, - r_k_cache_data, - r_v_cache_data, - r_input_pos_symint, - kDummyValueRef, // sequence_len - kDummyValueRef, // attn_mask - kDummyValueRef, // dropout_p - kDummyValueRef, // is_causal - kDummyValueRef, // scale - r_out, - }); + switch (mode) { + case SDPAMode::DECOMPOSED: { + const ValueRef r_k_cache = graph.add_tensor( + k_cache_data.sizes().vec(), + from_at_scalartype(k_cache_data.scalar_type()), + storage_type); + const ValueRef r_v_cache = graph.add_tensor( + v_cache_data.sizes().vec(), + from_at_scalartype(v_cache_data.scalar_type()), + storage_type); + const ValueRef r_dummy_out = graph.add_tensor( + {1}, from_at_scalartype(out.scalar_type()), utils::kBuffer); + VK_GET_OP_FN("update_cache.default") + (graph, + { + r_k.value, + r_k_cache, + r_input_pos_symint, + r_dummy_out, + }); + VK_GET_OP_FN("update_cache.default") + (graph, + { + r_v.value, + r_v_cache, + r_input_pos_symint, + r_dummy_out, + }); + VK_GET_OP_FN("llama.custom_sdpa.default") + (graph, + { + r_q.value, + r_k_cache, + r_v_cache, + r_input_pos_symint, + kDummyValueRef, // attn_mask + kDummyValueRef, // dropout_p + kDummyValueRef, // is_causal + kDummyValueRef, // scale + r_out, + }); + } break; + case SDPAMode::FUSED: + VK_GET_OP_FN("sdpa_with_kv_cache.default") + (graph, + { + r_q.value, + r_k.value, + r_v.value, + r_k_cache_data, + r_v_cache_data, + r_input_pos_symint, + kDummyValueRef, // sequence_len + kDummyValueRef, // attn_mask + kDummyValueRef, // dropout_p + kDummyValueRef, // is_causal + kDummyValueRef, // scale + r_out, + }); + break; + case SDPAMode::ATTN_WEIGHT_ONLY: + VK_GET_OP_FN("testing.compute_attn_weight_with_kv_cache.default") + (graph, + { + r_q.value, + r_k.value, + r_v.value, + r_k_cache_data, + r_v_cache_data, + r_input_pos_symint, + kDummyValueRef, // sequence_len + kDummyValueRef, // attn_mask + kDummyValueRef, // dropout_p + kDummyValueRef, // is_causal + kDummyValueRef, // scale + r_out, + }); + break; + default: + VK_THROW("Unsupported SDPA mode"); + } ValueRef staging_out = graph.set_output_tensor(r_out); @@ -378,7 +470,7 @@ void test_vulkan_sdpa( v = at::rand_like(k); at::Tensor reference_out = sdpa_reference_impl( - q, k, v, k_cache, v_cache, input_pos, seq_len, {}, 0.0, true, {}); + q, k, v, k_cache, v_cache, input_pos, seq_len, {}, 0.0, true, {}, mode); graph.set_symint(r_input_pos_symint, input_pos); graph.resize_input(0, q.sizes().vec()); @@ -393,15 +485,38 @@ void test_vulkan_sdpa( graph.execute(); - out = at::empty_like(q); + if (mode == SDPAMode::ATTN_WEIGHT_ONLY) { + const int context_len = input_pos + seq_len; + const int context_len_align_up4 = (context_len + 3) & ~3; + const int seq_len_align_up4 = (seq_len + 3) & ~3; + + out = at::empty( + {batch_size, num_heads, seq_len_align_up4, context_len_align_up4}, + q.options()); + } else { + out = at::empty_like(q); + } EXTRACT_TENSOR(out); + if (mode == SDPAMode::ATTN_WEIGHT_ONLY) { + // Index vk_out to only include the relevant seq_len and context_len + // dimensions + int context_len = input_pos + seq_len; + vk_out = vk_out.index( + {at::indexing::Slice(), + at::indexing::Slice(), + at::indexing::Slice(0, seq_len), + at::indexing::Slice(0, context_len)}); + } + const bool output_correct = at::allclose(reference_out, vk_out); if (!output_correct) { // Print only differing tensor elements side by side for easier comparison auto ref_flat = reference_out.flatten(); auto vk_flat = vk_out.flatten(); auto numel = ref_flat.numel(); + std::cout << "While testing " << mode << " mode with " << storage_type + << " storage" << std::endl; std::cout << "reference_out\tvk_out\tindex" << std::endl; int first_diff_idx = -1; auto sizes = reference_out.sizes(); @@ -466,27 +581,32 @@ void test_vulkan_sdpa( const int num_kv_heads, const int batch_size, at::ScalarType dtype = at::kFloat) { - // Test texture - test_vulkan_sdpa( - start_input_pos, - sequence_lens, - head_dim, - num_heads, - num_kv_heads, - batch_size, - vkcompute::utils::kTexture3D, - dtype); - - // Test buffer - test_vulkan_sdpa( - start_input_pos, - sequence_lens, - head_dim, - num_heads, - num_kv_heads, - batch_size, - vkcompute::utils::kBuffer, - dtype); + for (SDPAMode mode : + {SDPAMode::ATTN_WEIGHT_ONLY, SDPAMode::DECOMPOSED, SDPAMode::FUSED}) { + // Test texture + test_vulkan_sdpa( + start_input_pos, + sequence_lens, + head_dim, + num_heads, + num_kv_heads, + batch_size, + vkcompute::utils::kTexture3D, + dtype, + mode); + + // Test buffer + test_vulkan_sdpa( + start_input_pos, + sequence_lens, + head_dim, + num_heads, + num_kv_heads, + batch_size, + vkcompute::utils::kBuffer, + dtype, + mode); + } } TEST(VulkanSDPATest, test_sdpa_op_small_params) {