Skip to content

Commit 276bf27

Browse files
authored
Rewriter: Fold Batchnorm nodes (#2312)
Fuses `BatchNormalization` nodes into the following nodes (`Conv`, `ConvTranspose`, `Gemm`) (#2301)
1 parent 06bb751 commit 276bf27

File tree

2 files changed

+445
-0
lines changed

2 files changed

+445
-0
lines changed

onnxscript/rewriter/fuse_batchnorm.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Fuses BatchNormalization nodes into preceding nodes. Supported fusion patterns:
4+
- BatchNormalization ∘ Conv -> Conv
5+
- BatchNormalization ∘ ConvTranpose -> ConvTranpose
6+
- BatchNormalization ∘ Gemm -> Gemm
7+
8+
Approach:
9+
Given an inbound operation output: Y = W * X + B
10+
And a BatchNormalization outputs: Y_BN = (gamma * (Y - μ) / std) + β, where std = sqrt(var + eps)
11+
12+
The fusion updates the inbound weights as follows:
13+
- W_fused = W * (gamma / std)
14+
- B_fused = (B - μ) * (gamma / std) + β
15+
"""
16+
17+
from abc import ABC, abstractmethod
18+
from typing import Mapping
19+
20+
import numpy as np
21+
22+
from onnxscript import ir
23+
from onnxscript.rewriter import pattern as orp
24+
25+
26+
def _reshape_for_broadcast(x: np.ndarray, rank: int, axis: int = 1) -> np.ndarray:
27+
# Build shape: 1s everywhere except -1 at the target axis
28+
broadcast_shape = [1 if axis != i else -1 for i in range(rank)]
29+
return np.reshape(x, broadcast_shape)
30+
31+
32+
class _FuseBatchNormBase(orp.RewriteRuleClassBase, ABC):
33+
"""Interface for BatchNormalization nodes fusion."""
34+
35+
def __init__(
36+
self,
37+
op_type: str,
38+
name: str | None = None,
39+
remove_nodes: bool = True,
40+
as_function: bool = False,
41+
) -> None:
42+
super().__init__(name=name, remove_nodes=remove_nodes, as_function=as_function)
43+
self.op_type = op_type
44+
45+
@abstractmethod
46+
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
47+
"""Return the axis along which BatchNorm scale should be broadcasted."""
48+
49+
def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Value):
50+
batchnorm_node = batchnorm_out.producer()
51+
# Get BatchNorm parameters
52+
gamma, beta, input_mean, input_var = [
53+
inp.const_value.numpy() for inp in batchnorm_node.inputs[1:]
54+
]
55+
56+
# 1e-5 is the default value for epsilon according to
57+
# https://onnx.ai/onnx/operators/onnx__BatchNormalization.html#attributes
58+
default_eps = ir.Attr("epsilon", ir.AttributeType.FLOAT, 1e-5)
59+
eps = batchnorm_node.attributes.get("epsilon", default_eps).as_float()
60+
61+
# Compute the scale_factor to update the inbound weights and bias
62+
scale_factor = gamma / np.sqrt(input_var + eps)
63+
64+
# Update inbound weights
65+
inbound_node = inbound_out.producer()
66+
weights = inbound_node.inputs[1].const_value.numpy()
67+
68+
# Reshape scale factor so it is broadcastable
69+
axis = self.get_filters_axis(inbound_node.attributes)
70+
fused_weights = ir.tensor(
71+
weights * _reshape_for_broadcast(scale_factor, weights.ndim, axis=axis)
72+
)
73+
74+
# Update bias
75+
if len(inbound_node.inputs) > 2:
76+
original_bias = inbound_node.inputs[2].const_value.numpy()
77+
bias_name = inbound_node.inputs[2].name
78+
else:
79+
original_bias = np.zeros_like(input_mean)
80+
bias_name = x.name + "_bias"
81+
fused_bias = ir.tensor((original_bias - input_mean) * scale_factor + beta)
82+
83+
return op.op(
84+
self.op_type,
85+
inputs=[
86+
x,
87+
op.initializer(fused_weights, name=inbound_node.inputs[1].name),
88+
op.initializer(fused_bias, name=bias_name),
89+
],
90+
attributes=inbound_node.attributes,
91+
)
92+
93+
def check(
94+
self, context, x, inbound_out: ir.Value, batchnorm_out: ir.Value
95+
) -> orp.MatchResult:
96+
del context # Unused
97+
check_result = orp.MatchResult()
98+
99+
inbound_node = inbound_out.producer()
100+
batchnorm_node = batchnorm_out.producer()
101+
102+
# Check that inbound weights + (inbound bias) + batchnorm params are initializers
103+
# and that they are not graph inputs
104+
initializers = [inbound_node.inputs[1], *batchnorm_node.inputs[1:]]
105+
if len(inbound_node.inputs) > 2:
106+
initializers.append(inbound_node.inputs[2])
107+
108+
for initializer in initializers:
109+
if not initializer.is_initializer() or initializer.const_value is None:
110+
return check_result.fail(f"{initializer.name} is not a constant initializer.")
111+
if initializer.is_graph_input():
112+
return check_result.fail(f"{initializer.name} is a graph input.")
113+
114+
return check_result
115+
116+
117+
class FuseBatchNormIntoConv(_FuseBatchNormBase):
118+
"""Replaces ``BatchNormalization(Conv(x))`` with ``Conv(x)``."""
119+
120+
def __init__(self):
121+
super().__init__("Conv")
122+
123+
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
124+
return 0
125+
126+
def pattern(self, op, x):
127+
return op.BatchNormalization(
128+
op.Conv(x, _allow_other_inputs=True, _outputs=["inbound_out"]),
129+
_allow_other_inputs=True,
130+
_outputs=["batchnorm_out"],
131+
)
132+
133+
134+
class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase):
135+
"""Replaces ``BatchNormalization(ConvTranspose(x))`` with ``ConvTranspose(x)``."""
136+
137+
def __init__(self):
138+
super().__init__("ConvTranspose")
139+
140+
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
141+
return 1
142+
143+
def pattern(self, op, x):
144+
return op.BatchNormalization(
145+
op.ConvTranspose(x, _allow_other_inputs=True, _outputs=["inbound_out"]),
146+
_allow_other_inputs=True,
147+
_outputs=["batchnorm_out"],
148+
)
149+
150+
151+
class FuseBatchNormIntoGemm(_FuseBatchNormBase):
152+
"""Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``."""
153+
154+
def __init__(self):
155+
super().__init__("Gemm")
156+
157+
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
158+
return (
159+
0 if attributes.get("transB") is not None and attributes["transB"].as_int() else 1
160+
)
161+
162+
def pattern(self, op, x):
163+
return op.BatchNormalization(
164+
op.Gemm(x, _allow_other_inputs=True, _outputs=["inbound_out"]),
165+
_allow_other_inputs=True,
166+
_outputs=["batchnorm_out"],
167+
)
168+
169+
170+
fuse_batchnorm_into_conv_rule = FuseBatchNormIntoConv().rule()
171+
fuse_batchnorm_into_convtranspose_rule = FuseBatchNormIntoConvTranspose().rule()
172+
fuse_batchnorm_into_gemm_rule = FuseBatchNormIntoGemm().rule()
173+
174+
175+
def fuse_batchnorm_rule_set() -> orp.RewriteRuleSet:
176+
"""Returns a set of rewrite rules that fuse BatchNormalization nodes
177+
into preceding nodes such as Conv, ConvTranspose, and Gemm.
178+
179+
Returns:
180+
RewriteRuleSet
181+
"""
182+
return orp.RewriteRuleSet(
183+
[
184+
fuse_batchnorm_into_conv_rule,
185+
fuse_batchnorm_into_convtranspose_rule,
186+
fuse_batchnorm_into_gemm_rule,
187+
]
188+
)

0 commit comments

Comments
 (0)