From ab1237c9271e3f1fb04514f3879fbe9ac9cdae89 Mon Sep 17 00:00:00 2001 From: Nathanael See Date: Tue, 25 Mar 2025 10:47:26 -0700 Subject: [PATCH] [BE][ET-VK] update permute to use layout gen TSIA @pytorchbot label "topic: not user facing" Differential Revision: [D70435293](https://our.internmc.facebook.com/intern/diff/D70435293/) [ghstack-poisoned] --- backends/vulkan/runtime/graph/ops/glsl/permute.glsl | 8 ++++---- backends/vulkan/runtime/graph/ops/glsl/permute.yaml | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl index d4ad736a563..8a8703becd9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl @@ -16,8 +16,8 @@ layout(std430) buffer; #include "indexing_utils.h" -layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; -layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} image_in; +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} layout(push_constant) uniform PRECISION restrict Block { ivec4 out_limits; @@ -72,7 +72,7 @@ void main() { fetch_pos[packed_dim] >>= 2; // fetch input texel - VEC4_T inval = VEC4_T(texelFetch(image_in, fetch_pos, 0)); + VEC4_T inval = VEC4_T(load_texel(t_in, fetch_pos)); outval[j] = inval[in_packed_dim_lane_index]; // go to next position in the input, that is mapped to the packed dim in the output @@ -81,5 +81,5 @@ void main() { pos[packed_dim] = int(gl_GlobalInvocationID[packed_dim]); - imageStore(image_out, pos, outval); + imageStore(t_out, pos, outval); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.yaml b/backends/vulkan/runtime/graph/ops/glsl/permute.yaml index 64ad58e6e85..f678aeedf6e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/permute.yaml @@ -2,6 +2,7 @@ permute: parameter_names_with_default_values: DTYPE: float NDIM: 3 + STORAGE: texture3d generate_variant_forall: DTYPE: - VALUE: half