Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,12 @@ command = [

[[linter]]
code = 'DOCFORMATTER'
include_patterns = []
exclude_patterns = ['**']
include_patterns = [
'backends/arm/vgf/**/*.py',
'backends/arm/tosa/**/*.py',
'backends/arm/ethosu/**/*.py',
]
exclude_patterns = ['third-party/**', '**/third-party/**']
command = [
'python','-m','lintrunner_adapters','run','docformatter_linter','--config=pyproject.toml','--','@{{PATHSFILE}}'
]
Expand Down
1 change: 1 addition & 0 deletions backends/arm/ethosu/compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class EthosUCompileSpec(ArmCompileSpec):
Vela.
config_ini (str | None): Path to a Vela .ini configuration file.
Defaults to ``"Arm/vela.ini"``.

"""

_TARGET_KEY = "target"
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/ethosu/partitioner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -14,12 +14,12 @@

@final
class EthosUPartitioner(TOSAPartitioner):
"""
Partitions subgraphs supported by the Arm Ethos-U backend.
"""Partitions subgraphs supported by the Arm Ethos-U backend.

Args:
compile_spec: List of CompileSpec objects for Ethos-U backend.
additional_checks: Optional sequence of additional operator support checks.

"""

def __init__(
Expand Down
9 changes: 6 additions & 3 deletions backends/arm/tosa/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def _annotate_external_ids(ep_graph: Graph) -> Dict[str, int]:

Returns:
dict[str, int]: Mapping from *leaf output node name* to external output index.

"""
node2external_id = {}

Expand Down Expand Up @@ -116,8 +117,8 @@ def _sort_key(t: Node) -> int:


def _get_matching_fake_tensor(node: Node):
"""Return a fake tensor with the same properties as node,
but with .dim_order() == node.meta["tosa_dim_order"]
"""Return a fake tensor with the same properties as node, but with
.dim_order() == node.meta["tosa_dim_order"]
"""
fake_tensor = node.meta["val"]
desired_dim_order = node.meta["tosa_dim_order"]
Expand Down Expand Up @@ -267,14 +268,16 @@ def _preprocess( # noqa: C901

@staticmethod
def _regularize_submodule(submodule: GraphModule, submodule_node: Node):
"""To make a submodule fit into the normal flow of a graph_module, we need to do some regularizations.
"""To make a submodule fit into the normal flow of a graph_module, we
need to do some regularizations.

- Buffers created before passes are treated as input to the submodule. Buffers created during passes
are treated as "normal" buffers, i.e. gathered from the state_dict.
To make it easy to tell them apart, mark all placeholders with "is_input = True" before running passes.
- Make sure output node args[0] is always iterable.
- Match the dim_order() of the input tensors with the dim orders of the submodule_node inputs.
- Match the dim_order() of the out tensors with the dim orders of the submodule_node outputs.

"""
submodule_inputs: list[Node] = []
for node in submodule.graph.nodes:
Expand Down
1 change: 1 addition & 0 deletions backends/arm/tosa/compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class TosaCompileSpec(ArmCompileSpec):
Args:
tosa_spec (TosaSpecification | str): Target spec object or version
string supported by ``TosaSpecification.create_from_string``.

"""

def __init__(self, tosa_spec: TosaSpecification | str):
Expand Down
11 changes: 6 additions & 5 deletions backends/arm/tosa/dialect/ops/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
TosaSpecification.all_versions_and_profiles(),
)
def GATHER(values: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
"""
Expected signature (per TOSA):
values: [N, K, C] (rank 3)
indices: [N, W] (rank 2, int32)
output: [N, W, C] (rank 3)
"""Expected signature (per TOSA):

values: [N, K, C] (rank 3)
indices: [N, W] (rank 2, int32)
output: [N, W, C] (rank 3)

"""
tosa_spec = get_context_spec()

Expand Down
2 changes: 2 additions & 0 deletions backends/arm/tosa/dialect/ops/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
def MATMUL(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
tosa_spec = get_context_spec()
"""Performs matrix multiplication on two input tensors.

Additionally validates TOSA constraints of a MATMUL op.

"""
if x1.dtype != x2.dtype:
raise TosaValueError(
Expand Down
5 changes: 4 additions & 1 deletion backends/arm/tosa/dialect/ops/rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ def RESCALE(
x: torch.Tensor, dtype: torch.dtype, scales: List[float], in_zp: int, out_zp: int
) -> torch.Tensor:
tosa_spec = get_context_spec()
"""Casts the input tensor to dtype `dtype` to produce the correct tensor meta for a _rescale op.
"""Casts the input tensor to dtype `dtype` to produce the correct tensor
meta for a _rescale op.

Additionally validates TOSA constraints of a RESCALE op.

"""
if not tosa_spec.support_integer():
raise TosaValueError(
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/tosa/dialect/ops_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
def register_fake_tosa_op(
op_schema: str, tosa_specs: Iterable[TosaSpecification]
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
Decorator for registering a TOSA operation.
"""Decorator for registering a TOSA operation.

Parameters:
op_schema : A string that defines the operation schema.
Expand All @@ -39,6 +38,7 @@ def register_fake_tosa_op(
The decorated function is registered with the given op_schema by calling
register_tosa_dialect_op(op_schema, func) only once per function. The resulting
callable is then inserted into _tosa_registered_ops for each spec.

"""

def decorator(func: Callable[P, R]) -> Callable[P, R]:
Expand Down
1 change: 1 addition & 0 deletions backends/arm/tosa/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class TosaArg:
special (list | None): Captured list when the argument is a sequence.
number (float | int | None): Captured numeric value when provided.
multiple_output_name (list[str]): Output node names when node has multiple outputs; empty otherwise.

"""

def __process_node(self, argument: torch.fx.Node):
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/tosa/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,9 @@ def ops_to_not_decompose( # noqa: C901
}

def filter_fn(node: torch.fx.Node) -> bool:
"""Filter function applied to ops in 'ops_to_not_decompose'.
Returns True if the op should not be decomposed.
If this function returns True, the partitioner *must* accept the node, or the lowering fails.
"""Filter function applied to ops in 'ops_to_not_decompose'. Returns
True if the op should not be decomposed. If this function returns
True, the partitioner *must* accept the node, or the lowering fails.

Args:
node (torch.fx.Node): FX node to evaluate.
Expand Down
26 changes: 15 additions & 11 deletions backends/arm/tosa/specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Provide TOSA specification parsing and context utilities.

Use these helpers to parse and validate TOSA profile/extension strings and to
Expand All @@ -24,12 +23,13 @@ def __init__(self):
self._mapping: Dict[TosaSpecification, List[T]] = {}

def add(self, spec: "TosaSpecification", value: T) -> None:
"""
Adds a value to the mapping for the given TOSA specification.
"""Adds a value to the mapping for the given TOSA specification.

The specification is normalized to its canonical form, which means that
only the version and profiles are considered, without extensions.
This allows for grouping of values under the same TOSA specification
only the version and profiles are considered, without extensions. This
allows for grouping of values under the same TOSA specification
regardless of the extensions they may have.

"""

if spec.is_U55_subset or spec.extensions:
Expand Down Expand Up @@ -61,10 +61,12 @@ def _get_base_specs(spec: "TosaSpecification") -> List["TosaSpecification"]:
return [spec]

def get(self, spec: "TosaSpecification") -> List[T]:
"""
Returns a list of values associated with the given TOSA specification.
"""Returns a list of values associated with the given TOSA
specification.

The specification is normalized to its canonical form, which means that
only the version and profiles are considered, without extensions.

"""

base_specs = self._get_base_specs(spec)
Expand Down Expand Up @@ -215,8 +217,8 @@ def create_from_string(repr: str) -> "TosaSpecification":
raise ValueError(f"Failed to parse TOSA specification representation: {repr}")

def _canonical_key(self) -> "TosaSpecification":
"""
Returns a new TosaSpecification instance with only version and profiles (no extensions).
"""Returns a new TosaSpecification instance with only version and
profiles (no extensions).
"""
raise NotImplementedError

Expand Down Expand Up @@ -366,9 +368,11 @@ def support_extension(self, extension: str) -> bool:
return False

def _canonical_key(self) -> "Tosa_1_00":
"""
Returns a new Tosa_1_00 instance with only major.minor version and profiles (no extensions).
"""Returns a new Tosa_1_00 instance with only major.minor version and
profiles (no extensions).

Patch version is set to zero for normalization.

"""
from packaging.version import Version

Expand Down
1 change: 1 addition & 0 deletions backends/arm/vgf/compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class VgfCompileSpec(ArmCompileSpec):
target. Strings are parsed via ``TosaSpecification.create_from_string``.
Defaults to ``"TOSA-1.0+FP+INT"``.
compiler_flags (list[str] | None): Optional converter-backend flags.

"""

def __init__(
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/vgf/partitioner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -14,12 +14,12 @@

@final
class VgfPartitioner(TOSAPartitioner):
"""
Partitions subgraphs supported by the Arm Vgf backend.
"""Partitions subgraphs supported by the Arm Vgf backend.

Args:
compile_spec: The Vgf compilation specification.
additional_checks: Optional sequence of additional operator support checks.

"""

def __init__(
Expand Down
Loading