Skip to content

Commit bee7ff6

Browse files
committed
Added group-based partitioner
1 parent b74c68d commit bee7ff6

File tree

2 files changed

+2061
-0
lines changed

2 files changed

+2061
-0
lines changed
Lines changed: 389 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,389 @@
1+
# mypy: allow-untyped-defs
2+
import collections
3+
import itertools
4+
import logging
5+
from collections.abc import Sequence
6+
from typing import List, Optional
7+
8+
from torch.fx.graph_module import GraphModule
9+
from torch.fx.node import _get_qualified_name, Node
10+
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
11+
from torch.fx.passes.operator_support import OperatorSupportBase
12+
13+
14+
logger = logging.getLogger(__name__)
15+
logger.setLevel(logging.WARNING)
16+
17+
18+
class _DependencyViewer:
19+
def __init__(self, graph_module: GraphModule):
20+
self.downstreams = collections.defaultdict(set)
21+
self.upstreams = collections.defaultdict(set)
22+
23+
for node in reversed(graph_module.graph.nodes):
24+
for output_node in node.users:
25+
# add output_node and output_node's downstream dependency
26+
self.downstreams[node].add(output_node)
27+
self.downstreams[node].update(self.downstreams[output_node])
28+
29+
for node in graph_module.graph.nodes:
30+
for input_node in node.all_input_nodes:
31+
self.upstreams[node].add(input_node)
32+
self.upstreams[node].update(self.upstreams[input_node])
33+
34+
def downstreams_of(self, node: Node) -> set[Node]:
35+
return self.downstreams[node]
36+
37+
def upstreams_of(self, node: Node) -> set[Node]:
38+
return self.upstreams[node]
39+
40+
41+
class GroupBasedPartitioner(CapabilityBasedPartitioner):
42+
"""
43+
A specialized partitioner that extends the CapabilityBasedPartitioner from PyTorch FX.
44+
45+
GroupBasedPartitioner allows for explicit grouping of nodes into partitions based on
46+
predefined node groups, while also supporting automatic partitioning for nodes not
47+
included in any group. Nodes are only allowed to be in one group.
48+
49+
Features:
50+
- Explicit Node Grouping: Allows specifying groups of nodes that should be kept together
51+
in the same partition.
52+
- Automatic Partitioning: Nodes not included in any explicit group are automatically
53+
partitioned based on operator support.
54+
- Cycle Prevention: Ensures that partitioning doesn't create cycles in the execution graph.
55+
- Single Node Partition Control: Options to allow or disallow single-node partitions,
56+
with exceptions for specific operations.
57+
58+
Args:
59+
graph_module: The FX GraphModule to be partitioned.
60+
operator_support: Interface to determine if a node is supported by the target backend.
61+
allows_single_node_partition: Whether to allow single-node partitions. Default: False.
62+
non_compute_ops: Operations not counted for single-node partition determination. Default: None.
63+
allowed_single_node_partition_ops: Operations allowed as single-node partitions. Default: None.
64+
node_groups: Lists of nodes to group together in partitions. Default: None.
65+
"""
66+
67+
def __init__(
68+
self,
69+
graph_module: GraphModule,
70+
operator_support: OperatorSupportBase,
71+
allows_single_node_partition: bool = False,
72+
non_compute_ops: Optional[Sequence[str]] = None,
73+
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
74+
node_groups: List[List[Node]] = None,
75+
) -> None:
76+
super().__init__(
77+
graph_module=graph_module,
78+
operator_support=operator_support,
79+
allows_single_node_partition=allows_single_node_partition,
80+
non_compute_ops=non_compute_ops,
81+
allowed_single_node_partition_ops=allowed_single_node_partition_ops,
82+
)
83+
self.dependency_viewer = _DependencyViewer(graph_module)
84+
self.node_groups = (
85+
[set(node_group) for node_group in node_groups] if node_groups else None
86+
)
87+
self.node_to_group = collections.defaultdict(int)
88+
self.all_nodes_in_groups = set()
89+
if node_groups:
90+
for i, group in enumerate(self.node_groups):
91+
for node in group:
92+
# Node is in multiple groups - not allowed
93+
if node in self.node_to_group:
94+
raise ValueError(f"Node {node} exists in multiple groups.")
95+
self.node_to_group[node] = i
96+
self.all_nodes_in_groups.add(node)
97+
98+
def _can_merge_partitions(self, p1, p2, partitions_by_id):
99+
"""Check if merging two partitions would create a cycle."""
100+
p1_nodes = set(partitions_by_id[p1].nodes.keys())
101+
p2_nodes = set(partitions_by_id[p2].nodes.keys())
102+
combined_nodes = p1_nodes.union(p2_nodes)
103+
104+
for node in combined_nodes:
105+
# Get all downstream nodes that are not in the combined partition
106+
external_downstreams = {
107+
n
108+
for n in self.dependency_viewer.downstreams_of(node)
109+
if n not in combined_nodes
110+
}
111+
112+
# Check if any external downstream nodes have downstream nodes in the combined partition
113+
for external_node in external_downstreams:
114+
downstream_nodes = self.dependency_viewer.downstreams_of(external_node)
115+
if any(n in combined_nodes for n in downstream_nodes):
116+
return False
117+
118+
return True
119+
120+
def _process_node_groups(
121+
self,
122+
new_partition_id,
123+
partitions_by_id,
124+
assignment,
125+
nodes_order,
126+
partitions_order,
127+
partition_users,
128+
partition_map,
129+
):
130+
"""Process nodes in predefined groups."""
131+
group_to_partition_id = {}
132+
133+
if not self.node_groups:
134+
return group_to_partition_id
135+
136+
for i, group in enumerate(self.node_groups):
137+
# Create a partition for each group
138+
partition_id = next(new_partition_id)
139+
partition = Partition(id=partition_id, nodes=set())
140+
partitions_by_id[partition_id] = partition
141+
partitions_order[partition_id] = partition_id
142+
group_to_partition_id[i] = partition_id
143+
144+
# Add all supported nodes from the group to the partition
145+
for node in group:
146+
if self._is_node_supported(node):
147+
partition.add_node(node)
148+
assignment[node] = partition_id
149+
nodes_order[node] = partition_id
150+
151+
# Set partition users
152+
partition_users[partition_id] = {
153+
user
154+
for node in partition.nodes
155+
for user in node.users
156+
if user not in partition.nodes
157+
}
158+
159+
# Update partition map
160+
for node in partition.nodes:
161+
for user in node.users:
162+
target_id = assignment.get(user)
163+
if target_id is not None and target_id != partition_id:
164+
partition_map[partition_id].add(target_id)
165+
partition_map[partition_id].update(partition_map[target_id])
166+
167+
return group_to_partition_id
168+
169+
def _process_remaining_nodes(
170+
self,
171+
new_partition_id,
172+
partitions_by_id,
173+
assignment,
174+
nodes_order,
175+
partitions_order,
176+
partition_users,
177+
partition_map,
178+
):
179+
"""Process nodes not in any predefined group."""
180+
for node in reversed(self.graph_module.graph.nodes):
181+
if node in assignment or not self._is_node_supported(node):
182+
continue
183+
184+
partition_id = next(new_partition_id)
185+
nodes_order[node] = partition_id
186+
partitions_order[partition_id] = partition_id
187+
partitions_by_id[partition_id] = Partition(id=partition_id, nodes=[node])
188+
assignment[node] = partition_id
189+
partition_users[partition_id] = set(node.users)
190+
191+
# Update partition map
192+
for user in node.users:
193+
target_id = assignment.get(user)
194+
if target_id is not None:
195+
partition_map[partition_id].add(target_id)
196+
partition_map[partition_id].update(partition_map[target_id])
197+
198+
def _merge_partitions(
199+
self,
200+
partitions_by_id,
201+
assignment,
202+
partition_users,
203+
partition_map,
204+
partitions_order,
205+
):
206+
"""Merge partitions when possible."""
207+
merged = True
208+
while merged:
209+
merged = False
210+
partition_ids = list(partitions_by_id.keys())
211+
212+
for i, p1 in enumerate(partition_ids):
213+
if p1 not in partitions_by_id:
214+
continue
215+
216+
for p2 in partition_ids[i + 1 :]:
217+
if p2 not in partitions_by_id:
218+
continue
219+
220+
# Try to merge partitions if it doesn't create cycles
221+
if self._can_merge_partitions(p1, p2, partitions_by_id):
222+
self._perform_partition_merge(
223+
p1,
224+
p2,
225+
partitions_by_id,
226+
assignment,
227+
partition_users,
228+
partition_map,
229+
partitions_order,
230+
)
231+
merged = True
232+
break
233+
234+
if merged:
235+
break
236+
237+
def _perform_partition_merge(
238+
self,
239+
p1,
240+
p2,
241+
partitions_by_id,
242+
assignment,
243+
partition_users,
244+
partition_map,
245+
partitions_order,
246+
):
247+
"""Merge partition p2 into p1."""
248+
# Merge p2 into p1
249+
partitions_by_id[p1].nodes.update(partitions_by_id[p2].nodes)
250+
for node in partitions_by_id[p2].nodes:
251+
assignment[node] = p1
252+
253+
# Update partition users
254+
all_users = partition_users[p1] | partition_users[p2]
255+
all_users.difference_update(partitions_by_id[p1].nodes)
256+
partition_users[p1] = all_users
257+
258+
# Update partition map
259+
partition_map[p1].update(partition_map[p2])
260+
261+
# Update partition order
262+
partitions_order[p1] = min(partitions_order[p1], partitions_order[p2])
263+
264+
# Remove p2
265+
del partitions_by_id[p2]
266+
del partition_users[p2]
267+
del partitions_order[p2]
268+
if p2 in partition_map:
269+
del partition_map[p2]
270+
271+
def _process_getitem_nodes(self, partitions_by_id, assignment):
272+
"""Post-process getitem nodes."""
273+
nodes_reassignment = {}
274+
275+
for node in self.graph_module.graph.nodes:
276+
# Check if all users are getitem nodes
277+
is_tuple_output = True
278+
for user in node.users:
279+
if (
280+
user.op != "call_function"
281+
or _get_qualified_name(user.target) != "_operator.getitem"
282+
):
283+
is_tuple_output = False
284+
break
285+
286+
# Node has tuple outputs, reassign all following getitem nodes into node's partition
287+
if is_tuple_output:
288+
id = assignment.get(node, None)
289+
if id is not None:
290+
for user in node.users:
291+
if user in assignment and assignment.get(user, None) != id:
292+
nodes_reassignment[user] = id
293+
294+
# Apply reassignments
295+
for node, id in nodes_reassignment.items():
296+
if node in assignment:
297+
partitions_by_id[assignment[node]].remove_node(node)
298+
299+
assignment[node] = id
300+
partitions_by_id[id].add_node(node)
301+
302+
def _filter_single_node_partitions(self, partitions_by_id):
303+
"""Filter out single node partitions if needed."""
304+
if self.allows_single_node_partition:
305+
return
306+
307+
default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
308+
non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops or []))
309+
partitions_to_remove = []
310+
311+
for id, partition in partitions_by_id.items():
312+
compute_node_count = 0
313+
for node in partition.nodes:
314+
if node.op == "call_function":
315+
assert callable(node.target)
316+
target_name = _get_qualified_name(node.target)
317+
318+
if target_name not in non_compute_ops:
319+
compute_node_count += 1
320+
321+
if (
322+
self.allowed_single_node_partition_ops
323+
and target_name in self.allowed_single_node_partition_ops
324+
):
325+
compute_node_count += 1
326+
327+
if compute_node_count <= 1:
328+
partitions_to_remove.append(id)
329+
330+
for id in partitions_to_remove:
331+
del partitions_by_id[id]
332+
333+
def propose_partitions(self) -> list[Partition]:
334+
"""
335+
Propose partitions for the graph module based on node groups and operator support.
336+
337+
Returns:
338+
A list of proposed partitions.
339+
"""
340+
# Initialize data structures
341+
partition_map = collections.defaultdict(
342+
set
343+
) # Maps partition IDs to reachable partition IDs
344+
assignment = {} # Maps nodes to partition IDs
345+
partitions_by_id = {} # Maps partition IDs to partitions
346+
nodes_order = {} # Maps nodes to topological order
347+
partitions_order = {} # Maps partition IDs to minimum topological order
348+
partition_users = {} # Maps partition IDs to partition users
349+
new_partition_id = itertools.count()
350+
351+
# Process nodes in predefined groups
352+
self._process_node_groups(
353+
new_partition_id,
354+
partitions_by_id,
355+
assignment,
356+
nodes_order,
357+
partitions_order,
358+
partition_users,
359+
partition_map,
360+
)
361+
362+
# Process remaining nodes
363+
self._process_remaining_nodes(
364+
new_partition_id,
365+
partitions_by_id,
366+
assignment,
367+
nodes_order,
368+
partitions_order,
369+
partition_users,
370+
partition_map,
371+
)
372+
373+
# Merge partitions when possible
374+
self._merge_partitions(
375+
partitions_by_id,
376+
assignment,
377+
partition_users,
378+
partition_map,
379+
partitions_order,
380+
)
381+
382+
# Post-process getitem nodes
383+
self._process_getitem_nodes(partitions_by_id, assignment)
384+
385+
# Filter single node partitions if needed
386+
self._filter_single_node_partitions(partitions_by_id)
387+
388+
# Return non-empty partitions
389+
return [p for p in partitions_by_id.values() if p.size() > 0]

0 commit comments

Comments
 (0)