-
Notifications
You must be signed in to change notification settings - Fork 4
/
transforms.py
203 lines (170 loc) · 6.86 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""
HiddenLayer
Transforms that apply to and modify graph nodes.
Written by Waleed Abdulla
Licensed under the MIT License
"""
import re
import copy
from .graph import Node
from . import ge
###########################################################################
# Transforms
###########################################################################
class Fold():
def __init__(self, pattern, op, name=None):
# TODO: validate that op and name are valid
self.pattern = ge.GEParser(pattern).parse()
self.op = op
self.name = name
def apply(self, graph):
# Copy the graph. Don't change the original.
graph = copy.deepcopy(graph)
while True:
matches, _ = graph.search(self.pattern)
if not matches:
break
# Replace pattern with new node
if self.op == "__first__":
combo = matches[0]
elif self.op == "__last__":
combo = matches[-1]
else:
combo = Node(uid=graph.sequence_id(matches),
name=self.name or " > ".join([l.title for l in matches]),
op=self.op or self.pattern,
output_shape=matches[-1].output_shape)
combo._caption = "/".join(filter(None, [l.caption for l in matches]))
graph.replace(matches, combo)
return graph
class FoldId():
def __init__(self, id_regex, op, name=None):
# TODO: validate op and name are valid
self.id_regex = re.compile(id_regex)
self.op = op
self.name = name
def apply(self, graph):
# Copy the graph. Don't change the original.
graph = copy.deepcopy(graph)
# Group nodes by the first matching group of the regex
groups = {}
for node in graph.nodes.values():
m = self.id_regex.match(node.id)
if not m:
continue
assert m.groups(), "Regular expression must have a matching group to avoid folding unrelated nodes."
key = m.group(1)
if key not in groups:
groups[key] = []
groups[key].append(node)
# Fold each group of nodes together
for key, nodes in groups.items():
# Replace with a new node
# TODO: Find last node in the sub-graph and get the output shape from it
combo = Node(uid=key,
name=self.name,
op=self.op)
graph.replace(nodes, combo)
return graph
class Prune():
def __init__(self, pattern):
self.pattern = ge.GEParser(pattern).parse()
def apply(self, graph):
# Copy the graph. Don't change the original.
graph = copy.deepcopy(graph)
while True:
matches, _ = graph.search(self.pattern)
if not matches:
break
# Remove found nodes
graph.remove(matches)
return graph
class PruneBranch():
def __init__(self, pattern):
self.pattern = ge.GEParser(pattern).parse()
def tag(self, node, tag, graph, conditional=False):
# Return if the node is already tagged
if hasattr(node, "__tag__") and node.__tag__ == "tag":
return
# If conditional, then tag the node if and only if all its
# outgoing nodes already have the same tag.
if conditional:
# Are all outgoing nodes already tagged?
outgoing = graph.outgoing(node)
tagged = filter(lambda n: hasattr(n, "__tag__") and n.__tag__ == tag,
outgoing)
if len(list(tagged)) != len(outgoing):
# Not all outgoing are tagged
return
# Tag the node
node.__tag__ = tag
# Tag incoming nodes
for n in graph.incoming(node):
self.tag(n, tag, graph, conditional=True)
def apply(self, graph):
# Copy the graph. Don't change the original.
graph = copy.deepcopy(graph)
while True:
matches, _ = graph.search(self.pattern)
if not matches:
break
# Tag found nodes and their incoming branches
for n in matches:
self.tag(n, "delete", graph)
# Find all tagged nodes and delete them
tagged = [n for n in graph.nodes.values()
if hasattr(n, "__tag__") and n.__tag__ == "delete"]
graph.remove(tagged)
return graph
class FoldDuplicates():
def apply(self, graph):
# Copy the graph. Don't change the original.
graph = copy.deepcopy(graph)
matches = True
while matches:
for node in graph.nodes.values():
pattern = ge.SerialPattern([ge.NodePattern(node.op), ge.NodePattern(node.op)])
matches, _ = pattern.match(graph, node)
if matches:
# Use op and name from the first node, and output_shape from the last
combo = Node(uid=graph.sequence_id(matches),
name=node.name,
op=node.op,
output_shape=matches[-1].output_shape)
combo._caption = node.caption
combo.repeat = sum([n.repeat for n in matches])
graph.replace(matches, combo)
break
return graph
class Rename():
def __init__(self, op=None, name=None, to=None):
assert op or name, "Either op or name must be provided"
assert not(op and name), "Either op or name should be provided, but not both"
assert bool(to), "The to parameter is required"
self.to = to
self.op = re.compile(op) if op else None
self.name = re.compile(name) if name else None
def apply(self, graph):
# Copy the graph. Don't change the original.
graph = copy.deepcopy(graph)
for node in graph.nodes.values():
if self.op:
node.op = self.op.sub(self.to, node.op)
# TODO: name is not tested yet
if self.name:
node.name = self.name.sub(self.to, node.name)
return graph
# Transforms to simplify graphs by folding layers that tend to be
# used together often, such as Conv/BN/Relu.
# These transforms are used AFTER the framework specific transforms
# that map TF and PyTorch graphs to a common representation.
SIMPLICITY_TRANSFORMS = [
Fold("Conv > Conv > BatchNorm > Relu", "ConvConvBnRelu"),
Fold("Conv > BatchNorm > Relu", "ConvBnRelu"),
Fold("Conv > BatchNorm", "ConvBn"),
Fold("Conv > Relu", "ConvRelu"),
Fold("Linear > Relu", "LinearRelu"),
# Fold("ConvBnRelu > MaxPool", "ConvBnReluMaxpool"),
# Fold("ConvRelu > MaxPool", "ConvReluMaxpool"),
FoldDuplicates(),
]