From a0aa66344890c9a54e1b00185f5fb3d5c628df0c Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Fri, 21 Feb 2025 21:35:37 -0800 Subject: [PATCH] [EK-VT] Replacing the use of uvec3 with WorkgroupSize class to reduce memory usage and improve processing speed This diff replaces the use of `uvec3` with `WorkgroupSize` class to reduce memory usage and improve processing speed in the Vulkan backend of Executorch. Differential Revision: [D70021032](https://our.internmc.facebook.com/intern/diff/D70021032/) [ghstack-poisoned] --- backends/vulkan/runtime/api/Context.cpp | 16 +++++----------- backends/vulkan/runtime/api/Context.h | 14 +++++++++----- backends/vulkan/runtime/graph/ops/BlitNode.cpp | 2 +- .../vulkan/runtime/graph/ops/DispatchNode.h | 2 +- .../vulkan/runtime/graph/ops/PrepackNode.cpp | 4 ++-- .../vulkan/runtime/graph/ops/PrepackNode.h | 2 +- backends/vulkan/runtime/vk_api/Command.cpp | 2 +- backends/vulkan/runtime/vk_api/Command.h | 6 +++--- backends/vulkan/runtime/vk_api/Pipeline.cpp | 18 +++++++++++++----- backends/vulkan/runtime/vk_api/Pipeline.h | 4 ++++ backends/vulkan/runtime/vk_api/Shader.h | 2 +- 11 files changed, 41 insertions(+), 31 deletions(-) diff --git a/backends/vulkan/runtime/api/Context.cpp b/backends/vulkan/runtime/api/Context.cpp index 8178ada3a45..64f32e50f4e 100644 --- a/backends/vulkan/runtime/api/Context.cpp +++ b/backends/vulkan/runtime/api/Context.cpp @@ -74,7 +74,7 @@ void Context::cmd_reset_querypool() { void Context::report_shader_dispatch_start( const std::string& shader_name, const utils::uvec3& global_wg_size, - const utils::uvec3& local_wg_size, + const utils::WorkgroupSize& local_wg_size, const uint32_t dispatch_id) { if (querypool_) { querypool_.shader_profile_begin( @@ -82,7 +82,7 @@ void Context::report_shader_dispatch_start( dispatch_id, shader_name, vkapi::create_extent3d(global_wg_size), - vkapi::create_extent3d(local_wg_size)); + vkapi::create_extent3d((utils::uvec3)local_wg_size)); } } @@ -115,7 +115,7 @@ void Context::check_device_capabilities(const vkapi::ShaderInfo& shader) { vkapi::DescriptorSet Context::get_descriptor_set( const vkapi::ShaderInfo& shader_descriptor, - const utils::uvec3& local_workgroup_size, + const utils::WorkgroupSize& local_workgroup_size, const vkapi::SpecVarList& additional_constants, const uint32_t push_constants_size) { VkDescriptorSetLayout shader_layout = @@ -124,17 +124,11 @@ vkapi::DescriptorSet Context::get_descriptor_set( VkPipelineLayout pipeline_layout = pipeline_layout_cache().retrieve(shader_layout, push_constants_size); - vkapi::SpecVarList spec_constants = { - SV(local_workgroup_size[0u]), - SV(local_workgroup_size[1u]), - SV(local_workgroup_size[2u])}; - - spec_constants.append(additional_constants); - VkPipeline pipeline = pipeline_cache().retrieve( {pipeline_layout_cache().retrieve(shader_layout, push_constants_size), shader_cache().retrieve(shader_descriptor), - spec_constants}); + additional_constants, + local_workgroup_size}); cmd_.bind_pipeline(pipeline, pipeline_layout, local_workgroup_size); diff --git a/backends/vulkan/runtime/api/Context.h b/backends/vulkan/runtime/api/Context.h index 8bbcf79b45c..6cfbc64f141 100644 --- a/backends/vulkan/runtime/api/Context.h +++ b/backends/vulkan/runtime/api/Context.h @@ -11,6 +11,7 @@ // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName #include +#include #include #include @@ -150,7 +151,7 @@ class Context final { void report_shader_dispatch_start( const std::string& shader_name, const utils::uvec3& global_wg_size, - const utils::uvec3& local_wg_size, + const utils::WorkgroupSize& local_wg_size, const uint32_t dispatch_id = UINT32_MAX); /* @@ -189,13 +190,13 @@ class Context final { vkapi::DescriptorSet get_descriptor_set( const vkapi::ShaderInfo&, - const utils::uvec3&, + const utils::WorkgroupSize&, const vkapi::SpecVarList&, const uint32_t push_constants_size); inline vkapi::DescriptorSet get_descriptor_set( const vkapi::ShaderInfo& shader_descriptor, - const utils::uvec3& local_work_group_size) { + const utils::WorkgroupSize& local_work_group_size) { return get_descriptor_set(shader_descriptor, local_work_group_size, {}, 0u); } @@ -362,14 +363,17 @@ inline bool Context::submit_compute_job( report_shader_dispatch_start( shader.kernel_name, global_work_group, - local_work_group_size, + utils::WorkgroupSize(local_work_group_size), dispatch_id); // Factor out template parameter independent code to minimize code bloat. // Note that push constants are not exposed yet via this API, therefore the // push constants size is assumed to be 0. vkapi::DescriptorSet descriptor_set = get_descriptor_set( - shader, local_work_group_size, specialization_constants, 0u); + shader, + utils::WorkgroupSize(local_work_group_size), + specialization_constants, + 0u); detail::bind( descriptor_set, diff --git a/backends/vulkan/runtime/graph/ops/BlitNode.cpp b/backends/vulkan/runtime/graph/ops/BlitNode.cpp index 463a2d19c36..03ee4caa51a 100644 --- a/backends/vulkan/runtime/graph/ops/BlitNode.cpp +++ b/backends/vulkan/runtime/graph/ops/BlitNode.cpp @@ -46,7 +46,7 @@ void BlitNode::encode(ComputeGraph* graph) { kernel_name += vkapi::to_string(dst_tensor->dtype()); context->report_shader_dispatch_start( - kernel_name, utils::uvec3(), utils::uvec3(), node_id_); + kernel_name, utils::uvec3(), utils::WorkgroupSize(), node_id_); context->register_blit( pipeline_barrier, diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.h b/backends/vulkan/runtime/graph/ops/DispatchNode.h index 7d04f7714e9..4661b5bf9cf 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.h +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.h @@ -92,7 +92,7 @@ class DispatchNode final : public ExecuteNode { protected: const vkapi::ShaderInfo shader_; const utils::uvec3 global_workgroup_size_; - const utils::uvec3 local_workgroup_size_; + const utils::WorkgroupSize local_workgroup_size_; const vkapi::ParamsBindList params_; const vkapi::SpecVarList spec_vars_; const std::vector push_constants_; diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp index bf501296b1b..0507b679e13 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -100,8 +100,8 @@ void PrepackNode::encode(ComputeGraph* graph) { // bound with the correct image layout. { vkapi::PipelineBarrier pipeline_barrier{}; - vkapi::DescriptorSet descriptor_set = - context->get_descriptor_set(noop_shader_, {1, 1, 1}); + vkapi::DescriptorSet descriptor_set = context->get_descriptor_set( + noop_shader_, utils::WorkgroupSize(1, 1, 1)); bind_tensor_to_descriptor_set( *packed, diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.h b/backends/vulkan/runtime/graph/ops/PrepackNode.h index 3e713303c3d..2d194e7f6a0 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.h +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.h @@ -49,7 +49,7 @@ class PrepackNode final { const vkapi::ShaderInfo shader_; vkapi::ShaderInfo noop_shader_; const utils::uvec3 global_workgroup_size_; - const utils::uvec3 local_workgroup_size_; + const utils::WorkgroupSize local_workgroup_size_; const ValueRef tref_; const ValueRef packed_; const vkapi::ParamsBindList params_; diff --git a/backends/vulkan/runtime/vk_api/Command.cpp b/backends/vulkan/runtime/vk_api/Command.cpp index 3be790b53cf..3a5041f9500 100644 --- a/backends/vulkan/runtime/vk_api/Command.cpp +++ b/backends/vulkan/runtime/vk_api/Command.cpp @@ -81,7 +81,7 @@ void CommandBuffer::end() { void CommandBuffer::bind_pipeline( VkPipeline pipeline, VkPipelineLayout pipeline_layout, - const utils::uvec3 local_workgroup_size) { + const utils::WorkgroupSize local_workgroup_size) { VK_CHECK_COND( state_ == CommandBuffer::State::RECORDING, "Vulkan CommandBuffer: called bind_pipeline() on a command buffer whose state " diff --git a/backends/vulkan/runtime/vk_api/Command.h b/backends/vulkan/runtime/vk_api/Command.h index 99cd5d17c99..ff1e5934a5c 100644 --- a/backends/vulkan/runtime/vk_api/Command.h +++ b/backends/vulkan/runtime/vk_api/Command.h @@ -51,7 +51,7 @@ class CommandBuffer final { struct Bound { VkPipeline pipeline; VkPipelineLayout pipeline_layout; - utils::uvec3 local_workgroup_size; + utils::WorkgroupSize local_workgroup_size; VkDescriptorSet descriptors; explicit Bound() @@ -63,7 +63,7 @@ class CommandBuffer final { inline void reset() { pipeline = VK_NULL_HANDLE; pipeline_layout = VK_NULL_HANDLE; - local_workgroup_size = {0u, 0u, 0u}; + local_workgroup_size = utils::WorkgroupSize{0u, 0u, 0u}; descriptors = VK_NULL_HANDLE; } }; @@ -87,7 +87,7 @@ class CommandBuffer final { void begin(); void end(); - void bind_pipeline(VkPipeline, VkPipelineLayout, const utils::uvec3); + void bind_pipeline(VkPipeline, VkPipelineLayout, const utils::WorkgroupSize); void bind_descriptors(VkDescriptorSet); void set_push_constants(VkPipelineLayout, const void*, uint32_t); diff --git a/backends/vulkan/runtime/vk_api/Pipeline.cpp b/backends/vulkan/runtime/vk_api/Pipeline.cpp index 51b59ed4d1f..0de177c6dd5 100644 --- a/backends/vulkan/runtime/vk_api/Pipeline.cpp +++ b/backends/vulkan/runtime/vk_api/Pipeline.cpp @@ -275,14 +275,22 @@ ComputePipeline::ComputePipeline( const ComputePipeline::Descriptor& descriptor, VkPipelineCache pipeline_cache) : device_(device), handle_{VK_NULL_HANDLE} { - std::vector map_entries = - descriptor.specialization_constants.generate_map_entries(); + SpecVarList specialization_constants; + + specialization_constants.reserve(3 + descriptor.specialization_constants.size()); + specialization_constants.append(descriptor.local_wg_size[0]); + specialization_constants.append(descriptor.local_wg_size[1]); + specialization_constants.append(descriptor.local_wg_size[2]); + + specialization_constants.append(descriptor.specialization_constants); + const std::vector map_entries = + specialization_constants.generate_map_entries(); const VkSpecializationInfo specialization_info{ - descriptor.specialization_constants.size(), // mapEntryCount + specialization_constants.size(), // mapEntryCount map_entries.data(), // pMapEntries - descriptor.specialization_constants.data_nbytes(), // dataSize - descriptor.specialization_constants.data(), // pData + specialization_constants.data_nbytes(), // dataSize + specialization_constants.data(), // pData }; const VkPipelineShaderStageCreateInfo shader_stage_create_info{ diff --git a/backends/vulkan/runtime/vk_api/Pipeline.h b/backends/vulkan/runtime/vk_api/Pipeline.h index b9f4e3d2a35..3248051d12a 100644 --- a/backends/vulkan/runtime/vk_api/Pipeline.h +++ b/backends/vulkan/runtime/vk_api/Pipeline.h @@ -156,6 +156,7 @@ class ComputePipeline final { VkPipelineLayout pipeline_layout; VkShaderModule shader_module; SpecVarList specialization_constants; + utils::WorkgroupSize local_wg_size; }; explicit ComputePipeline( @@ -273,6 +274,9 @@ class ComputePipelineCache final { seed = utils::hash_combine(seed, new_seed); } + seed = utils::hash_combine( + seed, std::hash()((uint32_t)descriptor.local_wg_size)); + return seed; } }; diff --git a/backends/vulkan/runtime/vk_api/Shader.h b/backends/vulkan/runtime/vk_api/Shader.h index d9fec65febc..7d0fa7b7476 100644 --- a/backends/vulkan/runtime/vk_api/Shader.h +++ b/backends/vulkan/runtime/vk_api/Shader.h @@ -61,7 +61,7 @@ struct ShaderInfo final { ShaderLayout::Signature kernel_layout{}; // Shader Metadata - utils::uvec3 out_tile_size{1u, 1u, 1u}; + utils::WorkgroupSize out_tile_size{1u, 1u, 1u}; bool requires_shader_int16 = false; bool requires_16bit_storage = false; bool requires_8bit_storage = false;