Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 49 additions & 62 deletions backends/arm/_passes/to_tosa_memory_format_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,23 @@
import logging

import torch
from executorch.backends.arm._passes import AnnotateOutputDimOrderPass
from executorch.backends.arm._passes.annotate_decomposed_matmul import (
AnnotateDecomposedMatmulPass,
)
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
get_output_dim_orders,
is_param_node,
)
from executorch.backends.arm.constants import (
HWCM_ORDER,
NCHW_ORDER,
NHWC_INVERSE_ORDER,
NHWC_ORDER,
NNCHW_ORDER,
NNHWC_INVERSE_ORDER,
NNHWC_ORDER,
)
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand All @@ -38,12 +48,6 @@ class ToTosaMemoryFormatPass(ExportPass):
The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
"""

NHWC_order = (0, 2, 3, 1)
NHWC_inverse_order = (0, 3, 1, 2)
HWCM_order = (2, 3, 0, 1)
NNHWC_order = (0, 1, 3, 4, 2)
NNHWC_inverse_order = (0, 1, 4, 2, 3)

def __init__(self, exported_program: ExportedProgram) -> None:
self.exported_program = exported_program
super().__init__()
Expand Down Expand Up @@ -135,9 +139,9 @@ def insert_input_transpose(node, input_node, graph_module):
args=(
input_node,
list(
ToTosaMemoryFormatPass.NNHWC_inverse_order
NNHWC_INVERSE_ORDER
if len(get_first_fake_tensor(input_node).size()) == 5
else ToTosaMemoryFormatPass.NHWC_inverse_order
else NHWC_INVERSE_ORDER
),
),
from_node=node,
Expand All @@ -157,18 +161,18 @@ def insert_output_transpose(node, graph_module):
args=(
node,
list(
ToTosaMemoryFormatPass.NNHWC_order
NNHWC_ORDER
if len(get_first_fake_tensor(node).size()) == 5
else ToTosaMemoryFormatPass.NHWC_order
else NHWC_ORDER
),
),
from_node=node,
)

permute_node.meta["tosa_dim_order"] = (
ToTosaMemoryFormatPass.NNHWC_order
NNHWC_ORDER
if len(get_first_fake_tensor(node).size()) == 5
else ToTosaMemoryFormatPass.NHWC_order
else NHWC_ORDER
)
node.meta["tosa_dim_order"] = tuple(
range(len(get_first_fake_tensor(node).size()))
Expand Down Expand Up @@ -218,7 +222,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
# call_function and placeholder allowed due to
# index.Tensor being able to come in as both
if node.op not in ["call_function", "placeholder", "output"]:
if node.op != "call_function":
continue

# Transpose views
Expand All @@ -240,21 +244,33 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
graph_module,
)

# Transpose inputs
elif _is_input(node, self.exported_program):
input_shape = get_first_fake_tensor(node).size()
if len(input_shape) in (4, 5):
ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module)
output_node = graph_module.graph.output_node()

# Transpose outputs
elif node.op == "output":
output_shape = get_first_fake_tensor(node).size()
# Transpose inputs if they are in (N)NCHW format
inputs = [
n for n in graph_module.graph.nodes if _is_input(n, self.exported_program)
]
for input_node in inputs:
input_dim_order = get_first_fake_tensor(input_node).dim_order()
if input_dim_order in (NCHW_ORDER, NNCHW_ORDER):
self.insert_output_transpose(input_node, graph_module)

# Transpose outputs if they are in (N)NCHW format
outputs = output_node.args[0]
output_dim_orders = output_node.meta.get("original_dim_orders")
if output_dim_orders is None:
raise RuntimeError(
f"{AnnotateDecomposedMatmulPass.__name__} is required to run at the beginning of the pass pipeline when using {ToTosaMemoryFormatPass.__name__}."
)

if len(output_shape) in (4, 5):
for input_node in node.all_input_nodes:
ToTosaMemoryFormatPass.insert_input_transpose(
node, input_node, graph_module
)
for output_node_input, output_dim_order in zip(outputs, output_dim_orders): # type: ignore[arg-type]
if output_dim_order in (
NCHW_ORDER,
NNCHW_ORDER,
):
self.insert_input_transpose(
output_node, output_node_input, graph_module
)

def remove_dim_order_kwargs(
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
Expand All @@ -277,17 +293,17 @@ def call(self, graph_module: torch.fx.GraphModule):
node_data = get_first_fake_tensor(node).data

self.remove_dim_order_kwargs(graph_module, node)
# Inputs and outputs are always in (N)NCHW format
# Inputs and outputs may vary in dim_order
if _is_input(node, self.exported_program) or node.op == "output":
dim_order = tuple(range(node_data.dim()))
dim_order = node_data.dim_order()
elif node_data.dim() == 4:
dim_order = self.NHWC_order
dim_order = NHWC_ORDER
if self.is_weight_node_for_depthwise_conv2d(node):
# The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to
# dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
dim_order = self.HWCM_order
dim_order = HWCM_ORDER
elif node_data.dim() == 5:
dim_order = self.NNHWC_order
dim_order = NNHWC_ORDER
else:
dim_order = tuple(range(node_data.dim())) # type: ignore[assignment]

Expand All @@ -300,32 +316,3 @@ def call(self, graph_module: torch.fx.GraphModule):
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)

def requires(self, graph_module) -> None:
"""
This is the only pass which handles dim_orders, so verify that the output dim_orders has not changed since the beginning of the lowering pipeline.
"""

dim_orders = get_output_dim_orders(graph_module)
original_dim_orders = graph_module.graph.output_node().meta.get(
"original_dim_orders"
)
output_node = graph_module.graph.output_node()

if original_dim_orders is None:
raise RuntimeError(
f"{AnnotateOutputDimOrderPass.__name__} must be run in the beginning of the pass pipeline to verify that the dim order has not changed unexpectedly during its run."
)

if len(dim_orders) != len(original_dim_orders):
raise RuntimeError(
f"The number of outputs has changed since {AnnotateOutputDimOrderPass.__name__} was run."
)

for node, dim_order, original_dim_order in zip(
output_node.args[0], dim_orders, original_dim_orders
):
if dim_order != original_dim_order:
raise RuntimeError(
f"The dim order of output {node.name} has changed from {original_dim_order} to {dim_order} since {AnnotateOutputDimOrderPass.__name__} was run."
)
12 changes: 12 additions & 0 deletions backends/arm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,15 @@
DEQUANT_PER_TENSOR_OP_T,
)
PER_CHANNEL_QDQ_OPS: Final = (QUANT_PER_CHANNEL_OP, DEQUANT_PER_CHANNEL_OP)

NHWC_ORDER: Final = (0, 2, 3, 1)
NHWC_INVERSE_ORDER: Final = (0, 3, 1, 2)
NNHWC_ORDER: Final = (0, 1, 3, 4, 2)
NNHWC_INVERSE_ORDER: Final = (0, 1, 4, 2, 3)

NCHW_ORDER: Final = (0, 1, 2, 3)
NCHW_INVERSE_ORDER: Final = (0, 2, 3, 1)
NNCHW_ORDER: Final = (0, 1, 2, 3, 4)
NNCHW_INVERSE_ORDER: Final = (0, 1, 3, 4, 2)

HWCM_ORDER: Final = (2, 3, 0, 1)
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def _merge_supported_types(
torch.int32,
torch.bfloat16,
torch.float16,
torch.float32,
],
}
ALL_SUPPORTED_TYPES = _merge_supported_types(
Expand Down
7 changes: 0 additions & 7 deletions backends/arm/process_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,6 @@ def process_inputs(
tosa_spec: TosaSpecification,
):
"""Serialize an input node"""
# inputs need to be in default dim_order (contiguous memory format)
meta = node.meta["val"]
if meta.dim_order() != tuple(range(meta.dim())):
raise RuntimeError(
f"Arm backend only supports contiguous memory format for inputs. "
f"Expected dim_order: {tuple(range(meta.dim()))}, but got: {meta.dim_order()} for node {node.name}"
)
try:
tosa_arg = TosaArg(node, tosa_spec)
except ValueError as e:
Expand Down
9 changes: 0 additions & 9 deletions backends/arm/runtime/EthosUBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,6 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
handles.inputs->io[i].elem_size);
return Error::InvalidProgram;
}
supported = executorch::runtime::is_contiguous_dim_order(
tensor_in.dim_order().data(), tensor_in.dim());
if (!supported) {
ET_LOG(
Error,
"Input %d expected contiguous dim_order, but got non-contiguous dim_order",
i);
return Error::InvalidProgram;
}

// Select a compatible copy routine including checking for input layouts
// which require permutation.
Expand Down
123 changes: 123 additions & 0 deletions backends/arm/test/misc/test_dim_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Tuple

import torch
from executorch.backends.arm.test import common

from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineINT,
EthosU85PipelineINT,
TosaPipelineFP,
TosaPipelineINT,
)


input_t1 = Tuple[torch.Tensor] # Input x


class ChannelsLastInput(torch.nn.Module):
"""
Test a complex case with (channels last, channels first) input,
and (channels first, channels last) output.
"""

inputs: input_t1 = (
torch.arange(1, 25, dtype=torch.float32)
.reshape((1, 2, 3, 4))
.to(memory_format=torch.channels_last),
torch.arange(1, 25, dtype=torch.float32).reshape((1, 2, 3, 4)),
)

def forward(self, x, y):
x = x * x
return y, x


class ChannelsFirstOutput(torch.nn.Module):
"""
Test coverting to channels_first inside the delegate.
"""

inputs: input_t1 = (
torch.arange(1, 25, dtype=torch.float32)
.reshape((1, 2, 3, 4))
.to(memory_format=torch.channels_last),
)

def forward(self, x):
x = x.clone(memory_format=torch.contiguous_format) * x
return x


class ChannelsLastOutput(torch.nn.Module):
"""
Test changing of dim_order inside the delegate.
"""

inputs: input_t1 = (torch.arange(1, 9, dtype=torch.float32).reshape((1, 2, 2, 2)),)

def forward(self, x):
x = x * x
x = x.clone(memory_format=torch.channels_last)
return x


class ChannelsLastInsidePartition(torch.nn.Module):
"""
Test dim_order changes inside the partiton, but no dim_order changes at input/output.
"""

inputs: input_t1 = (torch.randn((1, 2, 3, 3)),)

def __init__(self):
super().__init__()
self.conv2d = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=(3, 3))

def forward(self, x):
return (
self.conv2d(x.clone(memory_format=torch.channels_last)).clone(
memory_format=torch.contiguous_format
)
* 1
)


test_modules = {
"channels_last_input": ChannelsLastInput,
"channels_first_output": ChannelsFirstOutput,
"channels_last_output": ChannelsLastOutput,
"channels_last_inside_partition": ChannelsLastInsidePartition,
}


@common.parametrize("module", test_modules)
def test_dim_order_tosa_FP(module):
pipeline = TosaPipelineFP[input_t1](module(), module.inputs, [])
pipeline.run()


@common.parametrize("module", test_modules)
def test_dim_order_tosa_INT(module):
pipeline = TosaPipelineINT[input_t1](
module(), module.inputs, [], symmetric_io_quantization=True
)
pipeline.run()


@common.XfailIfNoCorstone300
@common.parametrize("module", test_modules)
def test_dim_order_u55_INT(module):
pipeline = EthosU55PipelineINT[input_t1](module(), module.inputs, [])
pipeline.run()


@common.XfailIfNoCorstone320
@common.parametrize("module", test_modules)
def test_dim_order_u85_INT(module):
pipeline = EthosU85PipelineINT[input_t1](module(), module.inputs, [])
pipeline.run()
Loading
Loading