diff --git a/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl index 25749afbf85..9d78b7a6a6e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl @@ -15,24 +15,12 @@ layout(std430) buffer; -layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; -layout(set = 0, binding = 1, ${IMAGE_FORMAT["int"]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM]["int"]} image_idx; -layout(set = 0, binding = 2) uniform PRECISION sampler3D image_in; - -layout(set = 0, binding = 3) uniform PRECISION restrict OutLimits { - ivec3 out_limits; -}; - -layout(set = 0, binding = 4) uniform PRECISION restrict InSizes { - ivec4 in_sizes; -}; - -layout(set = 0, binding = 5) uniform PRECISION restrict Params { - ivec2 kernel_size; - ivec2 stride; - ivec2 padding; - ivec2 dilation; -}; +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(B, "w", "t_idx", "int", STORAGE)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} +${layout_declare_ubo(B, "ivec3", "out_limits")} +${layout_declare_ubo(B, "ivec4", "in_sizes")} +${layout_declare_ubo(B, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "padding", "ivec2", "dilation")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -54,7 +42,7 @@ void main() { for (int y = start.y; y < end.y; y += dilation.y) { for (int x = start.x; x < end.x; x += dilation.x) { if ((x >= 0 && x < in_sizes.x) && (y >= 0 && y < in_sizes.y)) { - const vec4 cur_texel = texelFetch(image_in, ivec3(x, y, pos.z), 0); + const vec4 cur_texel = load_texel(t_in, ivec3(x, y, pos.z)); // Set idx if value is greatest in the pool; else, keep the existing idx. ivec4 cur_idx = ivec4(x + int(in_sizes.x) * y); @@ -66,6 +54,6 @@ void main() { } } - imageStore(image_out, pos, out_texel); - imageStore(image_idx, pos, idx_texel); + imageStore(t_out, pos, out_texel); + imageStore(t_idx, pos, idx_texel); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.yaml b/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.yaml index 3be032bf85d..d8e3aa599f5 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.yaml @@ -8,6 +8,7 @@ max_pool2d: parameter_names_with_default_values: NDIM: 3 DTYPE: float + STORAGE: texture3d generate_variant_forall: DTYPE: - VALUE: half