diff --git a/backends/vulkan/runtime/graph/GraphConfig.cpp b/backends/vulkan/runtime/graph/GraphConfig.cpp index da5efbf8342..20b8f6f7c00 100644 --- a/backends/vulkan/runtime/graph/GraphConfig.cpp +++ b/backends/vulkan/runtime/graph/GraphConfig.cpp @@ -65,6 +65,7 @@ GraphConfig::GraphConfig() { local_wg_size_override = {}; expect_dynamic_shapes = false; + force_resize = false; external_adapter = nullptr; } diff --git a/backends/vulkan/runtime/graph/GraphConfig.h b/backends/vulkan/runtime/graph/GraphConfig.h index aa5cd8f8c4e..7533df3b685 100644 --- a/backends/vulkan/runtime/graph/GraphConfig.h +++ b/backends/vulkan/runtime/graph/GraphConfig.h @@ -35,6 +35,9 @@ struct GraphConfig final { // Whether or not the ComputeGraph should expect input shapes to be dynamic bool expect_dynamic_shapes; + // Used for testing/debugging only. Forces ExecuteNode to trigger the resize + // function even if none of the inputs have been updated. + bool force_resize = false; // Execution properties that determine specifics re: how command buffer // submission is handled, etc. 0 means this field is not set. diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp index 953f15e7b4d..aa46ee76336 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp @@ -21,9 +21,10 @@ ExecuteNode::ExecuteNode( name_(name) {} bool ExecuteNode::trigger_resize(ComputeGraph* graph) { - const bool any_arg_updated = was_any_arg_updated(graph); - if (resize_fn_ && any_arg_updated) { + bool any_arg_updated = was_any_arg_updated(graph); + if (resize_fn_ && (any_arg_updated || graph->graphconfig().force_resize)) { resize_fn_(graph, args_, resize_args_); + any_arg_updated = true; } return any_arg_updated; } diff --git a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp index 879f59667d6..d405825fad1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp @@ -59,7 +59,11 @@ void resize_pool2d_node( if (is_max_pool2d) { const ValueRef indices = args.at(0).refs.at(1); - graph->virtual_resize(indices, new_out_sizes); + // For max_pool2d variant, indices tensor will be a 0-dim tensor - only + // resize the indices tensor if this is not the case. + if (graph->sizes_of(indices).size() > 0) { + graph->virtual_resize(indices, new_out_sizes); + } } } diff --git a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp index 0a98f6d8f43..36a8ee4c3b1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp @@ -54,12 +54,21 @@ void resize_unsqueeze_node( const ValueRef in = args.at(1).refs.at(0); const ValueRef dims_ref = extra_args.at(0); - const IntListPtr dims = graph->get_int_list(dims_ref); + std::vector dims_vec; + if (graph->is_scalar_or_none(dims_ref)) { + // Handle scalar case + int64_t dim = graph->extract_scalar(dims_ref); + dims_vec.push_back(dim); + } else { + // Handle list case + const IntListPtr dims = graph->get_int_list(dims_ref); + dims_vec.assign(dims->begin(), dims->end()); + } std::vector out_sizes = graph->sizes_of(in); // Insert singleton dimensions at the specified positions - for (auto dim : *dims) { + for (auto dim : dims_vec) { int64_t d = dim; if (d < 0) { d += static_cast(out_sizes.size()) + 1; diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py index c368c23c539..08bc502f964 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py @@ -34,6 +34,7 @@ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam(); config.set_storage_type_override(default_storage_type); config.set_memory_layout_override(default_memory_layout); + config.force_resize = true; graph = new ComputeGraph(config); if (test_dtype == at::kHalf) {{