Skip to content

[Pass] Graph extractor pass #2119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions onnxscript/ir/passes/common/graph_extration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Passes for extracting subgraphs from a graph."""

from __future__ import annotations

import itertools

__all__ = [
"ExtractGraphPass",
]

import logging
from collections.abc import Collection

from onnxscript import ir

logger = logging.getLogger(__name__)


def _find_subgraph_bounded_by_values(
graph: ir.Graph, inputs: Collection[ir.Value], outputs: Collection[ir.Value]
) -> tuple[list[ir.Node], list[ir.Value]]:
"""Finds the subgraph bounded by the given inputs and outputs.

Args:
graph: The graph to search.
inputs: The inputs to the subgraph.
outputs: The outputs of the subgraph.

Returns:
A list of nodes in the subgraph and the initializers used.
"""
node_index = {node: idx for idx, node in enumerate(graph)}
all_nodes = []
value_stack: list[ir.Value] = [*outputs]
visited_nodes: set[ir.Node] = set()
visited_values: set[ir.Value] = set(inputs)
initializers = [val for val in inputs if val.name in graph.initializers]
while value_stack:
value = value_stack.pop()
if value in visited_values:
continue
if value.name in graph.initializers:
# Record the initializer
assert value.const_value is not None
initializers.append(value)
visited_values.add(value)
if (node := value.producer()) is not None:
if node not in visited_nodes:
visited_nodes.add(node)
all_nodes.append(node)
for input in node.inputs:
if input not in visited_values and input is not None:
value_stack.append(input)
# Preserve the original order
all_nodes.sort(key=lambda n: node_index[n])
return all_nodes, initializers


class ExtractGraphPass(ir.passes.InPlacePass):
"""This pass extracts a subgraph from the given graph."""

def __init__(self, input_names: Collection[str], output_names: Collection[str]) -> None:
"""Extracts sub-model from an ONNX model.

The sub-model is defined by the names of the input and output tensors *exactly*.

Args:
input_names: The names of the inputs to extract. Must be deduplicated.
output_names: The names of the outputs to extract. Must be deduplicated.
"""
super().__init__()
self.input_names = input_names
self.output_names = output_names

def call(self, model: ir.Model) -> ir.passes.PassResult:
values = ir.convenience.create_value_mapping(model.graph)
inputs = [values[name] for name in self.input_names]
outputs = [values[name] for name in self.output_names]
extracted_nodes, initializers = _find_subgraph_bounded_by_values(
model.graph, inputs, outputs
)

model.graph.remove(extracted_nodes)
# Create inputs for the new graph as the old inputs are owned by the old nodes
new_inputs = []
for input in inputs:
new_inputs.append(
ir.Value(
name=input.name,
shape=input.shape,
type=input.type,
doc_string=input.doc_string,
const_value=input.const_value,
)
)
ir.convenience.replace_all_uses_with(inputs, new_inputs)

# Replace the model graph
model.graph = ir.Graph(
new_inputs,
outputs,
nodes=extracted_nodes,
initializers=initializers,
doc_string=model.graph.doc_string,
opset_imports=model.graph.opset_imports,
name=model.graph.name,
metadata_props=model.graph.metadata_props,
)

return ir.passes.PassResult(model, modified=True)

def requires(self, model: ir.Model) -> None:
# All inputs and outputs can be found in the model
values = ir.convenience.create_value_mapping(model.graph)
input_names_not_found = sorted(set(self.input_names) - set(values.keys()))
if input_names_not_found:
raise ir.passes.PreconditionError(
f"Input names not found in the model: {input_names_not_found}"
)
output_names_not_found = sorted(set(self.output_names) - set(values.keys()))
if output_names_not_found:
raise ir.passes.PreconditionError(
f"Output names not found in the model: {output_names_not_found}"
)

# All inputs and outputs must have type and shape
for name in itertools.chain(self.input_names, self.output_names):
value = values[name]
if value.type is None:
logger.warning(
"Value %%%s does not have a type: '%r'. "
"Consider setting its type or running shape inference first.",
name,
value,
)
if value.shape is None:
logger.warning(
"Value %%%s does not have a shape: '%r'. "
"Consider setting its shape or running shape inference first.",
name,
value,
)
# TODO(justinchuby): Make sure the subgraph is completely bounded by inputs and outputs
154 changes: 154 additions & 0 deletions onnxscript/ir/passes/common/graph_extration_test.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change the copilot output

Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF/format Warning

Run lintrunner -a to apply this patch.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.
import unittest

Check warning

Code scanning / lintrunner

RUFF/I001 Warning

Import block is un-sorted or un-formatted.
See https://docs.astral.sh/ruff/rules/unsorted-imports
import numpy as np

from onnxscript import ir
from onnxscript.ir.passes.common.graph_extration import ExtractGraphPass


class TestExtractGraphPass(unittest.TestCase):
def test_extract_subgraph(self):
inputs = [
ir.Value(name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))),
ir.Value(name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))),
]

add_node = ir.node("Add", inputs=inputs)
mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]])

model = ir.Model(
graph=ir.Graph(
inputs=inputs,
outputs=mul_node.outputs,
nodes=[add_node, mul_node],
opset_imports={"": 20},
),
ir_version=10,
)

# Perform extract graph pass
extract_pass = ExtractGraphPass(input_names=["input_a"], output_names=[mul_node.outputs[0].name])
result = extract_pass(model)
self.assertTrue(result.modified)
self.assertEqual(len(result.model.graph.nodes), 2)
self.assertEqual(result.model.graph.nodes[0].op_type, "Add")
self.assertEqual(result.model.graph.nodes[1].op_type, "Mul")

def test_extract_subgraph_with_initializers(self):
inputs = [
ir.Value(name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))),
ir.Value(name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))),
]

constant_tensor = ir.tensor(np.random.rand(2, 3).astype(ir.DataType.FLOAT.numpy()))
const_node = ir.node(
"Constant", inputs=[], attributes={"value": constant_tensor}, num_outputs=1
)
add_node = ir.node("Add", inputs=[inputs[0], const_node.outputs[0]])
mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]])

model = ir.Model(
graph=ir.Graph(
inputs=inputs,
outputs=mul_node.outputs,
nodes=[const_node, add_node, mul_node],
opset_imports={"": 20},
),
ir_version=10,
)

# Perform extract graph pass
extract_pass = ExtractGraphPass(input_names=["input_a"], output_names=[mul_node.outputs[0].name])
result = extract_pass(model)
self.assertTrue(result.modified)
self.assertEqual(len(result.model.graph.nodes), 3)
self.assertEqual(result.model.graph.nodes[0].op_type, "Constant")
self.assertEqual(result.model.graph.nodes[1].op_type, "Add")
self.assertEqual(result.model.graph.nodes[2].op_type, "Mul")

def test_extract_subgraph_with_subgraph(self):
input_value = ir.Value(
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
)

then_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
then_const_node = ir.node(
"Constant", inputs=[], attributes={"value": then_constant_tensor}, num_outputs=1
)
add_node = ir.node("Add", inputs=[input_value, then_const_node.outputs[0]])
then_graph = ir.Graph(
inputs=[input_value],
outputs=[add_node.outputs[0]],
nodes=[then_const_node, add_node],
opset_imports={"": 20},
)
else_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
else_const_node = ir.node(
"Constant", inputs=[], attributes={"value": else_constant_tensor}, num_outputs=1
)
mul_node = ir.node("Mul", inputs=[input_value, else_const_node.outputs[0]])
else_graph = ir.Graph(
inputs=[input_value],
outputs=[mul_node.outputs[0]],
nodes=[else_const_node, mul_node],
opset_imports={"": 20},
)
cond_node = ir.node(
"If",
inputs=[input_value],
attributes={"then_branch": then_graph, "else_branch": else_graph},
num_outputs=1,
)
main_graph = ir.Graph(
inputs=[input_value],
outputs=cond_node.outputs,
nodes=[cond_node],
opset_imports={"": 20},
)
main_graph.sort()
model = ir.Model(
graph=main_graph,
ir_version=10,
)

# Perform extract graph pass
extract_pass = ExtractGraphPass(input_names=["input"], output_names=[cond_node.outputs[0].name])
result = extract_pass(model)
self.assertTrue(result.modified)
self.assertEqual(len(result.model.graph.nodes), 1)
self.assertEqual(result.model.graph.nodes[0].op_type, "If")
self.assertEqual(len(result.model.graph.nodes[0].attributes["then_branch"].nodes), 2)
self.assertEqual(len(result.model.graph.nodes[0].attributes["else_branch"].nodes), 2)

def test_extract_partial_subgraph(self):
inputs = [
ir.Value(name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))),
ir.Value(name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))),
]

add_node = ir.node("Add", inputs=inputs)
mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]])
sub_node = ir.node("Sub", inputs=[mul_node.outputs[0], inputs[0]])

model = ir.Model(
graph=ir.Graph(
inputs=inputs,
outputs=sub_node.outputs,
nodes=[add_node, mul_node, sub_node],
opset_imports={"": 20},
),
ir_version=10,
)

# Perform extract graph pass
extract_pass = ExtractGraphPass(input_names=["input_a"], output_names=[mul_node.outputs[0].name])
result = extract_pass(model)
self.assertTrue(result.modified)
self.assertEqual(len(result.model.graph.nodes), 2)
self.assertEqual(result.model.graph.nodes[0].op_type, "Add")
self.assertEqual(result.model.graph.nodes[1].op_type, "Mul")


if __name__ == "__main__":
unittest.main()
Loading