From b79289ce379bea150013b4906b65f6792ac856ea Mon Sep 17 00:00:00 2001 From: Jack <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 27 Aug 2025 01:05:51 -0400 Subject: [PATCH] Revert "[EXIR] Register _clone_dim_order op and map aten.clone (#12971)" This reverts commit 130cafc755725d670109f7e99f4a77a912887b49. --- backends/apple/coreml/compiler/torch_ops.py | 23 ------ backends/apple/coreml/test/test_torch_ops.py | 23 ------ backends/arm/_passes/remove_clone_pass.py | 2 +- backends/arm/operator_support/__init__.py | 1 - .../clone_dim_order_support.py | 76 ------------------- ...test_partition_decomposed_quantized_ops.py | 2 +- backends/arm/test/ops/test_clone.py | 2 +- .../arm/test/passes/test_remove_clone_pass.py | 6 +- exir/passes/dim_order_ops_registry.py | 19 ----- exir/tests/test_memory_format_ops_pass.py | 52 ------------- .../test_memory_format_ops_pass_utils.py | 30 -------- 11 files changed, 5 insertions(+), 231 deletions(-) delete mode 100644 backends/arm/operator_support/clone_dim_order_support.py diff --git a/backends/apple/coreml/compiler/torch_ops.py b/backends/apple/coreml/compiler/torch_ops.py index 33ec771ce56..e53670951e0 100644 --- a/backends/apple/coreml/compiler/torch_ops.py +++ b/backends/apple/coreml/compiler/torch_ops.py @@ -15,7 +15,6 @@ from coremltools.converters.mil.frontend.torch.ops import ( _get_inputs, _get_kwinputs, - noop, NUM_TO_NUMPY_DTYPE, NUM_TO_TORCH_DTYPE, split, @@ -92,28 +91,6 @@ def _to_dim_order_copy(context, node): to(context, node) -@register_torch_op( - torch_alias=[ - "dim_order_ops::_clone_dim_order", - "dim_order_ops._clone_dim_order", - ], - override=False, -) -def _clone_dim_order(context, node): - dim_order = _get_kwinputs(context, node, "dim_order", default=[None])[0] - node.kwinputs.pop("dim_order") - - # In CoreML, dim_order.val will be a ndarray, so we convert it to a list to check memory format. - dim_order = [int(d) for d in dim_order.val] - memory_format = get_memory_format(dim_order) - assert ( - memory_format == _torch.contiguous_format - ), "Only contiguous memory format is supported in CoreML" - - # Since CoreML only supports contiguous format, no dim_order preservation is needed. Treat this as a no-op clone. - noop(context, node) - - # https://github.com/apple/coremltools/pull/2558 @register_torch_op( torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"], diff --git a/backends/apple/coreml/test/test_torch_ops.py b/backends/apple/coreml/test/test_torch_ops.py index 29f09127d16..4fdbfdd8f21 100644 --- a/backends/apple/coreml/test/test_torch_ops.py +++ b/backends/apple/coreml/test/test_torch_ops.py @@ -213,28 +213,6 @@ def test_dequantize_codebook_embedding(self): et_prog = delegated_program.to_executorch() self._compare_outputs(et_prog, model, example_inputs) - def test__clone_dim_order_contiguous(self): - class Model(torch.nn.Module): - def forward(self, x): - return torch.ops.dim_order_ops._clone_dim_order( - x, dim_order=[0, 1, 2, 3] - ) - - model, example_inputs = Model(), (torch.randn(1, 3, 8, 8),) - ep = torch.export.export(model, example_inputs) - delegated_program = executorch.exir.to_edge_transform_and_lower( - ep, - partitioner=[self._coreml_partitioner()], - ) - for node in delegated_program.exported_program().graph.nodes: - if node.op == "call_function": - assert node.target.__name__ in [ - "executorch_call_delegate", - "getitem", - ], f"Got unexpected node target after delegation: {node.target.__name__}" - et_prog = delegated_program.to_executorch() - self._compare_outputs(et_prog, model, example_inputs) - if __name__ == "__main__": test_runner = TestTorchOps() @@ -245,4 +223,3 @@ def forward(self, x): test_runner.test_dequantize_affine_c8w_embedding_b4w_linear() test_runner.test_dequantize_codebook_linear() test_runner.test_dequantize_codebook_embedding() - test_runner.test__clone_dim_order_contiguous() diff --git a/backends/arm/_passes/remove_clone_pass.py b/backends/arm/_passes/remove_clone_pass.py index 896d3f54673..a2822c7378e 100644 --- a/backends/arm/_passes/remove_clone_pass.py +++ b/backends/arm/_passes/remove_clone_pass.py @@ -14,7 +14,7 @@ class RemoveClonePass(ExportPass): """Remove all clones from graph_module""" def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.dim_order_ops._clone_dim_order.default: + if op != exir_ops.edge.aten.clone.default: return super().call_operator(op, args, kwargs, meta) if len(args) != 1: diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index 5557a2116c6..2075e0f554f 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -6,7 +6,6 @@ # pyre-unsafe from . import ( # noqa - clone_dim_order_support, convolution_support, embedding_support, ethos_u55_support, diff --git a/backends/arm/operator_support/clone_dim_order_support.py b/backends/arm/operator_support/clone_dim_order_support.py deleted file mode 100644 index 7269f7e7932..00000000000 --- a/backends/arm/operator_support/clone_dim_order_support.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -import logging - -import torch -import torch.fx as fx - -from executorch.backends.arm.operator_support.tosa_supported_operators import ( - register_tosa_support_check, - SupportedTOSAOperatorCheck, -) -from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.exir.dialects._ops import ops as exir_ops - -logger = logging.getLogger(__name__) - - -@register_tosa_support_check -class CloneDimOrderSupport(SupportedTOSAOperatorCheck): - targets = [ - exir_ops.edge.dim_order_ops._clone_dim_order.default, - ] - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-1.0+INT"), - TosaSpecification.create_from_string("TOSA-1.0+FP"), - ] - - def is_node_tosa_supported( - self, node: fx.Node, tosa_spec: TosaSpecification - ) -> bool: - assert node.target in self.targets - - # Check input type - assert len(node.all_input_nodes) == 1 - input_val = node.all_input_nodes[0].meta["val"] - assert isinstance(input_val, torch._subclasses.FakeTensor) - input_dtype = input_val.dtype - - # Check output type - output_val = node.meta["val"] - assert isinstance(output_val, torch._subclasses.FakeTensor) - if output_val.dtype != input_dtype: - self.reporter.report_reject( - node, - f"Input dtype {input_val.dtype} does not match {output_val.dtype}.", - ) - return False - - # Check memory format - if "memory_format" in node.kwargs: - if node.kwargs["memory_format"] in (torch.preserve_format,): - self.reporter.report_reject( - node, - f"Argument 'memory_format' is not supported for " - f"{node.target} right now.", - ) - return False - - # Check dim_order - if "dim_order" in node.kwargs: - dim_order = node.kwargs["dim_order"] - # pyre-ignore[6] - if dim_order != list(range(len(dim_order))): # type: ignore[arg-type] - self.reporter.report_reject( - node, - f"Argument {dim_order=} is not supported for " - f"{node.target} right now.", - ) - return False - - return True diff --git a/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py b/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py index 04ecd57e7b1..1aaa2950337 100644 --- a/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py +++ b/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py @@ -38,7 +38,7 @@ ] linear_residual_exir_op: list[str] = [ "executorch_exir_dialects_edge__ops_aten_gelu_default", - "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default", + "executorch_exir_dialects_edge__ops_aten_clone_default", "executorch_exir_dialects_edge__ops_aten_linear_default", "executorch_exir_dialects_edge__ops_aten_add_Tensor", ] diff --git a/backends/arm/test/ops/test_clone.py b/backends/arm/test/ops/test_clone.py index 3afc52a5cf0..b4f2879be48 100644 --- a/backends/arm/test/ops/test_clone.py +++ b/backends/arm/test/ops/test_clone.py @@ -23,7 +23,7 @@ ) aten_op = "torch.ops.aten.clone.default" -exir_op = "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" +exir_op = "executorch_exir_dialects_edge__ops_aten_clone_default" input_t = Tuple[torch.Tensor] diff --git a/backends/arm/test/passes/test_remove_clone_pass.py b/backends/arm/test/passes/test_remove_clone_pass.py index 5c2171795f7..dea0bb06f5e 100755 --- a/backends/arm/test/passes/test_remove_clone_pass.py +++ b/backends/arm/test/passes/test_remove_clone_pass.py @@ -35,11 +35,9 @@ def test_remove_clone_tosa_INT(): module.get_inputs(), quantize=True, ops_before_pass={ - "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + "executorch_exir_dialects_edge__ops_aten_clone_default": 1, }, - ops_not_after_pass=[ - "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" - ], + ops_not_after_pass=["executorch_exir_dialects_edge__ops_aten_clone_default"], pass_list=[RemoveClonePass], ) pipeline.run() diff --git a/exir/passes/dim_order_ops_registry.py b/exir/passes/dim_order_ops_registry.py index 7a5dff387c1..f3fc009f109 100644 --- a/exir/passes/dim_order_ops_registry.py +++ b/exir/passes/dim_order_ops_registry.py @@ -28,14 +28,6 @@ "_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)" ) -lib.define( - "_clone_dim_order(Tensor self, *, bool non_blocking=False, int[]? dim_order=None) -> Tensor" -) - -lib.define( - "_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)" -) - def _op_impl(target, *args, **kwargs): kwargs["memory_format"] = get_memory_format(kwargs.get("dim_order", None)) @@ -65,23 +57,12 @@ def _empty_dim_order_out_impl(*args, **kwargs): return _op_impl(torch.ops.aten.empty.out, *args, **kwargs) -@impl(lib, "_clone_dim_order", "CompositeImplicitAutograd") -def _clone_dim_order_impl(*args, **kwargs): - return _op_impl(torch.ops.aten.clone.default, *args, **kwargs) - - -@impl(lib, "_clone_dim_order.out", "CompositeImplicitAutograd") -def _clone_dim_order_out_impl(*args, **kwargs): - return _op_impl(torch.ops.aten.clone.out, *args, **kwargs) - - """ Defines a map of edge ops to the corresponding dim_order ops for quick lookup """ DimOrderOpsMap = { exir_ops.edge.aten._to_copy.default: exir_ops.edge.dim_order_ops._to_dim_order_copy.default, exir_ops.edge.aten.empty.memory_format: exir_ops.edge.dim_order_ops._empty_dim_order.default, - exir_ops.edge.aten.clone.default: exir_ops.edge.dim_order_ops._clone_dim_order.default, } """ diff --git a/exir/tests/test_memory_format_ops_pass.py b/exir/tests/test_memory_format_ops_pass.py index 2384f6123a9..84cd0faa485 100644 --- a/exir/tests/test_memory_format_ops_pass.py +++ b/exir/tests/test_memory_format_ops_pass.py @@ -27,10 +27,7 @@ AmbiguousDimOrderError, MemoryFormatOpsPassTestUtils, MemoryFormatTestSet, - PropagateToCloneChannelsLastModule, PropagateToCopyChannalsLastModule, - SimpleCloneChannelsLastModule, - SimpleCloneContiguousModule, SimpleEmptyChannelLastModule, SimpleEmptyContiguoustModule, SimpleToCopyChannelsLastModule, @@ -94,36 +91,6 @@ def test_op_empty_replacement_contiguous(self) -> None: ), ) - def test_op_clone_replacement_contiguous(self) -> None: - model = SimpleCloneContiguousModule() - MemoryFormatOpsPassTestUtils.memory_format_test_runner( - self, - MemoryFormatTestSet( - module=model.eval(), - op=torch.ops.aten.clone.default, - sample_input=( - torch.randn((3, 4, 5, 6)).to(memory_format=torch.channels_last), - ), - target_memory_format=torch.contiguous_format, - _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, - ), - ) - - def test_op_clone_replacement_channels_last(self) -> None: - model = SimpleCloneChannelsLastModule() - MemoryFormatOpsPassTestUtils.memory_format_test_runner( - self, - MemoryFormatTestSet( - module=model.eval(), - op=torch.ops.aten.clone.default, - sample_input=( - torch.randn((3, 4, 5, 6)).to(memory_format=torch.contiguous_format), - ), - target_memory_format=torch.channels_last, - _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, - ), - ) - def test_op_dim_order_update(self) -> None: MemoryFormatOpsPassTestUtils.memory_format_test_runner( self, @@ -161,25 +128,6 @@ def test_op_dim_order_propagation(self) -> None: check_unambiguous_dim_order=True, ) - def test_op_clone_dim_order_propagation(self) -> None: - MemoryFormatOpsPassTestUtils.memory_format_test_runner( - self, - MemoryFormatTestSet( - module=PropagateToCloneChannelsLastModule().eval(), - op=torch.ops.aten.clone.default, - sample_input=( - torch.rand_like( - torch.zeros([2, 2, 2, 2]), - dtype=torch.float32, - memory_format=torch.contiguous_format, - ), - ), - target_memory_format=torch.channels_last, - _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, - ), - check_unambiguous_dim_order=True, - ) - def test_op_dim_order_propagation_ambiguous(self) -> None: try: MemoryFormatOpsPassTestUtils.memory_format_test_runner( diff --git a/exir/tests/test_memory_format_ops_pass_utils.py b/exir/tests/test_memory_format_ops_pass_utils.py index f5a786c6f74..6daf38b187f 100644 --- a/exir/tests/test_memory_format_ops_pass_utils.py +++ b/exir/tests/test_memory_format_ops_pass_utils.py @@ -38,10 +38,6 @@ "torch.ops.aten.empty.memory_format", "executorch_exir_dialects_edge__ops_dim_order_ops__empty_dim_order_default", ), - torch.ops.aten.clone.default: ( - "torch.ops.aten.clone.default", - "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default", - ), } @@ -74,22 +70,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x.to(dtype=torch.double, memory_format=torch.channels_last) -class SimpleCloneContiguousModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.clone(memory_format=torch.contiguous_format) - - -class SimpleCloneChannelsLastModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.clone(memory_format=torch.channels_last) - - class SimpleEmptyContiguoustModule(torch.nn.Module): def __init__(self): super().__init__() @@ -122,16 +102,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return t1 * t2 -class PropagateToCloneChannelsLastModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - t1 = x.clone(memory_format=torch.channels_last) - t2 = t1 + t1 - return t1 * t2 - - class AmbiguousDimOrderError(RuntimeError): pass