-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
fuse.py
69 lines (56 loc) · 2.35 KB
/
fuse.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from torch.fx import (
GraphModule,
map_arg
)
from torch.fx.graph import Graph
from .pattern_utils import (
is_match,
get_default_fusion_patterns,
)
from .fusion_patterns import * # noqa: F401
import copy
class Fuser:
def fuse(self, model, inplace=False, fuse_custom_config_dict=None):
if fuse_custom_config_dict is None:
fuse_custom_config_dict = {}
if not inplace:
model = copy.deepcopy(model)
input_root = model
input_graph = model.graph
self.modules = dict(input_root.named_modules())
additional_fusion_patterns = fuse_custom_config_dict.get("additional_quant_pattern", {})
fusion_patterns = dict(get_default_fusion_patterns(), **additional_fusion_patterns)
# find fusion
fusion_pairs = self._find_matches(input_root, input_graph, fusion_patterns)
self.fused_graph = Graph()
env = {}
def load_arg(a):
return map_arg(a, lambda node: env[node.name])
for node in input_graph.nodes:
root_node, obj = fusion_pairs.get(node.name, (None, None))
if root_node is node:
env[node.name] = obj.fuse(self, load_arg)
elif root_node is None:
env[node.name] = self.fused_graph.node_copy(node, load_arg)
# node matched in patterns and is not root is removed here
model = GraphModule(input_root, self.fused_graph)
return model
def _find_matches(self, root, graph, patterns):
modules = dict(root.named_modules())
match_map = {} # node name -> (root_node, match_value?)
def apply_match(pattern, node, match):
if isinstance(pattern, tuple):
s, *args = pattern
apply_match(s, node, match)
for subpattern, arg in zip(args, node.args):
apply_match(subpattern, arg, match)
else:
# the first pattern matches will take precedence
if node.name not in match_map:
match_map[node.name] = match
for node in reversed(graph.nodes):
if node.name not in match_map:
for pattern, value in patterns.items():
if is_match(modules, node, pattern):
apply_match(pattern, node, (node, value(self, node)))
return match_map