Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 33 additions & 43 deletions backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,22 @@

layout(std430) buffer;

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
layout(set = 0, binding = 2) uniform PRECISION sampler3D kernel_in;
layout(set = 0, binding = 3) uniform PRECISION sampler3D bias_in;

layout(set = 0, binding = 4) uniform PRECISION restrict OutLimits {
ivec3 out_limits;
};

layout(set = 0, binding = 5) uniform PRECISION restrict InSizes {
ivec4 in_sizes;
};

layout(set = 0, binding = 6) uniform PRECISION restrict Params {
int kernel_size;
int stride;
int padding;
int dilation;
int in_group_size;
int out_group_size;
};

layout(set = 0, binding = 7) uniform PRECISION restrict OutputParams {
float out_min;
float out_max;
};
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "kernel_in", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "bias_in", DTYPE, STORAGE)}

${layout_declare_ubo(B, "ivec3", "out_limits")}
${layout_declare_ubo(B, "ivec4", "in_sizes")}

${layout_declare_ubo(B, "ivec4", "out_axis_map")}
${layout_declare_ubo(B, "ivec4", "in_axis_map")}
${layout_declare_ubo(B, "ivec4", "kernel_axis_map")}
${layout_declare_ubo(B, "ivec4", "bias_axis_map")}

${layout_declare_ubo(B,"int", "kernel_size", "int", "stride", "int", "padding", "int", "dilation", "int", "in_group_size", "int", "out_group_size")}

${layout_declare_ubo(B, "float", "out_min", "float", "out_max")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

Expand All @@ -67,9 +57,9 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
// shader invocations, where each invocation computes 1 result. But that
// performs worse.
void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const ivec3 lpos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, out_limits))) {
if (any(greaterThanEqual(lpos, out_limits))) {
return;
}

Expand All @@ -78,8 +68,8 @@ void main() {

// "out_c" is the output's channel index where we write our result.
// Across shader invocations, this is the only value that varies.
int out_c = pos.y;
vec4 bias = texelFetch(bias_in, ivec3(out_c, 0, 0), 0);
int out_c = lpos.y;
VEC4_T bias = load_texel_lpos(bias_in, ivec3(out_c, 0, 0), bias_axis_map);

// "in_c" tracks the input's channel start index.
// We iterate over the input group that corresponds to the output group.
Expand All @@ -98,7 +88,7 @@ void main() {
int out_l = 0;

for (int in_l = l_start; in_l < l_end; in_l += stride, ++out_l) {
vec4 sum = vec4(0);
VEC4_T sum = VEC4_T(0);

for (int in_c = c_start; in_c < c_end; ++in_c) {
// "k" tracks the kernel's index for our input-kernel computation.
Expand All @@ -107,25 +97,25 @@ void main() {
for (int k = 0; k < kernel_size; k += 4) {
// Since the weight tensor is width-packed, which is along the length
// dimension, we can batch-read four elements at a time.
const ivec3 w_pos = ivec3(k / 4, in_c % in_group_size, out_c);
const vec4 weight = texelFetch(kernel_in, w_pos, 0);
const ivec3 w_lpos = ivec3(k / 4, in_c % in_group_size, out_c);
const VEC4_T weight = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);

const ivec3 in_pos_0 = ivec3(in_l + k * dilation, in_c, n / 4);
sum = fma(weight.xxxx, texelFetch(image_in, in_pos_0, 0), sum);
ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, n / 4), in_axis_map);
sum = fma(weight.xxxx, load_texel(t_in, in_pos), sum);

const ivec3 in_pos_1 = ivec3(in_l + (k+1) * dilation, in_c, n / 4);
sum = fma(weight.yyyy, texelFetch(image_in, in_pos_1, 0), sum);
in_pos[in_axis_map.x] += dilation;
sum = fma(weight.yyyy, load_texel(t_in, in_pos), sum);

const ivec3 in_pos_2 = ivec3(in_l + (k+2) * dilation, in_c, n / 4);
sum = fma(weight.zzzz, texelFetch(image_in, in_pos_2, 0), sum);
in_pos[in_axis_map.x] += dilation;
sum = fma(weight.zzzz, load_texel(t_in, in_pos), sum);

const ivec3 in_pos_3 = ivec3(in_l + (k+3) * dilation, in_c, n / 4);
sum = fma(weight.wwww, texelFetch(image_in, in_pos_3, 0), sum);
in_pos[in_axis_map.x] += dilation;
sum = fma(weight.wwww, load_texel(t_in, in_pos), sum);
}
}

ivec3 out_pos = ivec3(out_l, out_c, n / 4);
imageStore(image_out, out_pos, op(sum + bias.x, out_min, out_max));
const ivec3 out_lpos = ivec3(out_l, out_c, n / 4);
write_texel_lpos(t_out, out_lpos, op(sum + bias.x, out_min, out_max), out_axis_map);
}
}
}
3 changes: 1 addition & 2 deletions backends/vulkan/runtime/graph/ops/glsl/conv1d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
conv1d:
parameter_names_with_default_values:
OPERATOR: X
NDIM: 3
DTYPE: float
PACKING: C_packed
STORAGE: texture3d
generate_variant_forall:
DTYPE:
- VALUE: half
Expand Down
6 changes: 5 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ void add_conv1d_node(
int32_t out_group_size = static_cast<int64_t>(out_channels / groups_val);

utils::uvec3 global_size = {1, static_cast<uint32_t>(out_channels), 1};
utils::uvec3 local_size = {1, 1, 1};
utils::uvec3 local_size = {1, 64, 1};

Kernel1dParams kernel_params = {
kernel_size,
Expand Down Expand Up @@ -476,6 +476,10 @@ void add_conv1d_node(
{
t_out->logical_limits_ubo(),
t_in->sizes_ubo(),
t_out->axis_map_ubo(),
t_in->axis_map_ubo(),
t_weight->axis_map_ubo(),
t_bias->axis_map_ubo(),
graph.create_params_buffer(kernel_params),
graph.create_params_buffer(out_params),
},
Expand Down
Loading