From 55fd7abd6c7c463866638c0b293c59eed171756e Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Fri, 11 Oct 2024 15:11:50 +0200 Subject: [PATCH] Fix scalar arithemetic and add test cases Add UnsquezeScalarPlaceholders pass to make scalars rank 1 Add MatchArgRanksPass to guarantee same rank for all inputs for ops that require it. Additional fixes to make Scalar tests pass Map which cases work and which don't. Signed-off-by: Erik Lundell Change-Id: I4ea5e189e26cf7aff391ec153d525b2fb61aa16f Fix shape issues Change-Id: I0b8588cd5f8b284c25e806bb83bc788067d5b649 --- .../annotate_channels_last_dim_order_pass.py | 8 +- backends/arm/_passes/arm_pass_manager.py | 8 +- backends/arm/_passes/arm_pass_utils.py | 20 +++ backends/arm/_passes/decompose_div_pass.py | 9 +- backends/arm/_passes/match_arg_ranks_pass.py | 126 ++++++++++++++ .../arm/_passes/scalars_to_attribute_pass.py | 8 +- .../unsqueeze_scalar_placeholders_pass.py | 53 ++++++ .../quantization_annotation/mul_annotator.py | 2 +- .../quantization_annotation/sub_annotator.py | 15 +- backends/arm/test/misc/test_lifted_tensor.py | 93 +++++++++- backends/arm/test/ops/test_scalars.py | 162 ++++++++++++++++++ backends/arm/test/tester/arm_tester.py | 4 +- 12 files changed, 476 insertions(+), 32 deletions(-) create mode 100644 backends/arm/_passes/match_arg_ranks_pass.py create mode 100644 backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py create mode 100644 backends/arm/test/ops/test_scalars.py diff --git a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py index 222c0a7cb36..b4365bf75e3 100644 --- a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py +++ b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py @@ -9,6 +9,7 @@ from typing import cast import torch +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm.tosa_quant_utils import dq_op from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d from executorch.exir.pass_base import ExportPass, PassResult @@ -52,12 +53,7 @@ def call(self, graph_module: torch.fx.GraphModule): NHWC_Order = (0, 2, 3, 1) HWCM_Order = (2, 3, 0, 1) for node in graph_module.graph.nodes: - if isinstance( - node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list) - ): - node_data = node.meta["val"][0].data - else: - node_data = node.meta["val"].data + node_data = get_first_fake_tensor(node).data if len(node_data.shape) == 4: dim_order = NHWC_Order diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index c4e806a842d..3e061dbfeb0 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -22,6 +22,7 @@ from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import ( InsertSqueezeAfterSumPass, ) +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.meandim_to_averagepool_pass import ( ConvertMeanDimToAveragePool, ) @@ -30,6 +31,9 @@ ScalarsToAttributePass, ) from executorch.backends.arm._passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass +from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import ( + UnsqueezeScalarPlaceholdersPass, +) from executorch.exir import ExportedProgram from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.pass_manager import PassManager @@ -45,10 +49,12 @@ def transform_to_backend_pipeline( ): """Apply passes before transforming program to backend""" self.add_pass(CastInt64ToInt32Pass(exported_program)) + self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) self.add_pass(SizeAdjustConv2DPass()) self.add_pass(RemoveClonePass()) self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(ConvertMeanDimToAveragePool()) + self.add_pass(MatchArgRanksPass(exported_program)) self.add_pass(DecomposeDivPass()) self.add_pass(InsertSqueezeAfterSumPass()) self.add_pass(ConvertSplitToSlicePass()) @@ -61,6 +67,6 @@ def transform_to_backend_pipeline( return self._transform(exported_program.graph_module) def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule): - self.add_pass(DecomposeDivPass()) self.add_pass(ScalarsToAttributePass()) + self.add_pass(DecomposeDivPass()) return self._transform(graph_module) diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index 34704d2cedf..0e74701ab6d 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -7,9 +7,11 @@ from typing import Optional import torch +import torch.fx from executorch.exir.dialects._ops import ops as exir_ops from torch._ops import OpOverload +from torch._subclasses.fake_tensor import FakeTensor def create_node( @@ -64,3 +66,21 @@ def insert_q_dq_pair( # node's first use q.args = (anchor,) + q_params return dq + + +def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor: + """ + Returns a FakeTensor from the meta field of 'node'. + If the node contains many fake tensors, return the first one. + """ + if isinstance( + node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list) + ): + fake_tensor = node.meta["val"][0] + else: + fake_tensor = node.meta["val"] + + assert isinstance( + fake_tensor, FakeTensor + ), f'Found {fake_tensor} in meta["val"] of {node}, expected to find FakeTensor.' + return fake_tensor diff --git a/backends/arm/_passes/decompose_div_pass.py b/backends/arm/_passes/decompose_div_pass.py index 13ee8d8dff7..5cdc79c1c3e 100644 --- a/backends/arm/_passes/decompose_div_pass.py +++ b/backends/arm/_passes/decompose_div_pass.py @@ -8,15 +8,18 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass +edge_div_ops = (exir_ops.edge.aten.div.Tensor,) +aten_div_ops = (torch.ops.aten.div.Tensor, torch.ops.aten.div_.Tensor) + def get_div_decomposition(op) -> tuple: """ Returns the the (reciprocal_op, mul_op), where the ops depends on if the div op is in exir_ops torch.ops.aten. """ - if op == exir_ops.edge.aten.div.Tensor: + if op in edge_div_ops: return (exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.mul.Tensor) - if op == torch.ops.aten.div.Tensor: + if op in aten_div_ops: return (torch.ops.aten.reciprocal.default, torch.ops.aten.mul.Tensor) raise RuntimeError(f"Can't get div decomposition for op {op}") @@ -33,7 +36,7 @@ class DecomposeDivPass(ExportPass): """ def call_operator(self, op, args, kwargs, meta): - if op not in (exir_ops.edge.aten.div.Tensor, torch.ops.aten.div.Tensor): + if op not in (edge_div_ops + aten_div_ops): return super().call_operator(op, args, kwargs, meta) reciprocal_op, mul_op = get_div_decomposition(op) diff --git a/backends/arm/_passes/match_arg_ranks_pass.py b/backends/arm/_passes/match_arg_ranks_pass.py new file mode 100644 index 00000000000..e0cbcf294f6 --- /dev/null +++ b/backends/arm/_passes/match_arg_ranks_pass.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast + +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, +) + +from executorch.exir.dialects._ops import ops as exir_ops + +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx import GraphModule, Node + + +class MatchArgRanksPass(ExportPass): + """ + For ops in 'targeted_ops', make sure that the inputs share the same rank. + New dimensions are inserted at from the beginning of the + """ + + def __init__(self, exported_program): + super().__init__() + self.exported_program = exported_program + + targeted_ops = [ + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.div.Tensor, + ] + + def _match_op_rank(self, graph_module, node, arg, max_rank): + """ + In graph_module, insert a view between arg and node to make the + rank of arg match the other args to node. + """ + shape = get_first_fake_tensor(arg).shape + rank = len(shape) + new_shape = list([1] * (max_rank - rank) + list(shape)) + with graph_module.graph.inserting_before(node): + view = create_node( + graph_module.graph, + exir_ops.edge.aten.view_copy.default, + args=(arg, new_shape), + kwargs={}, + ) + node.replace_input_with(arg, view) + + def _match_buffer_rank(self, arg, max_rank): + """ + Change arg's fake tensor meta to match max_rank if: + - arg is found in inputs_to_buffers or inputs_to_parameters. + """ + fake_tensor = get_first_fake_tensor(arg) + shape = fake_tensor.shape + rank = len(shape) + new_shape = list([1] * (max_rank - rank) + list(shape)) + + buffer_name = None + if arg.name in self.exported_program.graph_signature.inputs_to_buffers: + buffer_name = self.exported_program.graph_signature.inputs_to_buffers[ + arg.name + ] + elif arg.name in self.exported_program.graph_signature.inputs_to_parameters: + buffer_name = self.exported_program.graph_signature.inputs_to_parameters[ + arg.name + ] + if buffer_name: + new_tensor = self.exported_program.state_dict[buffer_name].reshape( + new_shape + ) + self.exported_program.state_dict[buffer_name] = new_tensor + arg.meta["val"] = fake_tensor.fake_mode.from_tensor( + new_tensor, static_shapes=True + ) + + def call(self, graph_module: GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + node = cast(Node, node) + + if node.op != "call_function" or node.target not in self.targeted_ops: + continue + + # Calculate max rank of all inputs to node + max_rank = 1 + for arg in node.args: + if isinstance(arg, Node): + shape = get_first_fake_tensor(arg).shape + max_rank = max(max_rank, len(shape)) + + # Adjust output shape of args if needed. + for arg in node.args: + if not isinstance(arg, Node): + continue + shape = get_first_fake_tensor(arg).shape + rank = len(shape) + if rank == max_rank: + continue + + # If the argument is call_function, match shape by inserting view node. + if arg.op == "call_function": + self._match_op_rank(graph_module, node, arg, max_rank) + else: + # If the argument is a buffer or parameter, adjust shape by changing the fake tensor meta. + self._match_buffer_rank(arg, max_rank) + + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, True) + + def ensures(self, graph_module): + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in self.targeted_ops: + continue + arg0_rank = node.args[0].meta["val"].dim() + arg1_rank = node.args[1].meta["val"].dim() + if arg0_rank != arg1_rank: + raise ValueError( + "Arguments of arithmetic operators need to have the same rank!" + ) diff --git a/backends/arm/_passes/scalars_to_attribute_pass.py b/backends/arm/_passes/scalars_to_attribute_pass.py index e9e547b9c96..f1c3297165f 100644 --- a/backends/arm/_passes/scalars_to_attribute_pass.py +++ b/backends/arm/_passes/scalars_to_attribute_pass.py @@ -7,7 +7,7 @@ from typing import cast, Union import torch -from executorch.backends.arm.tosa_mapping import extract_tensor_meta +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.exir.pass_base import ExportPass, PassResult from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix @@ -22,10 +22,14 @@ class ScalarsToAttributePass(ExportPass): targeted_ops = [ torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, torch.ops.aten.sub.Tensor, torch.ops.aten.sub_.Tensor, + torch.ops.aten.rsub.Scalar, torch.ops.aten.mul.Tensor, + torch.ops.aten.mul_.Tensor, torch.ops.aten.div.Tensor, + torch.ops.aten.div_.Tensor, ] def call(self, graph_module: GraphModule) -> PassResult: @@ -37,7 +41,7 @@ def call(self, graph_module: GraphModule) -> PassResult: biggest_rank = 1 for arg in n.args: if isinstance(arg, Node): - _, shape, _ = extract_tensor_meta(arg.meta) + shape = get_first_fake_tensor(arg).shape biggest_rank = max(biggest_rank, len(shape)) new_args = [] diff --git a/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py b/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py new file mode 100644 index 00000000000..ad9844b5269 --- /dev/null +++ b/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py @@ -0,0 +1,53 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.pass_base import ExportPass, PassResult + + +class UnsqueezeScalarPlaceholdersPass(ExportPass): + """ + Placeholders that have node.meta["val"].shape = () cause issues later in the lowering. + This pass unsqueezes the placeholders to make sure shape is at least (1,). + """ + + def __init__(self, exported_program): + self.exported_program = exported_program + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule): + for node in graph_module.graph.nodes: + if node.op != "placeholder": + continue + rank = node.meta["val"].dim() + if rank == 0: + if not ( + node.name in self.exported_program.graph_signature.inputs_to_buffers + or node.name + in self.exported_program.graph_signature.inputs_to_parameters + ): + continue + tensor = self.exported_program.state_dict[node.name] + if tensor.dim() == 0: + self.exported_program.state_dict[node.name] = tensor.unsqueeze(0) + node.meta["val"] = node.meta["val"].fake_mode.from_tensor( + tensor.unsqueeze(0), static_shapes=True + ) + else: + node.meta["val"] = node.meta["val"].fake_mode.from_tensor( + tensor, static_shapes=True + ) + + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, True) + + def ensures(self, graph_module: torch.fx.GraphModule): + for node in graph_module.graph.nodes: + if node.op == "placeholder": + rank = node.meta["val"].dim() + if rank == 0: + raise ValueError("Placeholders of rank 0 are not supported!") diff --git a/backends/arm/quantizer/quantization_annotation/mul_annotator.py b/backends/arm/quantizer/quantization_annotation/mul_annotator.py index 5df697f4b14..47190d380e0 100644 --- a/backends/arm/quantizer/quantization_annotation/mul_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/mul_annotator.py @@ -24,7 +24,7 @@ def _annotate_mul( annotated_partitions = [] for node in gm.graph.nodes: - if node.target not in (torch.ops.aten.mul.Tensor,): + if node.target not in (torch.ops.aten.mul.Tensor, torch.ops.aten.mul_.Tensor): continue mul_node = node annotated_partitions.append([mul_node]) diff --git a/backends/arm/quantizer/quantization_annotation/sub_annotator.py b/backends/arm/quantizer/quantization_annotation/sub_annotator.py index 92f1808d023..437f3e22e75 100644 --- a/backends/arm/quantizer/quantization_annotation/sub_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/sub_annotator.py @@ -6,8 +6,6 @@ # pyre-unsafe -import itertools -import operator from typing import Callable, List, Optional import torch @@ -16,7 +14,6 @@ from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from torch.ao.quantization.quantizer import QuantizationAnnotation from torch.fx import GraphModule, Node -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions @register_annotator("sub") @@ -25,14 +22,12 @@ def _annotate_sub( quantization_config: QuantizationConfig, filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[List[List[Node]]]: - sub_partitions = get_source_partitions( - gm.graph, [operator.sub, torch.sub, operator.isub], filter_fn - ) - sub_partitions = list(itertools.chain.from_iterable(sub_partitions.values())) annotated_partitions = [] - for sub_partition in sub_partitions: - annotated_partitions.append(sub_partition.nodes) - sub_node = sub_partition.output_nodes[0] + for node in gm.graph.nodes: + if node.target not in (torch.ops.aten.sub.Tensor, torch.ops.aten.sub_.Tensor): + continue + annotated_partitions.append(node) + sub_node = node if arm_quantizer_utils.is_annotated(sub_node): continue diff --git a/backends/arm/test/misc/test_lifted_tensor.py b/backends/arm/test/misc/test_lifted_tensor.py index 90aa7e2950c..29b2887431c 100644 --- a/backends/arm/test/misc/test_lifted_tensor.py +++ b/backends/arm/test/misc/test_lifted_tensor.py @@ -3,40 +3,119 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import operator import unittest +from typing import Union import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester +from parameterized import parameterized class LiftedTensor(torch.nn.Module): - def __init__(self): + test_data = [ + # (operator, test_data, length) + (operator.add, (torch.randn(2, 2), 2)), + (operator.truediv, (torch.ones(2, 2), 2)), + (operator.mul, (torch.randn(2, 2), 2)), + (operator.sub, (torch.rand(2, 2), 2)), + ] + + def __init__(self, op: callable): super().__init__() + self.op = op self.lifted_tensor = torch.Tensor([[1, 2], [3, 4]]) def forward(self, x: torch.Tensor, length) -> torch.Tensor: sliced = self.lifted_tensor[:, :length] - return sliced + x + return self.op(sliced, x) + + +class LiftedScalarTensor(torch.nn.Module): + test_data = [ + # (operator, test_data) + (operator.add, (torch.randn(2, 2),), 1.0), + (operator.truediv, (torch.randn(4, 2),), 1.0), + (operator.mul, (torch.randn(1, 2),), 2.0), + (operator.sub, (torch.randn(3),), 1.0), + ] + + def __init__(self, op: callable, arg1: Union[int, float, torch.tensor]): + super().__init__() + self.op = op + self.arg1 = arg1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.op(x, self.arg1) class TestLiftedTensor(unittest.TestCase): """Tests the ArmPartitioner with a placeholder of type lifted tensor.""" - def test_partition_lifted_tensor(self): + @parameterized.expand(LiftedTensor.test_data) + def test_partition_lifted_tensor_tosa_MI(self, op, data): tester = ( ArmTester( - LiftedTensor(), - example_inputs=(torch.ones(2, 2), 2), + LiftedTensor(op), + example_inputs=data, compile_spec=common.get_tosa_compile_spec(), ) .export() .to_edge() - .dump_artifact() ) signature = tester.get_artifact().exported_program().graph_signature assert len(signature.lifted_tensor_constants) > 0 tester.partition() tester.to_executorch() - tester.run_method_and_compare_outputs((torch.ones(2, 2), 2)) + tester.run_method_and_compare_outputs(data) + + @parameterized.expand(LiftedTensor.test_data) + def test_partition_lifted_tensor_tosa_BI(self, op, data): + tester = ( + ArmTester( + LiftedTensor(op), + example_inputs=data, + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize() + .export() + .to_edge() + ) + signature = tester.get_artifact().exported_program().graph_signature + assert len(signature.lifted_tensor_constants) == 0 + tester.partition() + tester.to_executorch() + tester.run_method_and_compare_outputs(data) + + @parameterized.expand(LiftedScalarTensor.test_data) + def test_partition_lifted_scalar_tensor_tosa_MI(self, op, data, arg1): + ( + ArmTester( + LiftedScalarTensor(op, arg1), + example_inputs=(data), + compile_spec=common.get_tosa_compile_spec(), + ) + .export() + .to_edge() + .partition() + .to_executorch() + .run_method_and_compare_outputs(data) + ) + + @parameterized.expand(LiftedScalarTensor.test_data) + def test_partition_lifted_scalar_tensor_tosa_BI(self, op, data, arg1): + ( + ArmTester( + LiftedScalarTensor(op, arg1), + example_inputs=(data), + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize() + .export() + .to_edge() + .partition() + .to_executorch() + .run_method_and_compare_outputs(data) + ) diff --git a/backends/arm/test/ops/test_scalars.py b/backends/arm/test/ops/test_scalars.py new file mode 100644 index 00000000000..154ca82022c --- /dev/null +++ b/backends/arm/test/ops/test_scalars.py @@ -0,0 +1,162 @@ +import unittest + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from parameterized import parameterized + +""" +Summary of non-working cases. +MI: + Any case with int scalar: A to_copy is inserted to cast the value which we don't partition. + This makes the constant end up outside our partition and the input to the delegate becomes + a to_copy placeholder. In ArmTester, the placeholder is then interpreted as an input. + Potential fix: partition int -> float to_copy-ops in ArmBackend. + # MLETORCH-407 + Op(scalar, tensor): + One issue is that lift_constant_tensor_pass looks for a fake_tensor in the meta of the first + node which does not work the first node is a scalar. + Fixing that, the lowering fails since edge_program.graph_signatures.inputs_to_buffers is changed from + {"_lifted_tensor_constant0":"_lifted_tensor_constant0"} to {"x":"_lifted_tensor_constant0"} + somewhere in _transform in the to_edge step. This makes ArmPartitioner miss tagging the + data in tag_constant_data. + # MLETORCH-408 + +BI: + sub(Scalar, Tensor) becomes rsub, which either fails since the scalar does not become an attribute + in scalars_to_attribute_pass, or, if added to targeted_ops in that pass, fails since rsub expects a + Scalar. + Potential fix: Create pass to convert rsub.Scalar to sub.Tensor +""" + + +class TestScalars(unittest.TestCase): + """Tests various scalar cases for for""" + + class Add(torch.nn.Module): + def forward(self, x, y): + return x + y + + class Sub(torch.nn.Module): + def forward(self, x, y): + return x - y + + class Div(torch.nn.Module): + def forward(self, x, y): + return x / y + + class Mul(torch.nn.Module): + def forward(self, x, y): + return x * y + + class AddInplace(torch.nn.Module): + def forward(self, x, y): + x += y + return x + + class SubInplace(torch.nn.Module): + def forward(self, x, y): + x -= y + return x + + class DivInplace(torch.nn.Module): + def forward(self, x, y): + x /= y + return x + + class MulInplace(torch.nn.Module): + def forward(self, x, y): + x *= y + return x + + class AddConst(torch.nn.Module): + def forward(self, x): + x = 1.0 + x + return x + + # Inplace ops end with '_' (from aten naming) + ops = [ + ("Add", Add()), + ("Sub", Sub()), + ("Mul", Mul()), + ("Div", Div()), + ("Add_", AddInplace()), + ("Sub_", SubInplace()), + ("Mul_", MulInplace()), + ("Div_", DivInplace()), + ] + + const_ops = [("Add", AddConst())] + + dtypes = [("int", 3), ("float", 3.0)] + sizes = [("r1", (1)), ("r4", (2, 4, 5, 3))] + + # Create combinations of tests + tensor_scalar_tests = [] + for op in ops: + for dtype in dtypes: + for size in sizes: + test_name = f"{op[0]}_{dtype[0]}_{size[0]}" + tensor = torch.rand(size[1]) + scalar = dtype[1] + tensor_scalar_tests.append((test_name + "_ts", op[1], tensor, scalar)) + + # Don't add (scalar, tensor) test case for inplace ops. + if op[0][-1] == "_": + continue + + # sub(scalar, tensor) does not work in any case. + if op[0][0:3] == "Sub": + continue + tensor_scalar_tests.append((test_name + "_st", op[1], scalar, tensor)) + + tensor_const_tests = [] + for op in const_ops: + for size in sizes: + test_name = f"{op[0]}_{size[0]}" + tensor = torch.rand(size[1]) + tensor_const_tests.append((test_name, op[1], tensor)) + + def _test_add_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: tuple): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .export() + .to_edge() + .partition() + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_add_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: tuple): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize() + .export() + .to_edge() + .partition() + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + # Most MI tests fail, just show one working for now. + @parameterized.expand((tensor_scalar_tests[6],)) + def test_MI(self, test_name: str, op: torch.nn.Module, x, y): + self._test_add_tosa_MI_pipeline(op, (x, y)) + + # op(Scalar float, tensor) works if the scalar is constant. + @parameterized.expand(tensor_const_tests) + def test_MI_const(self, test_name: str, op: torch.nn.Module, x): + self._test_add_tosa_MI_pipeline(op, (x,)) + + @parameterized.expand(tensor_scalar_tests) + def test_BI(self, test_name: str, op: torch.nn.Module, x, y): + self._test_add_tosa_BI_pipeline(op, (x, y)) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 053ddc3a8ef..59d326109d3 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -293,9 +293,9 @@ def run_method_and_compare_outputs( test_input: list[torch.Tensor] = [] for arg in reference_input: if isinstance(arg, torch.Tensor): - test_input.append(arg) + test_input.append(arg.clone()) if isinstance(arg, tuple) and isinstance(arg[0], torch.Tensor): - test_input.extend(list(arg)) + test_input.extend([tensor.clone() for tensor in arg]) if ( is_nhwc