From dab26ccf6ee7e206c37b3ebdab413aded6cee4fe Mon Sep 17 00:00:00 2001 From: Nathanael See Date: Thu, 20 Feb 2025 11:20:23 -0800 Subject: [PATCH] update batch norm to use layout gen Summary: TSIA Differential Revision: D69937208 --- .../runtime/graph/ops/glsl/batchnorm.glsl | 38 ++++++++----------- .../runtime/graph/ops/glsl/batchnorm.yaml | 1 + 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/batchnorm.glsl b/backends/vulkan/runtime/graph/ops/glsl/batchnorm.glsl index deb03192af0..c2fc5a56754 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/batchnorm.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/batchnorm.glsl @@ -13,24 +13,18 @@ layout(std430) buffer; -layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; -layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; -layout(set = 0, binding = 2) uniform PRECISION sampler3D weight_in; -layout(set = 0, binding = 3) uniform PRECISION sampler3D bias_in; -layout(set = 0, binding = 4) uniform PRECISION sampler3D mean_in; -layout(set = 0, binding = 5) uniform PRECISION sampler3D var_in; +#include "indexing_utils.h" -layout(set = 0, binding = 6) uniform PRECISION restrict OutLimits { - ivec3 out_limits; -}; +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "weight_in", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "bias_in", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "mean_in", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "var_in", DTYPE, STORAGE)} -layout(set = 0, binding = 7) uniform PRECISION restrict Params { - float eps; -}; - -layout(set = 0, binding = 8) uniform PRECISION restrict Params2 { - int num_texel_per_batch; -}; +${layout_declare_ubo(B, "ivec3", "out_limits")} +${layout_declare_ubo(B, "float", "eps")} +${layout_declare_ubo(B, "int", "num_texel_per_batch")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -40,16 +34,16 @@ void main() { return; } - VEC4_T v = VEC4_T(texelFetch(image_in, pos, 0)); + VEC4_T v = VEC4_T(load_texel(t_in, pos)); ivec3 param_pos = ivec3(pos.z % num_texel_per_batch, 0, 0); - VEC4_T weight = VEC4_T(texelFetch(weight_in, param_pos, 0)); - VEC4_T bias = VEC4_T(texelFetch(bias_in, param_pos, 0)); - VEC4_T mean = VEC4_T(texelFetch(mean_in, param_pos, 0)); - VEC4_T var = VEC4_T(texelFetch(var_in, param_pos, 0)); + VEC4_T weight = VEC4_T(load_texel(weight_in, param_pos)); + VEC4_T bias = VEC4_T(load_texel(bias_in, param_pos)); + VEC4_T mean = VEC4_T(load_texel(mean_in, param_pos)); + VEC4_T var = VEC4_T(load_texel(var_in, param_pos)); v = ((v - mean) / sqrt(var + eps)) * weight + bias; - imageStore(image_out, pos, v); + write_texel(t_out, pos, v); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/batchnorm.yaml b/backends/vulkan/runtime/graph/ops/glsl/batchnorm.yaml index a92e44f636b..116773c816a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/batchnorm.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/batchnorm.yaml @@ -2,6 +2,7 @@ batchnorm: parameter_names_with_default_values: DTYPE: float NDIM: 3 + STORAGE: texture3d generate_variant_forall: DTYPE: - VALUE: half