Open
Description
Thanks for onnxscript
💯 .
I'm currently comparing operator fusion in onnxruntime and onnxscript for BART and thereby noticed, that the SkipLayerNormFusion
does currently not fuse ops, if stash_type
is at default (=1
) or epsilon
is at default (=1e-5
) for LayerNormalization
.
The current rule implementation only fuses if both are args explicitly set.
Minimal example:
import onnxscript
import onnxscript.ir as ir
from onnxscript.rewriter.ort_fusions.skip_normalization import (
skip_layer_normalization_ruleset,
)
tape = ir.tape.Tape()
x: ir.Value = tape.initializer(ir.tensor([[[1.0, 2.0, 3.0]]], name="x"))
skip: ir.Value = tape.initializer(ir.tensor([[[0.5, 0.5, 0.5]]], name="skip"))
bias: ir.Value = tape.initializer(ir.tensor([0.5, 0.5, 0.5], name="bias"))
scale: ir.Value = tape.initializer(ir.tensor([1.0, 1.0, 1.0], name="scale"))
optionals = {} # {"stash_type": 1, "epsilon": 1e-4} # <-- switch here.
skip_sum = tape.op("Add", [skip, x])
normalized = tape.op(
"LayerNormalization",
[skip_sum, scale, bias],
attributes= optionals | {"axis": -1},
)
onnx_model = ir.Model(
graph := ir.Graph(
inputs=[x],
outputs=[normalized],
nodes=tape.nodes,
initializers=tape.initializers,
opset_imports={"": 20},
),
ir_version=10,
)
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
onnx_model, pattern_rewrite_rules=skip_layer_normalization_ruleset
)
ir.save(model_with_rewrite_applied, "layer_norm_fusion.onnx")
Original pr #2259.
Would you be open to match against alternative default args for LayerNormalization
? If yes, I'd be happy to contribute a fix.