Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 32 additions & 17 deletions backends/arm/_passes/convert_expand_copy_to_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,43 @@
import logging
from typing import cast

import torch

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

logger = logging.getLogger(__name__)


def calculate_multiples(args):
input_node_or_tensor = args[0]

if isinstance(input_node_or_tensor, torch.fx.node.Node):
input_data = input_node_or_tensor.meta["val"]
else:
input_data = input_node_or_tensor.data

input_shape = input_data.shape

multiples = cast(list[int], args[1])
expanded_rank = len(multiples)

# Expanded shape is 'input_shape' front-padded with ones.
padding = expanded_rank - len(input_shape)
extended_shape = [
input_shape[i] if i >= 0 else 1 for i in range(-padding, len(input_shape))
]

# To convert expand arg to repeat arg, non-repeated dims should have
# multiples[dim] = 1. Passing -1 to expand arg means
# not changing the size of that dimension.
multiples = [
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
for i in range(expanded_rank)
]
return multiples


class ConvertExpandCopyToRepeatPass(ExportPass):
"""
Replace expand copy with repeat since it is a repeat that can only repeat singleton dimensions.
Expand All @@ -26,23 +57,7 @@ def call_operator(self, op, args, kwargs, meta):
if op != self.expand_copy:
return super().call_operator(op, args, kwargs, meta)

input_shape = args[0].data.shape
multiples = cast(list[int], args[1])
expanded_rank = len(multiples)

# Expanded shape is 'input_shape' front-padded with ones.
padding = expanded_rank - len(input_shape)
extended_shape = [
input_shape[i] if i >= 0 else 1 for i in range(-padding, len(input_shape))
]

# To convert expand arg to repeat arg, non-repeated dims should have
# multiples[dim] = 1. Passing -1 to expand arg means
# not changing the size of that dimension.
multiples = [
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
for i in range(expanded_rank)
]
multiples = calculate_multiples(args)

if all((x == 1 for x in multiples)):
# All dimensions/repetitions occur only once. Remove node
Expand Down
10 changes: 10 additions & 0 deletions backends/arm/_passes/remove_clone_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@

# pyre-unsafe

import logging

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

logger = logging.getLogger(__name__)


class RemoveClonePass(ExportPass):
"""Remove all clones from graph_module"""
Expand All @@ -21,4 +25,10 @@ def call_operator(self, op, args, kwargs, meta):
raise ValueError(
f"clone operator expects exactly one argument, got {len(args)}"
)

if "memory_format" in kwargs:
logger.warning(
f"Removing clone with memory_format '{kwargs['memory_format']}'."
)

return args[0]
1 change: 1 addition & 0 deletions backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# pyre-unsafe

from . import ( # noqa
clone_support,
convolution_support,
embedding_support,
ethos_u55_support,
Expand Down
37 changes: 37 additions & 0 deletions backends/arm/operator_support/clone_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging

import torch.fx as fx
from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops

logger = logging.getLogger(__name__)


@register_tosa_support_check
class CloneSupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten.clone.default]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def is_node_tosa_supported(
self, node: fx.Node, tosa_spec: TosaSpecification
) -> bool:

input_node = node.args[0]
if not isinstance(input_node, fx.Node):
self.reporter.report_reject(node, "Non tensor clones are not supported")
return False

return True
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ def is_node_supported(
exir_ops.edge.aten.var.correction,
exir_ops.edge.aten.var.dim,
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.pow.Tensor_Scalar,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/scripts/parse_test_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"unflatten.int",
"_native_batch_norm_legit_no_training.default",
"_native_batch_norm_legit.no_stats",
"alias_copy.default",
]
ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS

Expand Down
4 changes: 3 additions & 1 deletion backends/arm/test/ops/test_alias_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor):
return torch.alias_copy(x)
return (
torch.alias_copy(x) * 1
) # Multiply by one to make sure it is partitioned.


@common.parametrize("test_data", AliasCopy.test_data)
Expand Down
122 changes: 69 additions & 53 deletions backends/arm/test/ops/test_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

#
# Tests the clone op which copies the data of the input tensor (possibly with new data format)
#

from typing import Tuple

import pytest
import torch

from executorch.backends.arm.test import common
Expand All @@ -28,57 +24,82 @@
input_t = Tuple[torch.Tensor]


class Clone(torch.nn.Module):
"""A simple module that clones an input tensor."""
class CloneFirstArg(torch.nn.Module):
def forward(self, x):
return x.clone() + x

def forward(self, x: torch.Tensor):
return x.clone()

class CloneSecondArg(torch.nn.Module):
def forward(self, x):
return x * x.clone()


class CloneOutput(torch.nn.Module):
def forward(self, x):
return (x / x).clone()


class CloneBothArgs(torch.nn.Module):
def forward(self, x):
return x.clone() + x.clone()


class CloneAfterOtherOp(torch.nn.Module):
def forward(self, x):
x = x * 2
return x.clone() + x


class CloneParallelToOtherOp(torch.nn.Module):
def forward(self, x):
return x * 2 + x.clone()

test_data_suite = {
"ones_1D_10": lambda: (torch.ones(10),),
"ones_1D_50": lambda: (torch.ones(50),),
"rand_1D_20": lambda: (torch.rand(20),),
"rand_2D_10x10": lambda: (torch.rand(10, 10),),
"rand_3D_5x5x5": lambda: (torch.rand(5, 5, 5),),
"rand_4D_2x3x4x5": lambda: (torch.rand(2, 3, 4, 5),),
"large_tensor": lambda: (torch.rand(1000),),
}

delegated_clones = {
"clone_first_arg": lambda: (CloneFirstArg, (torch.rand(1, 2, 3, 4),)),
"clone_second_arg": lambda: (CloneSecondArg, (torch.rand(1, 2, 3, 4),)),
"clone_output": lambda: (CloneOutput, (torch.rand(1, 2, 3, 4),)),
"clone_both_args": lambda: (CloneBothArgs, (torch.rand(1, 2, 3, 4),)),
"clone_after_other_op": lambda: (CloneAfterOtherOp, (torch.rand(1, 2, 3, 4),)),
"clone_parallel_to_other_op": lambda: (
CloneParallelToOtherOp,
(torch.rand(1, 2, 3, 4),),
),
}

@common.parametrize("test_data", test_data_suite)
def test_clone_tosa_FP(test_data: Tuple[torch.Tensor]):

@common.parametrize("input_data", delegated_clones)
def test_clone_tosa_FP(input_data):
module, input_tensor = input_data()
pipeline = TosaPipelineFP[input_t](
Clone(),
test_data(),
aten_op,
exir_op,
module(),
input_tensor,
[],
)

pipeline.run()


@common.parametrize("test_data", test_data_suite)
def test_clone_tosa_INT(test_data):
@common.parametrize("input_data", delegated_clones)
def test_clone_tosa_INT(input_data):
module, input_tensor = input_data()

pipeline = TosaPipelineINT[input_t](
Clone(),
test_data(),
module(),
input_tensor,
aten_op,
exir_op,
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.parametrize("input_data", delegated_clones)
@common.XfailIfNoCorstone300
@pytest.mark.xfail(
reason="Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477"
)
def test_clone_u55_INT(test_data):
def test_clone_u55_INT(input_data):
module, input_tensor = input_data()

pipeline = EthosU55PipelineINT[input_t](
Clone(),
test_data(),
module(),
input_tensor,
aten_op,
exir_op,
run_on_fvp=True,
Expand All @@ -87,15 +108,14 @@ def test_clone_u55_INT(test_data):
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.parametrize("input_data", delegated_clones)
@common.XfailIfNoCorstone320
@pytest.mark.xfail(
reason="Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477"
)
def test_clone_u85_INT(test_data):
def test_clone_u85_INT(input_data):
module, input_tensor = input_data()

pipeline = EthosU85PipelineINT[input_t](
Clone(),
test_data(),
module(),
input_tensor,
aten_op,
exir_op,
run_on_fvp=True,
Expand All @@ -104,27 +124,23 @@ def test_clone_u85_INT(test_data):
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.parametrize("test_data", delegated_clones)
@common.SkipIfNoModelConverter
@pytest.mark.xfail(
reason="Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477"
)
def test_clone_vgf_FP(test_data):
module, input_tensor = test_data()
pipeline = VgfPipeline[input_t](
Clone(), test_data(), aten_op, exir_op, tosa_version="TOSA-1.0+FP"
module(), input_tensor, aten_op, exir_op, tosa_version="TOSA-1.0+FP"
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.parametrize("test_data", delegated_clones)
@common.SkipIfNoModelConverter
@pytest.mark.xfail(
reason="Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477"
)
def test_clone_vgf_INT(test_data):
module, input_tensor = test_data()
pipeline = VgfPipeline[input_t](
Clone(),
test_data(),
module(),
input_tensor,
aten_op,
exir_op,
tosa_version="TOSA-1.0+INT",
Expand Down
Loading
Loading