Skip to content

[Pass] Graph extractor pass #2119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxscript/ir/_convenience.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,8 @@
"convert_attribute",
"convert_attributes",
"replace_all_uses_with",
"create_value_mapping",
"replace_nodes_and_values",
]

import typing
2 changes: 2 additions & 0 deletions onnxscript/ir/convenience.py
Original file line number Diff line number Diff line change
@@ -9,11 +9,13 @@
"convert_attributes",
"replace_all_uses_with",
"replace_nodes_and_values",
"create_value_mapping",
]

from onnxscript.ir._convenience import (
convert_attribute,
convert_attributes,
create_value_mapping,
replace_all_uses_with,
replace_nodes_and_values,
)
40 changes: 22 additions & 18 deletions onnxscript/ir/passes/_pass_infra.py
Original file line number Diff line number Diff line change
@@ -70,12 +70,31 @@ class PassBase(abc.ABC):

Class attributes:
in_place: Whether the pass modifies the model in place.
destructive: Whether the pass will destroy the input model when ``in_place=False``.
"""

in_place: bool = True
destructive: bool = False

def __call__(self, model: ir.Model) -> PassResult:
return self.call(model)
# Check preconditions
try:
self.requires(model)
except PreconditionError:
raise
except Exception as e:
raise PreconditionError("Pre-condition failed") from e

result = self.call(model)

# Check postconditions
try:
self.ensures(model)
except PostconditionError:
raise
except Exception as e:
raise PostconditionError("Post-condition failed") from e
return result

@abc.abstractmethod
def call(self, model: ir.Model) -> PassResult:
@@ -111,12 +130,10 @@ class PassManager:
def __init__(
self,
passes: Sequence[PassBase],
check_invariants: bool = False,
steps: int = 1,
):
# TODO(justinchuby): Implement constraints
self.passes = list(passes)
self.check_invariants = check_invariants
self.steps = steps

def __call__(self, model: ir.Model) -> PassResult:
@@ -137,17 +154,10 @@ def _run_one_step(self, model: ir.Model, step: int) -> PassResult:
modified = False
for i, pass_ in enumerate(self.passes):
logger.debug("Running the %s-th pass '%s', (step %s)", i, pass_, step)

# 1. Check preconditions
if self.check_invariants:
try:
pass_.requires(model)
except Exception as e:
raise PreconditionError(f"Pre-condition failed for {pass_}") from e

# 2. Run the pass
try:
pass_result = pass_(model)
except (PreconditionError, PostconditionError):
raise
except Exception as e:
prev_pass_names = [str(p) for p in self.passes[:i]]
raise PassError(
@@ -163,10 +173,4 @@ def _run_one_step(self, model: ir.Model, step: int) -> PassResult:
model = pass_result.model
modified = modified or pass_result.modified

# 3. Check postconditions
if self.check_invariants:
try:
pass_.ensures(model)
except Exception as e:
raise PostconditionError(f"Post-condition failed for {pass_}") from e
return PassResult(model, modified)
136 changes: 136 additions & 0 deletions onnxscript/ir/passes/common/graph_extration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Passes for extracting subgraphs from a graph."""

from __future__ import annotations

import itertools

__all__ = [
"ExtractGraphPass",
]

import logging
from collections.abc import Collection

from onnxscript import ir

logger = logging.getLogger(__name__)


def _find_subgraph_bounded_by_values(
graph: ir.Graph, inputs: Collection[ir.Value], outputs: Collection[ir.Value]
) -> tuple[list[ir.Node], list[ir.Value]]:
"""Finds the subgraph bounded by the given inputs and outputs.

Args:
graph: The graph to search.
inputs: The inputs to the subgraph.
outputs: The outputs of the subgraph.

Returns:
A list of nodes in the subgraph and the initializers used.
"""
all_nodes = []
value_stack: list[ir.Value] = [*outputs]
visited_nodes: set[ir.Node] = set()
visited_values: set[ir.Value] = set()
initializers = []
while value_stack:
value = value_stack.pop()
if value in visited_values:
continue
if value.name in graph.initializers:
# Record the initializer
assert value.const_value is not None
initializers.append(value)
visited_values.add(value)
if (node := value.producer()) is not None:
if node not in visited_nodes:
visited_nodes.add(node)
all_nodes.append(node)
for input in node.inputs:
if input not in visited_values and input is not None:
value_stack.append(input)
return all_nodes, initializers


class ExtractGraphPass(ir.passes.PassBase):
"""This pass performs shape inference on the graph."""

# This pass does not modify the model in place
in_place = False
# This pass destroys the input model
destructive = True

def __init__(self, *, input_names: Collection[str], output_names: Collection[str]) -> None:
"""Extracts sub-model from an ONNX model.

The sub-model is defined by the names of the input and output tensors *exactly*.

Args:
input_names: The names of the inputs to extract.
output_names: The names of the outputs to extract.
"""
super().__init__()
self.input_names = input_names
self.output_names = output_names

def requires(self, model: ir.Model) -> None:
# All inputs and outputs can be found in the model
values = ir.convenience.create_value_mapping(model.graph)
input_names_not_found = sorted(set(self.input_names) - set(values.keys()))
if input_names_not_found:
raise ir.passes.PreconditionError(
f"Input names not found in the model: {input_names_not_found}"
)
output_names_not_found = sorted(set(self.output_names) - set(values.keys()))
if output_names_not_found:
raise ir.passes.PreconditionError(
f"Output names not found in the model: {output_names_not_found}"
)

# All inputs and outputs must have type and shape
for name in itertools.chain(self.input_names, self.output_names):
value = values[name]
if value.type is None:
raise ir.passes.PreconditionError(
f"Value {name} does not have a type: {value}. "
"Consider setting its type or running shape inference first."
)
if value.shape is None:
raise ir.passes.PreconditionError(
f"Value {name} does not have a shape: {value}. "
"Consider setting its shape or running shape inference first."
)

def call(self, model: ir.Model) -> ir.passes.PassResult:
values = ir.convenience.create_value_mapping(model.graph)
inputs = [values[name] for name in self.input_names]
outputs = [values[name] for name in self.output_names]
extracted_nodes, initializers = _find_subgraph_bounded_by_values(
model.graph, inputs, outputs
)
# Create a graph with the extracted nodes
model.graph.remove(extracted_nodes)
new_model = ir.Model(
ir.Graph(
inputs,
outputs,
nodes=extracted_nodes,
initializers=initializers,
doc_string=model.graph.doc_string,
opset_imports=model.graph.opset_imports,
name=model.graph.name,
metadata_props=model.graph.metadata_props,
),
ir_version=model.ir_version,
producer_name=model.producer_name,
producer_version=model.producer_version,
domain=model.domain,
model_version=model.model_version,
doc_string=model.doc_string,
functions=tuple(model.functions.values()),
meta_data_props=model.metadata_props,
)
return ir.passes.PassResult(new_model, modified=True)
Loading
Oops, something went wrong.