-
Notifications
You must be signed in to change notification settings - Fork 934
Add structure for use of TOSA CUSTOM ops #18837
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
b606983
Arm backend: TosaPartitoner allow custom ops
robell 9d0be16
Arm backend: Add TOSA CUSTOM dialect op, visitor
robell 224b23c
lintfix
robell 654b52c
better doc naming
robell a0647fc
fix for public manifest
robell 52d0903
automation fixes
robell a3b7ee1
fixes
robell File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| # Copyright 2026 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from typing import Any, List | ||
|
|
||
| import torch | ||
| import tosa_serializer as ts | ||
|
|
||
| from executorch.backends.arm.operators.node_visitor import ( | ||
| NodeVisitor, | ||
| register_node_visitor, | ||
| ) | ||
| from executorch.backends.arm.tosa.mapping import TosaArg | ||
|
|
||
|
|
||
| @register_node_visitor | ||
| class CustomVisitor(NodeVisitor): | ||
| """Lower the TOSA CUSTOM op from the TOSA backend dialect.""" | ||
|
|
||
| target = "tosa.CUSTOM.default" | ||
|
|
||
| def define_node( | ||
| self, | ||
| node: torch.fx.Node, | ||
| tosa_graph: Any, | ||
| inputs: List[TosaArg], | ||
| output: TosaArg, | ||
| ) -> None: | ||
| allowed_kwargs = {"operator_name", "domain_name", "implementation_attrs"} | ||
| unexpected = set(node.kwargs.keys()) - allowed_kwargs | ||
| if unexpected: | ||
| raise ValueError( | ||
| f"tosa.CUSTOM received unexpected kwargs: {sorted(unexpected)}" | ||
| ) | ||
|
|
||
| operator_name = node.kwargs.get("operator_name") | ||
| domain_name = node.kwargs.get("domain_name") | ||
| implementation_attrs = node.kwargs.get("implementation_attrs") | ||
|
|
||
| if operator_name is None or domain_name is None: | ||
| raise ValueError( | ||
| "tosa.CUSTOM requires operator_name and domain_name in kwargs" | ||
| ) | ||
|
|
||
| if implementation_attrs is None: | ||
| impl_list = [] | ||
| elif isinstance(implementation_attrs, list): | ||
| # NOTE: PyTorch schemas do not support a bytes type; we pass | ||
| # implementation_attrs as int[] representing raw bytes. | ||
| impl_list = [int(x) for x in implementation_attrs] | ||
| else: | ||
| raise TypeError( | ||
| "implementation_attrs must be None or list[int]; " | ||
| f"got {type(implementation_attrs)}" | ||
| ) | ||
robell marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| attr = ts.TosaSerializerAttribute() | ||
| attr.CustomAttribute( | ||
| operator_name=operator_name, | ||
| domain_name=domain_name, | ||
| implementation_attrs=impl_list, | ||
| ) | ||
|
|
||
| expanded = [TosaArg(item, self.tosa_spec) for item in inputs[0].special] | ||
robell marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| input_names = [arg.name for arg in expanded] | ||
robell marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| output_names = ( | ||
| output.multiple_output_names | ||
| if getattr(output, "multiple_output_names", None) | ||
| else [output.name] | ||
| ) | ||
| if len(output_names) != 1: | ||
| # TODO: Support multi-output CUSTOM ops with per-output meta/shape. | ||
| raise ValueError( | ||
| f"tosa.CUSTOM currently requires a single output, got {len(output_names)}" | ||
| ) | ||
| self._serialize_operator( | ||
| node, | ||
| tosa_graph, | ||
| ts.Op.CUSTOM, | ||
| input_names, | ||
| output_names, | ||
| attr, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| # Copyright 2026 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| """Fake-op support for the generic TOSA ``CUSTOM`` dialect op. | ||
|
|
||
| The serialized TOSA ``CUSTOM`` op is intentionally generic: it carries a | ||
| stable operator identity (for example ``myns.my_op``) plus an | ||
| opaque payload in ``implementation_attrs``. That is enough for serialization, | ||
| but not enough for FakeTensor propagation unless we also teach the compiler how | ||
| to model the output tensors of the specific wrapped op. | ||
|
|
||
| This module provides a lightweight registration mechanism for those compiler | ||
| side fake implementations: | ||
|
|
||
| 1. A lowering pass rewrites an op to ``exir_ops.backend.tosa.CUSTOM.default``. | ||
| 2. The wrapped custom op registers a thin adapter with | ||
| ``@register_fake_tosa("namespace::op")``. | ||
| 3. The generic ``CUSTOM`` fake implementation looks up that adapter by the | ||
| ``operator_name`` argument and invokes it with the full custom-op calling | ||
| convention ``(inputs, operator_name, domain_name, implementation_attrs)``. | ||
|
|
||
| The adapter should stay thin: it should only translate from the generic TOSA | ||
| CUSTOM signature back to the wrapped op's fake semantics. The real semantic | ||
| logic should continue to live in the original fake implementation where | ||
| possible. | ||
|
|
||
| """ | ||
|
|
||
| import inspect | ||
| from collections.abc import Callable | ||
|
|
||
| import torch | ||
| from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op | ||
|
|
||
| from executorch.backends.arm.tosa.specification import ( | ||
| get_context_spec, | ||
| TosaSpecification, | ||
| ) | ||
|
|
||
| _TOSA_CUSTOM_FAKE_IMPLS: dict[str, Callable] = {} | ||
|
|
||
|
|
||
| def _normalize_tosa_custom_operator_name(operator_name: str) -> str: | ||
| """Normalize operator names so ``ns::op`` and ``ns.op`` map identically.""" | ||
| return operator_name.replace("::", ".") | ||
|
|
||
|
|
||
| def validate_tosa_custom_fake_impl(fake_impl: object) -> Callable: | ||
| """Validate the signature expected by ``register_fake_tosa``. | ||
|
|
||
| Registered fake implementations must accept the generic TOSA CUSTOM fake | ||
| calling convention: | ||
|
|
||
| ``(inputs, operator_name, domain_name, implementation_attrs)`` | ||
|
|
||
| and return ``list[Tensor]``. | ||
|
|
||
| """ | ||
| if not callable(fake_impl): | ||
| raise TypeError( | ||
| "Expected tosa.CUSTOM fake impl to be callable, " f"got {type(fake_impl)}" | ||
| ) | ||
|
|
||
| params = tuple(inspect.signature(fake_impl).parameters.values()) | ||
| positional_kinds = { | ||
| inspect.Parameter.POSITIONAL_ONLY, | ||
| inspect.Parameter.POSITIONAL_OR_KEYWORD, | ||
| } | ||
| if len(params) != 4 or any(param.kind not in positional_kinds for param in params): | ||
| raise TypeError( | ||
| "tosa.CUSTOM fake impl must have signature " | ||
| "(inputs, operator_name, domain_name, implementation_attrs)" | ||
| ) | ||
| return fake_impl | ||
|
|
||
|
|
||
| def register_fake_tosa(operator_name: str) -> Callable[[Callable], Callable]: | ||
| """Register a fake implementation for a specific wrapped TOSA custom op. | ||
|
|
||
| Args: | ||
| operator_name: Stable custom operator identifier. Both ``ns::op`` and | ||
| ``ns.op`` spellings are accepted. | ||
|
|
||
| Returns: | ||
| A decorator that registers a callable with signature | ||
| ``(inputs, operator_name, domain_name, implementation_attrs)`` and | ||
| returning ``list[Tensor]``. | ||
|
|
||
| Example: | ||
| ``@register_fake_tosa("my_namespace::my_op")`` | ||
|
|
||
| """ | ||
| normalized_name = _normalize_tosa_custom_operator_name(operator_name) | ||
|
|
||
| def decorator(fake_impl: Callable) -> Callable: | ||
| validated = validate_tosa_custom_fake_impl(fake_impl) | ||
| _TOSA_CUSTOM_FAKE_IMPLS[normalized_name] = validated | ||
| return fake_impl | ||
|
|
||
| return decorator | ||
|
|
||
robell marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def has_fake_tosa_impl(operator_name: str) -> bool: | ||
| """Return whether a wrapped custom op has a registered fake impl.""" | ||
| normalized_name = _normalize_tosa_custom_operator_name(operator_name) | ||
| return normalized_name in _TOSA_CUSTOM_FAKE_IMPLS | ||
|
|
||
|
|
||
| def run_registered_fake_tosa_impl( | ||
| inputs: list[torch.Tensor], | ||
| operator_name: str, | ||
| domain_name: str, | ||
| implementation_attrs: list[int], | ||
| ) -> list[torch.Tensor]: | ||
| """Invoke the registered fake implementation for a wrapped custom op.""" | ||
| normalized_name = _normalize_tosa_custom_operator_name(operator_name) | ||
| fake_impl = _TOSA_CUSTOM_FAKE_IMPLS.get(normalized_name) | ||
| if fake_impl is None: | ||
| raise RuntimeError( | ||
| f"tosa.CUSTOM requires a registered fake impl for {normalized_name}" | ||
| ) | ||
| outputs = fake_impl(inputs, operator_name, domain_name, implementation_attrs) | ||
| if not isinstance(outputs, list): | ||
| raise TypeError( | ||
| "tosa.CUSTOM fake impl must return list[Tensor], " f"got {type(outputs)}" | ||
| ) | ||
| if not outputs: | ||
| raise RuntimeError("tosa.CUSTOM fake impl must return at least one output") | ||
| if not all(isinstance(output, torch.Tensor) for output in outputs): | ||
| raise TypeError("tosa.CUSTOM fake impl must return list[Tensor]") | ||
robell marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return outputs | ||
|
|
||
|
|
||
| @register_fake_tosa_op( | ||
| "CUSTOM(Tensor[] inputs, str operator_name, str domain_name, int[] implementation_attrs) -> Tensor[]", | ||
| TosaSpecification.all_versions_and_profiles(), | ||
| ) | ||
| def CUSTOM( | ||
| inputs: list[torch.Tensor], | ||
| operator_name: str, | ||
| domain_name: str, | ||
| implementation_attrs: list[int], | ||
| ) -> list[torch.Tensor]: | ||
| """Fake implementation for TOSA CUSTOM op. | ||
|
|
||
| The CUSTOM op is backend-defined. The fake implementation dispatches to a | ||
| registered compiler-side fake implementation for the specific custom op. | ||
|
|
||
| """ | ||
| _ = get_context_spec() # ensure a spec context exists | ||
| if not inputs: | ||
| raise RuntimeError("tosa.CUSTOM requires at least one input tensor") | ||
| return run_registered_fake_tosa_impl( | ||
robell marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| inputs, | ||
| operator_name, | ||
| domain_name, | ||
| implementation_attrs, | ||
| ) | ||
robell marked this conversation as resolved.
Show resolved
Hide resolved
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.