Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 64 additions & 37 deletions backends/arm/_passes/to_tosa_memory_format_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
NNCHW_ORDER,
NNHWC_INVERSE_ORDER,
NNHWC_ORDER,
NNNCHW_ORDER,
NNNHWC_INVERSE_ORDER,
NNNHWC_ORDER,
)
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
Expand All @@ -51,12 +54,6 @@ class ToTosaMemoryFormatPass(ExportPass):

_passes_required_after: Set[Type[ExportPass]] = set()

NHWC_order = (0, 2, 3, 1)
NHWC_inverse_order = (0, 3, 1, 2)
HWCM_order = (2, 3, 0, 1)
NNHWC_order = (0, 1, 3, 4, 2)
NNHWC_inverse_order = (0, 1, 4, 2, 3)

def __init__(self, exported_program: ExportedProgram) -> None:
self.exported_program = exported_program
super().__init__()
Expand Down Expand Up @@ -93,7 +90,11 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
@staticmethod
def memory_format_differs(shape):
"""Returns true if the shape will have a different memory layout in (N)NCHW and (N)NHWC format"""
if len(shape) >= 5:
if len(shape) >= 6:
C = shape[3]
H = shape[4]
W = shape[5]
elif len(shape) == 5:
C = shape[2]
H = shape[3]
W = shape[4]
Expand All @@ -112,25 +113,26 @@ def memory_format_differs(shape):

@staticmethod
def is_channel_reshape(input_shape, output_shape):
"""Returns true if the reshape changes the channel dimension"""
if not (
(len(input_shape) == len(output_shape) and (len(output_shape) in (4, 5)))
or (len(input_shape) == 4 and len(output_shape) == 5)
or (len(input_shape) == 5 and len(output_shape) == 4)
):
"""Returns true if reshape changes the channel dimension or batch product dimension(s)"""

valid_ranks = {4, 5, 6}

if not (len(input_shape) in valid_ranks and len(output_shape) in valid_ranks):
return False

C_old = input_shape[-3]
C_new = output_shape[-3]

N_new = (
output_shape[0]
if len(output_shape) == 4
else output_shape[0] * output_shape[1]
)
N_old = (
input_shape[0] if len(input_shape) == 4 else input_shape[0] * input_shape[1]
)
def get_batch_prod_dim(shape):
product = 1

for dim in shape[:-3]:
product = product * dim

return product

N_old = get_batch_prod_dim(input_shape)
N_new = get_batch_prod_dim(output_shape)

return (N_old != N_new) or (C_old != C_new)

Expand All @@ -141,17 +143,27 @@ def insert_input_transpose(node, input_node, graph_module):
node.replace_input_with(input_node, pre_permute_node)
return

if len(get_first_fake_tensor(input_node).size()) == 6:
mem_format = NNNHWC_INVERSE_ORDER
elif len(get_first_fake_tensor(input_node).size()) == 5:
mem_format = NNHWC_INVERSE_ORDER
else:
mem_format = NHWC_INVERSE_ORDER
# Guard: mem_format must be a true permutation for the current rank
_rank_ = len(
get_first_fake_tensor(input_node).size()
) # or (node) in output path
assert sorted(mem_format) == list(
range(_rank_)
), f"bad perm {mem_format} for rank {_rank_} in insert_input_transpose"

with graph_module.graph.inserting_before(node):
permute_node = create_node(
graph_module.graph,
exir_ops.backend.tosa.TRANSPOSE.default,
args=(
input_node,
list(
NNHWC_INVERSE_ORDER
if len(get_first_fake_tensor(input_node).size()) == 5
else NHWC_INVERSE_ORDER
),
list(mem_format),
),
from_node=node,
)
Expand All @@ -163,26 +175,38 @@ def insert_input_transpose(node, input_node, graph_module):

@staticmethod
def insert_output_transpose(node, graph_module):

if len(get_first_fake_tensor(node).size()) == 6:
mem_format = NNNHWC_ORDER
elif len(get_first_fake_tensor(node).size()) == 5:
mem_format = NNHWC_ORDER
else:
mem_format = NHWC_ORDER
# Guard: mem_format must be a true permutation for the current rank
_rank_ = len(get_first_fake_tensor(node).size()) # or (node) in output path
assert sorted(mem_format) == list(
range(_rank_)
), f"bad perm {mem_format} for rank {_rank_} in insert_input_transpose"

with graph_module.graph.inserting_after(node):
permute_node = create_node(
graph_module.graph,
exir_ops.backend.tosa.TRANSPOSE.default,
args=(
node,
list(
NNHWC_ORDER
if len(get_first_fake_tensor(node).size()) == 5
else NHWC_ORDER
),
list(mem_format),
),
from_node=node,
)

permute_node.meta["tosa_dim_order"] = (
NNHWC_ORDER
if len(get_first_fake_tensor(node).size()) == 5
else NHWC_ORDER
)
rank = len(get_first_fake_tensor(node).size())
if rank == 6:
permute_node.meta["tosa_dim_order"] = NNNHWC_ORDER
elif rank == 5:
permute_node.meta["tosa_dim_order"] = NNHWC_ORDER
else:
permute_node.meta["tosa_dim_order"] = NHWC_ORDER

node.meta["tosa_dim_order"] = tuple(
range(len(get_first_fake_tensor(node).size()))
)
Expand Down Expand Up @@ -261,7 +285,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
]
for input_node in inputs:
input_dim_order = get_first_fake_tensor(input_node).dim_order()
if input_dim_order in (NCHW_ORDER, NNCHW_ORDER):
if input_dim_order in (NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER):
self.insert_output_transpose(input_node, graph_module)

# Transpose outputs if they are in (N)NCHW format
Expand All @@ -276,6 +300,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
if output_dim_order in (
NCHW_ORDER,
NNCHW_ORDER,
NNNCHW_ORDER,
):
self.insert_input_transpose(
output_node, output_node_input, graph_module
Expand Down Expand Up @@ -313,6 +338,8 @@ def call(self, graph_module: torch.fx.GraphModule):
dim_order = HWCM_ORDER
elif node_data.dim() == 5:
dim_order = NNHWC_ORDER
elif node_data.dim() == 6:
dim_order = NNNHWC_ORDER
else:
dim_order = tuple(range(node_data.dim())) # type: ignore[assignment]

Expand Down
7 changes: 5 additions & 2 deletions backends/arm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@
NHWC_INVERSE_ORDER: Final = (0, 3, 1, 2)
NNHWC_ORDER: Final = (0, 1, 3, 4, 2)
NNHWC_INVERSE_ORDER: Final = (0, 1, 4, 2, 3)
NNNHWC_ORDER: Final = (0, 1, 2, 4, 5, 3)
NNNHWC_INVERSE_ORDER: Final = (0, 1, 2, 5, 3, 4)

NCHW_ORDER: Final = (0, 1, 2, 3)
NCHW_INVERSE_ORDER: Final = (0, 2, 3, 1)
NNCHW_ORDER: Final = (0, 1, 2, 3, 4)
NNCHW_INVERSE_ORDER: Final = (0, 1, 3, 4, 2)
NNNCHW_ORDER: Final = (0, 1, 2, 3, 4, 5)

HWCM_ORDER: Final = (2, 3, 0, 1)

MAX_RANK: Final = 6
4 changes: 2 additions & 2 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
FuseQuantizedActivationPass,
)
from executorch.backends.arm._passes.insert_table_ops import TableOps
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
from executorch.backends.arm.constants import DQ_OPS, MAX_RANK, Q_OPS
from executorch.backends.arm.operator_support.ethos_u55_support import (
EthosU55DtypeSupport,
EthosU55NotSupported,
Expand Down Expand Up @@ -126,7 +126,7 @@ def tosa_support_factory(
negative_checks: list[OperatorSupportBase] = [
CheckInt64InputsAndOutputs(exported_program, reporter),
CheckFloat64Inputs(exported_program, reporter),
RankCheck(reporter, max_rank=5),
RankCheck(reporter, max_rank=MAX_RANK),
*[
reporter.wrap_check(check, f"Rejected by {check.__class__.__name__}")
for check in (additional_checks if additional_checks else [])
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@ def _match_pattern(
torch.ops.aten.dropout_.default,
torch.ops.aten.adaptive_avg_pool2d.default,
torch.ops.aten.alias_copy.default,
torch.ops.aten.pixel_shuffle.default,
torch.ops.aten.pixel_unshuffle.default,
]


Expand Down
2 changes: 2 additions & 0 deletions backends/arm/scripts/parse_test_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
"_native_batch_norm_legit_no_training.default",
"_native_batch_norm_legit.no_stats",
"alias_copy.default",
"pixel_shuffle.default",
"pixel_unshuffle.default",
]
ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,12 @@ class TestSD3Transformer2DModel(unittest.TestCase):

# Adjust nbr below as we increase op support.
ops_after_partitioner_FP = {
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
"executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 1,
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 2,
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
"torch.ops.higher_order.executorch_call_delegate": 1,
}

ops_after_partitioner_INT = {
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 2,
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
"torch.ops.higher_order.executorch_call_delegate": 2,
}
Expand Down
Loading
Loading