From f8ed4450644964f6dc12d007cbd5e020db976c84 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Thu, 25 Sep 2025 14:25:32 +0200 Subject: [PATCH] Arm backend: Merge FP and INT pass pipelines Merges FP and INT pass pipelines into one pipeline. Signed-off-by: Oscar Andersson Change-Id: I42d228d54c46bd9e524c63431830e45a08052a46 --- .../arm/_passes/annotate_decomposed_matmul.py | 1 + backends/arm/_passes/arm_pass_manager.py | 104 ++++-------------- backends/arm/_passes/broadcast_args_pass.py | 4 + backends/arm/_passes/cast_to_int32_pass.py | 10 +- backends/arm/_passes/convert_elu_params.py | 7 ++ .../arm/_passes/convert_int_pow_to_mul.py | 8 ++ backends/arm/_passes/decompose_acosh_pass.py | 8 ++ .../_passes/decompose_asin_and_acos_pass.py | 9 ++ backends/arm/_passes/decompose_asinh_pass.py | 8 ++ backends/arm/_passes/decompose_atan_pass.py | 8 ++ backends/arm/_passes/decompose_atanh_pass.py | 8 ++ backends/arm/_passes/decompose_cosh_pass.py | 8 ++ backends/arm/_passes/decompose_elu_pass.py | 21 +++- backends/arm/_passes/decompose_expm1_pass.py | 8 ++ backends/arm/_passes/decompose_gelu_pass.py | 7 ++ .../decompose_int16_activation_conv2d_pass.py | 4 +- backends/arm/_passes/decompose_sinh_pass.py | 8 ++ backends/arm/_passes/decompose_sqrt_pass.py | 8 ++ .../fold_qdq_with_annotated_qparams_pass.py | 2 +- backends/arm/_passes/mm_to_bmm_pass.py | 16 --- .../replace_scalar_with_tensor_pass.py | 10 +- backends/arm/operators/op_tosa_rescale.py | 2 +- backends/arm/test/ops/test_pow.py | 10 +- .../test/passes/test_broadcast_args_pass.py | 1 + 24 files changed, 167 insertions(+), 113 deletions(-) diff --git a/backends/arm/_passes/annotate_decomposed_matmul.py b/backends/arm/_passes/annotate_decomposed_matmul.py index f378802d2c0..c8be7c7c04e 100644 --- a/backends/arm/_passes/annotate_decomposed_matmul.py +++ b/backends/arm/_passes/annotate_decomposed_matmul.py @@ -64,6 +64,7 @@ def call(self, graph_module: GraphModule) -> PassResult: ) matmul_targets = { exir_ops.edge.aten.bmm.default, + exir_ops.edge.aten.mm.default, } for partition in matmul_partitions: quantized_input = all( diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index d0647967577..a086d23dc40 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -155,79 +155,26 @@ def _transform(self, graph_module: GraphModule): with TosaLoweringContext(self.tosa_spec): return self(graph_module).graph_module - def _tosa_INT_pipeline( + def _tosa_pipeline( self, exported_program: ExportedProgram, graph_module: GraphModule ) -> GraphModule: self.add_pass(AnnotateOutputDimOrderPass()) self.add_pass(FuseQuantizedActivationPass()) self.add_pass(RemoveGetItemPass()) - self.add_pass(ConvertSplitToSlicePass()) - self.add_pass(ConvertMmToBmmPass()) - self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec)) - self.add_pass(ConvertFullLikeToFullPass()) self.add_pass(ConvertToClampPass()) - self.add_pass(ConvertMinMaxPass()) - self.add_pass(ConvertAnyDefaultDimDimsPass()) - self.add_pass(MatchArgDtypePass()) - if self.tosa_spec.is_U55_subset: - self.add_pass(CastToInt32Pass()) - - self.add_pass(CastBoolToInt8Pass()) - self.add_pass(ReplaceScalarWithTensorByProfilePass()) + self.add_pass(DecomposeGroupNormPass()) + self.add_pass(DecomposeLayerNormPass()) + self.add_pass(DecomposeBatchNormNoStatsPass()) + self.add_pass(DecomposeVarPass()) + self.add_pass( + DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec) + ) self.add_pass(AnnotateDecomposedMatmulPass()) - self.add_pass(QuantizeOperatorArguments()) self.add_pass(ConvertELUParamsPass()) + self.add_pass(ConvertSplitToSlicePass()) + self.add_pass(QuantizeOperatorArguments()) self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg] self.add_pass(FuseDuplicateUsersPass()) - self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) - self.add_pass(MatchArgRanksPass(exported_program)) - if self.tosa_spec.is_U55_subset: - self.add_pass(BroadcastArgsPass()) - self.add_pass(DecomposeLinearPass()) - self.add_pass(DecomposeAdaptiveAvgPool2dPass()) - self.add_pass(DecomposeAvgPool2d()) - self.add_pass(ComputeConstantOpsAOT(exported_program)) - - self.add_pass(DecomposeGroupedConv()) - - self.add_pass(ConvertExpandCopyToRepeatPass()) - self.add_pass(UnsqueezeBeforeRepeatPass()) - self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) - self.add_pass(DecomposeCumsumPass(exported_program)) - self.add_pass(Conv1dUnsqueezePass()) - self.add_pass(DecomposeMaxPool2DPass()) - self.add_pass(SizeAdjustInputPass()) - self.add_pass(DecomposeSelectPass()) - self.add_pass(ConvertSqueezesToViewPass()) - - self.add_pass(FuseViewCopyTransform()) - self.add_pass(FuseConstantArgsPass(exported_program)) - self.add_pass(InsertTableOpsPass(exported_program)) - # If we have a conv2d with int16 activation split up into a convolution - # and an addition, to work-around the lack of support for int48 in torch - # needs to happen before RewriteConv2dPass, but after the table ops are inserted - # to be able to validate that conv2d has right dtype arguments. - self.add_pass(DecomposeConv2dWithInt16ActivationPass()) - self.add_pass(RewriteConv2dPass(exported_program)) - - self.add_pass(RewriteMatmulPass()) - self.add_pass(RewriteUpsamplePass()) - self.add_pass(FuseEqualPlaceholdersPass(exported_program)) - - self.add_pass(InsertRescaleInt32Pass()) - self.add_pass(DecomposeSumPass()) - self.add_pass(ToTosaMemoryFormatPass(exported_program)) - self.add_pass(RemoveNoopPass()) - self.add_pass(InsertRescalePass()) - - self.validate_constraints_mandatory() - return self._transform(graph_module) - - def _tosa_FP_pipeline( - self, exported_program: ExportedProgram, graph_module: GraphModule - ) -> GraphModule: - self.add_pass(AnnotateOutputDimOrderPass()) - self.add_pass(FuseDuplicateUsersPass()) self.add_pass(DecomposeExpm1Pass()) self.add_pass(DecomposeLogitPass()) self.add_pass(DecomposeMaskedFill()) @@ -252,32 +199,20 @@ def _tosa_FP_pipeline( self.add_pass(DecomposeRemainderPass()) self.add_pass(DecomposeDivTensorModePass()) self.add_pass(DecomposeEmbeddingPass()) - self.add_pass(FuseQuantizedActivationPass()) - self.add_pass(RemoveGetItemPass()) - self.add_pass(ConvertSplitToSlicePass()) self.add_pass(FuseBatchnorm2DPass(exported_program)) self.add_pass(ConvertMmToBmmPass()) self.add_pass(DecomposeGluPass()) self.add_pass(DecomposeLinearPass()) self.add_pass(DecomposeLeakyReLUPass()) - self.add_pass(DecomposeGroupNormPass()) - self.add_pass(DecomposeLayerNormPass()) - self.add_pass(DecomposeBatchNormNoStatsPass()) - self.add_pass(DecomposeVarPass()) - self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec)) self.add_pass(DecomposeNotEqualPass()) self.add_pass(DecomposeDivPass()) self.add_pass(DecomposeAddSubAlphaPass()) self.add_pass(DecomposeSoftmaxPass()) self.add_pass(DecomposeGeluPass()) self.add_pass(ConvertFullLikeToFullPass()) - self.add_pass(ConvertToClampPass()) self.add_pass(ConvertMinMaxPass()) self.add_pass(ConvertAnyDefaultDimDimsPass()) self.add_pass(MatchArgDtypePass()) - self.add_pass(AnnotateDecomposedMatmulPass()) - self.add_pass(QuantizeOperatorArguments()) - self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg] self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) self.add_pass(MatchArgRanksPass(exported_program)) self.add_pass(DecomposeAdaptiveAvgPool2dPass()) @@ -290,22 +225,26 @@ def _tosa_FP_pipeline( self.add_pass(DecomposeGroupedConv()) self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(UnsqueezeBeforeRepeatPass()) - self.add_pass(DecomposeSumPass()) self.add_pass(DecomposeCumsumPass(exported_program)) self.add_pass(Conv1dUnsqueezePass()) self.add_pass(DecomposeMaxPool2DPass()) self.add_pass(SizeAdjustInputPass()) self.add_pass(DecomposeSelectPass()) self.add_pass(ConvertSqueezesToViewPass()) + self.add_pass(CastToInt32Pass()) + self.add_pass(BroadcastArgsPass()) self.add_pass(FuseViewCopyTransform()) self.add_pass(FuseConstantArgsPass(exported_program)) - self.add_pass(RewriteConv2dPass(exported_program)) + self.add_pass(DecomposeConv2dWithInt16ActivationPass()) self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) - self.add_pass(RewriteUpsamplePass()) self.add_pass(InsertTableOpsPass(exported_program)) + self.add_pass(RewriteUpsamplePass()) + self.add_pass(RewriteConv2dPass(exported_program)) self.add_pass(RewriteMatmulPass()) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) + self.add_pass(InsertRescaleInt32Pass()) + self.add_pass(DecomposeSumPass()) self.add_pass(ToTosaMemoryFormatPass(exported_program)) self.add_pass(RemoveNoopPass()) self.add_pass(InsertRescalePass()) @@ -317,10 +256,11 @@ def transform_to_backend_pipeline( self, exported_program: ExportedProgram, graph_module: GraphModule ): """Apply passes before transforming program to backend""" - if self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"): - return self._tosa_FP_pipeline(exported_program, graph_module) - elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"): - return self._tosa_INT_pipeline(exported_program, graph_module) + if self.tosa_spec in ( + TosaSpecification.create_from_string("TOSA-1.0+FP"), + TosaSpecification.create_from_string("TOSA-1.0+INT"), + ): + return self._tosa_pipeline(exported_program, graph_module) else: raise NotImplementedError( f"No pass pipeline implemented for {self.tosa_spec=}" diff --git a/backends/arm/_passes/broadcast_args_pass.py b/backends/arm/_passes/broadcast_args_pass.py index 659e6aca686..131b749b702 100644 --- a/backends/arm/_passes/broadcast_args_pass.py +++ b/backends/arm/_passes/broadcast_args_pass.py @@ -11,6 +11,7 @@ create_node, get_first_fake_tensor, ) +from executorch.backends.arm.tosa.specification import get_context_spec from executorch.exir.dialects._ops import ops as exir_ops @@ -34,6 +35,9 @@ class BroadcastArgsPass(ArmPass): } def call(self, graph_module: GraphModule) -> PassResult: + tosa_spec = get_context_spec() + if not tosa_spec.is_U55_subset: + return PassResult(graph_module, False) for node in graph_module.graph.nodes: if node.op != "call_function" or node.target not in self.targeted_ops: continue diff --git a/backends/arm/_passes/cast_to_int32_pass.py b/backends/arm/_passes/cast_to_int32_pass.py index db626bf5695..40f7e347b0f 100644 --- a/backends/arm/_passes/cast_to_int32_pass.py +++ b/backends/arm/_passes/cast_to_int32_pass.py @@ -8,8 +8,10 @@ import torch from executorch.backends.arm._passes.arm_pass import ArmPass + +from executorch.backends.arm.tosa.specification import get_context_spec from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass +from executorch.exir.pass_base import ExportPass, PassResult class CastToInt32Pass(ArmPass): @@ -22,6 +24,12 @@ class CastToInt32Pass(ArmPass): exir_ops.edge.aten.bitwise_right_shift.Tensor, } + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + tosa_spec = get_context_spec() + if not tosa_spec.is_U55_subset: + return PassResult(graph_module, False) + return super().call(graph_module) + def call_operator(self, op, args, kwargs, meta): if op not in self.targeted_ops: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/convert_elu_params.py b/backends/arm/_passes/convert_elu_params.py index 86c1c52c5b7..6225bf92707 100644 --- a/backends/arm/_passes/convert_elu_params.py +++ b/backends/arm/_passes/convert_elu_params.py @@ -8,6 +8,7 @@ import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm.constants import DQ_OPS from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -30,6 +31,12 @@ def call(self, graph_module: torch.fx.GraphModule): op="call_function", target=exir_ops.edge.aten.elu.default ) for node in node_list: + input_node = node.all_input_nodes[0] + is_quantized = ( + input_node.op == "call_function" and input_node.target in DQ_OPS + ) + if not is_quantized: + continue with graph.inserting_after(node): replace_node = create_node(graph, exir_ops.edge.aten.elu.default) old_args = list(node.args) diff --git a/backends/arm/_passes/convert_int_pow_to_mul.py b/backends/arm/_passes/convert_int_pow_to_mul.py index 2d8c72748a2..e2c8bd0c4d6 100644 --- a/backends/arm/_passes/convert_int_pow_to_mul.py +++ b/backends/arm/_passes/convert_int_pow_to_mul.py @@ -24,6 +24,14 @@ def call_operator(self, op, args, kwargs, meta): if op != exir_ops.edge.aten.pow.Tensor_Scalar: return super().call_operator(op, args, kwargs, meta) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) + x = args[0] exp = args[1] diff --git a/backends/arm/_passes/decompose_acosh_pass.py b/backends/arm/_passes/decompose_acosh_pass.py index 8b10cccb913..1d29986433b 100644 --- a/backends/arm/_passes/decompose_acosh_pass.py +++ b/backends/arm/_passes/decompose_acosh_pass.py @@ -41,6 +41,14 @@ def call_operator(self, op, args, kwargs, meta, updated=False): if op is not edge_acosh_op: return super().call_operator(op, args, kwargs, meta, updated) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta, updated) + log_op, sqrt_op, mul_op, sub_op, add_op, add_op_scalar = ( exir_ops.edge.aten.log.default, exir_ops.edge.aten.sqrt.default, diff --git a/backends/arm/_passes/decompose_asin_and_acos_pass.py b/backends/arm/_passes/decompose_asin_and_acos_pass.py index 3a0f87af835..e0da9eb9014 100644 --- a/backends/arm/_passes/decompose_asin_and_acos_pass.py +++ b/backends/arm/_passes/decompose_asin_and_acos_pass.py @@ -123,6 +123,15 @@ def _combine_branches( def call_operator(self, op, args, kwargs, meta): if op not in (edge_asin_op + edge_acos_op): return super().call_operator(op, args, kwargs, meta) + + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) + logging.info( f"Approximating {op}. This may introduce small numerical errors. For details, see {__file__}." ) diff --git a/backends/arm/_passes/decompose_asinh_pass.py b/backends/arm/_passes/decompose_asinh_pass.py index 7ffe75cd255..1131feea9c6 100644 --- a/backends/arm/_passes/decompose_asinh_pass.py +++ b/backends/arm/_passes/decompose_asinh_pass.py @@ -40,6 +40,14 @@ def call_operator(self, op, args, kwargs, meta): if op not in edge_asinh_op: return super().call_operator(op, args, kwargs, meta) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) + log_op, sqrt_op, mul_op, add_op_scalar, add_op = ( exir_ops.edge.aten.log.default, exir_ops.edge.aten.sqrt.default, diff --git a/backends/arm/_passes/decompose_atan_pass.py b/backends/arm/_passes/decompose_atan_pass.py index 6f1adccd257..a3b4081755a 100644 --- a/backends/arm/_passes/decompose_atan_pass.py +++ b/backends/arm/_passes/decompose_atan_pass.py @@ -80,6 +80,14 @@ def call_operator(self, op, args, kwargs, meta): if op is not edge_atan: return super().call_operator(op, args, kwargs, meta, updated=False) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) + logging.info( f"Approximating atan. This may introduce small numerical errors. For details, see {__file__}." ) diff --git a/backends/arm/_passes/decompose_atanh_pass.py b/backends/arm/_passes/decompose_atanh_pass.py index 1a41e77eacc..789dafed9ef 100644 --- a/backends/arm/_passes/decompose_atanh_pass.py +++ b/backends/arm/_passes/decompose_atanh_pass.py @@ -50,6 +50,14 @@ def call_operator(self, op, args, kwargs, meta): if op is not edge_atanh: return super().call_operator(op, args, kwargs, meta, updated=False) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) + ops = _get_atanh_ops(op) ( op_mul_tensor, diff --git a/backends/arm/_passes/decompose_cosh_pass.py b/backends/arm/_passes/decompose_cosh_pass.py index 6716ba499ad..fe84f2bde9b 100644 --- a/backends/arm/_passes/decompose_cosh_pass.py +++ b/backends/arm/_passes/decompose_cosh_pass.py @@ -39,6 +39,14 @@ def call_operator(self, op, args, kwargs, meta, updated=False): if op is not edge_cosh: return super().call_operator(op, args, kwargs, meta, updated) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) + x = args exp_op, mul_op, neg_op, add_op = ( diff --git a/backends/arm/_passes/decompose_elu_pass.py b/backends/arm/_passes/decompose_elu_pass.py index ba3d32b7529..5428465c619 100644 --- a/backends/arm/_passes/decompose_elu_pass.py +++ b/backends/arm/_passes/decompose_elu_pass.py @@ -64,6 +64,14 @@ def call_operator(self, op, args, kwargs, meta): if op not in edge_elu_ops: return super().call_operator(op, args, kwargs, meta, updated=False) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) + ( expm1_op, ge_op, @@ -75,8 +83,17 @@ def call_operator(self, op, args, kwargs, meta): alpha = args[1] if len(args) > 1 else 1.0 if alpha == 0: - relu_op = exir_ops.edge.aten.relu.default - return super().call_operator(relu_op, (input,), {}, meta, updated=True) + relu_op = exir_ops.edge.aten.clamp.default + return super().call_operator( + relu_op, + ( + input, + 0, + ), + {}, + meta, + updated=True, + ) expm1_node = super().call_operator(expm1_op, (input,), {}, meta, updated=True) mul_node = super().call_operator( diff --git a/backends/arm/_passes/decompose_expm1_pass.py b/backends/arm/_passes/decompose_expm1_pass.py index 0fe95d37ba2..09a891c34dc 100644 --- a/backends/arm/_passes/decompose_expm1_pass.py +++ b/backends/arm/_passes/decompose_expm1_pass.py @@ -92,6 +92,14 @@ def call_operator(self, op, args, kwargs, meta): if op not in edge_expm1_ops: return super().call_operator(op, args, kwargs, meta, updated=False) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) + ( op_pow, op_div, diff --git a/backends/arm/_passes/decompose_gelu_pass.py b/backends/arm/_passes/decompose_gelu_pass.py index 532f5d859fe..2a25e6dbb6d 100644 --- a/backends/arm/_passes/decompose_gelu_pass.py +++ b/backends/arm/_passes/decompose_gelu_pass.py @@ -94,6 +94,13 @@ class DecomposeGeluPass(ArmPass): def call_operator(self, op, args, kwargs, meta): if op not in torch_gelu + edge_gelu: return super().call_operator(op, args, kwargs, meta) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) full_op, add_op, mul_op, tanh_op, erf_op = _get_gelu_ops(op) diff --git a/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py b/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py index 06da5bbab22..ac4f271b744 100644 --- a/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py +++ b/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py @@ -44,9 +44,7 @@ def call_operator(self, op, args, kwargs, meta): "int16 activation for convolution requires TOSA int16 extension" ) else: - raise NotImplementedError( - "Decomposition to conv+add only implemented for activation of int16 type" - ) + return super().call_operator(op, args, kwargs, meta) # convolution with bias and activation is int16 bias = args[2] diff --git a/backends/arm/_passes/decompose_sinh_pass.py b/backends/arm/_passes/decompose_sinh_pass.py index 772cc7c4741..731f9a5dbf3 100644 --- a/backends/arm/_passes/decompose_sinh_pass.py +++ b/backends/arm/_passes/decompose_sinh_pass.py @@ -44,6 +44,14 @@ def call_operator(self, op, args, kwargs, meta): if op is not edge_sinh: return super().call_operator(op, args, kwargs, meta) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) + x = args sub_op, exp_op, neg_op, mul_op = ( diff --git a/backends/arm/_passes/decompose_sqrt_pass.py b/backends/arm/_passes/decompose_sqrt_pass.py index 50731388fed..6d78c70634f 100644 --- a/backends/arm/_passes/decompose_sqrt_pass.py +++ b/backends/arm/_passes/decompose_sqrt_pass.py @@ -38,6 +38,14 @@ def call_operator(self, op, args, kwargs, meta): if op not in (edge_sqrt_ops + aten_sqrt_ops): return super().call_operator(op, args, kwargs, meta) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) + pow_op = get_sqrt_decomposition(op) return super().call_operator(pow_op, (args[0], 0.5), {}, meta, updated=True) diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 52e96878042..2a0e889f87c 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -152,7 +152,7 @@ def fold_and_annotate_arg( if len(n.users) == 0: graph_module.graph.erase_node(n) - def call(self, graph_module: GraphModule) -> PassResult: + def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901 # Loop over the graph nodes and find any node in the 'targeted_ops' list. for n in graph_module.graph.nodes: diff --git a/backends/arm/_passes/mm_to_bmm_pass.py b/backends/arm/_passes/mm_to_bmm_pass.py index 48dbde43802..9ff15e2850b 100644 --- a/backends/arm/_passes/mm_to_bmm_pass.py +++ b/backends/arm/_passes/mm_to_bmm_pass.py @@ -12,15 +12,10 @@ from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, - insert_q_dq_pair, ) from executorch.backends.arm._passes.convert_squeezes_to_view import ( ConvertSqueezesToViewPass, ) -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - FoldAndAnnotateQParamsPass, -) -from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import Node @@ -38,7 +33,6 @@ class ConvertMmToBmmPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = { ConvertSqueezesToViewPass, - FoldAndAnnotateQParamsPass, } def call(self, graph_module: torch.fx.GraphModule): @@ -68,11 +62,6 @@ def call(self, graph_module: torch.fx.GraphModule): ) node.replace_input_with(input_node, unsqueeze_before) - # If Quantized we must insert unsqueeze --> q --> dq --> node - if input_node.target in DQ_OPS: - q_params = input_node.args[1:] - insert_q_dq_pair(graph, unsqueeze_before, q_params, from_node=node) - # Replace mm node with bmm with graph.inserting_before(node): bmm_node = create_node( @@ -101,11 +90,6 @@ def call(self, graph_module: torch.fx.GraphModule): for user in original_users: user.replace_input_with(bmm_node, squeeze_after) - # If quantized, insert mm --> q --> dq --> squeeze - if all(original_user.target in Q_OPS for original_user in original_users): - q_params = original_users[0].args[1:] - insert_q_dq_pair(graph, bmm_node, q_params, from_node=node) - modified_graph = True if modified_graph: diff --git a/backends/arm/_passes/replace_scalar_with_tensor_pass.py b/backends/arm/_passes/replace_scalar_with_tensor_pass.py index f5ab5f633ba..3a3ae0d5081 100644 --- a/backends/arm/_passes/replace_scalar_with_tensor_pass.py +++ b/backends/arm/_passes/replace_scalar_with_tensor_pass.py @@ -94,11 +94,13 @@ def __init__(self): def call_operator(self, op, args, kwargs, meta): tosa_spec = get_context_spec() + included_ops = {} if tosa_spec.support_integer(): - included_ops = _int_profile_ops - elif tosa_spec.support_float(): - included_ops = _fp_profile_ops - else: + included_ops |= _int_profile_ops + if tosa_spec.support_float(): + included_ops |= _fp_profile_ops + + if included_ops == {}: raise ValueError("Profile must support either INT or FP") if op in included_ops: diff --git a/backends/arm/operators/op_tosa_rescale.py b/backends/arm/operators/op_tosa_rescale.py index 75268938579..feb7d1ef28a 100644 --- a/backends/arm/operators/op_tosa_rescale.py +++ b/backends/arm/operators/op_tosa_rescale.py @@ -53,7 +53,7 @@ def define_node( and input_zp != 0 ): raise ValueError( - f"If input dtype is not int8 or int16, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}" + f"If input dtype is not int8 or int16, input_zp must be 0. Got input_dtype {input_dtype=}, {input_zp=}" ) if output_dtype not in [torch.int8, torch.int16] and output_zp != 0: raise ValueError( diff --git a/backends/arm/test/ops/test_pow.py b/backends/arm/test/ops/test_pow.py index 14fb05109cc..c11bc985101 100644 --- a/backends/arm/test/ops/test_pow.py +++ b/backends/arm/test/ops/test_pow.py @@ -118,11 +118,12 @@ def test_pow_tensor_tensor_vgf_FP(test_data: Pow_TensorTensor.input_t): x_fail = { "exp_two": "TOSA constraints: If x <0 .", - "non_neg_base_exp_pos_decimal": "TOSA constraints: If x == 0 and y ⇐ 0, the result is undefined.", } -@common.parametrize("test_data", Pow_TensorScalar.test_data, x_fail, strict=False) +@common.parametrize( + "test_data", Pow_TensorScalar.test_data, xfails=x_fail, strict=False +) def test_pow_tensor_scalar_tosa_FP(test_data: Pow_TensorScalar.input_t): base, exp = test_data() pipeline = TosaPipelineFP[Pow_TensorScalar.input_t]( @@ -186,7 +187,10 @@ def test_pow_tensor_scalar_vgf_FP(test_data: Pow_TensorScalar.input_t): pipeline.run() -@common.parametrize("test_data", Pow_TensorScalar.test_data, x_fail, strict=False) +@common.parametrize( + "test_data", + Pow_TensorScalar.test_data, +) @common.SkipIfNoModelConverter def test_pow_tensor_scalar_vgf_INT(test_data: Pow_TensorScalar.input_t): base, exp = test_data() diff --git a/backends/arm/test/passes/test_broadcast_args_pass.py b/backends/arm/test/passes/test_broadcast_args_pass.py index 719a0ddd622..d026997ca57 100644 --- a/backends/arm/test/passes/test_broadcast_args_pass.py +++ b/backends/arm/test/passes/test_broadcast_args_pass.py @@ -50,5 +50,6 @@ def test_multiple_broacasts_model(module: NeedsMultipleBroadcastsModel): ops_not_before_pass=ops_not_before_pass, ops_after_pass=ops_after_pass, pass_list=[BroadcastArgsPass], + tosa_extensions=["u55"], ) pipeline.run()