From 5a043e70f316b5c3c796aada386e15d894d30b05 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Sun, 23 Mar 2025 21:33:52 -0700 Subject: [PATCH] [ET-VK] Adding all tensor packing support for native layer norm. This diff updates Executorch Vulkan backend's `native layer norm` operation to support width, height and channel packed tensors. . and adds new test cases to the cases.py file to test the operation. Differential Revision: [D71663678](https://our.internmc.facebook.com/intern/diff/D71663678/) [ghstack-poisoned] --- backends/vulkan/op_registry.py | 15 ++- .../graph/ops/glsl/native_layer_norm.glsl | 126 +++++++++++++----- .../graph/ops/impl/NativeLayerNorm.cpp | 7 +- backends/vulkan/test/op_tests/cases.py | 5 + 4 files changed, 114 insertions(+), 39 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 26f461c062f..54b7b8651bc 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -576,7 +576,6 @@ def register_ported_op_all_packed_dims(features: OpFeatures): [ exir_ops.edge.aten.embedding.default, exir_ops.edge.aten._native_batch_norm_legit_no_training.default, - exir_ops.edge.aten.native_layer_norm.default, ] ) def register_ported_ops_with_prepacking(features: OpFeatures): @@ -587,6 +586,20 @@ def register_ported_ops_with_prepacking(features: OpFeatures): return features +# Ported ops that support their own prepacking. +@update_features( + [ + exir_ops.edge.aten.native_layer_norm.default, + ] +) +def register_ported_ops_with_prepacking_all_dims(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + valid_packed_dims=all_packed_dims, + ) + features.handles_own_prepacking = True + return features + + ####################### ## Utility functions ## ####################### diff --git a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl index f984821600b..f518e838750 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl @@ -15,6 +15,8 @@ #define VEC4_T ${texel_type(DTYPE)} +#define T ${texel_component_type(DTYPE)} + layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} @@ -48,37 +50,97 @@ void main() { const int width = int(sizes.x); - VEC4_T mean = VEC4_T(0); - VEC4_T delta = VEC4_T(0); - VEC4_T delta2 = VEC4_T(0); - VEC4_T M2 = VEC4_T(0); - - // Use Welford's online algorithm to compute mean and variance in one pass - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm - ivec3 in_pos = lpos_to_pos(lpos, in_axis_map); - for (int w = 0; w < width; ++w) { - in_pos[in_axis_map.x] = w; - VEC4_T v = load_texel(t_in, in_pos); - delta = v - mean; - mean += delta / (w + 1); - delta2 = v - mean; - M2 += delta * delta2; - } - - VEC4_T var = M2 / width; - VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5)); - VEC4_T offset = -rstd * mean; - - for (int w = 0; w < width; ++w) { - in_pos[in_axis_map.x] = w; - VEC4_T v = load_texel(t_in, in_pos); - // broadcasting - VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx; - VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx; - VEC4_T outtex = (v * rstd + offset) * weight + bias; - write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map); + if (in_packed_dim != W_DIM) { + VEC4_T mean = VEC4_T(0); + VEC4_T delta = VEC4_T(0); + VEC4_T delta2 = VEC4_T(0); + VEC4_T M2 = VEC4_T(0); + + // Use Welford's online algorithm to compute mean and variance in one pass + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + ivec3 in_pos = lpos_to_pos(lpos, in_axis_map); + for (int w = 0; w < width; ++w) { + in_pos[in_axis_map.x] = w; + VEC4_T v = load_texel(t_in, in_pos); + delta = v - mean; + mean += delta / (w + 1); + delta2 = v - mean; + M2 += delta * delta2; + } + + VEC4_T var = M2 / width; + VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5)); + VEC4_T offset = -rstd * mean; + + for (int w = 0; w < width; ++w) { + in_pos[in_axis_map.x] = w; + VEC4_T v = load_texel(t_in, in_pos); + // broadcasting + VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx; + VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx; + VEC4_T outtex = (v * rstd + offset) * weight + bias; + write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map); + } + + write_texel(t_mean, lpos, mean); + write_texel(t_rstd, lpos, rstd); + } else { + const int packed_width = divup4(width); + + T mean = T(0); + T delta = T(0); + T delta2 = T(0); + T M2 = T(0); + // Use Welford's online algorithm to compute mean and variance in one pass + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + ivec3 in_pos = lpos_to_pos(lpos, in_axis_map); + T width_counter = T(1); + + const bool has_unaligned_width = (width & 0x3) != 0; + const int fully_packed_4_comp_count = packed_width - mix(0, 1, has_unaligned_width); + + // iterate through texels that are fully packed ie. has 4 components + for (int w = 0; w < fully_packed_4_comp_count; ++w) { + in_pos[in_axis_map.x] = w; + VEC4_T v = load_texel(t_in, in_pos); + for (int i=0; i<4; i++) { + delta = v[i] - mean; + mean += delta / width_counter; + delta2 = v[i] - mean; + M2 += delta * delta2; + width_counter++; + } + } + + // handle last texel if its not 4 aligned + if (has_unaligned_width) { + in_pos[in_axis_map.x] = fully_packed_4_comp_count; + const int remaining_width = width & 0x3; + + VEC4_T v = load_texel(t_in, in_pos); + for (int i=0; ivirtual_resize(mean_size); } -void check_args(const api::vTensor& in, const api::vTensor& out) { - VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim)); - VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim)); -} - void add_native_layer_norm_node( ComputeGraph& graph, const ValueRef in, @@ -84,7 +79,7 @@ void add_native_layer_norm_node( vTensorPtr t_input = graph.get_tensor(in); float epsilon = graph.extract_scalar(eps); - check_args(*t_input, *t_out); + VK_CHECK_COND(check_same_packed_dim(*t_input, *t_out)); std::vector in_sizes = t_input->sizes(); diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 2918bbbd7d5..418ef9cd208 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -383,6 +383,11 @@ def get_native_layer_norm_inputs(): ((S, XL, M1, M2), [M2], (M2), (M2), 0.001), ] ) + test_suite.layouts = [ + "utils::kWidthPacked", + "utils::kHeightPacked", + "utils::kChannelsPacked", + ] return test_suite