From df798288fb46591013adba193d0cd4c96ff9f023 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 7 Aug 2025 09:05:01 -0700 Subject: [PATCH] [ET-VK] Better work group sizes for matmul ## Context Currently `default_pick_local_wg_size()` (which internally calls `ComputeGraph::create_local_wg_size`) is used to select the local work group size for matrix multiplication ops. However, these functions currently bias the size of the local work group towards the largest dim of the global work group producing local wg sizes like ``` shader globalwg size localwg size =========== ===================== ==================== ============= linear_qga4w_tiled_texture3d_texture3d_texture2d_float {256, 29, 1} {32, 2, 1} 1487 matmul_naive_texture3d_float {29, 115, 32} {4, 2, 8} 712 ``` for matrix multiplication shaders. This behaviour was introduced in D64418632 / https://github.com/pytorch/executorch/pull/6409. However, through experimental testing a "square" work group size of `{8, 8, 1}` works a lot better for matrix multiplication shaders. The theoretical analysis for this behaviour is that the local work group size determines the memory locations that need to be loaded to compute the overall work group. For a work group with size `{W, H, 1}` the data required to compute the output would be `W * OUTPUT_TILE_W` columns of the weight tensor and `H * OUTPUT_TILE_H` rows of the input tensor. Note that all work group items in the same W index will be requesting the same columns from the weight tensor, and all work group items in the same H index will be requesting the same rows from the input tensor. If `H==W`, then that "balances" the amount of data needed to loaded from each input tensor and may result in better data sharing behaviour among all work group items. Assuming `OUTPUT_TILE_W == OUTPUT_TILE_H == 1`, a local work group of size `{64, 1, 1}` would require 1 unique row from the input tensor an 64 unique columns to be loaded from the weight tensor, resulting in `(1 + 64) * K = 65K` elements to be loaded in total, where K is the size of the shared reduction dim. Conversely, a local work group of size `{8, 8, 1}` would require 8 unique rows / 8 unique columns resulting in only `(8 + 8) * K = 16K` unique elements to be loaded. This highlights the need to use dedicated logic to compute work group sizes for matrix multiplication shaders. ## Changes * Introduce `pick_hw_square_wg_size` * Use the new local work group size determination function for Quantized Linear, Matmul, and Linear Differential Revision: [D79813236](https://our.internmc.facebook.com/intern/diff/D79813236/) [ghstack-poisoned] --- .../vulkan/runtime/graph/ops/impl/Common.cpp | 22 +++++++++++++++++++ .../vulkan/runtime/graph/ops/impl/Common.h | 18 +++++++++++++++ .../vulkan/runtime/graph/ops/impl/Linear.cpp | 4 ++-- .../vulkan/runtime/graph/ops/impl/MatMul.cpp | 6 ++--- .../graph/ops/impl/QuantizedLinearQGANW.cpp | 3 ++- 5 files changed, 47 insertions(+), 6 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/impl/Common.cpp b/backends/vulkan/runtime/graph/ops/impl/Common.cpp index 4c3c16417b5..a52d2fafdf1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Common.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Common.cpp @@ -33,4 +33,26 @@ utils::uvec3 default_pick_local_wg_size( return graph->create_local_wg_size(global_workgroup_size); } +utils::uvec3 pick_hw_square_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)args; + (void)resize_args; + // Some inactive invocations are okay; set 6 as the threshold to use the + // a square wg size. + if (global_workgroup_size[0u] >= 6 && global_workgroup_size[1u] >= 6) { + return {8u, 8u, 1u}; + } + // If width dim is sufficiently small, then bias towards height dim to reduce + // the number of inactive invocations. + else if (global_workgroup_size[0u] < 6u) { + return {4u, 16u, 1u}; + } + return {16u, 4u, 1u}; +} + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Common.h b/backends/vulkan/runtime/graph/ops/impl/Common.h index 662fb07095a..1831ab2a845 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Common.h +++ b/backends/vulkan/runtime/graph/ops/impl/Common.h @@ -36,4 +36,22 @@ utils::uvec3 default_pick_local_wg_size( const std::vector& args, const std::vector& resize_args); +/** + * Constructs a local work group size with the shape {W, H, 1}. The function + * will try to set W == H == sqrt(num_invocations), where num_invocations is + * typically 64. This configuration is good for ops like matrix multiplication + * as it reduces the total volume of unique data that the entire work group + * will need to read from input tensors in order to produce the output data. + * To compute an output tile of {W, H, 1}, the work group will need to read + * H unique rows = H * K unique elements from the input tensor and W unique cols + * = W * K elements from the weight tensor, resulting in (W + H) * K unique + * elements in total. + */ +utils::uvec3 pick_hw_square_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args); + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index 7ca31599cdf..38d70271f4f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -178,7 +178,7 @@ void add_addmm_naive_texture_node( graph, VK_KERNEL_FROM_STR(kernel_name), addmm_naive_texture_global_wg_size, - default_pick_local_wg_size, + pick_hw_square_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}}, // Shader params buffers @@ -245,7 +245,7 @@ void add_addmm_naive_buffer_node( graph, VK_KERNEL_FROM_STR(kernel_name), addmm_naive_buffer_global_wg_size, - default_pick_local_wg_size, + pick_hw_square_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp index 0f5556060a2..47ecf5f18d2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp @@ -102,7 +102,7 @@ void add_matmul_naive_buffer_node( graph, VK_KERNEL_FROM_STR(kernel_name), matmul_naive_buffer_global_wg_size, - default_pick_local_wg_size, + pick_hw_square_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}}, // Shader params buffers @@ -158,7 +158,7 @@ void add_matmul_naive_texture3d_node( graph, pick_matmul_naive_texture3d_shader, default_pick_global_wg_size, - default_pick_local_wg_size, + pick_hw_square_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}}, // Shader params buffers @@ -273,7 +273,7 @@ void add_matmul_optimized_node( graph, pick_matmul_optimized_shader, matmul_optimized_global_wg_size, - default_pick_local_wg_size, + pick_hw_square_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1_W_packed, mat2_packed}, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp index 8c7c6b0cdf9..52cf75e28b5 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp @@ -158,7 +158,8 @@ utils::uvec3 linear_qga4w_local_wg_size( if (use_coop_algorithm) { return {64, 1, 1}; } else { - return graph->create_local_wg_size(global_workgroup_size); + return pick_hw_square_wg_size( + graph, shader, global_workgroup_size, args, resize_args); } }