Skip to content

Commit 2afccbb

Browse files
committed
Arm Backend: Add support for copy.default
Signed-off-by: Agrima Khare <agrima.khare@arm.com> Change-Id: Ib344e18445c892983449b5183148a5d3892f38b6
1 parent 5cf193a commit 2afccbb

File tree

4 files changed

+186
-0
lines changed

4 files changed

+186
-0
lines changed

backends/arm/_passes/remove_noop_pass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def call_operator(self, op, args, kwargs, meta):
2525
if op not in (
2626
exir_ops.edge.dim_order_ops._clone_dim_order.default,
2727
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
28+
exir_ops.edge.aten.copy.default,
2829
):
2930
return super().call_operator(op, args, kwargs, meta)
3031

@@ -34,4 +35,6 @@ def call_operator(self, op, args, kwargs, meta):
3435
if input_dtype != output_dtype:
3536
return super().call_operator(op, args, kwargs, meta)
3637

38+
if op == exir_ops.edge.aten.copy.default:
39+
return args[1]
3740
return args[0]

backends/arm/operator_support/tosa_profile_supported_op_lists.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@
118118
exir_ops.edge.aten.acos.default,
119119
exir_ops.edge.aten.elu.default,
120120
exir_ops.edge.aten.bitwise_not.default,
121+
exir_ops.edge.aten.copy.default,
121122
}
122123

123124

@@ -233,6 +234,7 @@
233234
exir_ops.edge.aten.logit.default,
234235
exir_ops.edge.aten.acos.default,
235236
exir_ops.edge.aten.elu.default,
237+
exir_ops.edge.aten.copy.default,
236238
}
237239

238240

backends/arm/quantizer/quantization_annotator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,16 @@ def any_or_hardtanh_min_zero(n: Node):
573573
0,
574574
SharedQuantizationSpec((input_node, node)),
575575
)
576+
elif node.target in [torch.ops.aten.copy_.default]:
577+
input_node = ensure_type(Node, node.args[1])
578+
quant_properties.quant_inputs = [
579+
_QuantProperty(0, input_act_qspec),
580+
_QuantProperty(1, input_act_qspec),
581+
]
582+
quant_properties.quant_output = _QuantProperty(
583+
0,
584+
SharedQuantizationSpec((input_node, node)),
585+
)
576586
elif node.target in [
577587
torch.ops.aten.eq.Tensor,
578588
torch.ops.aten.ge.Tensor,

backends/arm/test/ops/test_copy.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
from executorch.backends.arm.test import common
11+
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
EthosU55PipelineINT,
14+
EthosU85PipelineINT,
15+
TosaPipelineFP,
16+
TosaPipelineINT,
17+
VgfPipeline,
18+
)
19+
20+
21+
class CopyOutput(torch.nn.Module):
22+
def forward(self, x):
23+
y = torch.zeros(x.shape)
24+
return y.copy_(x / x) + x
25+
26+
27+
class CopyFirstArg(torch.nn.Module):
28+
def forward(self, x):
29+
y = torch.zeros(x.shape)
30+
return y.copy_(x) + x
31+
32+
33+
class CopySecondArg(torch.nn.Module):
34+
def forward(self, x):
35+
y = torch.zeros(x.shape)
36+
return x * y.copy_(x)
37+
38+
39+
class CopyBothArgs(torch.nn.Module):
40+
def forward(self, x):
41+
y = torch.zeros(x.shape)
42+
return y.copy_(x) + y.copy_(x)
43+
44+
45+
class CopyAfterOtherOp(torch.nn.Module):
46+
def forward(self, x):
47+
y = torch.zeros(x.shape)
48+
x = x * 2
49+
return y.copy_(x) + x
50+
51+
52+
class CopyParallelToOtherOp(torch.nn.Module):
53+
def forward(self, x):
54+
y = torch.zeros(x.shape)
55+
return x * 2 + y.copy_(x)
56+
57+
58+
test_suite = {
59+
"copy_output": lambda: (
60+
CopyOutput,
61+
(torch.rand(1, 2, 3, 4, dtype=torch.float32),),
62+
),
63+
"copy_first_arg": lambda: (
64+
CopyFirstArg,
65+
(torch.rand(1, 2, 3, 4, dtype=torch.float32),),
66+
),
67+
"copy_second_arg": lambda: (
68+
CopySecondArg,
69+
(torch.rand(1, 2, 3, 4, dtype=torch.float32),),
70+
),
71+
"copy_both_args": lambda: (
72+
CopyBothArgs,
73+
(torch.rand(1, 2, 3, 4, dtype=torch.float32),),
74+
),
75+
"copy_after_other_op": lambda: (
76+
CopyAfterOtherOp,
77+
(torch.rand(1, 2, 3, 4, dtype=torch.float32),),
78+
),
79+
"copy_parallel_to_other_op": lambda: (
80+
CopyParallelToOtherOp,
81+
(torch.rand(1, 2, 3, 4, dtype=torch.float32),),
82+
),
83+
}
84+
85+
86+
aten_op = "torch.ops.aten.copy_.default"
87+
exir_op = "executorch_exir_dialects_edge__ops_aten_copy_default"
88+
89+
input_t = Tuple[torch.Tensor]
90+
91+
92+
@common.parametrize("input_data", test_suite)
93+
def test_copy_tosa_FP(input_data):
94+
module, input_tensor = input_data()
95+
pipeline = TosaPipelineFP[input_t](
96+
module(),
97+
input_tensor,
98+
aten_op=aten_op,
99+
exir_op=exir_op,
100+
)
101+
pipeline.run()
102+
103+
104+
@common.parametrize("input_data", test_suite)
105+
def test_copy_tosa_INT(input_data):
106+
module, input_tensor = input_data()
107+
108+
pipeline = TosaPipelineINT[input_t](
109+
module(),
110+
input_tensor,
111+
aten_op,
112+
exir_op,
113+
)
114+
pipeline.run()
115+
116+
117+
@common.parametrize("input_data", test_suite)
118+
@common.XfailIfNoCorstone300
119+
def test_copy_u55_INT(input_data):
120+
module, input_tensor = input_data()
121+
122+
pipeline = EthosU55PipelineINT[input_t](
123+
module(),
124+
input_tensor,
125+
aten_op,
126+
exir_op,
127+
)
128+
pipeline.run()
129+
130+
131+
@common.parametrize("input_data", test_suite)
132+
@common.XfailIfNoCorstone320
133+
def test_copy_u85_INT(input_data):
134+
module, input_tensor = input_data()
135+
136+
pipeline = EthosU85PipelineINT[input_t](
137+
module(),
138+
input_tensor,
139+
aten_op,
140+
exir_op,
141+
)
142+
143+
pipeline.run()
144+
145+
146+
@common.parametrize("test_data", test_suite)
147+
@common.SkipIfNoModelConverter
148+
def test_copy_vgf_FP(test_data):
149+
module, input_tensor = test_data()
150+
pipeline = VgfPipeline[input_t](
151+
module(),
152+
input_tensor,
153+
aten_op=aten_op,
154+
exir_op=exir_op,
155+
tosa_version="TOSA-1.0+FP",
156+
)
157+
pipeline.run()
158+
159+
160+
@common.parametrize("test_data", test_suite)
161+
@common.SkipIfNoModelConverter
162+
def test_copy_vgf_INT(test_data):
163+
module, input_tensor = test_data()
164+
pipeline = VgfPipeline[input_t](
165+
module(),
166+
input_tensor,
167+
aten_op,
168+
exir_op,
169+
tosa_version="TOSA-1.0+INT",
170+
)
171+
pipeline.run()

0 commit comments

Comments
 (0)