Skip to content

Commit 3fd79be

Browse files
titaiwangmsCopilot
authored andcommitted
Support common subexpression elimination pass (CSE) (microsoft#2304)
Fix microsoft#2105 For the logic, this PR follows https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/dialect/common/cse_pass.py. Essentially, this PR traverses the original graph and examines whether the values or the nodes are duplicated. If it's not, the value or the node is saved in mappings, and added to the new graph. If it is duplicated, the value or the node is replaced with the mapped/saved value or node. (FunctionalPass) CSE subgraph is not supported: microsoft#2345. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 19b7f6a commit 3fd79be

File tree

3 files changed

+461
-0
lines changed

3 files changed

+461
-0
lines changed

onnxscript/ir/passes/common/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"AddInitializersToInputsPass",
66
"CheckerPass",
77
"ClearMetadataAndDocStringPass",
8+
"CommonSubexpressionEliminationPass",
89
"InlinePass",
910
"LiftConstantsToInitializersPass",
1011
"LiftSubgraphInitializersToMainGraphPass",
@@ -30,3 +31,7 @@
3031
ShapeInferencePass,
3132
TopologicalSortPass,
3233
)
34+
35+
from onnxscript.ir.passes.common.common_subexpression_elimination import (
36+
CommonSubexpressionEliminationPass,
37+
)
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Eliminate common subexpression in ONNX graphs."""
4+
5+
from __future__ import annotations
6+
7+
__all__ = [
8+
"CommonSubexpressionEliminationPass",
9+
]
10+
11+
import logging
12+
from typing import Sequence
13+
14+
from onnxscript import ir
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
20+
"""Eliminate common subexpression in ONNX graphs."""
21+
22+
def call(self, model: ir.Model) -> ir.passes.PassResult:
23+
"""Return the same ir.Model but with CSE applied to the graph."""
24+
modified = False
25+
graph = model.graph
26+
27+
modified = _eliminate_common_subexpression(graph, modified)
28+
29+
return ir.passes.PassResult(
30+
model,
31+
modified=modified,
32+
)
33+
34+
35+
def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool:
36+
"""Eliminate common subexpression in ONNX graphs."""
37+
38+
# node to node identifier, length of outputs, inputs, and attributes
39+
existing_node_info_to_the_node: dict[
40+
tuple[
41+
ir.OperatorIdentifier,
42+
int, # len(outputs)
43+
tuple[int, ...], # input ids
44+
tuple[tuple[str, object], ...], # attributes
45+
],
46+
ir.Node,
47+
] = {}
48+
49+
for node in graph:
50+
# Skip control flow ops like Loop and If.
51+
control_flow_op: bool = False
52+
# Use equality to check if the node is a common subexpression.
53+
attributes = {}
54+
for k, v in node.attributes.items():
55+
# TODO(exporter team): CSE subgraphs.
56+
# NOTE: control flow ops like Loop and If won't be CSEd
57+
# because attribute: graph won't match.
58+
if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
59+
control_flow_op = True
60+
logger.debug("Skipping control flow op %s", node)
61+
# The attribute value could be directly taken from the original
62+
# protobuf, so we need to make a copy of it.
63+
value = v.value
64+
if v.type in (
65+
ir.AttributeType.INTS,
66+
ir.AttributeType.FLOATS,
67+
ir.AttributeType.STRINGS,
68+
):
69+
# For INT, FLOAT and STRING attributes, we convert them to tuples
70+
# to ensure they are hashable.
71+
value = tuple(value)
72+
attributes[k] = value
73+
74+
if control_flow_op:
75+
# If the node is a control flow op, we skip it.
76+
continue
77+
78+
node_info = (
79+
node.op_identifier(),
80+
len(node.outputs),
81+
tuple(id(input) for input in node.inputs),
82+
tuple(sorted(attributes.items())),
83+
)
84+
# Check if the node is a common subexpression.
85+
if node_info in existing_node_info_to_the_node:
86+
# If it is, this node has an existing node with the same
87+
# operator, number of outputs, inputs, and attributes.
88+
# We replace the node with the existing node.
89+
modified = True
90+
existing_node = existing_node_info_to_the_node[node_info]
91+
_remove_node_and_replace_values(
92+
graph,
93+
remove_node=node,
94+
remove_values=node.outputs,
95+
new_values=existing_node.outputs,
96+
)
97+
logger.debug("Reusing node %s", existing_node)
98+
else:
99+
# If it is not, add to the mapping.
100+
existing_node_info_to_the_node[node_info] = node
101+
return modified
102+
103+
104+
def _remove_node_and_replace_values(
105+
graph: ir.Graph,
106+
/,
107+
remove_node: ir.Node,
108+
remove_values: Sequence[ir.Value],
109+
new_values: Sequence[ir.Value],
110+
) -> None:
111+
"""Replaces nodes and values in the graph or function.
112+
113+
Args:
114+
graph: The graph to replace nodes and values in.
115+
remove_node: The node to remove.
116+
remove_values: The values to replace.
117+
new_values: The values to replace with.
118+
"""
119+
# Reconnect the users of the deleted values to use the new values
120+
ir.convenience.replace_all_uses_with(remove_values, new_values)
121+
# Update graph/function outputs if the node generates output
122+
if any(remove_value.is_graph_output() for remove_value in remove_values):
123+
replacement_mapping = dict(zip(remove_values, new_values))
124+
for idx, graph_output in enumerate(graph.outputs):
125+
if graph_output in replacement_mapping:
126+
new_value = replacement_mapping[graph_output]
127+
if new_value.is_graph_output():
128+
# If the new value is also a graph output, we need to
129+
# create a Identity node to preserve the remove_value.
130+
identity_node = ir.node(
131+
"Identity",
132+
inputs=[new_value],
133+
outputs=[
134+
ir.Value(
135+
name=graph_output.name,
136+
type=graph_output.type,
137+
shape=graph_output.shape,
138+
)
139+
],
140+
)
141+
# reuse the name of the graph output
142+
graph.outputs[idx] = identity_node.outputs[0]
143+
graph.insert_before(
144+
remove_node,
145+
identity_node,
146+
)
147+
else:
148+
# if new_value is not graph output, we just
149+
# update it to use old_value name.
150+
new_value.name = graph_output.name
151+
graph.outputs[idx] = new_value
152+
153+
graph.remove(remove_node, safe=True)

0 commit comments

Comments
 (0)