diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml index 6a4cffcc913..d5cadc36060 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml @@ -12,10 +12,14 @@ sdpa_compute_attn_weights_coop: TILE_K4: 1 TILE_N4: 1 generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, K_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: sdpa_compute_attn_weights_coop_texture3d_texture3d - - NAME: sdpa_compute_attn_weights_coop_buffer_texture3d - IO_STORAGE: buffer + - NAME: sdpa_compute_attn_weights_coop diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml index 6aadbbc379e..7fc016cf3c3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml @@ -13,10 +13,14 @@ sdpa_compute_attn_weights_tiled: TILE_K4: 1 TILE_N4: 1 generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, K_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: sdpa_compute_attn_weights_tiled_texture3d_texture3d - - NAME: sdpa_compute_attn_weights_tiled_buffer_texture3d - IO_STORAGE: buffer + - NAME: sdpa_compute_attn_weights_tiled diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml index ccebf8f7c1c..33ec2f8b322 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml @@ -12,10 +12,14 @@ sdpa_compute_out_coop: TILE_K4: 1 TILE_N4: 1 generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, V_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: sdpa_compute_out_coop_texture3d_texture3d - - NAME: sdpa_compute_out_coop_buffer_texture3d - IO_STORAGE: buffer + - NAME: sdpa_compute_out_coop diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml index 7fbce29e908..eac2c6f37dd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml @@ -13,10 +13,14 @@ sdpa_compute_out_tiled: TILE_K4: 1 TILE_N4: 1 generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, V_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: sdpa_compute_out_tiled_texture3d_texture3d - - NAME: sdpa_compute_out_tiled_buffer_texture3d - IO_STORAGE: buffer + - NAME: sdpa_compute_out_tiled diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl index 932696fff02..028e02d1a20 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl @@ -5,6 +5,8 @@ #define IN_VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} #define T ${buffer_scalar_type(DTYPE)} +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER $if INPUT_STORAGE == "buffer": #define INPUT_BUFFER diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml index 85f4ce090f8..5ec2f3e190c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml @@ -10,10 +10,14 @@ sdpa_kv_cache_update: INPUT_STORAGE: texture3d OUTPUT_STORAGE: texture3d generate_variant_forall: + combination: + parameter_names: [OUTPUT_STORAGE, INPUT_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [texture3d, buffer] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: half - VALUE: float shader_variants: - - NAME: sdpa_kv_cache_update_texture3d - - NAME: sdpa_kv_cache_update_buffer - INPUT_STORAGE: buffer + - NAME: sdpa_kv_cache_update diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 8edaebd11ff..92b14c3b724 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -282,6 +282,7 @@ void add_sdpa_kv_cache_update_node( const ValueRef projected, const ValueRef cache) { std::string kernel_name("sdpa_kv_cache_update"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(cache)); add_storage_type_suffix(kernel_name, graph.storage_type_of(projected)); add_dtype_suffix(kernel_name, graph.dtype_of(projected));