diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl index c364e70bc9f..d966de7282e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl @@ -78,6 +78,12 @@ void main() { const int in_row_txstride = div4(in_sizes.x); + $if WEIGHT_STORAGE == "buffer": + $if QUANT_NBITS == 4: + uint qmat2_bufi = weight_txcol; + $else: + uint qmat2_bufi = out_txcol; + for (int pos = 0, txpos = 0; txpos < in_row_txstride; pos += 4, txpos += 1) { @@ -99,7 +105,6 @@ void main() { } $if WEIGHT_STORAGE == "buffer": - uint qmat2_bufi; uint weight_row_txstride = div4(weight_sizes.x); uint encoded_weight; @@ -114,28 +119,31 @@ void main() { $if QUANT_NBITS == 4: $for c in range(0, TILE_TXCOLS, 2): $if WEIGHT_STORAGE == "buffer": - qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol; encoded_weight = t_weight[qmat2_bufi + ${c}]; - packed_weight_tex = uvec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24); + qmat2[${c} * 4 * TILE_TXCOLS + 0] = T((encoded_weight >> 4) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 1] = T((encoded_weight >> 12) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 2] = T((encoded_weight >> 20) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 3] = T((encoded_weight >> 28)); + + qmat2[${c} * 4 * TILE_TXCOLS + 4] = T((encoded_weight) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 5] = T((encoded_weight >> 8) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 6] = T((encoded_weight >> 16) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 7] = T((encoded_weight >> 24) & 0xF); $else: packed_weight_tex = texelFetch( t_weight, ivec2(weight_txcol + ${c}, pos + r), 0); - - const uvec4 tmp1 = packed_weight_tex >> 4; - qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(tmp1.x); - qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(tmp1.y); - qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(tmp1.z); - qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(tmp1.w); - - const uvec4 tmp2 = packed_weight_tex & 0x0F; - qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(tmp2.x); - qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(tmp2.y); - qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(tmp2.z); - qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(tmp2.w); + qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(packed_weight_tex.x >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(packed_weight_tex.y >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(packed_weight_tex.z >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(packed_weight_tex.w >> 4); + + qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(packed_weight_tex.x & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(packed_weight_tex.y & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(packed_weight_tex.z & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(packed_weight_tex.w & 0xF); $else: $for c in range(TILE_TXCOLS): $if WEIGHT_STORAGE == "buffer": - qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol; encoded_weight = t_weight[qmat2_bufi + ${c}]; packed_weight_tex = ivec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24); $else: @@ -148,6 +156,8 @@ void main() { $for j in range(4): sums[tr * TILE_TXCOLS * 4 + ${c} * 4 + ${j}] += qmat2[${c} * 4 + ${j}] * mat1[tr * 4 + r]; } + $if WEIGHT_STORAGE == "buffer": + qmat2_bufi += weight_row_txstride; } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml index 287b2ee9333..81824a12026 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml @@ -20,6 +20,8 @@ linear_qcsnw_tiled: SUFFIX: o4x1 - VALUE: 2 SUFFIX: o4x2 + - VALUE: 3 + SUFFIX: o4x3 - VALUE: 4 SUFFIX: o4x4 shader_variants: diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp index e4e08363c6d..18958ccc3ce 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp @@ -61,15 +61,15 @@ utils::uvec3 linear_qcsnw_tiled_global_wg_size( std::vector mat1_sizes = graph->sizes_of(mat1); const int64_t M = utils::val_at(-2, mat1_sizes); - uint32_t out_tile_nrows = 4; - if (M % 6 == 0) { - out_tile_nrows = 2; + uint32_t out_tile_nrows = 1; + if (M % 3 == 0) { + out_tile_nrows = 3; } else if (M % 4 == 0) { out_tile_nrows = 4; - } else if (M % 1 == 0) { - out_tile_nrows = 1; + } else if (M % 2 == 0) { + out_tile_nrows = 2; } else { - out_tile_nrows = 4; + out_tile_nrows = 1; } // Number of output texels in the output tile @@ -309,19 +309,19 @@ void add_linear_qcsnw_tiled_node( std::vector mat1_sizes = graph.sizes_of(mat1); const int64_t M = utils::val_at(-2, mat1_sizes); - uint32_t out_tile_nrows = 4; - if (M % 6 == 0) { - kernel_name += "_o4x2"; - out_tile_nrows = 2; + uint32_t out_tile_nrows = 1; + if (M % 3 == 0) { + kernel_name += "_o4x3"; + out_tile_nrows = 3; } else if (M % 4 == 0) { kernel_name += "_o4x4"; out_tile_nrows = 4; - } else if (M % 1 == 0) { + } else if (M % 2 == 0) { + kernel_name += "_o4x2"; + out_tile_nrows = 2; + } else { kernel_name += "_o4x1"; out_tile_nrows = 1; - } else { - kernel_name += "_o4x4"; - out_tile_nrows = 4; } // Number of output texels in the output tile diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 40de9b59e81..db7c5a7e88b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -285,7 +285,7 @@ ValueRef prepack_int4_linear_weight_transposed_interleaved( const int64_t N = qmat2_orig_sizes.at(ndim - 2); const int64_t N_div2 = N / int64_t(2); - utils::StorageType storage_type = utils::kTexture2D; + utils::StorageType storage_type = utils::kBuffer; uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); if (N_div2 > max_extent * 4 || K > max_extent) { storage_type = utils::kBuffer;