Skip to content

Commit a0b7bee

Browse files
committed
review: apply requested changes
- Make the rule optional - Improve code/test (checks, type-checking)
1 parent 4a665e0 commit a0b7bee

File tree

3 files changed

+72
-50
lines changed

3 files changed

+72
-50
lines changed

onnxscript/rewriter/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
broadcast_to_matmul,
1919
cast_constant_of_shape,
2020
collapse_slices,
21-
fuse_batchnorm,
2221
gemm_to_matmul_add,
2322
llama_rule_sets,
2423
no_op,
@@ -29,7 +28,6 @@
2928
_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = (
3029
*no_op.rules.rules, # TODO: merge this rule into constant folding?
3130
*broadcast_to_matmul.rules.rules,
32-
*fuse_batchnorm.fuse_batchnorm_rule_set().rules,
3331
gemm_to_matmul_add.rule, # type: ignore[has-type]
3432
*cast_constant_of_shape.rules.rules,
3533
*collapse_slices.rules.rules,

onnxscript/rewriter/fuse_batchnorm.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33
"""Fuses BatchNormalization nodes into preceding nodes. Supported fusion patterns:
4-
- BatchNormalization + Conv -> Conv
5-
- BatchNormalization + ConvTranpose -> ConvTranpose
6-
- BatchNormalization + Gemm -> Gemm
4+
- BatchNormalization Conv -> Conv
5+
- BatchNormalization ConvTranpose -> ConvTranpose
6+
- BatchNormalization Gemm -> Gemm
77
88
Approach:
99
Given an inbound operation output: Y = W * X + B
@@ -15,14 +15,21 @@
1515
"""
1616

1717
from abc import ABC, abstractmethod
18+
from typing import Mapping
1819

1920
import numpy as np
2021

2122
from onnxscript import ir
2223
from onnxscript.rewriter import pattern as orp
2324

2425

25-
class FuseBatchNormBase(orp.RewriteRuleClassBase, ABC):
26+
def _reshape_for_broadcast(x: np.ndarray, rank: int, axis: int = 1) -> np.ndarray:
27+
# Build shape: 1s everywhere except -1 at the target axis
28+
broadcast_shape = [1 if axis != i else -1 for i in range(rank)]
29+
return np.reshape(x, broadcast_shape)
30+
31+
32+
class _FuseBatchNormBase(orp.RewriteRuleClassBase, ABC):
2633
"""Interface for BatchNormalization nodes fusion."""
2734

2835
def __init__(
@@ -36,18 +43,9 @@ def __init__(
3643
self.op_type = op_type
3744

3845
@abstractmethod
39-
def get_filters_axis(self, attributes) -> int:
46+
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
4047
"""Return the axis along which BatchNorm scale should be broadcasted."""
4148

42-
def _reshape_for_broadcast(self, x: np.ndarray, rank: int, axis: int = 1) -> np.ndarray:
43-
# Convert axis to positive
44-
if axis < 0:
45-
axis += rank
46-
47-
# Build shape: 1s everywhere except -1 at the target axis
48-
broadcast_shape = [1 if axis != i else -1 for i in range(rank)]
49-
return np.reshape(x, broadcast_shape)
50-
5149
def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Value):
5250
batchnorm_node = batchnorm_out.producer()
5351
# Get BatchNorm parameters
@@ -70,7 +68,7 @@ def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Valu
7068
# Reshape scale factor so it is broadcastable
7169
axis = self.get_filters_axis(inbound_node.attributes)
7270
fused_weights = ir.tensor(
73-
weights * self._reshape_for_broadcast(scale_factor, weights.ndim, axis=axis)
71+
weights * _reshape_for_broadcast(scale_factor, weights.ndim, axis=axis)
7472
)
7573

7674
# Update bias
@@ -92,32 +90,37 @@ def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Valu
9290
attributes=inbound_node.attributes,
9391
)
9492

95-
def check(self, context, x, inbound_out, batchnorm_out) -> orp.MatchResult:
93+
def check(
94+
self, context, x, inbound_out: ir.Value, batchnorm_out: ir.Value
95+
) -> orp.MatchResult:
9696
del context # Unused
9797
check_result = orp.MatchResult()
9898

9999
inbound_node = inbound_out.producer()
100100
batchnorm_node = batchnorm_out.producer()
101101

102102
# Check that inbound weights + (inbound bias) + batchnorm params are initializers
103+
# and that they are not graph inputs
103104
initializers = [inbound_node.inputs[1], *batchnorm_node.inputs[1:]]
104105
if len(inbound_node.inputs) > 2:
105106
initializers.append(inbound_node.inputs[2])
106107

107108
for initializer in initializers:
108109
if not initializer.is_initializer() or initializer.const_value is None:
109-
return check_result.fail(f"{initializer.name} is not a constant initializer")
110+
return check_result.fail(f"{initializer.name} is not a constant initializer.")
111+
if initializer.is_graph_input():
112+
return check_result.fail(f"{initializer.name} is a graph input.")
110113

111114
return check_result
112115

113116

114-
class FuseBatchNormIntoConv(FuseBatchNormBase):
117+
class FuseBatchNormIntoConv(_FuseBatchNormBase):
115118
"""Replaces ``BatchNormalization(Conv(x))`` with ``Conv(x)``."""
116119

117120
def __init__(self):
118121
super().__init__("Conv")
119122

120-
def get_filters_axis(self, attributes) -> int:
123+
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
121124
return 0
122125

123126
def pattern(self, op, x):
@@ -128,13 +131,13 @@ def pattern(self, op, x):
128131
)
129132

130133

131-
class FuseBatchNormIntoConvTranspose(FuseBatchNormBase):
134+
class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase):
132135
"""Replaces ``BatchNormalization(ConvTranspose(x))`` with ``ConvTranspose(x)``."""
133136

134137
def __init__(self):
135138
super().__init__("ConvTranspose")
136139

137-
def get_filters_axis(self, attributes) -> int:
140+
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
138141
return 1
139142

140143
def pattern(self, op, x):
@@ -145,14 +148,16 @@ def pattern(self, op, x):
145148
)
146149

147150

148-
class FuseBatchNormIntoGemm(FuseBatchNormBase):
151+
class FuseBatchNormIntoGemm(_FuseBatchNormBase):
149152
"""Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``."""
150153

151154
def __init__(self):
152155
super().__init__("Gemm")
153156

154-
def get_filters_axis(self, attributes) -> int:
155-
return 0 if attributes.get("transB") is not None and attributes["transB"].value else 1
157+
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
158+
return (
159+
0 if attributes.get("transB") is not None and attributes["transB"].as_int() else 1
160+
)
156161

157162
def pattern(self, op, x):
158163
return op.BatchNormalization(

onnxscript/rewriter/fuse_batchnorm_test.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,22 @@
1212

1313

1414
class FuseBatchnormTest(unittest.TestCase):
15+
def _create_batchnorm_params(self, size: int):
16+
return [
17+
onnx.numpy_helper.from_array(
18+
np.random.randn(size).astype(np.float32), name="gamma"
19+
),
20+
onnx.numpy_helper.from_array(
21+
np.random.randn(size).astype(np.float32), name="beta"
22+
),
23+
onnx.numpy_helper.from_array(
24+
np.random.randn(size).astype(np.float32), name="input_mean"
25+
),
26+
onnx.numpy_helper.from_array(
27+
np.abs(np.random.randn(size)).astype(np.float32), name="input_var"
28+
),
29+
]
30+
1531
@parameterized.parameterized.expand(
1632
[
1733
("bias_false", False),
@@ -45,14 +61,7 @@ def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool):
4561
onnx.numpy_helper.from_array(
4662
np.random.randn(32, 64, 3, 3).astype(np.float32), name="W"
4763
),
48-
onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="gamma"),
49-
onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="beta"),
50-
onnx.numpy_helper.from_array(
51-
np.random.randn(64).astype(np.float32), name="input_mean"
52-
),
53-
onnx.numpy_helper.from_array(
54-
np.abs(np.random.randn(64)).astype(np.float32), name="input_var"
55-
),
64+
*self._create_batchnorm_params(size=64),
5665
]
5766
if convtranspose_bias:
5867
initializers.append(
@@ -111,14 +120,7 @@ def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool):
111120
onnx.numpy_helper.from_array(
112121
np.random.randn(64, 32, 3, 3).astype(np.float32), name="W"
113122
),
114-
onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="gamma"),
115-
onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="beta"),
116-
onnx.numpy_helper.from_array(
117-
np.random.randn(64).astype(np.float32), name="input_mean"
118-
),
119-
onnx.numpy_helper.from_array(
120-
np.abs(np.random.randn(64)).astype(np.float32), name="input_var"
121-
),
123+
*self._create_batchnorm_params(size=64),
122124
]
123125
if conv_bias:
124126
initializers.append(
@@ -182,14 +184,7 @@ def test_fuse_batchnorm_gemm(self, _: str, gemm_bias: bool, transB: int):
182184
# Add initializers
183185
initializers = [
184186
onnx.numpy_helper.from_array(weights, name="W"),
185-
onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="gamma"),
186-
onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="beta"),
187-
onnx.numpy_helper.from_array(
188-
np.random.randn(64).astype(np.float32), name="input_mean"
189-
),
190-
onnx.numpy_helper.from_array(
191-
np.abs(np.random.randn(64)).astype(np.float32), name="input_var"
192-
),
187+
*self._create_batchnorm_params(size=64),
193188
]
194189
if gemm_bias:
195190
initializers.append(
@@ -233,6 +228,30 @@ def test_fuse_batchnorm_non_initializers(self):
233228
# No changes were applied
234229
self.assertEqual(count, 0)
235230

231+
def test_fuse_batchnorm_graph_inputs(self):
232+
model_proto = onnx.parser.parse_model("""
233+
< ir_version: 7, opset_import: ["" : 17] >
234+
test_model (float[N, 32, 14, 16] X, float[64, 32, 3, 3] W) => (float [N, ?, ?, ?] Y)
235+
{
236+
X1 = Conv(X, W)
237+
Y = BatchNormalization(X1, gamma, beta, input_mean, input_var)
238+
}
239+
""")
240+
initializers = [
241+
onnx.numpy_helper.from_array(
242+
np.random.randn(64, 32, 3, 3).astype(np.float32), name="W"
243+
),
244+
*self._create_batchnorm_params(size=64),
245+
]
246+
model_proto.graph.initializer.extend(initializers)
247+
onnx.checker.check_model(model_proto, True)
248+
249+
model = ir.serde.deserialize_model(model_proto)
250+
count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model)
251+
252+
# No changes were applied as W is a graph input
253+
self.assertEqual(count, 0)
254+
236255

237256
if __name__ == "__main__":
238257
unittest.main()

0 commit comments

Comments
 (0)