diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 836a0c6ef7d..0bd8dae0b66 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -91,7 +91,7 @@ def __init__( self.default_layout: VkMemoryLayout = default_memory_layout self.texture_limits = texture_limits - def propose_node_storage( + def propose_node_storage( # noqa: C901 self, node: torch.fx.Node, ) -> Optional[VkStorageType]: @@ -138,15 +138,23 @@ def propose_node_storage( for arg in node.args: if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg): storage = utils.get_node_storage_type(arg) + # Some operators which return multiple output tensors may specify a + # different storage type for each output. In this case, the storage type + # for the first output is used. + if isinstance(storage, (list, tuple)): + storage = storage[0] if storage is not None and storage in valid_storage_types: return storage # If no storage type has been resolved yet, assume the optimal storage type of # the first opinionated user. This search is recursive. for user in node.users: - optimal_storage = self.propose_node_storage(user) - if optimal_storage is not None: - return optimal_storage + storage = self.propose_node_storage(user) + # See above + if isinstance(storage, (list, tuple)): + storage = storage[0] + if storage is not None: + return storage if self.default_storage in valid_storage_types: return self.default_storage @@ -179,15 +187,23 @@ def propose_node_layout( for arg in node.args: if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg): layout = utils.get_node_memory_layout(arg) + # Some operators which return multiple output tensors may specify a + # different memory layout for each output. In this case, the storage + # type for the first output is used. + if isinstance(layout, (list, tuple)): + layout = layout[0] if layout is not None and layout in valid_layouts: return layout - # If no storage type has been resolved yet, assume the optimal storage type of - # the first opinionated user. This search is recursive. + # If no memory layout has been resolved yet, assume the optimal layout of the + # first opinionated user. This search is recursive. for user in node.users: - optimal_storage = self.propose_node_layout(user, storage) - if optimal_storage is not None: - return optimal_storage + layout = self.propose_node_layout(user, storage) + # See above comment + if isinstance(layout, (list, tuple)): + layout = layout[0] + if layout is not None: + return layout # As a last resort, return the default storage type that should be used. if self.default_layout in valid_layouts: diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 9333f34430e..0258aceb82b 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -655,6 +655,32 @@ def register_ported_ops_with_prepacking(features: OpFeatures): return features +@update_features( + [ + exir_ops.edge.aten.native_group_norm.default, + ] +) +def register_native_group_norm(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + valid_packed_dims={PackedDim.CHANNELS}, + ) + features.handles_own_prepacking = True + + features.optimal_storage = [ + VkStorageType.TEXTURE_3D, + VkStorageType.BUFFER, + VkStorageType.BUFFER, + ] + + features.optimal_layout = [ + VkMemoryLayout.TENSOR_CHANNELS_PACKED, + VkMemoryLayout.TENSOR_WIDTH_PACKED, + VkMemoryLayout.TENSOR_WIDTH_PACKED, + ] + + return features + + # Ported ops that support their own prepacking. @update_features( [ diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index bed379c0c35..b63f89e299d 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -272,6 +272,38 @@ vkapi::ScalarType ComputeGraph::dtype_of(const ValueRef idx) const { VK_THROW("Could not get dtype of value with type ", val.type()); } +bool ComputeGraph::is_contiguous_buffer_tensor(const ValueRef idx) const { + if (!val_is_tensor(idx)) { + return false; + } + if (!is_buffer_storage(idx)) { + return false; + } + return is_contiguous(idx); +} + +bool ComputeGraph::is_standard_channels_packed_texture_tensor( + const ValueRef idx) const { + if (!val_is_tensor(idx)) { + return false; + } + if (is_buffer_storage(idx)) { + return false; + } + return has_standard_axis_map(idx) && packed_dim_of(idx) == 2; +} + +bool ComputeGraph::is_standard_width_packed_texture_tensor( + const ValueRef idx) const { + if (!val_is_tensor(idx)) { + return false; + } + if (is_buffer_storage(idx)) { + return false; + } + return has_standard_axis_map(idx) && packed_dim_of(idx) == 0; +} + ValueRef ComputeGraph::add_tensor( const std::vector& sizes, const vkapi::ScalarType dtype, diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 21d80d5843f..eac632e6d35 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -231,7 +231,7 @@ class ComputeGraph final { inline ptr_type get_##short_name(const ValueRef idx) { \ return ptr_type(this, idx); \ } \ - inline bool val_is_##short_name(const ValueRef idx) { \ + inline bool val_is_##short_name(const ValueRef idx) const { \ return values_.at(idx).is##type_name(); \ } @@ -314,6 +314,32 @@ class ComputeGraph final { return values_.at(idx).toConstTensor().has_buffer_storage(); } + /* + * Checks that the following is true: + * 1. The value at `idx` is a tensor + * 2. The tensor at `idx` has buffer storage + * 3. The buffer backed tensor at `idx` has a contiguous memory layout + */ + bool is_contiguous_buffer_tensor(const ValueRef idx) const; + + /* + * Checks that the following is true: + * 1. The value at `idx` is a tensor + * 2. The tensor at `idx` has texture storage + * 3. The texture backed tensor at `idx` has a standard axis mapping + * 4. The texture backed tensor at `idx` is channels packed + */ + bool is_standard_channels_packed_texture_tensor(const ValueRef idx) const; + + /* + * Checks that the following is true: + * 1. The value at `idx` is a tensor + * 2. The tensor at `idx` has texture storage + * 3. The texture backed tensor at `idx` has a standard axis mapping + * 4. The texture backed tensor at `idx` is width packed + */ + bool is_standard_width_packed_texture_tensor(const ValueRef idx) const; + inline bool val_is_view_of(const ValueRef maybe_view, const ValueRef base) const { return values_.at(maybe_view) @@ -354,7 +380,7 @@ class ComputeGraph final { return values_.at(idx).toTensor().numel_ubo(); } - inline bool has_standard_axis_map(const ValueRef idx) { + inline bool has_standard_axis_map(const ValueRef idx) const { return values_.at(idx).toTensor().has_standard_axis_map(); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.glsl new file mode 100644 index 00000000000..70fdf2bae17 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.glsl @@ -0,0 +1,189 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#include "broadcasting_utils.h" +#include "indexing_utils.h" + +#define PRECISION ${PRECISION} + +#define BUF_T ${buffer_scalar_type(DTYPE)} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_mean", DTYPE, "buffer")} +${layout_declare_tensor(B, "w", "t_rstd", DTYPE, "buffer")} + +${layout_declare_tensor(B, "r", "t_in", DTYPE, "texture3d")} + +${layout_declare_ubo(B, "ivec4", "mean_strides")} +${layout_declare_ubo(B, "int", "mean_numel")} +${layout_declare_ubo(B, "ivec3", "in_limits")} +${layout_declare_ubo(B, "ivec4", "in_sizes")} + +layout(push_constant) uniform PRECISION restrict Block { + int group; + float epsilon; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "mean_layout", "DEFAULT_DIM_ORDER")} +const lowp ivec4 mean_dim_order = unhash_dim_order(mean_layout); + +#define LOCAL_WORK_GROUP_SIZE 64 +shared float shared_sum[LOCAL_WORK_GROUP_SIZE]; +shared float shared_sum_sq[LOCAL_WORK_GROUP_SIZE]; + +/* + * Computes the mean and standard deviation of one group of channels of the + * input tensor for the group normalization operator. + * + * Given a tensor of shape [W, H, C, N] the mean and standard deviation tensors + * will have a shape of [G, N] where G = C / group. + * + * The input tensor is assumed to be a channels-packed texture tensor with the + * standard axis mapping. The output tensors are assumed to be contiguous buffer + * tensors. + * + * Algorithm: + * 1. Each shader invocation corresponds to one group in one batch + * 2. The local work group cooperatively reduces over all spatial locations (H×W) + * and all channels within the group (C/group channels) + * 3. Uses shared memory for efficient parallel reduction + * 4. Main thread (local ID 0) writes the final mean and rstd to buffer + * + * Global work group size: {N, 1, 1} + * N is the number of elements in the tensor buffer; each thread computes one + * output element. + * + * Local work group size: {1, float, 1} + * float should be a power of 2, recommended 64 or 128 threads. This allows + * efficient tree-based reduction in shared memory. Each local group will + * cooperate to compute the output element. + * + * Each shader invocation will compute the mean and standard deviation for one + * channel group in the input, and write out the corresponding result. + */ +void group_norm_reduce_C_packed() { + const int global_idx = int(gl_GlobalInvocationID.x); + const int local_idx = int(gl_LocalInvocationID.y); + + // Calculate group dimensions + const int D = in_sizes.z / group; // channels per group + const int HxW = in_sizes.y * in_sizes.x; // spatial size + const int group_size = D * HxW; // total elements per group + + // Convert global index to (group_idx, batch_idx) + const ivec4 mean_tidx = bufi_to_tidx(global_idx, mean_strides, mean_dim_order); + + // Initialize local sums + float local_sum = 0.0; + float local_sum_sq = 0.0; + int local_count = 0; + + // Calculate the range of channels for this group + const int group_start_channel = mean_tidx.x * D; + const int group_end_channel = group_start_channel + D; + + // Calculate the range of texels that contain channels from this group + const int start_texel_idx = group_start_channel / 4; + const int end_texel_idx = divup4(group_end_channel); + const int texels_in_group = end_texel_idx - start_texel_idx; + + // Total texels to process across all spatial locations + const int total_texels = texels_in_group * HxW; + + // Each thread processes a subset of texels + const int texels_per_thread = (total_texels + LOCAL_WORK_GROUP_SIZE - 1) / LOCAL_WORK_GROUP_SIZE; + const int start_texel = local_idx * texels_per_thread; + const int end_texel = min(start_texel + texels_per_thread, total_texels); + + // Process assigned texels + for (int texel_idx = start_texel; texel_idx < end_texel; texel_idx++) { + // Convert texel index to spatial and channel coordinates + const int spatial_idx = texel_idx / texels_in_group; + const int texel_in_group = texel_idx % texels_in_group; + + // Convert to spatial coordinates + const int w = spatial_idx % in_sizes.x; + const int h = spatial_idx / in_sizes.x; + + // Calculate the global texel index + const int global_texel_idx = start_texel_idx + texel_in_group; + + // Convert to texture position using default axis mapping + ivec3 tex_pos = ivec3(w, h, global_texel_idx); + + // Adjust for batch dimension if needed + if (in_sizes.w > 1) { + // default axis mapping means channels is the batch concat dim + tex_pos.z += mean_tidx.y * divup4(in_sizes.z); + } + + // Check bounds and load texel + if (all(lessThan(tex_pos, in_limits))) { + const vec4 texel_val = load_texel(t_in, tex_pos); + + // Process all components of the texel that belong to this group + const int texel_start_channel = global_texel_idx * 4; + for (int comp = 0; comp < 4; comp++) { + const int current_channel = texel_start_channel + comp; + + // Check if this component belongs to the current group + if (current_channel >= group_start_channel && current_channel < group_end_channel) { + const float val = texel_val[comp]; + local_sum += val; + local_sum_sq += val * val; + local_count++; + } + } + } + } + + // Store local results in shared memory + shared_sum[local_idx] = local_sum; + shared_sum_sq[local_idx] = local_sum_sq; + + // Synchronize threads + memoryBarrierShared(); + barrier(); + + // Perform tree-based reduction in shared memory + for (int stride = LOCAL_WORK_GROUP_SIZE / 2; stride > 0; stride /= 2) { + if (local_idx < stride) { + shared_sum[local_idx] += shared_sum[local_idx + stride]; + shared_sum_sq[local_idx] += shared_sum_sq[local_idx + stride]; + } + memoryBarrierShared(); + barrier(); + } + + // Main thread writes the result + if (local_idx == 0 && global_idx < mean_numel) { + const float total_sum = shared_sum[0]; + const float total_sum_sq = shared_sum_sq[0]; + const float count = float(group_size); + + // Calculate mean and reciprocal standard deviation + const float mean_val = total_sum / count; + const float variance = (total_sum_sq / count) - (mean_val * mean_val); + const float rstd_val = 1.0 / sqrt(variance + epsilon); + + // Write to buffer-backed tensors + t_mean[global_idx] = BUF_T(mean_val); + t_rstd[global_idx] = BUF_T(rstd_val); + } +} + +void main() { + group_norm_reduce_C_packed(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.yaml new file mode 100644 index 00000000000..00c357a1d6e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/group_norm_reduce_texture.yaml @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +group_norm_reduce_texture: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: group_norm_reduce_texture diff --git a/backends/vulkan/runtime/graph/ops/glsl/group_norm_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/group_norm_texture.glsl new file mode 100644 index 00000000000..8440481963a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/group_norm_texture.glsl @@ -0,0 +1,129 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#include "broadcasting_utils.h" +#include "indexing_utils.h" + +#define PRECISION ${PRECISION} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")} + +${layout_declare_tensor(B, "r", "t_in", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_weight", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_mean", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_rstd", DTYPE, "buffer")} + +${layout_declare_ubo(B, "ivec3", "out_limits")} +${layout_declare_ubo(B, "ivec4", "out_sizes")} +${layout_declare_ubo(B, "ivec3", "weight_limits")} +${layout_declare_ubo(B, "ivec4", "mean_strides")} + +layout(push_constant) uniform PRECISION restrict Block { + int group; + float epsilon; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * Applies group normalization to t_in, and write the results to t_out. The mean + * and rstd of the input tensor are precomputed and passed in as t_mean and + * t_rstd. + * + * Given an input tensor t_in of shape [N, C, H, W], the mean and rstd will have + * shape [N, C / ngroup], and the output will have the same shape as t_in. The + * weight and bias tensor will have a shape of [C]. + * + * In this implementation, the input and output tensors are assumed to be + * channels packed textures with standard axis mapping. + * + * The weight and bias tensors are assumed to be width packed textures with + * standard axis mapping. + * + * The mean and rstd tensors are assumed to be contiguous buffer-backed tensors. + */ +void apply_group_norm() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + // Check bounds + if (any(greaterThanEqual(pos, out_limits))) { + return; + } + + // Convert texture position to tensor coordinates using default axis mapping + // and channels packing + ivec4 out_tidx = ivec4(pos.x, pos.y, mul4(pos.z), 0); + + // Handle batch dimension if batches > 1 + if (out_sizes.w > 1) { + const int C_aligned = alignup4(out_sizes.z); + // default axis mapping means channels is the batch concatenation dim + const int batch_idx = out_tidx.z / C_aligned; + out_tidx.w = batch_idx; + out_tidx.z = out_tidx.z % C_aligned; + } + + // Load input texel (contains 4 consecutive channels) + const vec4 input_texel = load_texel(t_in, pos); + + // Load weight and bias texels, which are width-packed; each element along the + // width dim corresponds to a channel in the input tensor. + const ivec3 weight_pos = ivec3(out_tidx.z / 4, 0, 0); + const vec4 weight_texel = load_texel(t_weight, weight_pos); + const vec4 bias_texel = load_texel(t_bias, weight_pos); + + // Calculate which channels this texel represents + // For channels-packed layout: texel at position z contains channels [z, z+1, z+2, z+3] + const int base_channel = out_tidx.z; + + // Calculate buffer indices for mean/rstd lookup + // Mean/rstd tensors have shape [G, N] where G = C/group + const int batch_idx = out_tidx.w; + const int channels_per_group = out_sizes.z / group; + + vec4 bias; + // Process each element of the output texel individually, since each element + // may belong to a different channel group + for (int i = 0; i < 4; ++i) { + const int channel_idx = base_channel + i; + // Handle case where padding channels are added + if (channel_idx >= out_sizes.z) { + bias[i] = input_texel[i]; + continue; + } + + // Calculate group index for this channel + const int group_idx = channel_idx / channels_per_group; + + // Create tensor index for mean/rstd buffer access + const ivec4 mean_tidx = ivec4(group_idx, batch_idx, 0, 0); + const int mean_bufi = tidx_to_bufi(mean_tidx, mean_strides); + + // Load mean and rstd values for this channel + const float mean_val = t_mean[mean_bufi]; + const float rstd_val = t_rstd[mean_bufi]; + + // Apply group normalization with weight and bias: ((input - mean) * rstd) * weight + bias + const float normalized = (input_texel[i] - mean_val) * rstd_val; + bias[i] = normalized * weight_texel[i] + bias_texel[i]; + } + + // Write result to output texture + write_texel(t_out, pos, bias); +} + +void main() { + apply_group_norm(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/group_norm_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/group_norm_texture.yaml new file mode 100644 index 00000000000..b50853be3b0 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/group_norm_texture.yaml @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +group_norm_texture: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: group_norm_texture diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl deleted file mode 100644 index 716c42e8ede..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define VEC4_T ${texel_type(DTYPE)} - -layout(std430) buffer; - -#include "indexing_utils.h" - -${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} - -layout(push_constant) uniform PRECISION restrict Block { - ivec4 out_limits; - ivec4 in_sizes; - // output dims - ivec4 out_ndims; - // x = output channels aligned to 4, y = input channels aligned to 4 - ivec2 channel_info; -}; - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -layout(constant_id = 3) const int packed_dim = C_DIM; - -#extension GL_EXT_control_flow_attributes : require - -void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, out_limits.xyz))) { - return; - } - - VEC4_T outval = VEC4_T(0.0); - - // scale up output position's packed dim - pos[packed_dim] <<= 2; - - // index of packed dim in bchw format - const int in_packed_dim_bchw_index = 3 - packed_dim; - - // determine input position based on output position and permute map - // out_ndims is in BCHW format - ivec4 in_bchw_pos = ivec4(0); // holds b,c,h,w - in_bchw_pos[out_ndims[0]] = (pos.z / channel_info.x); - in_bchw_pos[out_ndims[1]] = (pos.z % channel_info.x); - in_bchw_pos[out_ndims[2]] = pos.y; - in_bchw_pos[out_ndims[3]] = pos.x; - - const int in_packed_dim_size = in_sizes[3 - out_ndims[in_packed_dim_bchw_index]]; - - [[unroll]] for (int j = 0, bchw_index = in_bchw_pos[out_ndims[in_packed_dim_bchw_index]]; j < 4; ++j, ++bchw_index) { - // terminate the loop if trying to access input texture out of bounds - if (bchw_index >= in_packed_dim_size) { - break; - } - // go to position in the input, that is mapped to the packed dim in the output - in_bchw_pos[out_ndims[in_packed_dim_bchw_index]] = bchw_index; - - ivec3 fetch_pos; - - fetch_pos.xy = in_bchw_pos.wz; - // calculate input position in z axis using batch and channel index which is in_bchw_pos.x and in_bchw_pos.y respectively - fetch_pos.z = in_bchw_pos.y + in_bchw_pos.x * channel_info.y; - - // input tensor's packed dim lane corresponding to output tensor's pos - const int in_packed_dim_lane_index = fetch_pos[packed_dim] & 0x3; - - // scale down input tensor's packed dim pos to perform fetch - fetch_pos[packed_dim] >>= 2; - - // fetch input texel - VEC4_T inval = VEC4_T(load_texel(t_in, fetch_pos)); - outval[j] = inval[in_packed_dim_lane_index]; - } - - pos[packed_dim] = int(gl_GlobalInvocationID[packed_dim]); - - imageStore(t_out, pos, outval); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.glsl new file mode 100644 index 00000000000..55b9e3dc9ea --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.glsl @@ -0,0 +1,72 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("buffer")} +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_in", DTYPE, "buffer")} + +${layout_declare_ubo(B, "ivec4", "in_sizes")} +${layout_declare_ubo(B, "ivec4", "out_strides")} +${layout_declare_ubo(B, "int", "out_numel")} + +layout(push_constant) uniform restrict Block { + ivec4 in_strides; + ivec4 permute_dims; // Permutation mapping: permute_dims[i] = j means output dim i comes from input dim j +}; + +${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} +${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} + +const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// Convert output tensor index to input tensor index based on permutation +ivec4 out_tidx_to_in_tidx(const ivec4 out_tidx) { + ivec4 in_tidx; + + // Apply the permutation mapping: in_tidx[permute_dims[i]] = out_tidx[i] + in_tidx[permute_dims.x] = out_tidx.x; + in_tidx[permute_dims.y] = out_tidx.y; + in_tidx[permute_dims.z] = out_tidx.z; + in_tidx[permute_dims.w] = out_tidx.w; + + return in_tidx; +} + +void main() { + const int out_bufi = ivec3(gl_GlobalInvocationID).x; + if (out_bufi >= out_numel) { + return; + } + + // Convert buffer index to tensor index for output + const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_dim_order); + + // Convert output tensor index to input tensor index using permutation + const ivec4 in_tidx = out_tidx_to_in_tidx(out_tidx); + + // Convert input tensor index back to buffer index + const int in_bufi = tidx_to_bufi(in_tidx, in_strides); + + // Copy data from input to output + t_out[out_bufi] = t_in[in_bufi]; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.yaml b/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.yaml similarity index 73% rename from backends/vulkan/runtime/graph/ops/glsl/permute.yaml rename to backends/vulkan/runtime/graph/ops/glsl/permute_buffer.yaml index a90ddcb41ce..81675ae8917 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.yaml @@ -1,12 +1,10 @@ -permute: +permute_buffer: parameter_names_with_default_values: DTYPE: float - NDIM: 3 - STORAGE: texture3d generate_variant_forall: DTYPE: - VALUE: half - VALUE: float - VALUE: int32 shader_variants: - - NAME: permute + - NAME: permute_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/permute_texture.glsl new file mode 100644 index 00000000000..274077f4181 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/permute_texture.glsl @@ -0,0 +1,103 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("texture3d")} +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_in", DTYPE, "texture3d")} + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 in_sizes; + ivec4 permute_dims; // Permutation mapping: permute_dims[i] = j means output dim i comes from input dim j +}; + +${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} +const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); +const lowp int out_packed_dim = unhash_packed_dim(out_layout); + +${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} +const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); +const lowp int in_packed_dim = unhash_packed_dim(in_layout); + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// Convert output tensor index to input tensor index based on permutation +ivec4 out_tidx_to_in_tidx(const ivec4 out_tidx) { + ivec4 in_tidx; + + // Apply the permutation mapping: in_tidx[permute_dims[i]] = out_tidx[i] + in_tidx[permute_dims.x] = out_tidx.x; + in_tidx[permute_dims.y] = out_tidx.y; + in_tidx[permute_dims.z] = out_tidx.z; + in_tidx[permute_dims.w] = out_tidx.w; + + return in_tidx; +} + +// Check if we can use the fast path where texels from the input tensor can be +// copied directly into the output tensor. This occurs when the packed dimension +// is preserved in the permutation, i.e. reading a texel from the output tensor +// produces 4 texels along the same dimension as reading a texel from the input +// tensor. +bool can_use_fast_path() { + // Fast path is possible when the packed dimension is preserved in the permutation + // This means permute_dims[out_packed_dim] == in_packed_dim + return permute_dims[out_packed_dim] == in_packed_dim; +} + +void main() { + const ivec3 lpos = ivec3(gl_GlobalInvocationID); + ivec4 out_tidx = lpos_to_tidx(lpos, out_sizes, out_axis_map.w, out_packed_dim); + + if (any(greaterThanEqual(out_tidx, out_sizes))) { + return; + } + + if (can_use_fast_path()) { + // Fast path: packed dimension is preserved, so we can copy texels directly + ivec4 in_tidx = out_tidx_to_in_tidx(out_tidx); + ivec3 in_pos = tidx_to_pos(in_tidx, in_sizes, in_axis_map, in_packed_dim); + VEC4_T in_texel = VEC4_T(load_texel(t_in, in_pos)); + + write_texel_lpos(t_out, lpos, in_texel, out_axis_map); + } + else { + // Slow path: packed dimension is not preserved, so each element of the + // output texel may be "sourced" from a different texel in the input tensor. + // Therefore each output texel element is processed individually. + VEC4_T out_texel = VEC4_T(0); + + for (int texel_i = 0; texel_i < 4; ++texel_i) { + ivec4 in_tidx = out_tidx_to_in_tidx(out_tidx); + ivec3 in_pos = tidx_to_pos(in_tidx, in_sizes, in_axis_map, in_packed_dim); + int element_idx = in_tidx[in_packed_dim] % 4; + + VEC4_T in_texel = VEC4_T(load_texel(t_in, in_pos)); + T selected_value = T(in_texel[element_idx]); + + out_texel[texel_i] = selected_value; + + out_tidx[out_packed_dim]++; + } + + write_texel_lpos(t_out, lpos, out_texel, out_axis_map); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml new file mode 100644 index 00000000000..f68b8dcdd3d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml @@ -0,0 +1,10 @@ +permute_texture: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int32 + shader_variants: + - NAME: permute_texture3d diff --git a/backends/vulkan/runtime/graph/ops/impl/GroupNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/GroupNorm.cpp new file mode 100644 index 00000000000..8d2a848b0c4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/GroupNorm.cpp @@ -0,0 +1,225 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include + +#include + +namespace vkcompute { + +std::vector calc_group_norm_mean_sizes( + api::vTensor& self, + const int64_t group) { + const std::vector& input_sizes = self.sizes(); + const int64_t N = input_sizes.at(0); + return {N, group}; +} + +utils::uvec3 group_norm_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)global_workgroup_size; + (void)args; + (void)resize_args; + + return {1, 64, 1}; +} + +void resize_group_norm_texture_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + // Extract tensor references from args + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + const ValueRef mean = args.at(1).refs.at(3); + const ValueRef rstd = args.at(1).refs.at(4); + + // Extract group from resize args + const int64_t group_val = graph->extract_scalar(resize_args.at(0)); + + // Get input tensor sizes using ComputeGraph APIs + const std::vector in_sizes = graph->sizes_of(in); + + // Output tensor should have the same size as input + graph->virtual_resize(out, in_sizes); + + // Mean and rstd tensors should have size {num_batches, num_groups} + const int64_t N = in_sizes.at(0); // batch dimension + const std::vector mean_rstd_sizes = {N, group_val}; + + // Resize mean and rstd tensors + graph->virtual_resize(mean, mean_rstd_sizes); + graph->virtual_resize(rstd, mean_rstd_sizes); +} + +void add_native_group_norm_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef weight_data, + const ValueRef bias_data, + const ValueRef N, + const ValueRef C, + const ValueRef HxW, + const ValueRef group, + const ValueRef eps, + const ValueRef out, + const ValueRef mean, + const ValueRef rstd) { + (void)N; + (void)C; + (void)HxW; + + const ValueRef arg_weight = prepack_standard( + graph, + weight_data, + graph.storage_type_of(in), + utils::kWidthPacked, + false); + const ValueRef arg_bias = prepack_standard( + graph, bias_data, graph.storage_type_of(in), utils::kWidthPacked, false); + + const int64_t group_val = graph.extract_scalar(group); + const float epsilon = graph.extract_scalar(eps); + + const std::vector in_sizes = graph.sizes_of(in); + + std::string kernel_name("group_norm_reduce_texture"); + kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + const struct { + int32_t group; + float epsilon; + } params_uniform = {static_cast(group_val), epsilon}; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + group_norm_local_wg_size, + // Inputs and Outputs + {{{mean, rstd}, vkapi::kWrite}, {in, vkapi::kRead}}, + // Shader params buffers + { + graph.strides_ubo(mean), + graph.numel_ubo(mean), + graph.logical_limits_ubo(in), + graph.sizes_ubo(in), + }, + // Push Constants + { + PushConstantDataInfo(¶ms_uniform, sizeof(params_uniform)), + }, + // Specialization Constants + { + graph.hashed_layout_of(mean), + }, + // Resize Args + {group}, + // Resizing Logic + nullptr)); + + // Compute element-wise normalization, now that mean and rstd have been + // computed. + std::string norm_kernel_name("group_norm_texture"); + norm_kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(norm_kernel_name, graph.dtype_of(out)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(norm_kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, + {{in, arg_weight, arg_bias, mean, rstd}, vkapi::kRead}}, + // Shader params buffers + { + graph.logical_limits_ubo(out), + graph.sizes_ubo(out), + graph.logical_limits_ubo(arg_weight), + graph.strides_ubo(mean), + }, + // Push Constants + { + PushConstantDataInfo(¶ms_uniform, sizeof(params_uniform)), + }, + // Specialization Constants + { + graph.hashed_layout_of(in), + }, + // Resize Args + {group}, + // Resizing Logic + resize_group_norm_texture_node)); +} + +void native_group_norm(ComputeGraph& graph, const std::vector& args) { + // Assign each element of the args vector to const ValueRef variables + const ValueRef in = args.at(0); + const ValueRef weight_data = args.at(1); + const ValueRef bias_data = args.at(2); + const ValueRef N = args.at(3); + const ValueRef C = args.at(4); + const ValueRef HxW = args.at(5); + const ValueRef group = args.at(6); + const ValueRef eps = args.at(7); + const ValueRef out_tuple_ref = args.at(8); + + ValueRef out = kDummyValueRef; + ValueRef mean = kDummyValueRef; + ValueRef rstd = kDummyValueRef; + + { + const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); + out = out_tuple->at(0); + mean = out_tuple->at(1); + rstd = out_tuple->at(2); + } + + VK_CHECK_COND(graph.val_is_tref(weight_data)); + VK_CHECK_COND(graph.val_is_tref(bias_data)); + + // Check expected storage types and memory layouts for tensor variables + VK_CHECK_COND(graph.is_standard_channels_packed_texture_tensor(in)); + VK_CHECK_COND(graph.is_standard_channels_packed_texture_tensor(out)); + + VK_CHECK_COND(graph.is_contiguous_buffer_tensor(mean)); + VK_CHECK_COND(graph.is_contiguous_buffer_tensor(rstd)); + + return add_native_group_norm_node( + graph, + in, + weight_data, + bias_data, + N, + C, + HxW, + group, + eps, + out, + mean, + rstd); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.native_group_norm.default, native_group_norm); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp index fba3f03467b..6e6a6fa3bf2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp @@ -10,6 +10,7 @@ #include +#include #include #include #include @@ -100,54 +101,76 @@ void add_permute_node( const ValueRef out) { check_args(graph, in, permute_dims, out); - ivec4 out_dims{0, 1, 2, 3}; - - // Special cases of squeeze/unsqueeze. Because the input dim size can be - // different with output dim size. So pick graph.dim_of(in) if squeeze, and - // graph.dim_of(out) if unsqueeze to create parameter for permute. - const int64_t out_ndim = std::max(graph.dim_of(in), graph.dim_of(out)); - std::vector seen(out_ndim); + // Convert the permute dims to WHCN dimension order, which is the standard in + // our compute shaders. The following transformations are applied. + // 1. Change dimension index values from NCHW order valueto WHCN order value + // 2. Reverse the order of the permute array from NCHW order to WHCN order + ivec4 whcn_permute_dims{0, 1, 2, 3}; { IntListPtr permute_dims_ptr = graph.get_int_list(permute_dims); - for (int i = 0; i < out_ndim; i++) { - int64_t permute_dim = permute_dims_ptr->at(i); - VK_CHECK_COND( - !seen[permute_dim], "Argument dim ", permute_dim, " is repeated"); - seen[permute_dim] = true; + const int32_t permute_ndim = + utils::safe_downcast(permute_dims_ptr->size()); + + for (int32_t nchw_i = permute_ndim - 1, whcn_i = 0; nchw_i >= 0; + nchw_i--, whcn_i++) { + const int32_t permute_dim_nchw = permute_dims_ptr->at(nchw_i); + const int32_t permute_dim_whcn = permute_ndim - 1 - permute_dim_nchw; - out_dims[(4u - out_ndim) + i] = - utils::safe_downcast(permute_dim + (4 - out_ndim)); + whcn_permute_dims[whcn_i] = permute_dim_whcn; } } std::string kernel_name = "permute"; kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - const int32_t out_channels = dim_at(graph.sizes_of(out)); - const int32_t in_channels = dim_at(graph.sizes_of(in)); + vkapi::ParamsBindList param_buffers; + std::vector push_constants; + vkapi::SpecVarList spec_vars; - const int32_t packed_dim = graph.packed_dim_of(in); - ivec2 channel_info = {out_channels, in_channels}; - if (packed_dim == WHCN::kChannelsDim) { - channel_info[0] = utils::align_up_4(channel_info[0]); - channel_info[1] = utils::align_up_4(channel_info[1]); - } + if (graph.is_buffer_storage(out)) { + param_buffers.append(graph.sizes_ubo(in)); + param_buffers.append(graph.strides_ubo(out)); + param_buffers.append(graph.numel_ubo(out)); + + // Buffer storage - use permute_buffer shader + push_constants = { + graph.strides_pc_of(in), + PushConstantDataInfo(&whcn_permute_dims, sizeof(whcn_permute_dims)), + }; + + spec_vars = {graph.hashed_layout_of(out), graph.hashed_layout_of(in)}; + } else { + // Texture storage - use permute_texture shader + const int32_t out_channels = dim_at(graph.sizes_of(out)); + const int32_t in_channels = dim_at(graph.sizes_of(in)); + + const int32_t packed_dim = graph.packed_dim_of(in); + ivec2 channel_info = {out_channels, in_channels}; + if (packed_dim == WHCN::kChannelsDim) { + channel_info[0] = utils::align_up_4(channel_info[0]); + channel_info[1] = utils::align_up_4(channel_info[1]); + } + + push_constants = { + graph.sizes_pc_of(out), + graph.sizes_pc_of(in), + PushConstantDataInfo(&whcn_permute_dims, sizeof(whcn_permute_dims))}; - const vkapi::SpecVarList spec_vars = {packed_dim}; + spec_vars = {graph.hashed_layout_of(out), graph.hashed_layout_of(in)}; + } - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, {{out, vkapi::kWrite}, {in, vkapi::kRead}}, - {}, + // Parameter buffers + param_buffers, // Push Constants - {{graph.logical_limits_pc_of(out), - graph.sizes_pc_of(in), - PushConstantDataInfo(&out_dims, sizeof(out_dims)), - PushConstantDataInfo(&channel_info, sizeof(channel_info))}}, + push_constants, // Specialization Constants spec_vars, // Resize Args diff --git a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp index 306a79fb8b8..c4de5d88f30 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp @@ -26,6 +26,9 @@ void add_unsqueeze_node( in_dim < 4, "Cannot unsqueeze a tensor with more than 3 dimensions"); int64_t dim = graph.extract_scalar(dim_ref); + if (dim < 0) { + dim += out_dim; + } std::vector permute_dims(out_dim); for (int i = 1; i <= dim; i++) { diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 813807445f0..0fd5ef4f002 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -646,6 +646,45 @@ def get_native_layer_norm_inputs(): return test_suite +@register_test_suite("aten.native_group_norm.default") +def get_native_group_norm_inputs(): + test_suite = VkTestSuite( + [ + # (input_shape, weight_shape, bias_shape, N, C, HxW, group, eps) + # General test cases + ((1, 8, 4, 4), (8), (8), 1, 8, 16, 2, 0.001), + ((2, 8, 3, 3), (8), (8), 2, 8, 9, 4, 0.001), + ((1, 12, 2, 2), (12), (12), 1, 12, 4, 3, 0.001), + ((3, 16, 5, 5), (16), (16), 3, 16, 25, 8, 0.001), + ((3, 16, 13, 17), (16), (16), 3, 16, 13 * 17, 4, 0.001), + ((1, 4, 7, 7), (4), (4), 1, 4, 49, 2, 0.001), + ((2, 6, 1, 8), (6), (6), 2, 6, 8, 3, 0.001), + # Single group and prime number sizes + ((3, 7, 13, 11), (7), (7), 3, 7, 13 * 11, 1, 0.001), + # Each channel is it's own group and prime number sizes + ((1, 7, 13, 11), (7), (7), 1, 7, 13 * 11, 7, 0.001), + ] + ) + test_suite.layouts = [ + "utils::kChannelsPacked", + ] + test_suite.storage_types = [ + "utils::kTexture3D", + ] + test_suite.dtypes = [ + "at::kFloat", + "at::kHalf", + ] + test_suite.arg_storage_types = { + "out": [None, "utils::kBuffer", "utils::kBuffer"], + } + + test_suite.prepacked_args = ["weight", "bias"] + test_suite.requires_prepack = True + + return test_suite + + def get_upsample_inputs(): inputs_list = [ # (input tensor shape, output 2D image size (H, W), output scaling factors) @@ -752,6 +791,13 @@ def get_permute_inputs(): "utils::kHeightPacked", "utils::kChannelsPacked", ] + test_suite.storage_types = [ + "utils::kBuffer", + "utils::kTexture3D", + ] + test_suite.dtypes = [ + "at::kFloat", + ] return test_suite @@ -990,9 +1036,11 @@ def get_unsqueeze_inputs(): ((9, 9), 2), ((9,), 0), ((9,), 1), + ((1, 10), -1), ] ) test_suite.layouts = [ + "utils::kWidthPacked", "utils::kChannelsPacked", ] test_suite.data_gen = "make_seq_tensor" diff --git a/backends/vulkan/test/op_tests/utils/gen_computegraph.py b/backends/vulkan/test/op_tests/utils/gen_computegraph.py index b24879f660a..38a3ee93627 100644 --- a/backends/vulkan/test/op_tests/utils/gen_computegraph.py +++ b/backends/vulkan/test/op_tests/utils/gen_computegraph.py @@ -58,6 +58,8 @@ class ValueRef: src_cpp_type: str is_in: bool = False is_out: bool = False + fixed_storage_type: Optional[str] = None + fixed_memory_layout: Optional[str] = None requires_prepack: bool = False supports_prepack: bool = False # When is_dynamic_size is true, the underlying object size is not known @@ -137,20 +139,43 @@ def __init__( if arg.name in self.suite_def.prepacked_args: supports_prepack = True + fixed_storage_type = None + if arg.name in self.suite_def.arg_storage_types: + fixed_storage_type = self.suite_def.arg_storage_types[arg.name] + + fixed_memory_layout = None + if arg.name in self.suite_def.arg_memory_layouts: + fixed_memory_layout = self.suite_def.arg_memory_layouts[arg.name] + self.refs[arg.name] = ValueRef( name=f"{arg.name}_ref", src_cpp_name=arg.name, src_cpp_type=cpp_type, is_in=(cpp_type in InableCppType), + fixed_storage_type=fixed_storage_type, + fixed_memory_layout=fixed_memory_layout, requires_prepack=requires_prepack, supports_prepack=supports_prepack, ) ret_type = cpp.returns_type(self.f.func.returns, symint=False).cpp_type() self.out = ATenArg(name="out", cpp_type=ret_type, default=None) + + fixed_storage_type = None + if "out" in self.suite_def.arg_storage_types: + fixed_storage_type = self.suite_def.arg_storage_types["out"] + fixed_memory_layout = None + if "out" in self.suite_def.arg_memory_layouts: + fixed_memory_layout = self.suite_def.arg_memory_layouts["out"] + if ret_type == AT_TENSOR: self.refs["out"] = ValueRef( - name="out_ref", src_cpp_name="out", src_cpp_type=ret_type, is_out=True + name="out_ref", + src_cpp_name="out", + src_cpp_type=ret_type, + is_out=True, + fixed_storage_type=fixed_storage_type, + fixed_memory_layout=fixed_memory_layout, ) elif ret_type == TWO_TENSOR_TUPLE: self.refs["out"] = [ @@ -159,12 +184,24 @@ def __init__( src_cpp_name="std::get<0>(out)", src_cpp_type="at::Tensor", is_out=True, + fixed_storage_type=( + fixed_storage_type[0] if fixed_storage_type else None + ), + fixed_memory_layout=( + fixed_memory_layout[0] if fixed_memory_layout else None + ), ), ValueRef( name="out_ref_second", src_cpp_name="std::get<1>(out)", src_cpp_type="at::Tensor", is_out=True, + fixed_storage_type=( + fixed_storage_type[1] if fixed_storage_type else None + ), + fixed_memory_layout=( + fixed_memory_layout[1] if fixed_memory_layout else None + ), ), ValueRef( name="out_ref", @@ -180,18 +217,36 @@ def __init__( src_cpp_name="std::get<0>(out)", src_cpp_type="at::Tensor", is_out=True, + fixed_storage_type=( + fixed_storage_type[0] if fixed_storage_type else None + ), + fixed_memory_layout=( + fixed_memory_layout[0] if fixed_memory_layout else None + ), ), ValueRef( name="out_ref_second", src_cpp_name="std::get<1>(out)", src_cpp_type="at::Tensor", is_out=True, + fixed_storage_type=( + fixed_storage_type[1] if fixed_storage_type else None + ), + fixed_memory_layout=( + fixed_memory_layout[1] if fixed_memory_layout else None + ), ), ValueRef( name="out_ref_third", src_cpp_name="std::get<2>(out)", src_cpp_type="at::Tensor", is_out=True, + fixed_storage_type=( + fixed_storage_type[2] if fixed_storage_type else None + ), + fixed_memory_layout=( + fixed_memory_layout[2] if fixed_memory_layout else None + ), ), ValueRef( name="out_ref", @@ -302,7 +357,12 @@ def create_value_for( # noqa: C901 ret_str += f"{self.graph}{self.dot}" ret_str += "add_input_tensor(" if ref.is_in else "add_tensor(" ret_str += f"{ref.src_cpp_name}->sizes().vec(), " - ret_str += f"from_at_scalartype({ref.src_cpp_name}->scalar_type())); \n" + ret_str += f"from_at_scalartype({ref.src_cpp_name}->scalar_type()" + if ref.fixed_storage_type: + ret_str += f", {ref.fixed_storage_type}" + if ref.fixed_memory_layout: + ret_str += f", {ref.fixed_memory_layout}" + ret_str += "));\n" elif prepack: ret_str += f"{self.graph}{self.dot}" ret_str += f"add_tensorref({ref.src_cpp_name}->sizes().vec(), " @@ -385,7 +445,12 @@ def create_value_for( # noqa: C901 elif ref.src_cpp_type == AT_TENSOR and not prepack: ret_str += "add_input_tensor(" if ref.is_in else "add_tensor(" ret_str += f"{ref.src_cpp_name}.sizes().vec(), " - ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type())); \n" + ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type())" + if ref.fixed_storage_type: + ret_str += f", {ref.fixed_storage_type}" + if ref.fixed_memory_layout: + ret_str += f", {ref.fixed_memory_layout}" + ret_str += ");\n" elif ref.src_cpp_type == AT_TENSOR and prepack: ret_str += f"add_tensorref({ref.src_cpp_name}.sizes().vec(), " ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type()), " diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py index 5be4ddba6bf..250edf333bc 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py @@ -140,7 +140,13 @@ def call_data_gen_fn(self, arg: Argument, data: Any, terminate: bool = True) -> else self.suite_def.arg_data_range[arg.name] ) - ret_str = f"{self.suite_def.data_gen}({init_list_str(data)}, {tensor_dtype}, {data_range[0]}, {data_range[1]})" + data_gen_fn = ( + self.suite_def.data_gen + if arg.name not in self.suite_def.arg_data_gen_fn + else self.suite_def.arg_data_gen_fn[arg.name] + ) + + ret_str = f"{data_gen_fn}({init_list_str(data)}, {tensor_dtype}, {data_range[0]}, {data_range[1]})" if terminate: ret_str += ";" @@ -288,13 +294,29 @@ def generate_suite_cpp(self) -> str: if (dtype == at::kBool) return at::rand(sizes, at::device(at::kCPU)) > 0.5; - + if (high == 1.0 && low == 0.0) return at::rand(sizes, at::device(at::kCPU).dtype(dtype)); return at::rand(sizes, at::device(at::kCPU).dtype(dtype)) * (high - low) + low; }} +at::Tensor make_zeros_tensor( + std::vector sizes, + at::ScalarType dtype = at::kFloat, + float low = 0.0, + float high = 1.0) {{ + return at::zeros(sizes, at::device(at::kCPU).dtype(dtype)); +}} + +at::Tensor make_ones_tensor( + std::vector sizes, + at::ScalarType dtype = at::kFloat, + float low = 0.0, + float high = 1.0) {{ + return at::ones(sizes, at::device(at::kCPU).dtype(dtype)); +}} + at::Tensor make_seq_tensor( std::vector sizes, at::ScalarType dtype = at::kFloat, diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py index e7cf5ba92a5..c368c23c539 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py @@ -29,7 +29,6 @@ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple void SetUp() override {{ GraphConfig config; - config.expect_dynamic_shapes = true; utils::StorageType default_storage_type; utils::GPUMemoryLayout default_memory_layout; std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam(); diff --git a/backends/vulkan/test/op_tests/utils/test_suite.py b/backends/vulkan/test/op_tests/utils/test_suite.py index 72ba457b5af..427864b0d5d 100644 --- a/backends/vulkan/test/op_tests/utils/test_suite.py +++ b/backends/vulkan/test/op_tests/utils/test_suite.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional ################################### ## Generic Test Suite definition ## @@ -23,6 +23,7 @@ def __init__(self, input_cases: List[Any]): self.data_range = (0, 1) self.arg_dtype = {} + self.arg_data_gen_fn: Dict[str, str] = {} self.arg_data_range = {} self.atol: str = "1e-5" @@ -48,3 +49,5 @@ def __init__(self, input_cases: List[Any]): self.layouts: List[str] = ["utils::kChannelsPacked"] self.data_gen: str = "make_rand_tensor" self.force_io: bool = True + self.arg_storage_types: Dict[str, str] = {} + self.arg_memory_layouts: Dict[str, str] = {} diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 0096834f3c6..04adf183e55 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -1898,3 +1898,69 @@ def forward(self, x): dynamic_shapes=dynamic_shapes, test_inputs=test_inputs, ) + + def test_vulkan_backend_group_norm(self): + class ConvGroupNormModule(torch.nn.Module): + def __init__(self): + super().__init__() + # Conv2d: 3 input channels -> 16 output channels + self.conv = torch.nn.Conv2d( + in_channels=3, + out_channels=16, + kernel_size=3, + padding=1, + bias=True, + ) + # GroupNorm: 4 groups for 16 channels (16 % 4 == 0) + self.group_norm = torch.nn.GroupNorm( + num_groups=4, + num_channels=16, + eps=1e-5, + affine=True, + ) + + def forward(self, x): + x = self.conv(x) + x = self.group_norm(x) + return x + + # Create sample inputs: [batch, channels, height, width] + sample_inputs = (torch.randn(size=(1, 3, 32, 32), dtype=torch.float32),) + + # Test with static shapes first + self.lower_module_and_test_output( + ConvGroupNormModule(), + sample_inputs, + ) + + def test_vulkan_backend_group_norm_different_groups(self): + class GroupNormModule(torch.nn.Module): + def __init__(self, num_groups, num_channels): + super().__init__() + self.group_norm = torch.nn.GroupNorm( + num_groups=num_groups, + num_channels=num_channels, + eps=1e-5, + affine=True, + ) + + def forward(self, x): + return self.group_norm(x) + + # Test different group configurations + test_configs = [ + (2, 8), # 2 groups, 8 channels + (4, 16), # 4 groups, 16 channels + (8, 32), # 8 groups, 32 channels + ] + + for num_groups, num_channels in test_configs: + with self.subTest(num_groups=num_groups, num_channels=num_channels): + sample_inputs = ( + torch.randn(size=(2, num_channels, 16, 16), dtype=torch.float32), + ) + + self.lower_module_and_test_output( + GroupNormModule(num_groups, num_channels), + sample_inputs, + ) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 642f7c5f495..5d57ce1e7be 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -264,9 +264,19 @@ def set_node_spec_attr(node: torch.fx.Node, attr: str, value): if isinstance(spec, TensorSpec): setattr(spec, attr, value) elif isinstance(spec, (list, tuple)): - for s in spec: - assert isinstance(s, TensorSpec) - setattr(s, attr, value) + # Special case if value is a list/tuple of the same length as the + # collection of tensors in the node. In this case, treat the value list + # as a list of values to set indivudually for each tensor in the node + if isinstance(value, (list, tuple)) and len(spec) == len(value): + assert len(spec) == len(value) + for s, v in zip(spec, value): + assert isinstance(s, TensorSpec) + setattr(s, attr, v) + # Otherwise, set the attribute to value for all tensors in the list + else: + for s in spec: + assert isinstance(s, TensorSpec) + setattr(s, attr, value) else: raise RuntimeError(f"Cannot set attr for spec of type {type(spec)}")