diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl index 36b9c24317e..ecfb44d431e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl @@ -91,27 +91,23 @@ void main() { #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require -VEC4_T q_8w_linear(const ivec3 out_pos, const int K) { - u16vec3 mat1_pos = u16vec3(0, out_pos.yz); - u16vec3 qmat2_pos = u16vec3(0, out_pos.x * 4, 0); +VEC4_T q_8w_linear(const u16vec3 out_pos, const uint16_t K) { + const uint16_t qmat2_pos_y = out_pos.x * uint16_t(4); VEC4_T outtex = VEC4_T(0); const u16vec3 scales_pos = u16vec3(out_pos.x, 0, 0); const VEC4_T scales = load_texel(t_scales, scales_pos); - for (int i = 0; i < K; i += 4) { - const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos); + for (uint16_t i = uint16_t(0), x = uint16_t(0); i < K; i += uint16_t(4), x++) { + const VEC4_T mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.yz)); const VEC4_T sums = VEC4_T( - dot(mat1_tex, load_texel(t_qmat2, qmat2_pos)), - dot(mat1_tex, load_texel(t_qmat2, qmat2_pos + u16vec3(0, 1, 0))), - dot(mat1_tex, load_texel(t_qmat2, qmat2_pos + u16vec3(0, 2, 0))), - dot(mat1_tex, load_texel(t_qmat2, qmat2_pos + u16vec3(0, 3, 0)))); + dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y, 0))), + dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(1), 0))), + dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(2), 0))), + dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(3), 0)))); outtex += sums; - - mat1_pos.x++; - qmat2_pos.x++; } outtex *= scales; @@ -120,12 +116,12 @@ VEC4_T q_8w_linear(const ivec3 out_pos, const int K) { } void main() { - const ivec3 out_pos = ivec3(gl_GlobalInvocationID); + const u16vec3 out_pos = u16vec3(gl_GlobalInvocationID); if (any(greaterThanEqual(out_pos, out_limits))) { return; } - VEC4_T outtex = q_8w_linear(out_pos, mat1_sizes.x); + VEC4_T outtex = q_8w_linear(out_pos, uint16_t(mat1_sizes.x)); write_texel(t_out, out_pos, outtex); }