Skip to content

Handling of default attrs of LayerNormalization in SkipLayerNormFusion #2378

Open
@KarelZe

Description

@KarelZe

Thanks for onnxscript💯 .

I'm currently comparing operator fusion in onnxruntime and onnxscript for BART and thereby noticed, that the SkipLayerNormFusion does currently not fuse ops, if stash_type is at default (=1) or epsilon is at default (=1e-5) for LayerNormalization.

The current rule implementation only fuses if both are args explicitly set.

Minimal example:

import onnxscript
import onnxscript.ir as ir
from onnxscript.rewriter.ort_fusions.skip_normalization import (
    skip_layer_normalization_ruleset,
)

tape = ir.tape.Tape()

x: ir.Value = tape.initializer(ir.tensor([[[1.0, 2.0, 3.0]]], name="x"))
skip: ir.Value = tape.initializer(ir.tensor([[[0.5, 0.5, 0.5]]], name="skip"))

bias: ir.Value = tape.initializer(ir.tensor([0.5, 0.5, 0.5], name="bias"))
scale: ir.Value = tape.initializer(ir.tensor([1.0, 1.0, 1.0], name="scale"))

optionals = {} # {"stash_type": 1, "epsilon": 1e-4} # <-- switch here.

skip_sum = tape.op("Add", [skip, x])
normalized = tape.op(
    "LayerNormalization",
    [skip_sum, scale, bias],
    attributes= optionals | {"axis": -1},
)

onnx_model = ir.Model(
    graph := ir.Graph(
        inputs=[x],
        outputs=[normalized],
        nodes=tape.nodes,
        initializers=tape.initializers,
        opset_imports={"": 20},
    ),
    ir_version=10,
)

model_with_rewrite_applied = onnxscript.rewriter.rewrite(
    onnx_model, pattern_rewrite_rules=skip_layer_normalization_ruleset
)
ir.save(model_with_rewrite_applied, "layer_norm_fusion.onnx")

Original pr #2259.

Would you be open to match against alternative default args for LayerNormalization? If yes, I'd be happy to contribute a fix.

@shubhambhokare1 @justinchuby

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions