From 1e1df9e9a3b14dfa5de6480c9f90e8eacd753af2 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 7 Aug 2025 09:04:53 -0700 Subject: [PATCH] [ET-VK] Add mechanism to trigger command buffer re-encode only when necessary ## Context Dynamic shape models currently will require the command buffer to be re-encoded every inference. However, this introduces a significant overhead when running models that require dynamic shapes. The reality is that a command buffer re-encode may not be needed every frame. A command buffer re-encode will only be needed when: 1. Shader dispatch parameters change; i.e. new tensor sizes require a completely different compute shader, require new local work group sizing, or require new work group grid size (i.e. global work group size / local work group size) 2. Push constants containing tensor metadata need to be updated This diff aims to reduce the overhead of triggering tensor shape change by detecting when a command buffer re-encode is actually needed. ## Changes `ComputeGraph`: * Introduce `requires_reencode` flag to `ComputeGraph` to indicate when a command buffer re-encode is needed. * Introduce a `std::set` tracking which values were updated when propagating tensor sizes * "update" can be one of two things: 1) tensor sizes changed 2) symint value changed `DispatchNode`: * When propagating new tensor sizes, only execute the resize function if any of the values participating in the computation have been updated * Mark `requries_reencode` if any push constants associated with tensor metadata need to be udpated `DynamicDispatchNode`: * Only recompute compute shader dispatch params if any of the values participating in the computation have been updated * Mark `requires_reencode` if 1) a new compute shader is required, 2) local work group size changed, 3) work group grid size changed Differential Revision: [D79813237](https://our.internmc.facebook.com/intern/diff/D79813237/) [ghstack-poisoned] --- backends/vulkan/runtime/VulkanBackend.cpp | 8 +- .../vulkan/runtime/graph/ComputeGraph.cpp | 68 +++++++++++++- backends/vulkan/runtime/graph/ComputeGraph.h | 39 +++++++- .../graph/containers/PushConstantData.h | 17 ++++ .../vulkan/runtime/graph/ops/DispatchNode.cpp | 38 ++++++++ .../vulkan/runtime/graph/ops/DispatchNode.h | 6 ++ .../runtime/graph/ops/DynamicDispatchNode.cpp | 92 ++++++++++++++++--- .../runtime/graph/ops/DynamicDispatchNode.h | 4 +- .../vulkan/runtime/graph/ops/ExecuteNode.h | 4 +- 9 files changed, 245 insertions(+), 31 deletions(-) diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index ceb95f3a304..b09997e477c 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -582,13 +582,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { } } - // propagate_resize() will re-encode the command buffer so that push - // constants are updated and DynamicDispatchNode can update the compute - // shader, global workgroup size, and local workgroup size to perform the - // model inference. - if (should_propagate_resize || - (compute_graph->graphconfig().expect_dynamic_shapes && - compute_graph->execute_count() == 0u)) { + if (should_propagate_resize) { compute_graph->propagate_resize(); } diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 7bc00e128e5..6b6e9dc2da7 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -206,6 +206,44 @@ utils::StorageType ComputeGraph::suggested_storage_type() { return utils::kTexture3D; } +bool ComputeGraph::was_value_updated(const ValueRef value_ref) const { + // Check if this ValueRef itself was updated + if (updated_values_.find(value_ref) != updated_values_.end()) { + return true; + } + + // If this is a ValueList, check each ValueRef in the list + if (val_is_value_list(value_ref)) { + const auto& value_list = values_.at(value_ref).toConstValueList(); + for (const auto& nested_value_ref : value_list) { + if (was_value_updated(nested_value_ref)) { + return true; + } + } + } + + return false; +} + +bool ComputeGraph::was_value_ref_updated(const ValueRef value_ref) const { + // Check if this ValueRef itself was updated + if (updated_values_.find(value_ref) != updated_values_.end()) { + return true; + } + + // If this is a ValueList, check each ValueRef in the list + if (val_is_value_list(value_ref)) { + const auto& value_list = values_.at(value_ref).toConstValueList(); + for (const auto& nested_value_ref : value_list) { + if (was_value_ref_updated(nested_value_ref)) { + return true; + } + } + } + + return false; +} + utils::GPUMemoryLayout ComputeGraph::suggested_memory_layout( const std::vector& sizes) { if (config_.enable_memory_layout_override) { @@ -569,7 +607,12 @@ vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer( } void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) { - get_symint(idx)->set(val); + int32_t cur_val = read_symint(idx); + if (cur_val != val) { + get_symint(idx)->set(val); + // Track that this ValueRef was updated + updated_values_.insert(idx); + } } int32_t ComputeGraph::read_symint(const ValueRef idx) { @@ -921,6 +964,12 @@ void ComputeGraph::execute() { } execute_count_++; + + // Clear the set of updated values at the end of inference + updated_values_.clear(); + + // Reset the re-encoding flag at the end of inference + requires_reencode_ = false; } void ComputeGraph::virtual_clone(const ValueRef dst, const ValueRef src) { @@ -938,21 +987,30 @@ void ComputeGraph::resize_input( const int64_t idx, const std::vector& new_sizes) { IOValueRef io_val = inputs_.at(idx); - get_tensor(io_val.value)->virtual_resize(new_sizes); + virtual_resize(io_val.value, new_sizes); + updated_values_.insert(io_val.staging); } void ComputeGraph::virtual_resize( const ValueRef idx, const std::vector& new_sizes) { - get_tensor(idx)->virtual_resize(new_sizes); + std::vector cur_sizes = sizes_of(idx); + if (cur_sizes != new_sizes) { + get_tensor(idx)->virtual_resize(new_sizes); + // Track that this ValueRef was updated + updated_values_.insert(idx); + } } void ComputeGraph::propagate_resize() { for (std::unique_ptr& node : execute_nodes_) { node->trigger_resize(this); } - // Only re-encode on resize if dynamic shapes are expected - if (config_.expect_dynamic_shapes) { + // A command buffer re-encode will be needed if: + // 1. Any push constant data (used for tensor metadata) was updated + // 2. Compute shader dispatch parameters (i.e. compute shader, global and + // local work group sizes) were updated + if (requires_reencode_) { clear_deferred_cmds(); } } diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 34b14250314..6664155eed8 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -196,6 +196,12 @@ class ComputeGraph final { // List of command buffers deferred for submission std::vector deferred_cmd_list_; + // Set to track which ValueRefs were updated during inference + std::unordered_set updated_values_; + + // Flag to indicate if re-encoding is required + bool requires_reencode_ = false; + protected: size_t values_in_use_ = 0; size_t execute_count_ = 0; @@ -419,31 +425,41 @@ class ComputeGraph final { } inline PushConstantDataInfo sizes_pc_of(const ValueRef idx) const { - return PushConstantDataInfo( + PushConstantDataInfo pc_data = PushConstantDataInfo( values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorSizes); + pc_data.set_value(idx); + return pc_data; } inline PushConstantDataInfo dim_order_pc_of(const ValueRef idx) const { - return PushConstantDataInfo( + PushConstantDataInfo pc_data = PushConstantDataInfo( values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorDimOrder); + pc_data.set_value(idx); + return pc_data; } inline PushConstantDataInfo strides_pc_of(const ValueRef idx) const { - return PushConstantDataInfo( + PushConstantDataInfo pc_data = PushConstantDataInfo( values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorStrides); + pc_data.set_value(idx); + return pc_data; } inline PushConstantDataInfo logical_limits_pc_of(const ValueRef idx) const { - return PushConstantDataInfo( + PushConstantDataInfo pc_data = PushConstantDataInfo( values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorLogicalLimits); + pc_data.set_value(idx); + return pc_data; } inline PushConstantDataInfo numel_pc_of(const ValueRef idx) const { - return PushConstantDataInfo( + PushConstantDataInfo pc_data = PushConstantDataInfo( values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorNumel); + pc_data.set_value(idx); + return pc_data; } // @@ -940,6 +956,19 @@ class ComputeGraph final { void propagate_resize(); + // Check if a specific ValueRef (or ValueList) was updated, with recursive + // handling + bool was_value_updated(const ValueRef value_ref) const; + + // Check if a specific ValueRef (or ValueList) was updated, with recursive + // handling + bool was_value_ref_updated(const ValueRef value_ref) const; + + // Set the flag to indicate that re-encoding is required + inline void set_requires_reencode() { + requires_reencode_ = true; + } + // // Miscellaneous Utilities // diff --git a/backends/vulkan/runtime/graph/containers/PushConstantData.h b/backends/vulkan/runtime/graph/containers/PushConstantData.h index 39cde4722a7..c5185eafa25 100644 --- a/backends/vulkan/runtime/graph/containers/PushConstantData.h +++ b/backends/vulkan/runtime/graph/containers/PushConstantData.h @@ -10,6 +10,8 @@ #include +#include + namespace vkcompute { class ComputeGraph; @@ -33,6 +35,9 @@ class PushConstantDataInfo { }; Payload payload_; + // The value in a compute graph that this push constant data is associated + // with, if any. + ValueRef value_ = kDummyValueRef; public: explicit PushConstantDataInfo( @@ -60,6 +65,18 @@ class PushConstantDataInfo { void* dst, const uint32_t dst_offset, const uint32_t max_dst_size) const; + + inline bool is_tensor_metadata() const { + return tensorUniformData != nullptr; + } + + inline void set_value(ValueRef value) { + value_ = value; + } + + inline ValueRef value() const { + return value_; + } }; } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp index b5644cf3dcd..8a2b7904c17 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp @@ -89,4 +89,42 @@ void DispatchNode::write_push_constant_data() { } } +bool DispatchNode::trigger_resize(ComputeGraph* graph) { + bool any_value_updated = was_any_value_updated(graph); + if (resize_fn_ != nullptr && any_value_updated) { + resize_fn_(graph, args_, resize_args_); + + // If this shader uses push constants, and the tensor metadata associated + // with the push constants has changed, then the command buffer needs to be + // re-encoded since push constants cannot be updated. + for (const auto& push_constant : push_constants_) { + if (push_constant.is_tensor_metadata() && + graph->was_value_ref_updated(push_constant.value())) { + graph->set_requires_reencode(); + } + } + } + return any_value_updated; +} + +bool DispatchNode::was_any_value_updated(ComputeGraph* graph) const { + // Check all ValueRefs in ArgGroups + for (const auto& arg_group : args_) { + for (const auto& value_ref : arg_group.refs) { + if (graph->was_value_ref_updated(value_ref)) { + return true; + } + } + } + + // Check all ValueRefs in resize_args + for (const auto& value_ref : resize_args_) { + if (graph->was_value_ref_updated(value_ref)) { + return true; + } + } + + return false; +} + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.h b/backends/vulkan/runtime/graph/ops/DispatchNode.h index b6eb8624c26..08871c89441 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.h +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.h @@ -44,6 +44,12 @@ class DispatchNode : public ExecuteNode { void encode(ComputeGraph* graph) override; + bool trigger_resize(ComputeGraph* graph) override; + + private: + // Helper function to check if any ValueRef was updated + bool was_any_value_updated(ComputeGraph* graph) const; + protected: vkapi::ShaderInfo shader_; utils::uvec3 global_workgroup_size_; diff --git a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp index ea2061d3d7c..9f6562169b9 100644 --- a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp @@ -41,6 +41,12 @@ DynamicDispatchNode::DynamicDispatchNode( pick_global_wg_fn(&graph, shader_, args, resize_args); local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn( &graph, shader_, global_workgroup_size_, args, resize_args)); + + // Calculate dispatch grid similar to Context.cpp register_shader_dispatch + wg_dispatch_grid_ = { + utils::div_up(global_workgroup_size_[0], local_workgroup_size_[0]), + utils::div_up(global_workgroup_size_[1], local_workgroup_size_[1]), + utils::div_up(global_workgroup_size_[2], local_workgroup_size_[2])}; } DynamicDispatchNode::DynamicDispatchNode( @@ -72,21 +78,83 @@ DynamicDispatchNode::DynamicDispatchNode( pick_global_wg_fn(&graph, shader_, args, resize_args); local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn( &graph, shader_, global_workgroup_size_, args, resize_args)); + // Calculate the work group grid that will be dispatched + wg_dispatch_grid_ = { + utils::div_up(global_workgroup_size_[0], local_workgroup_size_[0]), + utils::div_up(global_workgroup_size_[1], local_workgroup_size_[1]), + utils::div_up(global_workgroup_size_[2], local_workgroup_size_[2])}; } -void DynamicDispatchNode::encode(ComputeGraph* graph) { - if (pick_shader_fn_) { - shader_ = pick_shader_fn_(graph, args_, resize_args_); - } - if (pick_global_wg_fn_) { - global_workgroup_size_ = - pick_global_wg_fn_(graph, shader_, args_, resize_args_); - } - if (pick_local_wg_fn_) { - local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn_( - graph, shader_, global_workgroup_size_, args_, resize_args_)); +bool DynamicDispatchNode::trigger_resize(ComputeGraph* graph) { + // DispatchNode::trigger_resize() will return true if any of the values + // participating in this operation were updated. + bool any_value_updated = DispatchNode::trigger_resize(graph); + // Indicates if the shader dispatch will have changed since the last time the + // command buffer was encoded. + bool dispatch_params_changed = false; + + // Only re-compute the shader, global workgroup size, and local workgroup size + // if any of the values participating in this operation were updated. + // Otherwise, assume that these will not have changed. + if (any_value_updated) { + if (pick_shader_fn_) { + vkapi::ShaderInfo new_shader = + pick_shader_fn_(graph, args_, resize_args_); + // Compare shader kernel names as a proxy for shader equality + if (shader_.kernel_name != new_shader.kernel_name) { + shader_ = new_shader; + dispatch_params_changed = true; + } + } + if (pick_global_wg_fn_) { + utils::uvec3 new_global_wg = + pick_global_wg_fn_(graph, shader_, args_, resize_args_); + if (global_workgroup_size_[0] != new_global_wg[0] || + global_workgroup_size_[1] != new_global_wg[1] || + global_workgroup_size_[2] != new_global_wg[2]) { + global_workgroup_size_ = new_global_wg; + // Note that if global workgroup size changes, then the dispatch params + // may not actually be different. The actual value to check is the + // work group grid size that will be dispatched, which is calculated + // below. + } + } + if (pick_local_wg_fn_) { + utils::uvec3 new_local_wg_uvec3 = pick_local_wg_fn_( + graph, shader_, global_workgroup_size_, args_, resize_args_); + utils::WorkgroupSize new_local_wg = + utils::WorkgroupSize(new_local_wg_uvec3); + if (local_workgroup_size_[0] != new_local_wg[0] || + local_workgroup_size_[1] != new_local_wg[1] || + local_workgroup_size_[2] != new_local_wg[2]) { + local_workgroup_size_ = new_local_wg; + dispatch_params_changed = true; + } + } + + // Always recompute the new dispatch grid and check if it's different + utils::uvec3 new_wg_dispatch_grid = { + utils::div_up(global_workgroup_size_[0], local_workgroup_size_[0]), + utils::div_up(global_workgroup_size_[1], local_workgroup_size_[1]), + utils::div_up(global_workgroup_size_[2], local_workgroup_size_[2])}; + + // Check if the new dispatch grid is different from the old one + if (wg_dispatch_grid_[0] != new_wg_dispatch_grid[0] || + wg_dispatch_grid_[1] != new_wg_dispatch_grid[1] || + wg_dispatch_grid_[2] != new_wg_dispatch_grid[2]) { + dispatch_params_changed = true; + } + // Update the dispatch grid + wg_dispatch_grid_ = new_wg_dispatch_grid; + + // If any of the dispatch params have changed, then the command buffer must + // be re-encoded. + if (dispatch_params_changed) { + graph->set_requires_reencode(); + } } - DispatchNode::encode(graph); + + return any_value_updated; } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h index 005151272c3..16777b5a92f 100644 --- a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h +++ b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h @@ -68,13 +68,15 @@ class DynamicDispatchNode final : public DispatchNode { ~DynamicDispatchNode() override = default; - void encode(ComputeGraph* graph) override; + bool trigger_resize(ComputeGraph* graph) override; protected: const PickShaderFn pick_shader_fn_; const PickGlobalFn pick_global_wg_fn_; const PickLocalFn pick_local_wg_fn_; + utils::uvec3 wg_dispatch_grid_; + public: operator bool() const { return shader_; diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.h b/backends/vulkan/runtime/graph/ops/ExecuteNode.h index 6a815b246ef..c7a6b8f70a4 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.h +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.h @@ -69,10 +69,12 @@ class ExecuteNode { (void)graph; } - virtual inline void trigger_resize(ComputeGraph* graph) { + virtual inline bool trigger_resize(ComputeGraph* graph) { if (resize_fn_ != nullptr) { resize_fn_(graph, args_, resize_args_); + return true; } + return false; } inline void set_node_id(uint32_t node_id) {