Skip to content

[Pass] Graph extractor pass #2119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxscript/ir/_convenience.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,8 @@
"convert_attribute",
"convert_attributes",
"replace_all_uses_with",
"create_value_mapping",
"replace_nodes_and_values",
]

import typing
2 changes: 2 additions & 0 deletions onnxscript/ir/convenience.py
Original file line number Diff line number Diff line change
@@ -9,11 +9,13 @@
"convert_attributes",
"replace_all_uses_with",
"replace_nodes_and_values",
"create_value_mapping",
]

from onnxscript.ir._convenience import (
convert_attribute,
convert_attributes,
create_value_mapping,
replace_all_uses_with,
replace_nodes_and_values,
)
40 changes: 22 additions & 18 deletions onnxscript/ir/passes/_pass_infra.py
Original file line number Diff line number Diff line change
@@ -70,12 +70,31 @@

Class attributes:
in_place: Whether the pass modifies the model in place.
destructive: Whether the pass will destroy the input model when ``in_place=False``.
"""

in_place: bool = True
destructive: bool = False

def __call__(self, model: ir.Model) -> PassResult:
return self.call(model)
# Check preconditions
try:
self.requires(model)
except PreconditionError:
raise
except Exception as e:
raise PreconditionError("Pre-condition failed") from e

Check warning on line 86 in onnxscript/ir/passes/_pass_infra.py

Codecov / codecov/patch

onnxscript/ir/passes/_pass_infra.py#L83-L86

Added lines #L83 - L86 were not covered by tests

result = self.call(model)

# Check postconditions
try:
self.ensures(model)
except PostconditionError:
raise
except Exception as e:
raise PostconditionError("Post-condition failed") from e

Check warning on line 96 in onnxscript/ir/passes/_pass_infra.py

Codecov / codecov/patch

onnxscript/ir/passes/_pass_infra.py#L93-L96

Added lines #L93 - L96 were not covered by tests
return result

@abc.abstractmethod
def call(self, model: ir.Model) -> PassResult:
@@ -111,12 +130,10 @@
def __init__(
self,
passes: Sequence[PassBase],
check_invariants: bool = False,
steps: int = 1,
):
# TODO(justinchuby): Implement constraints
self.passes = list(passes)
self.check_invariants = check_invariants
self.steps = steps

def __call__(self, model: ir.Model) -> PassResult:
@@ -137,17 +154,10 @@
modified = False
for i, pass_ in enumerate(self.passes):
logger.debug("Running the %s-th pass '%s', (step %s)", i, pass_, step)

# 1. Check preconditions
if self.check_invariants:
try:
pass_.requires(model)
except Exception as e:
raise PreconditionError(f"Pre-condition failed for {pass_}") from e

# 2. Run the pass
try:
pass_result = pass_(model)
except (PreconditionError, PostconditionError):
raise

Check warning on line 160 in onnxscript/ir/passes/_pass_infra.py

Codecov / codecov/patch

onnxscript/ir/passes/_pass_infra.py#L159-L160

Added lines #L159 - L160 were not covered by tests
except Exception as e:
prev_pass_names = [str(p) for p in self.passes[:i]]
raise PassError(
@@ -163,10 +173,4 @@
model = pass_result.model
modified = modified or pass_result.modified

# 3. Check postconditions
if self.check_invariants:
try:
pass_.ensures(model)
except Exception as e:
raise PostconditionError(f"Post-condition failed for {pass_}") from e
return PassResult(model, modified)
158 changes: 158 additions & 0 deletions onnxscript/ir/passes/common/graph_extration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Passes for extracting subgraphs from a graph."""

from __future__ import annotations

import itertools

__all__ = [
"ExtractGraphPass",
]

import logging
from collections.abc import Collection

from onnxscript import ir

logger = logging.getLogger(__name__)


def _find_subgraph_bounded_by_values(
graph: ir.Graph, inputs: Collection[ir.Value], outputs: Collection[ir.Value]
) -> tuple[list[ir.Node], list[ir.Value]]:
"""Finds the subgraph bounded by the given inputs and outputs.

Args:
graph: The graph to search.
inputs: The inputs to the subgraph.
outputs: The outputs of the subgraph.

Returns:
A list of nodes in the subgraph and the initializers used.
"""
node_index = {node: idx for idx, node in enumerate(graph)}
all_nodes = []
value_stack: list[ir.Value] = [*outputs]
visited_nodes: set[ir.Node] = set()
visited_values: set[ir.Value] = set(inputs)
initializers = [val for val in inputs if val.name in graph.initializers]

Check warning on line 39 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L34-L39

Added lines #L34 - L39 were not covered by tests
while value_stack:
value = value_stack.pop()

Check warning on line 41 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L41

Added line #L41 was not covered by tests
if value in visited_values:
continue

Check warning on line 43 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L43

Added line #L43 was not covered by tests
if value.name in graph.initializers:
# Record the initializer
assert value.const_value is not None
initializers.append(value)
visited_values.add(value)

Check warning on line 48 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L46-L48

Added lines #L46 - L48 were not covered by tests
if (node := value.producer()) is not None:
if node not in visited_nodes:
visited_nodes.add(node)
all_nodes.append(node)

Check warning on line 52 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L51-L52

Added lines #L51 - L52 were not covered by tests
for input in node.inputs:
if input not in visited_values and input is not None:
value_stack.append(input)

Check warning on line 55 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L55

Added line #L55 was not covered by tests
# Preserve the original order
all_nodes.sort(key=lambda n: node_index[n])
return all_nodes, initializers

Check warning on line 58 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L57-L58

Added lines #L57 - L58 were not covered by tests


class ExtractGraphPass(ir.passes.PassBase):
"""This pass extracts a subgraph from the given graph."""

# This pass does not modify the model in place
in_place = False
# This pass destroys the input model
destructive = True

def __init__(self, input_names: Collection[str], output_names: Collection[str]) -> None:
"""Extracts sub-model from an ONNX model.

The sub-model is defined by the names of the input and output tensors *exactly*.

Args:
input_names: The names of the inputs to extract. Must be deduplicated.
output_names: The names of the outputs to extract. Must be deduplicated.
"""
super().__init__()
self.input_names = input_names
self.output_names = output_names

Check warning on line 80 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L78-L80

Added lines #L78 - L80 were not covered by tests

def call(self, model: ir.Model) -> ir.passes.PassResult:
values = ir.convenience.create_value_mapping(model.graph)
inputs = [values[name] for name in self.input_names]
outputs = [values[name] for name in self.output_names]
extracted_nodes, initializers = _find_subgraph_bounded_by_values(

Check warning on line 86 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L83-L86

Added lines #L83 - L86 were not covered by tests
model.graph, inputs, outputs
)

model.graph.remove(extracted_nodes)

Check warning on line 90 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L90

Added line #L90 was not covered by tests
# Create inputs for the new graph as the old inputs are owned by the old nodes
new_inputs = []

Check warning on line 92 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L92

Added line #L92 was not covered by tests
for input in inputs:
new_inputs.append(

Check warning on line 94 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L94

Added line #L94 was not covered by tests
ir.Value(
name=input.name,
shape=input.shape,
type=input.type,
doc_string=input.doc_string,
const_value=input.const_value,
)
)
ir.convenience.replace_all_uses_with(inputs, new_inputs)

Check warning on line 103 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L103

Added line #L103 was not covered by tests

new_model = ir.Model(

Check warning on line 105 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L105

Added line #L105 was not covered by tests
ir.Graph(
new_inputs,
outputs,
nodes=extracted_nodes,
initializers=initializers,
doc_string=model.graph.doc_string,
opset_imports=model.graph.opset_imports,
name=model.graph.name,
metadata_props=model.graph.metadata_props,
),
ir_version=model.ir_version,
producer_name=model.producer_name,
producer_version=model.producer_version,
domain=model.domain,
model_version=model.model_version,
doc_string=model.doc_string,
functions=tuple(model.functions.values()),
meta_data_props=model.metadata_props,
)
return ir.passes.PassResult(new_model, modified=True)

Check warning on line 125 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L125

Added line #L125 was not covered by tests

def requires(self, model: ir.Model) -> None:
# All inputs and outputs can be found in the model
values = ir.convenience.create_value_mapping(model.graph)
input_names_not_found = sorted(set(self.input_names) - set(values.keys()))

Check warning on line 130 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L129-L130

Added lines #L129 - L130 were not covered by tests
if input_names_not_found:
raise ir.passes.PreconditionError(

Check warning on line 132 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L132

Added line #L132 was not covered by tests
f"Input names not found in the model: {input_names_not_found}"
)
output_names_not_found = sorted(set(self.output_names) - set(values.keys()))

Check warning on line 135 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L135

Added line #L135 was not covered by tests
if output_names_not_found:
raise ir.passes.PreconditionError(

Check warning on line 137 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L137

Added line #L137 was not covered by tests
f"Output names not found in the model: {output_names_not_found}"
)

# All inputs and outputs must have type and shape
for name in itertools.chain(self.input_names, self.output_names):
value = values[name]

Check warning on line 143 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L143

Added line #L143 was not covered by tests
if value.type is None:
logger.warning(

Check warning on line 145 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L145

Added line #L145 was not covered by tests
"Value %%%s does not have a type: '%r'. "
"Consider setting its type or running shape inference first.",
name,
value,
)
if value.shape is None:
logger.warning(

Check warning on line 152 in onnxscript/ir/passes/common/graph_extration.py

Codecov / codecov/patch

onnxscript/ir/passes/common/graph_extration.py#L152

Added line #L152 was not covered by tests
"Value %%%s does not have a shape: '%r'. "
"Consider setting its shape or running shape inference first.",
name,
value,
)
# TODO(justinchuby): Make sure the subgraph is completely bounded by inputs and outputs
Loading
Oops, something went wrong.