From 46216e003e37ce1a609cf1b632cfa977234a9d4c Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 15 Oct 2025 09:07:30 -0700 Subject: [PATCH] [ET-VK][ez] Introduce a graph config setting to force resize functions to execute Title says it all! A few months back, a mechanism was introduced where an `ExecuteNode` would not call an operator's resize function if none of the arguments were updated. However, this creates a blind spot during testing where the resize function of operators are not tested since the generated operator tests do not modify input sizes. To address this, add a way to force the resize function to be called during testing. Differential Revision: [D84716451](https://our.internmc.facebook.com/intern/diff/D84716451/) [ghstack-poisoned] --- backends/vulkan/runtime/graph/GraphConfig.cpp | 1 + backends/vulkan/runtime/graph/GraphConfig.h | 3 +++ backends/vulkan/runtime/graph/ops/ExecuteNode.cpp | 5 +++-- backends/vulkan/test/op_tests/utils/gen_correctness_vk.py | 1 + 4 files changed, 8 insertions(+), 2 deletions(-) diff --git a/backends/vulkan/runtime/graph/GraphConfig.cpp b/backends/vulkan/runtime/graph/GraphConfig.cpp index da5efbf8342..20b8f6f7c00 100644 --- a/backends/vulkan/runtime/graph/GraphConfig.cpp +++ b/backends/vulkan/runtime/graph/GraphConfig.cpp @@ -65,6 +65,7 @@ GraphConfig::GraphConfig() { local_wg_size_override = {}; expect_dynamic_shapes = false; + force_resize = false; external_adapter = nullptr; } diff --git a/backends/vulkan/runtime/graph/GraphConfig.h b/backends/vulkan/runtime/graph/GraphConfig.h index aa5cd8f8c4e..7533df3b685 100644 --- a/backends/vulkan/runtime/graph/GraphConfig.h +++ b/backends/vulkan/runtime/graph/GraphConfig.h @@ -35,6 +35,9 @@ struct GraphConfig final { // Whether or not the ComputeGraph should expect input shapes to be dynamic bool expect_dynamic_shapes; + // Used for testing/debugging only. Forces ExecuteNode to trigger the resize + // function even if none of the inputs have been updated. + bool force_resize = false; // Execution properties that determine specifics re: how command buffer // submission is handled, etc. 0 means this field is not set. diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp index 953f15e7b4d..aa46ee76336 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp @@ -21,9 +21,10 @@ ExecuteNode::ExecuteNode( name_(name) {} bool ExecuteNode::trigger_resize(ComputeGraph* graph) { - const bool any_arg_updated = was_any_arg_updated(graph); - if (resize_fn_ && any_arg_updated) { + bool any_arg_updated = was_any_arg_updated(graph); + if (resize_fn_ && (any_arg_updated || graph->graphconfig().force_resize)) { resize_fn_(graph, args_, resize_args_); + any_arg_updated = true; } return any_arg_updated; } diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py index c368c23c539..08bc502f964 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py @@ -34,6 +34,7 @@ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam(); config.set_storage_type_override(default_storage_type); config.set_memory_layout_override(default_memory_layout); + config.force_resize = true; graph = new ComputeGraph(config); if (test_dtype == at::kHalf) {{