Skip to content

Commit 7221f06

Browse files
committed
feat(rewriter): introduce fuse batchnorm
- Fuses Batchnorm node into the following nodes (Conv, ConvTranspose, Gemm)
1 parent 644e30c commit 7221f06

File tree

3 files changed

+423
-0
lines changed

3 files changed

+423
-0
lines changed

onnxscript/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
broadcast_to_matmul,
1919
cast_constant_of_shape,
2020
collapse_slices,
21+
fuse_batchnorm,
2122
gemm_to_matmul_add,
2223
llama_rule_sets,
2324
no_op,
@@ -28,6 +29,7 @@
2829
_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = (
2930
*no_op.rules.rules, # TODO: merge this rule into constant folding?
3031
*broadcast_to_matmul.rules.rules,
32+
*fuse_batchnorm.fuse_batchnorm_rule_set().rules,
3133
gemm_to_matmul_add.rule, # type: ignore[has-type]
3234
*cast_constant_of_shape.rules.rules,
3335
*collapse_slices.rules.rules,

onnxscript/rewriter/fuse_batchnorm.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Fuses BatchNormalization nodes into preceding nodes. Supported fusion patterns:
4+
- BatchNormalization + Conv -> Conv
5+
- BatchNormalization + ConvTranpose -> ConvTranpose
6+
- BatchNormalization + Gemm -> Gemm
7+
8+
Approach:
9+
Given an inbound operation output: Y = W * X + B
10+
And a BatchNormalization outputs: Y_BN = (gamma * (Y - μ) / std) + β, where std = sqrt(var + eps)
11+
12+
The fusion updates the inbound weights as follows:
13+
- W_fused = W * (gamma / std)
14+
- B_fused = (B - μ) * (gamma / std) + β
15+
"""
16+
17+
from abc import ABC, abstractmethod
18+
19+
import numpy as np
20+
21+
from onnxscript import ir
22+
from onnxscript.rewriter import pattern as orp
23+
24+
25+
class FuseBatchNormBase(orp.RewriteRuleClassBase, ABC):
26+
"""Interface for BatchNormalization nodes fusion."""
27+
28+
def __init__(
29+
self,
30+
op_type: str,
31+
name: str | None = None,
32+
remove_nodes: bool = True,
33+
as_function: bool = False,
34+
) -> None:
35+
super().__init__(name=name, remove_nodes=remove_nodes, as_function=as_function)
36+
self.op_type = op_type
37+
38+
@abstractmethod
39+
def get_filters_axis(self, attributes) -> int:
40+
"""Return the axis along which BatchNorm scale should be broadcasted."""
41+
42+
def _reshape_for_broadcast(self, x: np.ndarray, rank: int, axis: int = 1) -> np.ndarray:
43+
# Convert axis to positive
44+
if axis < 0:
45+
axis += rank
46+
47+
# Build shape: 1s everywhere except -1 at the target axis
48+
broadcast_shape = [1 if axis != i else -1 for i in range(rank)]
49+
return np.reshape(x, broadcast_shape)
50+
51+
def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Value):
52+
batchnorm_node = batchnorm_out.producer()
53+
# Get BatchNorm parameters
54+
gamma, beta, input_mean, input_var = [
55+
inp.const_value.numpy() for inp in batchnorm_node.inputs[1:]
56+
]
57+
58+
# 1e-5 is the default value for epsilon according to
59+
# https://onnx.ai/onnx/operators/onnx__BatchNormalization.html#attributes
60+
default_eps = ir.Attr("epsilon", ir.AttributeType.FLOAT, 1e-5)
61+
eps = batchnorm_node.attributes.get("epsilon", default_eps).as_float()
62+
63+
# Compute the scale_factor to update the inbound weights and bias
64+
scale_factor = gamma / np.sqrt(input_var + eps)
65+
66+
# Update inbound weights
67+
inbound_node = inbound_out.producer()
68+
weights = inbound_node.inputs[1].const_value.numpy()
69+
70+
# Reshape scale factor so it is broadcastable
71+
axis = self.get_filters_axis(inbound_node.attributes)
72+
fused_weights = ir.tensor(
73+
weights * self._reshape_for_broadcast(scale_factor, weights.ndim, axis=axis)
74+
)
75+
76+
# Update bias
77+
if len(inbound_node.inputs) > 2:
78+
original_bias = inbound_node.inputs[2].const_value.numpy()
79+
bias_name = inbound_node.inputs[2].name
80+
else:
81+
original_bias = np.zeros_like(input_mean)
82+
bias_name = x.name + "_bias"
83+
fused_bias = ir.tensor((original_bias - input_mean) * scale_factor + beta)
84+
85+
return op.op(
86+
self.op_type,
87+
inputs=[
88+
x,
89+
op.initializer(fused_weights, name=inbound_node.inputs[1].name),
90+
op.initializer(fused_bias, name=bias_name),
91+
],
92+
attributes=inbound_node.attributes,
93+
)
94+
95+
def check(self, context, x, inbound_out, batchnorm_out) -> orp.MatchResult:
96+
del context # Unused
97+
check_result = orp.MatchResult()
98+
99+
inbound_node = inbound_out.producer()
100+
batchnorm_node = batchnorm_out.producer()
101+
102+
# Check that inbound weights + (inbound bias) + batchnorm params are initializers
103+
initializers = [inbound_node.inputs[1], *batchnorm_node.inputs[1:]]
104+
if len(inbound_node.inputs) > 2:
105+
initializers.append(inbound_node.inputs[2])
106+
107+
for initializer in initializers:
108+
if not initializer.is_initializer() or initializer.const_value is None:
109+
return check_result.fail(f"{initializer.name} is not a constant initializer")
110+
111+
return check_result
112+
113+
114+
class FuseBatchNormIntoConv(FuseBatchNormBase):
115+
"""Replaces ``BatchNormalization(Conv(x))`` with ``Conv(x)``."""
116+
117+
def __init__(self):
118+
super().__init__("Conv")
119+
120+
def get_filters_axis(self, attributes) -> int:
121+
return 0
122+
123+
def pattern(self, op, x):
124+
return op.BatchNormalization(
125+
op.Conv(x, _allow_other_inputs=True, _outputs=["inbound_out"]),
126+
_allow_other_inputs=True,
127+
_outputs=["batchnorm_out"],
128+
)
129+
130+
131+
class FuseBatchNormIntoConvTranspose(FuseBatchNormBase):
132+
"""Replaces ``BatchNormalization(ConvTranspose(x))`` with ``ConvTranspose(x)``."""
133+
134+
def __init__(self):
135+
super().__init__("ConvTranspose")
136+
137+
def get_filters_axis(self, attributes) -> int:
138+
return 1
139+
140+
def pattern(self, op, x):
141+
return op.BatchNormalization(
142+
op.ConvTranspose(x, _allow_other_inputs=True, _outputs=["inbound_out"]),
143+
_allow_other_inputs=True,
144+
_outputs=["batchnorm_out"],
145+
)
146+
147+
148+
class FuseBatchNormIntoGemm(FuseBatchNormBase):
149+
"""Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``."""
150+
151+
def __init__(self):
152+
super().__init__("Gemm")
153+
154+
def get_filters_axis(self, attributes) -> int:
155+
return 0 if attributes.get("transB") is not None and attributes["transB"].value else 1
156+
157+
def pattern(self, op, x):
158+
return op.BatchNormalization(
159+
op.Gemm(x, _allow_other_inputs=True, _outputs=["inbound_out"]),
160+
_allow_other_inputs=True,
161+
_outputs=["batchnorm_out"],
162+
)
163+
164+
165+
fuse_batchnorm_into_conv_rule = FuseBatchNormIntoConv().rule()
166+
fuse_batchnorm_into_convtranspose_rule = FuseBatchNormIntoConvTranspose().rule()
167+
fuse_batchnorm_into_gemm_rule = FuseBatchNormIntoGemm().rule()
168+
169+
170+
def fuse_batchnorm_rule_set() -> orp.RewriteRuleSet:
171+
"""Returns a set of rewrite rules that fuse BatchNormalization nodes
172+
into preceding nodes such as Conv, ConvTranspose, and Gemm.
173+
174+
Returns:
175+
RewriteRuleSet
176+
"""
177+
return orp.RewriteRuleSet(
178+
[
179+
fuse_batchnorm_into_conv_rule,
180+
fuse_batchnorm_into_convtranspose_rule,
181+
fuse_batchnorm_into_gemm_rule,
182+
]
183+
)

0 commit comments

Comments
 (0)