diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index dcbdfb03f7b..b906c06b329 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -26,6 +26,9 @@ NNCHW_ORDER, NNHWC_INVERSE_ORDER, NNHWC_ORDER, + NNNCHW_ORDER, + NNNHWC_INVERSE_ORDER, + NNNHWC_ORDER, ) from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops @@ -51,12 +54,6 @@ class ToTosaMemoryFormatPass(ExportPass): _passes_required_after: Set[Type[ExportPass]] = set() - NHWC_order = (0, 2, 3, 1) - NHWC_inverse_order = (0, 3, 1, 2) - HWCM_order = (2, 3, 0, 1) - NNHWC_order = (0, 1, 3, 4, 2) - NNHWC_inverse_order = (0, 1, 4, 2, 3) - def __init__(self, exported_program: ExportedProgram) -> None: self.exported_program = exported_program super().__init__() @@ -93,7 +90,11 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node): @staticmethod def memory_format_differs(shape): """Returns true if the shape will have a different memory layout in (N)NCHW and (N)NHWC format""" - if len(shape) >= 5: + if len(shape) >= 6: + C = shape[3] + H = shape[4] + W = shape[5] + elif len(shape) == 5: C = shape[2] H = shape[3] W = shape[4] @@ -112,25 +113,26 @@ def memory_format_differs(shape): @staticmethod def is_channel_reshape(input_shape, output_shape): - """Returns true if the reshape changes the channel dimension""" - if not ( - (len(input_shape) == len(output_shape) and (len(output_shape) in (4, 5))) - or (len(input_shape) == 4 and len(output_shape) == 5) - or (len(input_shape) == 5 and len(output_shape) == 4) - ): + """Returns true if reshape changes the channel dimension or batch product dimension(s)""" + + valid_ranks = {4, 5, 6} + + if not (len(input_shape) in valid_ranks and len(output_shape) in valid_ranks): return False C_old = input_shape[-3] C_new = output_shape[-3] - N_new = ( - output_shape[0] - if len(output_shape) == 4 - else output_shape[0] * output_shape[1] - ) - N_old = ( - input_shape[0] if len(input_shape) == 4 else input_shape[0] * input_shape[1] - ) + def get_batch_prod_dim(shape): + product = 1 + + for dim in shape[:-3]: + product = product * dim + + return product + + N_old = get_batch_prod_dim(input_shape) + N_new = get_batch_prod_dim(output_shape) return (N_old != N_new) or (C_old != C_new) @@ -141,17 +143,27 @@ def insert_input_transpose(node, input_node, graph_module): node.replace_input_with(input_node, pre_permute_node) return + if len(get_first_fake_tensor(input_node).size()) == 6: + mem_format = NNNHWC_INVERSE_ORDER + elif len(get_first_fake_tensor(input_node).size()) == 5: + mem_format = NNHWC_INVERSE_ORDER + else: + mem_format = NHWC_INVERSE_ORDER + # Guard: mem_format must be a true permutation for the current rank + _rank_ = len( + get_first_fake_tensor(input_node).size() + ) # or (node) in output path + assert sorted(mem_format) == list( + range(_rank_) + ), f"bad perm {mem_format} for rank {_rank_} in insert_input_transpose" + with graph_module.graph.inserting_before(node): permute_node = create_node( graph_module.graph, exir_ops.backend.tosa.TRANSPOSE.default, args=( input_node, - list( - NNHWC_INVERSE_ORDER - if len(get_first_fake_tensor(input_node).size()) == 5 - else NHWC_INVERSE_ORDER - ), + list(mem_format), ), from_node=node, ) @@ -163,26 +175,38 @@ def insert_input_transpose(node, input_node, graph_module): @staticmethod def insert_output_transpose(node, graph_module): + + if len(get_first_fake_tensor(node).size()) == 6: + mem_format = NNNHWC_ORDER + elif len(get_first_fake_tensor(node).size()) == 5: + mem_format = NNHWC_ORDER + else: + mem_format = NHWC_ORDER + # Guard: mem_format must be a true permutation for the current rank + _rank_ = len(get_first_fake_tensor(node).size()) # or (node) in output path + assert sorted(mem_format) == list( + range(_rank_) + ), f"bad perm {mem_format} for rank {_rank_} in insert_input_transpose" + with graph_module.graph.inserting_after(node): permute_node = create_node( graph_module.graph, exir_ops.backend.tosa.TRANSPOSE.default, args=( node, - list( - NNHWC_ORDER - if len(get_first_fake_tensor(node).size()) == 5 - else NHWC_ORDER - ), + list(mem_format), ), from_node=node, ) - permute_node.meta["tosa_dim_order"] = ( - NNHWC_ORDER - if len(get_first_fake_tensor(node).size()) == 5 - else NHWC_ORDER - ) + rank = len(get_first_fake_tensor(node).size()) + if rank == 6: + permute_node.meta["tosa_dim_order"] = NNNHWC_ORDER + elif rank == 5: + permute_node.meta["tosa_dim_order"] = NNHWC_ORDER + else: + permute_node.meta["tosa_dim_order"] = NHWC_ORDER + node.meta["tosa_dim_order"] = tuple( range(len(get_first_fake_tensor(node).size())) ) @@ -261,7 +285,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): ] for input_node in inputs: input_dim_order = get_first_fake_tensor(input_node).dim_order() - if input_dim_order in (NCHW_ORDER, NNCHW_ORDER): + if input_dim_order in (NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER): self.insert_output_transpose(input_node, graph_module) # Transpose outputs if they are in (N)NCHW format @@ -276,6 +300,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): if output_dim_order in ( NCHW_ORDER, NNCHW_ORDER, + NNNCHW_ORDER, ): self.insert_input_transpose( output_node, output_node_input, graph_module @@ -313,6 +338,8 @@ def call(self, graph_module: torch.fx.GraphModule): dim_order = HWCM_ORDER elif node_data.dim() == 5: dim_order = NNHWC_ORDER + elif node_data.dim() == 6: + dim_order = NNNHWC_ORDER else: dim_order = tuple(range(node_data.dim())) # type: ignore[assignment] diff --git a/backends/arm/constants.py b/backends/arm/constants.py index b9995410b23..0e562f12e88 100644 --- a/backends/arm/constants.py +++ b/backends/arm/constants.py @@ -34,10 +34,13 @@ NHWC_INVERSE_ORDER: Final = (0, 3, 1, 2) NNHWC_ORDER: Final = (0, 1, 3, 4, 2) NNHWC_INVERSE_ORDER: Final = (0, 1, 4, 2, 3) +NNNHWC_ORDER: Final = (0, 1, 2, 4, 5, 3) +NNNHWC_INVERSE_ORDER: Final = (0, 1, 2, 5, 3, 4) NCHW_ORDER: Final = (0, 1, 2, 3) -NCHW_INVERSE_ORDER: Final = (0, 2, 3, 1) NNCHW_ORDER: Final = (0, 1, 2, 3, 4) -NNCHW_INVERSE_ORDER: Final = (0, 1, 3, 4, 2) +NNNCHW_ORDER: Final = (0, 1, 2, 3, 4, 5) HWCM_ORDER: Final = (2, 3, 0, 1) + +MAX_RANK: Final = 6 diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index b580fbb9a9a..4e085f4e34d 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -19,7 +19,7 @@ FuseQuantizedActivationPass, ) from executorch.backends.arm._passes.insert_table_ops import TableOps -from executorch.backends.arm.constants import DQ_OPS, Q_OPS +from executorch.backends.arm.constants import DQ_OPS, MAX_RANK, Q_OPS from executorch.backends.arm.operator_support.ethos_u55_support import ( EthosU55DtypeSupport, EthosU55NotSupported, @@ -126,7 +126,7 @@ def tosa_support_factory( negative_checks: list[OperatorSupportBase] = [ CheckInt64InputsAndOutputs(exported_program, reporter), CheckFloat64Inputs(exported_program, reporter), - RankCheck(reporter, max_rank=5), + RankCheck(reporter, max_rank=MAX_RANK), *[ reporter.wrap_check(check, f"Rejected by {check.__class__.__name__}") for check in (additional_checks if additional_checks else []) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index bea8fe2eddc..17990e5ab09 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -366,6 +366,8 @@ def _match_pattern( torch.ops.aten.dropout_.default, torch.ops.aten.adaptive_avg_pool2d.default, torch.ops.aten.alias_copy.default, + torch.ops.aten.pixel_shuffle.default, + torch.ops.aten.pixel_unshuffle.default, ] diff --git a/backends/arm/scripts/parse_test_names.py b/backends/arm/scripts/parse_test_names.py index 2629d8eb257..54f8aa7421d 100644 --- a/backends/arm/scripts/parse_test_names.py +++ b/backends/arm/scripts/parse_test_names.py @@ -26,6 +26,8 @@ "_native_batch_norm_legit_no_training.default", "_native_batch_norm_legit.no_stats", "alias_copy.default", + "pixel_shuffle.default", + "pixel_unshuffle.default", ] ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS diff --git a/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py b/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py index f9d814d044b..b6cd567b4a6 100644 --- a/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py +++ b/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py @@ -24,16 +24,12 @@ class TestSD3Transformer2DModel(unittest.TestCase): # Adjust nbr below as we increase op support. ops_after_partitioner_FP = { - "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 1, - "executorch_exir_dialects_edge__ops_aten_view_copy_default": 2, "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1, "torch.ops.higher_order.executorch_call_delegate": 1, } ops_after_partitioner_INT = { - "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, - "executorch_exir_dialects_edge__ops_aten_view_copy_default": 2, "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2, "torch.ops.higher_order.executorch_call_delegate": 2, } diff --git a/backends/arm/test/ops/test_pixel_shuffling.py b/backends/arm/test/ops/test_pixel_shuffling.py new file mode 100644 index 00000000000..5aeb8b2d1bb --- /dev/null +++ b/backends/arm/test/ops/test_pixel_shuffling.py @@ -0,0 +1,233 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Tuple + +import pytest + +import torch + +from executorch.backends.arm.constants import MAX_RANK + +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) +from torch import nn + +aten_op_pixel_unshuffle = "torch.ops.aten.pixel_unshuffle.default" +exir_op_pixel_unshuffle = ( + "executorch_exir_dialects_edge__ops_aten_pixel_unshuffle_default" +) + +aten_op_pixel_shuffle = "torch.ops.aten.pixel_shuffle.default" +exir_op_pixel_shuffle = "executorch_exir_dialects_edge__ops_aten_pixel_shuffle_default" + +input_t1 = Tuple[torch.Tensor] # single positional input (1-tuple) + +max_rank_input_supported = MAX_RANK - 2 + + +class PixelUnShuffle(nn.Module): + + upscale_factor = 2 + test_data_generators = { + "rand_4d": lambda: (torch.randn(1, 12, 64, 64),), + "test_4d": lambda: (torch.tensor([[[[10.0, 20.0], [30.0, 40.0]]]]),), + "test_3d": lambda: (torch.tensor([[[10.0, 20.0], [30.0, 40.0]]]),), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.space_to_depth = nn.PixelUnshuffle(self.upscale_factor) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + if inputs.dim() > max_rank_input_supported: + raise RuntimeError( + f"Max rank of input for pixel_unshuffle is currently {max_rank_input_supported}, got {inputs.dim()}" + ) + return self.space_to_depth(inputs) + + +class PixelShuffle(nn.Module): + + upscale_factor = 2 + test_data_generators = { + "rand_4d": lambda: (torch.randn(1, 12, 64, 64),), + "test_4d": lambda: (torch.tensor([[[[10.0]], [[20.0]], [[30.0]], [[40.0]]]]),), + "test_3d": lambda: (torch.tensor([[[10.0]], [[20.0]], [[30.0]], [[40.0]]]),), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.depth_to_space = nn.PixelShuffle(self.upscale_factor) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + if inputs.dim() > max_rank_input_supported: + raise RuntimeError( + f"Max rank of input for pixel_shuffle is currently {max_rank_input_supported}, got {inputs.dim()}" + ) + return self.depth_to_space(inputs) + + +@common.parametrize("test_data", PixelUnShuffle.test_data_generators) +def test_pixel_unshuffle_tosa_FP(test_data: input_t1): + pipeline = TosaPipelineFP[input_t1]( + PixelUnShuffle(), + test_data(), + aten_op_pixel_unshuffle, + exir_op_pixel_unshuffle, + ) + pipeline.run() + + +@common.parametrize("test_data", PixelUnShuffle.test_data_generators) +def test_pixel_unshuffle_tosa_INT(test_data: input_t1): + pipeline = TosaPipelineINT[input_t1]( + PixelUnShuffle(), + test_data(), + aten_op_pixel_unshuffle, + exir_op_pixel_unshuffle, + ) + pipeline.run() + + +@common.parametrize("test_data", PixelShuffle.test_data_generators) +def test_pixel_shuffle_tosa_FP(test_data: input_t1): + pipeline = TosaPipelineFP[input_t1]( + PixelShuffle(), + test_data(), + aten_op_pixel_shuffle, + exir_op_pixel_shuffle, + ) + pipeline.run() + + +@common.parametrize("test_data", PixelShuffle.test_data_generators) +def test_pixel_shuffle_tosa_INT(test_data: input_t1): + pipeline = TosaPipelineINT[input_t1]( + PixelShuffle(), + test_data(), + aten_op_pixel_shuffle, + exir_op_pixel_shuffle, + ) + pipeline.run() + + +@common.parametrize("test_data", PixelUnShuffle.test_data_generators) +@common.SkipIfNoModelConverter +def test_pixel_unshuffle_vgf_FP(test_data: input_t1): + pipeline = VgfPipeline[input_t1]( + PixelUnShuffle(), + test_data(), + aten_op_pixel_unshuffle, + exir_op_pixel_unshuffle, + tosa_version="TOSA-1.0+FP", + run_on_vulkan_runtime=True, + ) + pipeline.run() + + +@common.parametrize("test_data", PixelUnShuffle.test_data_generators) +@common.SkipIfNoModelConverter +def test_pixel_unshuffle_vgf_INT(test_data: input_t1): + pipeline = VgfPipeline[input_t1]( + PixelUnShuffle(), + test_data(), + aten_op_pixel_unshuffle, + exir_op_pixel_unshuffle, + tosa_version="TOSA-1.0+INT", + run_on_vulkan_runtime=True, + ) + pipeline.run() + + +@common.parametrize("test_data", PixelShuffle.test_data_generators) +@common.SkipIfNoModelConverter +def test_pixel_shuffle_vgf_FP(test_data: input_t1): + pipeline = VgfPipeline[input_t1]( + PixelShuffle(), + test_data(), + aten_op_pixel_shuffle, + exir_op_pixel_shuffle, + tosa_version="TOSA-1.0+FP", + run_on_vulkan_runtime=True, + ) + pipeline.run() + + +@common.parametrize("test_data", PixelShuffle.test_data_generators) +@common.SkipIfNoModelConverter +def test_pixel_shuffle_vgf_INT(test_data: input_t1): + pipeline = VgfPipeline[input_t1]( + PixelShuffle(), + test_data(), + aten_op_pixel_shuffle, + exir_op_pixel_shuffle, + tosa_version="TOSA-1.0+INT", + run_on_vulkan_runtime=True, + ) + pipeline.run() + + +@common.parametrize("test_data", PixelUnShuffle.test_data_generators) +@common.XfailIfNoCorstone300 +def test_pixel_unshuffle_u55_INT(test_data: input_t1): + pipeline = EthosU55PipelineINT[input_t1]( + PixelUnShuffle(), + test_data(), + aten_op_pixel_unshuffle, + exir_op_pixel_unshuffle, + run_on_fvp=True, + ) + pipeline.run() + + +@common.parametrize("test_data", PixelUnShuffle.test_data_generators) +@common.XfailIfNoCorstone320 +@pytest.mark.xfail(reason="MLETORCH-1424: rand test fails") +def test_pixel_unshuffle_u85_INT(test_data: input_t1): + pipeline = EthosU85PipelineINT[input_t1]( + PixelUnShuffle(), + test_data(), + aten_op_pixel_unshuffle, + exir_op_pixel_unshuffle, + run_on_fvp=True, + ) + pipeline.run() + + +@common.parametrize("test_data", PixelShuffle.test_data_generators) +@common.XfailIfNoCorstone300 +def test_pixel_shuffle_u55_INT(test_data: input_t1): + pipeline = EthosU55PipelineINT[input_t1]( + PixelShuffle(), + test_data(), + aten_op_pixel_shuffle, + exir_op_pixel_shuffle, + run_on_fvp=True, + ) + pipeline.run() + + +@common.parametrize("test_data", PixelShuffle.test_data_generators) +@common.XfailIfNoCorstone320 +@pytest.mark.xfail(reason="MLETORCH-1424: rand test fails") +def test_pixel_shuffle_u85_INT(test_data: input_t1): + pipeline = EthosU85PipelineINT[input_t1]( + PixelShuffle(), + test_data(), + aten_op_pixel_shuffle, + exir_op_pixel_shuffle, + run_on_fvp=True, + ) + pipeline.run() diff --git a/backends/arm/tosa/dialect/ops/transpose.py b/backends/arm/tosa/dialect/ops/transpose.py index 9c5aba05394..8d5bf8bac70 100644 --- a/backends/arm/tosa/dialect/ops/transpose.py +++ b/backends/arm/tosa/dialect/ops/transpose.py @@ -26,9 +26,9 @@ def TRANSPOSE(a, perms): # By utilizing an edge IR passthrough operator we can keep the edge program in # channels-first/contiguous and get the desired behavior in the TOSA lowering. - if len(perms) not in (4, 5): + if len(perms) not in (4, 5, 6): raise TosaValueError( - f"Only 4D and 5D tensors are supported, got {len(perms)}: {perms}", + f"Only 4D, 5D and 6D tensors are supported, got {len(perms)}: {perms}", op="TRANSPOSE", )