Skip to content

Support common subexpression elimination pass (CSE) #2304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
May 30, 2025
Merged
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f6563ad
draft tests
titaiwangms May 13, 2025
2f55000
add more tests
titaiwangms May 13, 2025
d567ba1
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 13, 2025
d1082d1
Update onnxscript/ir/passes/common/common_subexpression_elimination.py
titaiwangms May 13, 2025
ea873e4
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 16, 2025
017ef27
inplace
titaiwangms May 16, 2025
2a370e4
add recursive function but one test is still faling
titaiwangms May 16, 2025
d490072
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 27, 2025
706b86a
revert subgraph cse support
titaiwangms May 27, 2025
dcbc08d
add another test for subgraph
titaiwangms May 27, 2025
55d32c7
add the pass to optimization
titaiwangms May 27, 2025
c5cab5b
make repeated contained attributes hashable
titaiwangms May 27, 2025
be2c008
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 27, 2025
da05efb
delete previous_node and only delete the node
titaiwangms May 28, 2025
ce2bc54
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 29, 2025
1d4fd53
create and use a stateless function
titaiwangms May 29, 2025
5cfd94e
keep the names of graph output
titaiwangms May 29, 2025
44f6042
address reviews
titaiwangms May 29, 2025
ab212d6
resolve conflict
titaiwangms May 30, 2025
9c2d134
revert
titaiwangms May 30, 2025
6a43bfb
fix lint
titaiwangms May 30, 2025
3b1b19f
separate import common_subexpression_elimination
titaiwangms May 30, 2025
9fd8948
remove cse from optimizer
titaiwangms May 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions onnxscript/ir/_convenience/__init__.py
Original file line number Diff line number Diff line change
@@ -373,6 +373,5 @@ def replace_nodes_and_values(
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]

# insert new nodes after the index node
if new_nodes:
graph_or_function.insert_after(insertion_point, new_nodes)
graph_or_function.insert_after(insertion_point, new_nodes)
graph_or_function.remove(old_nodes, safe=True)
45 changes: 40 additions & 5 deletions onnxscript/ir/passes/common/common_subexpression_elimination.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
]

import logging
from typing import Sequence

from onnxscript import ir

@@ -88,16 +89,50 @@
# We replace the node with the existing node.
modified = True
existing_node = existing_node_info_to_the_node[node_info]
ir.convenience.replace_nodes_and_values(
_remove_node_and_replace__values(
graph,
insertion_point=node,
old_nodes=[node],
new_nodes=[], # Delete the duplicate node.
old_values=node.outputs,
remove_nodes=[node],
remove_values=node.outputs,
new_values=existing_node.outputs,
)
logger.debug("Reusing node %s", existing_node)
else:
# If it is not, add to the mapping.
existing_node_info_to_the_node[node_info] = node
return modified


def _remove_node_and_replace__values(
graph: ir.Graph,
/,
remove_nodes: ir.Node,
remove_values: Sequence[ir.Value],
new_values: Sequence[ir.Value],
) -> None:
"""Replaces nodes and values in the graph or function.

Args:
graph: The graph to replace nodes and values in.
remove_nodes: The nodes to remove.
remove_values: The values to replace.
new_values: The values to replace with.
"""

for old_value, new_value in zip(remove_values, new_values):
# Propagate relevant info from old value to new value
# TODO(Rama): Perhaps this should be a separate utility function. Also, consider
# merging old and new type/shape info.
new_value.type = old_value.type
new_value.shape = old_value.shape
new_value.const_value = old_value.const_value
new_value.name = old_value.name

# Reconnect the users of the deleted values to use the new values
ir.convenience.replace_all_uses_with(remove_values, new_values)
# Update graph/function outputs if the node generates output
replacement_mapping = dict(zip(remove_values, new_values))
for idx, graph_or_function_output in enumerate(graph.outputs):
if graph_or_function_output in replacement_mapping:
graph.outputs[idx] = replacement_mapping[graph_or_function_output]

Check warning on line 136 in onnxscript/ir/passes/common/common_subexpression_elimination.py

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination.py#L136

Added line #L136 was not covered by tests

graph.remove(remove_nodes, safe=True)
Loading
Oops, something went wrong.