Skip to content

[rewriter] Transpose initializer rule #2255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 23 commits into from
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions onnxscript/rewriter/transpose_initializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Rules to collapse Transpose nodes into initializers."""
from __future__ import annotations
from onnxscript import ir
from onnxscript.rewriter import _ir_utils as ir_utils
from onnxscript.rewriter import pattern as orp

import logging

logger = logging.getLogger(__name__)

class TransposeInitializer(orp.RewriteRuleClassBase):
"""Folds Transpose nodes into initializers."""

def __init__(self):
super().__init__("TransposeInitializer", remove_nodes=True)

def pattern(self, op, initializer):
return op.Transpose(initializer, _allow_other_attributes=True)

def rewrite(self, op, initializer: ir.Value) -> ir.Value:
array = ir_utils.get_const_value(initializer)

Check warning on line 21 in onnxscript/rewriter/transpose_initializer.py

Codecov / codecov/patch

onnxscript/rewriter/transpose_initializer.py#L21

Added line #L21 was not covered by tests
if array is None:
# Do nothing
logger.debug("Failed to obtain the initializer value. Do nothing")

Check warning on line 24 in onnxscript/rewriter/transpose_initializer.py

Codecov / codecov/patch

onnxscript/rewriter/transpose_initializer.py#L24

Added line #L24 was not covered by tests
# TODO: Handle both when perms is None and when perms is not None
return op.Transpose(initializer, perms)

Check warning on line 26 in onnxscript/rewriter/transpose_initializer.py

Codecov / codecov/patch

onnxscript/rewriter/transpose_initializer.py#L26

Added line #L26 was not covered by tests
# TODO Obtain perms from the matched node
return op.initializer(ir.tensor())

Check warning on line 28 in onnxscript/rewriter/transpose_initializer.py

Codecov / codecov/patch

onnxscript/rewriter/transpose_initializer.py#L28

Added line #L28 was not covered by tests

def check(self, context, initializer: ir.Value) -> orp.MatchResult:
del context # Unused
check_result = orp.MatchResult()

Check warning on line 32 in onnxscript/rewriter/transpose_initializer.py

Codecov / codecov/patch

onnxscript/rewriter/transpose_initializer.py#L31-L32

Added lines #L31 - L32 were not covered by tests
if initializer.const_value is None:
return check_result.fail("Value is not an initializer, const_value is None")

Check warning on line 34 in onnxscript/rewriter/transpose_initializer.py

Codecov / codecov/patch

onnxscript/rewriter/transpose_initializer.py#L34

Added line #L34 was not covered by tests
if initializer.producer() is not None:
return check_result.fail("Value is not an initializer, producer is not None")
return check_result

Check warning on line 37 in onnxscript/rewriter/transpose_initializer.py

Codecov / codecov/patch

onnxscript/rewriter/transpose_initializer.py#L36-L37

Added lines #L36 - L37 were not covered by tests


rule = TransposeInitializer.rule()
Loading
Oops, something went wrong.