Skip to content

Support common subexpression elimination pass (CSE) #2304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
May 30, 2025
Merged
Changes from 7 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f6563ad
draft tests
titaiwangms May 13, 2025
2f55000
add more tests
titaiwangms May 13, 2025
d567ba1
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 13, 2025
d1082d1
Update onnxscript/ir/passes/common/common_subexpression_elimination.py
titaiwangms May 13, 2025
ea873e4
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 16, 2025
017ef27
inplace
titaiwangms May 16, 2025
2a370e4
add recursive function but one test is still faling
titaiwangms May 16, 2025
d490072
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 27, 2025
706b86a
revert subgraph cse support
titaiwangms May 27, 2025
dcbc08d
add another test for subgraph
titaiwangms May 27, 2025
55d32c7
add the pass to optimization
titaiwangms May 27, 2025
c5cab5b
make repeated contained attributes hashable
titaiwangms May 27, 2025
be2c008
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 27, 2025
da05efb
delete previous_node and only delete the node
titaiwangms May 28, 2025
ce2bc54
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 29, 2025
1d4fd53
create and use a stateless function
titaiwangms May 29, 2025
5cfd94e
keep the names of graph output
titaiwangms May 29, 2025
44f6042
address reviews
titaiwangms May 29, 2025
ab212d6
resolve conflict
titaiwangms May 30, 2025
9c2d134
revert
titaiwangms May 30, 2025
6a43bfb
fix lint
titaiwangms May 30, 2025
3b1b19f
separate import common_subexpression_elimination
titaiwangms May 30, 2025
9fd8948
remove cse from optimizer
titaiwangms May 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
@@ -2788,7 +2788,7 @@ def __init__(
model_version: int | None = None,
doc_string: str | None = None,
functions: Sequence[Function] = (),
meta_data_props: dict[str, str] | None = None,
metadata_props: dict[str, str] | None = None,
) -> None:
self.graph: Graph = graph
self.ir_version = ir_version
@@ -2799,7 +2799,7 @@ def __init__(
self.doc_string = doc_string
self._functions = {func.identifier(): func for func in functions}
self._metadata: _metadata.MetadataStore | None = None
self._metadata_props: dict[str, str] | None = meta_data_props
self._metadata_props: dict[str, str] | None = metadata_props

@property
def functions(self) -> dict[_protocols.OperatorIdentifier, Function]:
84 changes: 84 additions & 0 deletions onnxscript/ir/passes/common/common_subexpression_elimination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Eliminate common subexpression in ONNX graphs."""

from __future__ import annotations

__all__ = [
"CommonSubexpressionEliminationPass",
]

import logging

from onnxscript import ir

logger = logging.getLogger(__name__)


class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
"""Eliminate common subexpression in ONNX graphs."""

def call(self, model: ir.Model) -> ir.passes.PassResult:
"""Return the same ir.Model but with CSE applied to the graph."""
modified = False
graph = model.graph

modified = _common_subexpression_elimination(graph, modified)

return ir.passes.PassResult(
model,
modified=modified,
)


def _common_subexpression_elimination(graph: ir.Graph, modified: bool) -> bool:
"""Eliminate common subexpression in ONNX graphs."""

# node to node identifier, length of outputs, inputs, and attributes
existing_node_info_to_the_node: dict[
tuple[
tuple[str, str, str], # op_identifier
int, # len(outputs)
tuple[int, ...], # input ids
tuple[tuple[str, object], ...], # attributes
],
ir.Node,
] = {}
previous_node = None

for node in graph:
# Use equality to check if the node is a common subexpression.
attributes = {}
for k, v in node.attributes.items():
assert isinstance(v, ir.Attr)
if isinstance(v.value, ir.Graph):
modified = _common_subexpression_elimination(v.value, modified)
attributes[k] = v.value

node_info = (
node.op_identifier(),
len(node.outputs),
tuple(id(input) for input in node.inputs),
tuple(sorted(attributes.items())),
)
# Check if the node is a common subexpression.
if node_info in existing_node_info_to_the_node:
# If it is, this node is already in the new graph, so
# we don't need to create a new node.
modified = True
existing_node = existing_node_info_to_the_node[node_info]
ir.convenience.replace_nodes_and_values(
graph,
insertion_point=previous_node or node,
old_nodes=[node],
new_nodes=[existing_node],
old_values=node.outputs,
new_values=existing_node.outputs,
)
previous_node = existing_node
logger.debug("Reusing node %s", existing_node.name)
else:
# If it is not, add to the mapping.
existing_node_info_to_the_node[node_info] = node
previous_node = node
return modified
237 changes: 237 additions & 0 deletions onnxscript/ir/passes/common/common_subexpression_elimination_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import numpy as np
import onnxruntime as ort

from onnxscript import FLOAT, ir, script
from onnxscript import opset18 as op
from onnxscript.ir.passes.common import common_subexpression_elimination


class TestCommonSubexpressionEliminationPass(unittest.TestCase):
def check_graph(self, model: ir.Model, inputs: list[ir.Value], delta_nodes: list[int]):
"""Check if the model applied the CSE pass correctly."""
assert len(list(model.graphs())) == len(delta_nodes)
# Log all results from the original model.
# 1. model graph node counts
original_graphs_node_count = np.array([graph.num_nodes() for graph in model.graphs()])
model_proto = ir.serde.serialize_model(model)

# 2. model outputs
ort_inputs = {
k.name: np.random.rand(*v.shape).astype(np.float32)
for k, v in zip(model.graph.inputs, inputs)
}
original_model_session = ort.InferenceSession(model_proto.SerializeToString())
original_model_results = original_model_session.run(None, ort_inputs)

result = common_subexpression_elimination.CommonSubexpressionEliminationPass()(model)

result_graphs_node_count = np.array([graph.num_nodes() for graph in model.graphs()])

# Check if the number of nodes in the model is correct
self.assertTrue(
np.array_equal(
original_graphs_node_count, np.add(result_graphs_node_count, delta_nodes)
)
)
self.assertEqual(
result.modified, any(original_graphs_node_count > result_graphs_node_count)
)

result_proto = ir.serde.serialize_model(result.model)
result_session = ort.InferenceSession(result_proto.SerializeToString())
result_results = result_session.run(None, ort_inputs)

# Check if the models produce the same output
# with the same inputs
for idx, original_model_result in enumerate(original_model_results):
np.testing.assert_allclose(
original_model_result, result_results[idx], rtol=1e-5, atol=1e-5
)

def test_two_branches_with_the_same_operations_is_csed(self):
"""Test if two branches with the same operations are CSEd.

def test_simple(self):
def f(x):
a = x.cos()
b = x.cos()
c = a + a
d = b + b
return c + d

x = torch.randn(2, 2)
"""

@script()
def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]:
a = op.Cos(x)
b = op.Cos(x)
c = a + a
d = b + b
return c + d

Check warning on line 77 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L73-L77

Added lines #L73 - L77 were not covered by tests

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)

self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=[2])

def test_more_operations_in_two_branches_with_the_same_operations_is_csed(self):
"""Test if two branches with the same operations are CSEd.

def test_simple(self):
def f(x):
a = x.cos().sin()
b = x.cos().sin()
c = a + a
d = b + b
return c + d

x = torch.randn(2, 2)
"""

@script()
def test_model(x: FLOAT[1]) -> FLOAT[1]:
a = op.Sin(op.Cos(x))
b = op.Sin(op.Cos(x))
c = a + a
d = b + b
return c + d

Check warning on line 104 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L100-L104

Added lines #L100 - L104 were not covered by tests

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
self.check_graph(model, [np.random.rand(1)], delta_nodes=[3])

def test_multiple_same_ops_with_attributes_are_csed(self):
"""Test if multiple same ops are CSEd.

def f(x):
a = x.sum()
b = x.sum()
c = x.sum()
d = x.sum()
return a + b + c + d

x = torch.randn(2, 2)

"""

@script()
def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]:
a = op.ReduceSum(x, keepdims=False)
b = op.ReduceSum(x, keepdims=False)
c = op.ReduceSum(x, keepdims=False)
d = op.ReduceSum(x, keepdims=False)
return a + b + c + d

Check warning on line 130 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L126-L130

Added lines #L126 - L130 were not covered by tests

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=[3])

def test_the_ops_with_the_same_inputs_but_different_attributes_are_not_csed(self):
"""Test if the ops with the same inputs but different attributes are not CSEd.

def f(x):
a = x.sum()
b = x.sum(keepdims=True)
c = x.sum()
d = x.sum(keepdims=True)
return a + b + c + d

x = torch.randn(2, 2)

"""

@script()
def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]:
a = op.ReduceSum(x, keepdims=False)
b = op.ReduceSum(x, keepdims=True)
return a + b

Check warning on line 154 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L152-L154

Added lines #L152 - L154 were not covered by tests

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=[0])

def test_control_flow_if_ops_are_not_csed_as_graph_attr_is_not_matched(self):
"""Test if control flow ops are not CSEd.

def f(a, b):
rank = a.rank()
if rank == 2:
result1 = a - b
else:
result1 = a + b
if rank == 2:
result2 = a - b
else:
result2 = a + b
return result1 + result2

x = torch.randn(2, 2)

"""

@script()
def test_model(a: FLOAT[2, 2], b: FLOAT[2, 2]) -> FLOAT[2, 2]:
rank = op.Size(op.Shape(a))

Check warning on line 181 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L181

Added line #L181 was not covered by tests
if rank == 2:
result1 = a - b

Check warning on line 183 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L183

Added line #L183 was not covered by tests
else:
result1 = a + b

Check warning on line 185 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L185

Added line #L185 was not covered by tests
if rank == 2:
result2 = a - b

Check warning on line 187 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L187

Added line #L187 was not covered by tests
else:
result2 = a + b
return result1 + result2

Check warning on line 190 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L189-L190

Added lines #L189 - L190 were not covered by tests

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
self.check_graph(
model, [np.random.rand(2, 2), np.random.rand(2, 2)], delta_nodes=[0, 0, 0, 0, 0]
)

def test_subgraph_is_csed(self):
"""Test if control flow ops are not CSEd.

def f(x):
rank = x.rank()
if rank == 2:
a = x.cos()
b = x.cos()
c = a + a
d = b + b
return c + d
else:
a = x.sin()
b = x.sin()
c = a + a
d = b + b
return c + d

x = torch.randn(2, 2)

"""

@script()
def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]:
rank = op.Size(op.Shape(x))

Check warning on line 222 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L222

Added line #L222 was not covered by tests
if rank == 2:
a = op.Cos(x)
b = op.Cos(x)
c = a + a
d = b + b

Check warning on line 227 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L224-L227

Added lines #L224 - L227 were not covered by tests
else:
a = op.Sin(x)
b = op.Sin(x)
c = a + a
d = b + b
return c + d

Check warning on line 233 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L229-L233

Added lines #L229 - L233 were not covered by tests

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=[0, 2, 2])
2 changes: 1 addition & 1 deletion onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
@@ -511,7 +511,7 @@ def deserialize_model(proto: onnx.ModelProto) -> _core.Model:
model_version=_get_field(proto, "model_version"),
doc_string=_get_field(proto, "doc_string"),
functions=functions,
meta_data_props=deserialize_metadata_props(proto.metadata_props),
metadata_props=deserialize_metadata_props(proto.metadata_props),
)

# Handle experimental value info for functions created by the dynamo exporter in IR version 9
Loading
Oops, something went wrong.