Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 14 additions & 18 deletions backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ ${define_required_extensions(DTYPE)}
$if WEIGHT_STORAGE == "buffer":
${define_required_extensions("int8")}

#extension GL_EXT_control_flow_attributes : require

layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
Expand All @@ -49,20 +47,18 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
void main() {
// txcol stands for "texel column". One txcol corresponds to 4 scalar columns.
$if TILE_TXCOLS > 1:
const uint16_t global_wg_x = uint16_t(divup(out_sizes.x, 4 * TILE_TXCOLS));
const uint16_t out_txcol = uint16_t(
(gl_GlobalInvocationID.x % global_wg_x) * TILE_TXCOLS);
const int global_wg_x = divup(out_sizes.x, 4 * TILE_TXCOLS);
const int out_txcol = (int(gl_GlobalInvocationID.x) % global_wg_x) * TILE_TXCOLS;
$else:
const uint16_t global_wg_x = uint16_t(divup4(out_sizes.x));
const uint16_t out_txcol = uint16_t(gl_GlobalInvocationID.x % global_wg_x);
const int global_wg_x = divup4(out_sizes.x);
const int out_txcol = int(gl_GlobalInvocationID.x) % global_wg_x;

const uint16_t out_row = uint16_t(
(gl_GlobalInvocationID.x / global_wg_x) * TILE_ROWS);
const int out_row = (int(gl_GlobalInvocationID.x) / global_wg_x) * TILE_ROWS;

$if QUANT_NBITS == 4:
const uint16_t weight_txcol = uint16_t(out_txcol / 2);
const int weight_txcol = out_txcol / 2;

if (out_row >= uint16_t(out_sizes.y)) {
if (out_row >= int(out_sizes.y)) {
return;
}

Expand All @@ -73,9 +69,9 @@ void main() {
sums[r][${c}] = VEC4_T(0.0);
}

for (uint16_t pos = uint16_t(0), txpos = uint16_t(0);
pos < uint16_t(in_sizes.x);
pos += uint16_t(4), txpos += uint16_t(1)) {
for (int pos = 0, txpos = 0;
pos < in_sizes.x;
pos += 4, txpos += 1) {

T mat1[TILE_ROWS][4];

Expand All @@ -91,7 +87,7 @@ void main() {
mat1[i][2] = tmp.z;
mat1[i][3] = tmp.w;
$else:
VEC4_T tmp = VEC4_T(texelFetch(t_in, u16vec3(txpos, out_row + i, 0), 0));
VEC4_T tmp = VEC4_T(texelFetch(t_in, ivec3(txpos, out_row + i, 0), 0));
mat1[i][0] = tmp.x;
mat1[i][1] = tmp.y;
mat1[i][2] = tmp.z;
Expand All @@ -117,7 +113,7 @@ void main() {
packed_weight_tex = t_weight[qmat2_bufi + ${c}]
$else:
packed_weight_tex = texelFetch(
t_weight, u16vec2(weight_txcol + ${c}, pos + r), 0);
t_weight, ivec2(weight_txcol + ${c}, pos + r), 0);

qmat2[${c}] = (VEC4_T(packed_weight_tex >> 4) - 8.0);
qmat2[${c + 1}] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0);
Expand All @@ -128,7 +124,7 @@ void main() {
qmat2[${c}] = t_weight[qmat2_bufi + ${c}];
$else:
qmat2[${c}] = VEC4_T(
texelFetch(t_weight, u16vec2(out_txcol + ${c}, pos + r), 0));
texelFetch(t_weight, ivec2(out_txcol + ${c}, pos + r), 0));

for (int tr = 0; tr < TILE_ROWS; ++tr) {
$for c in range(TILE_TXCOLS):
Expand All @@ -143,7 +139,7 @@ void main() {
scales[${c}] = VEC4_T(t_scales[out_txcol + ${c}]);
$else:
scales[${c}] = VEC4_T(
texelFetch(t_scales, u16vec2(out_txcol + ${c}, 0), 0));
texelFetch(t_scales, ivec2(out_txcol + ${c}, 0), 0));

// Store to output tensor
$if OUT_STORAGE == "buffer":
Expand Down
Loading