-
Notifications
You must be signed in to change notification settings - Fork 66
Rewriter: Fold Batchnorm nodes #2312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds functionality to fold BatchNormalization nodes into the preceding Conv, ConvTranspose, and Gemm nodes, simplifying the graph and improving runtime efficiency.
- Implements rewrite rules for Gemm, Conv, and ConvTranspose to absorb batchnorm parameters into weights and biases.
- Adds comprehensive unit tests covering folding behavior with/without bias and a non-initializer scenario.
- Registers the new fold_batchnorm rule in the default rewrite-rule set.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
onnxscript/rewriter/fold_batchnorm.py | Implements rewrite rules and helper functions to fold BatchNormalization. |
onnxscript/rewriter/fold_batchnorm_test.py | Adds tests for folding BatchNorm into ConvTranspose, Conv, Gemm, and non-init case. |
onnxscript/rewriter/init.py | Registers fold_batchnorm in the default rewrite-rule set. |
Comments suppressed due to low confidence (2)
onnxscript/rewriter/fold_batchnorm.py:70
- [nitpick] Using
x.name + "_bias"
can produce non-obvious or conflicting initializer names. Consider deriving the bias name from the inbound node's output or weight name for clarity and uniqueness.
bias_name = x.name + "_bias"
onnxscript/rewriter/fold_batchnorm_test.py:38
- Add a test case where the BatchNormalization node has a non-default
epsilon
attribute to ensure folding logic correctly respects custom epsilon values.
Y = BatchNormalization(X1, gamma, beta, input_mean, input_var)
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2312 +/- ##
==========================================
+ Coverage 73.68% 73.81% +0.13%
==========================================
Files 240 242 +2
Lines 30978 31138 +160
Branches 3517 3530 +13
==========================================
+ Hits 22825 22984 +159
Misses 6932 6932
- Partials 1221 1222 +1 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
- Could you follow
onnxscript/onnxscript/rewriter/llama_rule_sets.py
Lines 35 to 48 in 2ae13be
class CastIdentity(orp.RewriteRuleClassBase): """Replaces ``Cast(., to=to)`` by ``Identity`` if possible.""" def pattern(self, op, x, to): return op.Cast(x, to=to) def rewrite(self, op, x: ir.Value, to: ir.Attr): return op.Identity(x) def check(self, context, x, to) -> orp.MatchResult: check_result = orp.MatchResult() if x.dtype != to.as_int(): return check_result.fail("Input and output types are not the same") return check_result pattern.RewriteRuleClassBase
? You may create additional subclasses for the different variants for the rules. - In the subclasses I would also create a docstring to describe what the rules do in plain language.
- You may also follow the readme (https://github.com/microsoft/onnxscript#coding-style) to format the code with lintrunner.
1441168
to
7221f06
Compare
@justinchuby Added the suggestion—please let me know if it aligns with what you had in mind. |
The rules look good to me, although I am not sure it should be part of the default rule set? @gramalingam do you have any recommendations? Should we enable multiple levels of optimization, or is this fine? |
- Fuses Batchnorm node into the following nodes (Conv, ConvTranspose, Gemm)
- Make the rule optional - Improve code/test (checks, type-checking)
7221f06
to
a0b7bee
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot!
Fuses
BatchNormalization
nodes into the following nodes (Conv
,ConvTranspose
,Gemm
)(#2301)