From 6cd234dbf361fbf236aea0569a99cf25017e5381 Mon Sep 17 00:00:00 2001 From: Nathanael See Date: Fri, 18 Oct 2024 17:03:56 -0700 Subject: [PATCH] update native_layer_norm to new layout gen & axis mapping (#6358) Summary: Naively using ivec4 axis mapping regresses latency by 20-30% for layer norm, due to the added overhead of another layer of index lookups over the 2 loops over the entire width dim. We can use specialization constants to move the index lookups ahead of time to the shader compilation and command buffer construction phase. Unfortunately, we can't pass vec types as specialization constants. But, we can squeeze the axis mapping into a single 32-bit int and pass that in as a specialization constant! We can unpack the int and create a const ivec4 axis map which can be folded during shader compilation. Using this method, we incur a 1% overhead instead of the 20+% we previously saw. This diff also adds a codegen function for specialization constants, along with a new accumulator `C` for constant ids (besides `B` for binding index for textures, buffers and buffer objects) Reviewed By: SS-JIA Differential Revision: D63361329 --- backends/vulkan/runtime/gen_vulkan_spv.py | 27 +++++++++- .../runtime/graph/ops/glsl/indexing_utils.h | 7 +++ .../graph/ops/glsl/native_layer_norm.glsl | 53 ++++++++++--------- .../graph/ops/glsl/native_layer_norm.yaml | 3 +- .../graph/ops/impl/NativeLayerNorm.cpp | 15 ++++-- .../graph/ops/impl/utils/TensorUtils.h | 14 +++++ 6 files changed, 87 insertions(+), 32 deletions(-) diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index c133094dbfb..46db1e3a981 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import argparse import array import codecs @@ -42,6 +44,10 @@ # layout binding index when declaring layout bindings. Note that a container # type is used because integers are immutable in Python. "B": [0], + # C is shorthand for "constant_id". This is used to automatically increment the + # constant_id index for specialization constants. + # Note that it starts at 3, as 0-2 are reserved for local workgroup size ids. + "C": [3], } # Establishes relationships between different tensor types and different GLSL types @@ -300,7 +306,7 @@ def layout_declare_ubo( layout(set = 0, binding = {get_slot_val(slot)}) uniform {precision} restrict readonly {ubo_name}UBO {{ """ for type_name, var_name in var_list: - out_str += f"{type_name} {var_name};\n" + out_str += f" {type_name} {var_name};\n" out_str += "};" if isinstance(slot, list): @@ -308,6 +314,24 @@ def layout_declare_ubo( return out_str +def layout_declare_spec_const( + slot: Union[int, List[int]], + type_name: str, + var_name: str, + initial_val: Optional[str] = None, +) -> str: + assert type_name in ["int", "uint", "float", "bool"] + + out_str = f"layout(constant_id = {get_slot_val(slot)}) const {type_name} {var_name}" + if initial_val is not None: + out_str += f" = {initial_val}" + out_str += ";" + + if isinstance(slot, list): + slot[0] = slot[0] + 1 + return out_str + + def define_active_storage_type(storage_type: str): if storage_type.lower() == "buffer": return "#define USING_BUFFER" @@ -361,6 +385,7 @@ def define_required_extensions(dtypes: Union[str, List[str]]): "layout_declare_sampler": layout_declare_sampler, "layout_declare_tensor": layout_declare_tensor, "layout_declare_ubo": layout_declare_ubo, + "layout_declare_spec_const": layout_declare_spec_const, "define_active_storage_type": define_active_storage_type, "define_required_extensions": define_required_extensions, } diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h index cf4ff98c46b..26342bcd2ba 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -232,6 +232,13 @@ ivec3 lpos_to_pos(const ivec3 lpos, const ivec4 axis_map) { imageStore(im, lpos_to_pos(lpos, axis_map), texel) #endif +// Converts hashed axis mapping and packed dim to a ivec4 +// e.g. 0x000102, 2 -> ivec4(0, 1, 2, 2) +// e.g. 0x010200, 1 -> ivec4(1, 2, 0, 1) +#define UNHASH_AXIS_MAP(hash, packed_dim) \ + ivec4(hash >> 16, (hash >> 8) & 0xFF, hash & 0xFF, packed_dim) +#define DEFAULT_AXIS_MAP_HASH 0x000102 + /************************ * Deprecated Functions * ************************/ diff --git a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl index 235408c0a81..03500b2d085 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl @@ -17,32 +17,32 @@ layout(std430) buffer; -layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; -layout(set = 0, binding = 1, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_mean; -layout(set = 0, binding = 2, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_rstd; +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(B, "w", "t_mean", DTYPE, STORAGE)} +${layout_declare_tensor(B, "w", "t_rstd", DTYPE, STORAGE)} -layout(set = 0, binding = 3) uniform PRECISION sampler3D image_in; -layout(set = 0, binding = 4) uniform PRECISION sampler3D weight_in; -layout(set = 0, binding = 5) uniform PRECISION sampler3D bias_in; +${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "t_weight", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, STORAGE)} -layout(set = 0, binding = 6) uniform PRECISION restrict OutLimits { - ivec3 out_limits; -}; +${layout_declare_ubo(B, "ivec3", "out_limits")} +${layout_declare_ubo(B, "ivec4", "sizes")} +${layout_declare_ubo(B, "float", "epsilon")} -layout(set = 0, binding = 7) uniform PRECISION restrict Sizes { - ivec4 sizes; -}; +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -layout(set = 0, binding = 8) uniform PRECISION restrict Epsilon { - float epsilon; -}; +${layout_declare_spec_const(C, "int", "in_axis_map_hash", "DEFAULT_AXIS_MAP_HASH")} +${layout_declare_spec_const(C, "int", "in_packed_dim", "C_DIM")} +const ivec4 in_axis_map = UNHASH_AXIS_MAP(in_axis_map_hash, in_packed_dim); -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_spec_const(C, "int", "out_axis_map_hash", "DEFAULT_AXIS_MAP_HASH")} +${layout_declare_spec_const(C, "int", "out_packed_dim", "C_DIM")} +const ivec4 out_axis_map = UNHASH_AXIS_MAP(out_axis_map_hash, out_packed_dim); void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec3 lpos = ivec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(pos, out_limits))) { + if (any(greaterThanEqual(lpos, out_limits))) { return; } @@ -55,8 +55,10 @@ void main() { // Use Welford's online algorithm to compute mean and variance in one pass // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + ivec3 in_pos = lpos_to_pos(lpos, in_axis_map); for (int w = 0; w < width; ++w) { - VEC4_T v = texelFetch(image_in, ivec3(w, pos.y, pos.z), 0); + in_pos[in_axis_map.x] = w; + VEC4_T v = load_texel(t_in, in_pos); delta = v - mean; mean += delta / (w + 1); delta2 = v - mean; @@ -68,14 +70,15 @@ void main() { VEC4_T offset = -rstd * mean; for (int w = 0; w < width; ++w) { - VEC4_T v = texelFetch(image_in, ivec3(w, pos.y, pos.z), 0); + in_pos[in_axis_map.x] = w; + VEC4_T v = load_texel(t_in, in_pos); // broadcasting - VEC4_T weight = texelFetch(weight_in, ivec3(w, 0, 0), 0).xxxx; - VEC4_T bias = texelFetch(bias_in, ivec3(w, 0, 0), 0).xxxx; + VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx; + VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx; VEC4_T outtex = (v * rstd + offset) * weight + bias; - imageStore(image_out, ivec3(w, pos.y, pos.z), outtex); + write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map); } - imageStore(image_mean, pos, mean); - imageStore(image_rstd, pos, rstd); + write_texel(t_mean, lpos, mean); + write_texel(t_rstd, lpos, rstd); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.yaml b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.yaml index 44e9b627ada..ac478599f8a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.yaml @@ -6,9 +6,8 @@ native_layer_norm: parameter_names_with_default_values: - NDIM: 3 DTYPE: float - PACKING: C_packed + STORAGE: texture3d generate_variant_forall: DTYPE: - VALUE: half diff --git a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp index 0e30d8a2c6e..1509f35014d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp @@ -106,11 +106,18 @@ void add_native_layer_norm_node( vkapi::MemoryAccessType::WRITE}, {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}}, // Shader params buffers - {t_out->logical_limits_ubo(), - t_out->sizes_ubo(), - graph.create_params_buffer(epsilon)}, + { + t_out->logical_limits_ubo(), + t_out->sizes_ubo(), + graph.create_params_buffer(epsilon), + }, // Specialization Constants - {}, + { + hash_axis_map(t_input->axis_map()), + t_input->packed_dim(), + hash_axis_map(t_out->axis_map()), + t_out->packed_dim(), + }, // Resizing Logic resize_native_layer_norm_node, {normalized_shape})); diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h index c9eeb0efe08..508cc2538a0 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h +++ b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h @@ -79,4 +79,18 @@ T nchw_dim_to_whcn_dim(const T& nchw_dim, const int64_t ndim) { return ndim - 1 - nchw_dim; } +// +// Tensor axis map utilities +// + +// Converts ivec4 axis map to a single int32_t, to be able to pass it as a +// specialization constant instead of a ubo. This allows for the spir-v to +// bytecode compilation to perform compile-time folding on the axis map. +// Only converts the first 3 indices, as the last index is the packed dim, +// which is passed separately. +// Example: ivec4(0, 1, 2, 2) -> 0x000102 +inline int32_t hash_axis_map(const std::vector& axis_map) { + return (axis_map.at(0) << 16) + (axis_map.at(1) << 8) + axis_map.at(2); +} + } // namespace vkcompute