From dee7beb56b2f0254462118d270677326bfe9a197 Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 15 Oct 2025 09:07:42 -0700 Subject: [PATCH] [ET-VK] Introduce specialized implementation for per-row reduction Title says it all! This diff also adds support for argmin and argmax, but only for per-row reduction. Differential Revision: [D84716454](https://our.internmc.facebook.com/intern/diff/D84716454/) [ghstack-poisoned] --- backends/vulkan/op_registry.py | 10 ++ .../graph/ops/glsl/reduce_op_defs.glslh | 94 ++++++++++++++ .../graph/ops/glsl/reduce_per_row_buffer.glsl | 122 ++++++++++++++++++ .../graph/ops/glsl/reduce_per_row_buffer.yaml | 42 ++++++ .../runtime/graph/ops/impl/ArgReduce.cpp | 56 ++++++++ .../vulkan/runtime/graph/ops/impl/Reduce.cpp | 84 ++++++++++++ .../vulkan/runtime/graph/ops/impl/Reduce.h | 23 ++++ backends/vulkan/test/op_tests/cases.py | 68 ++++++++++ 8 files changed, 499 insertions(+) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/reduce_op_defs.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/ArgReduce.cpp create mode 100644 backends/vulkan/runtime/graph/ops/impl/Reduce.h diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index b47a8f383a0..7b18d8326a3 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -449,6 +449,7 @@ def try_find_keepdim_arg(node: torch.fx.Node) -> bool: return False keepdim = try_find_keepdim_arg(node) + # keepdim = False is not supported yet if isinstance(keepdim, bool) and not keepdim: return False @@ -461,6 +462,15 @@ def pick_io_storage_for_reduce(node: torch.fx.Node): input_tensor = node.args[0] ndim = input_tensor.meta["val"].ndim dim_list = node.args[1] + + # For 1D reductions, a special case is implemented for reducing the width dim + if isinstance(dim_list, list) and len(dim_list) == 1: + if dim_list[0] == -1: + inputs_storage = utils.ANY_TEXTURE.make_union(utils.CONTIGUOUS_BUFFER) + outputs_storage = inputs_storage + return inputs_storage, outputs_storage + + # For 2D reductions, the packed dimension cannot be one of the reduced dims if isinstance(dim_list, list) and len(dim_list) == 2: reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim) reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim) diff --git a/backends/vulkan/runtime/graph/ops/glsl/reduce_op_defs.glslh b/backends/vulkan/runtime/graph/ops/glsl/reduce_op_defs.glslh new file mode 100644 index 00000000000..bbfb991808f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/reduce_op_defs.glslh @@ -0,0 +1,94 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef REDUCE_OP_DEFS_GLSLH +#define REDUCE_OP_DEFS_GLSLH + +struct Accum { + T val; + uint idx; + uint count; +}; + +void init_accum(out Accum accum, T val, uint idx) { + accum.val = val; + accum.idx = idx; + accum.count = 1; +} + +void init_accum_zero(out Accum accum) { + accum.val = T(0); + accum.idx = 0; + accum.count = 0; +} + +// Sum / Mean + +void update_accum_sum(inout Accum accum, T val, uint idx) { + accum.val += val; + accum.count += 1; +} + +void merge_accum_sum(inout Accum accum, const Accum other) { + accum.val += other.val; + accum.count += other.count; +} + +void postprocess_accum_mean(inout Accum accum) { + accum.val /= T(accum.count); +} + +// Amax (maximum value) + +void update_accum_amax(inout Accum accum, T val, uint idx) { + if (val > accum.val) { + accum.val = val; + accum.idx = idx; + } + // For equivalence, select the lower index + if (val == accum.val && idx < accum.idx) { + accum.idx = idx; + } +} + +void merge_accum_amax(inout Accum accum, const Accum other) { + if (other.val > accum.val) { + accum.val = other.val; + accum.idx = other.idx; + } + // For equivalence, select the lower index + if (other.val == accum.val && other.idx < accum.idx) { + accum.idx = other.idx; + } +} + +// Amin (minimum value) + +void update_accum_amin(inout Accum accum, T val, uint idx) { + if (val < accum.val) { + accum.val = val; + accum.idx = idx; + } + // For equivalence, select the lower index + if (val == accum.val && idx < accum.idx) { + accum.idx = idx; + } +} + +void merge_accum_amin(inout Accum accum, const Accum other) { + if (other.count > 0 && (accum.count == 0 || other.val < accum.val)) { + accum.val = other.val; + accum.idx = other.idx; + } + // For equivalence, select the lower index + if (other.val == accum.val && other.idx < accum.idx) { + accum.idx = other.idx; + } +} + +#endif // REDUCE_OP_DEFS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.glsl new file mode 100644 index 00000000000..d1574a67a62 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.glsl @@ -0,0 +1,122 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${texel_load_component_type(DTYPE, "buffer")} + +#define NUM_OUTPUTS_PER_WG 1 +#define NUM_WORKERS_PER_OUTPUT 64 + +${define_active_storage_type("buffer")} +${define_required_extensions(DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "indexing.glslh" +#include "reduce_op_defs.glslh" + +$if OUTPUT_IS_INDICES: + ${layout_declare_tensor(B, "w", "t_out", "int", "buffer")} +$else: + ${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} + +${layout_declare_tensor(B, "r", "t_in", DTYPE, "buffer")} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// Shared memory for cooperative reduction +shared Accum shared_values[NUM_OUTPUTS_PER_WG][NUM_WORKERS_PER_OUTPUT]; + +#define init_fn ${INIT_ACCUM_FN} +#define update_fn ${UPDATE_ACCUM_FN} +#define merge_fn ${MERGE_ACCUM_FN} + +$if POSTPROCESS_ACCUM_FN != "none": + #define postprocess_fn ${POSTPROCESS_ACCUM_FN} + +$if OOB_INIT_MODE == "zero": + #define OOB_INIT_MODE 0 +$else: + #define OOB_INIT_MODE 1 + +$if OUTPUT_IS_INDICES: + #define OUTPUT_IS_INDICES + +#extension GL_EXT_debug_printf : require + +void main() { + const uint out_bufi = gl_GlobalInvocationID.y; + + if (out_of_bounds(out_bufi, outp)) { + return; + } + + // Local indices + const uint worker_id = gl_LocalInvocationID.x; + const uint output_id = gl_LocalInvocationID.y; + + const uint in_bufi_base = out_bufi * width(inp); + + Accum local_accum; + // Initialize accumulator with the first element being processed + if (worker_id < width(inp)) { + const uint in_bufi = in_bufi_base + worker_id; + init_fn(local_accum, t_in[in_bufi], worker_id); + } + // For out of bounds case, initialization depends on reduction op + else { +#if OOB_INIT_MODE == 0 + // Init with a zero value + init_accum_zero(local_accum); +#else + // Init with the first value (i.e. amin, amax) + init_fn(local_accum, t_in[in_bufi_base], 0); +#endif + } + + for (uint x = worker_id + NUM_WORKERS_PER_OUTPUT; x < width(inp); + x += NUM_WORKERS_PER_OUTPUT) { + update_fn(local_accum, t_in[in_bufi_base + x], x); + } + + shared_values[output_id][worker_id] = local_accum; + + memoryBarrierShared(); + barrier(); + + for (int i = NUM_WORKERS_PER_OUTPUT / 2; i > 0; i >>= 1) { + if (worker_id < i) { + merge_fn( + shared_values[output_id][worker_id], + shared_values[output_id][worker_id + i]); + } + memoryBarrierShared(); + barrier(); + } + + if (worker_id == 0) { + local_accum = shared_values[output_id][0]; +#ifdef postprocess_fn + postprocess_fn(local_accum); +#endif + +#ifdef OUTPUT_IS_INDICES + t_out[out_bufi] = int(0); // int(local_accum.idx); +#else + t_out[out_bufi] = local_accum.val; +#endif + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.yaml new file mode 100644 index 00000000000..e5a94165b96 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.yaml @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +reduce_per_row_buffer: + parameter_names_with_default_values: + DTYPE: float + INIT_ACCUM_FN: init_accum + UPDATE_ACCUM_FN: update_accum_sum + MERGE_ACCUM_FN: merge_accum_sum + POSTPROCESS_ACCUM_FN: none + OOB_INIT_MODE: zero + OUTPUT_IS_INDICES: false + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + - VALUE: int32 + shader_variants: + - NAME: sum_per_row_buffer + - NAME: mean_per_row_buffer + POSTPROCESS_ACCUM_FN: postprocess_accum_mean + - NAME: amax_per_row_buffer + UPDATE_ACCUM_FN: update_accum_amax + MERGE_ACCUM_FN: merge_accum_amax + OOB_INIT_MODE: first_element + - NAME: amin_per_row_buffer + UPDATE_ACCUM_FN: update_accum_amin + MERGE_ACCUM_FN: merge_accum_amin + OOB_INIT_MODE: first_element + - NAME: argmax_per_row_buffer + UPDATE_ACCUM_FN: update_accum_amax + MERGE_ACCUM_FN: merge_accum_amax + OOB_INIT_MODE: first_element + OUTPUT_IS_INDICES: true + - NAME: argmin_per_row_buffer + UPDATE_ACCUM_FN: update_accum_amin + MERGE_ACCUM_FN: merge_accum_amin + OOB_INIT_MODE: first_element + OUTPUT_IS_INDICES: true diff --git a/backends/vulkan/runtime/graph/ops/impl/ArgReduce.cpp b/backends/vulkan/runtime/graph/ops/impl/ArgReduce.cpp new file mode 100644 index 00000000000..cadaf6e5e27 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/ArgReduce.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include + +namespace vkcompute { + +void arg_reduce_impl( + ComputeGraph& graph, + const std::vector& args, + const std::string& op_name) { + int arg_idx = 0; + const ValueRef in = args.at(arg_idx++); + const ValueRef dim = args.at(arg_idx++); + const ValueRef keepdim = args.at(arg_idx++); + const ValueRef out = args.at(arg_idx++); + + VK_CHECK_COND(graph.is_buffer_storage(in)); + + int64_t dim_val = 0; + if (graph.val_is_not_none(dim)) { + dim_val = graph.extract_scalar(dim); + } + const int64_t ndim = graph.dim_of(in); + const int64_t normalized_dim = normalize(dim_val, graph.dim_of(in)); + + VK_CHECK_COND(normalized_dim == ndim - 1); + + // Use the reduce_per_row_node function + add_reduce_per_row_node(graph, in, out, op_name); +} + +void argmin(ComputeGraph& graph, const std::vector& args) { + arg_reduce_impl(graph, args, "argmin"); +} + +void argmax(ComputeGraph& graph, const std::vector& args) { + arg_reduce_impl(graph, args, "argmax"); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.argmin.default, argmin); + VK_REGISTER_OP(aten.argmax.default, argmax); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp b/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp index 6ad1d7f371d..21e5dcd36f3 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp @@ -52,6 +52,20 @@ void resize_reduce2d_node( graph->virtual_resize(out, new_sizes); } +void resize_reduce_per_row_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + + std::vector new_sizes = graph->sizes_of(in); + // Per-row reduction always reduces along the last dimension (width) + new_sizes.back() = 1; + graph->virtual_resize(out, new_sizes); +} + utils::uvec3 reduce_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, @@ -237,6 +251,70 @@ void add_reduce2d_node( resize_reduce2d_node)); } +utils::uvec3 reduce_per_row_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef in = args.at(1).refs.at(0); + return {1u, utils::safe_downcast(graph->numel_of(in)), 1u}; +} + +utils::uvec3 reduce_per_row_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)global_workgroup_size; + (void)args; + (void)resize_args; + + uint32_t outputs_per_wg = 1u; + uint32_t workers_per_output = 64u; + + return {workers_per_output, outputs_per_wg, 1u}; +} + +void add_reduce_per_row_node( + ComputeGraph& graph, + const ValueRef input, + const ValueRef output, + const std::string& op_name) { + std::string kernel_name = op_name + "_per_row"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + + vkapi::ParamsBindList param_ubos = { + graph.meta_ubo(output), + graph.meta_ubo(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + // Global workgroup size function + reduce_per_row_global_wg_size, + // Local workgroup size function + reduce_per_row_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, {input, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + resize_reduce_per_row_node)); +} + #define DEFINE_REDUCE_FN(op_name, out_arg_idx) \ void op_name(ComputeGraph& graph, const std::vector& args) { \ const std::vector dims_list = \ @@ -244,6 +322,12 @@ void add_reduce2d_node( if (dims_list.size() == 1) { \ const int64_t dim_val = dims_list.at(0); \ const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \ + \ + if (dim_val == -1 && graph.is_buffer_storage(args[0])) { \ + return add_reduce_per_row_node( \ + graph, args[0], args[out_arg_idx], #op_name); \ + } \ + \ return add_reduce_node( \ graph, args[0], dim_ref, args[out_arg_idx], #op_name); \ } \ diff --git a/backends/vulkan/runtime/graph/ops/impl/Reduce.h b/backends/vulkan/runtime/graph/ops/impl/Reduce.h new file mode 100644 index 00000000000..453fc165e4e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Reduce.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace vkcompute { + +void add_reduce_per_row_node( + ComputeGraph& graph, + const ValueRef input, + const ValueRef output, + const std::string& op_name); + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 76e5dc1bc0e..0f48b4be4b6 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1557,6 +1557,20 @@ def get_reduce_inputs(is_softmax: bool = False): ] +def get_reduce_per_row_inputs(): + inputs = [ + ((5, 16), -1, True), + ((5, 16), -1, False), + ((7, 21), -1, True), + ((7, 21), -1, False), + ((3, 7, 280), -1, True), + ((3, 7, 280), -1, False), + ((3, 17, 77), -1, True), + ((3, 17, 77), -1, False), + ] + return inputs + + @register_test_suite(["aten._softmax.default", "aten._log_softmax.default"]) def get_softmax_inputs(): test_suite = VkTestSuite(get_reduce_inputs(is_softmax=True)) @@ -1576,6 +1590,20 @@ def get_reduce_op_inputs(): "utils::kChannelsPacked", "utils::kWidthPacked", ] + + per_row_suite = VkTestSuite(get_reduce_per_row_inputs()) + per_row_suite.layouts = ["utils::kWidthPacked"] + per_row_suite.storage_types = ["utils::kBuffer"] + per_row_suite.test_name_suffix = "per_row" + return [test_suite, per_row_suite] + + +@register_test_suite(["aten.argmin.default", "aten.argmax.default"]) +def get_reduce_arg_op_inputs(): + test_suite = VkTestSuite(get_reduce_per_row_inputs()) + test_suite.layouts = ["utils::kWidthPacked"] + test_suite.storage_types = ["utils::kBuffer"] + test_suite.dtypes = ["at::kFloat"] return test_suite @@ -1986,3 +2014,43 @@ def get_pow_tensor_scalar_inputs(): ] test_suite.dtypes = ["at::kFloat"] return test_suite + + +# Example test suite for reduce_per_row operations +# This would be used to test the custom reduce_per_row shader implementation +def get_reduce_per_row_inputs(): + """ + Test cases for per-row reduction operations. + These tests assume 2D input tensors with width divisible by 4. + """ + test_suite = VkTestSuite( + [ + # Basic 2D tensor reductions - height x width + ((8, 32), "sum"), # 8 rows, 32 columns (divisible by 4) + ((16, 64), "mean"), # 16 rows, 64 columns + ((4, 128), "amax"), # 4 rows, 128 columns + ((12, 256), "amin"), # 12 rows, 256 columns + # Test various sizes that are multiples of 4 + ((S1, 12), "sum"), # S1=7 rows, 12 columns + ((S2, 16), "mean"), # S2=11 rows, 16 columns + ((M1, 20), "amax"), # M1=37 rows, 20 columns + ((M2, 24), "amin"), # M2=41 rows, 24 columns + # Larger tensor cases + ((L, 64), "sum"), # L=89 rows, 64 columns + ((XL, 128), "mean"), # XL=113 rows, 128 columns + ] + ) + test_suite.storage_types = [ + "utils::kBuffer", + "utils::kTexture3D", + ] + test_suite.layouts = [ + "utils::kWidthPacked", + ] + test_suite.dtypes = [ + "at::kFloat", + "at::kHalf", + ] + test_suite.data_gen = "make_seq_tensor" + test_suite.test_name_suffix = "per_row" + return test_suite