1
1
# Copyright (c) Microsoft Corporation.
2
2
# Licensed under the MIT License.
3
3
"""Fuses BatchNormalization nodes into preceding nodes. Supported fusion patterns:
4
- - BatchNormalization + Conv -> Conv
5
- - BatchNormalization + ConvTranpose -> ConvTranpose
6
- - BatchNormalization + Gemm -> Gemm
4
+ - BatchNormalization ∘ Conv -> Conv
5
+ - BatchNormalization ∘ ConvTranpose -> ConvTranpose
6
+ - BatchNormalization ∘ Gemm -> Gemm
7
7
8
8
Approach:
9
9
Given an inbound operation output: Y = W * X + B
15
15
"""
16
16
17
17
from abc import ABC , abstractmethod
18
+ from typing import Mapping
18
19
19
20
import numpy as np
20
21
21
22
from onnxscript import ir
22
23
from onnxscript .rewriter import pattern as orp
23
24
24
25
25
- class FuseBatchNormBase (orp .RewriteRuleClassBase , ABC ):
26
+ def _reshape_for_broadcast (x : np .ndarray , rank : int , axis : int = 1 ) -> np .ndarray :
27
+ # Build shape: 1s everywhere except -1 at the target axis
28
+ broadcast_shape = [1 if axis != i else - 1 for i in range (rank )]
29
+ return np .reshape (x , broadcast_shape )
30
+
31
+
32
+ class _FuseBatchNormBase (orp .RewriteRuleClassBase , ABC ):
26
33
"""Interface for BatchNormalization nodes fusion."""
27
34
28
35
def __init__ (
@@ -36,18 +43,9 @@ def __init__(
36
43
self .op_type = op_type
37
44
38
45
@abstractmethod
39
- def get_filters_axis (self , attributes ) -> int :
46
+ def get_filters_axis (self , attributes : Mapping [ str , ir . Attr ] ) -> int :
40
47
"""Return the axis along which BatchNorm scale should be broadcasted."""
41
48
42
- def _reshape_for_broadcast (self , x : np .ndarray , rank : int , axis : int = 1 ) -> np .ndarray :
43
- # Convert axis to positive
44
- if axis < 0 :
45
- axis += rank
46
-
47
- # Build shape: 1s everywhere except -1 at the target axis
48
- broadcast_shape = [1 if axis != i else - 1 for i in range (rank )]
49
- return np .reshape (x , broadcast_shape )
50
-
51
49
def rewrite (self , op , x : ir .Value , inbound_out : ir .Value , batchnorm_out : ir .Value ):
52
50
batchnorm_node = batchnorm_out .producer ()
53
51
# Get BatchNorm parameters
@@ -70,7 +68,7 @@ def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Valu
70
68
# Reshape scale factor so it is broadcastable
71
69
axis = self .get_filters_axis (inbound_node .attributes )
72
70
fused_weights = ir .tensor (
73
- weights * self . _reshape_for_broadcast (scale_factor , weights .ndim , axis = axis )
71
+ weights * _reshape_for_broadcast (scale_factor , weights .ndim , axis = axis )
74
72
)
75
73
76
74
# Update bias
@@ -92,32 +90,37 @@ def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Valu
92
90
attributes = inbound_node .attributes ,
93
91
)
94
92
95
- def check (self , context , x , inbound_out , batchnorm_out ) -> orp .MatchResult :
93
+ def check (
94
+ self , context , x , inbound_out : ir .Value , batchnorm_out : ir .Value
95
+ ) -> orp .MatchResult :
96
96
del context # Unused
97
97
check_result = orp .MatchResult ()
98
98
99
99
inbound_node = inbound_out .producer ()
100
100
batchnorm_node = batchnorm_out .producer ()
101
101
102
102
# Check that inbound weights + (inbound bias) + batchnorm params are initializers
103
+ # and that they are not graph inputs
103
104
initializers = [inbound_node .inputs [1 ], * batchnorm_node .inputs [1 :]]
104
105
if len (inbound_node .inputs ) > 2 :
105
106
initializers .append (inbound_node .inputs [2 ])
106
107
107
108
for initializer in initializers :
108
109
if not initializer .is_initializer () or initializer .const_value is None :
109
- return check_result .fail (f"{ initializer .name } is not a constant initializer" )
110
+ return check_result .fail (f"{ initializer .name } is not a constant initializer." )
111
+ if initializer .is_graph_input ():
112
+ return check_result .fail (f"{ initializer .name } is a graph input." )
110
113
111
114
return check_result
112
115
113
116
114
- class FuseBatchNormIntoConv (FuseBatchNormBase ):
117
+ class FuseBatchNormIntoConv (_FuseBatchNormBase ):
115
118
"""Replaces ``BatchNormalization(Conv(x))`` with ``Conv(x)``."""
116
119
117
120
def __init__ (self ):
118
121
super ().__init__ ("Conv" )
119
122
120
- def get_filters_axis (self , attributes ) -> int :
123
+ def get_filters_axis (self , attributes : Mapping [ str , ir . Attr ] ) -> int :
121
124
return 0
122
125
123
126
def pattern (self , op , x ):
@@ -128,13 +131,13 @@ def pattern(self, op, x):
128
131
)
129
132
130
133
131
- class FuseBatchNormIntoConvTranspose (FuseBatchNormBase ):
134
+ class FuseBatchNormIntoConvTranspose (_FuseBatchNormBase ):
132
135
"""Replaces ``BatchNormalization(ConvTranspose(x))`` with ``ConvTranspose(x)``."""
133
136
134
137
def __init__ (self ):
135
138
super ().__init__ ("ConvTranspose" )
136
139
137
- def get_filters_axis (self , attributes ) -> int :
140
+ def get_filters_axis (self , attributes : Mapping [ str , ir . Attr ] ) -> int :
138
141
return 1
139
142
140
143
def pattern (self , op , x ):
@@ -145,14 +148,16 @@ def pattern(self, op, x):
145
148
)
146
149
147
150
148
- class FuseBatchNormIntoGemm (FuseBatchNormBase ):
151
+ class FuseBatchNormIntoGemm (_FuseBatchNormBase ):
149
152
"""Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``."""
150
153
151
154
def __init__ (self ):
152
155
super ().__init__ ("Gemm" )
153
156
154
- def get_filters_axis (self , attributes ) -> int :
155
- return 0 if attributes .get ("transB" ) is not None and attributes ["transB" ].value else 1
157
+ def get_filters_axis (self , attributes : Mapping [str , ir .Attr ]) -> int :
158
+ return (
159
+ 0 if attributes .get ("transB" ) is not None and attributes ["transB" ].as_int () else 1
160
+ )
156
161
157
162
def pattern (self , op , x ):
158
163
return op .BatchNormalization (
0 commit comments