From 74e1d7a9fbd0198f1edb3b0ab129a45af9bf338c Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Mon, 16 Mar 2026 10:10:47 +0100 Subject: [PATCH 1/2] Arm backend: use scalars instead of fulls in TFA Scalars are then converted to buffers by the ScalarsToAttribute pass. This both simplifies the code, and allows affected ops to be moved to device with model.to(device=...). Note that this does not solve all issues with device kwargs after TFA, only specifically for scalar cases. Signed-off-by: Erik Lundell Change-Id: I65c4ac69ec8fbfae98d8660a670658af6ea2eca8 --- backends/arm/_passes/arm_pass.py | 29 ++++++++++++ backends/arm/_passes/arm_pass_manager.py | 12 ++--- backends/arm/_passes/arm_pass_utils.py | 32 +++++++++++++ .../_passes/decompose_add_sub_alpha_pass.py | 15 +------ backends/arm/_passes/decompose_gelu_pass.py | 40 +++-------------- .../arm/_passes/decompose_groupnorm_pass.py | 25 ++++------- .../arm/_passes/decompose_layernorm_pass.py | 22 ++++----- .../arm/_passes/decompose_leaky_relu_pass.py | 20 +++------ .../arm/_passes/decompose_meandim_pass.py | 45 +++---------------- backends/arm/_passes/decompose_var_pass.py | 14 +++--- .../arm/_passes/scalars_to_attribute_pass.py | 2 + 11 files changed, 115 insertions(+), 141 deletions(-) diff --git a/backends/arm/_passes/arm_pass.py b/backends/arm/_passes/arm_pass.py index b255d2c4661..1a1a179f456 100644 --- a/backends/arm/_passes/arm_pass.py +++ b/backends/arm/_passes/arm_pass.py @@ -11,6 +11,7 @@ from executorch.backends.arm.constants import DISALLOW_TFA_META_KEY from executorch.backends.arm.tosa.mapping import TosaSpecialDtype +from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue from torch.fx import GraphModule from torch.fx.passes.infra.pass_base import PassResult @@ -124,3 +125,31 @@ def call_shape_operator( shape_meta.data[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE # Call the super (ArmPass) call operator with updated meta return self.call_operator(op, args, kwargs, shape_meta, updated) + + def call_scalar(self, value: int | float, meta: NodeMetadata | dict[str, Any]): + """Return a scalar value for the current pass stage. + + In transform-for-annotation passes this returns the Python scalar + directly. In later passes it materializes a `(1,)` `aten.full` node + using the output dtype/device from `meta["val"]` when available. + + """ + + if self.is_tfa_pass: + return value + + kwargs = {} + if "val" in meta: + val = meta["val"] + if isinstance(val, tuple): + val = val[0] + kwargs = {"device": val.device, "dtype": val.dtype} + + return ArmPass.call_operator( + self, + op=exir_ops.edge.aten.full.default, + args=((1,), value), + kwargs=kwargs, + meta=meta, + updated=True, + ) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index f4dbcabe112..67d47ac3bf8 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -557,6 +557,12 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): DecomposeDivTensorModePass(tfa_pass=True), DecomposeWhereScalarOtherPass(tfa_pass=True), RewriteInplaceArithmeticPass(tfa_pass=True), + DecomposeAddSubAlphaPass(tfa_pass=True), + DecomposeLeakyReLUPass(tfa_pass=True), + DecomposeGroupNormPass(tfa_pass=True), + DecomposeLayerNormPass(tfa_pass=True), + DecomposeVarPass(tfa_pass=True), + DecomposeMeanDimPass(graph_module, self.tosa_spec, tfa_pass=True), ] ) @@ -573,16 +579,10 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_passes( [ NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True), - DecomposeAddSubAlphaPass(tfa_pass=True), - DecomposeGroupNormPass(tfa_pass=True), - DecomposeLayerNormPass(tfa_pass=True), - DecomposeVarPass(tfa_pass=True), - DecomposeMeanDimPass(graph_module, self.tosa_spec, tfa_pass=True), DecomposeNotEqualPass(tfa_pass=True), DecomposeCosineSimilarityPass(tfa_pass=True), DecomposeGluPass(tfa_pass=True), DecomposeDivPass(tfa_pass=True), - DecomposeLeakyReLUPass(tfa_pass=True), DecomposeLinalgVectorNormPass(tfa_pass=True), DecomposeSqrtPass(tfa_pass=True), DecomposeAdaptiveAvgPool2dPass(tfa_pass=True), diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index 9f55c0ca568..e57fad70aff 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -211,6 +211,38 @@ def meta_without_qparams(meta: NodeMetadata) -> NodeMetadata: return NodeMetadata(plain_meta_dict) +def insert_scalar( + graph: torch.fx.Graph, + value: int | float, + meta: NodeMetadata | dict, + from_node: torch.fx.Node, + is_tfa_pass: bool = False, +) -> torch.fx.Node | int | float: + """Insert an `aten.full` scalar node for direct graph-rewrite passes.""" + + if is_tfa_pass: + return value + + kwargs = {} + val = None + if "val" in meta: + val = meta["val"] + if isinstance(val, tuple): + val = val[0] + kwargs = {"device": val.device, "dtype": val.dtype} + + scalar = create_node( + graph=graph, + op_target=exir_ops.edge.aten.full.default, + args=((1,), value), + kwargs=kwargs, + from_node=from_node, + ) + if val is not None: + scalar.meta["val"] = torch.full((1,), value, **kwargs) + return scalar + + def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor: """Returns a FakeTensor from the meta field of 'node'. diff --git a/backends/arm/_passes/decompose_add_sub_alpha_pass.py b/backends/arm/_passes/decompose_add_sub_alpha_pass.py index b3668ea5d7b..d7db9c5bcf9 100644 --- a/backends/arm/_passes/decompose_add_sub_alpha_pass.py +++ b/backends/arm/_passes/decompose_add_sub_alpha_pass.py @@ -30,24 +30,20 @@ def _get_ops(op): if op is exir_ops.edge.aten.add.Tensor: return ( exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.full.default, exir_ops.edge.aten.add.Tensor, ) return ( torch.ops.aten.mul.Tensor, - torch.ops.aten.full.default, torch.ops.aten.add.Tensor, ) if op in _SUB_OPS: if op is exir_ops.edge.aten.sub.Tensor: return ( exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.full.default, exir_ops.edge.aten.sub.Tensor, ) return ( torch.ops.aten.mul.Tensor, - torch.ops.aten.full.default, torch.ops.aten.sub.Tensor, ) raise RuntimeError(f"Unsupported operator {op}") @@ -72,19 +68,12 @@ def call_operator(self, op, args, kwargs, meta, updated: bool | None = False): if not _should_decompose(alpha): return super().call_operator(op, args, kwargs, meta, updated) - mul_op, full_op, binary_op = _get_ops(op) + mul_op, binary_op = _get_ops(op) lhs, rhs = args - alpha_full = super().call_operator( - full_op, - ((1,), float(alpha)), - {"device": meta["val"].device, "dtype": meta["val"].dtype}, - meta, - updated=True, - ) scaled_rhs = super().call_operator( mul_op, - (rhs, alpha_full), + (rhs, super().call_scalar(alpha, meta)), {}, meta, updated=True, diff --git a/backends/arm/_passes/decompose_gelu_pass.py b/backends/arm/_passes/decompose_gelu_pass.py index 79ea5c6f12d..7815b5fa44f 100644 --- a/backends/arm/_passes/decompose_gelu_pass.py +++ b/backends/arm/_passes/decompose_gelu_pass.py @@ -27,7 +27,6 @@ def _get_gelu_ops(op) -> tuple: if op in edge_gelu: return ( - exir_ops.edge.aten.full.default, exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.tanh.default, @@ -35,7 +34,6 @@ def _get_gelu_ops(op) -> tuple: ) if op in torch_gelu: return ( - torch.ops.aten.full.default, torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor, torch.ops.aten.tanh.default, @@ -98,30 +96,18 @@ def call_operator(self, op, args, kwargs, meta): # If quantized, node should be replace by table op return super().call_operator(op, args, kwargs, meta) - full_op, add_op, mul_op, tanh_op, erf_op = _get_gelu_ops(op) + add_op, mul_op, tanh_op, erf_op = _get_gelu_ops(op) input = get_node_arg(args, 0) # If approximate is default (none) it does not appear in kwargs approximate = get_node_arg(kwargs, "approximate", "none") - shape = meta["val"].size() - dtype = meta["val"].dtype - - FULL_0_5 = super().call_operator( - full_op, ([1] * len(shape), 0.5), {"dtype": dtype}, meta - ) - FULL_1 = super().call_operator( - full_op, ([1] * len(shape), 1), {"dtype": dtype}, meta - ) + FULL_0_5 = super().call_scalar(0.5, meta) + FULL_1 = super().call_scalar(1, meta) if approximate == "none": # Constant mirrors ExecuTorch implementation for parity. - FULL_SQRT1_2 = super().call_operator( - full_op, - ([1] * len(shape), 0.70710678118654752440), - {"dtype": dtype}, - meta, - ) + FULL_SQRT1_2 = super().call_scalar(0.70710678118654752440, meta) op1 = super().call_operator(mul_op, (input, FULL_SQRT1_2), {}, meta) op2 = super().call_operator(erf_op, (op1,), {}, meta) @@ -131,21 +117,9 @@ def call_operator(self, op, args, kwargs, meta): elif approximate == "tanh": # Constants mirror ExecuTorch implementation for parity. - FULL_SQRT2 = super().call_operator( - full_op, - ([1] * len(shape), 1.41421356237309504880), - {"dtype": dtype}, - meta, - ) - FULL_2_SQRTPI = super().call_operator( - full_op, - ([1] * len(shape), 1.12837916709551257390), - {"dtype": dtype}, - meta, - ) - FULL_CUBE_COEFF = super().call_operator( - full_op, ([1] * len(shape), 0.044715), {"dtype": dtype}, meta - ) + FULL_SQRT2 = super().call_scalar(1.41421356237309504880, meta) + FULL_2_SQRTPI = super().call_scalar(1.12837916709551257390, meta) + FULL_CUBE_COEFF = super().call_scalar(0.044715, meta) # Mirrors ExecuTorch implementations for calculating this value SQRT_MUL = super().call_operator( diff --git a/backends/arm/_passes/decompose_groupnorm_pass.py b/backends/arm/_passes/decompose_groupnorm_pass.py index ff33d6a8127..2381dc2a443 100644 --- a/backends/arm/_passes/decompose_groupnorm_pass.py +++ b/backends/arm/_passes/decompose_groupnorm_pass.py @@ -9,7 +9,7 @@ import torch from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_scalar from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass @@ -24,7 +24,6 @@ def get_group_norm_decomposition(op) -> tuple: exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.var.correction, - exir_ops.edge.aten.full.default, exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.rsqrt.default, exir_ops.edge.aten.mul.Tensor, @@ -35,7 +34,6 @@ def get_group_norm_decomposition(op) -> tuple: torch.ops.aten.mean.dim, torch.ops.aten.sub.Tensor, torch.ops.aten.var.correction, - torch.ops.aten.full.default, torch.ops.aten.add.Tensor, torch.ops.aten.rsqrt.default, torch.ops.aten.mul.Tensor, @@ -91,12 +89,8 @@ def call(self, graph_module: torch.fx.GraphModule): meta = node.meta if isinstance(meta["val"], tuple): shape = meta["val"][0].size() - dtype = meta["val"][0].dtype - device = meta["val"][0].device else: shape = meta["val"].size() - dtype = meta["val"].dtype - device = meta["val"].device match len(args): # MI profile always provides all the args: x, weight, bias, N, C, HxW, group, eps case 8: @@ -126,13 +120,11 @@ def call(self, graph_module: torch.fx.GraphModule): channels_per_group = C // group grouped_shape = torch.Size([N, group, channels_per_group, HxW]) dims = [2, 3] - epsilon_reshaped_shape = torch.Size([1] * len(grouped_shape)) weights_reshaped_shape = torch.Size([1, group, channels_per_group, 1]) ( mean_op, sub_op, var_op, - full_op, add_op, rsqrt_op, mul_op, @@ -157,16 +149,17 @@ def call(self, graph_module: torch.fx.GraphModule): kwargs={"correction": 0, "keepdim": keepdim}, from_node=node, ) - full = create_node( + add0 = create_node( graph_module.graph, - full_op, - args=(epsilon_reshaped_shape, eps), - kwargs={"dtype": dtype, "device": device}, + add_op, + args=( + var, + insert_scalar( + graph_module.graph, eps, meta, node, self.is_tfa_pass + ), + ), from_node=node, ) - add0 = create_node( - graph_module.graph, add_op, args=(var, full), from_node=node - ) rsqrt = create_node( graph_module.graph, rsqrt_op, args=(add0,), from_node=node ) diff --git a/backends/arm/_passes/decompose_layernorm_pass.py b/backends/arm/_passes/decompose_layernorm_pass.py index 992b21fd592..780e932733b 100644 --- a/backends/arm/_passes/decompose_layernorm_pass.py +++ b/backends/arm/_passes/decompose_layernorm_pass.py @@ -9,7 +9,7 @@ import torch from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_scalar from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass from executorch.backends.arm._passes.fuse_constant_ops_pass import ( @@ -107,17 +107,12 @@ def call(self, graph_module: torch.fx.GraphModule): n_dims = len(normalized_shape) if isinstance(meta["val"], tuple): shape = meta["val"][0].size() - dtype = meta["val"][0].dtype - device = meta["val"][0].device else: shape = meta["val"].size() - dtype = meta["val"].dtype - device = meta["val"].device rank = len(shape) dims = list(range(-1, -1 * (n_dims + 1), -1)) dims = [dim % rank for dim in dims] weights_reshaped_shape = [shape[i] if i in dims else 1 for i in range(rank)] - epsilon_reshaped_shape = [1] * rank ( mean_op, @@ -140,16 +135,17 @@ def call(self, graph_module: torch.fx.GraphModule): kwargs={"correction": 0, "keepdim": keepdim}, from_node=node, ) - full = create_node( + add0 = create_node( graph_module.graph, - full_op, - args=(epsilon_reshaped_shape, epsilon), - kwargs={"dtype": dtype, "device": device}, + add_op, + args=( + var, + insert_scalar( + graph_module.graph, epsilon, meta, node, self.is_tfa_pass + ), + ), from_node=node, ) - add0 = create_node( - graph_module.graph, add_op, args=(var, full), from_node=node - ) rsqrt = create_node( graph_module.graph, rsqrt_op, args=(add0,), from_node=node ) diff --git a/backends/arm/_passes/decompose_leaky_relu_pass.py b/backends/arm/_passes/decompose_leaky_relu_pass.py index d9b5bbe96df..eb8b5bda61a 100644 --- a/backends/arm/_passes/decompose_leaky_relu_pass.py +++ b/backends/arm/_passes/decompose_leaky_relu_pass.py @@ -9,9 +9,6 @@ import torch from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( - ConvertFullLikeToFullPass, -) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -23,14 +20,12 @@ def _get_leaky_relu_ops(op) -> tuple: if op in edge_ops: return ( exir_ops.edge.aten.clamp.default, - exir_ops.edge.aten.full_like.default, exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.add.Tensor, ) elif op in torch_ops: return ( torch.ops.aten.clamp.default, - torch.ops.aten.full_like.default, torch.ops.aten.mul.Tensor, torch.ops.aten.add.Tensor, ) @@ -45,13 +40,13 @@ class DecomposeLeakyReLUPass(ArmPass): Example: %op1 = clamp(x,0,None) (equivalent to max(0,x)) %op2 = clamp(x,None,0) (equivalent to min(0,x)) - %op3 = full_like(x,slope) + %op3 = slope %op4 = mul(%op3,%op2) %op5 = add(%op1,%op4) """ - _passes_required_after: Set[Type[ExportPass]] = {ConvertFullLikeToFullPass} + _passes_required_after: Set[Type[ExportPass]] = set() def call_operator(self, op, args, kwargs, meta): if op not in (edge_ops + torch_ops) or not self.allowed_to_transform(meta): @@ -59,19 +54,18 @@ def call_operator(self, op, args, kwargs, meta): x = args[0] slope = args[1] if len(args) > 1 else 0.01 - clamp, full_like, mul, add = _get_leaky_relu_ops(op) + clamp, mul, add = _get_leaky_relu_ops(op) op1 = super().call_operator( op=clamp, args=(x, 0, None), kwargs=kwargs, meta=meta ) op2 = super().call_operator( op=clamp, args=(x, None, 0), kwargs=kwargs, meta=meta ) - op3 = super().call_operator( - op=full_like, - args=(x, slope), - kwargs={}, + op4 = super().call_operator( + op=mul, + args=(op2, super().call_scalar(slope, meta)), + kwargs=kwargs, meta=meta, ) - op4 = super().call_operator(op=mul, args=(op3, op2), kwargs=kwargs, meta=meta) op5 = super().call_operator(op=add, args=(op1, op4), kwargs=kwargs, meta=meta) return op5 diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index 81d2339812f..dec890c5561 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -25,13 +25,11 @@ def get_meandim_decomposition(op) -> tuple: if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default): return ( exir_ops.edge.aten.sum.dim_IntList, - exir_ops.edge.aten.full.default, exir_ops.edge.aten.mul.Tensor, ) if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default): return ( torch.ops.aten.sum.dim_IntList, - torch.ops.aten.full.default, torch.ops.aten.mul.Tensor, ) raise RuntimeError(f"Can't get meandim decomposition for op {op}") @@ -123,7 +121,6 @@ def call_operator(self, op, args, kwargs, meta): dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce] dims_to_reduce = [dim for dim in dims_to_reduce if input_shape[dim] != 1] - dtype = meta["val"].dtype view_op = get_view(op) # Reshape to 4D @@ -155,7 +152,7 @@ def call_operator(self, op, args, kwargs, meta): x = super().call_operator(view_op, (x, temp_shape), {}, meta, True) x = self._maybe_insert_q_dq_after(x, meta) # Reduce remaining dims by sum - x = self._reduce_by_sum(op, x, dims_to_reduce, meta, dtype) + x = self._reduce_by_sum(op, x, dims_to_reduce, meta) # Reshape to correct output shape if necessary if list(x.data.shape) != output_shape: @@ -163,48 +160,20 @@ def call_operator(self, op, args, kwargs, meta): return x - def _reduce_by_sum(self, op, input_node, dims, meta, dtype): + def _reduce_by_sum(self, op, input_node, dims, meta): if len(dims) == 0: return input_node input_shape = input_node.data.size() - output_shape = meta["val"].size() N = prod((n for i, n in enumerate(input_shape) if i in dims)) - sum_op, full_op, mul_op = get_meandim_decomposition(op) + sum_op, mul_op = get_meandim_decomposition(op) sum = super().call_operator(sum_op, (input_node, dims, True), {}, meta, True) - full = super().call_operator( - full_op, - ([1] * len(output_shape), 1 / N), - {"dtype": dtype, "device": input_node.data.device}, - meta, - True, - ) + divisor = super().call_scalar(1 / N, meta) if (quant_ops := get_quantization(input_node.node.target)) is not None: - # Insert Q and DQ nodes after full op. - # Since the value of full is known, we can compute quant params such that dq(q_max_value) q_op, dq_op = quant_ops - qmax = input_node.node.args[4] - full_quant_args = ( - 1 / (N * qmax), # Scale to map qmax to 1/N - 0, # Zero point - *input_node.node.args[3:], - ) - q_args = (full, *full_quant_args) - full = super().call_operator( - q_op, - q_args, - kwargs={}, - meta=meta, - updated=True, - ) - dq_args = (full, *full_quant_args) - full = super().call_operator( - dq_op, dq_args, kwargs={}, meta=meta, updated=True - ) - - # Insert Q and DQ nodes after sum op. - # Scale needs to be adjusted with N, since it was computed on data after the division with N. + # Scale needs to be adjusted with N, since it was computed on data + # after the division with N. sum_quant_args = (input_node.node.args[1] * N, *input_node.node.args[2:]) q_args = (sum, *sum_quant_args) sum = super().call_operator( @@ -219,7 +188,7 @@ def _reduce_by_sum(self, op, input_node, dims, meta, dtype): dq_op, dq_args, kwargs={}, meta=meta, updated=True ) - return super().call_operator(mul_op, (sum, full), {}, meta, True) + return super().call_operator(mul_op, (sum, divisor), {}, meta, True) def _reduce_by_average_pool(self, op, input_node, dims, meta): dims_to_reduce_by_avgpool = [dim for dim in dims if dim >= 2] diff --git a/backends/arm/_passes/decompose_var_pass.py b/backends/arm/_passes/decompose_var_pass.py index 72da356f70a..fcf61cf5129 100644 --- a/backends/arm/_passes/decompose_var_pass.py +++ b/backends/arm/_passes/decompose_var_pass.py @@ -26,7 +26,6 @@ def get_var_decomposition(op) -> tuple: exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.sum.dim_IntList, - exir_ops.edge.aten.full.default, ) if op in (torch.ops.aten.var.correction, torch.ops.aten.var.dim): return ( @@ -34,7 +33,6 @@ def get_var_decomposition(op) -> tuple: torch.ops.aten.sub.Tensor, torch.ops.aten.mul.Tensor, torch.ops.aten.sum.dim_IntList, - torch.ops.aten.full, ) raise RuntimeError(f"Can't get var decomposition for op {op}") @@ -73,7 +71,6 @@ def call_operator(self, op, args, kwargs, meta): if shape == []: shape = [1 for _ in input_shape] - dtype = meta["val"].dtype # Get dim from args based on argument type dim = get_node_arg(args, key=list, default_value=list(range(len(shape)))) @@ -92,18 +89,17 @@ def call_operator(self, op, args, kwargs, meta): for d in dim: N *= input_shape[d] - mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op) + mean_op, diff_op, mul_op, sum_op = get_var_decomposition(op) mean = super().call_operator(mean_op, (x, dim, True), {}, meta, True) diff = super().call_operator(diff_op, (x, mean), {}, meta, True) squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta, True) sum = super().call_operator( sum_op, (squared_diff, dim, keepdim), {}, meta, True ) - full = super().call_operator( - full_op, - ([], 1 / max(0, N - correction)), - {"dtype": dtype, "device": x.data.device}, + return super().call_operator( + mul_op, + (sum, super().call_scalar(1 / max(0, N - correction), meta)), + {}, meta, True, ) - return super().call_operator(mul_op, (sum, full), {}, meta, True) diff --git a/backends/arm/_passes/scalars_to_attribute_pass.py b/backends/arm/_passes/scalars_to_attribute_pass.py index 5cf462626c1..cb4d96efcaf 100644 --- a/backends/arm/_passes/scalars_to_attribute_pass.py +++ b/backends/arm/_passes/scalars_to_attribute_pass.py @@ -31,6 +31,8 @@ class ScalarsToAttributePass(ArmPass): torch.ops.aten.rsub.Scalar, torch.ops.aten.mul.Tensor, torch.ops.aten.div.Tensor, + torch.ops.aten.div_.Tensor, + torch.ops.aten.div.Tensor_mode, ] def _convert_scalar_args( From 2087f355a10816368429f5861fd7e25624bffa2f Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Mon, 16 Mar 2026 15:46:43 +0100 Subject: [PATCH 2/2] Arm backend: Clean up some pass inefficiencies. - The ScalarToAttribute pass went through all submodules for each node, it only needs to do it once. - Some exir passes used full_like for scalars. This creates very buffers of the same size as the input, when a single value is enough. Signed-off-by: Erik Lundell Change-Id: Ie48cec7dd2f78855eda90811fd1bc7dfec7d3a15 --- .../_passes/decompose_asin_and_acos_pass.py | 19 +++++-------------- backends/arm/_passes/decompose_erfinv_pass.py | 12 ++++-------- .../arm/_passes/scalars_to_attribute_pass.py | 6 ++---- 3 files changed, 11 insertions(+), 26 deletions(-) diff --git a/backends/arm/_passes/decompose_asin_and_acos_pass.py b/backends/arm/_passes/decompose_asin_and_acos_pass.py index 0ac834eb519..707e6ec070d 100644 --- a/backends/arm/_passes/decompose_asin_and_acos_pass.py +++ b/backends/arm/_passes/decompose_asin_and_acos_pass.py @@ -42,7 +42,6 @@ def get_decomposition(op) -> tuple: exir_ops.edge.aten.gt.Scalar, exir_ops.edge.aten.lt.Scalar, exir_ops.edge.aten.sub.Tensor, - exir_ops.edge.aten.full_like.default, exir_ops.edge.aten.neg.default, ) @@ -79,15 +78,12 @@ def _build_polynomial( """Helper function to build polynomial from coefficients and variable. """ - full_like_op, add_op, mul_op_scalar, mul_op = ( - exir_ops.edge.aten.full_like.default, + add_op, mul_op_scalar, mul_op = ( exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.mul.Scalar, exir_ops.edge.aten.mul.Tensor, ) - result = super().call_operator( - full_like_op, (variable, coefficients[0]), {}, meta, True - ) + result = super().call_scalar(coefficients[0], meta) for coeff in coefficients[1:]: result = super().call_operator( add_op, @@ -150,7 +146,6 @@ def call_operator(self, op, args, kwargs, meta): gt_op, lt_op, sub_op, - full_like_op, neg_op, ) = get_decomposition(op) @@ -179,7 +174,7 @@ def call_operator(self, op, args, kwargs, meta): # Step 2: Compute the transformed approximation for large values # Calculate z = -0.5 * (|x| - 1) - tmp_ones = super().call_operator(full_like_op, (x_abs, one), {}, meta, True) + tmp_ones = super().call_scalar(one, meta) tmp = super().call_operator(sub_op, (x_abs, tmp_ones), {}, meta, True) z = super().call_operator(mul_op_scalar, (tmp, neg_half), {}, meta, True) @@ -201,9 +196,7 @@ def call_operator(self, op, args, kwargs, meta): t2 = super().call_operator(mul_op_scalar, (t1, two), {}, meta, True) diff = super().call_operator(sub_op_scalar, (t2, pi_over_2), {}, meta, True) - tmp_neg_ones = super().call_operator( - full_like_op, (diff, neg_one), {}, meta, True - ) + tmp_neg_ones = super().call_scalar(neg_one, meta) asin_large = super().call_operator(mul_op, (diff, tmp_neg_ones), {}, meta, True) asin_unsigned = self._combine_branches( @@ -218,9 +211,7 @@ def call_operator(self, op, args, kwargs, meta): if op in edge_acos_op: # If x <= 0.5: acos(x) = pi/2 - asin(x) - const_tensor = super().call_operator( - full_like_op, (x, pi_over_2), {}, meta, True - ) + const_tensor = super().call_scalar(pi_over_2, meta) acos_small = super().call_operator( sub_op, (const_tensor, asin), {}, meta, True ) diff --git a/backends/arm/_passes/decompose_erfinv_pass.py b/backends/arm/_passes/decompose_erfinv_pass.py index 7ebbf181a97..747209d943e 100644 --- a/backends/arm/_passes/decompose_erfinv_pass.py +++ b/backends/arm/_passes/decompose_erfinv_pass.py @@ -26,7 +26,6 @@ def get_erfinv_decomposition(op) -> tuple: if op in edge_erfinv_ops: # Ordered by first use in call_operator below. return ( - exir_ops.edge.aten.full_like.default, exir_ops.edge.aten.lt.Tensor, exir_ops.edge.aten.where.self, exir_ops.edge.aten.abs.default, @@ -140,7 +139,6 @@ def call_operator(self, op, args, kwargs, meta): x = args[0] ( - op_full_like, op_lt_t, op_where, op_abs, @@ -179,12 +177,10 @@ def call_operator(self, op, args, kwargs, meta): CORR_MAX = 0.5 TWO_OVER_SQRT_PI = 1.1283791670955126 - # ---- zeros / ones (tensor-shaped) ---- - zeros = super().call_operator(op_full_like, (x, 0.0), {}, meta, updated=True) - ones = super().call_operator(op_full_like, (x, 1.0), {}, meta, updated=True) - neg_ones = super().call_operator( - op_full_like, (x, -1.0), {}, meta, updated=True - ) + # ---- zeros / ones constants ---- + zeros = super().call_scalar(0.0, meta) + ones = super().call_scalar(1.0, meta) + neg_ones = super().call_scalar(-1.0, meta) # ---- s = sign(x): -1 for x<0 else +1 ---- x_lt0 = super().call_operator(op_lt_t, (x, zeros), {}, meta, updated=True) diff --git a/backends/arm/_passes/scalars_to_attribute_pass.py b/backends/arm/_passes/scalars_to_attribute_pass.py index cb4d96efcaf..0473caf91e7 100644 --- a/backends/arm/_passes/scalars_to_attribute_pass.py +++ b/backends/arm/_passes/scalars_to_attribute_pass.py @@ -94,22 +94,20 @@ def _convert_scalar_args( sub.meta["val"] = n.meta["val"] graph_module.graph.erase_node(n) - def handle_control_nodes(self, node: Node, graph_module: GraphModule) -> None: + def handle_control_nodes(self, graph_module: GraphModule) -> None: """Apply scalar argument conversion on subgraphs of control-flow nodes. """ for _, submodule, _ in get_cond_while_submodules_nested(graph_module): for submodule_node in submodule.graph.nodes: - # use aten.full.default for scalar constants in control subgraphs self._convert_scalar_args(submodule, submodule_node) - graph_module.recompile() def call(self, graph_module: GraphModule) -> PassResult: # convert scalars in control-flow subgraphs and main graph for node in list(graph_module.graph.nodes): n = cast(Node, node) - self.handle_control_nodes(n, graph_module) self._convert_scalar_args(graph_module, n) + self.handle_control_nodes(graph_module) graph_module.recompile() graph_module = super().call(graph_module).graph_module return PassResult(graph_module, True)