Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 53 additions & 18 deletions backends/arm/tosa/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,19 @@


def _annotate_external_ids(ep_graph: Graph) -> Dict[str, int]:
"""
Returns dictionary: node name -> external ids
"""Assign deterministic output IDs to nodes reachable from graph outputs.

Args:
ep_graph (Graph): FX graph produced by export preprocessing.

Returns:
dict[str, int]: Mapping from node name to external output index.

Assign id to an output node of the model so we can trace it.
"""
node2external_id = {}

def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]):
"""Walk producer graph from ``start_nodes`` and record external IDs."""
q = deque(start_nodes)
while q:
n = q.popleft()
Expand All @@ -71,7 +76,19 @@ def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]):


def _sort_outputs(graph_module: GraphModule, node_to_id_map: dict[str, int]):
"""Reorder graph outputs to match ascending external IDs.

Args:
graph_module (GraphModule): Graph to reorder in place.
node_to_id_map (dict[str, int]): Mapping from node name to output index.

Returns:
GraphModule: Updated graph module with deterministic output ordering.

"""

def _external_id(n: Node, node_2_id, fallback: int) -> int:
"""Return the external ID for ``n`` or ``fallback`` when absent."""
return node_2_id.get(n.name, fallback)

out_node = graph_module.graph.output_node()
Expand All @@ -80,6 +97,7 @@ def _external_id(n: Node, node_2_id, fallback: int) -> int:

# sort nodes by the key that is id
def _sort_key(t: Node) -> int:
"""Key function that orders outputs by external ID or position."""
return _external_id(t, node_to_id_map, next(_counter))

orig_ord = tuple(sorted(out_list, key=_sort_key))
Expand All @@ -95,14 +113,14 @@ def _sort_key(t: Node) -> int:


def arm_get_first_delegation_tag(graph_module) -> str:
"""Return the first delegation tag from the FX graph.
"""Return the first delegation tag discovered in the FX graph.

Args:
graph_module: FX GraphModule produced by the Arm passes.
graph_module (GraphModule): Module produced by Arm partitioning.

Returns:
str: The first non-empty delegation tag found on any node, or an empty
string if none is present.
str: First non-empty delegation tag or an empty string when no tag is
recorded.

"""
for node in graph_module.graph.nodes:
Expand All @@ -125,6 +143,17 @@ class TOSABackend(BackendDetails):

@staticmethod
def preprocess(edge_program: ExportedProgram, compile_specs: List[CompileSpec]):
"""Convert an exported program using the provided compile specs.

Args:
edge_program (ExportedProgram): Program generated by Torch export.
compile_specs (List[CompileSpec]): Raw compile specifications from
``executorch.apply_backend``.

Returns:
PreprocessResult: Result containing serialized TOSA bytes.

"""
return TOSABackend._preprocess(
edge_program, TosaCompileSpec.from_list(compile_specs)
)
Expand All @@ -142,7 +171,7 @@ def _preprocess( # noqa: C901

Args:
edge_program (ExportedProgram): Program to lower to TOSA.
compile_spec (List[CompileSpec]): Backend options. Recognized keys:
compile_spec (TosaCompileSpec): Backend options. Recognized keys:
- output_format: Must be "tosa".
- tosa_spec: Target TOSA version/capabilities.
- debug_artifact_path: Directory for debug outputs.
Expand Down Expand Up @@ -233,7 +262,20 @@ def _preprocess_module( # noqa: C901
debug_hook: DebugHook | None,
submodule_name: str | None = None,
):
"""Convert 'graph_module' to a tosa_graph"""
"""Convert an FX ``graph_module`` to TOSA serializer calls.

Args:
graph_module (GraphModule): Module to lower recursively.
edge_program (ExportedProgram): Original exported program.
compile_spec (TosaCompileSpec): Backend options with TOSA settings.
tosa_graph (ts.TosaSerializer): Serializer receiving operators.
debug_hook (DebugHook | None): Optional debug instrumentation.
submodule_name (str | None): Name used when visiting nested blocks.

Raises:
RuntimeError: If an FX node with an unsupported op kind is found.

"""
tosa_spec = compile_spec.tosa_spec
node_to_id_map = _annotate_external_ids(graph_module.graph)
artifact_path = compile_spec.get_intermediate_path()
Expand Down Expand Up @@ -305,24 +347,17 @@ def _preprocess_module( # noqa: C901
def filter_tosa_compile_specs(
compile_spec: ArmCompileSpec,
) -> TosaCompileSpec:
"""
Filter out the CompileSpec elements relevant for the TOSA backend.
This is needed to compose a backend targetting hardware IP with the
TOSABackend, since we first want to use the TOSABackend to generate
the TOSA flatbuffer representation as an intermediate step. The TOSA
flatbuffer can then be consumed by the backend targetting specific
hardware.
"""Extract the TOSA-specific settings from a composite compile spec.

Args:
compile_spec (ArmCompileSpec): Compile specification that may
include both TOSA and hardware-specific options.

Returns:
TosaCompileSpec: TOSA-only specification ready for
``TOSABackend.preprocess``.
``TOSABackend.preprocess``.

"""

return (
TosaCompileSpec(compile_spec.tosa_spec)
.dump_intermediate_artifacts_to(compile_spec.get_intermediate_path())
Expand Down
45 changes: 29 additions & 16 deletions backends/arm/tosa/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Provide PyTorch-to-TOSA mapping helpers.

Use these utilities to translate PyTorch dtypes and FX node metadata into
the TOSA serializer types and shapes used during initial compilation.
Use these utilities to translate PyTorch dtypes and FX node metadata into the
TOSA serializer types and shapes used during initial compilation.

"""

Expand Down Expand Up @@ -34,18 +33,27 @@


class TosaSpecialDtype(Enum):
"""
Special TOSA data types that are not natively supported in PyTorch, to be
used in specific scenarios as a value in the key from meta_key().
"""
"""Special TOSA dtypes not natively expressed in PyTorch."""

INT48 = ts.DType.INT48

def get_tosa_dtype(self) -> ts.DType:
"""Return the underlying ``ts.DType`` enumerant.

Returns:
ts.DType: Serializer dtype associated with the enum entry.

"""
return self.value

@staticmethod
def meta_key() -> str:
"""Return the FX ``meta`` key that stores special dtypes.

Returns:
str: Metadata key used to encode :class:`TosaSpecialDtype`.

"""
return "tosa_special_dtype"


Expand All @@ -57,7 +65,7 @@ def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any:
tosa_spec (TosaSpecification): Active spec (reserved for future checks).

Returns:
Any: Matching ``ts.DType`` enum value.
ts.DType: Matching serializer dtype.

Raises:
ValueError: If the dtype is unsupported or unknown.
Expand Down Expand Up @@ -95,8 +103,8 @@ def extract_tensor_meta(meta, tosa_spec: TosaSpecification):
tosa_spec (TosaSpecification): Active TOSA spec for dtype mapping.

Returns:
tuple: ``(dtype, shape, dim_order)`` where ``dtype`` is ``ts.DType``,
``shape`` is ``Tuple[int, ...]``, and ``dim_order`` is ``Tuple[int, ...]``.
tuple[ts.DType, tuple[int, ...], tuple[int, ...]]: Tuple containing
tensor dtype, shape, and dimension order.

Raises:
ValueError: If ``meta['val']`` is not a ``FakeTensor``.
Expand Down Expand Up @@ -130,12 +138,14 @@ class TosaArg:
consistent structure suitable for TOSA serialization.

Attributes:
name (str): Node name when argument is a ``torch.fx.Node``; empty otherwise.
name (str): Node name when argument is a ``torch.fx.Node``; empty
otherwise.
dtype (ts.DType | None): Inferred dtype when available.
shape (tuple[int, ...] | None): Inferred shape when available.
dim_order (tuple[int, ...] | None): Dimension order, defaulting to ``range(len(shape))``.
dim_order (tuple[int, ...] | None): Dimension order, defaulting to
``range(len(shape))``.
special (list | None): Captured list when the argument is a sequence.
number (float | int | None): Captured numeric value when given.
number (float | int | None): Captured numeric value when provided.
tosa_spec (TosaSpecification): Active specification used for mapping.
multiple_output_name (list[str]): Output node names when node has multiple outputs; empty otherwise.
"""
Expand Down Expand Up @@ -174,7 +184,7 @@ def __process_list(self, argument):
"""Capture a sequence argument as ``special``.

Args:
argument (Sequence): Sequence to store.
argument (Sequence[Any]): Sequence to store.

"""
self.special: list = list(argument)
Expand All @@ -194,10 +204,13 @@ def __init__(
"""Initialize the argument wrapper and populate fields.

Args:
argument (Any): One of ``torch.fx.Node``, ``Sequence``, ``int``, ``float``, ``torch.dtype``, or ``None``.
tosa_spec (Optional[TosaSpecification]): Active specification; required.
argument (Any): One of ``torch.fx.Node``, ``Sequence``, ``int``,
``float``, ``torch.dtype``, or ``None``.
tosa_spec (Optional[TosaSpecification]): Active specification;
required for metadata extraction.

Raises:
ValueError: If ``tosa_spec`` is missing or has the wrong type.
RuntimeError: If ``argument`` is of an unsupported type.

"""
Expand Down
8 changes: 6 additions & 2 deletions backends/arm/tosa/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Provide a partitioner for delegating subgraphs to the TOSA backend.

Implement logic to identify and tag regions of an ``ExportedProgram`` that can
Expand All @@ -11,6 +10,7 @@
- Partition graphs based on operator support and additional checks.
- Prune trivial no-op partitions that would lower to empty TOSA graphs.
- Tag constant data and report reasons for rejected nodes.

"""

import logging
Expand Down Expand Up @@ -142,6 +142,7 @@ def reject_partition(
partition (object): Proposed partition object from the
capability partitioner.
reporter (WhyNoPartitionReporter): used to report why nodes were rejected.

"""
for node in partition.nodes:
if "delegation_tag" in node.meta:
Expand All @@ -158,6 +159,7 @@ class TOSAPartitioner(Partitioner):
Construct this partitioner for compile specs targeting TOSA. The partition
algorithm uses capability checks and optional additional operator-support
rules to tag nodes with a delegation tag per subgraph.

"""

def __init__(
Expand Down Expand Up @@ -191,14 +193,16 @@ def _tag_module( # noqa
reporter: WhyNoPartitionReporter,
tag_iterator: count | None = None,
) -> set[str]:
"""Tag nodes in a module, possibly a submodule, from the containing program.
"""Tag nodes in a module or submodule from the containing program.

Args:
module: A GraphModule from `containing_program` to tag nodes in.
containing_program: The ExportedProgram that contains the module.
reporter: A reporter to report why nodes were rejected.

Returns:
A set of strings with the partition tags.

"""
tags: set[str] = set()
if tag_iterator is None:
Expand Down
Loading