From 373dc7cfad149a895962f35576b59495bef46949 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi Date: Sun, 19 Oct 2025 05:43:48 -0700 Subject: [PATCH] Converting all uint16 to int in quantized mat mul shader to improve perf. (#15193) Summary: ## This Diff This diff improves the performance of the quantized matrix multiplication shader in the Executorch Vulkan runtime by converting all `uint16` to `int` in the shader code. Reviewed By: SS-JIA Differential Revision: D84777696 --- .../graph/ops/glsl/linear_qcsnw_tiled.glsl | 32 ++++++++----------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl index 936fd641a9b..88b054e2cb2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl @@ -21,8 +21,6 @@ ${define_required_extensions(DTYPE)} $if WEIGHT_STORAGE == "buffer": ${define_required_extensions("int8")} -#extension GL_EXT_control_flow_attributes : require - layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} @@ -49,20 +47,18 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { // txcol stands for "texel column". One txcol corresponds to 4 scalar columns. $if TILE_TXCOLS > 1: - const uint16_t global_wg_x = uint16_t(divup(out_sizes.x, 4 * TILE_TXCOLS)); - const uint16_t out_txcol = uint16_t( - (gl_GlobalInvocationID.x % global_wg_x) * TILE_TXCOLS); + const int global_wg_x = divup(out_sizes.x, 4 * TILE_TXCOLS); + const int out_txcol = (int(gl_GlobalInvocationID.x) % global_wg_x) * TILE_TXCOLS; $else: - const uint16_t global_wg_x = uint16_t(divup4(out_sizes.x)); - const uint16_t out_txcol = uint16_t(gl_GlobalInvocationID.x % global_wg_x); + const int global_wg_x = divup4(out_sizes.x); + const int out_txcol = int(gl_GlobalInvocationID.x) % global_wg_x; - const uint16_t out_row = uint16_t( - (gl_GlobalInvocationID.x / global_wg_x) * TILE_ROWS); + const int out_row = (int(gl_GlobalInvocationID.x) / global_wg_x) * TILE_ROWS; $if QUANT_NBITS == 4: - const uint16_t weight_txcol = uint16_t(out_txcol / 2); + const int weight_txcol = out_txcol / 2; - if (out_row >= uint16_t(out_sizes.y)) { + if (out_row >= int(out_sizes.y)) { return; } @@ -73,9 +69,9 @@ void main() { sums[r][${c}] = VEC4_T(0.0); } - for (uint16_t pos = uint16_t(0), txpos = uint16_t(0); - pos < uint16_t(in_sizes.x); - pos += uint16_t(4), txpos += uint16_t(1)) { + for (int pos = 0, txpos = 0; + pos < in_sizes.x; + pos += 4, txpos += 1) { T mat1[TILE_ROWS][4]; @@ -91,7 +87,7 @@ void main() { mat1[i][2] = tmp.z; mat1[i][3] = tmp.w; $else: - VEC4_T tmp = VEC4_T(texelFetch(t_in, u16vec3(txpos, out_row + i, 0), 0)); + VEC4_T tmp = VEC4_T(texelFetch(t_in, ivec3(txpos, out_row + i, 0), 0)); mat1[i][0] = tmp.x; mat1[i][1] = tmp.y; mat1[i][2] = tmp.z; @@ -117,7 +113,7 @@ void main() { packed_weight_tex = t_weight[qmat2_bufi + ${c}] $else: packed_weight_tex = texelFetch( - t_weight, u16vec2(weight_txcol + ${c}, pos + r), 0); + t_weight, ivec2(weight_txcol + ${c}, pos + r), 0); qmat2[${c}] = (VEC4_T(packed_weight_tex >> 4) - 8.0); qmat2[${c + 1}] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0); @@ -128,7 +124,7 @@ void main() { qmat2[${c}] = t_weight[qmat2_bufi + ${c}]; $else: qmat2[${c}] = VEC4_T( - texelFetch(t_weight, u16vec2(out_txcol + ${c}, pos + r), 0)); + texelFetch(t_weight, ivec2(out_txcol + ${c}, pos + r), 0)); for (int tr = 0; tr < TILE_ROWS; ++tr) { $for c in range(TILE_TXCOLS): @@ -143,7 +139,7 @@ void main() { scales[${c}] = VEC4_T(t_scales[out_txcol + ${c}]); $else: scales[${c}] = VEC4_T( - texelFetch(t_scales, u16vec2(out_txcol + ${c}, 0), 0)); + texelFetch(t_scales, ivec2(out_txcol + ${c}, 0), 0)); // Store to output tensor $if OUT_STORAGE == "buffer":