From b980d5a86b0a696af65c448e77469ccb715b986b Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 10 Sep 2025 17:08:55 -0700 Subject: [PATCH] [ET-VK] Misc code cleanup to recent Quantized Linear + SDPA implementations Pull Request resolved: https://github.com/pytorch/executorch/pull/14150 As title. Introduce some minor fixes and code cleanup to the recently added dqlinear and sdpa implementations. ghstack-source-id: 308939584 @exported-using-ghexport Differential Revision: [D82120825](https://our.internmc.facebook.com/intern/diff/D82120825/) --- .../ops/glsl/choose_qparams_per_row.glsl | 2 -- .../vulkan/runtime/graph/ops/glsl/im2col.glsl | 2 -- .../ops/glsl/linear_dq8ca_q4gsw_tiled.glsl | 9 +----- ...ear_fp_output_tile_int8_int4_compute.glslh | 9 ++++-- .../graph/ops/glsl/linear_q4gsw_coop.glsl | 4 +-- .../glsl/sdpa_compute_attn_weights_coop.glsl | 2 -- .../glsl/sdpa_compute_attn_weights_tiled.glsl | 2 -- .../graph/ops/glsl/sdpa_compute_out_coop.glsl | 2 -- .../ops/glsl/sdpa_compute_out_tiled.glsl | 2 -- .../graph/ops/impl/QuantizedLinear.cpp | 30 +++++++++---------- .../vulkan/test/custom_ops/q4gsw_linear.cpp | 6 ++-- 11 files changed, 26 insertions(+), 44 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl index c95bb66f164..639fe312148 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl @@ -40,8 +40,6 @@ layout(push_constant) uniform PushConstants { int quant_max; }; -#extension GL_EXT_debug_printf : enable - // Shared memory for cooperative min/max finding shared T shared_min[NUM_OUTPUTS_PER_WG][NUM_WORKERS_PER_OUTPUT]; shared T shared_max[NUM_OUTPUTS_PER_WG][NUM_WORKERS_PER_OUTPUT]; diff --git a/backends/vulkan/runtime/graph/ops/glsl/im2col.glsl b/backends/vulkan/runtime/graph/ops/glsl/im2col.glsl index f045d4e9702..f006ec993fe 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/im2col.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/im2col.glsl @@ -8,8 +8,6 @@ #version 450 core -#extension GL_EXT_debug_printf : enable - #define PRECISION ${PRECISION} #define VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} #define T ${texel_load_component_type(DTYPE, INPUT_STORAGE)} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl index c7df1b429c7..d918e4c9c7f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl @@ -110,6 +110,7 @@ void main() { IntPerInChannelParams int8_input_sums_tile; const int num_groups = K4 / K4_per_group; + const int group_size = mul_4(K4_per_group); for (int group_i = 0; group_i < num_groups; ++group_i) { // Reset int accumulator @@ -119,7 +120,6 @@ void main() { load_int8_input_tile(int8_in_tile, k4, m4, K4); load_int4_weight_tile(int4_weight_tile, k4, n8, K4); - // load_int4_weight_tile(int4_weight_tile, n8, k4, N8); int_accumulate_with_int4_weight( out_accum, int8_in_tile, int4_weight_tile); @@ -129,13 +129,6 @@ void main() { load_weight_sums_tile_for_group(weight_sums_tile, n4, group_i, N4); load_int8_input_sums_tile_for_group(int8_input_sums_tile, m4, group_i, M4); - const int group_size = mul_4(K4_per_group); - - // // Update output tile with accumulated values - // accumulate_out_tile_with_int_accum_from_int4_weights_test( - // out_tile, - // out_accum); - accumulate_out_tile_with_int_accum_from_int4_weights( out_tile, out_accum, diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int4_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int4_compute.glslh index ac886f78bfb..7bc7071ab1f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int4_compute.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int4_compute.glslh @@ -63,9 +63,12 @@ void accumulate_out_tile_with_int_accum_from_int4_weights( const FPPerOutChannelParams weight_scales, const int group_size) { [[unroll]] for (int m = 0; m < TILE_M; ++m) { - float input_scale_m = input_scales.data[0][m]; - int input_zp_m = input_zps.data[0][m]; - int input_sum_m = input_sums.data[0][m]; + const int m4 = div_4(m); + const int m4i = mod_4(m); + + float input_scale_m = input_scales.data[m4][m4i]; + int input_zp_m = input_zps.data[m4][m4i]; + int input_sum_m = input_sums.data[m4][m4i]; [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { ivec4 accum_adjusted = accum.data[m][n4] - diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl index c9b82425865..72c77eec704 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl @@ -71,8 +71,8 @@ ${layout_declare_spec_const(C, "int", "K4_per_group", "0")} shared FPOutTile partial_sums[WGS]; void main() { - const int lid = int(gl_LocalInvocationID.x); - const int n8 = int(gl_GlobalInvocationID.y); + const int lid = int(gl_LocalInvocationID.z); + const int n8 = int(gl_GlobalInvocationID.x); // The output tensor will have a shape of [n, 1, 1, 1]. Each thread computes // 8 output elements, so each thread will write to 8 elements starting at the diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl index 4b7e3e0ddd2..2900d63666b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl @@ -60,8 +60,6 @@ shared FPOutTile partial_sums[NUM_WORKERS_PER_OUT]; * the entire work group co-operates to compute one reduction output. */ -#extension GL_EXT_debug_printf : enable - void main() { const int worker_id = int(gl_LocalInvocationID.y); diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl index 577d7dea749..95c22d91b80 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl @@ -73,8 +73,6 @@ ${layout_declare_spec_const(C, "float", "inv_scale", "1.0")} * */ -#extension GL_EXT_debug_printf : enable - void main() { const int tile_idx_x = int(gl_GlobalInvocationID.x); const int tile_idx_y = int(gl_GlobalInvocationID.y); diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl index 1fdd803d02b..5f408b7581d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl @@ -60,8 +60,6 @@ shared FPOutTile partial_sums[NUM_WORKERS_PER_OUT]; * the entire work group co-operates to compute one reduction output. */ -#extension GL_EXT_debug_printf : enable - void main() { const int worker_id = int(gl_LocalInvocationID.y); diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl index fb4eaded826..0063ebf9d38 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl @@ -56,8 +56,6 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; * output has shape (batches, seq_len, num_q_heads, head_dim) */ -#extension GL_EXT_debug_printf : enable - void main() { const int tile_idx_x = int(gl_GlobalInvocationID.x); const int tile_idx_y = int(gl_GlobalInvocationID.y); diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 6a50f81830c..7fbfcee5cb1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -61,29 +61,27 @@ utils::uvec3 quantized_linear_global_wg_size( const ValueRef out = args.at(0).refs.at(0); std::vector out_sizes = graph->sizes_of(out); - // height - const uint32_t M = utils::val_at(-2, out_sizes); // width const uint32_t N = utils::val_at(-1, out_sizes); + // height + const uint32_t M = utils::val_at(-2, out_sizes); - const uint32_t M4 = utils::div_up(M, 4u); - const uint32_t N4 = utils::div_up(N, 4u); + uint32_t N_per_tile = 4; + uint32_t M_per_tile = 4; - // For 4-bit weights, each output tile contains 8 columns and 4 rows + // For 4-bit weights, each output tile contains 8 columns if (shader.kernel_name.find("q4") != std::string::npos) { - const uint32_t N8 = utils::div_up(N, 8u); - - const bool using_coop_algorithm = - shader.kernel_name.find("_coop") != std::string::npos; - // TODO: explain - if (using_coop_algorithm) { - return {64, N8, M}; - } - return {N8, M4, 1}; + N_per_tile = 8; + } + if (shader.kernel_name.find("coop") != std::string::npos) { + M_per_tile = 1; } + const uint32_t num_N_tiles = utils::div_up(N, N_per_tile); + const uint32_t num_M_tiles = utils::div_up(M, M_per_tile); + // Otherwise, each output tile contains 4 columns and 4 rows - return {N4, M4, 1}; + return {num_N_tiles, num_M_tiles, 1}; } utils::uvec3 quantized_linear_local_wg_size( @@ -96,7 +94,7 @@ utils::uvec3 quantized_linear_local_wg_size( shader.kernel_name.find("_coop") != std::string::npos; if (use_coop_algorithm) { - return {64, 1, 1}; + return {1, 1, 64}; } else { return pick_hw_square_wg_size( graph, shader, global_workgroup_size, args, resize_args); diff --git a/backends/vulkan/test/custom_ops/q4gsw_linear.cpp b/backends/vulkan/test/custom_ops/q4gsw_linear.cpp index 1c09fdd471f..59d9d694c2c 100644 --- a/backends/vulkan/test/custom_ops/q4gsw_linear.cpp +++ b/backends/vulkan/test/custom_ops/q4gsw_linear.cpp @@ -242,9 +242,9 @@ std::vector generate_quantized_linear_test_cases() { {32, 256, 128, 64, false}, // Performance test cases {1, 2048, 2048, 128}, - {128, 2048, 2048, 256}, - {256, 2048, 2048, 256}, - {1024, 2048, 2048, 256}, + {128, 2048, 2048, 128}, + {256, 2048, 2048, 128}, + {1024, 2048, 2048, 128}, }; // Test with different storage types and data types