From 48fa7b0e861deee142199e17599372d40d469ba8 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 25 Sep 2025 08:52:21 -0700 Subject: [PATCH] [ET-VK] Improve q8 matmul by increasing TILE_N4 Title says it all! I found that the latency of executing int8 matmul can be improved by increases the output tile's N4 dimension to 2. The improvement is about 20-25% on Samsung Galaxy S25. Differential Revision: [D83253129](https://our.internmc.facebook.com/intern/diff/D83253129/) [ghstack-poisoned] --- .../ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh | 2 +- .../runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml | 2 +- backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp | 4 ++++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh index ca25e406ac1..850dc7943c0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh @@ -75,7 +75,7 @@ void accumulate_out_tile_with_int_accum( input_zp_vec * weight_sums.data[n4] + accum.data[m][n4]; out_tile.data[m][n4] = fma(VEC4_T(accum_adjusted), - VEC4_T(input_q_scale * weight_scales.data[0]), + VEC4_T(input_q_scale * weight_scales.data[n4]), out_tile.data[m][n4]); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml index aa1de3077fc..989729f2d7f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml @@ -11,7 +11,7 @@ linear_q8ta_q8csw_tiled: PACKED_INT8_INPUT_STORAGE: buffer WEIGHT_STORAGE: texture2d TILE_M4: 1 - TILE_N4: 1 + TILE_N4: 2 TILE_K4: 1 generate_variant_forall: DTYPE: diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 6c841732d9c..97566038501 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -77,6 +77,10 @@ utils::uvec3 quantized_linear_global_wg_size( M_per_tile = 1; } + if (shader.kernel_name.find("q8ta_q8csw_tiled") != std::string::npos) { + N_per_tile = 8; + } + const uint32_t num_N_tiles = utils::div_up(N, N_per_tile); const uint32_t num_M_tiles = utils::div_up(M, M_per_tile);