From de587244bc39f716fd4c8db335dfcdaa6d86c36e Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 24 Jun 2025 22:00:58 -0700 Subject: [PATCH 1/4] [ET-VK] New Implementation of `permute' operator Pull Request resolved: https://github.com/pytorch/executorch/pull/11825 ## Changes * Introduce `permute_buffer.glsl` and `permute_texture.glsl` compute shader templates to implement the permute operator ## Motivation The existing implementation of permute produced incorrect outputs for width packed textures. Furthermore, there was no buffer implementation for the permute operator. My goal with this diff is to introduce a more flexible implementation of permute that could work for any tensor representation. ## Performance impact None expected. ghstack-source-id: 292530157 @exported-using-ghexport Differential Revision: [D76483755](https://our.internmc.facebook.com/intern/diff/D76483755/) --- .../runtime/graph/ops/glsl/permute.glsl | 89 --------------- .../graph/ops/glsl/permute_buffer.glsl | 72 ++++++++++++ .../{permute.yaml => permute_buffer.yaml} | 6 +- .../graph/ops/glsl/permute_texture.glsl | 103 ++++++++++++++++++ .../graph/ops/glsl/permute_texture.yaml | 10 ++ .../vulkan/runtime/graph/ops/impl/Permute.cpp | 85 +++++++++------ .../runtime/graph/ops/impl/Unsqueeze.cpp | 3 + backends/vulkan/test/op_tests/cases.py | 9 ++ 8 files changed, 253 insertions(+), 124 deletions(-) delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/permute.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/permute_buffer.glsl rename backends/vulkan/runtime/graph/ops/glsl/{permute.yaml => permute_buffer.yaml} (73%) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/permute_texture.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl deleted file mode 100644 index 716c42e8ede..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define VEC4_T ${texel_type(DTYPE)} - -layout(std430) buffer; - -#include "indexing_utils.h" - -${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} - -layout(push_constant) uniform PRECISION restrict Block { - ivec4 out_limits; - ivec4 in_sizes; - // output dims - ivec4 out_ndims; - // x = output channels aligned to 4, y = input channels aligned to 4 - ivec2 channel_info; -}; - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -layout(constant_id = 3) const int packed_dim = C_DIM; - -#extension GL_EXT_control_flow_attributes : require - -void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, out_limits.xyz))) { - return; - } - - VEC4_T outval = VEC4_T(0.0); - - // scale up output position's packed dim - pos[packed_dim] <<= 2; - - // index of packed dim in bchw format - const int in_packed_dim_bchw_index = 3 - packed_dim; - - // determine input position based on output position and permute map - // out_ndims is in BCHW format - ivec4 in_bchw_pos = ivec4(0); // holds b,c,h,w - in_bchw_pos[out_ndims[0]] = (pos.z / channel_info.x); - in_bchw_pos[out_ndims[1]] = (pos.z % channel_info.x); - in_bchw_pos[out_ndims[2]] = pos.y; - in_bchw_pos[out_ndims[3]] = pos.x; - - const int in_packed_dim_size = in_sizes[3 - out_ndims[in_packed_dim_bchw_index]]; - - [[unroll]] for (int j = 0, bchw_index = in_bchw_pos[out_ndims[in_packed_dim_bchw_index]]; j < 4; ++j, ++bchw_index) { - // terminate the loop if trying to access input texture out of bounds - if (bchw_index >= in_packed_dim_size) { - break; - } - // go to position in the input, that is mapped to the packed dim in the output - in_bchw_pos[out_ndims[in_packed_dim_bchw_index]] = bchw_index; - - ivec3 fetch_pos; - - fetch_pos.xy = in_bchw_pos.wz; - // calculate input position in z axis using batch and channel index which is in_bchw_pos.x and in_bchw_pos.y respectively - fetch_pos.z = in_bchw_pos.y + in_bchw_pos.x * channel_info.y; - - // input tensor's packed dim lane corresponding to output tensor's pos - const int in_packed_dim_lane_index = fetch_pos[packed_dim] & 0x3; - - // scale down input tensor's packed dim pos to perform fetch - fetch_pos[packed_dim] >>= 2; - - // fetch input texel - VEC4_T inval = VEC4_T(load_texel(t_in, fetch_pos)); - outval[j] = inval[in_packed_dim_lane_index]; - } - - pos[packed_dim] = int(gl_GlobalInvocationID[packed_dim]); - - imageStore(t_out, pos, outval); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.glsl new file mode 100644 index 00000000000..55b9e3dc9ea --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.glsl @@ -0,0 +1,72 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("buffer")} +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_in", DTYPE, "buffer")} + +${layout_declare_ubo(B, "ivec4", "in_sizes")} +${layout_declare_ubo(B, "ivec4", "out_strides")} +${layout_declare_ubo(B, "int", "out_numel")} + +layout(push_constant) uniform restrict Block { + ivec4 in_strides; + ivec4 permute_dims; // Permutation mapping: permute_dims[i] = j means output dim i comes from input dim j +}; + +${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} +${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} + +const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// Convert output tensor index to input tensor index based on permutation +ivec4 out_tidx_to_in_tidx(const ivec4 out_tidx) { + ivec4 in_tidx; + + // Apply the permutation mapping: in_tidx[permute_dims[i]] = out_tidx[i] + in_tidx[permute_dims.x] = out_tidx.x; + in_tidx[permute_dims.y] = out_tidx.y; + in_tidx[permute_dims.z] = out_tidx.z; + in_tidx[permute_dims.w] = out_tidx.w; + + return in_tidx; +} + +void main() { + const int out_bufi = ivec3(gl_GlobalInvocationID).x; + if (out_bufi >= out_numel) { + return; + } + + // Convert buffer index to tensor index for output + const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_dim_order); + + // Convert output tensor index to input tensor index using permutation + const ivec4 in_tidx = out_tidx_to_in_tidx(out_tidx); + + // Convert input tensor index back to buffer index + const int in_bufi = tidx_to_bufi(in_tidx, in_strides); + + // Copy data from input to output + t_out[out_bufi] = t_in[in_bufi]; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.yaml b/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.yaml similarity index 73% rename from backends/vulkan/runtime/graph/ops/glsl/permute.yaml rename to backends/vulkan/runtime/graph/ops/glsl/permute_buffer.yaml index a90ddcb41ce..81675ae8917 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.yaml @@ -1,12 +1,10 @@ -permute: +permute_buffer: parameter_names_with_default_values: DTYPE: float - NDIM: 3 - STORAGE: texture3d generate_variant_forall: DTYPE: - VALUE: half - VALUE: float - VALUE: int32 shader_variants: - - NAME: permute + - NAME: permute_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/permute_texture.glsl new file mode 100644 index 00000000000..274077f4181 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/permute_texture.glsl @@ -0,0 +1,103 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("texture3d")} +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_in", DTYPE, "texture3d")} + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 in_sizes; + ivec4 permute_dims; // Permutation mapping: permute_dims[i] = j means output dim i comes from input dim j +}; + +${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} +const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); +const lowp int out_packed_dim = unhash_packed_dim(out_layout); + +${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} +const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); +const lowp int in_packed_dim = unhash_packed_dim(in_layout); + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// Convert output tensor index to input tensor index based on permutation +ivec4 out_tidx_to_in_tidx(const ivec4 out_tidx) { + ivec4 in_tidx; + + // Apply the permutation mapping: in_tidx[permute_dims[i]] = out_tidx[i] + in_tidx[permute_dims.x] = out_tidx.x; + in_tidx[permute_dims.y] = out_tidx.y; + in_tidx[permute_dims.z] = out_tidx.z; + in_tidx[permute_dims.w] = out_tidx.w; + + return in_tidx; +} + +// Check if we can use the fast path where texels from the input tensor can be +// copied directly into the output tensor. This occurs when the packed dimension +// is preserved in the permutation, i.e. reading a texel from the output tensor +// produces 4 texels along the same dimension as reading a texel from the input +// tensor. +bool can_use_fast_path() { + // Fast path is possible when the packed dimension is preserved in the permutation + // This means permute_dims[out_packed_dim] == in_packed_dim + return permute_dims[out_packed_dim] == in_packed_dim; +} + +void main() { + const ivec3 lpos = ivec3(gl_GlobalInvocationID); + ivec4 out_tidx = lpos_to_tidx(lpos, out_sizes, out_axis_map.w, out_packed_dim); + + if (any(greaterThanEqual(out_tidx, out_sizes))) { + return; + } + + if (can_use_fast_path()) { + // Fast path: packed dimension is preserved, so we can copy texels directly + ivec4 in_tidx = out_tidx_to_in_tidx(out_tidx); + ivec3 in_pos = tidx_to_pos(in_tidx, in_sizes, in_axis_map, in_packed_dim); + VEC4_T in_texel = VEC4_T(load_texel(t_in, in_pos)); + + write_texel_lpos(t_out, lpos, in_texel, out_axis_map); + } + else { + // Slow path: packed dimension is not preserved, so each element of the + // output texel may be "sourced" from a different texel in the input tensor. + // Therefore each output texel element is processed individually. + VEC4_T out_texel = VEC4_T(0); + + for (int texel_i = 0; texel_i < 4; ++texel_i) { + ivec4 in_tidx = out_tidx_to_in_tidx(out_tidx); + ivec3 in_pos = tidx_to_pos(in_tidx, in_sizes, in_axis_map, in_packed_dim); + int element_idx = in_tidx[in_packed_dim] % 4; + + VEC4_T in_texel = VEC4_T(load_texel(t_in, in_pos)); + T selected_value = T(in_texel[element_idx]); + + out_texel[texel_i] = selected_value; + + out_tidx[out_packed_dim]++; + } + + write_texel_lpos(t_out, lpos, out_texel, out_axis_map); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml new file mode 100644 index 00000000000..f68b8dcdd3d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml @@ -0,0 +1,10 @@ +permute_texture: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int32 + shader_variants: + - NAME: permute_texture3d diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp index fba3f03467b..6e6a6fa3bf2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp @@ -10,6 +10,7 @@ #include +#include #include #include #include @@ -100,54 +101,76 @@ void add_permute_node( const ValueRef out) { check_args(graph, in, permute_dims, out); - ivec4 out_dims{0, 1, 2, 3}; - - // Special cases of squeeze/unsqueeze. Because the input dim size can be - // different with output dim size. So pick graph.dim_of(in) if squeeze, and - // graph.dim_of(out) if unsqueeze to create parameter for permute. - const int64_t out_ndim = std::max(graph.dim_of(in), graph.dim_of(out)); - std::vector seen(out_ndim); + // Convert the permute dims to WHCN dimension order, which is the standard in + // our compute shaders. The following transformations are applied. + // 1. Change dimension index values from NCHW order valueto WHCN order value + // 2. Reverse the order of the permute array from NCHW order to WHCN order + ivec4 whcn_permute_dims{0, 1, 2, 3}; { IntListPtr permute_dims_ptr = graph.get_int_list(permute_dims); - for (int i = 0; i < out_ndim; i++) { - int64_t permute_dim = permute_dims_ptr->at(i); - VK_CHECK_COND( - !seen[permute_dim], "Argument dim ", permute_dim, " is repeated"); - seen[permute_dim] = true; + const int32_t permute_ndim = + utils::safe_downcast(permute_dims_ptr->size()); + + for (int32_t nchw_i = permute_ndim - 1, whcn_i = 0; nchw_i >= 0; + nchw_i--, whcn_i++) { + const int32_t permute_dim_nchw = permute_dims_ptr->at(nchw_i); + const int32_t permute_dim_whcn = permute_ndim - 1 - permute_dim_nchw; - out_dims[(4u - out_ndim) + i] = - utils::safe_downcast(permute_dim + (4 - out_ndim)); + whcn_permute_dims[whcn_i] = permute_dim_whcn; } } std::string kernel_name = "permute"; kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - const int32_t out_channels = dim_at(graph.sizes_of(out)); - const int32_t in_channels = dim_at(graph.sizes_of(in)); + vkapi::ParamsBindList param_buffers; + std::vector push_constants; + vkapi::SpecVarList spec_vars; - const int32_t packed_dim = graph.packed_dim_of(in); - ivec2 channel_info = {out_channels, in_channels}; - if (packed_dim == WHCN::kChannelsDim) { - channel_info[0] = utils::align_up_4(channel_info[0]); - channel_info[1] = utils::align_up_4(channel_info[1]); - } + if (graph.is_buffer_storage(out)) { + param_buffers.append(graph.sizes_ubo(in)); + param_buffers.append(graph.strides_ubo(out)); + param_buffers.append(graph.numel_ubo(out)); + + // Buffer storage - use permute_buffer shader + push_constants = { + graph.strides_pc_of(in), + PushConstantDataInfo(&whcn_permute_dims, sizeof(whcn_permute_dims)), + }; + + spec_vars = {graph.hashed_layout_of(out), graph.hashed_layout_of(in)}; + } else { + // Texture storage - use permute_texture shader + const int32_t out_channels = dim_at(graph.sizes_of(out)); + const int32_t in_channels = dim_at(graph.sizes_of(in)); + + const int32_t packed_dim = graph.packed_dim_of(in); + ivec2 channel_info = {out_channels, in_channels}; + if (packed_dim == WHCN::kChannelsDim) { + channel_info[0] = utils::align_up_4(channel_info[0]); + channel_info[1] = utils::align_up_4(channel_info[1]); + } + + push_constants = { + graph.sizes_pc_of(out), + graph.sizes_pc_of(in), + PushConstantDataInfo(&whcn_permute_dims, sizeof(whcn_permute_dims))}; - const vkapi::SpecVarList spec_vars = {packed_dim}; + spec_vars = {graph.hashed_layout_of(out), graph.hashed_layout_of(in)}; + } - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, {{out, vkapi::kWrite}, {in, vkapi::kRead}}, - {}, + // Parameter buffers + param_buffers, // Push Constants - {{graph.logical_limits_pc_of(out), - graph.sizes_pc_of(in), - PushConstantDataInfo(&out_dims, sizeof(out_dims)), - PushConstantDataInfo(&channel_info, sizeof(channel_info))}}, + push_constants, // Specialization Constants spec_vars, // Resize Args diff --git a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp index 306a79fb8b8..c4de5d88f30 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp @@ -26,6 +26,9 @@ void add_unsqueeze_node( in_dim < 4, "Cannot unsqueeze a tensor with more than 3 dimensions"); int64_t dim = graph.extract_scalar(dim_ref); + if (dim < 0) { + dim += out_dim; + } std::vector permute_dims(out_dim); for (int i = 1; i <= dim; i++) { diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 813807445f0..92f73268ebf 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -752,6 +752,13 @@ def get_permute_inputs(): "utils::kHeightPacked", "utils::kChannelsPacked", ] + test_suite.storage_types = [ + "utils::kBuffer", + "utils::kTexture3D", + ] + test_suite.dtypes = [ + "at::kFloat", + ] return test_suite @@ -990,9 +997,11 @@ def get_unsqueeze_inputs(): ((9, 9), 2), ((9,), 0), ((9,), 1), + ((1, 10), -1), ] ) test_suite.layouts = [ + "utils::kWidthPacked", "utils::kChannelsPacked", ] test_suite.data_gen = "make_seq_tensor" From 405c96ad4bc816d7e2bedb2a86f024f0430bc202 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 24 Jun 2025 22:01:05 -0700 Subject: [PATCH 2/4] [ET-VK][ez][testing] Improvement to operator test codegen system Pull Request resolved: https://github.com/pytorch/executorch/pull/11826 ## Changes * Allow test cases to specify storage types / memory layouts for individual args * Allow test cases to specify different data generation functions for individual args ## Motivation > Allow test cases to specify storage types / memory layouts for individual args Make it possible to test args that require specific storage types for certain input/output tensors. > Allow test cases to specify different data generation functions for individual args Useful for debugging operators during development. ghstack-source-id: 292530160 @exported-using-ghexport Differential Revision: [D77038777](https://our.internmc.facebook.com/intern/diff/D77038777/) --- .../test/op_tests/utils/gen_computegraph.py | 71 ++++++++++++++++++- .../op_tests/utils/gen_correctness_base.py | 26 ++++++- .../test/op_tests/utils/gen_correctness_vk.py | 1 - .../vulkan/test/op_tests/utils/test_suite.py | 5 +- 4 files changed, 96 insertions(+), 7 deletions(-) diff --git a/backends/vulkan/test/op_tests/utils/gen_computegraph.py b/backends/vulkan/test/op_tests/utils/gen_computegraph.py index b24879f660a..38a3ee93627 100644 --- a/backends/vulkan/test/op_tests/utils/gen_computegraph.py +++ b/backends/vulkan/test/op_tests/utils/gen_computegraph.py @@ -58,6 +58,8 @@ class ValueRef: src_cpp_type: str is_in: bool = False is_out: bool = False + fixed_storage_type: Optional[str] = None + fixed_memory_layout: Optional[str] = None requires_prepack: bool = False supports_prepack: bool = False # When is_dynamic_size is true, the underlying object size is not known @@ -137,20 +139,43 @@ def __init__( if arg.name in self.suite_def.prepacked_args: supports_prepack = True + fixed_storage_type = None + if arg.name in self.suite_def.arg_storage_types: + fixed_storage_type = self.suite_def.arg_storage_types[arg.name] + + fixed_memory_layout = None + if arg.name in self.suite_def.arg_memory_layouts: + fixed_memory_layout = self.suite_def.arg_memory_layouts[arg.name] + self.refs[arg.name] = ValueRef( name=f"{arg.name}_ref", src_cpp_name=arg.name, src_cpp_type=cpp_type, is_in=(cpp_type in InableCppType), + fixed_storage_type=fixed_storage_type, + fixed_memory_layout=fixed_memory_layout, requires_prepack=requires_prepack, supports_prepack=supports_prepack, ) ret_type = cpp.returns_type(self.f.func.returns, symint=False).cpp_type() self.out = ATenArg(name="out", cpp_type=ret_type, default=None) + + fixed_storage_type = None + if "out" in self.suite_def.arg_storage_types: + fixed_storage_type = self.suite_def.arg_storage_types["out"] + fixed_memory_layout = None + if "out" in self.suite_def.arg_memory_layouts: + fixed_memory_layout = self.suite_def.arg_memory_layouts["out"] + if ret_type == AT_TENSOR: self.refs["out"] = ValueRef( - name="out_ref", src_cpp_name="out", src_cpp_type=ret_type, is_out=True + name="out_ref", + src_cpp_name="out", + src_cpp_type=ret_type, + is_out=True, + fixed_storage_type=fixed_storage_type, + fixed_memory_layout=fixed_memory_layout, ) elif ret_type == TWO_TENSOR_TUPLE: self.refs["out"] = [ @@ -159,12 +184,24 @@ def __init__( src_cpp_name="std::get<0>(out)", src_cpp_type="at::Tensor", is_out=True, + fixed_storage_type=( + fixed_storage_type[0] if fixed_storage_type else None + ), + fixed_memory_layout=( + fixed_memory_layout[0] if fixed_memory_layout else None + ), ), ValueRef( name="out_ref_second", src_cpp_name="std::get<1>(out)", src_cpp_type="at::Tensor", is_out=True, + fixed_storage_type=( + fixed_storage_type[1] if fixed_storage_type else None + ), + fixed_memory_layout=( + fixed_memory_layout[1] if fixed_memory_layout else None + ), ), ValueRef( name="out_ref", @@ -180,18 +217,36 @@ def __init__( src_cpp_name="std::get<0>(out)", src_cpp_type="at::Tensor", is_out=True, + fixed_storage_type=( + fixed_storage_type[0] if fixed_storage_type else None + ), + fixed_memory_layout=( + fixed_memory_layout[0] if fixed_memory_layout else None + ), ), ValueRef( name="out_ref_second", src_cpp_name="std::get<1>(out)", src_cpp_type="at::Tensor", is_out=True, + fixed_storage_type=( + fixed_storage_type[1] if fixed_storage_type else None + ), + fixed_memory_layout=( + fixed_memory_layout[1] if fixed_memory_layout else None + ), ), ValueRef( name="out_ref_third", src_cpp_name="std::get<2>(out)", src_cpp_type="at::Tensor", is_out=True, + fixed_storage_type=( + fixed_storage_type[2] if fixed_storage_type else None + ), + fixed_memory_layout=( + fixed_memory_layout[2] if fixed_memory_layout else None + ), ), ValueRef( name="out_ref", @@ -302,7 +357,12 @@ def create_value_for( # noqa: C901 ret_str += f"{self.graph}{self.dot}" ret_str += "add_input_tensor(" if ref.is_in else "add_tensor(" ret_str += f"{ref.src_cpp_name}->sizes().vec(), " - ret_str += f"from_at_scalartype({ref.src_cpp_name}->scalar_type())); \n" + ret_str += f"from_at_scalartype({ref.src_cpp_name}->scalar_type()" + if ref.fixed_storage_type: + ret_str += f", {ref.fixed_storage_type}" + if ref.fixed_memory_layout: + ret_str += f", {ref.fixed_memory_layout}" + ret_str += "));\n" elif prepack: ret_str += f"{self.graph}{self.dot}" ret_str += f"add_tensorref({ref.src_cpp_name}->sizes().vec(), " @@ -385,7 +445,12 @@ def create_value_for( # noqa: C901 elif ref.src_cpp_type == AT_TENSOR and not prepack: ret_str += "add_input_tensor(" if ref.is_in else "add_tensor(" ret_str += f"{ref.src_cpp_name}.sizes().vec(), " - ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type())); \n" + ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type())" + if ref.fixed_storage_type: + ret_str += f", {ref.fixed_storage_type}" + if ref.fixed_memory_layout: + ret_str += f", {ref.fixed_memory_layout}" + ret_str += ");\n" elif ref.src_cpp_type == AT_TENSOR and prepack: ret_str += f"add_tensorref({ref.src_cpp_name}.sizes().vec(), " ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type()), " diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py index 5be4ddba6bf..250edf333bc 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py @@ -140,7 +140,13 @@ def call_data_gen_fn(self, arg: Argument, data: Any, terminate: bool = True) -> else self.suite_def.arg_data_range[arg.name] ) - ret_str = f"{self.suite_def.data_gen}({init_list_str(data)}, {tensor_dtype}, {data_range[0]}, {data_range[1]})" + data_gen_fn = ( + self.suite_def.data_gen + if arg.name not in self.suite_def.arg_data_gen_fn + else self.suite_def.arg_data_gen_fn[arg.name] + ) + + ret_str = f"{data_gen_fn}({init_list_str(data)}, {tensor_dtype}, {data_range[0]}, {data_range[1]})" if terminate: ret_str += ";" @@ -288,13 +294,29 @@ def generate_suite_cpp(self) -> str: if (dtype == at::kBool) return at::rand(sizes, at::device(at::kCPU)) > 0.5; - + if (high == 1.0 && low == 0.0) return at::rand(sizes, at::device(at::kCPU).dtype(dtype)); return at::rand(sizes, at::device(at::kCPU).dtype(dtype)) * (high - low) + low; }} +at::Tensor make_zeros_tensor( + std::vector sizes, + at::ScalarType dtype = at::kFloat, + float low = 0.0, + float high = 1.0) {{ + return at::zeros(sizes, at::device(at::kCPU).dtype(dtype)); +}} + +at::Tensor make_ones_tensor( + std::vector sizes, + at::ScalarType dtype = at::kFloat, + float low = 0.0, + float high = 1.0) {{ + return at::ones(sizes, at::device(at::kCPU).dtype(dtype)); +}} + at::Tensor make_seq_tensor( std::vector sizes, at::ScalarType dtype = at::kFloat, diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py index e7cf5ba92a5..c368c23c539 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py @@ -29,7 +29,6 @@ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple void SetUp() override {{ GraphConfig config; - config.expect_dynamic_shapes = true; utils::StorageType default_storage_type; utils::GPUMemoryLayout default_memory_layout; std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam(); diff --git a/backends/vulkan/test/op_tests/utils/test_suite.py b/backends/vulkan/test/op_tests/utils/test_suite.py index 72ba457b5af..427864b0d5d 100644 --- a/backends/vulkan/test/op_tests/utils/test_suite.py +++ b/backends/vulkan/test/op_tests/utils/test_suite.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional ################################### ## Generic Test Suite definition ## @@ -23,6 +23,7 @@ def __init__(self, input_cases: List[Any]): self.data_range = (0, 1) self.arg_dtype = {} + self.arg_data_gen_fn: Dict[str, str] = {} self.arg_data_range = {} self.atol: str = "1e-5" @@ -48,3 +49,5 @@ def __init__(self, input_cases: List[Any]): self.layouts: List[str] = ["utils::kChannelsPacked"] self.data_gen: str = "make_rand_tensor" self.force_io: bool = True + self.arg_storage_types: Dict[str, str] = {} + self.arg_memory_layouts: Dict[str, str] = {} From b1fed6b18a4a97d37626d8b4071a69c41dd5933b Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 24 Jun 2025 22:01:11 -0700 Subject: [PATCH 3/4] [ET-VK] Implement `native_group_norm Pull Request resolved: https://github.com/pytorch/executorch/pull/11827 ## Changes * Add implementation for the group norm operator. The operator is implemented via a 2 stage implementation. First, a reduction operator is executed to calculate the mean and standard deviation of each channel group. Then, the normalization is applied in an elementwise fashion. ghstack-source-id: 292530158 @exported-using-ghexport Differential Revision: [D77038778](https://our.internmc.facebook.com/intern/diff/D77038778/) --- .../vulkan/runtime/graph/ComputeGraph.cpp | 32 +++ backends/vulkan/runtime/graph/ComputeGraph.h | 30 ++- .../ops/glsl/group_norm_reduce_texture.glsl | 189 +++++++++++++++ .../ops/glsl/group_norm_reduce_texture.yaml | 15 ++ .../graph/ops/glsl/group_norm_texture.glsl | 129 ++++++++++ .../graph/ops/glsl/group_norm_texture.yaml | 15 ++ .../runtime/graph/ops/impl/GroupNorm.cpp | 225 ++++++++++++++++++ backends/vulkan/test/op_tests/cases.py | 39 +++ 8 files changed, 672 insertions(+), 2 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/group_norm_texture.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/group_norm_texture.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/GroupNorm.cpp diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index bed379c0c35..b63f89e299d 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -272,6 +272,38 @@ vkapi::ScalarType ComputeGraph::dtype_of(const ValueRef idx) const { VK_THROW("Could not get dtype of value with type ", val.type()); } +bool ComputeGraph::is_contiguous_buffer_tensor(const ValueRef idx) const { + if (!val_is_tensor(idx)) { + return false; + } + if (!is_buffer_storage(idx)) { + return false; + } + return is_contiguous(idx); +} + +bool ComputeGraph::is_standard_channels_packed_texture_tensor( + const ValueRef idx) const { + if (!val_is_tensor(idx)) { + return false; + } + if (is_buffer_storage(idx)) { + return false; + } + return has_standard_axis_map(idx) && packed_dim_of(idx) == 2; +} + +bool ComputeGraph::is_standard_width_packed_texture_tensor( + const ValueRef idx) const { + if (!val_is_tensor(idx)) { + return false; + } + if (is_buffer_storage(idx)) { + return false; + } + return has_standard_axis_map(idx) && packed_dim_of(idx) == 0; +} + ValueRef ComputeGraph::add_tensor( const std::vector& sizes, const vkapi::ScalarType dtype, diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 21d80d5843f..eac632e6d35 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -231,7 +231,7 @@ class ComputeGraph final { inline ptr_type get_##short_name(const ValueRef idx) { \ return ptr_type(this, idx); \ } \ - inline bool val_is_##short_name(const ValueRef idx) { \ + inline bool val_is_##short_name(const ValueRef idx) const { \ return values_.at(idx).is##type_name(); \ } @@ -314,6 +314,32 @@ class ComputeGraph final { return values_.at(idx).toConstTensor().has_buffer_storage(); } + /* + * Checks that the following is true: + * 1. The value at `idx` is a tensor + * 2. The tensor at `idx` has buffer storage + * 3. The buffer backed tensor at `idx` has a contiguous memory layout + */ + bool is_contiguous_buffer_tensor(const ValueRef idx) const; + + /* + * Checks that the following is true: + * 1. The value at `idx` is a tensor + * 2. The tensor at `idx` has texture storage + * 3. The texture backed tensor at `idx` has a standard axis mapping + * 4. The texture backed tensor at `idx` is channels packed + */ + bool is_standard_channels_packed_texture_tensor(const ValueRef idx) const; + + /* + * Checks that the following is true: + * 1. The value at `idx` is a tensor + * 2. The tensor at `idx` has texture storage + * 3. The texture backed tensor at `idx` has a standard axis mapping + * 4. The texture backed tensor at `idx` is width packed + */ + bool is_standard_width_packed_texture_tensor(const ValueRef idx) const; + inline bool val_is_view_of(const ValueRef maybe_view, const ValueRef base) const { return values_.at(maybe_view) @@ -354,7 +380,7 @@ class ComputeGraph final { return values_.at(idx).toTensor().numel_ubo(); } - inline bool has_standard_axis_map(const ValueRef idx) { + inline bool has_standard_axis_map(const ValueRef idx) const { return values_.at(idx).toTensor().has_standard_axis_map(); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.glsl new file mode 100644 index 00000000000..70fdf2bae17 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.glsl @@ -0,0 +1,189 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#include "broadcasting_utils.h" +#include "indexing_utils.h" + +#define PRECISION ${PRECISION} + +#define BUF_T ${buffer_scalar_type(DTYPE)} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_mean", DTYPE, "buffer")} +${layout_declare_tensor(B, "w", "t_rstd", DTYPE, "buffer")} + +${layout_declare_tensor(B, "r", "t_in", DTYPE, "texture3d")} + +${layout_declare_ubo(B, "ivec4", "mean_strides")} +${layout_declare_ubo(B, "int", "mean_numel")} +${layout_declare_ubo(B, "ivec3", "in_limits")} +${layout_declare_ubo(B, "ivec4", "in_sizes")} + +layout(push_constant) uniform PRECISION restrict Block { + int group; + float epsilon; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "mean_layout", "DEFAULT_DIM_ORDER")} +const lowp ivec4 mean_dim_order = unhash_dim_order(mean_layout); + +#define LOCAL_WORK_GROUP_SIZE 64 +shared float shared_sum[LOCAL_WORK_GROUP_SIZE]; +shared float shared_sum_sq[LOCAL_WORK_GROUP_SIZE]; + +/* + * Computes the mean and standard deviation of one group of channels of the + * input tensor for the group normalization operator. + * + * Given a tensor of shape [W, H, C, N] the mean and standard deviation tensors + * will have a shape of [G, N] where G = C / group. + * + * The input tensor is assumed to be a channels-packed texture tensor with the + * standard axis mapping. The output tensors are assumed to be contiguous buffer + * tensors. + * + * Algorithm: + * 1. Each shader invocation corresponds to one group in one batch + * 2. The local work group cooperatively reduces over all spatial locations (H×W) + * and all channels within the group (C/group channels) + * 3. Uses shared memory for efficient parallel reduction + * 4. Main thread (local ID 0) writes the final mean and rstd to buffer + * + * Global work group size: {N, 1, 1} + * N is the number of elements in the tensor buffer; each thread computes one + * output element. + * + * Local work group size: {1, float, 1} + * float should be a power of 2, recommended 64 or 128 threads. This allows + * efficient tree-based reduction in shared memory. Each local group will + * cooperate to compute the output element. + * + * Each shader invocation will compute the mean and standard deviation for one + * channel group in the input, and write out the corresponding result. + */ +void group_norm_reduce_C_packed() { + const int global_idx = int(gl_GlobalInvocationID.x); + const int local_idx = int(gl_LocalInvocationID.y); + + // Calculate group dimensions + const int D = in_sizes.z / group; // channels per group + const int HxW = in_sizes.y * in_sizes.x; // spatial size + const int group_size = D * HxW; // total elements per group + + // Convert global index to (group_idx, batch_idx) + const ivec4 mean_tidx = bufi_to_tidx(global_idx, mean_strides, mean_dim_order); + + // Initialize local sums + float local_sum = 0.0; + float local_sum_sq = 0.0; + int local_count = 0; + + // Calculate the range of channels for this group + const int group_start_channel = mean_tidx.x * D; + const int group_end_channel = group_start_channel + D; + + // Calculate the range of texels that contain channels from this group + const int start_texel_idx = group_start_channel / 4; + const int end_texel_idx = divup4(group_end_channel); + const int texels_in_group = end_texel_idx - start_texel_idx; + + // Total texels to process across all spatial locations + const int total_texels = texels_in_group * HxW; + + // Each thread processes a subset of texels + const int texels_per_thread = (total_texels + LOCAL_WORK_GROUP_SIZE - 1) / LOCAL_WORK_GROUP_SIZE; + const int start_texel = local_idx * texels_per_thread; + const int end_texel = min(start_texel + texels_per_thread, total_texels); + + // Process assigned texels + for (int texel_idx = start_texel; texel_idx < end_texel; texel_idx++) { + // Convert texel index to spatial and channel coordinates + const int spatial_idx = texel_idx / texels_in_group; + const int texel_in_group = texel_idx % texels_in_group; + + // Convert to spatial coordinates + const int w = spatial_idx % in_sizes.x; + const int h = spatial_idx / in_sizes.x; + + // Calculate the global texel index + const int global_texel_idx = start_texel_idx + texel_in_group; + + // Convert to texture position using default axis mapping + ivec3 tex_pos = ivec3(w, h, global_texel_idx); + + // Adjust for batch dimension if needed + if (in_sizes.w > 1) { + // default axis mapping means channels is the batch concat dim + tex_pos.z += mean_tidx.y * divup4(in_sizes.z); + } + + // Check bounds and load texel + if (all(lessThan(tex_pos, in_limits))) { + const vec4 texel_val = load_texel(t_in, tex_pos); + + // Process all components of the texel that belong to this group + const int texel_start_channel = global_texel_idx * 4; + for (int comp = 0; comp < 4; comp++) { + const int current_channel = texel_start_channel + comp; + + // Check if this component belongs to the current group + if (current_channel >= group_start_channel && current_channel < group_end_channel) { + const float val = texel_val[comp]; + local_sum += val; + local_sum_sq += val * val; + local_count++; + } + } + } + } + + // Store local results in shared memory + shared_sum[local_idx] = local_sum; + shared_sum_sq[local_idx] = local_sum_sq; + + // Synchronize threads + memoryBarrierShared(); + barrier(); + + // Perform tree-based reduction in shared memory + for (int stride = LOCAL_WORK_GROUP_SIZE / 2; stride > 0; stride /= 2) { + if (local_idx < stride) { + shared_sum[local_idx] += shared_sum[local_idx + stride]; + shared_sum_sq[local_idx] += shared_sum_sq[local_idx + stride]; + } + memoryBarrierShared(); + barrier(); + } + + // Main thread writes the result + if (local_idx == 0 && global_idx < mean_numel) { + const float total_sum = shared_sum[0]; + const float total_sum_sq = shared_sum_sq[0]; + const float count = float(group_size); + + // Calculate mean and reciprocal standard deviation + const float mean_val = total_sum / count; + const float variance = (total_sum_sq / count) - (mean_val * mean_val); + const float rstd_val = 1.0 / sqrt(variance + epsilon); + + // Write to buffer-backed tensors + t_mean[global_idx] = BUF_T(mean_val); + t_rstd[global_idx] = BUF_T(rstd_val); + } +} + +void main() { + group_norm_reduce_C_packed(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.yaml new file mode 100644 index 00000000000..00c357a1d6e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.yaml @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +group_norm_reduce_texture: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: group_norm_reduce_texture diff --git a/backends/vulkan/runtime/graph/ops/glsl/group_norm_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/group_norm_texture.glsl new file mode 100644 index 00000000000..8440481963a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/group_norm_texture.glsl @@ -0,0 +1,129 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#include "broadcasting_utils.h" +#include "indexing_utils.h" + +#define PRECISION ${PRECISION} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")} + +${layout_declare_tensor(B, "r", "t_in", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_weight", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_mean", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_rstd", DTYPE, "buffer")} + +${layout_declare_ubo(B, "ivec3", "out_limits")} +${layout_declare_ubo(B, "ivec4", "out_sizes")} +${layout_declare_ubo(B, "ivec3", "weight_limits")} +${layout_declare_ubo(B, "ivec4", "mean_strides")} + +layout(push_constant) uniform PRECISION restrict Block { + int group; + float epsilon; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * Applies group normalization to t_in, and write the results to t_out. The mean + * and rstd of the input tensor are precomputed and passed in as t_mean and + * t_rstd. + * + * Given an input tensor t_in of shape [N, C, H, W], the mean and rstd will have + * shape [N, C / ngroup], and the output will have the same shape as t_in. The + * weight and bias tensor will have a shape of [C]. + * + * In this implementation, the input and output tensors are assumed to be + * channels packed textures with standard axis mapping. + * + * The weight and bias tensors are assumed to be width packed textures with + * standard axis mapping. + * + * The mean and rstd tensors are assumed to be contiguous buffer-backed tensors. + */ +void apply_group_norm() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + // Check bounds + if (any(greaterThanEqual(pos, out_limits))) { + return; + } + + // Convert texture position to tensor coordinates using default axis mapping + // and channels packing + ivec4 out_tidx = ivec4(pos.x, pos.y, mul4(pos.z), 0); + + // Handle batch dimension if batches > 1 + if (out_sizes.w > 1) { + const int C_aligned = alignup4(out_sizes.z); + // default axis mapping means channels is the batch concatenation dim + const int batch_idx = out_tidx.z / C_aligned; + out_tidx.w = batch_idx; + out_tidx.z = out_tidx.z % C_aligned; + } + + // Load input texel (contains 4 consecutive channels) + const vec4 input_texel = load_texel(t_in, pos); + + // Load weight and bias texels, which are width-packed; each element along the + // width dim corresponds to a channel in the input tensor. + const ivec3 weight_pos = ivec3(out_tidx.z / 4, 0, 0); + const vec4 weight_texel = load_texel(t_weight, weight_pos); + const vec4 bias_texel = load_texel(t_bias, weight_pos); + + // Calculate which channels this texel represents + // For channels-packed layout: texel at position z contains channels [z, z+1, z+2, z+3] + const int base_channel = out_tidx.z; + + // Calculate buffer indices for mean/rstd lookup + // Mean/rstd tensors have shape [G, N] where G = C/group + const int batch_idx = out_tidx.w; + const int channels_per_group = out_sizes.z / group; + + vec4 bias; + // Process each element of the output texel individually, since each element + // may belong to a different channel group + for (int i = 0; i < 4; ++i) { + const int channel_idx = base_channel + i; + // Handle case where padding channels are added + if (channel_idx >= out_sizes.z) { + bias[i] = input_texel[i]; + continue; + } + + // Calculate group index for this channel + const int group_idx = channel_idx / channels_per_group; + + // Create tensor index for mean/rstd buffer access + const ivec4 mean_tidx = ivec4(group_idx, batch_idx, 0, 0); + const int mean_bufi = tidx_to_bufi(mean_tidx, mean_strides); + + // Load mean and rstd values for this channel + const float mean_val = t_mean[mean_bufi]; + const float rstd_val = t_rstd[mean_bufi]; + + // Apply group normalization with weight and bias: ((input - mean) * rstd) * weight + bias + const float normalized = (input_texel[i] - mean_val) * rstd_val; + bias[i] = normalized * weight_texel[i] + bias_texel[i]; + } + + // Write result to output texture + write_texel(t_out, pos, bias); +} + +void main() { + apply_group_norm(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/group_norm_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/group_norm_texture.yaml new file mode 100644 index 00000000000..b50853be3b0 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/group_norm_texture.yaml @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +group_norm_texture: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: group_norm_texture diff --git a/backends/vulkan/runtime/graph/ops/impl/GroupNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/GroupNorm.cpp new file mode 100644 index 00000000000..8d2a848b0c4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/GroupNorm.cpp @@ -0,0 +1,225 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include + +#include + +namespace vkcompute { + +std::vector calc_group_norm_mean_sizes( + api::vTensor& self, + const int64_t group) { + const std::vector& input_sizes = self.sizes(); + const int64_t N = input_sizes.at(0); + return {N, group}; +} + +utils::uvec3 group_norm_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)global_workgroup_size; + (void)args; + (void)resize_args; + + return {1, 64, 1}; +} + +void resize_group_norm_texture_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + // Extract tensor references from args + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + const ValueRef mean = args.at(1).refs.at(3); + const ValueRef rstd = args.at(1).refs.at(4); + + // Extract group from resize args + const int64_t group_val = graph->extract_scalar(resize_args.at(0)); + + // Get input tensor sizes using ComputeGraph APIs + const std::vector in_sizes = graph->sizes_of(in); + + // Output tensor should have the same size as input + graph->virtual_resize(out, in_sizes); + + // Mean and rstd tensors should have size {num_batches, num_groups} + const int64_t N = in_sizes.at(0); // batch dimension + const std::vector mean_rstd_sizes = {N, group_val}; + + // Resize mean and rstd tensors + graph->virtual_resize(mean, mean_rstd_sizes); + graph->virtual_resize(rstd, mean_rstd_sizes); +} + +void add_native_group_norm_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef weight_data, + const ValueRef bias_data, + const ValueRef N, + const ValueRef C, + const ValueRef HxW, + const ValueRef group, + const ValueRef eps, + const ValueRef out, + const ValueRef mean, + const ValueRef rstd) { + (void)N; + (void)C; + (void)HxW; + + const ValueRef arg_weight = prepack_standard( + graph, + weight_data, + graph.storage_type_of(in), + utils::kWidthPacked, + false); + const ValueRef arg_bias = prepack_standard( + graph, bias_data, graph.storage_type_of(in), utils::kWidthPacked, false); + + const int64_t group_val = graph.extract_scalar(group); + const float epsilon = graph.extract_scalar(eps); + + const std::vector in_sizes = graph.sizes_of(in); + + std::string kernel_name("group_norm_reduce_texture"); + kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + const struct { + int32_t group; + float epsilon; + } params_uniform = {static_cast(group_val), epsilon}; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + group_norm_local_wg_size, + // Inputs and Outputs + {{{mean, rstd}, vkapi::kWrite}, {in, vkapi::kRead}}, + // Shader params buffers + { + graph.strides_ubo(mean), + graph.numel_ubo(mean), + graph.logical_limits_ubo(in), + graph.sizes_ubo(in), + }, + // Push Constants + { + PushConstantDataInfo(¶ms_uniform, sizeof(params_uniform)), + }, + // Specialization Constants + { + graph.hashed_layout_of(mean), + }, + // Resize Args + {group}, + // Resizing Logic + nullptr)); + + // Compute element-wise normalization, now that mean and rstd have been + // computed. + std::string norm_kernel_name("group_norm_texture"); + norm_kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(norm_kernel_name, graph.dtype_of(out)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(norm_kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, + {{in, arg_weight, arg_bias, mean, rstd}, vkapi::kRead}}, + // Shader params buffers + { + graph.logical_limits_ubo(out), + graph.sizes_ubo(out), + graph.logical_limits_ubo(arg_weight), + graph.strides_ubo(mean), + }, + // Push Constants + { + PushConstantDataInfo(¶ms_uniform, sizeof(params_uniform)), + }, + // Specialization Constants + { + graph.hashed_layout_of(in), + }, + // Resize Args + {group}, + // Resizing Logic + resize_group_norm_texture_node)); +} + +void native_group_norm(ComputeGraph& graph, const std::vector& args) { + // Assign each element of the args vector to const ValueRef variables + const ValueRef in = args.at(0); + const ValueRef weight_data = args.at(1); + const ValueRef bias_data = args.at(2); + const ValueRef N = args.at(3); + const ValueRef C = args.at(4); + const ValueRef HxW = args.at(5); + const ValueRef group = args.at(6); + const ValueRef eps = args.at(7); + const ValueRef out_tuple_ref = args.at(8); + + ValueRef out = kDummyValueRef; + ValueRef mean = kDummyValueRef; + ValueRef rstd = kDummyValueRef; + + { + const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); + out = out_tuple->at(0); + mean = out_tuple->at(1); + rstd = out_tuple->at(2); + } + + VK_CHECK_COND(graph.val_is_tref(weight_data)); + VK_CHECK_COND(graph.val_is_tref(bias_data)); + + // Check expected storage types and memory layouts for tensor variables + VK_CHECK_COND(graph.is_standard_channels_packed_texture_tensor(in)); + VK_CHECK_COND(graph.is_standard_channels_packed_texture_tensor(out)); + + VK_CHECK_COND(graph.is_contiguous_buffer_tensor(mean)); + VK_CHECK_COND(graph.is_contiguous_buffer_tensor(rstd)); + + return add_native_group_norm_node( + graph, + in, + weight_data, + bias_data, + N, + C, + HxW, + group, + eps, + out, + mean, + rstd); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.native_group_norm.default, native_group_norm); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 92f73268ebf..0fd5ef4f002 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -646,6 +646,45 @@ def get_native_layer_norm_inputs(): return test_suite +@register_test_suite("aten.native_group_norm.default") +def get_native_group_norm_inputs(): + test_suite = VkTestSuite( + [ + # (input_shape, weight_shape, bias_shape, N, C, HxW, group, eps) + # General test cases + ((1, 8, 4, 4), (8), (8), 1, 8, 16, 2, 0.001), + ((2, 8, 3, 3), (8), (8), 2, 8, 9, 4, 0.001), + ((1, 12, 2, 2), (12), (12), 1, 12, 4, 3, 0.001), + ((3, 16, 5, 5), (16), (16), 3, 16, 25, 8, 0.001), + ((3, 16, 13, 17), (16), (16), 3, 16, 13 * 17, 4, 0.001), + ((1, 4, 7, 7), (4), (4), 1, 4, 49, 2, 0.001), + ((2, 6, 1, 8), (6), (6), 2, 6, 8, 3, 0.001), + # Single group and prime number sizes + ((3, 7, 13, 11), (7), (7), 3, 7, 13 * 11, 1, 0.001), + # Each channel is it's own group and prime number sizes + ((1, 7, 13, 11), (7), (7), 1, 7, 13 * 11, 7, 0.001), + ] + ) + test_suite.layouts = [ + "utils::kChannelsPacked", + ] + test_suite.storage_types = [ + "utils::kTexture3D", + ] + test_suite.dtypes = [ + "at::kFloat", + "at::kHalf", + ] + test_suite.arg_storage_types = { + "out": [None, "utils::kBuffer", "utils::kBuffer"], + } + + test_suite.prepacked_args = ["weight", "bias"] + test_suite.requires_prepack = True + + return test_suite + + def get_upsample_inputs(): inputs_list = [ # (input tensor shape, output 2D image size (H, W), output scaling factors) From f3efc6c372fcdc97b52e523eaa68c601a9c93017 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 24 Jun 2025 22:01:18 -0700 Subject: [PATCH 4/4] [ET-VK][ez] Allow specifying multiple storage types/memory layouts for an operator + register group norm operator Pull Request resolved: https://github.com/pytorch/executorch/pull/11828 ## Changes * Handle cases where an operator needs to specify a separate storage type / memory layout for each individual output. ## Motivation Required for the group norm operator. ## Future Work Currently, the `tag_memory_meta_pass` graph pass assumes that all tensors participating in a computation (aside from weights) will have the same storage type and memory layout. As more operators are being added, there are more exceptions to this rule. The pass may need an update in the near future to make it possible to specify required storage types and memory layouts on a more granular level. ghstack-source-id: 292530159 @exported-using-ghexport Differential Revision: [D77038781](https://our.internmc.facebook.com/intern/diff/D77038781/) --- .../vulkan/_passes/tag_memory_meta_pass.py | 34 +++++++--- backends/vulkan/op_registry.py | 26 ++++++++ backends/vulkan/test/test_vulkan_delegate.py | 66 +++++++++++++++++++ backends/vulkan/utils.py | 16 ++++- 4 files changed, 130 insertions(+), 12 deletions(-) diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 836a0c6ef7d..0bd8dae0b66 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -91,7 +91,7 @@ def __init__( self.default_layout: VkMemoryLayout = default_memory_layout self.texture_limits = texture_limits - def propose_node_storage( + def propose_node_storage( # noqa: C901 self, node: torch.fx.Node, ) -> Optional[VkStorageType]: @@ -138,15 +138,23 @@ def propose_node_storage( for arg in node.args: if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg): storage = utils.get_node_storage_type(arg) + # Some operators which return multiple output tensors may specify a + # different storage type for each output. In this case, the storage type + # for the first output is used. + if isinstance(storage, (list, tuple)): + storage = storage[0] if storage is not None and storage in valid_storage_types: return storage # If no storage type has been resolved yet, assume the optimal storage type of # the first opinionated user. This search is recursive. for user in node.users: - optimal_storage = self.propose_node_storage(user) - if optimal_storage is not None: - return optimal_storage + storage = self.propose_node_storage(user) + # See above + if isinstance(storage, (list, tuple)): + storage = storage[0] + if storage is not None: + return storage if self.default_storage in valid_storage_types: return self.default_storage @@ -179,15 +187,23 @@ def propose_node_layout( for arg in node.args: if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg): layout = utils.get_node_memory_layout(arg) + # Some operators which return multiple output tensors may specify a + # different memory layout for each output. In this case, the storage + # type for the first output is used. + if isinstance(layout, (list, tuple)): + layout = layout[0] if layout is not None and layout in valid_layouts: return layout - # If no storage type has been resolved yet, assume the optimal storage type of - # the first opinionated user. This search is recursive. + # If no memory layout has been resolved yet, assume the optimal layout of the + # first opinionated user. This search is recursive. for user in node.users: - optimal_storage = self.propose_node_layout(user, storage) - if optimal_storage is not None: - return optimal_storage + layout = self.propose_node_layout(user, storage) + # See above comment + if isinstance(layout, (list, tuple)): + layout = layout[0] + if layout is not None: + return layout # As a last resort, return the default storage type that should be used. if self.default_layout in valid_layouts: diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 9333f34430e..0258aceb82b 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -655,6 +655,32 @@ def register_ported_ops_with_prepacking(features: OpFeatures): return features +@update_features( + [ + exir_ops.edge.aten.native_group_norm.default, + ] +) +def register_native_group_norm(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + valid_packed_dims={PackedDim.CHANNELS}, + ) + features.handles_own_prepacking = True + + features.optimal_storage = [ + VkStorageType.TEXTURE_3D, + VkStorageType.BUFFER, + VkStorageType.BUFFER, + ] + + features.optimal_layout = [ + VkMemoryLayout.TENSOR_CHANNELS_PACKED, + VkMemoryLayout.TENSOR_WIDTH_PACKED, + VkMemoryLayout.TENSOR_WIDTH_PACKED, + ] + + return features + + # Ported ops that support their own prepacking. @update_features( [ diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 0096834f3c6..04adf183e55 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -1898,3 +1898,69 @@ def forward(self, x): dynamic_shapes=dynamic_shapes, test_inputs=test_inputs, ) + + def test_vulkan_backend_group_norm(self): + class ConvGroupNormModule(torch.nn.Module): + def __init__(self): + super().__init__() + # Conv2d: 3 input channels -> 16 output channels + self.conv = torch.nn.Conv2d( + in_channels=3, + out_channels=16, + kernel_size=3, + padding=1, + bias=True, + ) + # GroupNorm: 4 groups for 16 channels (16 % 4 == 0) + self.group_norm = torch.nn.GroupNorm( + num_groups=4, + num_channels=16, + eps=1e-5, + affine=True, + ) + + def forward(self, x): + x = self.conv(x) + x = self.group_norm(x) + return x + + # Create sample inputs: [batch, channels, height, width] + sample_inputs = (torch.randn(size=(1, 3, 32, 32), dtype=torch.float32),) + + # Test with static shapes first + self.lower_module_and_test_output( + ConvGroupNormModule(), + sample_inputs, + ) + + def test_vulkan_backend_group_norm_different_groups(self): + class GroupNormModule(torch.nn.Module): + def __init__(self, num_groups, num_channels): + super().__init__() + self.group_norm = torch.nn.GroupNorm( + num_groups=num_groups, + num_channels=num_channels, + eps=1e-5, + affine=True, + ) + + def forward(self, x): + return self.group_norm(x) + + # Test different group configurations + test_configs = [ + (2, 8), # 2 groups, 8 channels + (4, 16), # 4 groups, 16 channels + (8, 32), # 8 groups, 32 channels + ] + + for num_groups, num_channels in test_configs: + with self.subTest(num_groups=num_groups, num_channels=num_channels): + sample_inputs = ( + torch.randn(size=(2, num_channels, 16, 16), dtype=torch.float32), + ) + + self.lower_module_and_test_output( + GroupNormModule(num_groups, num_channels), + sample_inputs, + ) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 642f7c5f495..5d57ce1e7be 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -264,9 +264,19 @@ def set_node_spec_attr(node: torch.fx.Node, attr: str, value): if isinstance(spec, TensorSpec): setattr(spec, attr, value) elif isinstance(spec, (list, tuple)): - for s in spec: - assert isinstance(s, TensorSpec) - setattr(s, attr, value) + # Special case if value is a list/tuple of the same length as the + # collection of tensors in the node. In this case, treat the value list + # as a list of values to set indivudually for each tensor in the node + if isinstance(value, (list, tuple)) and len(spec) == len(value): + assert len(spec) == len(value) + for s, v in zip(spec, value): + assert isinstance(s, TensorSpec) + setattr(s, attr, v) + # Otherwise, set the attribute to value for all tensors in the list + else: + for s in spec: + assert isinstance(s, TensorSpec) + setattr(s, attr, value) else: raise RuntimeError(f"Cannot set attr for spec of type {type(spec)}")