diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index d3c0ae0a1b3..f9cf838a526 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -23,6 +23,7 @@ from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa from .convert_to_clamp import ConvertToClampPass # noqa from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa +from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa from .decompose_div_pass import DecomposeDivPass # noqa from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 2cefd3bdaca..f0a86b1ce84 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -26,6 +26,7 @@ ConvertSqueezesToViewPass, ConvertToClampPass, DecomposeAvgPool2d, + DecomposeBatchNormNoStatsPass, DecomposeCosineSimilarityPass, DecomposeDivPass, DecomposeEmbeddingPass, @@ -164,6 +165,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(DecomposeLeakyReLUPass()) self.add_pass(DecomposeGroupNormPass()) self.add_pass(DecomposeLayerNormPass()) + self.add_pass(DecomposeBatchNormNoStatsPass()) self.add_pass(DecomposeVarPass()) self.add_pass( DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec) diff --git a/backends/arm/_passes/decompose_batch_norm_no_stats.py b/backends/arm/_passes/decompose_batch_norm_no_stats.py new file mode 100644 index 00000000000..5fdb8db2d7c --- /dev/null +++ b/backends/arm/_passes/decompose_batch_norm_no_stats.py @@ -0,0 +1,219 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import operator + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import PassResult + + +class DecomposeBatchNormNoStatsPass(ArmPass): + """ + Decompose BatchNorm2d(track_running_stats=False) (aten._native_batch_norm_legit_no_training) + into a sequence of elementwise operations: + + # let input = x, rm = running_mean, rv = running_var, eps: float + rm_view = view(rm, weights_shape) + rv_view = view(rv, weights_shape) + centered = sub(x, rm_view) + eps_full = full(eps_shape, eps) + var_eps = add(rv_view, eps_full) + inv_sqrt = rsqrt(var_eps) + normed = mul(centered, inv_sqrt) + weighted = mul(normed, w_view) + biased = add(weighted, b_view) + + Source: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 + bn_ops = ( + exir_ops.edge.aten._native_batch_norm_legit.no_stats, + exir_ops.edge.aten._native_batch_norm_legit_no_training.default, + torch.ops.aten._native_batch_norm_legit_no_training.default, + torch.ops.aten.batch_norm.default, + torch.ops.aten.native_batch_norm.default, + ) + + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in bn_ops: + continue + + if node.target in ( + torch.ops.aten.batch_norm.default, + torch.ops.aten.native_batch_norm.default, + ): + # signature: (input, weight, bias, mean, var, training, momentum, eps, cudnn_enabled) + # pos‐arg 5 is training + training = node.kwargs.get("training", False) + if len(node.args) > 5: + training = node.args[5] + if training: + # skip training‐mode batchnorm + continue + + # Extract args + args = node.args + meta = node.meta + + # Default eps + eps: float = torch.finfo().eps + # weight and bias may be None + x = args[0] + weight = args[1] if len(args) > 1 else None + bias = args[2] if len(args) > 2 else None + running_mean = args[3] + running_var = args[4] + if len(args) > 6: + eps = args[6] + + # Determine shapes + val = meta.get("val") + ref_tensor = val[0] if isinstance(val, tuple) else val + shape = tuple(ref_tensor.size()) + dtype = ref_tensor.dtype + rank = len(shape) + + # channel dimension is 1 for BatchNorm2d + channel_axis = 1 + weights_shape = [1] * rank + weights_shape[channel_axis] = shape[channel_axis] + num_features = shape[channel_axis] + + # Ops to use + sub_op = exir_ops.edge.aten.sub.Tensor + view_op = exir_ops.edge.aten.view_copy.default + full_op = exir_ops.edge.aten.full.default + add_op = exir_ops.edge.aten.add.Tensor + rsqrt_op = exir_ops.edge.aten.rsqrt.default + mul_op = exir_ops.edge.aten.mul.Tensor + + # Begin decomposition + with graph_module.graph.inserting_before(node): + # reshape running stats + rm_view = create_node( + graph_module.graph, + view_op, + args=(running_mean, weights_shape), + from_node=node, + ) + rv_view = create_node( + graph_module.graph, + view_op, + args=(running_var, weights_shape), + from_node=node, + ) + # center input + centered = create_node( + graph_module.graph, + sub_op, + args=(x, rm_view), + from_node=node, + ) + # epsilon tensor + eps_shape = [1] * rank + eps_full = create_node( + graph_module.graph, + full_op, + args=(eps_shape, eps), + kwargs={"dtype": dtype}, + from_node=node, + ) + # var + eps + var_eps = create_node( + graph_module.graph, + add_op, + args=(rv_view, eps_full), + from_node=node, + ) + # inverse sqrt + inv_sqrt = create_node( + graph_module.graph, + rsqrt_op, + args=(var_eps,), + from_node=node, + ) + # normalized + normed = create_node( + graph_module.graph, + mul_op, + args=(centered, inv_sqrt), + from_node=node, + ) + + # weight + if weight is None: + one = create_node( + graph_module.graph, + full_op, + args=([num_features], 1), + kwargs={"dtype": dtype}, + from_node=node, + ) + w_view = create_node( + graph_module.graph, + view_op, + args=(one, weights_shape), + from_node=node, + ) + else: + w_view = create_node( + graph_module.graph, + view_op, + args=(weight, weights_shape), + from_node=node, + ) + weighted = create_node( + graph_module.graph, + mul_op, + args=(normed, w_view), + from_node=node, + ) + + # bias + if bias is None: + zero = create_node( + graph_module.graph, + full_op, + args=([num_features], 0), + kwargs={"dtype": dtype}, + from_node=node, + ) + b_view = create_node( + graph_module.graph, + view_op, + args=(zero, weights_shape), + from_node=node, + ) + else: + b_view = create_node( + graph_module.graph, + view_op, + args=(bias, weights_shape), + from_node=node, + ) + final_out = create_node( + graph_module.graph, + add_op, + args=(weighted, b_view), + from_node=node, + ) + + users = [u for u in node.users if u is not node] + node.replace_all_uses_with(final_out) + for u in users: + if u.target == operator.getitem: + u.replace_all_uses_with(final_out) + graph_module.graph.erase_node(node) + graph_module.graph.eliminate_dead_code() + + graph_module.recompile() + new_gm = super().call(graph_module).graph_module + return PassResult(new_gm, True) diff --git a/backends/arm/test/ops/test_batch_norm.py b/backends/arm/test/ops/test_batch_norm.py index 7f98a48b203..eb0d4306e6e 100644 --- a/backends/arm/test/ops/test_batch_norm.py +++ b/backends/arm/test/ops/test_batch_norm.py @@ -224,6 +224,8 @@ class BatchNorm2dNoStats(torch.nn.Module): Decomposes into _native_batch_norm_legit.no_stats """ + aten_ops = ["torch.ops.aten.batch_norm.default"] + def __init__( self, num_features: int, @@ -250,29 +252,60 @@ def forward(self, x): return self.batch_norm_2d(x) -@pytest.mark.skip( - reason="MLETORCH-999: Add support for _native_batch_norm_legit.no_stats." -) -def test_native_batch_norm_legit_no_stats_tosa_MI(): - pass +@common.parametrize("test_data", test_data_suite) +def test_native_batch_norm_legit_no_stats_tosa_MI(test_data: Tuple): + test_data, model_params = test_data() + pipeline = TosaPipelineMI[input_t1]( + BatchNorm2dNoStats(*model_params), + (test_data,), + aten_op=BatchNorm2dNoStats.aten_ops, + ) + pipeline.run() @pytest.mark.skip( reason="MLETORCH-999: Add support for _native_batch_norm_legit.no_stats." ) -def test_native_batch_norm_legit_no_stats_tosa_BI(): - pass +def test_native_batch_norm_legit_no_stats_tosa_BI(test_data: Tuple): + test_data, model_params = test_data() + pipeline = TosaPipelineBI[input_t1]( + BatchNorm2dNoStats(*model_params), + (test_data,), + aten_op=BatchNorm2dNoStats.aten_ops, + qtol=1, + ) + pipeline.run() @pytest.mark.skip( reason="MLETORCH-999: Add support for _native_batch_norm_legit.no_stats." ) -def test_native_batch_norm_legit_no_stats_u55_BI(): - pass +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +def test_native_batch_norm_legit_no_stats_u55_BI(test_data: Tuple): + test_data, model_params = test_data() + pipeline = EthosU55PipelineBI[input_t1]( + BatchNorm2dNoStats(*model_params), + (test_data,), + aten_op=BatchNorm2dNoStats.aten_ops, + run_on_fvp=True, + qtol=1, + ) + pipeline.run() @pytest.mark.skip( reason="MLETORCH-999: Add support for _native_batch_norm_legit.no_stats." ) -def test_native_batch_norm_legit_no_stats_u85_BI(): - pass +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +def test_native_batch_norm_legit_no_stats_u85_BI(test_data: Tuple): + test_data, model_params = test_data() + pipeline = EthosU85PipelineBI[input_t1]( + BatchNorm2dNoStats(*model_params), + (test_data,), + aten_op=BatchNorm2dNoStats.aten_ops, + run_on_fvp=False, + qtol=1, + ) + pipeline.run()