Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def register_dequantize_for_conv2d_op():
@update_features("llama::sdpa_with_kv_cache")
def register_sdpa_with_kv_cache_op():
return OpFeatures(
inputs_storage=utils.WIDTH_PACKED_TEXTURE,
inputs_storage=utils.CONTIGUOUS_ANY,
supports_resize=True,
supports_prepacking=True,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ void main() {
const int Q_H = q_projected_sizes.y;
// sequence length
const int S = q_projected_sizes.z;
const int S_aligned = align_up_4(S);
// manually determine size of the context_len dim of the attention weight.
// The "actual" tensor sizes may have been aligned to a multiple of 4 to allow
// memory loads to be aligned to texel boundaries.
Expand All @@ -96,7 +97,7 @@ void main() {
// number of threads in the work group.
for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) {
VEC4_T in_texel = load_attn_weights_c4(
c4, s, q_h, context_texel_len, S, Q_H);
c4, s, q_h, context_texel_len, S_aligned, Q_H);

for (int comp = 0; comp < 4; comp++) {
local_exp_sum += exp(in_texel[comp]);
Expand All @@ -108,7 +109,7 @@ void main() {
for (int c4 = C4_limit; c4 < context_texel_len; ++c4) {
const int c_base = mul_4(c4);
VEC4_T in_texel = load_attn_weights_c4(
c4, s, q_h, context_texel_len, S, Q_H);
c4, s, q_h, context_texel_len, S_aligned, Q_H);

[[unroll]] for (int comp = 0; comp < 4; comp++) {
if (c_base + comp < context_len) {
Expand Down Expand Up @@ -138,19 +139,19 @@ void main() {
// Now go back through each element in the row and normalize
for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) {
VEC4_T in_texel = load_attn_weights_c4(
c4, s, q_h, context_texel_len, S, Q_H);
c4, s, q_h, context_texel_len, S_aligned, Q_H);

VEC4_T out_texel = exp(in_texel) / local_exp_sum;
store_attn_weights_softmax_c4(
out_texel, c4, s, q_h, context_texel_len, S, Q_H);
out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H);
}
// First thread in the work group responsible for handling last texel if it
// contains any padded elements
if (worker_id == 0) {
for (int c4 = C4_limit; c4 < context_texel_len; ++c4) {
const int c_base = mul_4(c4);
VEC4_T in_texel = load_attn_weights_c4(
c4, s, q_h, context_texel_len, S, Q_H);
c4, s, q_h, context_texel_len, S_aligned, Q_H);

// Ensure that padding elements are set to 0.
VEC4_T out_texel = VEC4_T(0);
Expand All @@ -160,7 +161,7 @@ void main() {
}
}
store_attn_weights_softmax_c4(
out_texel, c4, s, q_h, context_texel_len, S, Q_H);
out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ void main() {
const int Q_H = q_projected_sizes.y;
// sequence length
const int S = q_projected_sizes.z;
const int S_aligned = align_up_4(S);

// number of K/V heads
const int KV_H = k_cache_sizes.y;
Expand Down Expand Up @@ -205,7 +206,7 @@ void main() {
s,
q_h,
context_texel_len,
S,
S_aligned,
Q_H);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ sdpa_compute_attn_weights_coop:
TILE_K4: 1
TILE_N4: 1
generate_variant_forall:
combination:
parameter_names: [IO_STORAGE, K_CACHE_STORAGE]
combos:
- parameter_values: [texture3d, texture3d]
- parameter_values: [buffer, texture3d]
- parameter_values: [buffer, buffer]
DTYPE:
- VALUE: float
- VALUE: half
shader_variants:
- NAME: sdpa_compute_attn_weights_coop_texture3d_texture3d
- NAME: sdpa_compute_attn_weights_coop_buffer_texture3d
IO_STORAGE: buffer
- NAME: sdpa_compute_attn_weights_coop
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ void main() {
const int Q_H = q_projected_sizes.y;
// sequence length
const int S = q_projected_sizes.z;
const int S_aligned = align_up_4(S);

// number of K/V heads
const int KV_H = k_cache_sizes.y;
Expand Down Expand Up @@ -196,6 +197,6 @@ void main() {
s,
q_h,
context_texel_len,
S,
S_aligned,
Q_H);
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ sdpa_compute_attn_weights_tiled:
TILE_K4: 1
TILE_N4: 1
generate_variant_forall:
combination:
parameter_names: [IO_STORAGE, K_CACHE_STORAGE]
combos:
- parameter_values: [texture3d, texture3d]
- parameter_values: [buffer, texture3d]
- parameter_values: [buffer, buffer]
DTYPE:
- VALUE: float
- VALUE: half
shader_variants:
- NAME: sdpa_compute_attn_weights_tiled_texture3d_texture3d
- NAME: sdpa_compute_attn_weights_tiled_buffer_texture3d
IO_STORAGE: buffer
- NAME: sdpa_compute_attn_weights_tiled
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ void main() {
const int Q_H = q_projected_sizes.y;
// sequence length
const int S = q_projected_sizes.z;
const int S_aligned = align_up_4(S);

// number of K/V heads
const int KV_H = v_cache_sizes.y;
Expand Down Expand Up @@ -120,7 +121,7 @@ void main() {
s,
q_h,
context_texel_len,
S,
S_aligned,
Q_H);

load_v_cache_tile_no_checks(
Expand All @@ -146,7 +147,7 @@ void main() {
s,
q_h,
context_texel_len,
S,
S_aligned,
Q_H);

load_v_cache_tile_with_checks(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ sdpa_compute_out_coop:
TILE_K4: 1
TILE_N4: 1
generate_variant_forall:
combination:
parameter_names: [IO_STORAGE, V_CACHE_STORAGE]
combos:
- parameter_values: [texture3d, texture3d]
- parameter_values: [buffer, texture3d]
- parameter_values: [buffer, buffer]
DTYPE:
- VALUE: float
- VALUE: half
shader_variants:
- NAME: sdpa_compute_out_coop_texture3d_texture3d
- NAME: sdpa_compute_out_coop_buffer_texture3d
IO_STORAGE: buffer
- NAME: sdpa_compute_out_coop
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ void main() {
const int Q_H = q_projected_sizes.y;
// sequence length
const int S = q_projected_sizes.z;
const int S_aligned = align_up_4(S);

// number of K/V heads
const int KV_H = v_cache_sizes.y;
Expand Down Expand Up @@ -113,7 +114,7 @@ void main() {
s,
q_h,
context_texel_len,
S,
S_aligned,
Q_H);

load_v_cache_tile_no_checks(
Expand All @@ -136,7 +137,7 @@ void main() {
s,
q_h,
context_texel_len,
S,
S_aligned,
Q_H);

load_v_cache_tile_with_checks(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ sdpa_compute_out_tiled:
TILE_K4: 1
TILE_N4: 1
generate_variant_forall:
combination:
parameter_names: [IO_STORAGE, V_CACHE_STORAGE]
combos:
- parameter_values: [texture3d, texture3d]
- parameter_values: [buffer, texture3d]
- parameter_values: [buffer, buffer]
DTYPE:
- VALUE: float
- VALUE: half
shader_variants:
- NAME: sdpa_compute_out_tiled_texture3d_texture3d
- NAME: sdpa_compute_out_tiled_buffer_texture3d
IO_STORAGE: buffer
- NAME: sdpa_compute_out_tiled
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#define IN_VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)}
#define T ${buffer_scalar_type(DTYPE)}

$if OUTPUT_STORAGE == "buffer":
#define OUTPUT_BUFFER
$if INPUT_STORAGE == "buffer":
#define INPUT_BUFFER

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@ sdpa_kv_cache_update:
INPUT_STORAGE: texture3d
OUTPUT_STORAGE: texture3d
generate_variant_forall:
combination:
parameter_names: [OUTPUT_STORAGE, INPUT_STORAGE]
combos:
- parameter_values: [texture3d, texture3d]
- parameter_values: [texture3d, buffer]
- parameter_values: [buffer, buffer]
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: sdpa_kv_cache_update_texture3d
- NAME: sdpa_kv_cache_update_buffer
INPUT_STORAGE: buffer
- NAME: sdpa_kv_cache_update
53 changes: 48 additions & 5 deletions backends/vulkan/runtime/graph/ops/impl/SDPA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void resize_compute_attn_weights_node(
std::vector<int64_t> out_sizes = {
1, // batch
num_q_heads,
seq_len,
utils::align_up_4(seq_len),
utils::align_up_4(context_len)};

graph->virtual_resize(attn_weights, out_sizes);
Expand Down Expand Up @@ -282,6 +282,7 @@ void add_sdpa_kv_cache_update_node(
const ValueRef projected,
const ValueRef cache) {
std::string kernel_name("sdpa_kv_cache_update");
add_storage_type_suffix(kernel_name, graph.storage_type_of(cache));
add_storage_type_suffix(kernel_name, graph.storage_type_of(projected));
add_dtype_suffix(kernel_name, graph.dtype_of(projected));

Expand Down Expand Up @@ -525,10 +526,11 @@ void sdpa_with_kv_cache_impl(

(void)sequence_len;

const ValueRef k_cache = prepack_standard(
graph, k_cache_data, utils::kTexture3D, utils::kWidthPacked);
const ValueRef v_cache = prepack_standard(
graph, v_cache_data, utils::kTexture3D, utils::kWidthPacked);
utils::StorageType cache_storage = graph.storage_type_of(q_projected);
const ValueRef k_cache =
prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked);
const ValueRef v_cache =
prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked);

update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1});
update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1});
Expand All @@ -546,10 +548,51 @@ void sdpa_with_kv_cache_impl(
out});
}

void compute_attn_weight_with_kv_cache_impl(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
int arg_idx = 0;
const ValueRef q_projected = args[arg_idx++];
const ValueRef k_projected = args[arg_idx++];
const ValueRef v_projected = args[arg_idx++];
const ValueRef k_cache_data = args[arg_idx++];
const ValueRef v_cache_data = args[arg_idx++];
const ValueRef input_pos_symint = args[arg_idx++];
const ValueRef sequence_len = args[arg_idx++];
const ValueRef attn_mask = args[arg_idx++];
(void)attn_mask;
const ValueRef dropout_p = args[arg_idx++];
(void)dropout_p;
const ValueRef is_causal = args[arg_idx++];
(void)is_causal;
const ValueRef scale = args[arg_idx++];
(void)scale;

// Output tensors
const ValueRef out = args[arg_idx++];

(void)sequence_len;

utils::StorageType cache_storage = graph.storage_type_of(q_projected);
const ValueRef k_cache =
prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked);
const ValueRef v_cache =
prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked);

update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1});
update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1});

add_sdpa_compute_attn_weights_node(
graph, q_projected, k_cache, input_pos_symint, out);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl);
VK_REGISTER_OP(update_cache.default, update_cache_impl);
VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl);
VK_REGISTER_OP(
testing.compute_attn_weight_with_kv_cache.default,
compute_attn_weight_with_kv_cache_impl);
}

} // namespace vkcompute
Loading
Loading