From c442d508b64d6866ea39e352c689b81dc932a40e Mon Sep 17 00:00:00 2001 From: Nathanael See Date: Wed, 26 Mar 2025 10:49:38 -0700 Subject: [PATCH] [BE][ET-VK] update max_pool2d to use layout gen Pull Request resolved: https://github.com/pytorch/executorch/pull/9591 TSIA @pytorchbot label "topic: not user facing" Differential Revision: [D71825476](https://our.internmc.facebook.com/intern/diff/D71825476/) ghstack-source-id: 274222178 --- .../runtime/graph/ops/glsl/max_pool2d.glsl | 30 ++++++------------- .../runtime/graph/ops/glsl/max_pool2d.yaml | 1 + 2 files changed, 10 insertions(+), 21 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl index 25749afbf85..9d78b7a6a6e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl @@ -15,24 +15,12 @@ layout(std430) buffer; -layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; -layout(set = 0, binding = 1, ${IMAGE_FORMAT["int"]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM]["int"]} image_idx; -layout(set = 0, binding = 2) uniform PRECISION sampler3D image_in; - -layout(set = 0, binding = 3) uniform PRECISION restrict OutLimits { - ivec3 out_limits; -}; - -layout(set = 0, binding = 4) uniform PRECISION restrict InSizes { - ivec4 in_sizes; -}; - -layout(set = 0, binding = 5) uniform PRECISION restrict Params { - ivec2 kernel_size; - ivec2 stride; - ivec2 padding; - ivec2 dilation; -}; +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(B, "w", "t_idx", "int", STORAGE)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} +${layout_declare_ubo(B, "ivec3", "out_limits")} +${layout_declare_ubo(B, "ivec4", "in_sizes")} +${layout_declare_ubo(B, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "padding", "ivec2", "dilation")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -54,7 +42,7 @@ void main() { for (int y = start.y; y < end.y; y += dilation.y) { for (int x = start.x; x < end.x; x += dilation.x) { if ((x >= 0 && x < in_sizes.x) && (y >= 0 && y < in_sizes.y)) { - const vec4 cur_texel = texelFetch(image_in, ivec3(x, y, pos.z), 0); + const vec4 cur_texel = load_texel(t_in, ivec3(x, y, pos.z)); // Set idx if value is greatest in the pool; else, keep the existing idx. ivec4 cur_idx = ivec4(x + int(in_sizes.x) * y); @@ -66,6 +54,6 @@ void main() { } } - imageStore(image_out, pos, out_texel); - imageStore(image_idx, pos, idx_texel); + imageStore(t_out, pos, out_texel); + imageStore(t_idx, pos, idx_texel); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.yaml b/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.yaml index 3be032bf85d..d8e3aa599f5 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.yaml @@ -8,6 +8,7 @@ max_pool2d: parameter_names_with_default_values: NDIM: 3 DTYPE: float + STORAGE: texture3d generate_variant_forall: DTYPE: - VALUE: half