From 3c7ff301acee7b516f08b27e1e6d2e9f0606bfd7 Mon Sep 17 00:00:00 2001 From: Sicheng Stephen Jia Date: Thu, 22 Jan 2026 14:12:29 -0500 Subject: [PATCH] [ET-VK][ez] Add support for auto_functionalized_v2 (#16790) Summary: Update the Vulkan partitioner to recognize torch.ops.higher_order.auto_functionalized_v2 in addition to the existing auto_functionalized op when determining operator support and extracting target names for custom operators. (cherry picked from commit 9f6f8285347addabff62882c4202dd6951238047) --- backends/vulkan/partitioner/vulkan_partitioner.py | 10 ++++++++-- backends/vulkan/test/utils.py | 10 ++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index bc3bf14bf14..44a8ae694ed 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -97,7 +97,10 @@ def op_node_is_compatible( # noqa: C901: Function is too complex """ target = node.target # Account for custom operators - if node.target == torch.ops.higher_order.auto_functionalized: + if ( + node.target == torch.ops.higher_order.auto_functionalized + or node.target == torch.ops.higher_order.auto_functionalized_v2 + ): first_arg = node.args[0] assert isinstance(first_arg, torch._ops.OpOverload) target = first_arg.name() @@ -218,7 +221,10 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901 return True target = node.target - if node.target == torch.ops.higher_order.auto_functionalized: + if ( + node.target == torch.ops.higher_order.auto_functionalized + or node.target == torch.ops.higher_order.auto_functionalized_v2 + ): first_arg = node.args[0] assert isinstance(first_arg, torch._ops.OpOverload) target = first_arg.name() diff --git a/backends/vulkan/test/utils.py b/backends/vulkan/test/utils.py index ab0a15ce4cf..e6f20a68c7b 100644 --- a/backends/vulkan/test/utils.py +++ b/backends/vulkan/test/utils.py @@ -819,7 +819,10 @@ def print_occurrences(edge_program, operator_list: List): if utils.is_torch_op_node(node): target = node.target # Handle auto_functionalized nodes - if node.target == torch.ops.higher_order.auto_functionalized: + if ( + node.target == torch.ops.higher_order.auto_functionalized + or node.target == torch.ops.higher_order.auto_functionalized_v2 + ): first_arg = node.args[0] if hasattr(first_arg, "name"): target = first_arg.name() @@ -907,7 +910,10 @@ def op_ablation_test( # noqa: C901 if utils.is_torch_op_node(node): target = node.target # Handle auto_functionalized nodes - if node.target == torch.ops.higher_order.auto_functionalized: + if ( + node.target == torch.ops.higher_order.auto_functionalized + or node.target == torch.ops.higher_order.auto_functionalized_v2 + ): first_arg = node.args[0] if hasattr(first_arg, "name"): target = first_arg.name()