From 877119fd69871ebc112e3793c2c1fad9f593d3d2 Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Tue, 29 Jul 2025 10:49:34 -0700 Subject: [PATCH 01/15] Register clone_dim_order op; add test for op replacement --- exir/passes/dim_order_ops_registry.py | 19 ++++++++++++++++++ exir/tests/test_memory_format_ops_pass.py | 15 ++++++++++++++ .../test_memory_format_ops_pass_utils.py | 20 +++++++++++++++++++ 3 files changed, 54 insertions(+) diff --git a/exir/passes/dim_order_ops_registry.py b/exir/passes/dim_order_ops_registry.py index f3fc009f109..7a5dff387c1 100644 --- a/exir/passes/dim_order_ops_registry.py +++ b/exir/passes/dim_order_ops_registry.py @@ -28,6 +28,14 @@ "_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "_clone_dim_order(Tensor self, *, bool non_blocking=False, int[]? dim_order=None) -> Tensor" +) + +lib.define( + "_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)" +) + def _op_impl(target, *args, **kwargs): kwargs["memory_format"] = get_memory_format(kwargs.get("dim_order", None)) @@ -57,12 +65,23 @@ def _empty_dim_order_out_impl(*args, **kwargs): return _op_impl(torch.ops.aten.empty.out, *args, **kwargs) +@impl(lib, "_clone_dim_order", "CompositeImplicitAutograd") +def _clone_dim_order_impl(*args, **kwargs): + return _op_impl(torch.ops.aten.clone.default, *args, **kwargs) + + +@impl(lib, "_clone_dim_order.out", "CompositeImplicitAutograd") +def _clone_dim_order_out_impl(*args, **kwargs): + return _op_impl(torch.ops.aten.clone.out, *args, **kwargs) + + """ Defines a map of edge ops to the corresponding dim_order ops for quick lookup """ DimOrderOpsMap = { exir_ops.edge.aten._to_copy.default: exir_ops.edge.dim_order_ops._to_dim_order_copy.default, exir_ops.edge.aten.empty.memory_format: exir_ops.edge.dim_order_ops._empty_dim_order.default, + exir_ops.edge.aten.clone.default: exir_ops.edge.dim_order_ops._clone_dim_order.default, } """ diff --git a/exir/tests/test_memory_format_ops_pass.py b/exir/tests/test_memory_format_ops_pass.py index 84cd0faa485..da29d37b382 100644 --- a/exir/tests/test_memory_format_ops_pass.py +++ b/exir/tests/test_memory_format_ops_pass.py @@ -28,6 +28,7 @@ MemoryFormatOpsPassTestUtils, MemoryFormatTestSet, PropagateToCopyChannalsLastModule, + SimpleCloneChannelsLastModule, SimpleEmptyChannelLastModule, SimpleEmptyContiguoustModule, SimpleToCopyChannelsLastModule, @@ -389,3 +390,17 @@ def test_mobilenet_v3_xnnpack(self) -> None: rtol=1e-3, ), ) + + def test_op_clone_dim_order_replacement(self): + model = SimpleCloneChannelsLastModule() + x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format) + clone_dim_order_op_str = ( + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" + ) + + exported = export(model.eval(), (x,), strict=True) + epm = to_edge(exported, compile_config=EdgeCompileConfig(_skip_dim_order=False)) + + FileCheck().check_count(clone_dim_order_op_str, 1, exactly=True).run( + epm.exported_program().graph_module.code + ) diff --git a/exir/tests/test_memory_format_ops_pass_utils.py b/exir/tests/test_memory_format_ops_pass_utils.py index 6daf38b187f..2ce928ca30c 100644 --- a/exir/tests/test_memory_format_ops_pass_utils.py +++ b/exir/tests/test_memory_format_ops_pass_utils.py @@ -38,6 +38,10 @@ "torch.ops.aten.empty.memory_format", "executorch_exir_dialects_edge__ops_dim_order_ops__empty_dim_order_default", ), + torch.ops.aten.clone.default: ( + "torch.ops.aten.clone.default", + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default", + ), } @@ -70,6 +74,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x.to(dtype=torch.double, memory_format=torch.channels_last) +class SimpleCloneContiguousModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.clone(memory_format=torch.contiguous_format) + + +class SimpleCloneChannelsLastModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.clone(memory_format=torch.channels_last) + + class SimpleEmptyContiguoustModule(torch.nn.Module): def __init__(self): super().__init__() From f75845d9e5dc829bf38bfe82a39d9f10bf2ac4c3 Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Tue, 29 Jul 2025 12:09:41 -0700 Subject: [PATCH 02/15] Rename clone_dim_order op registration test --- exir/tests/test_memory_format_ops_pass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exir/tests/test_memory_format_ops_pass.py b/exir/tests/test_memory_format_ops_pass.py index da29d37b382..645df2acc9d 100644 --- a/exir/tests/test_memory_format_ops_pass.py +++ b/exir/tests/test_memory_format_ops_pass.py @@ -391,7 +391,7 @@ def test_mobilenet_v3_xnnpack(self) -> None: ), ) - def test_op_clone_dim_order_replacement(self): + def test_op_clone_dim_order_registration(self): model = SimpleCloneChannelsLastModule() x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format) clone_dim_order_op_str = ( From cff39c9ae6bb28ffaae7959a370467249201c5a9 Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Sat, 2 Aug 2025 20:04:24 -0700 Subject: [PATCH 03/15] Add graph level and end to end tests for _clone_dim_order op --- exir/tests/test_memory_format_ops_pass.py | 88 ++++++++++++++++--- .../test_memory_format_ops_pass_utils.py | 10 +++ 2 files changed, 84 insertions(+), 14 deletions(-) diff --git a/exir/tests/test_memory_format_ops_pass.py b/exir/tests/test_memory_format_ops_pass.py index 645df2acc9d..2eeee244825 100644 --- a/exir/tests/test_memory_format_ops_pass.py +++ b/exir/tests/test_memory_format_ops_pass.py @@ -27,8 +27,10 @@ AmbiguousDimOrderError, MemoryFormatOpsPassTestUtils, MemoryFormatTestSet, + PropagateToCloneChannelsLastModule, PropagateToCopyChannalsLastModule, SimpleCloneChannelsLastModule, + SimpleCloneContiguousModule, SimpleEmptyChannelLastModule, SimpleEmptyContiguoustModule, SimpleToCopyChannelsLastModule, @@ -92,6 +94,36 @@ def test_op_empty_replacement_contiguous(self) -> None: ), ) + def test_op_clone_replacement_contiguous(self) -> None: + model = SimpleCloneContiguousModule() + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=model.eval(), + op=torch.ops.aten.clone.default, + sample_input=( + torch.randn((3, 4, 5, 6)).to(memory_format=torch.channels_last), + ), + target_memory_format=torch.contiguous_format, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + ), + ) + + def test_op_clone_replacement_channels_last(self) -> None: + model = SimpleCloneChannelsLastModule() + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=model.eval(), + op=torch.ops.aten.clone.default, + sample_input=( + torch.randn((3, 4, 5, 6)).to(memory_format=torch.contiguous_format), + ), + target_memory_format=torch.channels_last, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + ), + ) + def test_op_dim_order_update(self) -> None: MemoryFormatOpsPassTestUtils.memory_format_test_runner( self, @@ -129,6 +161,25 @@ def test_op_dim_order_propagation(self) -> None: check_unambiguous_dim_order=True, ) + def test_op_clone_dim_order_propagation(self) -> None: + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=PropagateToCloneChannelsLastModule().eval(), + op=torch.ops.aten.clone.default, + sample_input=( + torch.rand_like( + torch.zeros([2, 2, 2, 2]), + dtype=torch.float32, + memory_format=torch.contiguous_format, + ), + ), + target_memory_format=torch.channels_last, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + ), + check_unambiguous_dim_order=True, + ) + def test_op_dim_order_propagation_ambiguous(self) -> None: try: MemoryFormatOpsPassTestUtils.memory_format_test_runner( @@ -154,6 +205,29 @@ def test_op_dim_order_propagation_ambiguous(self) -> None: except AmbiguousDimOrderError: pass # Expected error + def test_op_clone_dim_order_graph_replacement(self): + model = SimpleCloneChannelsLastModule() + x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format) + _clone_dim_order_op_str = ( + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" + ) + + exported = export(model.eval(), (x,), strict=True) + epm = to_edge(exported, compile_config=EdgeCompileConfig(_skip_dim_order=False)) + + # Verify one _clone_dim_order op exists and aten.clone.default nodes have been removed. + ( + FileCheck() + .check_not( + "aten.clone.default" + ) # Check before first _clone_dim_order_op_str match. + .check_count(_clone_dim_order_op_str, 1, exactly=True) + .check_not( + "aten.clone.default" + ) # Check after _clone_dim_order_op_str match. + .run(epm.exported_program().graph_module.code) + ) + # Only test dim order replacement result in lean mode test. # This test is irrelevant with operator mode. def test_dim_order_replacement(self) -> None: @@ -390,17 +464,3 @@ def test_mobilenet_v3_xnnpack(self) -> None: rtol=1e-3, ), ) - - def test_op_clone_dim_order_registration(self): - model = SimpleCloneChannelsLastModule() - x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format) - clone_dim_order_op_str = ( - "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" - ) - - exported = export(model.eval(), (x,), strict=True) - epm = to_edge(exported, compile_config=EdgeCompileConfig(_skip_dim_order=False)) - - FileCheck().check_count(clone_dim_order_op_str, 1, exactly=True).run( - epm.exported_program().graph_module.code - ) diff --git a/exir/tests/test_memory_format_ops_pass_utils.py b/exir/tests/test_memory_format_ops_pass_utils.py index 2ce928ca30c..f5a786c6f74 100644 --- a/exir/tests/test_memory_format_ops_pass_utils.py +++ b/exir/tests/test_memory_format_ops_pass_utils.py @@ -122,6 +122,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return t1 * t2 +class PropagateToCloneChannelsLastModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + t1 = x.clone(memory_format=torch.channels_last) + t2 = t1 + t1 + return t1 * t2 + + class AmbiguousDimOrderError(RuntimeError): pass From 1fe461f7baf30f8be5b09d752acda6f798851212 Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Sat, 2 Aug 2025 20:09:28 -0700 Subject: [PATCH 04/15] Remove _clone_dim_order op registration (moved to PR #12974) --- exir/passes/dim_order_ops_registry.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/exir/passes/dim_order_ops_registry.py b/exir/passes/dim_order_ops_registry.py index 7a5dff387c1..eddffc5e980 100644 --- a/exir/passes/dim_order_ops_registry.py +++ b/exir/passes/dim_order_ops_registry.py @@ -28,14 +28,6 @@ "_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)" ) -lib.define( - "_clone_dim_order(Tensor self, *, bool non_blocking=False, int[]? dim_order=None) -> Tensor" -) - -lib.define( - "_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)" -) - def _op_impl(target, *args, **kwargs): kwargs["memory_format"] = get_memory_format(kwargs.get("dim_order", None)) @@ -65,16 +57,6 @@ def _empty_dim_order_out_impl(*args, **kwargs): return _op_impl(torch.ops.aten.empty.out, *args, **kwargs) -@impl(lib, "_clone_dim_order", "CompositeImplicitAutograd") -def _clone_dim_order_impl(*args, **kwargs): - return _op_impl(torch.ops.aten.clone.default, *args, **kwargs) - - -@impl(lib, "_clone_dim_order.out", "CompositeImplicitAutograd") -def _clone_dim_order_out_impl(*args, **kwargs): - return _op_impl(torch.ops.aten.clone.out, *args, **kwargs) - - """ Defines a map of edge ops to the corresponding dim_order ops for quick lookup """ From 95db027728a1a80e50fb8361a7b86a30c9b41e5a Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Tue, 5 Aug 2025 17:09:32 -0700 Subject: [PATCH 05/15] Register _clone_dim_order op --- exir/passes/dim_order_ops_registry.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/exir/passes/dim_order_ops_registry.py b/exir/passes/dim_order_ops_registry.py index eddffc5e980..7a5dff387c1 100644 --- a/exir/passes/dim_order_ops_registry.py +++ b/exir/passes/dim_order_ops_registry.py @@ -28,6 +28,14 @@ "_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "_clone_dim_order(Tensor self, *, bool non_blocking=False, int[]? dim_order=None) -> Tensor" +) + +lib.define( + "_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)" +) + def _op_impl(target, *args, **kwargs): kwargs["memory_format"] = get_memory_format(kwargs.get("dim_order", None)) @@ -57,6 +65,16 @@ def _empty_dim_order_out_impl(*args, **kwargs): return _op_impl(torch.ops.aten.empty.out, *args, **kwargs) +@impl(lib, "_clone_dim_order", "CompositeImplicitAutograd") +def _clone_dim_order_impl(*args, **kwargs): + return _op_impl(torch.ops.aten.clone.default, *args, **kwargs) + + +@impl(lib, "_clone_dim_order.out", "CompositeImplicitAutograd") +def _clone_dim_order_out_impl(*args, **kwargs): + return _op_impl(torch.ops.aten.clone.out, *args, **kwargs) + + """ Defines a map of edge ops to the corresponding dim_order ops for quick lookup """ From c7caa2731edc9615ed9a4e539809bd92fbc4b7ce Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Tue, 12 Aug 2025 21:23:22 -0700 Subject: [PATCH 06/15] Register _clone_dim_order as no-op in CoreML --- backends/apple/coreml/compiler/torch_ops.py | 23 +++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/backends/apple/coreml/compiler/torch_ops.py b/backends/apple/coreml/compiler/torch_ops.py index 81306c9a2fd..e8969c6a7bd 100644 --- a/backends/apple/coreml/compiler/torch_ops.py +++ b/backends/apple/coreml/compiler/torch_ops.py @@ -15,6 +15,7 @@ from coremltools.converters.mil.frontend.torch.ops import ( _get_inputs, _get_kwinputs, + noop, NUM_TO_NUMPY_DTYPE, NUM_TO_TORCH_DTYPE, split, @@ -67,6 +68,28 @@ def _to_dim_order_copy(context, node): to(context, node) +@register_torch_op( + torch_alias=[ + "dim_order_ops::_clone_dim_order", + "dim_order_ops._clone_dim_order", + ], + override=False, +) +def _clone_dim_order(context, node): + dim_order = _get_kwinputs(context, node, "dim_order", default=[None])[0] + node.kwinputs.pop("dim_order") + + # In CoreML, dim_order.val will be a ndarray, so we convert it to a list to check memory format. + dim_order = [int(d) for d in dim_order.val] + memory_format = get_memory_format(dim_order) + assert ( + memory_format == _torch.contiguous_format + ), "Only contiguous memory format is supported in CoreML" + + # Since CoreML only supports contiguous format, no dim_order preservation is needed. Treat this as a no-op clone. + noop(context, node) + + # https://github.com/apple/coremltools/pull/2558 @register_torch_op( torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"], From e54605f1366b95fc4657297f8d6de8230b149b8f Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Tue, 12 Aug 2025 21:39:23 -0700 Subject: [PATCH 07/15] Remove redundant _clone_dim_order graph check --- exir/tests/test_memory_format_ops_pass.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/exir/tests/test_memory_format_ops_pass.py b/exir/tests/test_memory_format_ops_pass.py index 2eeee244825..2384f6123a9 100644 --- a/exir/tests/test_memory_format_ops_pass.py +++ b/exir/tests/test_memory_format_ops_pass.py @@ -205,29 +205,6 @@ def test_op_dim_order_propagation_ambiguous(self) -> None: except AmbiguousDimOrderError: pass # Expected error - def test_op_clone_dim_order_graph_replacement(self): - model = SimpleCloneChannelsLastModule() - x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format) - _clone_dim_order_op_str = ( - "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" - ) - - exported = export(model.eval(), (x,), strict=True) - epm = to_edge(exported, compile_config=EdgeCompileConfig(_skip_dim_order=False)) - - # Verify one _clone_dim_order op exists and aten.clone.default nodes have been removed. - ( - FileCheck() - .check_not( - "aten.clone.default" - ) # Check before first _clone_dim_order_op_str match. - .check_count(_clone_dim_order_op_str, 1, exactly=True) - .check_not( - "aten.clone.default" - ) # Check after _clone_dim_order_op_str match. - .run(epm.exported_program().graph_module.code) - ) - # Only test dim order replacement result in lean mode test. # This test is irrelevant with operator mode. def test_dim_order_replacement(self) -> None: From 246bc441f54657d38935e6a5a5b6573117e9dab0 Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Wed, 13 Aug 2025 15:49:24 -0700 Subject: [PATCH 08/15] Add _clone_dim_order to RemoveClonePass and update op name in tests --- backends/arm/_passes/remove_clone_pass.py | 6 +++++- .../test/misc/test_partition_decomposed_quantized_ops.py | 2 +- backends/arm/test/ops/test_clone.py | 2 +- backends/arm/test/passes/test_remove_clone_pass.py | 6 ++++-- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/backends/arm/_passes/remove_clone_pass.py b/backends/arm/_passes/remove_clone_pass.py index a2822c7378e..d795cd42bcf 100644 --- a/backends/arm/_passes/remove_clone_pass.py +++ b/backends/arm/_passes/remove_clone_pass.py @@ -14,7 +14,11 @@ class RemoveClonePass(ExportPass): """Remove all clones from graph_module""" def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.aten.clone.default: + clone_ops = ( + exir_ops.edge.aten.clone.default, + exir_ops.edge.dim_order_ops._clone_dim_order.default, + ) + if op not in clone_ops: return super().call_operator(op, args, kwargs, meta) if len(args) != 1: diff --git a/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py b/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py index 1aaa2950337..04ecd57e7b1 100644 --- a/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py +++ b/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py @@ -38,7 +38,7 @@ ] linear_residual_exir_op: list[str] = [ "executorch_exir_dialects_edge__ops_aten_gelu_default", - "executorch_exir_dialects_edge__ops_aten_clone_default", + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default", "executorch_exir_dialects_edge__ops_aten_linear_default", "executorch_exir_dialects_edge__ops_aten_add_Tensor", ] diff --git a/backends/arm/test/ops/test_clone.py b/backends/arm/test/ops/test_clone.py index 7a24848697e..5c5f5e9979a 100644 --- a/backends/arm/test/ops/test_clone.py +++ b/backends/arm/test/ops/test_clone.py @@ -23,7 +23,7 @@ ) aten_op = "torch.ops.aten.clone.default" -exir_op = "executorch_exir_dialects_edge__ops_aten_clone_default" +exir_op = "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" input_t = Tuple[torch.Tensor] diff --git a/backends/arm/test/passes/test_remove_clone_pass.py b/backends/arm/test/passes/test_remove_clone_pass.py index dea0bb06f5e..5c2171795f7 100755 --- a/backends/arm/test/passes/test_remove_clone_pass.py +++ b/backends/arm/test/passes/test_remove_clone_pass.py @@ -35,9 +35,11 @@ def test_remove_clone_tosa_INT(): module.get_inputs(), quantize=True, ops_before_pass={ - "executorch_exir_dialects_edge__ops_aten_clone_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, }, - ops_not_after_pass=["executorch_exir_dialects_edge__ops_aten_clone_default"], + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" + ], pass_list=[RemoveClonePass], ) pipeline.run() From c48467c17dd2d600e165dd5ca98735eacdcec863 Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Wed, 13 Aug 2025 15:52:11 -0700 Subject: [PATCH 09/15] Register _clone_dim_order under TOSA support check --- .../clone_dim_order_support.py | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 backends/arm/operator_support/clone_dim_order_support.py diff --git a/backends/arm/operator_support/clone_dim_order_support.py b/backends/arm/operator_support/clone_dim_order_support.py new file mode 100644 index 00000000000..2c6aacb7d38 --- /dev/null +++ b/backends/arm/operator_support/clone_dim_order_support.py @@ -0,0 +1,76 @@ +# Copyright 2024-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import logging + +import torch +import torch.fx as fx + +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + +logger = logging.getLogger(__name__) + + +@register_tosa_support_check +class CloneDimOrderSupport(SupportedTOSAOperatorCheck): + targets = [ + exir_ops.edge.dim_order_ops._clone_dim_order.default, + ] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: + assert node.target in self.targets + + supported_dtypes = {torch.bool, torch.int8, torch.int16, torch.int32} + if tosa_spec.support_float(): + supported_dtypes |= {torch.bfloat16, torch.float16, torch.float32} + + # Check input type + assert len(node.all_input_nodes) == 1 + input_val = node.all_input_nodes[0].meta["val"] + assert isinstance(input_val, torch._subclasses.FakeTensor) + input_dtype = input_val.dtype + if input_dtype not in supported_dtypes: + self.reporter.report_reject( + node, + f"Input dtype {input_val.dtype} is not supported in {node.target}.", + ) + return False + + # Check output type + output_val = node.meta["val"] + assert isinstance(output_val, torch._subclasses.FakeTensor) + if output_val.dtype != input_dtype: + self.reporter.report_reject( + node, + f"Input dtype {input_val.dtype} does not match {output_val.dtype}.", + ) + return False + + # Check dim_order + if "dim_order" in node.kwargs: + dim_order = node.kwargs["dim_order"] + # pyre-ignore[6] + if dim_order != list(range(len(dim_order))): # type: ignore[arg-type] + self.reporter.report_reject( + node, + f"Argument {dim_order=} is not supported for " + f"{node.target} right now.", + ) + return False + + return True From 5546360e8ab1a1b248befb06c589eaf169f8d9bd Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Fri, 15 Aug 2025 19:14:44 -0700 Subject: [PATCH 10/15] Add clone_dim_order_support to TOSA operator support list --- backends/arm/operator_support/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index 2075e0f554f..5557a2116c6 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -6,6 +6,7 @@ # pyre-unsafe from . import ( # noqa + clone_dim_order_support, convolution_support, embedding_support, ethos_u55_support, From 5c5e65a3f9188a5279c3386388fdf95c2250ca19 Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Fri, 15 Aug 2025 19:18:06 -0700 Subject: [PATCH 11/15] Register node visitor for _clone_dim_order --- backends/arm/operators/op_clone_dim_order.py | 43 ++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 backends/arm/operators/op_clone_dim_order.py diff --git a/backends/arm/operators/op_clone_dim_order.py b/backends/arm/operators/op_clone_dim_order.py new file mode 100644 index 00000000000..79369ade516 --- /dev/null +++ b/backends/arm/operators/op_clone_dim_order.py @@ -0,0 +1,43 @@ +# Copyright 2024-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import Any, List + +import torch + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) +from executorch.backends.arm.tosa_mapping import TosaArg + + +@register_node_visitor +class CloneDimOrderVisitor(NodeVisitor): + """ + Implement _clone_dim_order as an identity operation. + """ + + target = "dim_order_ops._clone_dim_order.default" + + tosa_specs = NodeVisitor.tosa_specs + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts # type: ignore + + validate_num_inputs(self.target, inputs, 1) + + # Since only contiguous dim order is currently supported, treat clone as an identity op. + tosa_graph.addOperator(ts.TosaOp.Op().IDENTITY, [inputs[0].name], [output.name]) From 7a0bc6af31ce15cc4405e598df0d33f99f958333 Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Mon, 25 Aug 2025 09:15:13 -0700 Subject: [PATCH 12/15] Remove visitor node registration for _clone_dim_order --- backends/arm/operators/op_clone_dim_order.py | 43 -------------------- 1 file changed, 43 deletions(-) delete mode 100644 backends/arm/operators/op_clone_dim_order.py diff --git a/backends/arm/operators/op_clone_dim_order.py b/backends/arm/operators/op_clone_dim_order.py deleted file mode 100644 index 79369ade516..00000000000 --- a/backends/arm/operators/op_clone_dim_order.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -from typing import Any, List - -import torch - -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, -) -from executorch.backends.arm.tosa_mapping import TosaArg - - -@register_node_visitor -class CloneDimOrderVisitor(NodeVisitor): - """ - Implement _clone_dim_order as an identity operation. - """ - - target = "dim_order_ops._clone_dim_order.default" - - tosa_specs = NodeVisitor.tosa_specs - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts # type: ignore - - validate_num_inputs(self.target, inputs, 1) - - # Since only contiguous dim order is currently supported, treat clone as an identity op. - tosa_graph.addOperator(ts.TosaOp.Op().IDENTITY, [inputs[0].name], [output.name]) From 74e2cce064e7cb89ff2e3a96a12844d51973ba0d Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Mon, 25 Aug 2025 09:17:22 -0700 Subject: [PATCH 13/15] Remove aten.clone check from RemoveClonePass --- backends/arm/_passes/remove_clone_pass.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/backends/arm/_passes/remove_clone_pass.py b/backends/arm/_passes/remove_clone_pass.py index d795cd42bcf..896d3f54673 100644 --- a/backends/arm/_passes/remove_clone_pass.py +++ b/backends/arm/_passes/remove_clone_pass.py @@ -14,11 +14,7 @@ class RemoveClonePass(ExportPass): """Remove all clones from graph_module""" def call_operator(self, op, args, kwargs, meta): - clone_ops = ( - exir_ops.edge.aten.clone.default, - exir_ops.edge.dim_order_ops._clone_dim_order.default, - ) - if op not in clone_ops: + if op != exir_ops.edge.dim_order_ops._clone_dim_order.default: return super().call_operator(op, args, kwargs, meta) if len(args) != 1: From f9f9515816f4b3ff9dbca70b028267e842fec563 Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Mon, 25 Aug 2025 09:18:17 -0700 Subject: [PATCH 14/15] Remove input dtype gating and add memory_format check --- .../clone_dim_order_support.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/backends/arm/operator_support/clone_dim_order_support.py b/backends/arm/operator_support/clone_dim_order_support.py index 2c6aacb7d38..7269f7e7932 100644 --- a/backends/arm/operator_support/clone_dim_order_support.py +++ b/backends/arm/operator_support/clone_dim_order_support.py @@ -35,21 +35,11 @@ def is_node_tosa_supported( ) -> bool: assert node.target in self.targets - supported_dtypes = {torch.bool, torch.int8, torch.int16, torch.int32} - if tosa_spec.support_float(): - supported_dtypes |= {torch.bfloat16, torch.float16, torch.float32} - # Check input type assert len(node.all_input_nodes) == 1 input_val = node.all_input_nodes[0].meta["val"] assert isinstance(input_val, torch._subclasses.FakeTensor) input_dtype = input_val.dtype - if input_dtype not in supported_dtypes: - self.reporter.report_reject( - node, - f"Input dtype {input_val.dtype} is not supported in {node.target}.", - ) - return False # Check output type output_val = node.meta["val"] @@ -61,6 +51,16 @@ def is_node_tosa_supported( ) return False + # Check memory format + if "memory_format" in node.kwargs: + if node.kwargs["memory_format"] in (torch.preserve_format,): + self.reporter.report_reject( + node, + f"Argument 'memory_format' is not supported for " + f"{node.target} right now.", + ) + return False + # Check dim_order if "dim_order" in node.kwargs: dim_order = node.kwargs["dim_order"] From 6839212c21b029d2174066960a33c88641bdec6e Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Mon, 25 Aug 2025 15:41:24 -0700 Subject: [PATCH 15/15] Add Core ML test for _clone_dim_order --- backends/apple/coreml/test/test_torch_ops.py | 23 ++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/backends/apple/coreml/test/test_torch_ops.py b/backends/apple/coreml/test/test_torch_ops.py index 0d6b581ee72..25691777aec 100644 --- a/backends/apple/coreml/test/test_torch_ops.py +++ b/backends/apple/coreml/test/test_torch_ops.py @@ -221,6 +221,28 @@ def test_dequantize_codebook_embedding(self): et_prog = delegated_program.to_executorch() self._compare_outputs(et_prog, model, example_inputs) + def test__clone_dim_order_contiguous(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.ops.dim_order_ops._clone_dim_order( + x, dim_order=[0, 1, 2, 3] + ) + + model, example_inputs = Model(), (torch.randn(1, 3, 8, 8),) + ep = torch.export.export(model, example_inputs) + delegated_program = executorch.exir.to_edge_transform_and_lower( + ep, + partitioner=[self._coreml_partitioner()], + ) + for node in delegated_program.exported_program().graph.nodes: + if node.op == "call_function": + assert node.target.__name__ in [ + "executorch_call_delegate", + "getitem", + ], f"Got unexpected node target after delegation: {node.target.__name__}" + et_prog = delegated_program.to_executorch() + self._compare_outputs(et_prog, model, example_inputs) + if __name__ == "__main__": test_runner = TestTorchOps() @@ -231,3 +253,4 @@ def test_dequantize_codebook_embedding(self): test_runner.test_dequantize_affine_c8w_embedding_b4w_linear() test_runner.test_dequantize_codebook_linear() test_runner.test_dequantize_codebook_embedding() + test_runner.test__clone_dim_order_contiguous()