diff --git a/backends/vulkan/runtime/graph/ops/impl/Arange.cpp b/backends/vulkan/runtime/graph/ops/impl/Arange.cpp index ebfadbb05cb..3171fbeb488 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Arange.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Arange.cpp @@ -10,6 +10,7 @@ #include +#include #include #include @@ -86,11 +87,11 @@ void add_arange_node( kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp index dcadcf80e42..757afd06849 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -83,11 +84,11 @@ void add_native_batch_norm_node( const int32_t num_texel_per_batch = utils::div_up_4((dim_at(in_sizes))); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out_ref), - graph.create_local_wg_size(out_ref), + default_pick_global_wg_size, + default_pick_local_wg_size, {{out_ref, vkapi::kWrite}, {{in_ref, arg_weight, arg_bias, arg_mean, arg_var}, vkapi::kRead}}, {graph.logical_limits_ubo(out_ref), diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 25b4d85be68..f5b5faa1c8b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -19,6 +20,13 @@ namespace vkcompute { +enum class Conv2dMethod : uint8_t { + Depthwise, + Pointwise, + SlidingWindow, + Transposed, +}; + void resize_conv2d_node( ComputeGraph* graph, const std::vector& args, @@ -114,13 +122,6 @@ ValueRef prepack_biases( return v; } -enum class Conv2dMethod : uint8_t { - Depthwise, - Pointwise, - SlidingWindow, - Transposed, -}; - vkapi::ShaderInfo get_conv2d_shader( ComputeGraph& graph, const ValueRef out, @@ -327,6 +328,108 @@ utils::uvec3 create_conv2d_global_wg_size( } } +// Custom global workgroup size function for conv2d +utils::uvec3 conv2d_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef weight_data = resize_args.at(0); + + // Determine method from shader name + Conv2dMethod method; + if (shader.kernel_name.find("conv2d_dw") != std::string::npos) { + method = Conv2dMethod::Depthwise; + } else if ( + shader.kernel_name.find("conv2d_pw") != std::string::npos || + (shader.kernel_name.find("conv2d") != std::string::npos && + shader.kernel_name.find("conv_transpose2d") == std::string::npos)) { + // Check if it's pointwise by examining weight sizes + const auto& weight_sizes = graph->get_tref(weight_data)->sizes; + if (weight_sizes.at(2) == 1 && weight_sizes.at(3) == 1) { + method = Conv2dMethod::Pointwise; + } else { + method = Conv2dMethod::SlidingWindow; + } + } else if (shader.kernel_name.find("conv_transpose2d") != std::string::npos) { + method = Conv2dMethod::Transposed; + } else { + method = Conv2dMethod::SlidingWindow; + } + + // Determine stride_equals_dilation from shader name + bool stride_equals_dilation = + shader.kernel_name.find("_sned") == std::string::npos; + + utils::uvec3 wg_size = create_conv2d_global_wg_size( + *graph, method, out, weight_data, stride_equals_dilation); + + if (method == Conv2dMethod::Depthwise || method == Conv2dMethod::Pointwise) { + wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1}; + } + + return wg_size; +} + +// Custom local workgroup size function for conv2d +utils::uvec3 conv2d_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)args; + (void)resize_args; + + // Determine method from shader name + Conv2dMethod method; + if (shader.kernel_name.find("conv2d_dw") != std::string::npos) { + method = Conv2dMethod::Depthwise; + } else if ( + shader.kernel_name.find("conv2d_pw") != std::string::npos || + (shader.kernel_name.find("conv2d") != std::string::npos && + shader.kernel_name.find("conv_transpose2d") == std::string::npos)) { + method = Conv2dMethod::Pointwise; + } else { + method = Conv2dMethod::SlidingWindow; + } + + if (method == Conv2dMethod::Pointwise) { + uint32_t local_wg_size_y = 1; + if (global_workgroup_size[1] % 8 == 0) { + local_wg_size_y = 8; + } else if (global_workgroup_size[1] % 4 == 0) { + local_wg_size_y = 4; + } else if (global_workgroup_size[1] % 2 == 0) { + local_wg_size_y = 2; + } + return {64 / local_wg_size_y, local_wg_size_y, 1}; + } else if (method == Conv2dMethod::Depthwise) { + return {64, 1, 1}; + } else { + return graph->create_local_wg_size(global_workgroup_size); + } +} + +// Custom global workgroup size function for conv1d +utils::uvec3 conv1d_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + + return {// out length + graph->size_at(-1, out), + // out channels + static_cast(graph->size_at(-2, out)), + // out batches + utils::div_up_4(graph->size_at(-3, out))}; +} + void add_conv2d_node( ComputeGraph& graph, const ValueRef in, @@ -486,11 +589,11 @@ void add_conv2d_node( }; } - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, shader, - wg_size, - local_wg_size, + conv2d_global_wg_size, + conv2d_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}}, // Shader params buffers @@ -560,15 +663,6 @@ void add_conv1d_node( const int32_t out_group_size = static_cast(out_channels / groups_val); - const utils::uvec3 global_size = { - // out length - graph.size_at(-1, out), - // out channels - static_cast(out_channels), - // out batches - utils::div_up_4(graph.size_at(-3, out))}; - const utils::uvec3 local_size = graph.create_local_wg_size(global_size); - Kernel1dParams kernel_params = { kernel_size, stride_size, @@ -587,11 +681,11 @@ void add_conv1d_node( add_dtype_suffix(kernel_name, graph.dtype_of(out)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, + conv1d_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/Copy.cpp b/backends/vulkan/runtime/graph/ops/impl/Copy.cpp index 27e8c81ba9e..bd648dbae2d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Copy.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Copy.cpp @@ -8,6 +8,7 @@ #include +#include #include #include #include @@ -35,11 +36,11 @@ void add_copy_offset_node( auto shader = VK_KERNEL_FROM_STR(kernel_name); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs { {out, vkapi::kWrite}, diff --git a/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp b/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp index b5a2f20cf4b..475e7796b09 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -46,11 +47,11 @@ void add_embedding_node( kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, {{out, vkapi::kWrite}, {{in, weight}, vkapi::kRead}}, { graph.sizes_ubo(out), diff --git a/backends/vulkan/runtime/graph/ops/impl/Flip.cpp b/backends/vulkan/runtime/graph/ops/impl/Flip.cpp index 6679bfe32f5..52288734704 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Flip.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Flip.cpp @@ -8,6 +8,7 @@ #include +#include #include #include #include @@ -15,6 +16,18 @@ namespace vkcompute { +// Custom global workgroup size function for flip +utils::uvec3 flip_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + return graph->create_global_wg_size(out); +} + void check_flip_args( ComputeGraph& graph, const ValueRef in, @@ -59,11 +72,11 @@ void add_flip_node( kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + flip_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs { {out, vkapi::kWrite}, diff --git a/backends/vulkan/runtime/graph/ops/impl/Full.cpp b/backends/vulkan/runtime/graph/ops/impl/Full.cpp index 2fa22312745..fe2676e91e0 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Full.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Full.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -42,11 +43,11 @@ void add_full_node( add_dtype_suffix(kernel_name, graph.dtype_of(out)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/GridPriors.cpp b/backends/vulkan/runtime/graph/ops/impl/GridPriors.cpp index 620613fdfb8..5f39c16d405 100644 --- a/backends/vulkan/runtime/graph/ops/impl/GridPriors.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/GridPriors.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -46,11 +47,11 @@ void add_grid_priors_node( add_dtype_suffix(kernel_name, graph.dtype_of(out)); const GridPriorsParam param = {stride, offset}; - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs { {out, vkapi::kWrite}, diff --git a/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp b/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp index 86faabd48d5..576711a86f1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -38,11 +39,11 @@ void add_index_select_channel_node( kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, {{out, vkapi::kWrite}, {{in, idx}, vkapi::kRead}}, {graph.sizes_ubo(out), graph.sizes_ubo(in)}, // Push Constants @@ -92,11 +93,11 @@ void add_index_select_node( kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, {{out, vkapi::kWrite}, {{in, idx}, vkapi::kRead}}, {graph.sizes_ubo(out), graph.create_params_buffer(params)}, // Push Constants diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index a58444a7830..7ca31599cdf 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -18,6 +19,70 @@ namespace vkcompute { +// Custom global workgroup size function for addmm_naive_texture +utils::uvec3 addmm_naive_texture_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + return graph->logical_limits_of(out); +} + +// Custom global workgroup size function for addmm_naive_buffer +utils::uvec3 addmm_naive_buffer_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + return { + graph->size_at(-1, out), + graph->size_at(-2, out), + graph->size_at(-3, out) * graph->size_at(-4, out)}; +} + +// Custom global workgroup size function for addmm_optimized +utils::uvec3 addmm_optimized_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef mat1 = args.at(1).refs.at(0); + + std::vector mat1_sizes = graph->sizes_of(mat1); + int mat1_dims = mat1_sizes.size(); + + utils::uvec3 global_size = graph->logical_limits_of(out); + + if (mat1_sizes.at(mat1_dims - 2) < 8) { + global_size = utils::divup_vec(global_size, {4, 2, 1}); + } else { + global_size = utils::divup_vec(global_size, {4, 4, 1}); + } + return global_size; +} + +// Custom local workgroup size function for addmm_optimized +utils::uvec3 addmm_optimized_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)args; + (void)resize_args; + return adaptive_work_group_size(global_workgroup_size); +} + void check_addmm_args( ComputeGraph& graph, const ValueRef self, @@ -109,11 +174,11 @@ void add_addmm_naive_texture_node( add_dtype_suffix(kernel_name, graph.dtype_of(out)); utils::uvec3 global_wg_size = graph.logical_limits_of(out); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - graph.create_local_wg_size(global_wg_size), + addmm_naive_texture_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}}, // Shader params buffers @@ -176,11 +241,11 @@ void add_addmm_naive_buffer_node( ? 1 : 0; - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - graph.create_local_wg_size(global_size), + addmm_naive_buffer_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}}, // Shader params buffers @@ -250,31 +315,13 @@ void add_addmm_optimized_node( } else { kernel_name += "_tile_row_4"; } - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - utils::uvec3 global_size = graph.logical_limits_of(out); - - // Each thread computes a W=(2/4) x H=4 x C=(1/4) output tile. Therefore, the - // total number of threads is W/(2 or 4) x H/4 x C/1. Since the out tensor is - // channels packed, C does not need to be divided by 4. The "identity" of each - // thread is the (x, y, z) coordinate of the output tile it is computing, and - // this identity can be used to compute the tensor index of the top left - // element in the tile, which will be [W=x*(2 or 4), H=y*4, C=z*(1 or 4), N=0] - if (mat1_sizes.at(mat1_dims - 2) < 8) { - // Use `logical_extents` instead of `image_extents` because the workgroup - // axes need to correspond to tensor dimensions. - global_size = utils::divup_vec(global_size, {4, 2, 1}); - } else { - global_size = utils::divup_vec(global_size, {4, 4, 1}); - } - utils::uvec3 local_size = adaptive_work_group_size(global_size); - - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, + addmm_optimized_global_wg_size, + addmm_optimized_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1_W_packed, mat2_packed, self}, vkapi::kRead}}, diff --git a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp index 99f945da535..8e15b56b208 100644 --- a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -104,11 +105,11 @@ void add_native_layer_norm_node( add_dtype_suffix(kernel_name, graph.dtype_of(out_tensor)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{{out_tensor, mean_tensor, rstd_tensor}, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}}, diff --git a/backends/vulkan/runtime/graph/ops/impl/Pad.cpp b/backends/vulkan/runtime/graph/ops/impl/Pad.cpp index a10984eac78..d225af05633 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Pad.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Pad.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -76,11 +77,11 @@ void add_constant_pad_nd_node( add_dtype_suffix(kernel_name, graph.dtype_of(out)); } - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp index e74b9ec96a7..b3791a4f7d1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -80,9 +81,6 @@ void add_max_pool2d_node( check_pool2d_args(graph, in, out_tensor); - utils::uvec3 global_size = graph.logical_limits_of(out_tensor); - utils::uvec3 local_size = adaptive_work_group_size(global_size); - std::string kernel_name("max_pool2d"); add_dtype_suffix(kernel_name, graph.dtype_of(out_tensor)); @@ -94,11 +92,11 @@ void add_max_pool2d_node( padding, dilation); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{{out_val->at(0), out_val->at(1)}, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers @@ -154,9 +152,6 @@ void add_avg_pool2d_node( const ValueRef out) { check_pool2d_args(graph, in, out); - utils::uvec3 global_size = graph.logical_limits_of(out); - utils::uvec3 local_size = adaptive_work_group_size(global_size); - std::string kernel_name("avg_pool2d"); add_dtype_suffix(kernel_name, graph.dtype_of(out)); @@ -166,11 +161,11 @@ void add_avg_pool2d_node( DivisorParams divisor_params = create_divisor_params(graph, divisor_override, count_include_pad); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp index 05a300bee4c..89c9e847724 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -15,6 +16,99 @@ namespace vkcompute { +// Custom global workgroup size function for linear_qcs8w +utils::uvec3 linear_qcs8w_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + return {static_cast(graph->numel_of(out)), 1, 1}; +} + +// Custom local workgroup size function for linear_qcs8w +utils::uvec3 linear_qcs8w_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)global_workgroup_size; + (void)args; + (void)resize_args; + return {64, 1, 1}; +} + +// Custom global workgroup size function for linear_qcsnw_tiled +utils::uvec3 linear_qcsnw_tiled_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef mat1 = args.at(1).refs.at(0); + + // Determine quantization bits from shader name + int quant_nbits = 8; + if (shader.kernel_name.find("qcs4w") != std::string::npos) { + quant_nbits = 4; + } + + std::vector mat1_sizes = graph->sizes_of(mat1); + const int64_t M = utils::val_at(-2, mat1_sizes); + uint32_t out_tile_nrows = 4; + if (M % 6 == 0) { + out_tile_nrows = 2; + } else if (M % 4 == 0) { + out_tile_nrows = 4; + } else if (M % 1 == 0) { + out_tile_nrows = 1; + } else { + out_tile_nrows = 4; + } + + // Number of output texels in the output tile + uint32_t out_tile_ntxcols = 1; + if (quant_nbits == 4) { + out_tile_ntxcols = 2; + } + + utils::uvec3 out_limits = graph->logical_limits_of(out); + uint32_t global_wg_x = utils::div_up(out_limits[0], out_tile_ntxcols); + return { + global_wg_x * (utils::div_up(out_limits[1], out_tile_nrows)), + 1, + out_limits[2]}; +} + +// Custom local workgroup size function for linear_qcsnw_tiled +utils::uvec3 linear_qcsnw_tiled_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)global_workgroup_size; + (void)args; + (void)resize_args; + + // Check if using cooperative algorithm from shader name + bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; + + if (use_coop_algorithm) { + return {8, 1, 8}; + } else { + return {64, 1, 1}; + } +} + void check_linear_qcsnw_args( const ComputeGraph& graph, const int quant_nbits, @@ -138,11 +232,11 @@ void add_linear_qcs8w_node( static_cast(graph.numel_of(out_W_packed)), 1, 1}; const utils::uvec3 local_wg{64, 1, 1}; - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_wg, - local_wg, + linear_qcs8w_global_wg_size, + linear_qcs8w_local_wg_size, // Inputs and Outputs {{out_W_packed, vkapi::MemoryAccessType::WRITE}, {{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}}, @@ -247,11 +341,11 @@ void add_linear_qcsnw_tiled_node( local_wg_size = {8, 1, 8}; } - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - local_wg_size, + linear_qcsnw_tiled_global_wg_size, + linear_qcsnw_tiled_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, q_mat2, scales}, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp b/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp index d7a2b7a8ca2..72c1637a2c9 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp @@ -8,6 +8,7 @@ #include +#include #include #include #include @@ -92,15 +93,15 @@ void add_repeat_node( const auto shader = VK_KERNEL_FROM_STR(kernel_name); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - wg_size, - graph.create_local_wg_size(wg_size), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs { - {out, vkapi::MemoryAccessType::WRITE}, - {in, vkapi::MemoryAccessType::READ}, + {out, vkapi::kWrite}, + {in, vkapi::kRead}, }, // Parameter buffers {}, diff --git a/backends/vulkan/runtime/graph/ops/impl/RepeatInterleave.cpp b/backends/vulkan/runtime/graph/ops/impl/RepeatInterleave.cpp index ae2aeec10bf..221d0d23f51 100644 --- a/backends/vulkan/runtime/graph/ops/impl/RepeatInterleave.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/RepeatInterleave.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -49,16 +50,11 @@ void add_repeat_interleave_node( std::string kernel_name = "repeat_interleave"; add_dtype_suffix(kernel_name, graph.dtype_of(out)); - const utils::uvec3 global_wg_size = graph.logical_limits_of(in); - const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); - - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - // Shader VK_KERNEL_FROM_STR(kernel_name), - // Workgroup sizes - global_wg_size, - local_wg_size, + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::MemoryAccessType::WRITE}, {in, vkapi::MemoryAccessType::READ}}, diff --git a/backends/vulkan/runtime/graph/ops/impl/Tan.cpp b/backends/vulkan/runtime/graph/ops/impl/Tan.cpp index 307f774de5e..687b3923354 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Tan.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Tan.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -35,11 +36,11 @@ void add_tan_node(ComputeGraph& graph, const ValueRef in, const ValueRef out) { vkapi::ParamsBindList ubos({}); ubos.append({graph.logical_limits_ubo(out)}); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/Upsample.cpp b/backends/vulkan/runtime/graph/ops/impl/Upsample.cpp index ed9fef61a78..6662ae367c5 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Upsample.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Upsample.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -114,11 +115,11 @@ void add_upsample_nearest2d_node( } add_dtype_suffix(kernel_name, graph.dtype_of(out)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::MemoryAccessType::WRITE}, {in, vkapi::MemoryAccessType::READ}}, diff --git a/backends/vulkan/runtime/graph/ops/impl/Var.cpp b/backends/vulkan/runtime/graph/ops/impl/Var.cpp index 106a6fd6d9a..d8fd367f18a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Var.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Var.cpp @@ -7,6 +7,7 @@ */ #include +#include #include #include @@ -14,6 +15,93 @@ namespace vkcompute { using namespace utils; +// Custom global workgroup size function for var_buffer +utils::uvec3 var_buffer_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + return { + graph->size_at(-1, out), + graph->size_at(-2, out), + graph->size_at(-3, out) * graph->size_at(-4, out)}; +} + +// Custom local workgroup size function for var_buffer +utils::uvec3 var_buffer_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)global_workgroup_size; + const ValueRef in = args.at(1).refs.at(0); + const int dim = resize_args.at(0); + + const int64_t ndim = graph->dim_of(in); + int32_t reduce_dim = normalize(dim, ndim); + reduce_dim = nchw_dim_to_whcn_dim(reduce_dim, ndim); + + const uint32_t nworkers_per_group = 4; + utils::uvec3 local_wg_size{1, 1, 1}; + local_wg_size[reduce_dim] = nworkers_per_group; + return local_wg_size; +} + +// Custom global workgroup size function for var_texture +utils::uvec3 var_texture_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + const int dim = resize_args.at(0); + + const int64_t ndim = graph->dim_of(in); + int32_t reduce_dim = normalize(dim, ndim); + reduce_dim = nchw_dim_to_whcn_dim(reduce_dim, ndim); + + utils::uvec3 global_wg_size = graph->logical_limits_of(out); + global_wg_size[reduce_dim] = 1; + return global_wg_size; +} + +// Custom local workgroup size function for var_texture +utils::uvec3 var_texture_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + const ValueRef in = args.at(1).refs.at(0); + const int dim = resize_args.at(0); + + const int64_t ndim = graph->dim_of(in); + int32_t reduce_dim = normalize(dim, ndim); + reduce_dim = nchw_dim_to_whcn_dim(reduce_dim, ndim); + + const uint32_t nworkers_per_group = 4; + const uint32_t ngroups = 4; + + utils::uvec3 local_wg_size{1, 1, 1}; + local_wg_size[reduce_dim] = nworkers_per_group; + const int other_dim_1 = (reduce_dim + 1) % 3; + const int other_dim_2 = (reduce_dim + 2) % 3; + if (global_workgroup_size[other_dim_1] > global_workgroup_size[other_dim_2]) { + local_wg_size[other_dim_1] = ngroups; + } else { + local_wg_size[other_dim_2] = ngroups; + } + return local_wg_size; +} + void resize_var_node( ComputeGraph* graph, const std::vector& args, @@ -68,11 +156,11 @@ void add_var_buffer_node( int32_t unbiased_int = static_cast(unbiased); push_constants.emplace_back(&unbiased_int, sizeof(unbiased_int)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - local_wg_size, + var_buffer_global_wg_size, + var_buffer_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers @@ -143,12 +231,11 @@ void add_var_texture_node( int32_t unbiased_int = static_cast(unbiased); push_constants.emplace_back(&unbiased_int, sizeof(unbiased_int)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - // shader_descriptor, VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - local_wg_size, + var_texture_global_wg_size, + var_texture_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/Where.cpp b/backends/vulkan/runtime/graph/ops/impl/Where.cpp index 1868d3b872e..c1c482d9967 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Where.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Where.cpp @@ -10,6 +10,7 @@ #include +#include #include namespace vkcompute { @@ -37,16 +38,11 @@ void add_where_texture_node( add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - const utils::uvec3 global_wg_size = graph.create_global_wg_size(out); - const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); - - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - // Shader VK_KERNEL_FROM_STR(kernel_name), - // Workgroup sizes - global_wg_size, - local_wg_size, + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{cond, self, other}, vkapi::kRead}}, // Parameter buffers @@ -72,9 +68,6 @@ void add_where_buffer_node( add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - const utils::uvec3 global_wg_size = graph.create_global_wg_size(out); - const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); - vkapi::ParamsBindList ubos = { graph.numel_ubo(out), graph.strides_ubo(out), @@ -82,13 +75,11 @@ void add_where_buffer_node( graph.strides_ubo(self), graph.strides_ubo(other)}; - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - // Shader VK_KERNEL_FROM_STR(kernel_name), - // Workgroup sizes - global_wg_size, - local_wg_size, + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{cond, self, other}, vkapi::kRead}}, // Parameter buffers