From cb348aec2199ede62721880fff699c90e334fcd4 Mon Sep 17 00:00:00 2001 From: Adrian Lundell <36153706+AdrianLundell@users.noreply.github.com> Date: Wed, 17 Sep 2025 22:50:31 +0200 Subject: [PATCH] Arm backend: Support channels-last input and output Differential Revision: D82449155 Pull Request resolved: https://github.com/pytorch/executorch/pull/14259 (cherry picked from commit 5348ea9503326987ccd06245be992a51420f6722) --- .../arm/_passes/to_tosa_memory_format_pass.py | 111 +++++++--------- backends/arm/constants.py | 12 ++ .../to_dim_order_copy_support.py | 1 + backends/arm/process_node.py | 7 - backends/arm/runtime/EthosUBackend.cpp | 9 -- backends/arm/test/misc/test_dim_order.py | 123 ++++++++++++++++++ .../arm/test/misc/test_dim_order_guards.py | 67 ---------- .../arm/test/models/test_mobilenet_v2_arm.py | 17 +++ .../arm/test/models/test_torch_functions.py | 1 - .../test/passes/test_to_tosa_memory_format.py | 10 +- backends/arm/test/runner_utils.py | 108 ++++++++++----- backends/arm/test/targets.bzl | 2 +- docs/source/backends-arm-ethos-u.md | 9 ++ 13 files changed, 296 insertions(+), 181 deletions(-) create mode 100644 backends/arm/test/misc/test_dim_order.py delete mode 100644 backends/arm/test/misc/test_dim_order_guards.py diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index e4436d638f4..ac16cbaf8cb 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -9,13 +9,23 @@ import logging import torch -from executorch.backends.arm._passes import AnnotateOutputDimOrderPass +from executorch.backends.arm._passes.annotate_decomposed_matmul import ( + AnnotateDecomposedMatmulPass, +) from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, - get_output_dim_orders, is_param_node, ) +from executorch.backends.arm.constants import ( + HWCM_ORDER, + NCHW_ORDER, + NHWC_INVERSE_ORDER, + NHWC_ORDER, + NNCHW_ORDER, + NNHWC_INVERSE_ORDER, + NNHWC_ORDER, +) from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -38,12 +48,6 @@ class ToTosaMemoryFormatPass(ExportPass): The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape. """ - NHWC_order = (0, 2, 3, 1) - NHWC_inverse_order = (0, 3, 1, 2) - HWCM_order = (2, 3, 0, 1) - NNHWC_order = (0, 1, 3, 4, 2) - NNHWC_inverse_order = (0, 1, 4, 2, 3) - def __init__(self, exported_program: ExportedProgram) -> None: self.exported_program = exported_program super().__init__() @@ -135,9 +139,9 @@ def insert_input_transpose(node, input_node, graph_module): args=( input_node, list( - ToTosaMemoryFormatPass.NNHWC_inverse_order + NNHWC_INVERSE_ORDER if len(get_first_fake_tensor(input_node).size()) == 5 - else ToTosaMemoryFormatPass.NHWC_inverse_order + else NHWC_INVERSE_ORDER ), ), from_node=node, @@ -157,18 +161,18 @@ def insert_output_transpose(node, graph_module): args=( node, list( - ToTosaMemoryFormatPass.NNHWC_order + NNHWC_ORDER if len(get_first_fake_tensor(node).size()) == 5 - else ToTosaMemoryFormatPass.NHWC_order + else NHWC_ORDER ), ), from_node=node, ) permute_node.meta["tosa_dim_order"] = ( - ToTosaMemoryFormatPass.NNHWC_order + NNHWC_ORDER if len(get_first_fake_tensor(node).size()) == 5 - else ToTosaMemoryFormatPass.NHWC_order + else NHWC_ORDER ) node.meta["tosa_dim_order"] = tuple( range(len(get_first_fake_tensor(node).size())) @@ -218,7 +222,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: # call_function and placeholder allowed due to # index.Tensor being able to come in as both - if node.op not in ["call_function", "placeholder", "output"]: + if node.op != "call_function": continue # Transpose views @@ -240,21 +244,33 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): graph_module, ) - # Transpose inputs - elif _is_input(node, self.exported_program): - input_shape = get_first_fake_tensor(node).size() - if len(input_shape) in (4, 5): - ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module) + output_node = graph_module.graph.output_node() - # Transpose outputs - elif node.op == "output": - output_shape = get_first_fake_tensor(node).size() + # Transpose inputs if they are in (N)NCHW format + inputs = [ + n for n in graph_module.graph.nodes if _is_input(n, self.exported_program) + ] + for input_node in inputs: + input_dim_order = get_first_fake_tensor(input_node).dim_order() + if input_dim_order in (NCHW_ORDER, NNCHW_ORDER): + self.insert_output_transpose(input_node, graph_module) + + # Transpose outputs if they are in (N)NCHW format + outputs = output_node.args[0] + output_dim_orders = output_node.meta.get("original_dim_orders") + if output_dim_orders is None: + raise RuntimeError( + f"{AnnotateDecomposedMatmulPass.__name__} is required to run at the beginning of the pass pipeline when using {ToTosaMemoryFormatPass.__name__}." + ) - if len(output_shape) in (4, 5): - for input_node in node.all_input_nodes: - ToTosaMemoryFormatPass.insert_input_transpose( - node, input_node, graph_module - ) + for output_node_input, output_dim_order in zip(outputs, output_dim_orders): # type: ignore[arg-type] + if output_dim_order in ( + NCHW_ORDER, + NNCHW_ORDER, + ): + self.insert_input_transpose( + output_node, output_node_input, graph_module + ) def remove_dim_order_kwargs( self, graph_module: torch.fx.GraphModule, node: torch.fx.Node @@ -277,17 +293,17 @@ def call(self, graph_module: torch.fx.GraphModule): node_data = get_first_fake_tensor(node).data self.remove_dim_order_kwargs(graph_module, node) - # Inputs and outputs are always in (N)NCHW format + # Inputs and outputs may vary in dim_order if _is_input(node, self.exported_program) or node.op == "output": - dim_order = tuple(range(node_data.dim())) + dim_order = node_data.dim_order() elif node_data.dim() == 4: - dim_order = self.NHWC_order + dim_order = NHWC_ORDER if self.is_weight_node_for_depthwise_conv2d(node): # The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to # dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d). - dim_order = self.HWCM_order + dim_order = HWCM_ORDER elif node_data.dim() == 5: - dim_order = self.NNHWC_order + dim_order = NNHWC_ORDER else: dim_order = tuple(range(node_data.dim())) # type: ignore[assignment] @@ -300,32 +316,3 @@ def call(self, graph_module: torch.fx.GraphModule): graph_module = super().call(graph_module).graph_module return PassResult(graph_module, True) - - def requires(self, graph_module) -> None: - """ - This is the only pass which handles dim_orders, so verify that the output dim_orders has not changed since the beginning of the lowering pipeline. - """ - - dim_orders = get_output_dim_orders(graph_module) - original_dim_orders = graph_module.graph.output_node().meta.get( - "original_dim_orders" - ) - output_node = graph_module.graph.output_node() - - if original_dim_orders is None: - raise RuntimeError( - f"{AnnotateOutputDimOrderPass.__name__} must be run in the beginning of the pass pipeline to verify that the dim order has not changed unexpectedly during its run." - ) - - if len(dim_orders) != len(original_dim_orders): - raise RuntimeError( - f"The number of outputs has changed since {AnnotateOutputDimOrderPass.__name__} was run." - ) - - for node, dim_order, original_dim_order in zip( - output_node.args[0], dim_orders, original_dim_orders - ): - if dim_order != original_dim_order: - raise RuntimeError( - f"The dim order of output {node.name} has changed from {original_dim_order} to {dim_order} since {AnnotateOutputDimOrderPass.__name__} was run." - ) diff --git a/backends/arm/constants.py b/backends/arm/constants.py index fd8710d3ead..b9995410b23 100644 --- a/backends/arm/constants.py +++ b/backends/arm/constants.py @@ -29,3 +29,15 @@ DEQUANT_PER_TENSOR_OP_T, ) PER_CHANNEL_QDQ_OPS: Final = (QUANT_PER_CHANNEL_OP, DEQUANT_PER_CHANNEL_OP) + +NHWC_ORDER: Final = (0, 2, 3, 1) +NHWC_INVERSE_ORDER: Final = (0, 3, 1, 2) +NNHWC_ORDER: Final = (0, 1, 3, 4, 2) +NNHWC_INVERSE_ORDER: Final = (0, 1, 4, 2, 3) + +NCHW_ORDER: Final = (0, 1, 2, 3) +NCHW_INVERSE_ORDER: Final = (0, 2, 3, 1) +NNCHW_ORDER: Final = (0, 1, 2, 3, 4) +NNCHW_INVERSE_ORDER: Final = (0, 1, 3, 4, 2) + +HWCM_ORDER: Final = (2, 3, 0, 1) diff --git a/backends/arm/operator_support/to_dim_order_copy_support.py b/backends/arm/operator_support/to_dim_order_copy_support.py index e21f8a68ad6..ced9b7c5afc 100644 --- a/backends/arm/operator_support/to_dim_order_copy_support.py +++ b/backends/arm/operator_support/to_dim_order_copy_support.py @@ -89,6 +89,7 @@ def _merge_supported_types( torch.int32, torch.bfloat16, torch.float16, + torch.float32, ], } ALL_SUPPORTED_TYPES = _merge_supported_types( diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 9ca435c60c5..5093ea32d4c 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -70,13 +70,6 @@ def process_inputs( tosa_spec: TosaSpecification, ): """Serialize an input node""" - # inputs need to be in default dim_order (contiguous memory format) - meta = node.meta["val"] - if meta.dim_order() != tuple(range(meta.dim())): - raise RuntimeError( - f"Arm backend only supports contiguous memory format for inputs. " - f"Expected dim_order: {tuple(range(meta.dim()))}, but got: {meta.dim_order()} for node {node.name}" - ) try: tosa_arg = TosaArg(node, tosa_spec) except ValueError as e: diff --git a/backends/arm/runtime/EthosUBackend.cpp b/backends/arm/runtime/EthosUBackend.cpp index 8f63569eece..08589c34c69 100644 --- a/backends/arm/runtime/EthosUBackend.cpp +++ b/backends/arm/runtime/EthosUBackend.cpp @@ -249,15 +249,6 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { handles.inputs->io[i].elem_size); return Error::InvalidProgram; } - supported = executorch::runtime::is_contiguous_dim_order( - tensor_in.dim_order().data(), tensor_in.dim()); - if (!supported) { - ET_LOG( - Error, - "Input %d expected contiguous dim_order, but got non-contiguous dim_order", - i); - return Error::InvalidProgram; - } // Select a compatible copy routine including checking for input layouts // which require permutation. diff --git a/backends/arm/test/misc/test_dim_order.py b/backends/arm/test/misc/test_dim_order.py new file mode 100644 index 00000000000..6b0b79add99 --- /dev/null +++ b/backends/arm/test/misc/test_dim_order.py @@ -0,0 +1,123 @@ +# Copyright 2024-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, +) + + +input_t1 = Tuple[torch.Tensor] # Input x + + +class ChannelsLastInput(torch.nn.Module): + """ + Test a complex case with (channels last, channels first) input, + and (channels first, channels last) output. + """ + + inputs: input_t1 = ( + torch.arange(1, 25, dtype=torch.float32) + .reshape((1, 2, 3, 4)) + .to(memory_format=torch.channels_last), + torch.arange(1, 25, dtype=torch.float32).reshape((1, 2, 3, 4)), + ) + + def forward(self, x, y): + x = x * x + return y, x + + +class ChannelsFirstOutput(torch.nn.Module): + """ + Test coverting to channels_first inside the delegate. + """ + + inputs: input_t1 = ( + torch.arange(1, 25, dtype=torch.float32) + .reshape((1, 2, 3, 4)) + .to(memory_format=torch.channels_last), + ) + + def forward(self, x): + x = x.clone(memory_format=torch.contiguous_format) * x + return x + + +class ChannelsLastOutput(torch.nn.Module): + """ + Test changing of dim_order inside the delegate. + """ + + inputs: input_t1 = (torch.arange(1, 9, dtype=torch.float32).reshape((1, 2, 2, 2)),) + + def forward(self, x): + x = x * x + x = x.clone(memory_format=torch.channels_last) + return x + + +class ChannelsLastInsidePartition(torch.nn.Module): + """ + Test dim_order changes inside the partiton, but no dim_order changes at input/output. + """ + + inputs: input_t1 = (torch.randn((1, 2, 3, 3)),) + + def __init__(self): + super().__init__() + self.conv2d = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=(3, 3)) + + def forward(self, x): + return ( + self.conv2d(x.clone(memory_format=torch.channels_last)).clone( + memory_format=torch.contiguous_format + ) + * 1 + ) + + +test_modules = { + "channels_last_input": ChannelsLastInput, + "channels_first_output": ChannelsFirstOutput, + "channels_last_output": ChannelsLastOutput, + "channels_last_inside_partition": ChannelsLastInsidePartition, +} + + +@common.parametrize("module", test_modules) +def test_dim_order_tosa_FP(module): + pipeline = TosaPipelineFP[input_t1](module(), module.inputs, []) + pipeline.run() + + +@common.parametrize("module", test_modules) +def test_dim_order_tosa_INT(module): + pipeline = TosaPipelineINT[input_t1]( + module(), module.inputs, [], symmetric_io_quantization=True + ) + pipeline.run() + + +@common.XfailIfNoCorstone300 +@common.parametrize("module", test_modules) +def test_dim_order_u55_INT(module): + pipeline = EthosU55PipelineINT[input_t1](module(), module.inputs, []) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("module", test_modules) +def test_dim_order_u85_INT(module): + pipeline = EthosU85PipelineINT[input_t1](module(), module.inputs, []) + pipeline.run() diff --git a/backends/arm/test/misc/test_dim_order_guards.py b/backends/arm/test/misc/test_dim_order_guards.py deleted file mode 100644 index 80a3c014abc..00000000000 --- a/backends/arm/test/misc/test_dim_order_guards.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -from typing import Tuple - -import pytest - -import torch -from executorch.backends.arm.test import common - -from executorch.backends.arm.test.tester.test_pipeline import ( - TosaPipelineFP, - TosaPipelineINT, -) - - -input_t1 = Tuple[torch.Tensor] # Input x - - -class Conv2D(torch.nn.Module): - inputs: dict[str, input_t1] = { - "randn": (torch.randn(1, 2, 20, 20).to(memory_format=torch.channels_last),), - } - - def __init__(self): - super().__init__() - self.conv2d = torch.nn.Conv2d(in_channels=2, out_channels=3, kernel_size=(3, 3)) - - def forward(self, x): - return self.conv2d(x) - - -@common.parametrize("test_data", Conv2D.inputs) -def test_tosa_FP_pipeline(test_data: input_t1): - module = Conv2D() - pipeline = TosaPipelineFP[input_t1]( - module, - test_data, - [], - [], - use_to_edge_transform_and_lower=False, - ) - pos = pipeline.find_pos("partition") - pipeline._stages = pipeline._stages[:pos] - pipeline.run() - with pytest.raises(RuntimeError): - pipeline.tester.partition() - - -@common.parametrize("test_data", Conv2D.inputs) -def test_tosa_INT_pipeline(test_data: input_t1): - module = Conv2D() - pipeline = TosaPipelineINT[input_t1]( - module, - test_data, - [], - [], - use_to_edge_transform_and_lower=False, - ) - pos = pipeline.find_pos("partition") - pipeline._stages = pipeline._stages[:pos] - pipeline.run() - with pytest.raises(RuntimeError): - pipeline.tester.partition() diff --git a/backends/arm/test/models/test_mobilenet_v2_arm.py b/backends/arm/test/models/test_mobilenet_v2_arm.py index d4e3bbc8e28..84de432155e 100644 --- a/backends/arm/test/models/test_mobilenet_v2_arm.py +++ b/backends/arm/test/models/test_mobilenet_v2_arm.py @@ -46,6 +46,23 @@ def test_mv2_tosa_FP(): pipeline.run() +def test_mv2_tosa_FP_channels_last(): + input_tensor = model_inputs[0].to(memory_format=torch.channels_last) + pipeline = TosaPipelineFP[input_t]( + mv2, + (input_tensor,), + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + ) + # Changing memory format leads to an unsupported as_strided_copy op being inserted into the graph, + # leading to a graph break. + pipeline.change_args( + "check_count.exir", {"torch.ops.higher_order.executorch_call_delegate": 2} + ) + pipeline.run() + + @common.parametrize("per_channel_quantization", quant_test_data) def test_mv2_tosa_INT(per_channel_quantization): pipeline = TosaPipelineINT[input_t]( diff --git a/backends/arm/test/models/test_torch_functions.py b/backends/arm/test/models/test_torch_functions.py index 580438f6da8..de45dbe0356 100644 --- a/backends/arm/test/models/test_torch_functions.py +++ b/backends/arm/test/models/test_torch_functions.py @@ -101,7 +101,6 @@ def forward(self, *args): "Requires dynamic output shape.", "topk": "NotImplementedError: No registered serialization name for found", "sort": "NotImplementedError: No registered serialization name for found", - "norm": "An error occurred when running the 'KeepDimsFalseToSqueezePass' pass after the following passes:", }, ) def test_torch_fns_FP(test_data): diff --git a/backends/arm/test/passes/test_to_tosa_memory_format.py b/backends/arm/test/passes/test_to_tosa_memory_format.py index 1e9b8ffc63d..643a3bf5733 100644 --- a/backends/arm/test/passes/test_to_tosa_memory_format.py +++ b/backends/arm/test/passes/test_to_tosa_memory_format.py @@ -6,7 +6,10 @@ from typing import Tuple import torch -from executorch.backends.arm._passes import ToTosaMemoryFormatPass +from executorch.backends.arm._passes import ( + AnnotateOutputDimOrderPass, + ToTosaMemoryFormatPass, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( @@ -177,7 +180,10 @@ def test_to_tosa_memory_format_tosa_INT(module): ops_after_pass=module.ops_after_pass, ops_not_after_pass=module.ops_not_after_pass, pass_list=[RemoveGetItemPass], - passes_with_exported_program=[ToTosaMemoryFormatPass], + passes_with_exported_program=[ + AnnotateOutputDimOrderPass, + ToTosaMemoryFormatPass, + ], ) pipeline.pop_stage( "run_method_and_compare_outputs" diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 1b59b186a2e..3d002eff25e 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -13,11 +13,19 @@ from pathlib import Path +from types import NoneType from typing import Any, cast, Dict, List, Literal, Optional, Tuple import numpy as np import torch +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.constants import ( + NHWC_INVERSE_ORDER, + NHWC_ORDER, + NNHWC_INVERSE_ORDER, + NNHWC_ORDER, +) from executorch.backends.arm.ethosu import EthosUCompileSpec from executorch.backends.arm.test.conftest import is_option_enabled @@ -157,6 +165,36 @@ def get_output_quantization_params( return quant_params +def torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: + dtype = _torch_to_numpy_dtype_dict[tensor.dtype] + array = tensor.detach().numpy().astype(dtype) + dim_order = tensor.dim_order() + if dim_order == NHWC_ORDER: + a = array.transpose(NHWC_ORDER) + return a + elif dim_order == NNHWC_ORDER: + return array.transpose(NNHWC_ORDER) + else: + return array + + +def numpy_to_torch_tensor(array: np.ndarray, output_node: Node) -> torch.Tensor: + output_tensor = get_first_fake_tensor(output_node) + shape = output_tensor.shape + dim_order = output_tensor.dim_order() + if dim_order == NHWC_ORDER: + shape_with_dim_order = [shape[i] for i in NHWC_ORDER] + tensor = torch.from_numpy(array).reshape(shape_with_dim_order) + return tensor.permute(NHWC_INVERSE_ORDER).to(memory_format=torch.channels_last) + elif dim_order == NNHWC_ORDER: + shape_with_dim_order = [shape[i] for i in NNHWC_ORDER] + tensor = torch.from_numpy(array).reshape(shape_with_dim_order) + return tensor.permute(NNHWC_INVERSE_ORDER).to(memory_format=torch.channels_last) + else: + tensor = torch.from_numpy(array).reshape(shape) + return tensor + + class TosaReferenceModelDispatch(TorchFunctionMode): """A context manager for executing call_delegate nodes using the reference model""" @@ -168,7 +206,8 @@ def _tosa_dispatch(self, lowered_backend_module: LoweredBackendModule, inputs): tosa_buffer = lowered_backend_module.processed_bytes compile_spec = TosaCompileSpec.from_list(lowered_backend_module.compile_specs) - return run_tosa_graph(tosa_buffer, compile_spec.tosa_spec, inputs) + output_node = lowered_backend_module.original_module.graph.output_node() + return run_tosa_graph(tosa_buffer, compile_spec.tosa_spec, inputs, output_node) def __exit__(self, exc_type, exc_val, exc_tb): super().__exit__(exc_type, exc_val, exc_tb) @@ -190,6 +229,22 @@ def __torch_function__(self, func, types, args=..., kwargs=None): ) kwargs = kwargs or {} + + # This is a hack since Q/DQ ops does not handle channels last input correctly: the simplest and most robust + # workaround is to simply run them in channels first format and then convert back to channels last. + if func in ( + torch.ops.quantized_decomposed.quantize_per_tensor.out, + torch.ops.quantized_decomposed.dequantize_per_tensor.out, + torch.ops.quantized_decomposed.quantize_per_channel.out, + torch.ops.quantized_decomposed.dequantize_per_channel.out, + ): + + input_dim_order = args[0].dim_order() + if input_dim_order in (NHWC_ORDER, NNHWC_ORDER): + args = [args[0].to(memory_format=torch.contiguous_format), *args[1:]] + res = func(*args, **kwargs) + return res.to(memory_format=torch.channels_last) + return func(*args, **kwargs) @@ -244,14 +299,13 @@ def get_output_from_file( output_np = [] output_node = exported_program.graph_module.graph.output_node() for i, node in enumerate(output_node.args[0]): - output_shape = node.meta["val"].shape output_dtype = node.meta["val"].dtype tosa_ref_output = np.fromfile( os.path.join(intermediate_path, f"{output_base_name}-{i}.bin"), _torch_to_numpy_dtype_dict[output_dtype], ) - output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape)) + output_np.append(numpy_to_torch_tensor(tosa_ref_output, node)) return tuple(output_np) @@ -437,11 +491,14 @@ def prep_data_for_save( quant_param: Optional[QuantizationParams] = None, ): if isinstance(data, torch.Tensor): - data_np = np.array(data.detach(), order="C").astype( - _torch_to_numpy_dtype_dict[data.dtype] - ) + data_np = torch_tensor_to_numpy(data) + elif isinstance(data, (int, float, bool, NoneType)): + return np.array(data) else: - data_np = np.array(data) + raise RuntimeError( + f"Input dtype {type(data)} could not be converted to numpy array." + ) + if quant_param is not None: assert quant_param.node_name in input_name, ( f"The quantization params name '{quant_param.node_name}' does not " @@ -455,30 +512,8 @@ def prep_data_for_save( f"{quant_param.dtype}".replace("torch.", "") ) # Use string format of dtype to convert to numpy dtype ) - return data_np - - -def save_npy( - path: str, - data, - input_name: str, - quant_param: Optional[QuantizationParams] = None, -) -> str: - """Serializes and saves 'data' as a .npy file, possibly quantizing it before. - - Parameters: - path: the directory where to save the data. - data: the data to save. - input_name: the name of the file, without file-ending. - quant_param: the parameters to use for quantization. - Returns: - the full file path of the output. - """ - data_np = prep_data_for_save(data, input_name, quant_param) - file_path = os.path.join(path, input_name + ".npy") - np.save(file_path, data_np, allow_pickle=False) - return file_path + return data_np def save_bytes( @@ -691,9 +726,12 @@ def run_tosa_graph( graph: Any, tosa_version: TosaSpecification, inputs: list[torch.Tensor], + output_node: Node, ) -> list[torch.Tensor]: """Runs the TOSA reference model with inputs and returns the result.""" - inputs_np = [input.numpy() for input in inputs] + + # Convert tensors to numpy arrays with correct dim_order + inputs_np = [torch_tensor_to_numpy(input_tensor) for input_tensor in inputs] if isinstance(tosa_version, Tosa_1_00): import tosa_reference_model as reference_model @@ -715,7 +753,13 @@ def run_tosa_graph( status == reference_model.GraphStatus.TOSA_VALID ), "Non-valid TOSA given to reference model." - return [torch.from_numpy(output) for output in outputs_np] + # Convert output numpy arrays to tensors with same dim_order as the output nodes + result = [ + numpy_to_torch_tensor(output_array, node) + for output_array, node in zip(outputs_np, output_node.args[0]) + ] + + return result def get_target_board(compile_spec: ArmCompileSpec) -> str | None: diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index f240855cdf4..7634eed7a53 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -39,7 +39,7 @@ def define_arm_tests(): "misc/test_bn_relu_folding_qat.py", "misc/test_custom_partition.py", "misc/test_debug_hook.py", - "misc/test_dim_order_guards.py", + "misc/test_dim_order.py", "misc/test_outputs_order.py", ] diff --git a/docs/source/backends-arm-ethos-u.md b/docs/source/backends-arm-ethos-u.md index 9b3d02b21c1..0a5d1dded74 100644 --- a/docs/source/backends-arm-ethos-u.md +++ b/docs/source/backends-arm-ethos-u.md @@ -273,5 +273,14 @@ non delegated Aten ops manually by setting `EXECUTORCH_SELECT_OPS_LIST`. To enab when building the executor_runner. +## Memory formats + +Tensors of rank 4 and higher have two differing [memory format](https://pytorch.org/blog/tensor-memory-format-matters/) standards used. +Pytorch defaults to contiguous/ channels first/ NCHW memory formats, compared to TOSA which only supports channels last/NHWC memory format. +To support this, the backend inserts a transpose in the beginning if the incoming memory format is contiguous, and correspondingly a +transpose in the end if the outgoing memory format is contiguous. Note that this means that you may avoid transposing the data unneccessarily if the runtime integration and +full network is converted to use channels last. A word of caution must be given here however - changing memory format has been noted to have side effects such as +unsupported ops being inserted into the graph, and it is currently not widely tested, so the feature must so far be viewed as experimental. + ## See Also - [Arm Ethos-U Backend Tutorial](tutorial-arm.md) \ No newline at end of file