From 5d722213135474882bb3deaf99cf79cd948fda5c Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 26 Aug 2025 08:33:45 -0700 Subject: [PATCH] [ET-VK] Add high dim support for permute Pull Request resolved: https://github.com/pytorch/executorch/pull/13642 Title says it all! Adding high dimension tensor support to permute by using the new `BufferMetadata` struct in the permute shader. ghstack-source-id: 305694102 @exported-using-ghexport Differential Revision: [D80962719](https://our.internmc.facebook.com/intern/diff/D80962719/) --- backends/vulkan/op_registry.py | 2 +- .../runtime/graph/ops/glsl/indexing.glslh | 9 ++ .../graph/ops/glsl/permute_buffer.glsl | 52 ++------ .../vulkan/runtime/graph/ops/impl/Permute.cpp | 120 +++++++++++++----- 4 files changed, 114 insertions(+), 69 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index a711f81b738..79448beda65 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -490,7 +490,6 @@ def register_rotary_emb_op(): @update_features( [ exir_ops.edge.aten.permute.default, - exir_ops.edge.aten.permute_copy.default, ] ) def register_view_ops(): @@ -506,6 +505,7 @@ def register_view_ops(): exir_ops.edge.aten.squeeze_copy.dims, exir_ops.edge.aten.unsqueeze_copy.default, exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.permute_copy.default, ] ) def register_view_ops_with_buffer_meta(): diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh index 7155b4616e3..81783422ab4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh @@ -98,6 +98,15 @@ uint idx_at(const TensorIndex tidx, const int dim) { return tidx.data[div_4(dim)][mod_4(dim)]; } +void permute(inout TensorIndex tidx, const ivec4 permute_order[DIMLIMIT_DIV4]) { + TensorIndex new_tidx = tidx; + for (int d = 0; d < DIMLIMIT; ++d) { + int src_dim = permute_order[div_4(d)][mod_4(d)]; + new_tidx.data[div_4(d)][mod_4(d)] = idx_at(tidx, src_dim); + } + tidx = new_tidx; +} + // // Index Conversions // diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.glsl index 55b9e3dc9ea..3447ab07552 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.glsl @@ -18,55 +18,31 @@ ${define_required_extensions(DTYPE)} layout(std430) buffer; -#include "indexing_utils.h" +#include "indexing.glslh" -${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} -${layout_declare_tensor(B, "r", "t_in", DTYPE, "buffer")} +${layout_declare_tensor(B, "w", "t_outp", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_inp", DTYPE, "buffer")} -${layout_declare_ubo(B, "ivec4", "in_sizes")} -${layout_declare_ubo(B, "ivec4", "out_strides")} -${layout_declare_ubo(B, "int", "out_numel")} +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} -layout(push_constant) uniform restrict Block { - ivec4 in_strides; - ivec4 permute_dims; // Permutation mapping: permute_dims[i] = j means output dim i comes from input dim j -}; - -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} - -const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); +${layout_declare_ubo(B, "ivec4[DIMLIMIT_DIV4]", "permute_order")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -// Convert output tensor index to input tensor index based on permutation -ivec4 out_tidx_to_in_tidx(const ivec4 out_tidx) { - ivec4 in_tidx; - - // Apply the permutation mapping: in_tidx[permute_dims[i]] = out_tidx[i] - in_tidx[permute_dims.x] = out_tidx.x; - in_tidx[permute_dims.y] = out_tidx.y; - in_tidx[permute_dims.z] = out_tidx.z; - in_tidx[permute_dims.w] = out_tidx.w; - - return in_tidx; -} - void main() { - const int out_bufi = ivec3(gl_GlobalInvocationID).x; - if (out_bufi >= out_numel) { + const uint inp_bufi = gl_GlobalInvocationID.x; + if (inp_bufi >= numel(inp)) { return; } - // Convert buffer index to tensor index for output - const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_dim_order); - - // Convert output tensor index to input tensor index using permutation - const ivec4 in_tidx = out_tidx_to_in_tidx(out_tidx); + TensorIndex inp_tidx; + linear_idx_to_tensor_idx(inp, inp_bufi, inp_tidx); - // Convert input tensor index back to buffer index - const int in_bufi = tidx_to_bufi(in_tidx, in_strides); + TensorIndex outp_tidx = inp_tidx; + permute(outp_tidx, permute_order); + const uint outp_bufi = tensor_idx_to_linear_idx(outp, outp_tidx); // Copy data from input to output - t_out[out_bufi] = t_in[in_bufi]; + t_outp[outp_bufi] = t_inp[inp_bufi]; } diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp index 6e6a6fa3bf2..9ac4c963bc3 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp @@ -129,37 +129,22 @@ void add_permute_node( std::vector push_constants; vkapi::SpecVarList spec_vars; - if (graph.is_buffer_storage(out)) { - param_buffers.append(graph.sizes_ubo(in)); - param_buffers.append(graph.strides_ubo(out)); - param_buffers.append(graph.numel_ubo(out)); - - // Buffer storage - use permute_buffer shader - push_constants = { - graph.strides_pc_of(in), - PushConstantDataInfo(&whcn_permute_dims, sizeof(whcn_permute_dims)), - }; - - spec_vars = {graph.hashed_layout_of(out), graph.hashed_layout_of(in)}; - } else { - // Texture storage - use permute_texture shader - const int32_t out_channels = dim_at(graph.sizes_of(out)); - const int32_t in_channels = dim_at(graph.sizes_of(in)); - - const int32_t packed_dim = graph.packed_dim_of(in); - ivec2 channel_info = {out_channels, in_channels}; - if (packed_dim == WHCN::kChannelsDim) { - channel_info[0] = utils::align_up_4(channel_info[0]); - channel_info[1] = utils::align_up_4(channel_info[1]); - } + const int32_t out_channels = dim_at(graph.sizes_of(out)); + const int32_t in_channels = dim_at(graph.sizes_of(in)); + + const int32_t packed_dim = graph.packed_dim_of(in); + ivec2 channel_info = {out_channels, in_channels}; + if (packed_dim == WHCN::kChannelsDim) { + channel_info[0] = utils::align_up_4(channel_info[0]); + channel_info[1] = utils::align_up_4(channel_info[1]); + } - push_constants = { - graph.sizes_pc_of(out), - graph.sizes_pc_of(in), - PushConstantDataInfo(&whcn_permute_dims, sizeof(whcn_permute_dims))}; + push_constants = { + graph.sizes_pc_of(out), + graph.sizes_pc_of(in), + PushConstantDataInfo(&whcn_permute_dims, sizeof(whcn_permute_dims))}; - spec_vars = {graph.hashed_layout_of(out), graph.hashed_layout_of(in)}; - } + spec_vars = {graph.hashed_layout_of(out), graph.hashed_layout_of(in)}; graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, @@ -179,8 +164,83 @@ void add_permute_node( resize_permute_node)); } +struct WHCNPermuteDims { + int32_t whcn_permute_dims[api::kTensorDimLimit]; + + void initialize(const std::vector& permute_dims) { + const int32_t permute_ndim = permute_dims.size(); + for (int32_t whcn_i = 0; whcn_i < permute_ndim; whcn_i++) { + const int32_t nchw_i = permute_ndim - 1 - whcn_i; + int64_t index_val = permute_dims.at(nchw_i); + if (index_val < 0) { + index_val += permute_ndim; + } + const int32_t permute_dim_whcn = permute_ndim - 1 - index_val; + whcn_permute_dims[whcn_i] = permute_dim_whcn; + } + for (int32_t whcn_i = permute_ndim; whcn_i < api::kTensorDimLimit; + whcn_i++) { + whcn_permute_dims[whcn_i] = whcn_i; + } + } +}; + +void add_permute_buffer_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef permute_dims, + const ValueRef out) { + check_args(graph, in, permute_dims, out); + + WHCNPermuteDims whcn_permute_dims; + // Convert the permute dims to WHCN dimension order, which is the standard in + // our compute shaders. The following transformations are applied. + // 1. Change dimension index values from NCHW order valueto WHCN order value + // 2. Extend the permute array to kTensorDimLimit + { + IntListPtr permute_dims_ptr = graph.get_int_list(permute_dims); + whcn_permute_dims.initialize(*permute_dims_ptr); + } + + std::string kernel_name = "permute"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + vkapi::ParamsBindList param_buffers = { + graph.buffer_meta_ubo(out), + graph.buffer_meta_ubo(in), + graph.create_params_buffer(whcn_permute_dims), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Parameter buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {permute_dims}, + // Resizing Logic + resize_permute_node)); +} + void permute(ComputeGraph& graph, const std::vector& args) { - return add_permute_node(graph, args[0], args[1], args[2]); + int idx = 0; + const ValueRef in = args.at(idx++); + const ValueRef permute_dims = args.at(idx++); + const ValueRef out = args.at(idx++); + + if (graph.is_buffer_storage(args[2])) { + return add_permute_buffer_node(graph, in, permute_dims, out); + } + return add_permute_node(graph, in, permute_dims, out); } REGISTER_OPERATORS {