Skip to content

Rewriter: Fold Batchnorm nodes #2312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 27, 2025

Conversation

AyoubMDL
Copy link
Contributor

@AyoubMDL AyoubMDL commented May 18, 2025

Fuses BatchNormalization nodes into the following nodes (Conv, ConvTranspose, Gemm)
(#2301)

Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds functionality to fold BatchNormalization nodes into the preceding Conv, ConvTranspose, and Gemm nodes, simplifying the graph and improving runtime efficiency.

  • Implements rewrite rules for Gemm, Conv, and ConvTranspose to absorb batchnorm parameters into weights and biases.
  • Adds comprehensive unit tests covering folding behavior with/without bias and a non-initializer scenario.
  • Registers the new fold_batchnorm rule in the default rewrite-rule set.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
onnxscript/rewriter/fold_batchnorm.py Implements rewrite rules and helper functions to fold BatchNormalization.
onnxscript/rewriter/fold_batchnorm_test.py Adds tests for folding BatchNorm into ConvTranspose, Conv, Gemm, and non-init case.
onnxscript/rewriter/init.py Registers fold_batchnorm in the default rewrite-rule set.
Comments suppressed due to low confidence (2)

onnxscript/rewriter/fold_batchnorm.py:70

  • [nitpick] Using x.name + "_bias" can produce non-obvious or conflicting initializer names. Consider deriving the bias name from the inbound node's output or weight name for clarity and uniqueness.
            bias_name = x.name + "_bias"

onnxscript/rewriter/fold_batchnorm_test.py:38

  • Add a test case where the BatchNormalization node has a non-default epsilon attribute to ensure folding logic correctly respects custom epsilon values.
                Y = BatchNormalization(X1, gamma, beta, input_mean, input_var)

Copy link

codecov bot commented May 18, 2025

Codecov Report

Attention: Patch coverage is 98.75000% with 2 lines in your changes missing coverage. Please review.

Project coverage is 73.81%. Comparing base (a3ce145) to head (a0b7bee).
Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
onnxscript/rewriter/fuse_batchnorm_test.py 97.75% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2312      +/-   ##
==========================================
+ Coverage   73.68%   73.81%   +0.13%     
==========================================
  Files         240      242       +2     
  Lines       30978    31138     +160     
  Branches     3517     3530      +13     
==========================================
+ Hits        22825    22984     +159     
  Misses       6932     6932              
- Partials     1221     1222       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

  1. Could you follow
    class CastIdentity(orp.RewriteRuleClassBase):
    """Replaces ``Cast(., to=to)`` by ``Identity`` if possible."""
    def pattern(self, op, x, to):
    return op.Cast(x, to=to)
    def rewrite(self, op, x: ir.Value, to: ir.Attr):
    return op.Identity(x)
    def check(self, context, x, to) -> orp.MatchResult:
    check_result = orp.MatchResult()
    if x.dtype != to.as_int():
    return check_result.fail("Input and output types are not the same")
    return check_result
    to implement the rules as an pattern.RewriteRuleClassBase? You may create additional subclasses for the different variants for the rules.
  2. In the subclasses I would also create a docstring to describe what the rules do in plain language.
  3. You may also follow the readme (https://github.com/microsoft/onnxscript#coding-style) to format the code with lintrunner.

@AyoubMDL AyoubMDL force-pushed the fold-batchnorm-rewrite branch from 1441168 to 7221f06 Compare May 20, 2025 17:07
@AyoubMDL
Copy link
Contributor Author

  1. Could you follow
    class CastIdentity(orp.RewriteRuleClassBase):
    """Replaces ``Cast(., to=to)`` by ``Identity`` if possible."""
    def pattern(self, op, x, to):
    return op.Cast(x, to=to)
    def rewrite(self, op, x: ir.Value, to: ir.Attr):
    return op.Identity(x)
    def check(self, context, x, to) -> orp.MatchResult:
    check_result = orp.MatchResult()
    if x.dtype != to.as_int():
    return check_result.fail("Input and output types are not the same")
    return check_result

    to implement the rules as an pattern.RewriteRuleClassBase? You may create additional subclasses for the different variants for the rules.

@justinchuby Added the suggestion—please let me know if it aligns with what you had in mind.

@justinchuby
Copy link
Collaborator

The rules look good to me, although I am not sure it should be part of the default rule set? @gramalingam do you have any recommendations? Should we enable multiple levels of optimization, or is this fine?

AyoubMDL added 2 commits May 24, 2025 20:26
- Fuses Batchnorm node into the following nodes (Conv, ConvTranspose,
  Gemm)
- Make the rule optional
- Improve code/test (checks, type-checking)
@AyoubMDL AyoubMDL force-pushed the fold-batchnorm-rewrite branch from 7221f06 to a0b7bee Compare May 24, 2025 19:00
@AyoubMDL AyoubMDL requested a review from justinchuby May 24, 2025 19:02
Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!

@justinchuby justinchuby enabled auto-merge (squash) May 27, 2025 17:21
@justinchuby justinchuby disabled auto-merge May 27, 2025 17:21
@justinchuby justinchuby merged commit 276bf27 into microsoft:main May 27, 2025
26 of 29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

3 participants