diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index b47a8f383a0..7672a2d891c 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -630,7 +630,7 @@ def register_dequantize_for_conv2d_op(): @update_features("llama::sdpa_with_kv_cache") def register_sdpa_with_kv_cache_op(): return OpFeatures( - inputs_storage=utils.WIDTH_PACKED_TEXTURE, + inputs_storage=utils.CONTIGUOUS_ANY, supports_resize=True, supports_prepacking=True, ) diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 92b14c3b724..6b4da5d95f1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -526,10 +526,11 @@ void sdpa_with_kv_cache_impl( (void)sequence_len; - const ValueRef k_cache = prepack_standard( - graph, k_cache_data, utils::kTexture3D, utils::kWidthPacked); - const ValueRef v_cache = prepack_standard( - graph, v_cache_data, utils::kTexture3D, utils::kWidthPacked); + utils::StorageType cache_storage = graph.storage_type_of(q_projected); + const ValueRef k_cache = + prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked); + const ValueRef v_cache = + prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked); update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); @@ -547,10 +548,51 @@ void sdpa_with_kv_cache_impl( out}); } +void compute_attn_weight_with_kv_cache_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef q_projected = args[arg_idx++]; + const ValueRef k_projected = args[arg_idx++]; + const ValueRef v_projected = args[arg_idx++]; + const ValueRef k_cache_data = args[arg_idx++]; + const ValueRef v_cache_data = args[arg_idx++]; + const ValueRef input_pos_symint = args[arg_idx++]; + const ValueRef sequence_len = args[arg_idx++]; + const ValueRef attn_mask = args[arg_idx++]; + (void)attn_mask; + const ValueRef dropout_p = args[arg_idx++]; + (void)dropout_p; + const ValueRef is_causal = args[arg_idx++]; + (void)is_causal; + const ValueRef scale = args[arg_idx++]; + (void)scale; + + // Output tensors + const ValueRef out = args[arg_idx++]; + + (void)sequence_len; + + utils::StorageType cache_storage = graph.storage_type_of(q_projected); + const ValueRef k_cache = + prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked); + const ValueRef v_cache = + prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked); + + update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); + update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); + + add_sdpa_compute_attn_weights_node( + graph, q_projected, k_cache, input_pos_symint, out); +} + REGISTER_OPERATORS { VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl); VK_REGISTER_OP(update_cache.default, update_cache_impl); VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl); + VK_REGISTER_OP( + testing.compute_attn_weight_with_kv_cache.default, + compute_attn_weight_with_kv_cache_impl); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/sdpa_test.cpp b/backends/vulkan/test/op_tests/sdpa_test.cpp index a94e68a53af..c3347b339a7 100644 --- a/backends/vulkan/test/op_tests/sdpa_test.cpp +++ b/backends/vulkan/test/op_tests/sdpa_test.cpp @@ -23,6 +23,24 @@ #include #include +// +// SDPA Mode Enum +// + +enum class SDPAMode { DECOMPOSED, FUSED, ATTN_WEIGHT_ONLY }; + +std::ostream& operator<<(std::ostream& os, const SDPAMode& mode) { + switch (mode) { + case SDPAMode::DECOMPOSED: + return os << "DECOMPOSED"; + case SDPAMode::FUSED: + return os << "FUSED"; + case SDPAMode::ATTN_WEIGHT_ONLY: + return os << "ATTN_WEIGHT_ONLY"; + } + return os; +} + namespace torch { namespace executor { namespace native { @@ -74,7 +92,7 @@ at::Tensor sdpa_with_kv_cache_aten( const int64_t seq_len, // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy - const std::optional attn_mask, + const std::optional& attn_mask, const double dropout_p, const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy @@ -161,10 +179,11 @@ at::Tensor sdpa_reference_impl( at::Tensor& value_cache, const int64_t start_pos, const int64_t seq_len, - const std::optional __attn_mask_ignored, + const std::optional& __attn_mask_ignored, const double dropout_p, const bool is_causal, - const std::optional scale) { + const std::optional scale, + SDPAMode mode = SDPAMode::DECOMPOSED) { at::Tensor attn_mask = construct_attention_mask(q_projected, key_cache, start_pos); @@ -202,6 +221,10 @@ at::Tensor sdpa_reference_impl( float scale_factor = 1.0 / sqrt(q_transposed.size(-1)); at::Tensor attn_weight = attn_weight_prescale * scale_factor + attn_mask; + if (mode == SDPAMode::ATTN_WEIGHT_ONLY) { + return attn_weight; + } + at::Tensor attn_weight_softmax = at::softmax(attn_weight, -1); at::Tensor out = at::matmul(attn_weight_softmax, v_transposed); @@ -268,7 +291,8 @@ void test_vulkan_sdpa( const int num_kv_heads, const int batch_size, vkcompute::utils::StorageType storage_type, - at::ScalarType dtype = at::kFloat) { + at::ScalarType dtype = at::kFloat, + SDPAMode mode = SDPAMode::DECOMPOSED) { // compute the max sequence length int max_seq_len = start_input_pos; for (int i = 0; i < sequence_lens.size(); ++i) { @@ -296,6 +320,9 @@ void test_vulkan_sdpa( // Get reference output at::Tensor out = at::empty_like(q); + if (mode == SDPAMode::ATTN_WEIGHT_ONLY) { + out = at::empty({batch_size, num_heads, init_seq_len, init_seq_len}); + } // Build Vulkan SDPA graph using namespace vkcompute; @@ -330,22 +357,87 @@ void test_vulkan_sdpa( const ValueRef r_out = graph.add_tensor( out.sizes().vec(), from_at_scalartype(out.scalar_type()), storage_type); - VK_GET_OP_FN("sdpa_with_kv_cache.default") - (graph, - { - r_q.value, - r_k.value, - r_v.value, - r_k_cache_data, - r_v_cache_data, - r_input_pos_symint, - kDummyValueRef, // sequence_len - kDummyValueRef, // attn_mask - kDummyValueRef, // dropout_p - kDummyValueRef, // is_causal - kDummyValueRef, // scale - r_out, - }); + switch (mode) { + case SDPAMode::DECOMPOSED: { + const ValueRef r_k_cache = graph.add_tensor( + k_cache_data.sizes().vec(), + from_at_scalartype(k_cache_data.scalar_type()), + storage_type); + const ValueRef r_v_cache = graph.add_tensor( + v_cache_data.sizes().vec(), + from_at_scalartype(v_cache_data.scalar_type()), + storage_type); + const ValueRef r_dummy_out = graph.add_tensor( + {1}, from_at_scalartype(out.scalar_type()), utils::kBuffer); + VK_GET_OP_FN("update_cache.default") + (graph, + { + r_k.value, + r_k_cache, + r_input_pos_symint, + r_dummy_out, + }); + VK_GET_OP_FN("update_cache.default") + (graph, + { + r_v.value, + r_v_cache, + r_input_pos_symint, + r_dummy_out, + }); + VK_GET_OP_FN("llama.custom_sdpa.default") + (graph, + { + r_q.value, + r_k_cache, + r_v_cache, + r_input_pos_symint, + kDummyValueRef, // attn_mask + kDummyValueRef, // dropout_p + kDummyValueRef, // is_causal + kDummyValueRef, // scale + r_out, + }); + } break; + case SDPAMode::FUSED: + VK_GET_OP_FN("sdpa_with_kv_cache.default") + (graph, + { + r_q.value, + r_k.value, + r_v.value, + r_k_cache_data, + r_v_cache_data, + r_input_pos_symint, + kDummyValueRef, // sequence_len + kDummyValueRef, // attn_mask + kDummyValueRef, // dropout_p + kDummyValueRef, // is_causal + kDummyValueRef, // scale + r_out, + }); + break; + case SDPAMode::ATTN_WEIGHT_ONLY: + VK_GET_OP_FN("testing.compute_attn_weight_with_kv_cache.default") + (graph, + { + r_q.value, + r_k.value, + r_v.value, + r_k_cache_data, + r_v_cache_data, + r_input_pos_symint, + kDummyValueRef, // sequence_len + kDummyValueRef, // attn_mask + kDummyValueRef, // dropout_p + kDummyValueRef, // is_causal + kDummyValueRef, // scale + r_out, + }); + break; + default: + VK_THROW("Unsupported SDPA mode"); + } ValueRef staging_out = graph.set_output_tensor(r_out); @@ -378,7 +470,7 @@ void test_vulkan_sdpa( v = at::rand_like(k); at::Tensor reference_out = sdpa_reference_impl( - q, k, v, k_cache, v_cache, input_pos, seq_len, {}, 0.0, true, {}); + q, k, v, k_cache, v_cache, input_pos, seq_len, {}, 0.0, true, {}, mode); graph.set_symint(r_input_pos_symint, input_pos); graph.resize_input(0, q.sizes().vec()); @@ -393,15 +485,38 @@ void test_vulkan_sdpa( graph.execute(); - out = at::empty_like(q); + if (mode == SDPAMode::ATTN_WEIGHT_ONLY) { + const int context_len = input_pos + seq_len; + const int context_len_align_up4 = (context_len + 3) & ~3; + const int seq_len_align_up4 = (seq_len + 3) & ~3; + + out = at::empty( + {batch_size, num_heads, seq_len_align_up4, context_len_align_up4}, + q.options()); + } else { + out = at::empty_like(q); + } EXTRACT_TENSOR(out); + if (mode == SDPAMode::ATTN_WEIGHT_ONLY) { + // Index vk_out to only include the relevant seq_len and context_len + // dimensions + int context_len = input_pos + seq_len; + vk_out = vk_out.index( + {at::indexing::Slice(), + at::indexing::Slice(), + at::indexing::Slice(0, seq_len), + at::indexing::Slice(0, context_len)}); + } + const bool output_correct = at::allclose(reference_out, vk_out); if (!output_correct) { // Print only differing tensor elements side by side for easier comparison auto ref_flat = reference_out.flatten(); auto vk_flat = vk_out.flatten(); auto numel = ref_flat.numel(); + std::cout << "While testing " << mode << " mode with " << storage_type + << " storage" << std::endl; std::cout << "reference_out\tvk_out\tindex" << std::endl; int first_diff_idx = -1; auto sizes = reference_out.sizes(); @@ -466,27 +581,32 @@ void test_vulkan_sdpa( const int num_kv_heads, const int batch_size, at::ScalarType dtype = at::kFloat) { - // Test texture - test_vulkan_sdpa( - start_input_pos, - sequence_lens, - head_dim, - num_heads, - num_kv_heads, - batch_size, - vkcompute::utils::kTexture3D, - dtype); - - // Test buffer - test_vulkan_sdpa( - start_input_pos, - sequence_lens, - head_dim, - num_heads, - num_kv_heads, - batch_size, - vkcompute::utils::kBuffer, - dtype); + for (SDPAMode mode : + {SDPAMode::ATTN_WEIGHT_ONLY, SDPAMode::DECOMPOSED, SDPAMode::FUSED}) { + // Test texture + test_vulkan_sdpa( + start_input_pos, + sequence_lens, + head_dim, + num_heads, + num_kv_heads, + batch_size, + vkcompute::utils::kTexture3D, + dtype, + mode); + + // Test buffer + test_vulkan_sdpa( + start_input_pos, + sequence_lens, + head_dim, + num_heads, + num_kv_heads, + batch_size, + vkcompute::utils::kBuffer, + dtype, + mode); + } } TEST(VulkanSDPATest, test_sdpa_op_small_params) {