Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def register_dequantize_for_conv2d_op():
@update_features("llama::sdpa_with_kv_cache")
def register_sdpa_with_kv_cache_op():
return OpFeatures(
inputs_storage=utils.WIDTH_PACKED_TEXTURE,
inputs_storage=utils.CONTIGUOUS_ANY,
supports_resize=True,
supports_prepacking=True,
)
Expand Down
50 changes: 46 additions & 4 deletions backends/vulkan/runtime/graph/ops/impl/SDPA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,10 +526,11 @@ void sdpa_with_kv_cache_impl(

(void)sequence_len;

const ValueRef k_cache = prepack_standard(
graph, k_cache_data, utils::kTexture3D, utils::kWidthPacked);
const ValueRef v_cache = prepack_standard(
graph, v_cache_data, utils::kTexture3D, utils::kWidthPacked);
utils::StorageType cache_storage = graph.storage_type_of(q_projected);
const ValueRef k_cache =
prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked);
const ValueRef v_cache =
prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked);

update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1});
update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1});
Expand All @@ -547,10 +548,51 @@ void sdpa_with_kv_cache_impl(
out});
}

void compute_attn_weight_with_kv_cache_impl(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
int arg_idx = 0;
const ValueRef q_projected = args[arg_idx++];
const ValueRef k_projected = args[arg_idx++];
const ValueRef v_projected = args[arg_idx++];
const ValueRef k_cache_data = args[arg_idx++];
const ValueRef v_cache_data = args[arg_idx++];
const ValueRef input_pos_symint = args[arg_idx++];
const ValueRef sequence_len = args[arg_idx++];
const ValueRef attn_mask = args[arg_idx++];
(void)attn_mask;
const ValueRef dropout_p = args[arg_idx++];
(void)dropout_p;
const ValueRef is_causal = args[arg_idx++];
(void)is_causal;
const ValueRef scale = args[arg_idx++];
(void)scale;

// Output tensors
const ValueRef out = args[arg_idx++];

(void)sequence_len;

utils::StorageType cache_storage = graph.storage_type_of(q_projected);
const ValueRef k_cache =
prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked);
const ValueRef v_cache =
prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked);

update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1});
update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1});

add_sdpa_compute_attn_weights_node(
graph, q_projected, k_cache, input_pos_symint, out);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl);
VK_REGISTER_OP(update_cache.default, update_cache_impl);
VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl);
VK_REGISTER_OP(
testing.compute_attn_weight_with_kv_cache.default,
compute_attn_weight_with_kv_cache_impl);
}

} // namespace vkcompute
206 changes: 163 additions & 43 deletions backends/vulkan/test/op_tests/sdpa_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,24 @@
#include <cassert>
#include <iostream>

//
// SDPA Mode Enum
//

enum class SDPAMode { DECOMPOSED, FUSED, ATTN_WEIGHT_ONLY };

std::ostream& operator<<(std::ostream& os, const SDPAMode& mode) {
switch (mode) {
case SDPAMode::DECOMPOSED:
return os << "DECOMPOSED";
case SDPAMode::FUSED:
return os << "FUSED";
case SDPAMode::ATTN_WEIGHT_ONLY:
return os << "ATTN_WEIGHT_ONLY";
}
return os;
}

namespace torch {
namespace executor {
namespace native {
Expand Down Expand Up @@ -74,7 +92,7 @@ at::Tensor sdpa_with_kv_cache_aten(
const int64_t seq_len,
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const std::optional<at::Tensor> attn_mask,
const std::optional<at::Tensor>& attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
Expand Down Expand Up @@ -161,10 +179,11 @@ at::Tensor sdpa_reference_impl(
at::Tensor& value_cache,
const int64_t start_pos,
const int64_t seq_len,
const std::optional<at::Tensor> __attn_mask_ignored,
const std::optional<at::Tensor>& __attn_mask_ignored,
const double dropout_p,
const bool is_causal,
const std::optional<double> scale) {
const std::optional<double> scale,
SDPAMode mode = SDPAMode::DECOMPOSED) {
at::Tensor attn_mask =
construct_attention_mask(q_projected, key_cache, start_pos);

Expand Down Expand Up @@ -202,6 +221,10 @@ at::Tensor sdpa_reference_impl(
float scale_factor = 1.0 / sqrt(q_transposed.size(-1));
at::Tensor attn_weight = attn_weight_prescale * scale_factor + attn_mask;

if (mode == SDPAMode::ATTN_WEIGHT_ONLY) {
return attn_weight;
}

at::Tensor attn_weight_softmax = at::softmax(attn_weight, -1);
at::Tensor out = at::matmul(attn_weight_softmax, v_transposed);

Expand Down Expand Up @@ -268,7 +291,8 @@ void test_vulkan_sdpa(
const int num_kv_heads,
const int batch_size,
vkcompute::utils::StorageType storage_type,
at::ScalarType dtype = at::kFloat) {
at::ScalarType dtype = at::kFloat,
SDPAMode mode = SDPAMode::DECOMPOSED) {
// compute the max sequence length
int max_seq_len = start_input_pos;
for (int i = 0; i < sequence_lens.size(); ++i) {
Expand Down Expand Up @@ -296,6 +320,9 @@ void test_vulkan_sdpa(

// Get reference output
at::Tensor out = at::empty_like(q);
if (mode == SDPAMode::ATTN_WEIGHT_ONLY) {
out = at::empty({batch_size, num_heads, init_seq_len, init_seq_len});
}

// Build Vulkan SDPA graph
using namespace vkcompute;
Expand Down Expand Up @@ -330,22 +357,87 @@ void test_vulkan_sdpa(
const ValueRef r_out = graph.add_tensor(
out.sizes().vec(), from_at_scalartype(out.scalar_type()), storage_type);

VK_GET_OP_FN("sdpa_with_kv_cache.default")
(graph,
{
r_q.value,
r_k.value,
r_v.value,
r_k_cache_data,
r_v_cache_data,
r_input_pos_symint,
kDummyValueRef, // sequence_len
kDummyValueRef, // attn_mask
kDummyValueRef, // dropout_p
kDummyValueRef, // is_causal
kDummyValueRef, // scale
r_out,
});
switch (mode) {
case SDPAMode::DECOMPOSED: {
const ValueRef r_k_cache = graph.add_tensor(
k_cache_data.sizes().vec(),
from_at_scalartype(k_cache_data.scalar_type()),
storage_type);
const ValueRef r_v_cache = graph.add_tensor(
v_cache_data.sizes().vec(),
from_at_scalartype(v_cache_data.scalar_type()),
storage_type);
const ValueRef r_dummy_out = graph.add_tensor(
{1}, from_at_scalartype(out.scalar_type()), utils::kBuffer);
VK_GET_OP_FN("update_cache.default")
(graph,
{
r_k.value,
r_k_cache,
r_input_pos_symint,
r_dummy_out,
});
VK_GET_OP_FN("update_cache.default")
(graph,
{
r_v.value,
r_v_cache,
r_input_pos_symint,
r_dummy_out,
});
VK_GET_OP_FN("llama.custom_sdpa.default")
(graph,
{
r_q.value,
r_k_cache,
r_v_cache,
r_input_pos_symint,
kDummyValueRef, // attn_mask
kDummyValueRef, // dropout_p
kDummyValueRef, // is_causal
kDummyValueRef, // scale
r_out,
});
} break;
case SDPAMode::FUSED:
VK_GET_OP_FN("sdpa_with_kv_cache.default")
(graph,
{
r_q.value,
r_k.value,
r_v.value,
r_k_cache_data,
r_v_cache_data,
r_input_pos_symint,
kDummyValueRef, // sequence_len
kDummyValueRef, // attn_mask
kDummyValueRef, // dropout_p
kDummyValueRef, // is_causal
kDummyValueRef, // scale
r_out,
});
break;
case SDPAMode::ATTN_WEIGHT_ONLY:
VK_GET_OP_FN("testing.compute_attn_weight_with_kv_cache.default")
(graph,
{
r_q.value,
r_k.value,
r_v.value,
r_k_cache_data,
r_v_cache_data,
r_input_pos_symint,
kDummyValueRef, // sequence_len
kDummyValueRef, // attn_mask
kDummyValueRef, // dropout_p
kDummyValueRef, // is_causal
kDummyValueRef, // scale
r_out,
});
break;
default:
VK_THROW("Unsupported SDPA mode");
}

ValueRef staging_out = graph.set_output_tensor(r_out);

Expand Down Expand Up @@ -378,7 +470,7 @@ void test_vulkan_sdpa(
v = at::rand_like(k);

at::Tensor reference_out = sdpa_reference_impl(
q, k, v, k_cache, v_cache, input_pos, seq_len, {}, 0.0, true, {});
q, k, v, k_cache, v_cache, input_pos, seq_len, {}, 0.0, true, {}, mode);

graph.set_symint(r_input_pos_symint, input_pos);
graph.resize_input(0, q.sizes().vec());
Expand All @@ -393,15 +485,38 @@ void test_vulkan_sdpa(

graph.execute();

out = at::empty_like(q);
if (mode == SDPAMode::ATTN_WEIGHT_ONLY) {
const int context_len = input_pos + seq_len;
const int context_len_align_up4 = (context_len + 3) & ~3;
const int seq_len_align_up4 = (seq_len + 3) & ~3;

out = at::empty(
{batch_size, num_heads, seq_len_align_up4, context_len_align_up4},
q.options());
} else {
out = at::empty_like(q);
}
EXTRACT_TENSOR(out);

if (mode == SDPAMode::ATTN_WEIGHT_ONLY) {
// Index vk_out to only include the relevant seq_len and context_len
// dimensions
int context_len = input_pos + seq_len;
vk_out = vk_out.index(
{at::indexing::Slice(),
at::indexing::Slice(),
at::indexing::Slice(0, seq_len),
at::indexing::Slice(0, context_len)});
}

const bool output_correct = at::allclose(reference_out, vk_out);
if (!output_correct) {
// Print only differing tensor elements side by side for easier comparison
auto ref_flat = reference_out.flatten();
auto vk_flat = vk_out.flatten();
auto numel = ref_flat.numel();
std::cout << "While testing " << mode << " mode with " << storage_type
<< " storage" << std::endl;
std::cout << "reference_out\tvk_out\tindex" << std::endl;
int first_diff_idx = -1;
auto sizes = reference_out.sizes();
Expand Down Expand Up @@ -466,27 +581,32 @@ void test_vulkan_sdpa(
const int num_kv_heads,
const int batch_size,
at::ScalarType dtype = at::kFloat) {
// Test texture
test_vulkan_sdpa(
start_input_pos,
sequence_lens,
head_dim,
num_heads,
num_kv_heads,
batch_size,
vkcompute::utils::kTexture3D,
dtype);

// Test buffer
test_vulkan_sdpa(
start_input_pos,
sequence_lens,
head_dim,
num_heads,
num_kv_heads,
batch_size,
vkcompute::utils::kBuffer,
dtype);
for (SDPAMode mode :
{SDPAMode::ATTN_WEIGHT_ONLY, SDPAMode::DECOMPOSED, SDPAMode::FUSED}) {
// Test texture
test_vulkan_sdpa(
start_input_pos,
sequence_lens,
head_dim,
num_heads,
num_kv_heads,
batch_size,
vkcompute::utils::kTexture3D,
dtype,
mode);

// Test buffer
test_vulkan_sdpa(
start_input_pos,
sequence_lens,
head_dim,
num_heads,
num_kv_heads,
batch_size,
vkcompute::utils::kBuffer,
dtype,
mode);
}
}

TEST(VulkanSDPATest, test_sdpa_op_small_params) {
Expand Down
Loading