Skip to content

Support common subexpression elimination pass (CSE) #2304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
May 30, 2025
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f6563ad
draft tests
titaiwangms May 13, 2025
2f55000
add more tests
titaiwangms May 13, 2025
d567ba1
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 13, 2025
d1082d1
Update onnxscript/ir/passes/common/common_subexpression_elimination.py
titaiwangms May 13, 2025
ea873e4
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 16, 2025
017ef27
inplace
titaiwangms May 16, 2025
2a370e4
add recursive function but one test is still faling
titaiwangms May 16, 2025
d490072
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 27, 2025
706b86a
revert subgraph cse support
titaiwangms May 27, 2025
dcbc08d
add another test for subgraph
titaiwangms May 27, 2025
55d32c7
add the pass to optimization
titaiwangms May 27, 2025
c5cab5b
make repeated contained attributes hashable
titaiwangms May 27, 2025
be2c008
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 27, 2025
da05efb
delete previous_node and only delete the node
titaiwangms May 28, 2025
ce2bc54
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 29, 2025
1d4fd53
create and use a stateless function
titaiwangms May 29, 2025
5cfd94e
keep the names of graph output
titaiwangms May 29, 2025
44f6042
address reviews
titaiwangms May 29, 2025
ab212d6
resolve conflict
titaiwangms May 30, 2025
9c2d134
revert
titaiwangms May 30, 2025
6a43bfb
fix lint
titaiwangms May 30, 2025
3b1b19f
separate import common_subexpression_elimination
titaiwangms May 30, 2025
9fd8948
remove cse from optimizer
titaiwangms May 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions onnxscript/ir/passes/common/__init__.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
"AddInitializersToInputsPass",
"CheckerPass",
"ClearMetadataAndDocStringPass",
"CommonSubexpressionEliminationPass",
"InlinePass",
"LiftConstantsToInitializersPass",
"LiftSubgraphInitializersToMainGraphPass",
@@ -30,3 +31,7 @@
ShapeInferencePass,
TopologicalSortPass,
)

from onnxscript.ir.passes.common.common_subexpression_elimination import (
CommonSubexpressionEliminationPass,
)
153 changes: 153 additions & 0 deletions onnxscript/ir/passes/common/common_subexpression_elimination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Eliminate common subexpression in ONNX graphs."""

from __future__ import annotations

__all__ = [
"CommonSubexpressionEliminationPass",
]

import logging
from typing import Sequence

from onnxscript import ir

logger = logging.getLogger(__name__)


class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
"""Eliminate common subexpression in ONNX graphs."""

def call(self, model: ir.Model) -> ir.passes.PassResult:
"""Return the same ir.Model but with CSE applied to the graph."""
modified = False
graph = model.graph

modified = _eliminate_common_subexpression(graph, modified)

return ir.passes.PassResult(
model,
modified=modified,
)


def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool:
"""Eliminate common subexpression in ONNX graphs."""

# node to node identifier, length of outputs, inputs, and attributes
existing_node_info_to_the_node: dict[
tuple[
ir.OperatorIdentifier,
int, # len(outputs)
tuple[int, ...], # input ids
tuple[tuple[str, object], ...], # attributes
],
ir.Node,
] = {}

for node in graph:
# Skip control flow ops like Loop and If.
control_flow_op: bool = False
# Use equality to check if the node is a common subexpression.
attributes = {}
for k, v in node.attributes.items():
# TODO(exporter team): CSE subgraphs.
# NOTE: control flow ops like Loop and If won't be CSEd
# because attribute: graph won't match.
if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
control_flow_op = True
logger.debug("Skipping control flow op %s", node)
# The attribute value could be directly taken from the original
# protobuf, so we need to make a copy of it.
value = v.value
if v.type in (
ir.AttributeType.INTS,
ir.AttributeType.FLOATS,
ir.AttributeType.STRINGS,
):
# For INT, FLOAT and STRING attributes, we convert them to tuples
# to ensure they are hashable.
value = tuple(value)

Check warning on line 71 in onnxscript/ir/passes/common/common_subexpression_elimination.py

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination.py#L71

Added line #L71 was not covered by tests
attributes[k] = value

if control_flow_op:
# If the node is a control flow op, we skip it.
continue

node_info = (
node.op_identifier(),
len(node.outputs),
tuple(id(input) for input in node.inputs),
tuple(sorted(attributes.items())),
)
# Check if the node is a common subexpression.
if node_info in existing_node_info_to_the_node:
# If it is, this node has an existing node with the same
# operator, number of outputs, inputs, and attributes.
# We replace the node with the existing node.
modified = True
existing_node = existing_node_info_to_the_node[node_info]
_remove_node_and_replace_values(
graph,
remove_node=node,
remove_values=node.outputs,
new_values=existing_node.outputs,
)
logger.debug("Reusing node %s", existing_node)
else:
# If it is not, add to the mapping.
existing_node_info_to_the_node[node_info] = node
return modified


def _remove_node_and_replace_values(
graph: ir.Graph,
/,
remove_node: ir.Node,
remove_values: Sequence[ir.Value],
new_values: Sequence[ir.Value],
) -> None:
"""Replaces nodes and values in the graph or function.

Args:
graph: The graph to replace nodes and values in.
remove_node: The node to remove.
remove_values: The values to replace.
new_values: The values to replace with.
"""
# Reconnect the users of the deleted values to use the new values
ir.convenience.replace_all_uses_with(remove_values, new_values)
# Update graph/function outputs if the node generates output
if any(remove_value.is_graph_output() for remove_value in remove_values):
replacement_mapping = dict(zip(remove_values, new_values))
for idx, graph_output in enumerate(graph.outputs):
if graph_output in replacement_mapping:
new_value = replacement_mapping[graph_output]
if new_value.is_graph_output():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't rename it if is a graph input also. We could change this condition to new_value.is_graph_output() or new_value.is_graph_input() ... I think it won't show up in CSE, but will be necessary if we move this as an IR utility used in other scenarios.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a kind of case that we would cse graph inputs? In this PR, it's deleting nodes and affecting their outputs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean

Copy link
Contributor Author

@titaiwangms titaiwangms May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add that in onnx/ir-py#36

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gramalingam I don't see how new_value here could be graph_input. It has to be from a node output. Does that sound right?

# If the new value is also a graph output, we need to
# create a Identity node to preserve the remove_value.
identity_node = ir.node(
"Identity",
inputs=[new_value],
outputs=[
ir.Value(
name=graph_output.name,
type=graph_output.type,
shape=graph_output.shape,
)
],
)
# reuse the name of the graph output
graph.outputs[idx] = identity_node.outputs[0]
graph.insert_before(
remove_node,
identity_node,
)
else:
# if new_value is not graph output, we just
# update it to use old_value name.
new_value.name = graph_output.name
graph.outputs[idx] = new_value

graph.remove(remove_node, safe=True)
Loading
Oops, something went wrong.
Loading
Oops, something went wrong.