diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index bc3bf14bf14..44a8ae694ed 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -97,7 +97,10 @@ def op_node_is_compatible( # noqa: C901: Function is too complex """ target = node.target # Account for custom operators - if node.target == torch.ops.higher_order.auto_functionalized: + if ( + node.target == torch.ops.higher_order.auto_functionalized + or node.target == torch.ops.higher_order.auto_functionalized_v2 + ): first_arg = node.args[0] assert isinstance(first_arg, torch._ops.OpOverload) target = first_arg.name() @@ -218,7 +221,10 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901 return True target = node.target - if node.target == torch.ops.higher_order.auto_functionalized: + if ( + node.target == torch.ops.higher_order.auto_functionalized + or node.target == torch.ops.higher_order.auto_functionalized_v2 + ): first_arg = node.args[0] assert isinstance(first_arg, torch._ops.OpOverload) target = first_arg.name() diff --git a/backends/vulkan/test/utils.py b/backends/vulkan/test/utils.py index ab0a15ce4cf..e6f20a68c7b 100644 --- a/backends/vulkan/test/utils.py +++ b/backends/vulkan/test/utils.py @@ -819,7 +819,10 @@ def print_occurrences(edge_program, operator_list: List): if utils.is_torch_op_node(node): target = node.target # Handle auto_functionalized nodes - if node.target == torch.ops.higher_order.auto_functionalized: + if ( + node.target == torch.ops.higher_order.auto_functionalized + or node.target == torch.ops.higher_order.auto_functionalized_v2 + ): first_arg = node.args[0] if hasattr(first_arg, "name"): target = first_arg.name() @@ -907,7 +910,10 @@ def op_ablation_test( # noqa: C901 if utils.is_torch_op_node(node): target = node.target # Handle auto_functionalized nodes - if node.target == torch.ops.higher_order.auto_functionalized: + if ( + node.target == torch.ops.higher_order.auto_functionalized + or node.target == torch.ops.higher_order.auto_functionalized_v2 + ): first_arg = node.args[0] if hasattr(first_arg, "name"): target = first_arg.name()