Skip to content

Commit 6676a6c

Browse files
Lee, Kyunggeunquic-kyunggeu
authored andcommitted
Set priority among supergroups
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com> Co-authored-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
1 parent beb7a4b commit 6676a6c

File tree

14 files changed

+300
-139
lines changed

14 files changed

+300
-139
lines changed

TrainingExtensions/common/src/python/aimet_common/graph_searcher.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@
3636
# =============================================================================
3737
"""Main class for pattern match based graph searcher"""
3838

39-
from typing import Callable, Optional
39+
from typing import Callable, Optional, List, Set
40+
from aimet_common.connected_graph.connectedgraph import ConnectedGraph
41+
from aimet_common.graph_pattern_matcher import PatternType
4042
from aimet_common.utils import AimetLogger
4143
from aimet_common.connected_graph.operation import Op
4244

@@ -110,7 +112,9 @@ class GraphSearcher:
110112
It uses SlidingWindow to maintain the search window and PatternMatcher to match sub graph patterns.
111113
"""
112114

113-
def __init__(self, conn_graph, patterns_with_callback):
115+
def __init__(
116+
self, conn_graph: ConnectedGraph, patterns_with_callback: List[PatternType]
117+
):
114118
"""
115119
initializes params required for pattern matching
116120
:param patterns_with_callback: patterns with corresponding call back functions
@@ -124,18 +128,22 @@ def __init__(self, conn_graph, patterns_with_callback):
124128
else:
125129
self.type_to_op_dict[op.type] = [op]
126130

131+
self._already_matched: Set[Op] = set()
132+
127133
# pylint: disable=too-many-nested-blocks
128134
def find_all_patterns_in_graph_apply_actions(
129135
self,
130136
ignore: Optional[Op] = None,
131137
op_pattern_to_reject: Callable[[Op], bool] = None,
138+
disjoint: bool = False,
132139
):
133140
"""
134141
Find corresponding op sequences and apply actions.
135142
:param ignore: List of operations to ignore during searching
136143
:param op_pattern_to_reject: Callable to perform additional checks on Op to reject pattern match.
137144
This is useful to express intent on patterns that should not be matched.
138145
Since GraphSearcher performs high level pattern match, this enables to provide override for aggressive rejection for a given op config.
146+
:param disjoint: If True, ensures that matched patterns do not share any ops.
139147
"""
140148

141149
if ignore is None:
@@ -151,19 +159,29 @@ def find_all_patterns_in_graph_apply_actions(
151159
matched_ops = self._match_pattern(
152160
op, pattern_type.pattern, ignore, op_pattern_to_reject
153161
)
154-
if matched_ops:
155-
for matched_ops_list in matched_ops:
156-
pattern_type.action(pattern_type, matched_ops_list)
157-
logger.debug("found match: %s", matched_ops_list)
162+
if not matched_ops:
163+
continue
164+
for matched_ops_list in matched_ops:
165+
if disjoint and any(
166+
matched_op in self._already_matched
167+
for matched_op in matched_ops_list
168+
):
169+
# This pattern has already been matched as part of a longer pattern
170+
continue
171+
else:
172+
self._already_matched |= set(matched_ops_list)
173+
174+
pattern_type.action(pattern_type, matched_ops_list)
175+
logger.debug("found match: %s", matched_ops_list)
158176

159177
# pylint: disable=too-many-branches, too-many-return-statements
160178
def _match_pattern(
161179
self,
162-
op,
163-
pattern,
164-
ignored_ops,
165-
op_pattern_to_reject: Callable[[Op], bool] = None,
166-
):
180+
op: Op,
181+
pattern: List[str],
182+
ignored_ops: List[Op],
183+
op_pattern_to_reject: Optional[Callable[[Op], bool]] = None,
184+
) -> Optional[List[List[Op]]]:
167185
if not pattern:
168186
return []
169187

TrainingExtensions/onnx/src/python/aimet_onnx/graph_passes/graph_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class GraphPass:
5151
"""
5252

5353
@abstractmethod
54-
def match_pattern(self, op: Op, model: ModelProto):
54+
def match_pattern(self, op: Op, model: ModelProto) -> List[Op]:
5555
"""
5656
Pattern match and collect ops starting from given Op.
5757
"""
@@ -126,7 +126,7 @@ def __init__(self):
126126
self.disable_quantizers: List[str] = []
127127

128128
@abstractmethod
129-
def match_pattern(self, op: Op, model: ModelProto):
129+
def match_pattern(self, op: Op, model: ModelProto) -> List[Op]:
130130
"""
131131
Pattern match and collect ops starting from given Op.
132132
"""

TrainingExtensions/onnx/src/python/aimet_onnx/graph_passes/pass_registry.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from aimet_onnx.meta.connectedgraph import ConnectedGraph
4242
from aimet_onnx.qc_quantize_op import QcQuantizeOp
4343
from aimet_onnx.utils import ModelProto
44+
from ..meta.operations import Op
4445

4546

4647
class PassRegistry:
@@ -126,3 +127,33 @@ def apply_graph_passes(
126127
PASS_REGISTRY[p](model, connected_graph, op_to_quantizers)
127128
else:
128129
raise ValueError(f"Graph pass requested but not found: {p}")
130+
131+
132+
def find_all_matches(
133+
model: ModelProto,
134+
connected_graph: ConnectedGraph,
135+
passes_to_run: List[str],
136+
) -> List[List[Op]]:
137+
"""
138+
Runs list of graph passes on input ConnectedGraph
139+
140+
Args:
141+
connected_graph (ConnectedGraph): Input graph to run graph passes on
142+
passes_to_run (List[str]): List of graph passes to run.
143+
144+
Raises:
145+
ValueError: If requested GraphPass does not exists.
146+
"""
147+
matches = []
148+
149+
for p in passes_to_run:
150+
for op in connected_graph.ordered_ops:
151+
if p in PASS_REGISTRY:
152+
graph_pass = PASS_REGISTRY[p]
153+
match = graph_pass.match_pattern(op, model)
154+
if match:
155+
matches.append(match)
156+
else:
157+
raise ValueError(f"Graph pass requested but not found: {p}")
158+
159+
return matches

TrainingExtensions/onnx/src/python/aimet_onnx/graph_passes/passes/common_patterns.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# pylint: disable=missing-docstring
55

6+
from typing import List
67
from aimet_common.connected_graph.operation import Op
78
from aimet_onnx.utils import ModelProto
89

@@ -13,7 +14,7 @@
1314
)
1415

1516

16-
def match_rms_norm_pattern(op: Op, model: ModelProto):
17+
def match_rms_norm_pattern(op: Op, model: ModelProto) -> List[Op]:
1718
"""Common pattern for RMSNormalization which can be re-used"""
1819
# Match Mul(x, x) or Pow(x, 2)
1920
match = match_pow_2_pattern(op, model)

TrainingExtensions/onnx/src/python/aimet_onnx/graph_passes/passes/decoder_block.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@ def apply_on_op(self, op: Op, model: ModelProto, _: Dict[str, QcQuantizeOp]):
3838
self.pattern_last_op = None
3939

4040
# pylint: disable=too-many-branches, too-many-return-statements
41-
def match_pattern(self, op: Op, model: ModelProto):
41+
def match_pattern(self, op: Op, model: ModelProto) -> List[Op]:
4242
"""
4343
Match RMSNorm pattern and collect ops to disable output quantizers
4444
"""
4545
all_ops = match_rms_norm_pattern(op, model)
4646
if not all_ops:
47-
return False
47+
return []
4848

4949
# Check if weights are present
5050
self.pattern_last_op = all_ops[0]
5151

52-
return True
52+
return all_ops

TrainingExtensions/onnx/src/python/aimet_onnx/graph_passes/passes/layernorm.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
# =============================================================================
3737
# pylint: disable=missing-module-docstring
3838

39+
from typing import List
3940
from aimet_common.connected_graph.operation import Op
4041
from aimet_onnx.graph_passes.graph_pass import SupergroupGraphPass
4142
from aimet_onnx.graph_passes.pass_registry import register_pass
@@ -80,7 +81,7 @@ class LayerNormalization(SupergroupGraphPass):
8081
"""
8182

8283
# pylint: disable=too-many-branches, too-many-return-statements
83-
def match_pattern(self, op: Op, _: ModelProto):
84+
def match_pattern(self, op: Op, _: ModelProto) -> List[Op]:
8485
"""
8586
Match LayerNormalization pattern and collect ops to disable output quantizers
8687
"""
@@ -94,26 +95,26 @@ def match_pattern(self, op: Op, _: ModelProto):
9495
or len(sub_1.output_ops) != 2
9596
or sub_1.inputs[0] != op.inputs[0]
9697
):
97-
return False
98+
return []
9899

99100
pow_1 = get_op_from_outputs(sub_1, "Pow")
100101
div_1 = get_op_from_outputs(sub_1, "Div")
101102
if pow_1 is None or div_1 is None:
102-
return False
103+
return []
103104

104105
# Sqrt(Var(x) + ε)
105106
match, denominator_ops = check_consecutive_ops(
106107
pow_1, ["Pow", "ReduceMean", "Add", "Sqrt"]
107108
)
108109
if not match:
109-
return False
110+
return []
110111

111112
# (x - E(x)) / Sqrt(Var(x) + ε)
112113
if (
113114
div_1.inputs[0].producer != sub_1
114115
or div_1.inputs[1].producer != denominator_ops[-1]
115116
):
116-
return False
117+
return []
117118

118119
# Collect quantizers to disable.
119120
all_ops = [op, sub_1] + denominator_ops + [div_1]
@@ -124,8 +125,10 @@ def match_pattern(self, op: Op, _: ModelProto):
124125
# Check if affine_transform is set.
125126
# (x - E(x)) / Sqrt(Var(x) + ε) * γ
126127
match, div_mul_ops = check_consecutive_ops(div_1, ["Div", "Mul"])
127-
if not match:
128-
return True
128+
if match:
129+
all_ops += div_mul_ops[1:]
130+
else:
131+
return all_ops
129132

130133
# NOTE: keep weights quantized
131134
self.disable_output_quantizers(op_list=[div_1])
@@ -135,10 +138,12 @@ def match_pattern(self, op: Op, _: ModelProto):
135138
match, mul_add_ops = check_consecutive_ops(
136139
div_mul_ops[-1], ["Mul", "Add"], validate_last_op_consumers=False
137140
)
138-
if not match:
139-
return True
141+
if match:
142+
all_ops += mul_add_ops[1:]
143+
else:
144+
return all_ops
140145

141146
# NOTE: skip bias quantization
142147
self.disable_output_quantizers(op_list=mul_add_ops[:1])
143148
self.disable_const_quantizers(op_list=mul_add_ops[-1:])
144-
return True
149+
return all_ops

TrainingExtensions/onnx/src/python/aimet_onnx/graph_passes/passes/matmul_add.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@ def match_pattern(self, op: Op, model: ModelProto):
3030
Match RMSNormalization pattern and collect ops to disable output quantizers
3131
"""
3232
if op.type != "MatMul":
33-
return False
33+
return []
3434

3535
if _get_matmul_add_bias_idx(op, model) is None:
36-
return False
36+
return []
3737

38-
return True
38+
matmul_op: Op = op
39+
add_op: Op = op.output_ops[0]
40+
41+
return [matmul_op, add_op]
3942

4043
def apply_on_op(
4144
self, op: Op, model: ModelProto, op_quantizers: Dict[str, QcQuantizeOp]

TrainingExtensions/onnx/src/python/aimet_onnx/graph_passes/passes/rmsnorm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: BSD-3-Clause
33
# pylint: disable=missing-module-docstring
44

5+
from typing import List
56
from aimet_common.connected_graph.operation import Op
67
from aimet_onnx.graph_passes.graph_pass import SupergroupGraphPass
78
from aimet_onnx.graph_passes.pass_registry import register_pass
@@ -55,13 +56,13 @@ class RMSNormalization(SupergroupGraphPass):
5556
"""
5657

5758
# pylint: disable=too-many-branches, too-many-return-statements
58-
def match_pattern(self, op: Op, model: ModelProto):
59+
def match_pattern(self, op: Op, model: ModelProto) -> List[Op]:
5960
"""
6061
Match RMSNormalization pattern and collect ops to disable output quantizers
6162
"""
6263
all_ops = match_rms_norm_pattern(op, model)
6364
if not all_ops:
64-
return False
65+
return []
6566

6667
# Check if weights are present
6768
elementwise_affine = False
@@ -74,4 +75,4 @@ def match_pattern(self, op: Op, model: ModelProto):
7475
self.disable_output_quantizers(all_ops[:-1])
7576
# Disable all constant quantizers except weights
7677
self.disable_const_quantizers(all_ops[:-1] if elementwise_affine else all_ops)
77-
return True
78+
return all_ops

TrainingExtensions/onnx/src/python/aimet_onnx/quantsim_config/quantsim_config.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
from aimet_onnx.meta.connectedgraph import ConnectedGraph, CONSTANT_TYPE
7373
from aimet_onnx.utils import get_product_name_from_quantized_name
7474
from aimet_onnx.qc_quantize_op import OpMode, QcQuantizeOp
75-
from aimet_onnx.graph_passes.pass_registry import apply_graph_passes
75+
from aimet_onnx.graph_passes.pass_registry import apply_graph_passes, find_all_matches
7676

7777
# pylint: disable=no-name-in-module, ungrouped-imports
7878
if version.parse(onnx.__version__) >= version.parse("1.14.0"):
@@ -337,6 +337,15 @@ def _set_supergroup_configs(self, supergroups_configs: List[SupergroupType]):
337337
Set supergroup specific configurations (fourth level of specificity in configuration file)
338338
:param supergroups_configs: Configurations for supergroups
339339
"""
340+
matched_by_graph_pass = find_all_matches(
341+
self._model.model,
342+
self._conn_graph,
343+
self._get_supergroup_pass_list(),
344+
)
345+
matched_by_graph_pass = set(
346+
node_name for match in matched_by_graph_pass for node_name in match
347+
)
348+
340349
patterns_with_callbacks = []
341350
for supergroup_config in supergroups_configs:
342351
callback = SupergroupConfigCallback(self._model, self._op_to_quantizers)
@@ -347,10 +356,19 @@ def _set_supergroup_configs(self, supergroups_configs: List[SupergroupType]):
347356
for pattern in patterns:
348357
patterns_with_callbacks.append(pattern)
349358

359+
def exclude_from_supergroup(op: NodeProto) -> bool:
360+
return (
361+
_check_if_conv3d_or_depthwise_conv(op)
362+
# Don't apply "supergroups" config if the op will be handled by graph passes
363+
# since "supergroup_pass_list" is a higher priority than "supergroups"
364+
or op in matched_by_graph_pass
365+
)
366+
350367
if patterns_with_callbacks:
351368
graph_searcher = GraphSearcher(self._conn_graph, patterns_with_callbacks)
352369
graph_searcher.find_all_patterns_in_graph_apply_actions(
353-
op_pattern_to_reject=_check_if_conv3d_or_depthwise_conv
370+
op_pattern_to_reject=exclude_from_supergroup,
371+
disjoint=True,
354372
)
355373

356374
def _set_model_input_configs(self, model_input_configs: ConfigType):

TrainingExtensions/onnx/test/python/test_graph_passes/test_graph_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def match_pattern(self, op: Op, _: ModelProto):
8383
self.disable_quantizers = get_const_input_names(
8484
op_list=[op]
8585
) + get_output_names(op_list=[op])
86-
return True
86+
return [op]
8787

8888

8989
def test_register_and_apply_graph_pass():

0 commit comments

Comments
 (0)