diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index e39d8d605f4..496bd48d8fd 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -438,6 +438,7 @@ def _tosa_pipeline( ConvertSplitToSlicePass(), QuantizeClampArgumentsPass(), RemoveGetItemPass(), + FuseBatchNorm2dPass(exported_program), DecomposeBatchNormNoStatsPass(), DecomposeLogitPass(), DecomposeMaskedFillPass(), @@ -501,7 +502,6 @@ def _tosa_pipeline( RewriteBoolBitwiseToLogicalPass(), DecomposeRemainderPass(), DecomposeDivTensorModePass(), - FuseBatchNorm2dPass(exported_program), ConvertMmToBmmPass(), DecomposeGluPass(), DecomposeDivPass(), diff --git a/backends/arm/_passes/fuse_batch_norm2d_pass.py b/backends/arm/_passes/fuse_batch_norm2d_pass.py index 13889adb5fe..a13ed9da922 100644 --- a/backends/arm/_passes/fuse_batch_norm2d_pass.py +++ b/backends/arm/_passes/fuse_batch_norm2d_pass.py @@ -56,6 +56,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 != exir_ops.edge.aten._native_batch_norm_legit_no_training.default ): continue + if get_first_fake_tensor(node).dtype == torch.bfloat16: + # Don't fuse if the data type is bfloat16, as the fused weights may + # not be accurate enough and cause significant accuracy drop. + continue # Get data from batchnorm input_node = node.all_input_nodes[0] @@ -153,8 +157,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 if not ( (input_bias_node is None) or ( - isinstance(input_weight_node, Node) - and input_weight_node.op == "placeholder" + isinstance(input_bias_node, Node) + and input_bias_node.op == "placeholder" ) ): raise RuntimeError( diff --git a/backends/arm/test/misc/test_transpose_counts.py b/backends/arm/test/misc/test_transpose_counts.py index 068dd28cabc..899a210bba5 100644 --- a/backends/arm/test/misc/test_transpose_counts.py +++ b/backends/arm/test/misc/test_transpose_counts.py @@ -445,16 +445,16 @@ def forward(self, x): Model6GruLinear(), (torch.randn(2, 16, 8),), 2 ), "model_7_dwconv_batchnorm_linear": TransposeCountCase( - Model7DwConvBatchNormLinear(), (torch.randn(2, 8, 64),), 3 + Model7DwConvBatchNormLinear(), (torch.randn(2, 8, 64),), 1 ), "model_8_conv_batchnorm_maxpool_residual": TransposeCountCase( - Model8ConvBatchNormMaxPoolResidual(), (torch.randn(1, 8, 16, 16),), 6 + Model8ConvBatchNormMaxPoolResidual(), (torch.randn(1, 8, 16, 16),), 4 ), "model_9_dilated_conv_batchnorm_avgpool_residual": TransposeCountCase( - Model9DilatedConvBatchNormAvgPoolResidual(), (torch.randn(1, 8, 16, 16),), 6 + Model9DilatedConvBatchNormAvgPoolResidual(), (torch.randn(1, 8, 16, 16),), 4 ), "model_10_dwconv_batchnorm_linear_cat": TransposeCountCase( - Model10DwConvBatchNormLinearCat(), (torch.randn(2, 8, 64),), 3 + Model10DwConvBatchNormLinearCat(), (torch.randn(2, 8, 64),), 1 ), } diff --git a/backends/arm/test/ops/test_batch_norm.py b/backends/arm/test/ops/test_batch_norm.py index 7bc11efe095..d631b1e1a6d 100644 --- a/backends/arm/test/ops/test_batch_norm.py +++ b/backends/arm/test/ops/test_batch_norm.py @@ -200,6 +200,20 @@ def test_native_batch_norm_legit_no_training_tosa_FP_conv(test_data: Tuple): pipeline.run() +@common.parametrize("test_data", test_data_suite) +def test_native_batch_norm_legit_no_training_tosa_FP_conv_fuses_before_decompose( + test_data: Tuple, +): + test_data, model_params = test_data() + pipeline = TosaPipelineFP[input_t1]( + BatchNorm2dConv(*model_params), + (test_data,), + aten_op=BatchNorm2dConv.aten_ops, + ) + pipeline.count_tosa_ops({"CONV2D": 1, "RSQRT": 0, "SUB": 0}) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) def test_native_batch_norm_legit_no_training_tosa_INT_conv(test_data: Tuple): test_data, model_params = test_data() diff --git a/backends/arm/test/passes/test_fuse_batchnorm_pass.py b/backends/arm/test/passes/test_fuse_batchnorm_pass.py index bef98e478a1..df965fd7620 100644 --- a/backends/arm/test/passes/test_fuse_batchnorm_pass.py +++ b/backends/arm/test/passes/test_fuse_batchnorm_pass.py @@ -54,6 +54,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class MergeOneOfTwoBNBf16(MergeOneOfTwoBN): + ops_before_pass: ClassVar[Dict[str, int]] = MergeOneOfTwoBN.ops_before_pass + ops_after_pass: ClassVar[Dict[str, int]] = MergeOneOfTwoBN.ops_before_pass + + def __init__(self, affine: bool): + super().__init__(affine) + self.to(torch.bfloat16) + + def get_inputs(self) -> input_t: + return (torch.randn(1, 3, 256, 256, dtype=torch.bfloat16),) + + class MergeTwosOfTwoBN(torch.nn.Module): ops_before_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2, @@ -163,3 +175,18 @@ def test_fuse_batch_norm2d_tosa_FP(module: ModuleWithBatchNormAttrs) -> None: passes_with_exported_program=[FuseBatchNorm2dPass], ) pipeline.run() + + +def test_fuse_batch_norm2d_tosa_FP_bf16_skips_fusion() -> None: + module = cast(ModuleWithBatchNormAttrs, MergeOneOfTwoBNBf16(True)) + nn_module = cast(torch.nn.Module, module) + pipeline = PassPipeline[input_t]( + nn_module, + module.get_inputs(), + quantize=False, + ops_before_pass=module.ops_before_pass, + ops_after_pass=module.ops_after_pass, + passes_with_exported_program=[FuseBatchNorm2dPass], + tosa_extensions=["bf16"], + ) + pipeline.run()