From 2dabba8affc4257ee88e5dada395020802bef62f Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Tue, 5 Aug 2025 14:13:01 -0700 Subject: [PATCH] [ET-VK][BE] Move all ops to use `DynamicDispatchNode` ## Changes Update (almost) all operators use `DynamicDispatchNode` instead of `DispatchNode`. ## Context `DynamicDispatchNode` was introduced in order to provide a way for operators to adjust 1. Which compute shader to dispatch 2. What global work group size to use 3. What local work group size to use Based on the current input and output shapes. This is useful for making sure that the most optimal compute shader is used for the current tensor sizes, and minimizing the number of inactive shader invocations. Differential Revision: [D79564595](https://our.internmc.facebook.com/intern/diff/D79564595/) [ghstack-poisoned] --- .../vulkan/runtime/graph/ops/impl/Arange.cpp | 7 +- .../runtime/graph/ops/impl/BatchNorm.cpp | 7 +- .../runtime/graph/ops/impl/Convolution.cpp | 138 +++++++++++++++--- .../vulkan/runtime/graph/ops/impl/Copy.cpp | 7 +- .../runtime/graph/ops/impl/Embedding.cpp | 7 +- .../vulkan/runtime/graph/ops/impl/Flip.cpp | 19 ++- .../vulkan/runtime/graph/ops/impl/Full.cpp | 7 +- .../runtime/graph/ops/impl/GridPriors.cpp | 7 +- .../runtime/graph/ops/impl/IndexSelect.cpp | 13 +- .../vulkan/runtime/graph/ops/impl/Linear.cpp | 101 +++++++++---- .../graph/ops/impl/NativeLayerNorm.cpp | 7 +- .../vulkan/runtime/graph/ops/impl/Pad.cpp | 7 +- .../vulkan/runtime/graph/ops/impl/Pool.cpp | 19 +-- .../graph/ops/impl/QuantizedLinearQCSNW.cpp | 106 +++++++++++++- .../vulkan/runtime/graph/ops/impl/Repeat.cpp | 11 +- .../graph/ops/impl/RepeatInterleave.cpp | 12 +- .../vulkan/runtime/graph/ops/impl/Tan.cpp | 7 +- .../runtime/graph/ops/impl/Upsample.cpp | 7 +- .../vulkan/runtime/graph/ops/impl/Var.cpp | 101 ++++++++++++- .../vulkan/runtime/graph/ops/impl/Where.cpp | 23 +-- 20 files changed, 471 insertions(+), 142 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/impl/Arange.cpp b/backends/vulkan/runtime/graph/ops/impl/Arange.cpp index ebfadbb05cb..3171fbeb488 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Arange.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Arange.cpp @@ -10,6 +10,7 @@ #include +#include #include #include @@ -86,11 +87,11 @@ void add_arange_node( kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp index dcadcf80e42..757afd06849 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -83,11 +84,11 @@ void add_native_batch_norm_node( const int32_t num_texel_per_batch = utils::div_up_4((dim_at(in_sizes))); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out_ref), - graph.create_local_wg_size(out_ref), + default_pick_global_wg_size, + default_pick_local_wg_size, {{out_ref, vkapi::kWrite}, {{in_ref, arg_weight, arg_bias, arg_mean, arg_var}, vkapi::kRead}}, {graph.logical_limits_ubo(out_ref), diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 25b4d85be68..f5b5faa1c8b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -19,6 +20,13 @@ namespace vkcompute { +enum class Conv2dMethod : uint8_t { + Depthwise, + Pointwise, + SlidingWindow, + Transposed, +}; + void resize_conv2d_node( ComputeGraph* graph, const std::vector& args, @@ -114,13 +122,6 @@ ValueRef prepack_biases( return v; } -enum class Conv2dMethod : uint8_t { - Depthwise, - Pointwise, - SlidingWindow, - Transposed, -}; - vkapi::ShaderInfo get_conv2d_shader( ComputeGraph& graph, const ValueRef out, @@ -327,6 +328,108 @@ utils::uvec3 create_conv2d_global_wg_size( } } +// Custom global workgroup size function for conv2d +utils::uvec3 conv2d_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef weight_data = resize_args.at(0); + + // Determine method from shader name + Conv2dMethod method; + if (shader.kernel_name.find("conv2d_dw") != std::string::npos) { + method = Conv2dMethod::Depthwise; + } else if ( + shader.kernel_name.find("conv2d_pw") != std::string::npos || + (shader.kernel_name.find("conv2d") != std::string::npos && + shader.kernel_name.find("conv_transpose2d") == std::string::npos)) { + // Check if it's pointwise by examining weight sizes + const auto& weight_sizes = graph->get_tref(weight_data)->sizes; + if (weight_sizes.at(2) == 1 && weight_sizes.at(3) == 1) { + method = Conv2dMethod::Pointwise; + } else { + method = Conv2dMethod::SlidingWindow; + } + } else if (shader.kernel_name.find("conv_transpose2d") != std::string::npos) { + method = Conv2dMethod::Transposed; + } else { + method = Conv2dMethod::SlidingWindow; + } + + // Determine stride_equals_dilation from shader name + bool stride_equals_dilation = + shader.kernel_name.find("_sned") == std::string::npos; + + utils::uvec3 wg_size = create_conv2d_global_wg_size( + *graph, method, out, weight_data, stride_equals_dilation); + + if (method == Conv2dMethod::Depthwise || method == Conv2dMethod::Pointwise) { + wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1}; + } + + return wg_size; +} + +// Custom local workgroup size function for conv2d +utils::uvec3 conv2d_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)args; + (void)resize_args; + + // Determine method from shader name + Conv2dMethod method; + if (shader.kernel_name.find("conv2d_dw") != std::string::npos) { + method = Conv2dMethod::Depthwise; + } else if ( + shader.kernel_name.find("conv2d_pw") != std::string::npos || + (shader.kernel_name.find("conv2d") != std::string::npos && + shader.kernel_name.find("conv_transpose2d") == std::string::npos)) { + method = Conv2dMethod::Pointwise; + } else { + method = Conv2dMethod::SlidingWindow; + } + + if (method == Conv2dMethod::Pointwise) { + uint32_t local_wg_size_y = 1; + if (global_workgroup_size[1] % 8 == 0) { + local_wg_size_y = 8; + } else if (global_workgroup_size[1] % 4 == 0) { + local_wg_size_y = 4; + } else if (global_workgroup_size[1] % 2 == 0) { + local_wg_size_y = 2; + } + return {64 / local_wg_size_y, local_wg_size_y, 1}; + } else if (method == Conv2dMethod::Depthwise) { + return {64, 1, 1}; + } else { + return graph->create_local_wg_size(global_workgroup_size); + } +} + +// Custom global workgroup size function for conv1d +utils::uvec3 conv1d_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + + return {// out length + graph->size_at(-1, out), + // out channels + static_cast(graph->size_at(-2, out)), + // out batches + utils::div_up_4(graph->size_at(-3, out))}; +} + void add_conv2d_node( ComputeGraph& graph, const ValueRef in, @@ -486,11 +589,11 @@ void add_conv2d_node( }; } - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, shader, - wg_size, - local_wg_size, + conv2d_global_wg_size, + conv2d_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}}, // Shader params buffers @@ -560,15 +663,6 @@ void add_conv1d_node( const int32_t out_group_size = static_cast(out_channels / groups_val); - const utils::uvec3 global_size = { - // out length - graph.size_at(-1, out), - // out channels - static_cast(out_channels), - // out batches - utils::div_up_4(graph.size_at(-3, out))}; - const utils::uvec3 local_size = graph.create_local_wg_size(global_size); - Kernel1dParams kernel_params = { kernel_size, stride_size, @@ -587,11 +681,11 @@ void add_conv1d_node( add_dtype_suffix(kernel_name, graph.dtype_of(out)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, + conv1d_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/Copy.cpp b/backends/vulkan/runtime/graph/ops/impl/Copy.cpp index 27e8c81ba9e..bd648dbae2d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Copy.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Copy.cpp @@ -8,6 +8,7 @@ #include +#include #include #include #include @@ -35,11 +36,11 @@ void add_copy_offset_node( auto shader = VK_KERNEL_FROM_STR(kernel_name); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs { {out, vkapi::kWrite}, diff --git a/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp b/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp index b5a2f20cf4b..475e7796b09 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -46,11 +47,11 @@ void add_embedding_node( kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, {{out, vkapi::kWrite}, {{in, weight}, vkapi::kRead}}, { graph.sizes_ubo(out), diff --git a/backends/vulkan/runtime/graph/ops/impl/Flip.cpp b/backends/vulkan/runtime/graph/ops/impl/Flip.cpp index 6679bfe32f5..52288734704 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Flip.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Flip.cpp @@ -8,6 +8,7 @@ #include +#include #include #include #include @@ -15,6 +16,18 @@ namespace vkcompute { +// Custom global workgroup size function for flip +utils::uvec3 flip_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + return graph->create_global_wg_size(out); +} + void check_flip_args( ComputeGraph& graph, const ValueRef in, @@ -59,11 +72,11 @@ void add_flip_node( kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + flip_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs { {out, vkapi::kWrite}, diff --git a/backends/vulkan/runtime/graph/ops/impl/Full.cpp b/backends/vulkan/runtime/graph/ops/impl/Full.cpp index 2fa22312745..fe2676e91e0 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Full.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Full.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -42,11 +43,11 @@ void add_full_node( add_dtype_suffix(kernel_name, graph.dtype_of(out)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/GridPriors.cpp b/backends/vulkan/runtime/graph/ops/impl/GridPriors.cpp index 620613fdfb8..5f39c16d405 100644 --- a/backends/vulkan/runtime/graph/ops/impl/GridPriors.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/GridPriors.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -46,11 +47,11 @@ void add_grid_priors_node( add_dtype_suffix(kernel_name, graph.dtype_of(out)); const GridPriorsParam param = {stride, offset}; - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs { {out, vkapi::kWrite}, diff --git a/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp b/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp index 86faabd48d5..576711a86f1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -38,11 +39,11 @@ void add_index_select_channel_node( kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, {{out, vkapi::kWrite}, {{in, idx}, vkapi::kRead}}, {graph.sizes_ubo(out), graph.sizes_ubo(in)}, // Push Constants @@ -92,11 +93,11 @@ void add_index_select_node( kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, {{out, vkapi::kWrite}, {{in, idx}, vkapi::kRead}}, {graph.sizes_ubo(out), graph.create_params_buffer(params)}, // Push Constants diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index a58444a7830..7ca31599cdf 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -18,6 +19,70 @@ namespace vkcompute { +// Custom global workgroup size function for addmm_naive_texture +utils::uvec3 addmm_naive_texture_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + return graph->logical_limits_of(out); +} + +// Custom global workgroup size function for addmm_naive_buffer +utils::uvec3 addmm_naive_buffer_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + return { + graph->size_at(-1, out), + graph->size_at(-2, out), + graph->size_at(-3, out) * graph->size_at(-4, out)}; +} + +// Custom global workgroup size function for addmm_optimized +utils::uvec3 addmm_optimized_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef mat1 = args.at(1).refs.at(0); + + std::vector mat1_sizes = graph->sizes_of(mat1); + int mat1_dims = mat1_sizes.size(); + + utils::uvec3 global_size = graph->logical_limits_of(out); + + if (mat1_sizes.at(mat1_dims - 2) < 8) { + global_size = utils::divup_vec(global_size, {4, 2, 1}); + } else { + global_size = utils::divup_vec(global_size, {4, 4, 1}); + } + return global_size; +} + +// Custom local workgroup size function for addmm_optimized +utils::uvec3 addmm_optimized_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)args; + (void)resize_args; + return adaptive_work_group_size(global_workgroup_size); +} + void check_addmm_args( ComputeGraph& graph, const ValueRef self, @@ -109,11 +174,11 @@ void add_addmm_naive_texture_node( add_dtype_suffix(kernel_name, graph.dtype_of(out)); utils::uvec3 global_wg_size = graph.logical_limits_of(out); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - graph.create_local_wg_size(global_wg_size), + addmm_naive_texture_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}}, // Shader params buffers @@ -176,11 +241,11 @@ void add_addmm_naive_buffer_node( ? 1 : 0; - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - graph.create_local_wg_size(global_size), + addmm_naive_buffer_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}}, // Shader params buffers @@ -250,31 +315,13 @@ void add_addmm_optimized_node( } else { kernel_name += "_tile_row_4"; } - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - utils::uvec3 global_size = graph.logical_limits_of(out); - - // Each thread computes a W=(2/4) x H=4 x C=(1/4) output tile. Therefore, the - // total number of threads is W/(2 or 4) x H/4 x C/1. Since the out tensor is - // channels packed, C does not need to be divided by 4. The "identity" of each - // thread is the (x, y, z) coordinate of the output tile it is computing, and - // this identity can be used to compute the tensor index of the top left - // element in the tile, which will be [W=x*(2 or 4), H=y*4, C=z*(1 or 4), N=0] - if (mat1_sizes.at(mat1_dims - 2) < 8) { - // Use `logical_extents` instead of `image_extents` because the workgroup - // axes need to correspond to tensor dimensions. - global_size = utils::divup_vec(global_size, {4, 2, 1}); - } else { - global_size = utils::divup_vec(global_size, {4, 4, 1}); - } - utils::uvec3 local_size = adaptive_work_group_size(global_size); - - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, + addmm_optimized_global_wg_size, + addmm_optimized_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1_W_packed, mat2_packed, self}, vkapi::kRead}}, diff --git a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp index 99f945da535..8e15b56b208 100644 --- a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -104,11 +105,11 @@ void add_native_layer_norm_node( add_dtype_suffix(kernel_name, graph.dtype_of(out_tensor)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{{out_tensor, mean_tensor, rstd_tensor}, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}}, diff --git a/backends/vulkan/runtime/graph/ops/impl/Pad.cpp b/backends/vulkan/runtime/graph/ops/impl/Pad.cpp index a10984eac78..d225af05633 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Pad.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Pad.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -76,11 +77,11 @@ void add_constant_pad_nd_node( add_dtype_suffix(kernel_name, graph.dtype_of(out)); } - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp index e74b9ec96a7..b3791a4f7d1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -80,9 +81,6 @@ void add_max_pool2d_node( check_pool2d_args(graph, in, out_tensor); - utils::uvec3 global_size = graph.logical_limits_of(out_tensor); - utils::uvec3 local_size = adaptive_work_group_size(global_size); - std::string kernel_name("max_pool2d"); add_dtype_suffix(kernel_name, graph.dtype_of(out_tensor)); @@ -94,11 +92,11 @@ void add_max_pool2d_node( padding, dilation); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{{out_val->at(0), out_val->at(1)}, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers @@ -154,9 +152,6 @@ void add_avg_pool2d_node( const ValueRef out) { check_pool2d_args(graph, in, out); - utils::uvec3 global_size = graph.logical_limits_of(out); - utils::uvec3 local_size = adaptive_work_group_size(global_size); - std::string kernel_name("avg_pool2d"); add_dtype_suffix(kernel_name, graph.dtype_of(out)); @@ -166,11 +161,11 @@ void add_avg_pool2d_node( DivisorParams divisor_params = create_divisor_params(graph, divisor_override, count_include_pad); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp index 05a300bee4c..89c9e847724 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -15,6 +16,99 @@ namespace vkcompute { +// Custom global workgroup size function for linear_qcs8w +utils::uvec3 linear_qcs8w_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + return {static_cast(graph->numel_of(out)), 1, 1}; +} + +// Custom local workgroup size function for linear_qcs8w +utils::uvec3 linear_qcs8w_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)global_workgroup_size; + (void)args; + (void)resize_args; + return {64, 1, 1}; +} + +// Custom global workgroup size function for linear_qcsnw_tiled +utils::uvec3 linear_qcsnw_tiled_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef mat1 = args.at(1).refs.at(0); + + // Determine quantization bits from shader name + int quant_nbits = 8; + if (shader.kernel_name.find("qcs4w") != std::string::npos) { + quant_nbits = 4; + } + + std::vector mat1_sizes = graph->sizes_of(mat1); + const int64_t M = utils::val_at(-2, mat1_sizes); + uint32_t out_tile_nrows = 4; + if (M % 6 == 0) { + out_tile_nrows = 2; + } else if (M % 4 == 0) { + out_tile_nrows = 4; + } else if (M % 1 == 0) { + out_tile_nrows = 1; + } else { + out_tile_nrows = 4; + } + + // Number of output texels in the output tile + uint32_t out_tile_ntxcols = 1; + if (quant_nbits == 4) { + out_tile_ntxcols = 2; + } + + utils::uvec3 out_limits = graph->logical_limits_of(out); + uint32_t global_wg_x = utils::div_up(out_limits[0], out_tile_ntxcols); + return { + global_wg_x * (utils::div_up(out_limits[1], out_tile_nrows)), + 1, + out_limits[2]}; +} + +// Custom local workgroup size function for linear_qcsnw_tiled +utils::uvec3 linear_qcsnw_tiled_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)global_workgroup_size; + (void)args; + (void)resize_args; + + // Check if using cooperative algorithm from shader name + bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; + + if (use_coop_algorithm) { + return {8, 1, 8}; + } else { + return {64, 1, 1}; + } +} + void check_linear_qcsnw_args( const ComputeGraph& graph, const int quant_nbits, @@ -138,11 +232,11 @@ void add_linear_qcs8w_node( static_cast(graph.numel_of(out_W_packed)), 1, 1}; const utils::uvec3 local_wg{64, 1, 1}; - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_wg, - local_wg, + linear_qcs8w_global_wg_size, + linear_qcs8w_local_wg_size, // Inputs and Outputs {{out_W_packed, vkapi::MemoryAccessType::WRITE}, {{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}}, @@ -247,11 +341,11 @@ void add_linear_qcsnw_tiled_node( local_wg_size = {8, 1, 8}; } - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - local_wg_size, + linear_qcsnw_tiled_global_wg_size, + linear_qcsnw_tiled_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, q_mat2, scales}, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp b/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp index d7a2b7a8ca2..72c1637a2c9 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp @@ -8,6 +8,7 @@ #include +#include #include #include #include @@ -92,15 +93,15 @@ void add_repeat_node( const auto shader = VK_KERNEL_FROM_STR(kernel_name); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - wg_size, - graph.create_local_wg_size(wg_size), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs { - {out, vkapi::MemoryAccessType::WRITE}, - {in, vkapi::MemoryAccessType::READ}, + {out, vkapi::kWrite}, + {in, vkapi::kRead}, }, // Parameter buffers {}, diff --git a/backends/vulkan/runtime/graph/ops/impl/RepeatInterleave.cpp b/backends/vulkan/runtime/graph/ops/impl/RepeatInterleave.cpp index ae2aeec10bf..221d0d23f51 100644 --- a/backends/vulkan/runtime/graph/ops/impl/RepeatInterleave.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/RepeatInterleave.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -49,16 +50,11 @@ void add_repeat_interleave_node( std::string kernel_name = "repeat_interleave"; add_dtype_suffix(kernel_name, graph.dtype_of(out)); - const utils::uvec3 global_wg_size = graph.logical_limits_of(in); - const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); - - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - // Shader VK_KERNEL_FROM_STR(kernel_name), - // Workgroup sizes - global_wg_size, - local_wg_size, + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::MemoryAccessType::WRITE}, {in, vkapi::MemoryAccessType::READ}}, diff --git a/backends/vulkan/runtime/graph/ops/impl/Tan.cpp b/backends/vulkan/runtime/graph/ops/impl/Tan.cpp index 307f774de5e..687b3923354 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Tan.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Tan.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -35,11 +36,11 @@ void add_tan_node(ComputeGraph& graph, const ValueRef in, const ValueRef out) { vkapi::ParamsBindList ubos({}); ubos.append({graph.logical_limits_ubo(out)}); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/Upsample.cpp b/backends/vulkan/runtime/graph/ops/impl/Upsample.cpp index ed9fef61a78..6662ae367c5 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Upsample.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Upsample.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -114,11 +115,11 @@ void add_upsample_nearest2d_node( } add_dtype_suffix(kernel_name, graph.dtype_of(out)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::MemoryAccessType::WRITE}, {in, vkapi::MemoryAccessType::READ}}, diff --git a/backends/vulkan/runtime/graph/ops/impl/Var.cpp b/backends/vulkan/runtime/graph/ops/impl/Var.cpp index 106a6fd6d9a..d8fd367f18a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Var.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Var.cpp @@ -7,6 +7,7 @@ */ #include +#include #include #include @@ -14,6 +15,93 @@ namespace vkcompute { using namespace utils; +// Custom global workgroup size function for var_buffer +utils::uvec3 var_buffer_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + return { + graph->size_at(-1, out), + graph->size_at(-2, out), + graph->size_at(-3, out) * graph->size_at(-4, out)}; +} + +// Custom local workgroup size function for var_buffer +utils::uvec3 var_buffer_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)global_workgroup_size; + const ValueRef in = args.at(1).refs.at(0); + const int dim = resize_args.at(0); + + const int64_t ndim = graph->dim_of(in); + int32_t reduce_dim = normalize(dim, ndim); + reduce_dim = nchw_dim_to_whcn_dim(reduce_dim, ndim); + + const uint32_t nworkers_per_group = 4; + utils::uvec3 local_wg_size{1, 1, 1}; + local_wg_size[reduce_dim] = nworkers_per_group; + return local_wg_size; +} + +// Custom global workgroup size function for var_texture +utils::uvec3 var_texture_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + const int dim = resize_args.at(0); + + const int64_t ndim = graph->dim_of(in); + int32_t reduce_dim = normalize(dim, ndim); + reduce_dim = nchw_dim_to_whcn_dim(reduce_dim, ndim); + + utils::uvec3 global_wg_size = graph->logical_limits_of(out); + global_wg_size[reduce_dim] = 1; + return global_wg_size; +} + +// Custom local workgroup size function for var_texture +utils::uvec3 var_texture_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + const ValueRef in = args.at(1).refs.at(0); + const int dim = resize_args.at(0); + + const int64_t ndim = graph->dim_of(in); + int32_t reduce_dim = normalize(dim, ndim); + reduce_dim = nchw_dim_to_whcn_dim(reduce_dim, ndim); + + const uint32_t nworkers_per_group = 4; + const uint32_t ngroups = 4; + + utils::uvec3 local_wg_size{1, 1, 1}; + local_wg_size[reduce_dim] = nworkers_per_group; + const int other_dim_1 = (reduce_dim + 1) % 3; + const int other_dim_2 = (reduce_dim + 2) % 3; + if (global_workgroup_size[other_dim_1] > global_workgroup_size[other_dim_2]) { + local_wg_size[other_dim_1] = ngroups; + } else { + local_wg_size[other_dim_2] = ngroups; + } + return local_wg_size; +} + void resize_var_node( ComputeGraph* graph, const std::vector& args, @@ -68,11 +156,11 @@ void add_var_buffer_node( int32_t unbiased_int = static_cast(unbiased); push_constants.emplace_back(&unbiased_int, sizeof(unbiased_int)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - local_wg_size, + var_buffer_global_wg_size, + var_buffer_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers @@ -143,12 +231,11 @@ void add_var_texture_node( int32_t unbiased_int = static_cast(unbiased); push_constants.emplace_back(&unbiased_int, sizeof(unbiased_int)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - // shader_descriptor, VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - local_wg_size, + var_texture_global_wg_size, + var_texture_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/Where.cpp b/backends/vulkan/runtime/graph/ops/impl/Where.cpp index 1868d3b872e..c1c482d9967 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Where.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Where.cpp @@ -10,6 +10,7 @@ #include +#include #include namespace vkcompute { @@ -37,16 +38,11 @@ void add_where_texture_node( add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - const utils::uvec3 global_wg_size = graph.create_global_wg_size(out); - const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); - - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - // Shader VK_KERNEL_FROM_STR(kernel_name), - // Workgroup sizes - global_wg_size, - local_wg_size, + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{cond, self, other}, vkapi::kRead}}, // Parameter buffers @@ -72,9 +68,6 @@ void add_where_buffer_node( add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - const utils::uvec3 global_wg_size = graph.create_global_wg_size(out); - const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); - vkapi::ParamsBindList ubos = { graph.numel_ubo(out), graph.strides_ubo(out), @@ -82,13 +75,11 @@ void add_where_buffer_node( graph.strides_ubo(self), graph.strides_ubo(other)}; - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - // Shader VK_KERNEL_FROM_STR(kernel_name), - // Workgroup sizes - global_wg_size, - local_wg_size, + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{cond, self, other}, vkapi::kRead}}, // Parameter buffers