Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ void main() {

const int in_row_txstride = div4(in_sizes.x);

$if WEIGHT_STORAGE == "buffer":
$if QUANT_NBITS == 4:
uint qmat2_bufi = weight_txcol;
$else:
uint qmat2_bufi = out_txcol;

for (int pos = 0, txpos = 0;
txpos < in_row_txstride;
pos += 4, txpos += 1) {
Expand All @@ -99,7 +105,6 @@ void main() {
}

$if WEIGHT_STORAGE == "buffer":
uint qmat2_bufi;
uint weight_row_txstride = div4(weight_sizes.x);
uint encoded_weight;

Expand All @@ -114,28 +119,31 @@ void main() {
$if QUANT_NBITS == 4:
$for c in range(0, TILE_TXCOLS, 2):
$if WEIGHT_STORAGE == "buffer":
qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol;
encoded_weight = t_weight[qmat2_bufi + ${c}];
packed_weight_tex = uvec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24);
qmat2[${c} * 4 * TILE_TXCOLS + 0] = T((encoded_weight >> 4) & 0xF);
qmat2[${c} * 4 * TILE_TXCOLS + 1] = T((encoded_weight >> 12) & 0xF);
qmat2[${c} * 4 * TILE_TXCOLS + 2] = T((encoded_weight >> 20) & 0xF);
qmat2[${c} * 4 * TILE_TXCOLS + 3] = T((encoded_weight >> 28));

qmat2[${c} * 4 * TILE_TXCOLS + 4] = T((encoded_weight) & 0xF);
qmat2[${c} * 4 * TILE_TXCOLS + 5] = T((encoded_weight >> 8) & 0xF);
qmat2[${c} * 4 * TILE_TXCOLS + 6] = T((encoded_weight >> 16) & 0xF);
qmat2[${c} * 4 * TILE_TXCOLS + 7] = T((encoded_weight >> 24) & 0xF);
$else:
packed_weight_tex = texelFetch(
t_weight, ivec2(weight_txcol + ${c}, pos + r), 0);

const uvec4 tmp1 = packed_weight_tex >> 4;
qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(tmp1.x);
qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(tmp1.y);
qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(tmp1.z);
qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(tmp1.w);

const uvec4 tmp2 = packed_weight_tex & 0x0F;
qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(tmp2.x);
qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(tmp2.y);
qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(tmp2.z);
qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(tmp2.w);
qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(packed_weight_tex.x >> 4);
qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(packed_weight_tex.y >> 4);
qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(packed_weight_tex.z >> 4);
qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(packed_weight_tex.w >> 4);

qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(packed_weight_tex.x & 0xF);
qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(packed_weight_tex.y & 0xF);
qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(packed_weight_tex.z & 0xF);
qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(packed_weight_tex.w & 0xF);
$else:
$for c in range(TILE_TXCOLS):
$if WEIGHT_STORAGE == "buffer":
qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol;
encoded_weight = t_weight[qmat2_bufi + ${c}];
packed_weight_tex = ivec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24);
$else:
Expand All @@ -148,6 +156,8 @@ void main() {
$for j in range(4):
sums[tr * TILE_TXCOLS * 4 + ${c} * 4 + ${j}] += qmat2[${c} * 4 + ${j}] * mat1[tr * 4 + r];
}
$if WEIGHT_STORAGE == "buffer":
qmat2_bufi += weight_row_txstride;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ linear_qcsnw_tiled:
SUFFIX: o4x1
- VALUE: 2
SUFFIX: o4x2
- VALUE: 3
SUFFIX: o4x3
- VALUE: 4
SUFFIX: o4x4
shader_variants:
Expand Down
28 changes: 14 additions & 14 deletions backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ utils::uvec3 linear_qcsnw_tiled_global_wg_size(

std::vector<int64_t> mat1_sizes = graph->sizes_of(mat1);
const int64_t M = utils::val_at(-2, mat1_sizes);
uint32_t out_tile_nrows = 4;
if (M % 6 == 0) {
out_tile_nrows = 2;
uint32_t out_tile_nrows = 1;
if (M % 3 == 0) {
out_tile_nrows = 3;
} else if (M % 4 == 0) {
out_tile_nrows = 4;
} else if (M % 1 == 0) {
out_tile_nrows = 1;
} else if (M % 2 == 0) {
out_tile_nrows = 2;
} else {
out_tile_nrows = 4;
out_tile_nrows = 1;
}

// Number of output texels in the output tile
Expand Down Expand Up @@ -309,19 +309,19 @@ void add_linear_qcsnw_tiled_node(

std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1);
const int64_t M = utils::val_at(-2, mat1_sizes);
uint32_t out_tile_nrows = 4;
if (M % 6 == 0) {
kernel_name += "_o4x2";
out_tile_nrows = 2;
uint32_t out_tile_nrows = 1;
if (M % 3 == 0) {
kernel_name += "_o4x3";
out_tile_nrows = 3;
} else if (M % 4 == 0) {
kernel_name += "_o4x4";
out_tile_nrows = 4;
} else if (M % 1 == 0) {
} else if (M % 2 == 0) {
kernel_name += "_o4x2";
out_tile_nrows = 2;
} else {
kernel_name += "_o4x1";
out_tile_nrows = 1;
} else {
kernel_name += "_o4x4";
out_tile_nrows = 4;
}

// Number of output texels in the output tile
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/Staging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ ValueRef prepack_int4_linear_weight_transposed_interleaved(
const int64_t N = qmat2_orig_sizes.at(ndim - 2);
const int64_t N_div2 = N / int64_t(2);

utils::StorageType storage_type = utils::kTexture2D;
utils::StorageType storage_type = utils::kBuffer;
uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim();
if (N_div2 > max_extent * 4 || K > max_extent) {
storage_type = utils::kBuffer;
Expand Down