From 6c70e748871e6435e774cc48a6e559dfa69b3197 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Mon, 30 Dec 2024 15:00:59 +0100 Subject: [PATCH] Move batchnorm decomposition to pass The decomposition logic of batchnorm is better suited for a pass than a node visitor. Signed-off-by: Oscar Andersson Change-Id: I859912dc6ed437aa96a3bcdd819531dc8e02230e --- backends/arm/_passes/arm_pass_manager.py | 5 + .../arm/_passes/decompose_batchnorm_pass.py | 138 ++++++++++++ .../arm/_passes/decompose_layernorm_pass.py | 5 +- backends/arm/operators/__init__.py | 1 - backends/arm/operators/op_batch_norm.py | 211 ------------------ backends/arm/tosa_utils.py | 40 ---- 6 files changed, 146 insertions(+), 254 deletions(-) create mode 100644 backends/arm/_passes/decompose_batchnorm_pass.py delete mode 100644 backends/arm/operators/op_batch_norm.py diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 7e0960968ca..686bfbcd8af 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -24,6 +24,9 @@ from executorch.backends.arm._passes.convert_squeezes_to_view import ( # type: ignore[import-not-found] ConvertSqueezesToViewPass, ) +from executorch.backends.arm._passes.decompose_batchnorm_pass import ( + DecomposeBatchNormPass, +) from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass from executorch.backends.arm._passes.decompose_layernorm_pass import ( DecomposeLayerNormPass, @@ -87,6 +90,7 @@ def _transform(self, graph_module: GraphModule): def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(FuseQuantizedActivationPass()) self.add_pass(RemoveGetItemPass()) + self.add_pass(DecomposeBatchNormPass()) self.add_pass(ConvertSplitToSlicePass()) self.add_pass(ConvertMmToBmmPass()) self.add_pass(DecomposeLinearPass()) @@ -121,6 +125,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertSplitToSlicePass()) self.add_pass(ConvertMmToBmmPass()) self.add_pass(DecomposeLinearPass()) + self.add_pass(DecomposeBatchNormPass()) self.add_pass(DecomposeLayerNormPass()) self.add_pass(DecomposeVarPass()) self.add_pass(DecomposeMeanDimPass()) diff --git a/backends/arm/_passes/decompose_batchnorm_pass.py b/backends/arm/_passes/decompose_batchnorm_pass.py new file mode 100644 index 00000000000..d33e8e3b51a --- /dev/null +++ b/backends/arm/_passes/decompose_batchnorm_pass.py @@ -0,0 +1,138 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import operator + +import torch +from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +edge_bn_ops = (exir_ops.edge.aten._native_batch_norm_legit_no_training.default,) + + +def get_bn_decomposition(op) -> tuple: + """ + Returns decomposition of batchnorm in edge ops. + Raises RuntimeError if op is not batchnorm edge op. + """ + if op in edge_bn_ops: + return ( + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.rsqrt.default, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.full.default, + ) + else: + raise RuntimeError(f"Can't get decomposition for {op}") + + +class DecomposeBatchNormPass(ExportPass): + """ + Decompose BatchNorm to: + %output = (%x - %E[x]) / SQRT( %Var[x] + %epsilon ) * %gamma + %beta + e.g. + %output = (%activations - %running_mean) / SQRT( %running_var + %epsilon_const ) * %weights + %bias + -> + %op1 = sub(%activations, %running_mean) + %op2 = add(%running_var, %epsilon_const) + %op3 = rsqrt(%op2) + %op4 = mul(%op1, %op3) + %op5 = mul(%op4, %weights) + %output = add(%op5, %bias) + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in edge_bn_ops: + continue + + args = node.args + meta = node.meta + ( + activations, + weights, + bias, + running_mean, + running_var, + momentum, + epsilon, + ) = args + if momentum != 0.1: + raise RuntimeError(f"Expected momenttum=0.1 but got {momentum}") + + shape = meta["val"][0].size() + dtype = meta["val"][0].dtype + rank = len(shape) + running_mean_shape = running_mean.meta["val"].shape + running_mean_reshaped_shape = [1] * rank + running_mean_reshaped_shape[1] = running_mean_shape[0] + epsilon_reshaped_shape = [1] * rank + + sub, add, rsqrt, mul, view, full = get_bn_decomposition(node.target) + with graph_module.graph.inserting_before(node): + mean_reshaped = create_node( + graph_module.graph, + view, + args=(running_mean, running_mean_reshaped_shape), + ) + op1 = create_node( + graph_module.graph, sub, args=(activations, mean_reshaped) + ) + full = create_node( + graph_module.graph, + full, + args=(epsilon_reshaped_shape, epsilon), + kwargs={"dtype": dtype}, + ) + var_reshaped = create_node( + graph_module.graph, + view, + args=(running_var, running_mean_reshaped_shape), + ) + op2 = create_node(graph_module.graph, add, args=(var_reshaped, full)) + op3 = create_node(graph_module.graph, rsqrt, args=(op2,)) + op4 = create_node(graph_module.graph, mul, args=(op1, op3)) + if weights is not None: + weights_reshaped = create_node( + graph_module.graph, + view, + args=(weights, running_mean_reshaped_shape), + ) + op5 = create_node( + graph_module.graph, mul, args=(op4, weights_reshaped) + ) + else: + op5 = op4 + output = op5 + if bias is not None: + bias_reshaped_shape = running_mean_reshaped_shape + bias_reshaped = create_node( + graph_module.graph, view, args=(bias, bias_reshaped_shape) + ) + output = create_node( + graph_module.graph, add, args=(op5, bias_reshaped) + ) + + users = [user for user in node.users if node != user] + node.replace_all_uses_with(output) + for user in users: + if user.target == operator.getitem: + user.replace_all_uses_with(output) + graph_module.graph.erase_node(node) + graph_module.graph.eliminate_dead_code() + modified = True + if modified: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/decompose_layernorm_pass.py b/backends/arm/_passes/decompose_layernorm_pass.py index 3739337101e..cc4a81caae0 100644 --- a/backends/arm/_passes/decompose_layernorm_pass.py +++ b/backends/arm/_passes/decompose_layernorm_pass.py @@ -1,4 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -82,9 +82,10 @@ def call(self, graph_module: torch.fx.GraphModule): n_dims = len(normalized_shape) if isinstance(meta["val"], tuple): shape = meta["val"][0].size() + dtype = meta["val"][0].dtype else: shape = meta["val"].size() - dtype = meta["val"][0].dtype + dtype = meta["val"].dtype rank = len(shape) dims = list(range(-1, -1 * (n_dims + 1), -1)) dims = [dim % rank for dim in dims] diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index aece200047d..f57ba092bc4 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -9,7 +9,6 @@ node_visitor, op_add, op_avg_pool2d, - op_batch_norm, op_bmm, op_cat, op_clamp, diff --git a/backends/arm/operators/op_batch_norm.py b/backends/arm/operators/op_batch_norm.py deleted file mode 100644 index dc423d0b4a2..00000000000 --- a/backends/arm/operators/op_batch_norm.py +++ /dev/null @@ -1,211 +0,0 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -from typing import List - -import serializer.tosa_serializer as ts # type: ignore -import torch -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.backends.arm.tosa_utils import promote_shape, tosa_shape -from serializer.tosa_serializer import TosaOp - - -@register_node_visitor -class BatchNormVisitor(NodeVisitor): - target = "aten._native_batch_norm_legit_no_training.default" - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80+MI"), - ] - - def __init__(self, *args): - super().__init__(*args) - - # For BatchNorm2D, mean and var are calculated over the channel dimension - # But TOSA doesn't allow subtraction of inputs with different ranks - # Need to augment the shapes to match the ranks with activations - def augment_shape_rank(self, shape, dim_order): - nchw_shape = (1, *shape, 1, 1) - return tosa_shape(nchw_shape, dim_order) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - # Decompose batch norm into sequence - (activations, weights, bias, running_mean, running_var, momentum, epsilon) = ( - inputs - ) - - input_dtype = activations.dtype - - assert ( - 0.1 == momentum.number - ), "Expected 0.1 momentum, not currently encoded into TOSA" - - # %output = (%x - %E[x]) / SQRT( %Var[x] + %epsilon ) * %gamma + %beta - # e.g. - # %output = (%activations - %running_mean) / SQRT( %running_var + %epsilon_const ) * %weights + %bias - # -> - # %op1 = tosa.SUB(%activations, %running_mean) - # %op2 = tosa.ADD(%running_var, %epsilon_const) - # %op3 = tosa.RSQRT(%op2) - # %op4 = tosa.MUL(%op1, %op3) - # %op5 = tosa.MUL(%op4, %weights) - # %output = tosa.ADD(%op5, %bias) - - # Reshape mean to match rank of activations - mean_reshaped = promote_shape( - tosa_graph, - running_mean, - self.augment_shape_rank(running_mean.shape, output.dim_order), - input_dtype, - ) - - # Subtract mean - # %op1 = tosa.SUB(%activations, %running_mean) - op1 = tosa_graph.addIntermediate( - tosa_shape(output.shape, output.dim_order), input_dtype - ) - tosa_graph.addOperator( - TosaOp.Op().SUB, - [activations.name, mean_reshaped.name], - [op1.name], - ) - # Adding eplison to variance - # %op2 = tosa.ADD(%running_var, %epsilon_const) - epsilon_const = tosa_graph.addConst([1], input_dtype, [epsilon.number]) - op2 = tosa_graph.addIntermediate( - tosa_shape(running_var.shape, running_var.dim_order), input_dtype - ) - tosa_graph.addOperator( - TosaOp.Op().ADD, - [running_var.name, epsilon_const.name], - [op2.name], - ) - # Push downward the variance - # %op3 = tosa.RSQRT(%op2) - op3 = tosa_graph.addIntermediate(running_var.shape, input_dtype) - tosa_graph.addOperator(TosaOp.Op().RSQRT, [op2.name], [op3.name]) - - # Reshape variable to match rank of activations - op3_reshaped = promote_shape( - tosa_graph, - op3, - self.augment_shape_rank(running_var.shape, output.dim_order), - input_dtype, - ) - - # Handle non existing weights and bias - if not weights.name and not bias.name: - # Multiply shifted activations with reciprocal variance - # %output = tosa.MUL(%op1, %op3) e.g. Now we have %output = (%activations - %running_mean) / SQRT( %running_var + %epsilon_const ) - attr_mul = ts.TosaSerializerAttribute() - attr_mul.MulAttribute(0) - tosa_graph.addOperator( - TosaOp.Op().MUL, [op1.name, op3_reshaped.name], [output.name], attr_mul - ) - return - else: - # Multiply shifted activations with reciprocal variance - # %op4 = tosa.MUL(%op1, %op3) - op4 = tosa_graph.addIntermediate( - tosa_shape(output.shape, output.dim_order), input_dtype - ) - attr_mul = ts.TosaSerializerAttribute() - attr_mul.MulAttribute(0) - tosa_graph.addOperator( - TosaOp.Op().MUL, [op1.name, op3_reshaped.name], [op4.name], attr_mul - ) - - # Now we have %op4 = (%activations - %running_mean) / SQRT( %running_var + %epsilon_const ) - - if weights.name and not bias.name: - # Handle only weights but no bias - - # Reshape weights to match rank of activations - weights_reshaped = promote_shape( - tosa_graph, - weights, - self.augment_shape_rank(weights.shape, output.dim_order), - input_dtype, - ) - - # %output = tosa.MUL(%op4, %weights) - attr_mul = ts.TosaSerializerAttribute() - attr_mul.MulAttribute(0) - tosa_graph.addOperator( - TosaOp.Op().MUL, - [op4.name, weights_reshaped.name], - [output.name], - attr_mul, - ) - return - - if not weights.name and bias.name: - # Handle only bias but no weights - - # Reshape bias to match rank of activations - bias_reshaped = promote_shape( - tosa_graph, - bias, - self.augment_shape_rank(bias.shape, output.dim_order), - input_dtype, - ) - - # %output = tosa.ADD(%op4, %bias) - tosa_graph.addOperator( - TosaOp.Op().ADD, - [op4.name, bias_reshaped.name], - [output.name], - ) - return - - # We have both weights and bias - - # Reshape weights to match rank of activations - weights_reshaped = promote_shape( - tosa_graph, - weights, - self.augment_shape_rank(weights.shape, output.dim_order), - input_dtype, - ) - - # %op5 = tosa.MUL(%op4, %weights) - op5 = tosa_graph.addIntermediate( - tosa_shape(output.shape, output.dim_order), input_dtype - ) - attr_mul = ts.TosaSerializerAttribute() - attr_mul.MulAttribute(0) - tosa_graph.addOperator( - TosaOp.Op().MUL, - [op4.name, weights_reshaped.name], - [op5.name], - attr_mul, - ) - - # Reshape bias to match rank of activations - bias_reshaped = promote_shape( - tosa_graph, - bias, - self.augment_shape_rank(bias.shape, output.dim_order), - input_dtype, - ) - - # %output = tosa.ADD(%op5, %bias) - tosa_graph.addOperator( - TosaOp.Op().ADD, - [op5.name, bias_reshaped.name], - [output.name], - ) diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index a57e8aa10c9..15d29b57482 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -9,7 +9,6 @@ import os from typing import Any -import numpy as np import serializer.tosa_serializer as ts # type: ignore import torch from executorch.backends.arm.tosa_mapping import TosaArg @@ -72,45 +71,6 @@ def dbg_fail(node, tosa_graph, path): raise RuntimeError("TOSA Internal Error on node, enable logging for further info.") -# Helper function to match TOSA's broadcasting rank requirement -# Ref: TOSA 0.80 specification - 1.9.3. Data Layouts from -# https://www.mlplatform.org/tosa/tosa_spec.html -def promote_shape(tosa_fb, arg, promoted_shape, out_dtype): - assert np.prod(arg.shape) == np.prod(promoted_shape), "Incompatible promoted shape" - reshape_res = tosa_fb.addIntermediate(promoted_shape, out_dtype) - attr = ts.TosaSerializerAttribute() - attr.ReshapeAttribute(promoted_shape) - tosa_fb.addOperator(TosaOp.Op().RESHAPE, [arg.name], [reshape_res.name], attr) - return reshape_res - - -# Helper transpose function to match TOSA's shape requirements -# E.g., TOSA 0.80 specification - 2.3.3 CONV2D shapes: -# https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d -def transpose_helper(tosa_fb, input, new_order, out_dtype): - # Check new_order's length is equal to input rank - assert len(input.shape) == len(new_order), "Wrong shape order length" - - # Check no duplications - assert len(set(new_order)) == len(new_order), "Contain duplicated dim numbers" - - # Check all dims are valid - for idx in new_order: - if idx < 0: - assert True, "Negative dim number" - elif idx >= len(input.shape): - assert True, "Dim is greater than input rank" - - input_shape_transpoed = [input.shape[i] for i in new_order] - attr = ts.TosaSerializerAttribute() - attr.TransposeAttribute(new_order) - input_transposed = tosa_fb.addIntermediate(input_shape_transpoed, out_dtype) - tosa_fb.addOperator( - TosaOp.Op().TRANSPOSE, [input.name], [input_transposed.name], attr - ) - return input_transposed - - def getNodeArgs(node: Node) -> list[TosaArg]: return [TosaArg(arg) for arg in node.args]