Skip to content

[rewriter] Transpose initializer rule #2255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 23 commits into from
Closed
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
llama_rule_sets,
no_op,
pattern,
transpose_initializer,
)

_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)
@@ -32,6 +33,7 @@
*cast_constant_of_shape.rules.rules,
*collapse_slices.rules.rules,
*llama_rule_sets.llama_p0_rule_set().rules,
transpose_initializer.rule,
)


59 changes: 59 additions & 0 deletions onnxscript/rewriter/transpose_initializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Rules to collapse Transpose nodes into initializers."""

from __future__ import annotations

import logging

import numpy as np

from onnxscript import ir
from onnxscript.rewriter import _ir_utils as ir_utils
from onnxscript.rewriter import pattern as orp

logger = logging.getLogger(__name__)


class TransposeInitializer(orp.RewriteRuleClassBase):
"""Folds Transpose nodes into initializers."""

def __init__(self):
super().__init__("TransposeInitializer", remove_nodes=True)

def pattern(self, op, initializer):
return op.Transpose(initializer, _allow_other_attributes=True)

def rewrite(self, op, initializer: ir.Value) -> ir.Value:
original_transpose = initializer.consumers()[0]
perm_attr = original_transpose.attributes.get("perm")
assert isinstance(perm_attr, ir.Attr)
array = ir_utils.get_numpy_value(initializer)

Check warning on line 31 in onnxscript/rewriter/transpose_initializer.py

Codecov / codecov/patch

onnxscript/rewriter/transpose_initializer.py#L28-L31

Added lines #L28 - L31 were not covered by tests
if array is None:
# Do nothing
logger.debug("Failed to obtain the initializer value. Do nothing")

Check warning on line 34 in onnxscript/rewriter/transpose_initializer.py

Codecov / codecov/patch

onnxscript/rewriter/transpose_initializer.py#L34

Added line #L34 was not covered by tests
# perm=None is filtered out when the attribute is constructed so we are ok
return op.Transpose(initializer, perm=perm_attr)

Check warning on line 36 in onnxscript/rewriter/transpose_initializer.py

Codecov / codecov/patch

onnxscript/rewriter/transpose_initializer.py#L36

Added line #L36 was not covered by tests

if perm_attr is not None:
perm = perm_attr.as_ints()

Check warning on line 39 in onnxscript/rewriter/transpose_initializer.py

Codecov / codecov/patch

onnxscript/rewriter/transpose_initializer.py#L39

Added line #L39 was not covered by tests
else:
perm = None
transposed = np.transpose(array, axes=perm)
new_name = f"{initializer.const_value.name}_transposed"
return op.initializer(ir.tensor(transposed, name=new_name))

Check warning on line 44 in onnxscript/rewriter/transpose_initializer.py

Codecov / codecov/patch

onnxscript/rewriter/transpose_initializer.py#L41-L44

Added lines #L41 - L44 were not covered by tests

def check(self, context, initializer: ir.Value) -> orp.MatchResult:
del context # Unused
check_result = orp.MatchResult()
if initializer.const_value is None:
return check_result.fail("Value is not an initializer, const_value is None")
if initializer.producer() is not None:
return check_result.fail("Value is not an initializer, producer is not None")
if initializer.uses() != 1:
return check_result.fail("Initializer is used by more than one node")
# TODO(justinchuby): Avoid matching when it is a graph input
return check_result

Check warning on line 56 in onnxscript/rewriter/transpose_initializer.py

Codecov / codecov/patch

onnxscript/rewriter/transpose_initializer.py#L56

Added line #L56 was not covered by tests


rule = TransposeInitializer.rule()
Loading
Oops, something went wrong.