From 24cfc84b0dfc6fc1f51b1e4015c797170144ea7c Mon Sep 17 00:00:00 2001 From: Michael Adragna Date: Fri, 25 Jul 2025 15:48:15 -0700 Subject: [PATCH] Change node group partitioning to be with all nodes, to keep partition ids in top sort order --- .../group_partitioner.py | 144 ++++++------------ 1 file changed, 48 insertions(+), 96 deletions(-) diff --git a/exir/backend/canonical_partitioners/group_partitioner.py b/exir/backend/canonical_partitioners/group_partitioner.py index 63bedad3b42..2594bbe05c4 100644 --- a/exir/backend/canonical_partitioners/group_partitioner.py +++ b/exir/backend/canonical_partitioners/group_partitioner.py @@ -123,7 +123,7 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id): return True - def _process_node_groups( + def _process_all_nodes( self, new_partition_id, partitions_by_id, @@ -133,97 +133,60 @@ def _process_node_groups( partition_users, partition_map, ): - """Process nodes in predefined groups.""" - group_to_partition_id = {} - - if not self.node_groups: - return group_to_partition_id - - processed_nodes = set() - - # We have to create the partitions in reverse topological order - # so we find the groups as we traverse backwards in the graph - # this likely needs to be combined with the process_remaining_nodes - # TODO: this currently doesn't work with _process_remaining_nodes so - # if a user provides grouped nodes with operatorsupport, then this will - # faile + """Process nodes into a partition.""" for node in reversed(self.graph_module.graph.nodes): - if node not in self.node_to_group: + if node in assignment or not self._is_node_supported(node): continue - if node in processed_nodes: - continue + if node in self.all_nodes_in_groups: + group_idx = self.node_to_group[node] + group = self.node_groups[group_idx] - group_idx = self.node_to_group[node] - group = self.node_groups[group_idx] - - # Create a partition for group - partition_id = next(new_partition_id) - partition = Partition(id=partition_id, nodes=set()) - partitions_by_id[partition_id] = partition - partitions_order[partition_id] = partition_id - group_to_partition_id[group_idx] = partition_id - - # Add all supported nodes from the group to the partition - for node in group: - if self._is_node_supported(node): - partition.add_node(node) - assignment[node] = partition_id - nodes_order[node] = partition_id - - # Set partition users - partition_users[partition_id] = { - user - for node in partition.nodes - for user in node.users - if user not in partition.nodes - } - - # Update partition map - for node in partition.nodes: + # Create a partition for group + partition_id = next(new_partition_id) + partition = Partition(id=partition_id, nodes=set()) + partitions_by_id[partition_id] = partition + partitions_order[partition_id] = partition_id + + # Add all supported nodes from the group to the partition + for node in group: + if self._is_node_supported(node): + partition.add_node(node) + assignment[node] = partition_id + nodes_order[node] = partition_id + + # Set partition users + partition_users[partition_id] = { + user + for node in partition.nodes + for user in node.users + if user not in partition.nodes + } + + # Update partition map + for node in partition.nodes: + for user in node.users: + target_id = assignment.get(user, None) + if target_id is not None and target_id != partition_id: + partition_map[partition_id].add(target_id) + partition_map[partition_id].update(partition_map[target_id]) + else: + partition_id = next(new_partition_id) + nodes_order[node] = partition_id + partitions_order[partition_id] = partition_id + partitions_by_id[partition_id] = Partition( + id=partition_id, nodes=[node] + ) + assignment[node] = partition_id + partition_users[partition_id] = set(node.users) + + # Update partition map for user in node.users: target_id = assignment.get(user) - if target_id is not None and target_id != partition_id: + if target_id is not None: partition_map[partition_id].add(target_id) partition_map[partition_id].update(partition_map[target_id]) - # all the nodes in the group have now been processed - # so skip if we encoutner them again in our rev topo - # iteration - for node in group: - processed_nodes.add(node) - - return group_to_partition_id - - def _process_remaining_nodes( - self, - new_partition_id, - partitions_by_id, - assignment, - nodes_order, - partitions_order, - partition_users, - partition_map, - ): - """Process nodes not in any predefined group.""" - for node in reversed(self.graph_module.graph.nodes): - if node in assignment or not self._is_node_supported(node): - continue - - partition_id = next(new_partition_id) - nodes_order[node] = partition_id - partitions_order[partition_id] = partition_id - partitions_by_id[partition_id] = Partition(id=partition_id, nodes=[node]) - assignment[node] = partition_id - partition_users[partition_id] = set(node.users) - - # Update partition map - for user in node.users: - target_id = assignment.get(user) - if target_id is not None: - partition_map[partition_id].add(target_id) - partition_map[partition_id].update(partition_map[target_id]) - def _merge_partitions( self, partitions_by_id, @@ -378,19 +341,8 @@ def propose_partitions(self) -> list[Partition]: partition_users = {} # Maps partition IDs to partition users new_partition_id = itertools.count() - # Process nodes in predefined groups - self._process_node_groups( - new_partition_id, - partitions_by_id, - assignment, - nodes_order, - partitions_order, - partition_users, - partition_map, - ) - - # Process remaining nodes - self._process_remaining_nodes( + # Process all nodes into partitions + self._process_all_nodes( new_partition_id, partitions_by_id, assignment,