-
Notifications
You must be signed in to change notification settings - Fork 66
fix: handling of default attrs in SimplifiedLayerNormalization + LayerNormalization🐛 #2396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2396 +/- ##
==========================================
- Coverage 70.37% 62.37% -8.01%
==========================================
Files 199 200 +1
Lines 25216 25473 +257
Branches 2686 2688 +2
==========================================
- Hits 17747 15888 -1859
- Misses 6540 8762 +2222
+ Partials 929 823 -106 ☔ View full report in Codecov by Sentry. |
skip_sum_pattern_2 = op.Add(input, skip) | ||
skip_sum = pattern.OrValue([skip_sum_pattern_1, skip_sum_pattern_2], name="skip_sum") | ||
|
||
skip_sum = op.Add(input, skip) | ||
if self._has_bias and not self._bias_pre_add: | ||
skip_sum = op.Add(skip_sum, bias) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I chose to enable commute(...), as we didn't check for all variants in this addition and only in the lines above.
encoder_layers_0_self_attn_layer_norm_weight | ||
) | ||
|
||
encoder_layers_1_fc2_bias = opset20.Identity(encoder_layers_0_self_attn_k_proj_bias) |
Check warning
Code scanning / CodeQL
Variable defined multiple times
|
||
encoder_layers_1_fc2_bias = opset20.Identity(encoder_layers_0_self_attn_k_proj_bias) | ||
encoder_layers_1_fc1_bias = opset20.Identity(encoder_layers_0_fc1_bias) | ||
encoder_layers_1_self_attn_layer_norm_bias = opset20.Identity( |
Check warning
Code scanning / CodeQL
Variable defined multiple times
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes how default attributes (epsilon
, stash_type
) are handled in both LayerNormalization
and SimplifiedLayerNormalization
fusions, adds a BART encoder model to the fusion tests, and introduces commuted-input support for SkipLayerNormalization
rules.
- Extract default
epsilon
from the matched node instead of requiring it in the pattern signature - Add
test_bart_encoder
to validate fusion with default-attribute cases - Enable commuted-input variants by applying
.commute()
to fusion rules
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
File | Description |
---|---|
skip_normalization_test.py | Added test_bart_encoder to cover default-attribute fusions |
skip_normalization.py | Refactored patterns to drop default attrs, extract epsilon in rewrite, and apply rule commutation |
Comments suppressed due to low confidence (2)
onnxscript/rewriter/ort_fusions/skip_normalization_test.py:73
- The test uses
fuse_skip_layer_normalization(model)
but there is no import for that symbol in this file. Please addfrom onnxscript.rewriter.ort_fusions.skip_normalization import fuse_skip_layer_normalization
(or adjust the import path) to ensure the function is available.
fuse_skip_layer_normalization(model)
onnxscript/rewriter/ort_fusions/skip_normalization.py:231
- The new
.commute()
calls are applied only to the fullSkipLayerNormalization
rules. To allow commuted inputs forSkipSimplifiedLayerNormalization
as well, you should apply.commute()
to the simplified-layer ruleset (if defined) or include those here before applyingapply_fusion_rules
.
skip_layer_normalization_ruleset = pattern.RewriteRuleSet(
**_, | ||
): | ||
epsilon = simplified_layer_norm.producer().attributes.get_float("epsilon") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You extract epsilon
from the matched node but do not extract or forward stash_type
. If a non-default stash_type
was used, it will be lost in the fused op. Consider retrieving stash_type = simplified_layer_norm.producer().attributes.get_int("stash_type")
and passing it into SkipSimplifiedLayerNormalization
.
epsilon = simplified_layer_norm.producer().attributes.get_float("epsilon") | |
epsilon = simplified_layer_norm.producer().attributes.get_float("epsilon") | |
stash_type = simplified_layer_norm.producer().attributes.get_int("stash_type") |
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess there is no stash type for fused layer norm ops? https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.SkipLayerNormalization
if self._has_bias and not self._bias_pre_add: | ||
skip_sum = op.Add(skip_sum, bias) | ||
|
||
normalized = op.LayerNormalization( | ||
skip_sum, | ||
gamma, | ||
beta, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gramalingam beta
is an optional input. I'd lean toward matching both variants (w and w/o bias).
SkipLayerNormFusion
does currently not fuse ops, if stash_type is at default (=1) or epsilon is at default (=1e-5) forLayerNormalization
andSimplifiedLayerNormalization
This pr:
LayerNormalization
,SimplifiedLayerNormalization
EmbedLayerNormalization
.Closes #2378.
@shubhambhokare1 @justinchuby Could you please review? Any feedback is greatly appreciated.