-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
fuser_utils.py
233 lines (176 loc) · 8.42 KB
/
fuser_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import copy
from queue import SimpleQueue
from typing import List, Dict, Tuple
import torch.fx
from torch.fx.graph_module import GraphModule
from torch.fx.graph import Graph
from torch.fx.node import Node
from torch.fx.passes.tools_common import NodeList, NodeSet, legalize_graph
from torch.fx.passes.utils import lift_subgraph_as_module
from torch.fx._compatibility import compatibility
@compatibility(is_backward_compatible=False)
def topo_sort(nodes: NodeList) -> NodeList:
# sort nodes according to the topological order
indegree_map = {node : 0 for node in nodes}
candidates: SimpleQueue = SimpleQueue()
for node in nodes:
for n in node.all_input_nodes:
if n in indegree_map:
indegree_map[node] += 1
if indegree_map[node] == 0:
candidates.put(node)
sorted_nodes: NodeList = list()
while not candidates.empty():
node = candidates.get()
sorted_nodes.append(node)
for n in node.users:
if n in indegree_map:
indegree_map[n] -= 1
if indegree_map[n] == 0:
candidates.put(n)
assert len(nodes) == len(sorted_nodes), "topological sorted nodes doesn't have same length as input nodes"
return sorted_nodes
@compatibility(is_backward_compatible=False)
def validate_partition(partition: NodeList) -> bool:
# verify the partition does't form a dependency cycle in the original graph
# returns True for valid partition, False for invalid
partition_set = set(partition)
outputs: NodeList = list()
for node in partition_set:
for user_node in node.users:
if user_node not in partition_set:
# external user node, need to expose as an output
outputs.append(user_node)
# Perform BFS on the partition outputs.
# If it reaches a node within the partition, then it found a cycle.
# This function takes the ownership of `root_nodes` and may modify it.
def bfs_find_cycle(root_nodes: NodeList) -> bool:
# Set used to exclude nodes that have already been visited.
# If a node has been visited, that node and all its children have
# been checked for cycles.
visited: NodeSet = set()
# Start with `root_nodes` and traverse through (toward child nodes)
# their connected sub-graph. Nodes in `visited` won't be added
# to `queue` again.
queue: NodeList = root_nodes
while queue:
current = queue.pop()
visited.add(current)
if current in partition_set:
# Started from partition's `output` nodes, and reached
# another node in partition. Cycle!
return True
for user_node in current.users:
if user_node in visited:
continue
queue.append(user_node)
# `root_nodes` don't cause cycle.
return False
# Use all output nodes as roots to traverse
# the graph to check cycles.
if bfs_find_cycle(outputs):
return False
return True
@compatibility(is_backward_compatible=False)
def fuse_as_graphmodule(gm: GraphModule,
nodes: NodeList,
module_name: str) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]:
"""
Fuse nodes in graph_module into a GraphModule.
Args:
gm (GraphModule): target graph_module
nodes (List[Node]): list of nodes in `gm` to fuse, where the node must be topologically sorted
module_name: class name for the fused GraphModule
Returns:
fused_gm (GraphModule): fused graph module, where its node is a copy of `nodes` in `gm`
original_inputs (Tuple[Node, ...]): input nodes to `nodes` in original `gm`
original_outputs (Tuple[Node, ...]): consumer nodes of `nodes` in original `gm`
"""
# assumption: nodes are already sorted in topo order
for node in nodes:
assert node.graph.owning_module is gm, f"{node} doesn't belong to passed in graph module {gm._get_name()}"
assert not node._erased, f"{node} has been removed from owning graph"
assert node in gm.graph.nodes, f"{node} is not found in graph module {gm._get_name()}"
# validates partition doesn't introduce dependency circles in the graph
assert validate_partition(nodes), "Invalid partition, found dependency cycles"
subgraph = Graph()
node_to_placeholder: Dict[Node, Node] = {} # mapping of nodes from old graph to placeholder in new graph
node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph
# handles inputs through graph.node_copy's arg_transform functions
def remap_inputs(x):
if x.op == "get_attr":
# TODO: do we really need copy the get_attr node into the graph?
# do something here
pass
if x in nodes:
# x is inside subgraph, return the copied node
# the node should have been copied aleady, as we are copying graph in the topological order
return node_map[x]
if x not in node_to_placeholder:
# x is not in subgraph, create a new placeholder for subgraph
placeholder_node = subgraph.placeholder(x.name, type_expr=x.type)
# copy all meta fields, even if some fields might be irrelvant for the placeholder node
placeholder_node.meta = copy.copy(x.meta)
node_to_placeholder[x] = placeholder_node
return node_to_placeholder[x]
# copy nodes in topological order
for node in nodes:
new_node = subgraph.node_copy(node, remap_inputs)
node_map[node] = new_node
# handles outputs
output_mapping: Dict[Node, Node] = {} # mapping from old output to new outputs
for node in nodes:
for user_node in node.users:
if user_node not in nodes:
# external user node, need to expose as an output
output_mapping[node] = node_map[node]
# outs contain nodes in the new subgraph
outs = tuple(output_mapping.values())
# Take care of the args of FX output node. If there's a single
# output then the output node args is like (output_single), else
# if there're multiple outputs then the output node args is like
# ((output_0, output_1, ...)).
subgraph.output(outs[0] if len(outs) == 1 else outs)
# lint to ensure correctness
subgraph.lint()
fused_gm: GraphModule = lift_subgraph_as_module(gm, subgraph, class_name=module_name)
# sub_gm's input nodes in the original module
original_inputs: Tuple[Node, ...] = tuple(node_to_placeholder.keys())
# sub_gm's outputs node in the original module
original_outputs: Tuple[Node, ...] = tuple(output_mapping.keys())
return fused_gm, original_inputs, original_outputs
@compatibility(is_backward_compatible=False)
def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node, ...]):
# add sub_gm into gm
submodule_name = sub_gm.__class__.__name__
gm.add_submodule(submodule_name, sub_gm)
# Create a call_module node in main graph.
module_node = gm.graph.call_module(
submodule_name,
args=orig_inputs,
kwargs=None)
if len(orig_outputs) == 1:
# main_remapping[comp.orig_outputs[0]] = module_node
orig_outputs[0].replace_all_uses_with(module_node, propagate_meta=True)
else:
for i, orig_output in enumerate(orig_outputs):
# Use Proxy to record getitem access.
proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index]
orig_output.replace_all_uses_with(proxy_out, propagate_meta=True)
return gm
@compatibility(is_backward_compatible=False)
def erase_nodes(gm: GraphModule, nodes: NodeList):
# erase original nodes in inversed topological order
for node in reversed(nodes):
gm.graph.erase_node(node)
@compatibility(is_backward_compatible=False)
def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList]) -> GraphModule:
for partition_id, nodes in enumerate(partitions):
sorted_nodes = topo_sort(nodes)
submodule_name = "fused_" + str(partition_id)
sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name)
insert_subgm(gm, sub_gm, orig_inputs, orig_outputs)
erase_nodes(gm, sorted_nodes)
# topological sort original gm with newly created sub_gm
legalize_graph(gm)
return gm