From 7ab09517cb7c9adb31b71e56834d7fb81c468e77 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi Date: Wed, 26 Nov 2025 13:17:05 -0800 Subject: [PATCH 1/2] Use 4x3 tiled shader for linear mat mul which performs slightly better. (#15988) Summary: This diff optimizes the performance of the quantized linear matrix multiplication operation by using a 4x3 tiled shader, which performs slightly better than the previous implementation. Reviewed By: yipjustin Differential Revision: D87902847 --- .../graph/ops/glsl/linear_qcsnw_tiled.glsl | 20 ++++++------- .../graph/ops/glsl/linear_qcsnw_tiled.yaml | 2 ++ .../graph/ops/impl/QuantizedLinearQCSNW.cpp | 28 +++++++++---------- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl index c364e70bc9f..1e5de21cffc 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl @@ -121,17 +121,15 @@ void main() { packed_weight_tex = texelFetch( t_weight, ivec2(weight_txcol + ${c}, pos + r), 0); - const uvec4 tmp1 = packed_weight_tex >> 4; - qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(tmp1.x); - qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(tmp1.y); - qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(tmp1.z); - qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(tmp1.w); - - const uvec4 tmp2 = packed_weight_tex & 0x0F; - qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(tmp2.x); - qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(tmp2.y); - qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(tmp2.z); - qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(tmp2.w); + qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(packed_weight_tex.x >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(packed_weight_tex.y >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(packed_weight_tex.z >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(packed_weight_tex.w >> 4); + + qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(packed_weight_tex.x & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(packed_weight_tex.y & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(packed_weight_tex.z & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(packed_weight_tex.w & 0xF); $else: $for c in range(TILE_TXCOLS): $if WEIGHT_STORAGE == "buffer": diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml index 287b2ee9333..81824a12026 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml @@ -20,6 +20,8 @@ linear_qcsnw_tiled: SUFFIX: o4x1 - VALUE: 2 SUFFIX: o4x2 + - VALUE: 3 + SUFFIX: o4x3 - VALUE: 4 SUFFIX: o4x4 shader_variants: diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp index e4e08363c6d..18958ccc3ce 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp @@ -61,15 +61,15 @@ utils::uvec3 linear_qcsnw_tiled_global_wg_size( std::vector mat1_sizes = graph->sizes_of(mat1); const int64_t M = utils::val_at(-2, mat1_sizes); - uint32_t out_tile_nrows = 4; - if (M % 6 == 0) { - out_tile_nrows = 2; + uint32_t out_tile_nrows = 1; + if (M % 3 == 0) { + out_tile_nrows = 3; } else if (M % 4 == 0) { out_tile_nrows = 4; - } else if (M % 1 == 0) { - out_tile_nrows = 1; + } else if (M % 2 == 0) { + out_tile_nrows = 2; } else { - out_tile_nrows = 4; + out_tile_nrows = 1; } // Number of output texels in the output tile @@ -309,19 +309,19 @@ void add_linear_qcsnw_tiled_node( std::vector mat1_sizes = graph.sizes_of(mat1); const int64_t M = utils::val_at(-2, mat1_sizes); - uint32_t out_tile_nrows = 4; - if (M % 6 == 0) { - kernel_name += "_o4x2"; - out_tile_nrows = 2; + uint32_t out_tile_nrows = 1; + if (M % 3 == 0) { + kernel_name += "_o4x3"; + out_tile_nrows = 3; } else if (M % 4 == 0) { kernel_name += "_o4x4"; out_tile_nrows = 4; - } else if (M % 1 == 0) { + } else if (M % 2 == 0) { + kernel_name += "_o4x2"; + out_tile_nrows = 2; + } else { kernel_name += "_o4x1"; out_tile_nrows = 1; - } else { - kernel_name += "_o4x4"; - out_tile_nrows = 4; } // Number of output texels in the output tile From 79e115ec267434fcdda082393173221669b3ffdf Mon Sep 17 00:00:00 2001 From: Vivek Trivedi Date: Wed, 26 Nov 2025 13:17:05 -0800 Subject: [PATCH 2/2] Some minor performance improvements to buffer 4b mat mul. (#15989) Summary: The code change in this diff aims to improve the performance of buffer 4b matrix multiplication by reducing unnecessary computations and by spreading operations to allow better latency hiding. Reviewed By: yipjustin Differential Revision: D87910988 --- .../graph/ops/glsl/linear_qcsnw_tiled.glsl | 40 ++++++++++++------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl index 1e5de21cffc..d966de7282e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl @@ -78,6 +78,12 @@ void main() { const int in_row_txstride = div4(in_sizes.x); + $if WEIGHT_STORAGE == "buffer": + $if QUANT_NBITS == 4: + uint qmat2_bufi = weight_txcol; + $else: + uint qmat2_bufi = out_txcol; + for (int pos = 0, txpos = 0; txpos < in_row_txstride; pos += 4, txpos += 1) { @@ -99,7 +105,6 @@ void main() { } $if WEIGHT_STORAGE == "buffer": - uint qmat2_bufi; uint weight_row_txstride = div4(weight_sizes.x); uint encoded_weight; @@ -114,26 +119,31 @@ void main() { $if QUANT_NBITS == 4: $for c in range(0, TILE_TXCOLS, 2): $if WEIGHT_STORAGE == "buffer": - qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol; encoded_weight = t_weight[qmat2_bufi + ${c}]; - packed_weight_tex = uvec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24); + qmat2[${c} * 4 * TILE_TXCOLS + 0] = T((encoded_weight >> 4) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 1] = T((encoded_weight >> 12) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 2] = T((encoded_weight >> 20) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 3] = T((encoded_weight >> 28)); + + qmat2[${c} * 4 * TILE_TXCOLS + 4] = T((encoded_weight) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 5] = T((encoded_weight >> 8) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 6] = T((encoded_weight >> 16) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 7] = T((encoded_weight >> 24) & 0xF); $else: packed_weight_tex = texelFetch( t_weight, ivec2(weight_txcol + ${c}, pos + r), 0); - - qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(packed_weight_tex.x >> 4); - qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(packed_weight_tex.y >> 4); - qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(packed_weight_tex.z >> 4); - qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(packed_weight_tex.w >> 4); - - qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(packed_weight_tex.x & 0xF); - qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(packed_weight_tex.y & 0xF); - qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(packed_weight_tex.z & 0xF); - qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(packed_weight_tex.w & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(packed_weight_tex.x >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(packed_weight_tex.y >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(packed_weight_tex.z >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(packed_weight_tex.w >> 4); + + qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(packed_weight_tex.x & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(packed_weight_tex.y & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(packed_weight_tex.z & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(packed_weight_tex.w & 0xF); $else: $for c in range(TILE_TXCOLS): $if WEIGHT_STORAGE == "buffer": - qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol; encoded_weight = t_weight[qmat2_bufi + ${c}]; packed_weight_tex = ivec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24); $else: @@ -146,6 +156,8 @@ void main() { $for j in range(4): sums[tr * TILE_TXCOLS * 4 + ${c} * 4 + ${j}] += qmat2[${c} * 4 + ${j}] * mat1[tr * 4 + r]; } + $if WEIGHT_STORAGE == "buffer": + qmat2_bufi += weight_row_txstride; } }