Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 131 additions & 30 deletions backends/arm/tosa/specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
"""Provide TOSA specification parsing and context utilities.

#
# Main implementation of AoT flow to partition and preprocess for Arm target
# backends. Converts via TOSA as an intermediate form supported by AoT and
# JIT compiler flows.
#
Use these helpers to parse and validate TOSA profile/extension strings and to
manage a lowering-time context for the active specification.

"""

import contextvars
import re
Expand All @@ -19,36 +19,39 @@


class TosaSpecification:
"""
This class implements a representation of TOSA specification
(https://www.mlplatform.org/tosa/tosa_spec.html) with a version, a profile
(with extension) and a level (8k).
For 1.00 releases the profile is INT or FP, and the extensions are for
INT: int16, int4, var, cf
FP: bf16, fp8e4m3, fp8e5m2, fft, var, cf
"""Represent a TOSA specification.

The TOSA specification is encoded in the string represenatation
TOSA-major.minor.patch+profile[+level][+extensions]
A specification includes a semantic version, one or more profiles, and
optional extensions and levels (for example ``8k``).
The encoded form follows ``TOSA-<major>.<minor>.<patch>+<PROFILE>[+<LEVEL>][+<EXT>...]``.
Profiles use uppercase (for example ``INT``, ``FP``); levels and extensions
use lowercase.

Attributes:
version (Version): Parsed TOSA semantic version.
is_U55_subset (bool): True if the ``u55`` subset is requested.

Profiles are uppercase letters and extensions and level is lowercase.
"""

version: Version
is_U55_subset: bool

def support_integer(self) -> bool:
"""
Returns true if any integer operations are supported for the specification.
"""
"""Return True if integer operations are supported."""
raise NotImplementedError

def support_float(self) -> bool:
"""
Returns true if any float operations are supported for the specification.
"""
"""Return True if floating-point operations are supported."""
raise NotImplementedError

def __init__(self, version: Version, extras: List[str]):
"""Initialize the base specification.

Args:
version (Version): Parsed TOSA semantic version.
extras (List[str]): Remaining tokens such as profiles, levels, and extensions.

"""
self.version = version

self.is_U55_subset = "u55" in extras
Expand All @@ -57,11 +60,20 @@ def __init__(self, version: Version, extras: List[str]):

@staticmethod
def create_from_string(repr: str) -> "TosaSpecification":
"""
Creates a TOSA specification class from a string representation:
TOSA-1.00.0+INT+FP+int4+cf
"""
"""Create a specification from a standard string format.

Example: ``TOSA-1.00.0+INT+FP+int4+cf``.

Args:
repr (str): Standard representation string.

Returns:
TosaSpecification: Parsed specification instance.

Raises:
ValueError: If the representation is malformed or version is unsupported.

"""
pattern = r"^(TOSA)-([\d.]+)\+(.+)$"
match = re.match(pattern, repr)
if match:
Expand All @@ -80,6 +92,18 @@ def create_from_string(repr: str) -> "TosaSpecification":


class Tosa_1_00(TosaSpecification):
"""Provide TOSA 1.00 profile and extension semantics.

This variant validates profiles (``INT``, ``FP``), the optional ``8k`` level,
and allowed extensions based on the selected profiles.

Attributes:
profiles (List[str]): Selected profiles, e.g., ``["INT"]`` or ``["INT", "FP"]``.
level_8k (bool): True if the ``8k`` level is enabled.
extensions (List[str]): Enabled extensions valid for the chosen profiles.

"""

profiles: List[str]
level_8k: bool
extensions: List[str]
Expand All @@ -91,6 +115,16 @@ class Tosa_1_00(TosaSpecification):
}

def __init__(self, version: Version, extras: List[str]):
"""Initialize the 1.00 specification and validate extras.

Args:
version (Version): Semantic version (major=1, minor=0).
extras (List[str]): Tokens including profiles, level, and extensions.

Raises:
ValueError: If no/too many profiles are provided or extensions are invalid.

"""
super().__init__(version, extras)

# Check that we have at least one profile in the extensions list
Expand Down Expand Up @@ -129,12 +163,20 @@ def __init__(self, version: Version, extras: List[str]):
self.extensions = extras

def _get_profiles_string(self) -> str:
"""Return the ``+``-joined profile segment (e.g., ``+INT+FP``)."""
return "".join(["+" + p for p in self.profiles])

def _get_extensions_string(self) -> str:
"""Return the ``+``-joined extensions segment (e.g., ``+int4+cf``)."""
return "".join(["+" + e for e in self.extensions])

def __repr__(self):
"""Return the standard specification string format.

Returns:
str: Standard form like ``TOSA-1.00.0+INT+8k+int4``.

"""
extensions = self._get_extensions_string()
if self.level_8k:
extensions += "+8k"
Expand All @@ -143,22 +185,48 @@ def __repr__(self):
return f"TOSA-{self.version}{self._get_profiles_string()}{extensions}"

def __hash__(self) -> int:
"""Return a stable hash for use in sets and dict keys.

Returns:
int: Hash value derived from version and profiles.

"""
return hash(str(self.version) + self._get_profiles_string())

def __eq__(self, other: object) -> bool:
"""Return True if another instance represents the same spec.

Args:
other (object): Object to compare.

Returns:
bool: True if versions and profiles match.

"""
if isinstance(other, Tosa_1_00):
return (self.version == other.version) and (
self._get_profiles_string() == other._get_profiles_string()
)
return False

def support_integer(self):
"""Return True if the ``INT`` profile is present."""
return "INT" in self.profiles

def support_float(self):
"""Return True if the ``FP`` profile is present."""
return "FP" in self.profiles

def support_extension(self, extension: str) -> bool:
"""Return True if an extension is supported and enabled.

Args:
extension (str): Extension name (for example ``int4``, ``bf16``).

Returns:
bool: True if the extension is valid for the active profiles and selected.

"""
for p in self.profiles:
if extension in self.valid_extensions[p] and extension in self.extensions:
return True
Expand All @@ -167,30 +235,63 @@ def support_extension(self, extension: str) -> bool:


class TosaLoweringContext:
"""
A context manager to handle the TOSA specific aspects of the lowering process.
For now it only handles the TOSA specification context, but it can be extended
to include other policies or configurations.
"""Manage the TOSA specification context for lowering.

For now, only the active ``TosaSpecification`` is tracked, but this can be
extended to carry additional lowering policies or configuration.

Attributes:
tosa_spec_var (contextvars.ContextVar): Context variable storing the active spec.
spec (TosaSpecification): Specification passed to the context manager.

"""

# Define a context variable for the spec
tosa_spec_var: contextvars.ContextVar = contextvars.ContextVar("tosa_spec")

def __init__(self, spec: TosaSpecification):
"""Initialize the lowering context with a specification.

Args:
spec (TosaSpecification): Active specification to put into context.

"""
self.spec = spec

def __enter__(self):
"""Set the context variable and return self.

Returns:
TosaLoweringContext: This context manager instance.

"""
# Set the spec in the context variable and store the token for later reset
self.token = TosaLoweringContext.tosa_spec_var.set(self.spec)
return self

def __exit__(self, exc_type, exc_value, traceback):
"""Reset the context variable to its previous state.

Args:
exc_type (type | None): Exception type, if any.
exc_value (BaseException | None): Exception instance, if any.
traceback (TracebackType | None): Traceback, if any.

"""
# Reset the context variable to its previous state
TosaLoweringContext.tosa_spec_var.reset(self.token)


# A helper function to retrieve the current spec anywhere in your code
def get_context_spec() -> TosaSpecification:
"""Get the current ``TosaSpecification`` from the lowering context.

Returns:
TosaSpecification: Active specification retrieved from the context var.

Raises:
RuntimeError: If called outside a ``TosaLoweringContext``.

"""
try:
return TosaLoweringContext.tosa_spec_var.get()
except LookupError:
Expand Down
Loading