diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh index ca25e406ac1..850dc7943c0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh @@ -75,7 +75,7 @@ void accumulate_out_tile_with_int_accum( input_zp_vec * weight_sums.data[n4] + accum.data[m][n4]; out_tile.data[m][n4] = fma(VEC4_T(accum_adjusted), - VEC4_T(input_q_scale * weight_scales.data[0]), + VEC4_T(input_q_scale * weight_scales.data[n4]), out_tile.data[m][n4]); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml index aa1de3077fc..989729f2d7f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml @@ -11,7 +11,7 @@ linear_q8ta_q8csw_tiled: PACKED_INT8_INPUT_STORAGE: buffer WEIGHT_STORAGE: texture2d TILE_M4: 1 - TILE_N4: 1 + TILE_N4: 2 TILE_K4: 1 generate_variant_forall: DTYPE: diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 6c841732d9c..97566038501 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -77,6 +77,10 @@ utils::uvec3 quantized_linear_global_wg_size( M_per_tile = 1; } + if (shader.kernel_name.find("q8ta_q8csw_tiled") != std::string::npos) { + N_per_tile = 8; + } + const uint32_t num_N_tiles = utils::div_up(N, N_per_tile); const uint32_t num_M_tiles = utils::div_up(M, M_per_tile);