Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 48 additions & 96 deletions exir/backend/canonical_partitioners/group_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id):

return True

def _process_node_groups(
def _process_all_nodes(
self,
new_partition_id,
partitions_by_id,
Expand All @@ -133,97 +133,60 @@ def _process_node_groups(
partition_users,
partition_map,
):
"""Process nodes in predefined groups."""
group_to_partition_id = {}

if not self.node_groups:
return group_to_partition_id

processed_nodes = set()

# We have to create the partitions in reverse topological order
# so we find the groups as we traverse backwards in the graph
# this likely needs to be combined with the process_remaining_nodes
# TODO: this currently doesn't work with _process_remaining_nodes so
# if a user provides grouped nodes with operatorsupport, then this will
# faile
"""Process nodes into a partition."""
for node in reversed(self.graph_module.graph.nodes):
if node not in self.node_to_group:
if node in assignment or not self._is_node_supported(node):
continue

if node in processed_nodes:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about this check? The purpose was that when traversing nodes in the graph module, we could get two nodes from the same group, and we could potentially add the group twice?

continue
if node in self.all_nodes_in_groups:
group_idx = self.node_to_group[node]
group = self.node_groups[group_idx]

group_idx = self.node_to_group[node]
group = self.node_groups[group_idx]

# Create a partition for group
partition_id = next(new_partition_id)
partition = Partition(id=partition_id, nodes=set())
partitions_by_id[partition_id] = partition
partitions_order[partition_id] = partition_id
group_to_partition_id[group_idx] = partition_id

# Add all supported nodes from the group to the partition
for node in group:
if self._is_node_supported(node):
partition.add_node(node)
assignment[node] = partition_id
nodes_order[node] = partition_id

# Set partition users
partition_users[partition_id] = {
user
for node in partition.nodes
for user in node.users
if user not in partition.nodes
}

# Update partition map
for node in partition.nodes:
# Create a partition for group
partition_id = next(new_partition_id)
partition = Partition(id=partition_id, nodes=set())
partitions_by_id[partition_id] = partition
partitions_order[partition_id] = partition_id

# Add all supported nodes from the group to the partition
for node in group:
if self._is_node_supported(node):
partition.add_node(node)
assignment[node] = partition_id
nodes_order[node] = partition_id

# Set partition users
partition_users[partition_id] = {
user
for node in partition.nodes
for user in node.users
if user not in partition.nodes
}

# Update partition map
for node in partition.nodes:
for user in node.users:
target_id = assignment.get(user, None)
if target_id is not None and target_id != partition_id:
partition_map[partition_id].add(target_id)
partition_map[partition_id].update(partition_map[target_id])
else:
partition_id = next(new_partition_id)
nodes_order[node] = partition_id
partitions_order[partition_id] = partition_id
partitions_by_id[partition_id] = Partition(
id=partition_id, nodes=[node]
)
assignment[node] = partition_id
partition_users[partition_id] = set(node.users)

# Update partition map
for user in node.users:
target_id = assignment.get(user)
if target_id is not None and target_id != partition_id:
if target_id is not None:
partition_map[partition_id].add(target_id)
partition_map[partition_id].update(partition_map[target_id])

# all the nodes in the group have now been processed
# so skip if we encoutner them again in our rev topo
# iteration
for node in group:
processed_nodes.add(node)

return group_to_partition_id

def _process_remaining_nodes(
self,
new_partition_id,
partitions_by_id,
assignment,
nodes_order,
partitions_order,
partition_users,
partition_map,
):
"""Process nodes not in any predefined group."""
for node in reversed(self.graph_module.graph.nodes):
if node in assignment or not self._is_node_supported(node):
continue

partition_id = next(new_partition_id)
nodes_order[node] = partition_id
partitions_order[partition_id] = partition_id
partitions_by_id[partition_id] = Partition(id=partition_id, nodes=[node])
assignment[node] = partition_id
partition_users[partition_id] = set(node.users)

# Update partition map
for user in node.users:
target_id = assignment.get(user)
if target_id is not None:
partition_map[partition_id].add(target_id)
partition_map[partition_id].update(partition_map[target_id])

def _merge_partitions(
self,
partitions_by_id,
Expand Down Expand Up @@ -378,19 +341,8 @@ def propose_partitions(self) -> list[Partition]:
partition_users = {} # Maps partition IDs to partition users
new_partition_id = itertools.count()

# Process nodes in predefined groups
self._process_node_groups(
new_partition_id,
partitions_by_id,
assignment,
nodes_order,
partitions_order,
partition_users,
partition_map,
)

# Process remaining nodes
self._process_remaining_nodes(
# Process all nodes into partitions
self._process_all_nodes(
new_partition_id,
partitions_by_id,
assignment,
Expand Down
Loading