diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 0c3d4f31aac..3097d641978 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -18,6 +18,9 @@ from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( ConvertExpandCopyToRepeatPass, ) +from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( + ConvertFullLikeToFullPass, +) from executorch.backends.arm._passes.convert_split_to_slice import ( ConvertSplitToSlicePass, ) @@ -95,6 +98,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertMmToBmmPass()) self.add_pass(DecomposeLinearPass()) self.add_pass(ConvertMeanDimToAveragePoolPass()) + self.add_pass(ConvertFullLikeToFullPass()) self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeOperatorArguments()) @@ -133,7 +137,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertMeanDimToAveragePoolPass()) self.add_pass(DecomposeDivPass()) self.add_pass(DecomposeSoftmaxesPass()) - + self.add_pass(ConvertFullLikeToFullPass()) self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeOperatorArguments()) self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg] diff --git a/backends/arm/_passes/convert_full_like_to_full_pass.py b/backends/arm/_passes/convert_full_like_to_full_pass.py new file mode 100644 index 00000000000..234e2ecda82 --- /dev/null +++ b/backends/arm/_passes/convert_full_like_to_full_pass.py @@ -0,0 +1,33 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +class ConvertFullLikeToFullPass(ExportPass): + """As per the full_like pytorch documentation, + `torch.full_like(input, fill_value)` is equivalent to + `torch.full(input.size(), + fill_value, + dtype=input.dtype, + layout=input.layout, + device=input.device + )` + Skip layout and device since it's not relevant for our backend. + """ + + def call_operator(self, op, args, kwargs, meta): + if op not in [ + exir_ops.edge.aten.full_like.default, + ]: + return super().call_operator(op, args, kwargs, meta) + + tensor = args[0].data + full_args = (list(tensor.shape), args[1]) + full_kwargs = {"dtype": tensor.dtype} + return super().call_operator( + exir_ops.edge.aten.full.default, full_args, full_kwargs, meta + ) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 43ab9ea10b5..dd092968764 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -105,6 +105,7 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool: exir_ops.edge.aten.linear.default, exir_ops.edge.aten.split_with_sizes_copy.default, exir_ops.edge.aten.full.default, + exir_ops.edge.aten.full_like.default, exir_ops.edge.aten.ge.Tensor, exir_ops.edge.aten.gt.Tensor, exir_ops.edge.aten.le.Tensor, diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 32f64963e87..f6f6221510f 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -134,6 +134,7 @@ def _match_pattern( torch.ops.aten.sum.dim_IntList, torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardswish.default, + torch.ops.aten.full_like.default, ] _one_to_one_shared_input_qspec = [ @@ -379,3 +380,11 @@ def annotate_graph( # type: ignore[return] _annotate_output(node, quant_properties.quant_output) arm_quantizer_utils.mark_node_as_annotated(node) # type: ignore[attr-defined] + + # Quantization does not allow kwargs for some reason. + # Remove from ops we know have and where we know it does not break anything. + if node.target in [ + torch.ops.aten.full_like.default, + torch.ops.aten.full.default, + ]: + node.kwargs = {} diff --git a/backends/arm/test/models/test_conformer.py b/backends/arm/test/models/test_conformer.py index e3be7811dd1..05f1563f46c 100644 --- a/backends/arm/test/models/test_conformer.py +++ b/backends/arm/test/models/test_conformer.py @@ -27,7 +27,6 @@ class TestConformer(unittest.TestCase): # .to_executorch step, i.e. after Arm partitioner. ops_after_partitioner = { "executorch_exir_dialects_edge__ops_aten_arange_start_step": 1, - "executorch_exir_dialects_edge__ops_aten_full_like_default": 4, "executorch_exir_dialects_edge__ops_aten_max_default": 1, "executorch_exir_dialects_edge__ops_aten_mul_Scalar": 4, "executorch_exir_dialects_edge__ops_aten_eq_Scalar": 2, diff --git a/backends/arm/test/ops/test_full.py b/backends/arm/test/ops/test_full.py index 586e6bd4db2..8347d01be4c 100644 --- a/backends/arm/test/ops/test_full.py +++ b/backends/arm/test/ops/test_full.py @@ -1,5 +1,4 @@ # Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -36,8 +35,8 @@ def forward(self, x: torch.Tensor): return torch.full((2, 2, 3, 3), 4.5, dtype=torch.float32) + x class AddVariableFull(torch.nn.Module): - sizes = [ - (5), + sizes: list[tuple[int, ...]] = [ + (5,), (5, 5), (5, 5, 5), (1, 5, 5, 5), @@ -48,6 +47,21 @@ def forward(self, x: torch.Tensor, y): # Input + a full with the shape from the input and a given value 'y'. return x + torch.full(x.shape, y) + class FullLike(torch.nn.Module): + """Since full_like is replaced with full, we only need to test on reference model, not FVP.""" + + test_parameters = [ + ((torch.randn(2, 2, 2, 2) * 50, 3.2),), + ((torch.randn(2, 2, 2, 2) * 50, 3),), + (((torch.randn(2, 2, 2, 2) * 50).to(torch.int32), 3.2),), + (((torch.randn(2, 2, 2, 2) * 50).to(torch.int32), 3),), + ] + + def forward(self, input_tensor: torch.Tensor, value): + # Our backend can't handle tensors without users, which input_tensor doesn't have + # when the full_like is converted to a full. Therefore involve it in the output. + return input_tensor + torch.full_like(input_tensor, value) + def _test_full_tosa_MI_pipeline( self, module: torch.nn.Module, @@ -63,9 +77,7 @@ def _test_full_tosa_MI_pipeline( compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() - .check_count({"torch.ops.aten.full.default": 1}) - .to_edge() - .partition() + .to_edge_transform_and_lower() .check_not(["executorch_exir_dialects_edge__ops_aten_full_default"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() @@ -85,9 +97,7 @@ def _test_full_tosa_BI_pipeline( ) .quantize() .export() - .check_count({"torch.ops.aten.full.default": 1}) - .to_edge() - .partition() + .to_edge_transform_and_lower() .check_not(["executorch_exir_dialects_edge__ops_aten_full_default"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() @@ -101,9 +111,7 @@ def _test_full_tosa_ethos_pipeline( ArmTester(module, example_inputs=test_data, compile_spec=compile_spec) .quantize() .export() - .check_count({"torch.ops.aten.full.default": 1}) - .to_edge() - .partition() + .to_edge_transform_and_lower() .check_not(["executorch_exir_dialects_edge__ops_aten_full_default"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() @@ -129,6 +137,10 @@ def test_const_full_tosa_MI(self): _input = torch.rand((2, 2, 3, 3)) * 10 self._test_full_tosa_MI_pipeline(self.AddConstFull(), (_input,)) + @parameterized.expand(FullLike.test_parameters) + def test_full_like_tosa_MI(self, test_tensor: Tuple): + self._test_full_tosa_MI_pipeline(self.FullLike(), test_tensor) + def test_const_full_nhwc_tosa_BI(self): _input = torch.rand((2, 2, 3, 3)) * 10 self._test_full_tosa_BI_pipeline(self.AddConstFull(), (_input,)) @@ -143,6 +155,10 @@ def test_full_tosa_MI(self, test_tensor: Tuple): def test_full_tosa_BI(self, test_tensor: Tuple): self._test_full_tosa_BI_pipeline(self.AddVariableFull(), test_tensor) + @parameterized.expand(FullLike.test_parameters) + def test_full_like_tosa_BI(self, test_tensor: Tuple): + self._test_full_tosa_BI_pipeline(self.FullLike(), test_tensor) + @parameterized.expand(AddVariableFull.test_parameters) @pytest.mark.corstone_fvp def test_full_u55_BI(self, test_tensor: Tuple): diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index 79a15f55383..800dfb8d6d4 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -61,7 +61,7 @@ ethos_u_base_rev="24.08" # tosa reference model tosa_reference_model_url="https://review.mlplatform.org/tosa/reference_model" -tosa_reference_model_rev="v0.80.1" +tosa_reference_model_rev="70ed0b40fa831387e36abdb4f7fb9670a3464f5a" # vela vela_repo_url="https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela"