-
Notifications
You must be signed in to change notification settings - Fork 72
[Pass] Graph extractor pass #2119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
28beefd
b4ab30f
30fb4c4
7455c2e
78aefc2
ca03ed9
716413f
307c62c
b0d1843
4479afd
c8b391d
18d70b1
3fcbbbb
c99bab7
157d601
0fd082f
35c45cf
2e5191a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
"""Passes for extracting subgraphs from a graph.""" | ||
|
||
from __future__ import annotations | ||
|
||
import itertools | ||
|
||
__all__ = [ | ||
"ExtractGraphPass", | ||
] | ||
|
||
import logging | ||
from collections.abc import Collection | ||
|
||
from onnxscript import ir | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
|
||
def _find_subgraph_bounded_by_values( | ||
graph: ir.Graph, inputs: Collection[ir.Value], outputs: Collection[ir.Value] | ||
) -> tuple[list[ir.Node], list[ir.Value]]: | ||
"""Finds the subgraph bounded by the given inputs and outputs. | ||
|
||
Args: | ||
graph: The graph to search. | ||
inputs: The inputs to the subgraph. | ||
outputs: The outputs of the subgraph. | ||
|
||
Returns: | ||
A list of nodes in the subgraph and the initializers used. | ||
""" | ||
node_index = {node: idx for idx, node in enumerate(graph)} | ||
all_nodes = [] | ||
value_stack: list[ir.Value] = [*outputs] | ||
visited_nodes: set[ir.Node] = set() | ||
visited_values: set[ir.Value] = set(inputs) | ||
initializers = [val for val in inputs if val.name in graph.initializers] | ||
while value_stack: | ||
value = value_stack.pop() | ||
if value in visited_values: | ||
continue | ||
if value.name in graph.initializers: | ||
# Record the initializer | ||
assert value.const_value is not None | ||
initializers.append(value) | ||
visited_values.add(value) | ||
if (node := value.producer()) is not None: | ||
if node not in visited_nodes: | ||
visited_nodes.add(node) | ||
all_nodes.append(node) | ||
for input in node.inputs: | ||
if input not in visited_values and input is not None: | ||
value_stack.append(input) | ||
# Preserve the original order | ||
all_nodes.sort(key=lambda n: node_index[n]) | ||
return all_nodes, initializers | ||
|
||
|
||
class ExtractGraphPass(ir.passes.InPlacePass): | ||
"""This pass extracts a subgraph from the given graph.""" | ||
|
||
def __init__(self, input_names: Collection[str], output_names: Collection[str]) -> None: | ||
"""Extracts sub-model from an ONNX model. | ||
|
||
The sub-model is defined by the names of the input and output tensors *exactly*. | ||
|
||
Args: | ||
input_names: The names of the inputs to extract. Must be deduplicated. | ||
output_names: The names of the outputs to extract. Must be deduplicated. | ||
""" | ||
super().__init__() | ||
self.input_names = input_names | ||
self.output_names = output_names | ||
|
||
def call(self, model: ir.Model) -> ir.passes.PassResult: | ||
values = ir.convenience.create_value_mapping(model.graph) | ||
inputs = [values[name] for name in self.input_names] | ||
outputs = [values[name] for name in self.output_names] | ||
extracted_nodes, initializers = _find_subgraph_bounded_by_values( | ||
model.graph, inputs, outputs | ||
) | ||
|
||
model.graph.remove(extracted_nodes) | ||
# Create inputs for the new graph as the old inputs are owned by the old nodes | ||
new_inputs = [] | ||
for input in inputs: | ||
new_inputs.append( | ||
ir.Value( | ||
name=input.name, | ||
shape=input.shape, | ||
type=input.type, | ||
doc_string=input.doc_string, | ||
const_value=input.const_value, | ||
) | ||
) | ||
ir.convenience.replace_all_uses_with(inputs, new_inputs) | ||
|
||
# Replace the model graph | ||
model.graph = ir.Graph( | ||
new_inputs, | ||
outputs, | ||
nodes=extracted_nodes, | ||
initializers=initializers, | ||
doc_string=model.graph.doc_string, | ||
opset_imports=model.graph.opset_imports, | ||
name=model.graph.name, | ||
metadata_props=model.graph.metadata_props, | ||
) | ||
|
||
return ir.passes.PassResult(model, modified=True) | ||
|
||
def requires(self, model: ir.Model) -> None: | ||
# All inputs and outputs can be found in the model | ||
values = ir.convenience.create_value_mapping(model.graph) | ||
input_names_not_found = sorted(set(self.input_names) - set(values.keys())) | ||
if input_names_not_found: | ||
raise ir.passes.PreconditionError( | ||
f"Input names not found in the model: {input_names_not_found}" | ||
) | ||
output_names_not_found = sorted(set(self.output_names) - set(values.keys())) | ||
if output_names_not_found: | ||
raise ir.passes.PreconditionError( | ||
f"Output names not found in the model: {output_names_not_found}" | ||
) | ||
|
||
# All inputs and outputs must have type and shape | ||
for name in itertools.chain(self.input_names, self.output_names): | ||
value = values[name] | ||
if value.type is None: | ||
logger.warning( | ||
"Value %%%s does not have a type: '%r'. " | ||
"Consider setting its type or running shape inference first.", | ||
name, | ||
value, | ||
) | ||
if value.shape is None: | ||
logger.warning( | ||
"Value %%%s does not have a shape: '%r'. " | ||
"Consider setting its shape or running shape inference first.", | ||
name, | ||
value, | ||
) | ||
# TODO(justinchuby): Make sure the subgraph is completely bounded by inputs and outputs |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change the copilot output |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
Check warningCode scanning / lintrunner RUFF/format Warning
Run lintrunner -a to apply this patch.
Check warningCode scanning / lintrunner RUFF-FORMAT/format Warning
Run lintrunner -a to apply this patch.
|
||
# Licensed under the MIT License. | ||
import unittest | ||
Check warningCode scanning / lintrunner RUFF/I001 Warning
Import block is un-sorted or un-formatted.
See https://docs.astral.sh/ruff/rules/unsorted-imports |
||
import numpy as np | ||
|
||
from onnxscript import ir | ||
from onnxscript.ir.passes.common.graph_extration import ExtractGraphPass | ||
|
||
|
||
class TestExtractGraphPass(unittest.TestCase): | ||
def test_extract_subgraph(self): | ||
inputs = [ | ||
ir.Value(name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))), | ||
ir.Value(name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))), | ||
] | ||
|
||
add_node = ir.node("Add", inputs=inputs) | ||
mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]]) | ||
|
||
model = ir.Model( | ||
graph=ir.Graph( | ||
inputs=inputs, | ||
outputs=mul_node.outputs, | ||
nodes=[add_node, mul_node], | ||
opset_imports={"": 20}, | ||
), | ||
ir_version=10, | ||
) | ||
|
||
# Perform extract graph pass | ||
extract_pass = ExtractGraphPass(input_names=["input_a"], output_names=[mul_node.outputs[0].name]) | ||
result = extract_pass(model) | ||
self.assertTrue(result.modified) | ||
self.assertEqual(len(result.model.graph.nodes), 2) | ||
self.assertEqual(result.model.graph.nodes[0].op_type, "Add") | ||
self.assertEqual(result.model.graph.nodes[1].op_type, "Mul") | ||
|
||
def test_extract_subgraph_with_initializers(self): | ||
inputs = [ | ||
ir.Value(name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))), | ||
ir.Value(name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))), | ||
] | ||
|
||
constant_tensor = ir.tensor(np.random.rand(2, 3).astype(ir.DataType.FLOAT.numpy())) | ||
const_node = ir.node( | ||
"Constant", inputs=[], attributes={"value": constant_tensor}, num_outputs=1 | ||
) | ||
add_node = ir.node("Add", inputs=[inputs[0], const_node.outputs[0]]) | ||
mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]]) | ||
|
||
model = ir.Model( | ||
graph=ir.Graph( | ||
inputs=inputs, | ||
outputs=mul_node.outputs, | ||
nodes=[const_node, add_node, mul_node], | ||
opset_imports={"": 20}, | ||
), | ||
ir_version=10, | ||
) | ||
|
||
# Perform extract graph pass | ||
extract_pass = ExtractGraphPass(input_names=["input_a"], output_names=[mul_node.outputs[0].name]) | ||
result = extract_pass(model) | ||
self.assertTrue(result.modified) | ||
self.assertEqual(len(result.model.graph.nodes), 3) | ||
self.assertEqual(result.model.graph.nodes[0].op_type, "Constant") | ||
self.assertEqual(result.model.graph.nodes[1].op_type, "Add") | ||
self.assertEqual(result.model.graph.nodes[2].op_type, "Mul") | ||
|
||
def test_extract_subgraph_with_subgraph(self): | ||
input_value = ir.Value( | ||
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) | ||
) | ||
|
||
then_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) | ||
then_const_node = ir.node( | ||
"Constant", inputs=[], attributes={"value": then_constant_tensor}, num_outputs=1 | ||
) | ||
add_node = ir.node("Add", inputs=[input_value, then_const_node.outputs[0]]) | ||
then_graph = ir.Graph( | ||
inputs=[input_value], | ||
outputs=[add_node.outputs[0]], | ||
nodes=[then_const_node, add_node], | ||
opset_imports={"": 20}, | ||
) | ||
else_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) | ||
else_const_node = ir.node( | ||
"Constant", inputs=[], attributes={"value": else_constant_tensor}, num_outputs=1 | ||
) | ||
mul_node = ir.node("Mul", inputs=[input_value, else_const_node.outputs[0]]) | ||
else_graph = ir.Graph( | ||
inputs=[input_value], | ||
outputs=[mul_node.outputs[0]], | ||
nodes=[else_const_node, mul_node], | ||
opset_imports={"": 20}, | ||
) | ||
cond_node = ir.node( | ||
"If", | ||
inputs=[input_value], | ||
attributes={"then_branch": then_graph, "else_branch": else_graph}, | ||
num_outputs=1, | ||
) | ||
main_graph = ir.Graph( | ||
inputs=[input_value], | ||
outputs=cond_node.outputs, | ||
nodes=[cond_node], | ||
opset_imports={"": 20}, | ||
) | ||
main_graph.sort() | ||
model = ir.Model( | ||
graph=main_graph, | ||
ir_version=10, | ||
) | ||
|
||
# Perform extract graph pass | ||
extract_pass = ExtractGraphPass(input_names=["input"], output_names=[cond_node.outputs[0].name]) | ||
result = extract_pass(model) | ||
self.assertTrue(result.modified) | ||
self.assertEqual(len(result.model.graph.nodes), 1) | ||
self.assertEqual(result.model.graph.nodes[0].op_type, "If") | ||
self.assertEqual(len(result.model.graph.nodes[0].attributes["then_branch"].nodes), 2) | ||
self.assertEqual(len(result.model.graph.nodes[0].attributes["else_branch"].nodes), 2) | ||
|
||
def test_extract_partial_subgraph(self): | ||
inputs = [ | ||
ir.Value(name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))), | ||
ir.Value(name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))), | ||
] | ||
|
||
add_node = ir.node("Add", inputs=inputs) | ||
mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]]) | ||
sub_node = ir.node("Sub", inputs=[mul_node.outputs[0], inputs[0]]) | ||
|
||
model = ir.Model( | ||
graph=ir.Graph( | ||
inputs=inputs, | ||
outputs=sub_node.outputs, | ||
nodes=[add_node, mul_node, sub_node], | ||
opset_imports={"": 20}, | ||
), | ||
ir_version=10, | ||
) | ||
|
||
# Perform extract graph pass | ||
extract_pass = ExtractGraphPass(input_names=["input_a"], output_names=[mul_node.outputs[0].name]) | ||
result = extract_pass(model) | ||
self.assertTrue(result.modified) | ||
self.assertEqual(len(result.model.graph.nodes), 2) | ||
self.assertEqual(result.model.graph.nodes[0].op_type, "Add") | ||
self.assertEqual(result.model.graph.nodes[1].op_type, "Mul") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Uh oh!
There was an error while loading. Please reload this page.