diff --git a/backends/aoti/aoti_partitioner.py b/backends/aoti/aoti_partitioner.py index 499bc57b735..aa56d3507e9 100644 --- a/backends/aoti/aoti_partitioner.py +++ b/backends/aoti/aoti_partitioner.py @@ -52,10 +52,24 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: partition_tags: Dict[str, DelegationSpec] = {} tag = "tag0" + # Tag torch.cond and other control flow operations + def is_control_flow(node: torch.fx.Node) -> bool: + return node.op == "call_function" and node.target in [ + torch.ops.higher_order.cond, + torch.ops.higher_order.map_impl, + torch.ops.higher_order.while_loop, + ] + for node in exported_program.graph.nodes: - if node.op != "call_function": - continue - node.meta["delegation_tag"] = tag + if node.op == "call_function": + node.meta["delegation_tag"] = tag + # Tag get_attr nodes that are used by control flow operations + elif node.op == "get_attr": + # Check if any user is a control flow operation + for user in node.users: + if is_control_flow(user): + node.meta["delegation_tag"] = tag + break partition_tags[tag] = self.delegation_spec