From ed4b043edecb405dadef3b71fe10aa1cf19eab1a Mon Sep 17 00:00:00 2001 From: Agrima Khare Date: Tue, 7 Oct 2025 11:21:38 +0100 Subject: [PATCH] Arm Backend: Add support for copy.default Signed-off-by: Agrima Khare Change-Id: Ib344e18445c892983449b5183148a5d3892f38b6 --- backends/arm/_passes/remove_noop_pass.py | 3 + .../tosa_profile_supported_op_lists.py | 2 + .../arm/quantizer/quantization_annotator.py | 10 + backends/arm/test/ops/test_copy.py | 171 ++++++++++++++++++ 4 files changed, 186 insertions(+) create mode 100644 backends/arm/test/ops/test_copy.py diff --git a/backends/arm/_passes/remove_noop_pass.py b/backends/arm/_passes/remove_noop_pass.py index 9758ac7ba24..8ac808809ef 100644 --- a/backends/arm/_passes/remove_noop_pass.py +++ b/backends/arm/_passes/remove_noop_pass.py @@ -25,6 +25,7 @@ def call_operator(self, op, args, kwargs, meta): if op not in ( exir_ops.edge.dim_order_ops._clone_dim_order.default, exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + exir_ops.edge.aten.copy.default, ): return super().call_operator(op, args, kwargs, meta) @@ -34,4 +35,6 @@ def call_operator(self, op, args, kwargs, meta): if input_dtype != output_dtype: return super().call_operator(op, args, kwargs, meta) + if op == exir_ops.edge.aten.copy.default: + return args[1] return args[0] diff --git a/backends/arm/operator_support/tosa_profile_supported_op_lists.py b/backends/arm/operator_support/tosa_profile_supported_op_lists.py index 61f51165a33..5a83d4e241e 100644 --- a/backends/arm/operator_support/tosa_profile_supported_op_lists.py +++ b/backends/arm/operator_support/tosa_profile_supported_op_lists.py @@ -118,6 +118,7 @@ exir_ops.edge.aten.acos.default, exir_ops.edge.aten.elu.default, exir_ops.edge.aten.bitwise_not.default, + exir_ops.edge.aten.copy.default, } @@ -233,6 +234,7 @@ exir_ops.edge.aten.logit.default, exir_ops.edge.aten.acos.default, exir_ops.edge.aten.elu.default, + exir_ops.edge.aten.copy.default, } diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 99c18953efb..bd0fcd7d64f 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -574,6 +574,16 @@ def any_or_hardtanh_min_zero(n: Node): 0, SharedQuantizationSpec((input_node, node)), ) + elif node.target in [torch.ops.aten.copy_.default]: + input_node = ensure_type(Node, node.args[1]) + quant_properties.quant_inputs = [ + _QuantProperty(0, input_act_qspec), + _QuantProperty(1, input_act_qspec), + ] + quant_properties.quant_output = _QuantProperty( + 0, + SharedQuantizationSpec((input_node, node)), + ) elif node.target in [ torch.ops.aten.eq.Tensor, torch.ops.aten.ge.Tensor, diff --git a/backends/arm/test/ops/test_copy.py b/backends/arm/test/ops/test_copy.py new file mode 100644 index 00000000000..95f13c4653f --- /dev/null +++ b/backends/arm/test/ops/test_copy.py @@ -0,0 +1,171 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + + +class CopyOutput(torch.nn.Module): + def forward(self, x): + y = torch.zeros(x.shape) + return y.copy_(x / x) + x + + +class CopyFirstArg(torch.nn.Module): + def forward(self, x): + y = torch.zeros(x.shape) + return y.copy_(x) + x + + +class CopySecondArg(torch.nn.Module): + def forward(self, x): + y = torch.zeros(x.shape) + return x * y.copy_(x) + + +class CopyBothArgs(torch.nn.Module): + def forward(self, x): + y = torch.zeros(x.shape) + return y.copy_(x) + y.copy_(x) + + +class CopyAfterOtherOp(torch.nn.Module): + def forward(self, x): + y = torch.zeros(x.shape) + x = x * 2 + return y.copy_(x) + x + + +class CopyParallelToOtherOp(torch.nn.Module): + def forward(self, x): + y = torch.zeros(x.shape) + return x * 2 + y.copy_(x) + + +test_suite = { + "copy_output": lambda: ( + CopyOutput, + (torch.rand(1, 2, 3, 4, dtype=torch.float32),), + ), + "copy_first_arg": lambda: ( + CopyFirstArg, + (torch.rand(1, 2, 3, 4, dtype=torch.float32),), + ), + "copy_second_arg": lambda: ( + CopySecondArg, + (torch.rand(1, 2, 3, 4, dtype=torch.float32),), + ), + "copy_both_args": lambda: ( + CopyBothArgs, + (torch.rand(1, 2, 3, 4, dtype=torch.float32),), + ), + "copy_after_other_op": lambda: ( + CopyAfterOtherOp, + (torch.rand(1, 2, 3, 4, dtype=torch.float32),), + ), + "copy_parallel_to_other_op": lambda: ( + CopyParallelToOtherOp, + (torch.rand(1, 2, 3, 4, dtype=torch.float32),), + ), +} + + +aten_op = "torch.ops.aten.copy_.default" +exir_op = "executorch_exir_dialects_edge__ops_aten_copy_default" + +input_t = Tuple[torch.Tensor] + + +@common.parametrize("input_data", test_suite) +def test_copy_tosa_FP(input_data): + module, input_tensor = input_data() + pipeline = TosaPipelineFP[input_t]( + module(), + input_tensor, + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.parametrize("input_data", test_suite) +def test_copy_tosa_INT(input_data): + module, input_tensor = input_data() + + pipeline = TosaPipelineINT[input_t]( + module(), + input_tensor, + aten_op, + exir_op, + ) + pipeline.run() + + +@common.parametrize("input_data", test_suite) +@common.XfailIfNoCorstone300 +def test_copy_u55_INT(input_data): + module, input_tensor = input_data() + + pipeline = EthosU55PipelineINT[input_t]( + module(), + input_tensor, + aten_op, + exir_op, + ) + pipeline.run() + + +@common.parametrize("input_data", test_suite) +@common.XfailIfNoCorstone320 +def test_copy_u85_INT(input_data): + module, input_tensor = input_data() + + pipeline = EthosU85PipelineINT[input_t]( + module(), + input_tensor, + aten_op, + exir_op, + ) + + pipeline.run() + + +@common.parametrize("test_data", test_suite) +@common.SkipIfNoModelConverter +def test_copy_vgf_FP(test_data): + module, input_tensor = test_data() + pipeline = VgfPipeline[input_t]( + module(), + input_tensor, + aten_op=aten_op, + exir_op=exir_op, + tosa_version="TOSA-1.0+FP", + ) + pipeline.run() + + +@common.parametrize("test_data", test_suite) +@common.SkipIfNoModelConverter +def test_copy_vgf_INT(test_data): + module, input_tensor = test_data() + pipeline = VgfPipeline[input_t]( + module(), + input_tensor, + aten_op, + exir_op, + tosa_version="TOSA-1.0+INT", + ) + pipeline.run()