diff --git a/onnxscript/rewriter/fuse_batchnorm.py b/onnxscript/rewriter/fuse_batchnorm.py new file mode 100644 index 000000000..b8b5c143d --- /dev/null +++ b/onnxscript/rewriter/fuse_batchnorm.py @@ -0,0 +1,188 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Fuses BatchNormalization nodes into preceding nodes. Supported fusion patterns: +- BatchNormalization ∘ Conv -> Conv +- BatchNormalization ∘ ConvTranpose -> ConvTranpose +- BatchNormalization ∘ Gemm -> Gemm + +Approach: + Given an inbound operation output: Y = W * X + B + And a BatchNormalization outputs: Y_BN = (gamma * (Y - μ) / std) + β, where std = sqrt(var + eps) + + The fusion updates the inbound weights as follows: + - W_fused = W * (gamma / std) + - B_fused = (B - μ) * (gamma / std) + β +""" + +from abc import ABC, abstractmethod +from typing import Mapping + +import numpy as np + +from onnxscript import ir +from onnxscript.rewriter import pattern as orp + + +def _reshape_for_broadcast(x: np.ndarray, rank: int, axis: int = 1) -> np.ndarray: + # Build shape: 1s everywhere except -1 at the target axis + broadcast_shape = [1 if axis != i else -1 for i in range(rank)] + return np.reshape(x, broadcast_shape) + + +class _FuseBatchNormBase(orp.RewriteRuleClassBase, ABC): + """Interface for BatchNormalization nodes fusion.""" + + def __init__( + self, + op_type: str, + name: str | None = None, + remove_nodes: bool = True, + as_function: bool = False, + ) -> None: + super().__init__(name=name, remove_nodes=remove_nodes, as_function=as_function) + self.op_type = op_type + + @abstractmethod + def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: + """Return the axis along which BatchNorm scale should be broadcasted.""" + + def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Value): + batchnorm_node = batchnorm_out.producer() + # Get BatchNorm parameters + gamma, beta, input_mean, input_var = [ + inp.const_value.numpy() for inp in batchnorm_node.inputs[1:] + ] + + # 1e-5 is the default value for epsilon according to + # https://onnx.ai/onnx/operators/onnx__BatchNormalization.html#attributes + default_eps = ir.Attr("epsilon", ir.AttributeType.FLOAT, 1e-5) + eps = batchnorm_node.attributes.get("epsilon", default_eps).as_float() + + # Compute the scale_factor to update the inbound weights and bias + scale_factor = gamma / np.sqrt(input_var + eps) + + # Update inbound weights + inbound_node = inbound_out.producer() + weights = inbound_node.inputs[1].const_value.numpy() + + # Reshape scale factor so it is broadcastable + axis = self.get_filters_axis(inbound_node.attributes) + fused_weights = ir.tensor( + weights * _reshape_for_broadcast(scale_factor, weights.ndim, axis=axis) + ) + + # Update bias + if len(inbound_node.inputs) > 2: + original_bias = inbound_node.inputs[2].const_value.numpy() + bias_name = inbound_node.inputs[2].name + else: + original_bias = np.zeros_like(input_mean) + bias_name = x.name + "_bias" + fused_bias = ir.tensor((original_bias - input_mean) * scale_factor + beta) + + return op.op( + self.op_type, + inputs=[ + x, + op.initializer(fused_weights, name=inbound_node.inputs[1].name), + op.initializer(fused_bias, name=bias_name), + ], + attributes=inbound_node.attributes, + ) + + def check( + self, context, x, inbound_out: ir.Value, batchnorm_out: ir.Value + ) -> orp.MatchResult: + del context # Unused + check_result = orp.MatchResult() + + inbound_node = inbound_out.producer() + batchnorm_node = batchnorm_out.producer() + + # Check that inbound weights + (inbound bias) + batchnorm params are initializers + # and that they are not graph inputs + initializers = [inbound_node.inputs[1], *batchnorm_node.inputs[1:]] + if len(inbound_node.inputs) > 2: + initializers.append(inbound_node.inputs[2]) + + for initializer in initializers: + if not initializer.is_initializer() or initializer.const_value is None: + return check_result.fail(f"{initializer.name} is not a constant initializer.") + if initializer.is_graph_input(): + return check_result.fail(f"{initializer.name} is a graph input.") + + return check_result + + +class FuseBatchNormIntoConv(_FuseBatchNormBase): + """Replaces ``BatchNormalization(Conv(x))`` with ``Conv(x)``.""" + + def __init__(self): + super().__init__("Conv") + + def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: + return 0 + + def pattern(self, op, x): + return op.BatchNormalization( + op.Conv(x, _allow_other_inputs=True, _outputs=["inbound_out"]), + _allow_other_inputs=True, + _outputs=["batchnorm_out"], + ) + + +class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase): + """Replaces ``BatchNormalization(ConvTranspose(x))`` with ``ConvTranspose(x)``.""" + + def __init__(self): + super().__init__("ConvTranspose") + + def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: + return 1 + + def pattern(self, op, x): + return op.BatchNormalization( + op.ConvTranspose(x, _allow_other_inputs=True, _outputs=["inbound_out"]), + _allow_other_inputs=True, + _outputs=["batchnorm_out"], + ) + + +class FuseBatchNormIntoGemm(_FuseBatchNormBase): + """Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``.""" + + def __init__(self): + super().__init__("Gemm") + + def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: + return ( + 0 if attributes.get("transB") is not None and attributes["transB"].as_int() else 1 + ) + + def pattern(self, op, x): + return op.BatchNormalization( + op.Gemm(x, _allow_other_inputs=True, _outputs=["inbound_out"]), + _allow_other_inputs=True, + _outputs=["batchnorm_out"], + ) + + +fuse_batchnorm_into_conv_rule = FuseBatchNormIntoConv().rule() +fuse_batchnorm_into_convtranspose_rule = FuseBatchNormIntoConvTranspose().rule() +fuse_batchnorm_into_gemm_rule = FuseBatchNormIntoGemm().rule() + + +def fuse_batchnorm_rule_set() -> orp.RewriteRuleSet: + """Returns a set of rewrite rules that fuse BatchNormalization nodes + into preceding nodes such as Conv, ConvTranspose, and Gemm. + + Returns: + RewriteRuleSet + """ + return orp.RewriteRuleSet( + [ + fuse_batchnorm_into_conv_rule, + fuse_batchnorm_into_convtranspose_rule, + fuse_batchnorm_into_gemm_rule, + ] + ) diff --git a/onnxscript/rewriter/fuse_batchnorm_test.py b/onnxscript/rewriter/fuse_batchnorm_test.py new file mode 100644 index 000000000..20d272abd --- /dev/null +++ b/onnxscript/rewriter/fuse_batchnorm_test.py @@ -0,0 +1,257 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import numpy as np +import onnx.checker +import onnx.parser +import parameterized + +from onnxscript import ir +from onnxscript.rewriter import fuse_batchnorm, testing + + +class FuseBatchnormTest(unittest.TestCase): + def _create_batchnorm_params(self, size: int): + return [ + onnx.numpy_helper.from_array( + np.random.randn(size).astype(np.float32), name="gamma" + ), + onnx.numpy_helper.from_array( + np.random.randn(size).astype(np.float32), name="beta" + ), + onnx.numpy_helper.from_array( + np.random.randn(size).astype(np.float32), name="input_mean" + ), + onnx.numpy_helper.from_array( + np.abs(np.random.randn(size)).astype(np.float32), name="input_var" + ), + ] + + @parameterized.parameterized.expand( + [ + ("bias_false", False), + ("bias_true", True), + ] + ) + def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool): + convtranspose_inputs = "X, W" + parameters = ( + "float[32, 64, 3, 3] W, " + "float[64] gamma, " + "float[64] beta, " + "float[64] input_mean, " + "float[64] input_var" + ) + if convtranspose_bias: + parameters += ", float[64] B" + convtranspose_inputs += ", B" + + model_proto = onnx.parser.parse_model(f""" + < ir_version: 7, opset_import: ["" : 17] > + test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y) + <{parameters}> + {{ + X1 = ConvTranspose({convtranspose_inputs}) + Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) + }} + """) + # Add initializers + initializers = [ + onnx.numpy_helper.from_array( + np.random.randn(32, 64, 3, 3).astype(np.float32), name="W" + ), + *self._create_batchnorm_params(size=64), + ] + if convtranspose_bias: + initializers.append( + onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="B") + ) + model_proto.graph.initializer.extend(initializers) + + onnx.checker.check_model(model_proto, True) + model = ir.serde.deserialize_model(model_proto) + + # Apply rule + count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + + # Check that BatchNorm was fused + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 1) + + # Check inference + testing.assert_numerically_equal( + model_proto, model, (np.random.rand(1, 32, 14, 16).astype(np.float32),) + ) + + output_model_proto = ir.serde.serialize_model(model) + onnx.checker.check_model(output_model_proto, True) + + @parameterized.parameterized.expand( + [ + ("bias_false", False), + ("bias_true", True), + ] + ) + def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool): + conv_inputs = "X, W" + parameters = ( + "float[64, 32, 3, 3] W, " + "float[64] gamma, " + "float[64] beta, " + "float[64] input_mean, " + "float[64] input_var" + ) + if conv_bias: + parameters += ", float[64] B" + conv_inputs += ", B" + + model_proto = onnx.parser.parse_model(f""" + < ir_version: 7, opset_import: ["" : 17] > + test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y) + <{parameters}> + {{ + X1 = Conv({conv_inputs}) + Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) + }} + """) + # Add initializers + initializers = [ + onnx.numpy_helper.from_array( + np.random.randn(64, 32, 3, 3).astype(np.float32), name="W" + ), + *self._create_batchnorm_params(size=64), + ] + if conv_bias: + initializers.append( + onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="B") + ) + model_proto.graph.initializer.extend(initializers) + + onnx.checker.check_model(model_proto, True) + model = ir.serde.deserialize_model(model_proto) + + # Apply rule + count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + + # Check that BatchNorm was fused + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 1) + + # Check inference + testing.assert_numerically_equal( + model_proto, model, (np.random.rand(1, 32, 14, 16).astype(np.float32),) + ) + + output_model_proto = ir.serde.serialize_model(model) + onnx.checker.check_model(output_model_proto, True) + + @parameterized.parameterized.expand( + [ + ("bias_false_transB_0", False, 0), + ("bias_true_transB_0", True, 0), + ("bias_false_transB_1", False, 1), + ("bias_true_transB_1", True, 1), + ] + ) + def test_fuse_batchnorm_gemm(self, _: str, gemm_bias: bool, transB: int): + gemm_inputs = "X, W" + parameters = ( + f"float{'[64, 32]' if transB else '[32, 64]'} W, " + "float[64] gamma, " + "float[64] beta, " + "float[64] input_mean, " + "float[64] input_var" + ) + + if gemm_bias: + parameters += ", float[64] B" + gemm_inputs += ", B" + + model_proto = onnx.parser.parse_model(f""" + < ir_version: 7, opset_import: ["" : 17] > + test_model (float[N, 32] X) => (float [N, ?] Y) + <{parameters}> + {{ + X1 = Gemm({gemm_inputs}) + Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) + }} + """) + weights = np.random.randn(32, 64).astype(np.float32) + if transB: + weights = weights.T + + # Add initializers + initializers = [ + onnx.numpy_helper.from_array(weights, name="W"), + *self._create_batchnorm_params(size=64), + ] + if gemm_bias: + initializers.append( + onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="B") + ) + model_proto.graph.initializer.extend(initializers) + + onnx.checker.check_model(model_proto, True) + model = ir.serde.deserialize_model(model_proto) + + # Apply rule + count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + + # Check that BatchNorm was fused + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 1) + + # Check inference + testing.assert_numerically_equal( + model_proto, model, (np.random.rand(1, 32).astype(np.float32),) + ) + + output_model_proto = ir.serde.serialize_model(model) + onnx.checker.check_model(output_model_proto, True) + + def test_fuse_batchnorm_non_initializers(self): + model_proto = onnx.parser.parse_model(""" + < ir_version: 7, opset_import: ["" : 17] > + test_model (float[N, 32, 14, 16] X, float[64, 32, 3, 3] W, float[64] B, + float[64] gamma, float[64] beta, float[64] input_var, + float[64] input_mean) => (float [N, ?, ?, ?] Y) + { + X1 = Conv(X, W, B) + Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) + } + """) + onnx.checker.check_model(model_proto, True) + model = ir.serde.deserialize_model(model_proto) + count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + + # No changes were applied + self.assertEqual(count, 0) + + def test_fuse_batchnorm_graph_inputs(self): + model_proto = onnx.parser.parse_model(""" + < ir_version: 7, opset_import: ["" : 17] > + test_model (float[N, 32, 14, 16] X, float[64, 32, 3, 3] W) => (float [N, ?, ?, ?] Y) + { + X1 = Conv(X, W) + Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) + } + """) + initializers = [ + onnx.numpy_helper.from_array( + np.random.randn(64, 32, 3, 3).astype(np.float32), name="W" + ), + *self._create_batchnorm_params(size=64), + ] + model_proto.graph.initializer.extend(initializers) + onnx.checker.check_model(model_proto, True) + + model = ir.serde.deserialize_model(model_proto) + count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + + # No changes were applied as W is a graph input + self.assertEqual(count, 0) + + +if __name__ == "__main__": + unittest.main()