diff --git a/onnxscript/ir/passes/common/__init__.py b/onnxscript/ir/passes/common/__init__.py index 34931c924f..3f6f55ee1d 100644 --- a/onnxscript/ir/passes/common/__init__.py +++ b/onnxscript/ir/passes/common/__init__.py @@ -5,6 +5,7 @@ "AddInitializersToInputsPass", "CheckerPass", "ClearMetadataAndDocStringPass", + "CommonSubexpressionEliminationPass", "InlinePass", "LiftConstantsToInitializersPass", "LiftSubgraphInitializersToMainGraphPass", @@ -30,3 +31,7 @@ ShapeInferencePass, TopologicalSortPass, ) + +from onnxscript.ir.passes.common.common_subexpression_elimination import ( + CommonSubexpressionEliminationPass, +) diff --git a/onnxscript/ir/passes/common/common_subexpression_elimination.py b/onnxscript/ir/passes/common/common_subexpression_elimination.py new file mode 100644 index 0000000000..4fce1250a0 --- /dev/null +++ b/onnxscript/ir/passes/common/common_subexpression_elimination.py @@ -0,0 +1,153 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Eliminate common subexpression in ONNX graphs.""" + +from __future__ import annotations + +__all__ = [ + "CommonSubexpressionEliminationPass", +] + +import logging +from typing import Sequence + +from onnxscript import ir + +logger = logging.getLogger(__name__) + + +class CommonSubexpressionEliminationPass(ir.passes.InPlacePass): + """Eliminate common subexpression in ONNX graphs.""" + + def call(self, model: ir.Model) -> ir.passes.PassResult: + """Return the same ir.Model but with CSE applied to the graph.""" + modified = False + graph = model.graph + + modified = _eliminate_common_subexpression(graph, modified) + + return ir.passes.PassResult( + model, + modified=modified, + ) + + +def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool: + """Eliminate common subexpression in ONNX graphs.""" + + # node to node identifier, length of outputs, inputs, and attributes + existing_node_info_to_the_node: dict[ + tuple[ + ir.OperatorIdentifier, + int, # len(outputs) + tuple[int, ...], # input ids + tuple[tuple[str, object], ...], # attributes + ], + ir.Node, + ] = {} + + for node in graph: + # Skip control flow ops like Loop and If. + control_flow_op: bool = False + # Use equality to check if the node is a common subexpression. + attributes = {} + for k, v in node.attributes.items(): + # TODO(exporter team): CSE subgraphs. + # NOTE: control flow ops like Loop and If won't be CSEd + # because attribute: graph won't match. + if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS): + control_flow_op = True + logger.debug("Skipping control flow op %s", node) + # The attribute value could be directly taken from the original + # protobuf, so we need to make a copy of it. + value = v.value + if v.type in ( + ir.AttributeType.INTS, + ir.AttributeType.FLOATS, + ir.AttributeType.STRINGS, + ): + # For INT, FLOAT and STRING attributes, we convert them to tuples + # to ensure they are hashable. + value = tuple(value) + attributes[k] = value + + if control_flow_op: + # If the node is a control flow op, we skip it. + continue + + node_info = ( + node.op_identifier(), + len(node.outputs), + tuple(id(input) for input in node.inputs), + tuple(sorted(attributes.items())), + ) + # Check if the node is a common subexpression. + if node_info in existing_node_info_to_the_node: + # If it is, this node has an existing node with the same + # operator, number of outputs, inputs, and attributes. + # We replace the node with the existing node. + modified = True + existing_node = existing_node_info_to_the_node[node_info] + _remove_node_and_replace_values( + graph, + remove_node=node, + remove_values=node.outputs, + new_values=existing_node.outputs, + ) + logger.debug("Reusing node %s", existing_node) + else: + # If it is not, add to the mapping. + existing_node_info_to_the_node[node_info] = node + return modified + + +def _remove_node_and_replace_values( + graph: ir.Graph, + /, + remove_node: ir.Node, + remove_values: Sequence[ir.Value], + new_values: Sequence[ir.Value], +) -> None: + """Replaces nodes and values in the graph or function. + + Args: + graph: The graph to replace nodes and values in. + remove_node: The node to remove. + remove_values: The values to replace. + new_values: The values to replace with. + """ + # Reconnect the users of the deleted values to use the new values + ir.convenience.replace_all_uses_with(remove_values, new_values) + # Update graph/function outputs if the node generates output + if any(remove_value.is_graph_output() for remove_value in remove_values): + replacement_mapping = dict(zip(remove_values, new_values)) + for idx, graph_output in enumerate(graph.outputs): + if graph_output in replacement_mapping: + new_value = replacement_mapping[graph_output] + if new_value.is_graph_output(): + # If the new value is also a graph output, we need to + # create a Identity node to preserve the remove_value. + identity_node = ir.node( + "Identity", + inputs=[new_value], + outputs=[ + ir.Value( + name=graph_output.name, + type=graph_output.type, + shape=graph_output.shape, + ) + ], + ) + # reuse the name of the graph output + graph.outputs[idx] = identity_node.outputs[0] + graph.insert_before( + remove_node, + identity_node, + ) + else: + # if new_value is not graph output, we just + # update it to use old_value name. + new_value.name = graph_output.name + graph.outputs[idx] = new_value + + graph.remove(remove_node, safe=True) diff --git a/onnxscript/ir/passes/common/common_subexpression_elimination_test.py b/onnxscript/ir/passes/common/common_subexpression_elimination_test.py new file mode 100644 index 0000000000..461af36fc8 --- /dev/null +++ b/onnxscript/ir/passes/common/common_subexpression_elimination_test.py @@ -0,0 +1,303 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import numpy as np +import onnxruntime as ort + +from onnxscript import FLOAT, ir, script +from onnxscript import opset18 as op +from onnxscript.ir.passes.common import common_subexpression_elimination + + +class TestCommonSubexpressionEliminationPass(unittest.TestCase): + def check_graph(self, model: ir.Model, inputs: list[ir.Value], delta_nodes: list[int]): + """Check if the model applied the CSE pass correctly. + + Args: + model: The model to check. + inputs: The inputs to the model. + delta_nodes: The expected change in the number of nodes in the model. + The length of this list should match the number of graphs + in the model. (to support subgraphs in the future) + + Raises: + AssertionError: If the model does not match the expected number of nodes or outputs. + + """ + assert len(list(model.graphs())) == len(delta_nodes) + # Log all results from the original model. + # 1. model graph node counts + original_graphs_node_count = np.array([graph.num_nodes() for graph in model.graphs()]) + model_proto = ir.serde.serialize_model(model) + + # 2. model outputs + ort_inputs = { + k.name: np.random.rand(*v.shape).astype(np.float32) + for k, v in zip(model.graph.inputs, inputs) + } + original_model_session = ort.InferenceSession(model_proto.SerializeToString()) + original_model_results = original_model_session.run(None, ort_inputs) + + result = common_subexpression_elimination.CommonSubexpressionEliminationPass()(model) + + result_graphs_node_count = np.array([graph.num_nodes() for graph in model.graphs()]) + # Check if the number of nodes in the model is correct + self.assertTrue( + np.array_equal( + original_graphs_node_count, np.add(result_graphs_node_count, delta_nodes) + ) + ) + self.assertEqual( + result.modified, any(original_graphs_node_count > result_graphs_node_count) + ) + + result_proto = ir.serde.serialize_model(result.model) + result_session = ort.InferenceSession(result_proto.SerializeToString()) + result_results = result_session.run(None, ort_inputs) + + # Check if the models produce the same output + # with the same inputs + for idx, original_model_result in enumerate(original_model_results): + np.testing.assert_allclose( + original_model_result, result_results[idx], rtol=1e-5, atol=1e-5 + ) + + def test_duplicate_operations_are_csed(self): + """Test if the same operations are CSEd. + + def test_simple(self): + def f(x): + a = x.cos() + b = x.cos() + c = a + a + d = b + b + return c + d + + x = torch.randn(2, 2) + """ + + @script() + def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]: + a = op.Cos(x) + b = op.Cos(x) + c = a + a + d = b + b + return c + d + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + + self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=[2]) + + def test_more_operations_in_duplicated_operations_is_csed(self): + """Test if the same operations are CSEd. + + def test_simple(self): + def f(x): + a = x.cos().sin() + b = x.cos().sin() + c = a + a + d = b + b + return c + d + + x = torch.randn(2, 2) + """ + + @script() + def test_model(x: FLOAT[1]) -> FLOAT[1]: + a = op.Sin(op.Cos(x)) + b = op.Sin(op.Cos(x)) + c = a + a + d = b + b + return c + d + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + self.check_graph(model, [np.random.rand(1)], delta_nodes=[3]) + + def test_multiple_same_ops_with_attributes_are_csed(self): + """Test if multiple same ops are CSEd. + + def f(x): + a = x.sum() + b = x.sum() + c = x.sum() + d = x.sum() + return a + b + c + d + + x = torch.randn(2, 2) + + """ + + @script() + def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]: + a = op.ReduceSum(x, keepdims=False) + b = op.ReduceSum(x, keepdims=False) + c = op.ReduceSum(x, keepdims=False) + d = op.ReduceSum(x, keepdims=False) + return a + b + c + d + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=[3]) + + def test_the_ops_with_the_same_inputs_but_different_attributes_are_not_csed(self): + """Test if the ops with the same inputs but different attributes are not CSEd. + + def f(x): + a = x.sum() + b = x.sum(keepdims=True) + c = x.sum() + d = x.sum(keepdims=True) + return a + b + c + d + + x = torch.randn(2, 2) + + """ + + @script() + def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]: + a = op.ReduceSum(x, keepdims=False) + b = op.ReduceSum(x, keepdims=True) + return a + b + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=[0]) + + def test_control_flow_if_ops_are_not_csed_as_graph_attr_is_not_matched(self): + """Test if control flow ops are not CSEd. + + def f(a, b): + rank = a.rank() + if rank == 2: + result1 = a - b + else: + result1 = a + b + if rank == 2: + result2 = a - b + else: + result2 = a + b + return result1 + result2 + + x = torch.randn(2, 2) + + """ + + @script() + def test_model(a: FLOAT[2, 2], b: FLOAT[2, 2]) -> FLOAT[2, 2]: + rank = op.Size(op.Shape(a)) + if rank == 2: + result1 = a - b + else: + result1 = a + b + if rank == 2: + result2 = a - b + else: + result2 = a + b + return result1 + result2 + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + self.check_graph( + model, [np.random.rand(2, 2), np.random.rand(2, 2)], delta_nodes=[0, 0, 0, 0, 0] + ) + + def test_the_nodes_following_control_flow_ops_are_csed(self): + """Test if the nodes following control flow ops are CSEd. + + def f(a, b): + rank = a.rank() + if rank == 2: + x = a - b + else: + x = a + b + a = x.cos().sin() + b = x.cos().sin() + c = a + a + d = b + b + return c + d + + x = torch.randn(2, 2) + + """ + + @script() + def test_model(a: FLOAT[2, 2], b: FLOAT[2, 2]) -> FLOAT[2, 2]: + rank = op.Size(op.Shape(a)) + if rank == 2: + x = a - b + else: + x = a + b + a = op.Sin(op.Cos(x)) + b = op.Sin(op.Cos(x)) + c = a + a + d = b + b + return c + d + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + self.check_graph( + model, [np.random.rand(2, 2), np.random.rand(2, 2)], delta_nodes=[3, 0, 0] + ) + + def test_graph_output_value_replacement_preserves_name(self): + @script() + def test_model(x: FLOAT[2, 2]) -> (FLOAT[2, 2], FLOAT[2, 2]): + a = op.Cos(x) + b = op.Cos(x) + return a + b, b + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + # Set custom output names + output_name_0 = "my_output_0" + output_name_1 = "my_output_1" + model.graph.outputs[0].name = output_name_0 + model.graph.outputs[1].name = output_name_1 + original_output_value_0 = model.graph.outputs[0] + original_output_value_1 = model.graph.outputs[1] + + # Run CSE pass + result = common_subexpression_elimination.CommonSubexpressionEliminationPass()(model) + new_output_value_0 = result.model.graph.outputs[0] + new_output_value_1 = result.model.graph.outputs[1] + + # The Value objects should be replaced (different id) + self.assertIs(original_output_value_0, new_output_value_0) + self.assertIsNot(original_output_value_1, new_output_value_1) + # But the names should be preserved + self.assertEqual(new_output_value_0.name, output_name_0) + self.assertEqual(new_output_value_1.name, output_name_1) + + def test_identity_inserted_when_both_outputs_are_graph_outputs(self): + @script() + def test_model(x: FLOAT[2, 2]) -> (FLOAT[2, 2], FLOAT[2, 2]): + a = op.Cos(x) + b = op.Cos(x) + return a, b + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + # Set custom output names + output_name_0 = "output0" + output_name_1 = "output1" + model.graph.outputs[0].name = output_name_0 + model.graph.outputs[1].name = output_name_1 + + # Run CSE pass + result = common_subexpression_elimination.CommonSubexpressionEliminationPass()(model) + new_graph = result.model.graph + + # There should be an Identity node in the graph + identity_nodes = [node for node in new_graph if node.op_type == "Identity"] + self.assertTrue( + identity_nodes, "No Identity node inserted for duplicated graph outputs." + ) + + # The outputs should still have the correct names + self.assertEqual(new_graph.outputs[0].name, output_name_0) + self.assertEqual(new_graph.outputs[1].name, output_name_1)