From 8e3d21ed30c776986371266a521f4231e4623c29 Mon Sep 17 00:00:00 2001 From: Alexander Dean Date: Wed, 10 Sep 2025 13:43:53 -0500 Subject: [PATCH 1/3] Optimize conv2d s1p0 --- .../graph/ops/glsl/conv2d_pw_s1p0.glsl | 185 +++++++----------- .../graph/ops/glsl/conv2d_pw_s1p0.yaml | 2 - .../runtime/graph/ops/impl/Convolution.cpp | 4 + 3 files changed, 80 insertions(+), 111 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl index 9f84afeb1a1..217c727bcd6 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl @@ -14,9 +14,6 @@ #define VEC4_T ${texel_type(DTYPE)} -#define TILE_SIZE_X uint16_t(${TILE_SIZE_X}) -#define TILE_SIZE_Y uint16_t(${TILE_SIZE_Y}) - #define op(X, A, B) ${OPERATOR} #include "indexing_utils.h" @@ -50,119 +47,89 @@ ${layout_declare_spec_const(C, "int", "ngroups", "1")} * size is only 1x1, making it easier to re-use loaded texels from t_kernel. */ void main() { - const int out_limits_scaled[2] = - {(out_limits.x + (TILE_SIZE_X - 1)) / TILE_SIZE_X, - (out_limits.y + (TILE_SIZE_Y - 1)) / TILE_SIZE_Y}; - - const uint16_t div_by_x = uint16_t(gl_GlobalInvocationID.x / out_limits_scaled[0]); - const uint16_t out_pos_xy[2] = {uint16_t(gl_GlobalInvocationID.x % out_limits_scaled[0]), div_by_x}; - const int out_pos_z = int(gl_GlobalInvocationID.y); - - // If the top left position is out of bounds, then this invocation will have - // no work to do. - if (out_pos_xy[1] >= out_limits_scaled[1] || out_pos_z >= out_limits.z) { - return; - } - // Output position for TILE_SIZE = 2 - // +--------+--------+ - // | pos[0] | pos[1] | - // +--------+--------+ - // | pos[2] | pos[3] | - // +--------+--------+ - uint16_t pos[TILE_SIZE_X * TILE_SIZE_Y * 2]; - for (uint16_t y = uint16_t(0), i = uint16_t(0); y < TILE_SIZE_Y; ++y) { - for (uint16_t x = uint16_t(0); x < TILE_SIZE_X; ++x) { - pos[i * 2] = out_pos_xy[0] * TILE_SIZE_X + x; - pos[i * 2 + 1] = out_pos_xy[1] * TILE_SIZE_Y + y; - i++; - } - } + int inputAndOutputWidth = out_limits.x; + int inputAndOutputHeight = out_limits.y; + int outputChannel = out_limits.z*4; - // Final output array where each element is a tensor value. - // Tuple of consecutive 4 elements represents a single output texel. - float sum[TILE_SIZE_X * TILE_SIZE_Y * 4]; + // Divided by 4 because the input channels are packed + int inputChannel = in_group_size/4; - // Initialize the output array with the bias value - for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y * 4; i++) { - sum[i] = 0; - } + int threadHW = int(gl_GlobalInvocationID.x); + int gid1 = int(gl_GlobalInvocationID.y); - int z4 = 0; - // Since the kernel is 1x1, we only have to loop over the depth dimension. - for (int z = 0; z < in_group_size; z += 4, ++z4) { - // During prepacking, the weight tensor has been permuted so that the - // channel (IC) dim is along the x-axis, and the batch (OC) dim is along - // the z-axis. - float kernel_values[4 * 4]; // 4 channels, 4 elements per channel - - // Load kernel values from texels to array - [[unroll]] for (int i = 0; i < 4; ++i) { - const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, out_pos_z), 0); - kernel_values[i * 4 + 0] = k_tex.x; - kernel_values[i * 4 + 1] = k_tex.y; - kernel_values[i * 4 + 2] = k_tex.z; - kernel_values[i * 4 + 3] = k_tex.w; - } + int xIdx = threadHW % inputAndOutputWidth; + int yIdx = threadHW / inputAndOutputWidth; - for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) { - const vec4 in_tex = texelFetch(t_in, ivec3(pos[i * 2], pos[i * 2 + 1], z4), 0); - // Load the input texel into an array - float tex_values[4]; - tex_values[0] = in_tex.x; - tex_values[1] = in_tex.y; - tex_values[2] = in_tex.z; - tex_values[3] = in_tex.w; - - // For 2x2 tile size algorithm works as follows. - // To explain the calculations below, the contents of one in_tex and the - // group of 4 texels loaded from t_kernel are shown: - // - // in_tex t_kernel - // -x-> ---x---> - // +---+ +----+----+----+----+ - // ^ | w | ^ | D0 | D1 | D2 | D3 | - // | +---+ | +----+----+----+----+ - // | | z | | | C0 | C1 | C2 | C3 | - // z +---+ z +----+----+----+----+ - // | | y | | | B0 | B2 | B2 | B3 | - // | +---+ | +----+----+----+----+ - // | x | | A0 | A1 | A2 | A3 | - // +---+ +----+----+----+----+ - // - // In the t_kernel graphic, cells sharing the same letter are from - // the same batch/output channel index, and the number denotes a unique - // channel index. To calculate the output texel, the following - // calculation is performed: - // - // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ - // | x | | D0 | | y | | D1 | | z | | D2 | | w | | D3 | - // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ - // | x | | C0 | | y | | C1 | | z | | C2 | | w | | C3 | - // +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+ - // | x | | B0 | | y | | B1 | | z | | B2 | | w | | B3 | - // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ - // | x | | A0 | | y | | A1 | | z | | A2 | | w | | A3 | - // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ - // - // which is what is expressed in the following calculations. This is done - // for each output position. - for (int j = 0; j < 4; ++j) { - sum[i * 4 + j] = tex_values[0] * kernel_values[0 + j] + sum[i * 4 + j]; - sum[i * 4 + j] = tex_values[1] * kernel_values[4 + j] + sum[i * 4 + j]; - sum[i * 4 + j] = tex_values[2] * kernel_values[8 + j] + sum[i * 4 + j]; - sum[i * 4 + j] = tex_values[3] * kernel_values[12 + j] + sum[i * 4 + j]; - } - } - } + if (threadHW < inputAndOutputWidth * inputAndOutputHeight && gid1 < outputChannel) { + + vec4 outputTexel = texelFetch(t_bias, ivec2(gid1, 0), 0); + + vec4 inputVec; + vec4 weight1OutputChannelPacked; + vec4 weight2OutputChannelPacked; + vec4 weight3OutputChannelPacked; + vec4 weight4OutputChannelPacked; + + // By unrolling the loop in sets of 4, this significantly reduces the number of branching instructions + // and enables the compiler to rearrange instructions for more efficient memory retrieval and compute + for (int inputC = 0; inputC < inputChannel; inputC += 1) { + + inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + + weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, gid1), 0); + weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, gid1), 0); + weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, gid1), 0); + weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, gid1), 0); - const vec4 bias = texelFetch(t_bias, ivec2(out_pos_z, 0), 0); + outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); - for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) { - const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], out_pos_z); - if (all(lessThan(pos_l.xy, out_limits.xy))) { - const vec4 out_sum = vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]); - imageStore(t_out, pos_l, op(out_sum + bias, out_min, out_max)); + inputC += 1; + + inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + + weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, gid1), 0); + weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, gid1), 0); + weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, gid1), 0); + weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, gid1), 0); + + outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + + inputC += 1; + + inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + + weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, gid1), 0); + weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, gid1), 0); + weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, gid1), 0); + weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, gid1), 0); + + outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + + inputC += 1; + + inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + + weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, gid1), 0); + weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, gid1), 0); + weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, gid1), 0); + weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, gid1), 0); + + outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); } + + imageStore(t_out, ivec3(xIdx, yIdx, gid1), op(outputTexel, out_min, out_max)); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml index ebfee11c405..bab3c715540 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml @@ -9,8 +9,6 @@ conv2d_pw_s1p0: OPERATOR: X NDIM: 3 DTYPE: float - TILE_SIZE_X: 1 - TILE_SIZE_Y: 4 generate_variant_forall: DTYPE: - VALUE: half diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index ded1defe973..ef4a4d514b0 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -364,6 +364,10 @@ utils::uvec3 conv2d_global_wg_size( if (method == Conv2dMethod::Depthwise || method == Conv2dMethod::Pointwise) { wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1}; + + if (shader.kernel_name.find("s1p0") != std::string::npos) { + wg_size[0] *= 4; + } } return wg_size; From 41ec8b64829f6a74a4153c3a700c290c6e03838c Mon Sep 17 00:00:00 2001 From: Alex Dean Date: Thu, 11 Sep 2025 16:13:36 -0700 Subject: [PATCH 2/3] Stylistic changes to pw conv2d s1p0 --- .../graph/ops/glsl/conv2d_pw_s1p0.glsl | 107 +++++++++--------- 1 file changed, 54 insertions(+), 53 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl index 217c727bcd6..06443d6a028 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl @@ -56,80 +56,81 @@ void main() { int inputChannel = in_group_size/4; int threadHW = int(gl_GlobalInvocationID.x); - int gid1 = int(gl_GlobalInvocationID.y); + int threadOutChannel = int(gl_GlobalInvocationID.y); int xIdx = threadHW % inputAndOutputWidth; int yIdx = threadHW / inputAndOutputWidth; - if (threadHW < inputAndOutputWidth * inputAndOutputHeight && gid1 < outputChannel) { - - vec4 outputTexel = texelFetch(t_bias, ivec2(gid1, 0), 0); + if (threadHW >= inputAndOutputWidth * inputAndOutputHeight && threadOutChannel >= outputChannel) { + return; + } - vec4 inputVec; - vec4 weight1OutputChannelPacked; - vec4 weight2OutputChannelPacked; - vec4 weight3OutputChannelPacked; - vec4 weight4OutputChannelPacked; + vec4 outputTexel = texelFetch(t_bias, ivec2(threadOutChannel, 0), 0); - // By unrolling the loop in sets of 4, this significantly reduces the number of branching instructions - // and enables the compiler to rearrange instructions for more efficient memory retrieval and compute - for (int inputC = 0; inputC < inputChannel; inputC += 1) { + vec4 inputVec; + vec4 weight1OutputChannelPacked; + vec4 weight2OutputChannelPacked; + vec4 weight3OutputChannelPacked; + vec4 weight4OutputChannelPacked; - inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + // By unrolling the loop in sets of 4, this significantly reduces the number of branching instructions + // and enables the compiler to rearrange instructions for more efficient memory retrieval and compute + for (int inputC = 0; inputC < inputChannel; inputC += 1) { - weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, gid1), 0); - weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, gid1), 0); - weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, gid1), 0); - weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, gid1), 0); + inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); - outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); - outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); - outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); - outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0); + weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0); + weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0); + weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0); - inputC += 1; + outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); - inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + inputC += 1; - weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, gid1), 0); - weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, gid1), 0); - weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, gid1), 0); - weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, gid1), 0); + inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); - outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); - outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); - outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); - outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0); + weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0); + weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0); + weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0); - inputC += 1; + outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); - inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + inputC += 1; - weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, gid1), 0); - weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, gid1), 0); - weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, gid1), 0); - weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, gid1), 0); + inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); - outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); - outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); - outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); - outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0); + weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0); + weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0); + weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0); - inputC += 1; + outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); - inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + inputC += 1; - weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, gid1), 0); - weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, gid1), 0); - weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, gid1), 0); - weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, gid1), 0); + inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); - outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); - outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); - outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); - outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); - } + weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0); + weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0); + weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0); + weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0); - imageStore(t_out, ivec3(xIdx, yIdx, gid1), op(outputTexel, out_min, out_max)); + outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); } + + imageStore(t_out, ivec3(xIdx, yIdx, threadOutChannel), op(outputTexel, out_min, out_max)); } From c1910fe25ec3ff47f3b6faaf29bff54d6bbe1ce5 Mon Sep 17 00:00:00 2001 From: Alexander Dean Date: Wed, 17 Sep 2025 16:49:55 -0500 Subject: [PATCH 3/3] Add fp16 to conv2d pw s1p0 --- .../graph/ops/glsl/conv2d_pw_s1p0.glsl | 93 ++++++++++--------- 1 file changed, 49 insertions(+), 44 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl index 06443d6a028..ef50a1aca9f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl @@ -12,7 +12,12 @@ #define PRECISION ${PRECISION} -#define VEC4_T ${texel_type(DTYPE)} +$if DTYPE == "half": + #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require + #define VEC4_T f16vec4 +$else: + #define VEC4_T ${texel_type(DTYPE)} + #define op(X, A, B) ${OPERATOR} @@ -65,72 +70,72 @@ void main() { return; } - vec4 outputTexel = texelFetch(t_bias, ivec2(threadOutChannel, 0), 0); + VEC4_T outputTexel = VEC4_T(texelFetch(t_bias, ivec2(threadOutChannel, 0), 0)); - vec4 inputVec; - vec4 weight1OutputChannelPacked; - vec4 weight2OutputChannelPacked; - vec4 weight3OutputChannelPacked; - vec4 weight4OutputChannelPacked; + VEC4_T inputVec; + VEC4_T weight1OutputChannelPacked; + VEC4_T weight2OutputChannelPacked; + VEC4_T weight3OutputChannelPacked; + VEC4_T weight4OutputChannelPacked; // By unrolling the loop in sets of 4, this significantly reduces the number of branching instructions // and enables the compiler to rearrange instructions for more efficient memory retrieval and compute for (int inputC = 0; inputC < inputChannel; inputC += 1) { - inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); - weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0); - weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0); - weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0); - weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0); + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); - outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); - outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); - outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); - outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); inputC += 1; - inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); - weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0); - weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0); - weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0); - weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0); + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); - outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); - outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); - outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); - outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); inputC += 1; - inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); - weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0); - weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0); - weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0); - weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0); + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); - outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); - outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); - outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); - outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); inputC += 1; - inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); - weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0); - weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0); - weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0); - weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0); + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); - outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); - outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); - outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); - outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); } - imageStore(t_out, ivec3(xIdx, yIdx, threadOutChannel), op(outputTexel, out_min, out_max)); + imageStore(t_out, ivec3(xIdx, yIdx, threadOutChannel), op(vec4(outputTexel), out_min, out_max)); }