Skip to content

Clean up rewriter code: improve efficiency, finish TODOs, and enhance documentation #2392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,38 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""ONNX Model Rewriter.

This module provides pattern-based rewriting capabilities for ONNX models.
The rewriter allows you to define patterns that match subgraphs in ONNX models
and replace them with equivalent but potentially more efficient implementations.

Main Components:
- pattern: Main API for defining and applying rewrite rules
- rewrite(): High-level function to apply rules to models
- RewritePass: Integration with the IR passes framework

Example Usage:

```python
import onnxscript.rewriter as rewriter

# Apply default optimization rules
optimized_model = rewriter.rewrite(model)

# Apply custom rules
from onnxscript.rewriter import pattern

class MyOptimization(pattern.RewriteRuleClassBase):
def pattern(self, op, x):
return op.Add(x, op.Constant(value=0.0))

def rewrite(self, op, x, zero=None):
return op.Identity(x)

custom_rules = [MyOptimization.rule()]
optimized_model = rewriter.rewrite(model, custom_rules)
```
"""
from __future__ import annotations

from typing import Sequence, TypeVar, Union
Expand Down
52 changes: 51 additions & 1 deletion onnxscript/rewriter/_rewrite_rule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,56 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Rewrite rules for ONNX models."""
"""Rewrite rules for ONNX models.

This module provides the core functionality for pattern-based rewriting of ONNX models.
It includes:

- RewriteRuleClassBase: Recommended base class for implementing rewrite rules using a class-based API
- RewriteRule: Defines a single pattern-to-replacement rewrite transformation
- RewriteRuleSet: Manages a collection of rewrite rules and applies them to models
- Supporting utilities for pattern matching, replacement, and context management

The rewriter enables users to define patterns that match subgraphs in ONNX models
and replace them with equivalent but potentially more efficient implementations.

Example usage with class-based rules (recommended):

```python
class ConstantFolding(RewriteRuleClassBase):
\"\"\"Fold Add operations with two constants\"\"\"

def pattern(self, op, x, y):
return op.Add(x, y)

def check(self, context, x, y):
# Only apply if both inputs are constants
return (x.const_value is not None and
y.const_value is not None)

def rewrite(self, op, x, y):
# Compute the result and create a constant
result = x.const_value + y.const_value
return op.Constant(value=result)

# Apply the rule
rule = ConstantFolding.rule()
rule.apply_to_model(model)
```

Function-based usage (lower-level API):

```python
def add_zero_pattern(op, x):
return op.Add(x, op.Constant(value=0.0))

def identity_replacement(op, x):
return op.Identity(x)

# Create and apply the rule
rule = RewriteRule(add_zero_pattern, identity_replacement)
rule.apply_to_model(model)
```
"""

from __future__ import annotations

Expand Down
80 changes: 80 additions & 0 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,85 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Pattern-based rewriter API for ONNX models.

This module provides the main user-facing API for the ONNX pattern rewriter.
It allows users to define patterns that match subgraphs in ONNX models and
replace them with more efficient implementations.

Recommended Usage with Class-Based Rules:

```python
from onnxscript.rewriter import pattern

class AddZeroElimination(pattern.RewriteRuleClassBase):
\"\"\"Removes addition with zero: Add(x, 0) -> Identity(x)\"\"\"

def pattern(self, op, x):
zero = op.Constant(value=0.0)
return op.Add(x, zero)

def check(self, context, x, zero):
# Optional: Add conditions for when to apply this rule
return zero.const_value is not None and zero.const_value.item() == 0.0

def rewrite(self, op, x, zero=None):
return op.Identity(x)

# Create and apply the rule
rule = AddZeroElimination.rule()
rule.apply_to_model(model)
```

Multiple pattern example:

```python
class TransposeElimination(pattern.RewriteRuleClassBase):
\"\"\"Removes redundant transpose: Transpose(Transpose(x, perm), reverse_perm) -> x\"\"\"

def pattern(self, op, x, perm):
return op.Transpose(x, perm=perm)

def check(self, context, x, perm):
# Only apply if permutation is identity (no-op transpose)
if perm.is_ref():
return False
if perm.type == ir.AttributeType.INTS:
perm_list = perm.as_ints()
return perm_list == list(range(len(perm_list)))
return False

def rewrite(self, op, x, perm=None):
return op.Identity(x)

# Apply multiple rules as a set
rules = pattern.RewriteRuleSet([
AddZeroElimination.rule(),
TransposeElimination.rule()
])
rules.apply_to_model(model)
```

Function-based rules (lower-level API):

```python
# For simple cases, you can still use function-based rules
def mul_one_pattern(op, x):
one = op.Constant(value=1.0)
return op.Mul(x, one)

def identity_replacement(op, x):
return op.Identity(x)

rule = pattern.RewriteRule(mul_one_pattern, identity_replacement)
```

Classes and functions exported:
- RewriteRuleClassBase: Recommended base class for implementing rewrite rules
- RewriteRule: Core class for defining pattern-to-replacement rules
- RewriteRuleSet: Collection of rules with application logic
- Pattern building utilities: OpsetPatternBuilder, pattern_builder, etc.
- Matching utilities: MatchResult, MatchingTracer, etc.
"""
from __future__ import annotations

from onnxscript.ir import _tape
Expand Down