diff --git a/backends/arm/_passes/arm_pass.py b/backends/arm/_passes/arm_pass.py index b255d2c4661..1a1a179f456 100644 --- a/backends/arm/_passes/arm_pass.py +++ b/backends/arm/_passes/arm_pass.py @@ -11,6 +11,7 @@ from executorch.backends.arm.constants import DISALLOW_TFA_META_KEY from executorch.backends.arm.tosa.mapping import TosaSpecialDtype +from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue from torch.fx import GraphModule from torch.fx.passes.infra.pass_base import PassResult @@ -124,3 +125,31 @@ def call_shape_operator( shape_meta.data[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE # Call the super (ArmPass) call operator with updated meta return self.call_operator(op, args, kwargs, shape_meta, updated) + + def call_scalar(self, value: int | float, meta: NodeMetadata | dict[str, Any]): + """Return a scalar value for the current pass stage. + + In transform-for-annotation passes this returns the Python scalar + directly. In later passes it materializes a `(1,)` `aten.full` node + using the output dtype/device from `meta["val"]` when available. + + """ + + if self.is_tfa_pass: + return value + + kwargs = {} + if "val" in meta: + val = meta["val"] + if isinstance(val, tuple): + val = val[0] + kwargs = {"device": val.device, "dtype": val.dtype} + + return ArmPass.call_operator( + self, + op=exir_ops.edge.aten.full.default, + args=((1,), value), + kwargs=kwargs, + meta=meta, + updated=True, + ) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index f4dbcabe112..67d47ac3bf8 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -557,6 +557,12 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): DecomposeDivTensorModePass(tfa_pass=True), DecomposeWhereScalarOtherPass(tfa_pass=True), RewriteInplaceArithmeticPass(tfa_pass=True), + DecomposeAddSubAlphaPass(tfa_pass=True), + DecomposeLeakyReLUPass(tfa_pass=True), + DecomposeGroupNormPass(tfa_pass=True), + DecomposeLayerNormPass(tfa_pass=True), + DecomposeVarPass(tfa_pass=True), + DecomposeMeanDimPass(graph_module, self.tosa_spec, tfa_pass=True), ] ) @@ -573,16 +579,10 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_passes( [ NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True), - DecomposeAddSubAlphaPass(tfa_pass=True), - DecomposeGroupNormPass(tfa_pass=True), - DecomposeLayerNormPass(tfa_pass=True), - DecomposeVarPass(tfa_pass=True), - DecomposeMeanDimPass(graph_module, self.tosa_spec, tfa_pass=True), DecomposeNotEqualPass(tfa_pass=True), DecomposeCosineSimilarityPass(tfa_pass=True), DecomposeGluPass(tfa_pass=True), DecomposeDivPass(tfa_pass=True), - DecomposeLeakyReLUPass(tfa_pass=True), DecomposeLinalgVectorNormPass(tfa_pass=True), DecomposeSqrtPass(tfa_pass=True), DecomposeAdaptiveAvgPool2dPass(tfa_pass=True), diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index 9f55c0ca568..e57fad70aff 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -211,6 +211,38 @@ def meta_without_qparams(meta: NodeMetadata) -> NodeMetadata: return NodeMetadata(plain_meta_dict) +def insert_scalar( + graph: torch.fx.Graph, + value: int | float, + meta: NodeMetadata | dict, + from_node: torch.fx.Node, + is_tfa_pass: bool = False, +) -> torch.fx.Node | int | float: + """Insert an `aten.full` scalar node for direct graph-rewrite passes.""" + + if is_tfa_pass: + return value + + kwargs = {} + val = None + if "val" in meta: + val = meta["val"] + if isinstance(val, tuple): + val = val[0] + kwargs = {"device": val.device, "dtype": val.dtype} + + scalar = create_node( + graph=graph, + op_target=exir_ops.edge.aten.full.default, + args=((1,), value), + kwargs=kwargs, + from_node=from_node, + ) + if val is not None: + scalar.meta["val"] = torch.full((1,), value, **kwargs) + return scalar + + def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor: """Returns a FakeTensor from the meta field of 'node'. diff --git a/backends/arm/_passes/decompose_add_sub_alpha_pass.py b/backends/arm/_passes/decompose_add_sub_alpha_pass.py index b3668ea5d7b..d7db9c5bcf9 100644 --- a/backends/arm/_passes/decompose_add_sub_alpha_pass.py +++ b/backends/arm/_passes/decompose_add_sub_alpha_pass.py @@ -30,24 +30,20 @@ def _get_ops(op): if op is exir_ops.edge.aten.add.Tensor: return ( exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.full.default, exir_ops.edge.aten.add.Tensor, ) return ( torch.ops.aten.mul.Tensor, - torch.ops.aten.full.default, torch.ops.aten.add.Tensor, ) if op in _SUB_OPS: if op is exir_ops.edge.aten.sub.Tensor: return ( exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.full.default, exir_ops.edge.aten.sub.Tensor, ) return ( torch.ops.aten.mul.Tensor, - torch.ops.aten.full.default, torch.ops.aten.sub.Tensor, ) raise RuntimeError(f"Unsupported operator {op}") @@ -72,19 +68,12 @@ def call_operator(self, op, args, kwargs, meta, updated: bool | None = False): if not _should_decompose(alpha): return super().call_operator(op, args, kwargs, meta, updated) - mul_op, full_op, binary_op = _get_ops(op) + mul_op, binary_op = _get_ops(op) lhs, rhs = args - alpha_full = super().call_operator( - full_op, - ((1,), float(alpha)), - {"device": meta["val"].device, "dtype": meta["val"].dtype}, - meta, - updated=True, - ) scaled_rhs = super().call_operator( mul_op, - (rhs, alpha_full), + (rhs, super().call_scalar(alpha, meta)), {}, meta, updated=True, diff --git a/backends/arm/_passes/decompose_asin_and_acos_pass.py b/backends/arm/_passes/decompose_asin_and_acos_pass.py index 0ac834eb519..707e6ec070d 100644 --- a/backends/arm/_passes/decompose_asin_and_acos_pass.py +++ b/backends/arm/_passes/decompose_asin_and_acos_pass.py @@ -42,7 +42,6 @@ def get_decomposition(op) -> tuple: exir_ops.edge.aten.gt.Scalar, exir_ops.edge.aten.lt.Scalar, exir_ops.edge.aten.sub.Tensor, - exir_ops.edge.aten.full_like.default, exir_ops.edge.aten.neg.default, ) @@ -79,15 +78,12 @@ def _build_polynomial( """Helper function to build polynomial from coefficients and variable. """ - full_like_op, add_op, mul_op_scalar, mul_op = ( - exir_ops.edge.aten.full_like.default, + add_op, mul_op_scalar, mul_op = ( exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.mul.Scalar, exir_ops.edge.aten.mul.Tensor, ) - result = super().call_operator( - full_like_op, (variable, coefficients[0]), {}, meta, True - ) + result = super().call_scalar(coefficients[0], meta) for coeff in coefficients[1:]: result = super().call_operator( add_op, @@ -150,7 +146,6 @@ def call_operator(self, op, args, kwargs, meta): gt_op, lt_op, sub_op, - full_like_op, neg_op, ) = get_decomposition(op) @@ -179,7 +174,7 @@ def call_operator(self, op, args, kwargs, meta): # Step 2: Compute the transformed approximation for large values # Calculate z = -0.5 * (|x| - 1) - tmp_ones = super().call_operator(full_like_op, (x_abs, one), {}, meta, True) + tmp_ones = super().call_scalar(one, meta) tmp = super().call_operator(sub_op, (x_abs, tmp_ones), {}, meta, True) z = super().call_operator(mul_op_scalar, (tmp, neg_half), {}, meta, True) @@ -201,9 +196,7 @@ def call_operator(self, op, args, kwargs, meta): t2 = super().call_operator(mul_op_scalar, (t1, two), {}, meta, True) diff = super().call_operator(sub_op_scalar, (t2, pi_over_2), {}, meta, True) - tmp_neg_ones = super().call_operator( - full_like_op, (diff, neg_one), {}, meta, True - ) + tmp_neg_ones = super().call_scalar(neg_one, meta) asin_large = super().call_operator(mul_op, (diff, tmp_neg_ones), {}, meta, True) asin_unsigned = self._combine_branches( @@ -218,9 +211,7 @@ def call_operator(self, op, args, kwargs, meta): if op in edge_acos_op: # If x <= 0.5: acos(x) = pi/2 - asin(x) - const_tensor = super().call_operator( - full_like_op, (x, pi_over_2), {}, meta, True - ) + const_tensor = super().call_scalar(pi_over_2, meta) acos_small = super().call_operator( sub_op, (const_tensor, asin), {}, meta, True ) diff --git a/backends/arm/_passes/decompose_erfinv_pass.py b/backends/arm/_passes/decompose_erfinv_pass.py index 7ebbf181a97..747209d943e 100644 --- a/backends/arm/_passes/decompose_erfinv_pass.py +++ b/backends/arm/_passes/decompose_erfinv_pass.py @@ -26,7 +26,6 @@ def get_erfinv_decomposition(op) -> tuple: if op in edge_erfinv_ops: # Ordered by first use in call_operator below. return ( - exir_ops.edge.aten.full_like.default, exir_ops.edge.aten.lt.Tensor, exir_ops.edge.aten.where.self, exir_ops.edge.aten.abs.default, @@ -140,7 +139,6 @@ def call_operator(self, op, args, kwargs, meta): x = args[0] ( - op_full_like, op_lt_t, op_where, op_abs, @@ -179,12 +177,10 @@ def call_operator(self, op, args, kwargs, meta): CORR_MAX = 0.5 TWO_OVER_SQRT_PI = 1.1283791670955126 - # ---- zeros / ones (tensor-shaped) ---- - zeros = super().call_operator(op_full_like, (x, 0.0), {}, meta, updated=True) - ones = super().call_operator(op_full_like, (x, 1.0), {}, meta, updated=True) - neg_ones = super().call_operator( - op_full_like, (x, -1.0), {}, meta, updated=True - ) + # ---- zeros / ones constants ---- + zeros = super().call_scalar(0.0, meta) + ones = super().call_scalar(1.0, meta) + neg_ones = super().call_scalar(-1.0, meta) # ---- s = sign(x): -1 for x<0 else +1 ---- x_lt0 = super().call_operator(op_lt_t, (x, zeros), {}, meta, updated=True) diff --git a/backends/arm/_passes/decompose_gelu_pass.py b/backends/arm/_passes/decompose_gelu_pass.py index 79ea5c6f12d..7815b5fa44f 100644 --- a/backends/arm/_passes/decompose_gelu_pass.py +++ b/backends/arm/_passes/decompose_gelu_pass.py @@ -27,7 +27,6 @@ def _get_gelu_ops(op) -> tuple: if op in edge_gelu: return ( - exir_ops.edge.aten.full.default, exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.tanh.default, @@ -35,7 +34,6 @@ def _get_gelu_ops(op) -> tuple: ) if op in torch_gelu: return ( - torch.ops.aten.full.default, torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor, torch.ops.aten.tanh.default, @@ -98,30 +96,18 @@ def call_operator(self, op, args, kwargs, meta): # If quantized, node should be replace by table op return super().call_operator(op, args, kwargs, meta) - full_op, add_op, mul_op, tanh_op, erf_op = _get_gelu_ops(op) + add_op, mul_op, tanh_op, erf_op = _get_gelu_ops(op) input = get_node_arg(args, 0) # If approximate is default (none) it does not appear in kwargs approximate = get_node_arg(kwargs, "approximate", "none") - shape = meta["val"].size() - dtype = meta["val"].dtype - - FULL_0_5 = super().call_operator( - full_op, ([1] * len(shape), 0.5), {"dtype": dtype}, meta - ) - FULL_1 = super().call_operator( - full_op, ([1] * len(shape), 1), {"dtype": dtype}, meta - ) + FULL_0_5 = super().call_scalar(0.5, meta) + FULL_1 = super().call_scalar(1, meta) if approximate == "none": # Constant mirrors ExecuTorch implementation for parity. - FULL_SQRT1_2 = super().call_operator( - full_op, - ([1] * len(shape), 0.70710678118654752440), - {"dtype": dtype}, - meta, - ) + FULL_SQRT1_2 = super().call_scalar(0.70710678118654752440, meta) op1 = super().call_operator(mul_op, (input, FULL_SQRT1_2), {}, meta) op2 = super().call_operator(erf_op, (op1,), {}, meta) @@ -131,21 +117,9 @@ def call_operator(self, op, args, kwargs, meta): elif approximate == "tanh": # Constants mirror ExecuTorch implementation for parity. - FULL_SQRT2 = super().call_operator( - full_op, - ([1] * len(shape), 1.41421356237309504880), - {"dtype": dtype}, - meta, - ) - FULL_2_SQRTPI = super().call_operator( - full_op, - ([1] * len(shape), 1.12837916709551257390), - {"dtype": dtype}, - meta, - ) - FULL_CUBE_COEFF = super().call_operator( - full_op, ([1] * len(shape), 0.044715), {"dtype": dtype}, meta - ) + FULL_SQRT2 = super().call_scalar(1.41421356237309504880, meta) + FULL_2_SQRTPI = super().call_scalar(1.12837916709551257390, meta) + FULL_CUBE_COEFF = super().call_scalar(0.044715, meta) # Mirrors ExecuTorch implementations for calculating this value SQRT_MUL = super().call_operator( diff --git a/backends/arm/_passes/decompose_groupnorm_pass.py b/backends/arm/_passes/decompose_groupnorm_pass.py index ff33d6a8127..2381dc2a443 100644 --- a/backends/arm/_passes/decompose_groupnorm_pass.py +++ b/backends/arm/_passes/decompose_groupnorm_pass.py @@ -9,7 +9,7 @@ import torch from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_scalar from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass @@ -24,7 +24,6 @@ def get_group_norm_decomposition(op) -> tuple: exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.var.correction, - exir_ops.edge.aten.full.default, exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.rsqrt.default, exir_ops.edge.aten.mul.Tensor, @@ -35,7 +34,6 @@ def get_group_norm_decomposition(op) -> tuple: torch.ops.aten.mean.dim, torch.ops.aten.sub.Tensor, torch.ops.aten.var.correction, - torch.ops.aten.full.default, torch.ops.aten.add.Tensor, torch.ops.aten.rsqrt.default, torch.ops.aten.mul.Tensor, @@ -91,12 +89,8 @@ def call(self, graph_module: torch.fx.GraphModule): meta = node.meta if isinstance(meta["val"], tuple): shape = meta["val"][0].size() - dtype = meta["val"][0].dtype - device = meta["val"][0].device else: shape = meta["val"].size() - dtype = meta["val"].dtype - device = meta["val"].device match len(args): # MI profile always provides all the args: x, weight, bias, N, C, HxW, group, eps case 8: @@ -126,13 +120,11 @@ def call(self, graph_module: torch.fx.GraphModule): channels_per_group = C // group grouped_shape = torch.Size([N, group, channels_per_group, HxW]) dims = [2, 3] - epsilon_reshaped_shape = torch.Size([1] * len(grouped_shape)) weights_reshaped_shape = torch.Size([1, group, channels_per_group, 1]) ( mean_op, sub_op, var_op, - full_op, add_op, rsqrt_op, mul_op, @@ -157,16 +149,17 @@ def call(self, graph_module: torch.fx.GraphModule): kwargs={"correction": 0, "keepdim": keepdim}, from_node=node, ) - full = create_node( + add0 = create_node( graph_module.graph, - full_op, - args=(epsilon_reshaped_shape, eps), - kwargs={"dtype": dtype, "device": device}, + add_op, + args=( + var, + insert_scalar( + graph_module.graph, eps, meta, node, self.is_tfa_pass + ), + ), from_node=node, ) - add0 = create_node( - graph_module.graph, add_op, args=(var, full), from_node=node - ) rsqrt = create_node( graph_module.graph, rsqrt_op, args=(add0,), from_node=node ) diff --git a/backends/arm/_passes/decompose_layernorm_pass.py b/backends/arm/_passes/decompose_layernorm_pass.py index 992b21fd592..780e932733b 100644 --- a/backends/arm/_passes/decompose_layernorm_pass.py +++ b/backends/arm/_passes/decompose_layernorm_pass.py @@ -9,7 +9,7 @@ import torch from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_scalar from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass from executorch.backends.arm._passes.fuse_constant_ops_pass import ( @@ -107,17 +107,12 @@ def call(self, graph_module: torch.fx.GraphModule): n_dims = len(normalized_shape) if isinstance(meta["val"], tuple): shape = meta["val"][0].size() - dtype = meta["val"][0].dtype - device = meta["val"][0].device else: shape = meta["val"].size() - dtype = meta["val"].dtype - device = meta["val"].device rank = len(shape) dims = list(range(-1, -1 * (n_dims + 1), -1)) dims = [dim % rank for dim in dims] weights_reshaped_shape = [shape[i] if i in dims else 1 for i in range(rank)] - epsilon_reshaped_shape = [1] * rank ( mean_op, @@ -140,16 +135,17 @@ def call(self, graph_module: torch.fx.GraphModule): kwargs={"correction": 0, "keepdim": keepdim}, from_node=node, ) - full = create_node( + add0 = create_node( graph_module.graph, - full_op, - args=(epsilon_reshaped_shape, epsilon), - kwargs={"dtype": dtype, "device": device}, + add_op, + args=( + var, + insert_scalar( + graph_module.graph, epsilon, meta, node, self.is_tfa_pass + ), + ), from_node=node, ) - add0 = create_node( - graph_module.graph, add_op, args=(var, full), from_node=node - ) rsqrt = create_node( graph_module.graph, rsqrt_op, args=(add0,), from_node=node ) diff --git a/backends/arm/_passes/decompose_leaky_relu_pass.py b/backends/arm/_passes/decompose_leaky_relu_pass.py index d9b5bbe96df..eb8b5bda61a 100644 --- a/backends/arm/_passes/decompose_leaky_relu_pass.py +++ b/backends/arm/_passes/decompose_leaky_relu_pass.py @@ -9,9 +9,6 @@ import torch from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( - ConvertFullLikeToFullPass, -) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -23,14 +20,12 @@ def _get_leaky_relu_ops(op) -> tuple: if op in edge_ops: return ( exir_ops.edge.aten.clamp.default, - exir_ops.edge.aten.full_like.default, exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.add.Tensor, ) elif op in torch_ops: return ( torch.ops.aten.clamp.default, - torch.ops.aten.full_like.default, torch.ops.aten.mul.Tensor, torch.ops.aten.add.Tensor, ) @@ -45,13 +40,13 @@ class DecomposeLeakyReLUPass(ArmPass): Example: %op1 = clamp(x,0,None) (equivalent to max(0,x)) %op2 = clamp(x,None,0) (equivalent to min(0,x)) - %op3 = full_like(x,slope) + %op3 = slope %op4 = mul(%op3,%op2) %op5 = add(%op1,%op4) """ - _passes_required_after: Set[Type[ExportPass]] = {ConvertFullLikeToFullPass} + _passes_required_after: Set[Type[ExportPass]] = set() def call_operator(self, op, args, kwargs, meta): if op not in (edge_ops + torch_ops) or not self.allowed_to_transform(meta): @@ -59,19 +54,18 @@ def call_operator(self, op, args, kwargs, meta): x = args[0] slope = args[1] if len(args) > 1 else 0.01 - clamp, full_like, mul, add = _get_leaky_relu_ops(op) + clamp, mul, add = _get_leaky_relu_ops(op) op1 = super().call_operator( op=clamp, args=(x, 0, None), kwargs=kwargs, meta=meta ) op2 = super().call_operator( op=clamp, args=(x, None, 0), kwargs=kwargs, meta=meta ) - op3 = super().call_operator( - op=full_like, - args=(x, slope), - kwargs={}, + op4 = super().call_operator( + op=mul, + args=(op2, super().call_scalar(slope, meta)), + kwargs=kwargs, meta=meta, ) - op4 = super().call_operator(op=mul, args=(op3, op2), kwargs=kwargs, meta=meta) op5 = super().call_operator(op=add, args=(op1, op4), kwargs=kwargs, meta=meta) return op5 diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index 81d2339812f..dec890c5561 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -25,13 +25,11 @@ def get_meandim_decomposition(op) -> tuple: if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default): return ( exir_ops.edge.aten.sum.dim_IntList, - exir_ops.edge.aten.full.default, exir_ops.edge.aten.mul.Tensor, ) if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default): return ( torch.ops.aten.sum.dim_IntList, - torch.ops.aten.full.default, torch.ops.aten.mul.Tensor, ) raise RuntimeError(f"Can't get meandim decomposition for op {op}") @@ -123,7 +121,6 @@ def call_operator(self, op, args, kwargs, meta): dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce] dims_to_reduce = [dim for dim in dims_to_reduce if input_shape[dim] != 1] - dtype = meta["val"].dtype view_op = get_view(op) # Reshape to 4D @@ -155,7 +152,7 @@ def call_operator(self, op, args, kwargs, meta): x = super().call_operator(view_op, (x, temp_shape), {}, meta, True) x = self._maybe_insert_q_dq_after(x, meta) # Reduce remaining dims by sum - x = self._reduce_by_sum(op, x, dims_to_reduce, meta, dtype) + x = self._reduce_by_sum(op, x, dims_to_reduce, meta) # Reshape to correct output shape if necessary if list(x.data.shape) != output_shape: @@ -163,48 +160,20 @@ def call_operator(self, op, args, kwargs, meta): return x - def _reduce_by_sum(self, op, input_node, dims, meta, dtype): + def _reduce_by_sum(self, op, input_node, dims, meta): if len(dims) == 0: return input_node input_shape = input_node.data.size() - output_shape = meta["val"].size() N = prod((n for i, n in enumerate(input_shape) if i in dims)) - sum_op, full_op, mul_op = get_meandim_decomposition(op) + sum_op, mul_op = get_meandim_decomposition(op) sum = super().call_operator(sum_op, (input_node, dims, True), {}, meta, True) - full = super().call_operator( - full_op, - ([1] * len(output_shape), 1 / N), - {"dtype": dtype, "device": input_node.data.device}, - meta, - True, - ) + divisor = super().call_scalar(1 / N, meta) if (quant_ops := get_quantization(input_node.node.target)) is not None: - # Insert Q and DQ nodes after full op. - # Since the value of full is known, we can compute quant params such that dq(q_max_value) q_op, dq_op = quant_ops - qmax = input_node.node.args[4] - full_quant_args = ( - 1 / (N * qmax), # Scale to map qmax to 1/N - 0, # Zero point - *input_node.node.args[3:], - ) - q_args = (full, *full_quant_args) - full = super().call_operator( - q_op, - q_args, - kwargs={}, - meta=meta, - updated=True, - ) - dq_args = (full, *full_quant_args) - full = super().call_operator( - dq_op, dq_args, kwargs={}, meta=meta, updated=True - ) - - # Insert Q and DQ nodes after sum op. - # Scale needs to be adjusted with N, since it was computed on data after the division with N. + # Scale needs to be adjusted with N, since it was computed on data + # after the division with N. sum_quant_args = (input_node.node.args[1] * N, *input_node.node.args[2:]) q_args = (sum, *sum_quant_args) sum = super().call_operator( @@ -219,7 +188,7 @@ def _reduce_by_sum(self, op, input_node, dims, meta, dtype): dq_op, dq_args, kwargs={}, meta=meta, updated=True ) - return super().call_operator(mul_op, (sum, full), {}, meta, True) + return super().call_operator(mul_op, (sum, divisor), {}, meta, True) def _reduce_by_average_pool(self, op, input_node, dims, meta): dims_to_reduce_by_avgpool = [dim for dim in dims if dim >= 2] diff --git a/backends/arm/_passes/decompose_var_pass.py b/backends/arm/_passes/decompose_var_pass.py index 72da356f70a..fcf61cf5129 100644 --- a/backends/arm/_passes/decompose_var_pass.py +++ b/backends/arm/_passes/decompose_var_pass.py @@ -26,7 +26,6 @@ def get_var_decomposition(op) -> tuple: exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.sum.dim_IntList, - exir_ops.edge.aten.full.default, ) if op in (torch.ops.aten.var.correction, torch.ops.aten.var.dim): return ( @@ -34,7 +33,6 @@ def get_var_decomposition(op) -> tuple: torch.ops.aten.sub.Tensor, torch.ops.aten.mul.Tensor, torch.ops.aten.sum.dim_IntList, - torch.ops.aten.full, ) raise RuntimeError(f"Can't get var decomposition for op {op}") @@ -73,7 +71,6 @@ def call_operator(self, op, args, kwargs, meta): if shape == []: shape = [1 for _ in input_shape] - dtype = meta["val"].dtype # Get dim from args based on argument type dim = get_node_arg(args, key=list, default_value=list(range(len(shape)))) @@ -92,18 +89,17 @@ def call_operator(self, op, args, kwargs, meta): for d in dim: N *= input_shape[d] - mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op) + mean_op, diff_op, mul_op, sum_op = get_var_decomposition(op) mean = super().call_operator(mean_op, (x, dim, True), {}, meta, True) diff = super().call_operator(diff_op, (x, mean), {}, meta, True) squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta, True) sum = super().call_operator( sum_op, (squared_diff, dim, keepdim), {}, meta, True ) - full = super().call_operator( - full_op, - ([], 1 / max(0, N - correction)), - {"dtype": dtype, "device": x.data.device}, + return super().call_operator( + mul_op, + (sum, super().call_scalar(1 / max(0, N - correction), meta)), + {}, meta, True, ) - return super().call_operator(mul_op, (sum, full), {}, meta, True) diff --git a/backends/arm/_passes/scalars_to_attribute_pass.py b/backends/arm/_passes/scalars_to_attribute_pass.py index 5cf462626c1..0473caf91e7 100644 --- a/backends/arm/_passes/scalars_to_attribute_pass.py +++ b/backends/arm/_passes/scalars_to_attribute_pass.py @@ -31,6 +31,8 @@ class ScalarsToAttributePass(ArmPass): torch.ops.aten.rsub.Scalar, torch.ops.aten.mul.Tensor, torch.ops.aten.div.Tensor, + torch.ops.aten.div_.Tensor, + torch.ops.aten.div.Tensor_mode, ] def _convert_scalar_args( @@ -92,22 +94,20 @@ def _convert_scalar_args( sub.meta["val"] = n.meta["val"] graph_module.graph.erase_node(n) - def handle_control_nodes(self, node: Node, graph_module: GraphModule) -> None: + def handle_control_nodes(self, graph_module: GraphModule) -> None: """Apply scalar argument conversion on subgraphs of control-flow nodes. """ for _, submodule, _ in get_cond_while_submodules_nested(graph_module): for submodule_node in submodule.graph.nodes: - # use aten.full.default for scalar constants in control subgraphs self._convert_scalar_args(submodule, submodule_node) - graph_module.recompile() def call(self, graph_module: GraphModule) -> PassResult: # convert scalars in control-flow subgraphs and main graph for node in list(graph_module.graph.nodes): n = cast(Node, node) - self.handle_control_nodes(n, graph_module) self._convert_scalar_args(graph_module, n) + self.handle_control_nodes(graph_module) graph_module.recompile() graph_module = super().call(graph_module).graph_module return PassResult(graph_module, True)