Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions backends/arm/test/misc/test_outputs_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# pyre-unsafe
import tempfile
from pathlib import Path

import pytest
import torch
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_quantization_config,
TOSAQuantizer,
)
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
from executorch.backends.arm.tosa.specification import TosaSpecification
from executorch.exir import to_edge_transform_and_lower
from torch import nn
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from tosa import TosaGraph


class Network(nn.Module):
def __init__(self, batch_norm=False):
super().__init__()
self.conv2d_0 = nn.Sequential(
nn.Conv2d(1, 8, 3, padding=1, bias=False),
nn.BatchNorm2d(8) if batch_norm else nn.Identity(),
nn.ReLU(),
)
self.conv2d_1 = nn.Sequential(
nn.Conv2d(8, 8, 3, padding=1, bias=False),
nn.BatchNorm2d(8) if batch_norm else nn.Identity(),
nn.ReLU(),
)
self.conv2d_2 = nn.Sequential(
nn.Conv2d(8, 8, 3, padding=1, bias=False),
nn.BatchNorm2d(8) if batch_norm else nn.Identity(),
nn.ReLU(),
)
self.out_0 = nn.Sequential(nn.Conv2d(8, 1, 3, padding=1, bias=False), nn.ReLU())
self.out_1 = nn.Sequential(nn.Conv2d(8, 2, 3, padding=1, bias=False), nn.ReLU())
self.out_2 = nn.Sequential(nn.Conv2d(8, 3, 3, padding=1, bias=False), nn.ReLU())

def forward(self, x):
x = self.conv2d_0(x)
x = self.conv2d_1(x)
x = self.conv2d_2(x)
out0 = self.out_0(x)
out1 = self.out_1(x)
out2 = self.out_2(x)
return out0, out1, out2


def _read_tosa_outputs(tosa_path: Path):
# Find output tensor names in order and return shapes
buf = tosa_path.read_bytes()
buf_arr = bytearray(buf)
graph = TosaGraph.TosaGraph.GetRootAsTosaGraph(buf_arr, 0)
region = graph.Regions(0)
block = region.Blocks(0)
# Build a dict name - tensor‑shape
tensors = {}
for i in range(block.TensorsLength()):
t = block.Tensors(i)
name = t.Name().decode()
# NHWC
shape = [t.Shape(j) for j in range(t.ShapeLength())]
tensors[name] = shape
shapes = []
for i in range(block.OutputsLength()):
out_name = block.Outputs(i).decode()
shapes.append(tensors[out_name])
return shapes


@pytest.mark.parametrize("batch_size", [1, 4])
def test_network_output_order_and_restore(tmp_path, batch_size):
model = Network(batch_norm=True).eval()
# Prepare spec
spec = TosaSpecification.create_from_string("TOSA-1.0+INT")
compile_spec = ArmCompileSpecBuilder().tosa_compile_spec(tosa_spec=spec).build()
# Setup quantizer
quantizer = TOSAQuantizer(compile_spec)
quantizer.set_global(
get_symmetric_quantization_config(is_qat=True, is_per_channel=False)
)
# Trace the model
dummy = torch.randn(batch_size, 1, 28, 28)
fx_mod = torch.export.export_for_training(model, (dummy,)).module()
model = prepare_pt2e(fx_mod, quantizer)
model(dummy)
model = convert_pt2e(model)
# Export to aten dialect
aten_gm = torch.export.export(model, args=(dummy,), strict=True)
with tempfile.TemporaryDirectory() as tmpdir:
art_dir = Path(tmpdir)
part = TOSAPartitioner(
ArmCompileSpecBuilder()
.tosa_compile_spec(spec)
.dump_intermediate_artifacts_to(str(art_dir))
.build()
)
_ = to_edge_transform_and_lower(aten_gm, partitioner=[part])
# Expect exactly one .tosa file in the artefact dir
tosa_files = list(art_dir.glob("*.tosa"))
assert (
len(tosa_files) == 1
), f"Expected 1 .tosa artefact, found {len(tosa_files)} in {art_dir}"
out_shapes = _read_tosa_outputs(tosa_files[0])
# We use shape that is unique to output to check
# that we preserve output order
channel_dims = [s[1] for s in reversed(out_shapes)]
assert channel_dims == [1, 2, 3], (
"Outputs in .tosa do not keep author order: "
f"expected [1, 2, 3], got {channel_dims}"
)
58 changes: 56 additions & 2 deletions backends/arm/tosa/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# JIT compiler flows.
#
import logging
from typing import cast, final, List
from collections import deque
from itertools import count
from typing import cast, Dict, final, List, Set

import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
Expand All @@ -26,12 +28,38 @@
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch.export.exported_program import ExportedProgram
from torch.fx import Node
from torch.fx import Graph, Node

# TOSA backend debug functionality
logger = logging.getLogger(__name__)


def _annotate_external_ids(ep_graph: Graph) -> Dict[str, int]:
"""
Returns dictionary: node name -> external ids

Assign id to an output node of the model so we can trace it.
"""
node2external_id = {}

def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]):
q = deque(start_nodes)
while q:
n = q.popleft()
if n in seen:
continue
seen.add(n)
node2external_id[n.name] = idx
# Walk backwards so we touch every producer
q.extend(n.all_input_nodes)

out = next(n for n in ep_graph.nodes if n.op == "output")
seen: Set[Node] = set()
for idx, val in enumerate(out.args[0]):
bfs_mark([val], idx, seen)
return node2external_id


def arm_get_first_delegation_tag(graph_module) -> str:
"""Get the first delegation tag from the graph_module or return empty string."""
for node in graph_module.graph.nodes:
Expand Down Expand Up @@ -75,6 +103,9 @@ def preprocess( # noqa: C901
if output_format != "tosa":
raise ValueError(f'Invalid output format {output_format}, must be "tosa"')

# Assign to every node external id
node_2_id = _annotate_external_ids(edge_program.graph)

tosa_spec = get_tosa_spec(compile_spec)
if tosa_spec is None:
raise ValueError(
Expand Down Expand Up @@ -107,6 +138,29 @@ def preprocess( # noqa: C901
from executorch.backends.arm.operators.node_visitor import get_node_visitors

node_visitors = get_node_visitors(edge_program, tosa_spec, debug_hook)

# Re-shuffle output nodes to preserve author's order
def _external_id(n: Node, node_2_id, fallback: int) -> int:
return node_2_id.get(n.name, fallback)

out_node = next(n for n in graph_module.graph.nodes if n.op == "output")
_counter = count()

# sort nodes by the key that is id
def _sort_key(t: Node) -> int:
return _external_id(t, node_2_id, next(_counter))

orig_ord = tuple(sorted(out_node.args[0], key=_sort_key))

current_order = tuple(out_node.args[0])
if orig_ord != current_order:
replacement = (
list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord
)
out_node.args = (replacement,)
graph_module.graph.lint()
graph_module.recompile()

input_count = 0
for node in graph_module.graph.nodes:
node = cast(Node, node)
Expand Down
Loading