Skip to content

Add tests for removing unused nodes in subgraphs #2265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions onnxscript/ir/passes/common/unused_removal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,206 @@
self.assertEqual(list(model.graph.node[0].output), ["z", "mean_out", "var_out"])
self.assertEqual(len(model.graph.node[0].attribute), 1)

def test_remove_unused_nodes_in_subgraph(self):
model = onnx.parser.parse_model(
"""
<ir_version: 10, opset_import: [ "" : 17]>
agraph (float[N] x) => (float[N] z) {
two = Constant <value_float=2.0> ()
four = Add(two, two)
z = If (x) <
then_branch = then_graph () => (then_y) {
two = Constant <value_float=2.0> ()
four = Add(two, two)
then_y = Mul(x, x)
},
else_branch = else_graph () => (then_y) {
two = Constant <value_float=2.0> ()
four = Add(two, two)
then_y = Mul(x, x)
}
>
}
"""
)
model = self.remove_unused_nodes(model)
self.assertEqual(len(model.graph.node), 1)
self.assertEqual(model.graph.node[0].op_type, "If")
self.assertEqual(len(model.graph.node[0].attribute), 2)
then_graph = model.graph.node[0].attribute[0].g
else_graph = model.graph.node[0].attribute[1].g
self.assertEqual(len(then_graph.node), 1)
self.assertEqual(then_graph.node[0].op_type, "Mul")
self.assertEqual(len(else_graph.node), 1)
self.assertEqual(else_graph.node[0].op_type, "Mul")

def test_remove_unused_initializers_in_subgraph(self):
model = onnx.parser.parse_model(
"""
<ir_version: 10, opset_import: [ "" : 17]>
agraph (float[N] x) => (float[N] z) {
two = Constant <value_float=2.0> ()
four = Add(two, two)
z = If (x) <then_branch=then_graph, else_branch=else_graph>
}
<ir_version: 10, opset_import: [ "" : 17]>
then_graph (float[N] x) => (float[N] y)
<float two = {2.0}> {
four = Add(two, two)
y = Mul(x, x)
}
<ir_version: 10, opset_import: [ "" : 17]>
else_graph (float[N] x) => (float[N] y)
<float two = {2.0}> {
four = Add(two, two)
y = Mul(x, x)
}
"""
)
self.assertEqual(len(model.graph.node), 1)
self.assertEqual(model.graph.node[0].op_type, "If")
self.assertEqual(len(model.graph.node[0].attribute), 2)
then_graph = model.graph.node[0].attribute[0].g
else_graph = model.graph.node[0].attribute[1].g
self.assertEqual(len(then_graph.initializer), 1)
self.assertEqual(len(else_graph.initializer), 1)
model = self.remove_unused_nodes(model)
self.assertEqual(len(model.graph.node), 1)
self.assertEqual(model.graph.node[0].op_type, "If")
self.assertEqual(len(model.graph.node[0].attribute), 2)
then_graph = model.graph.node[0].attribute[0].g
else_graph = model.graph.node[0].attribute[1].g
self.assertEqual(len(then_graph.node), 1)
self.assertEqual(then_graph.node[0].op_type, "Mul")
self.assertEqual(len(then_graph.initializer), 0)
self.assertEqual(len(else_graph.node), 1)
self.assertEqual(else_graph.node[0].op_type, "Mul")
self.assertEqual(len(else_graph.initializer), 0)

Check warning on line 329 in onnxscript/ir/passes/common/unused_removal_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/unused_removal_test.py#L311-L329

Added lines #L311 - L329 were not covered by tests

def test_unused_inputs_in_subgraph_are_not_removed(self):
model = onnx.parser.parse_model(
"""
<ir_version: 10, opset_import: [ "" : 17]>
agraph (float[N] x) => (float[N] z) {
two = Constant <value_float=2.0> ()
four = Add(two, two)
z = If (x) <then_branch=then_graph, else_branch=else_graph>
}
<ir_version: 10, opset_import: [ "" : 17]>
then_graph (float[N] x, float[N] two) => (float[N] y) {
four = Add(two, two)
y = Mul(x, x)
}
<ir_version: 10, opset_import: [ "" : 17]>
else_graph (float[N] x, float[N] two) => (float[N] y) {
four = Add(two, two)
y = Mul(x, x)
}
"""
)
model = self.remove_unused_nodes(model)
self.assertEqual(len(model.graph.node), 1)
self.assertEqual(model.graph.node[0].op_type, "If")
self.assertEqual(len(model.graph.node[0].attribute), 2)
then_graph = model.graph.node[0].attribute[0].g
else_graph = model.graph.node[0].attribute[1].g
self.assertEqual(len(then_graph.node), 1)
self.assertEqual(then_graph.node[0].op_type, "Mul")
self.assertEqual(len(then_graph.input), 2)
self.assertEqual(len(else_graph.node), 1)
self.assertEqual(else_graph.node[0].op_type, "Mul")
self.assertEqual(len(else_graph.input), 2)

Check warning on line 363 in onnxscript/ir/passes/common/unused_removal_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/unused_removal_test.py#L352-L363

Added lines #L352 - L363 were not covered by tests

def test_remove_unused_optional_outputs_in_subgraph(self):
model = onnx.parser.parse_model(
"""
<ir_version: 10, opset_import: [ "" : 17]>
agraph (float[N] x) => (float[N] z) {
two = Constant <value_float=2.0> ()
four = Add(two, two)
z = If (x) <then_branch=then_graph, else_branch=else_graph>
}
<ir_version: 10, opset_import: [ "" : 17]>
then_graph (float[N] x) => (float[N] y) {
y, indices = MaxPool <pads = [2, 2, 2, 2], kernel_shape = [5, 5]> (x)
}
<ir_version: 10, opset_import: [ "" : 17]>
else_graph (float[N] x) => (float[N] y) {
y, indices = MaxPool <pads = [2, 2, 2, 2], kernel_shape = [5, 5]> (x)
}
"""
)
self.assertEqual(len(model.graph.node), 1)
self.assertEqual(model.graph.node[0].op_type, "If")
self.assertEqual(len(model.graph.node[0].attribute), 2)
then_graph = model.graph.node[0].attribute[0].g
else_graph = model.graph.node[0].attribute[1].g
self.assertEqual(len(then_graph.node), 1)
self.assertEqual(then_graph.node[0].op_type, "MaxPool")
self.assertEqual(len(then_graph.node[0].output), 2)
self.assertEqual(len(else_graph.node), 1)
self.assertEqual(else_graph.node[0].op_type, "MaxPool")
self.assertEqual(len(else_graph.node[0].output), 2)
model = self.remove_unused_nodes(model)
self.assertEqual(len(model.graph.node), 1)
self.assertEqual(model.graph.node[0].op_type, "If")
self.assertEqual(len(model.graph.node[0].attribute), 2)
then_graph = model.graph.node[0].attribute[0].g
else_graph = model.graph.node[0].attribute[1].g
self.assertEqual(len(then_graph.node), 1)
self.assertEqual(then_graph.node[0].op_type, "MaxPool")
self.assertEqual(then_graph.node[0].output, ["y"])
self.assertEqual(len(else_graph.node), 1)
self.assertEqual(else_graph.node[0].op_type, "MaxPool")
self.assertEqual(else_graph.node[0].output, ["y"])

Check warning on line 406 in onnxscript/ir/passes/common/unused_removal_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/unused_removal_test.py#L384-L406

Added lines #L384 - L406 were not covered by tests

def test_remove_trailing_unused_optional_outputs_in_subgraph(self):
model = onnx.parser.parse_model(
"""
<ir_version: 10, opset_import: [ "" : 17]>
agraph (float[N] x) => (float[N] z) {
two = Constant <value_float=2.0> ()
four = Add(two, two)
z = If (x) <then_branch=then_graph, else_branch=else_graph>
}
<ir_version: 10, opset_import: [ "" : 17]>
then_graph (float[N] x) => (float[N] y, float[N] mean) {
scale = Constant <value_ints=[3]> ()
B = Constant <value_ints=[3]> ()
y, mean, InvStdDev = LayerNormalization(x, scale, B)
}
<ir_version: 10, opset_import: [ "" : 17]>
else_graph (float[N] x) => (float[N] y, float[N] mean) {
scale = Constant <value_ints=[3]> ()
B = Constant <value_ints=[3]> ()
y, mean, InvStdDev = LayerNormalization(x, scale, B)
}
"""
)
self.assertEqual(len(model.graph.node), 1)
self.assertEqual(model.graph.node[0].op_type, "If")
self.assertEqual(len(model.graph.node[0].attribute), 2)
then_graph = model.graph.node[0].attribute[0].g
else_graph = model.graph.node[0].attribute[1].g
self.assertEqual(len(then_graph.node), 3)
self.assertEqual(then_graph.node[2].op_type, "LayerNormalization")
self.assertEqual(len(then_graph.node[2].output), 3)
self.assertEqual(len(else_graph.node), 3)
self.assertEqual(else_graph.node[2].op_type, "LayerNormalization")
self.assertEqual(len(else_graph.node[2].output), 3)
model = self.remove_unused_nodes(model)
self.assertEqual(len(model.graph.node), 1)
self.assertEqual(model.graph.node[0].op_type, "If")
self.assertEqual(len(model.graph.node[0].attribute), 2)
then_graph = model.graph.node[0].attribute[0].g
else_graph = model.graph.node[0].attribute[1].g
self.assertEqual(len(then_graph.node), 3)
self.assertEqual(then_graph.node[2].op_type, "LayerNormalization")
self.assertEqual(list(then_graph.node[2].output), ["y", "mean"])
self.assertEqual(len(else_graph.node), 3)
self.assertEqual(else_graph.node[2].op_type, "LayerNormalization")
self.assertEqual(list(else_graph.node[2].output), ["y", "mean"])

Check warning on line 453 in onnxscript/ir/passes/common/unused_removal_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/unused_removal_test.py#L431-L453

Added lines #L431 - L453 were not covered by tests


if __name__ == "__main__":
unittest.main()
Loading