diff --git a/backends/arm/tosa/specification.py b/backends/arm/tosa/specification.py index b372cd5a636..3edf27760b5 100644 --- a/backends/arm/tosa/specification.py +++ b/backends/arm/tosa/specification.py @@ -4,12 +4,12 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe +"""Provide TOSA specification parsing and context utilities. -# -# Main implementation of AoT flow to partition and preprocess for Arm target -# backends. Converts via TOSA as an intermediate form supported by AoT and -# JIT compiler flows. -# +Use these helpers to parse and validate TOSA profile/extension strings and to +manage a lowering-time context for the active specification. + +""" import contextvars import re @@ -19,36 +19,39 @@ class TosaSpecification: - """ - This class implements a representation of TOSA specification - (https://www.mlplatform.org/tosa/tosa_spec.html) with a version, a profile - (with extension) and a level (8k). - For 1.00 releases the profile is INT or FP, and the extensions are for - INT: int16, int4, var, cf - FP: bf16, fp8e4m3, fp8e5m2, fft, var, cf + """Represent a TOSA specification. - The TOSA specification is encoded in the string represenatation - TOSA-major.minor.patch+profile[+level][+extensions] + A specification includes a semantic version, one or more profiles, and + optional extensions and levels (for example ``8k``). + The encoded form follows ``TOSA-..+[+][+...]``. + Profiles use uppercase (for example ``INT``, ``FP``); levels and extensions + use lowercase. + + Attributes: + version (Version): Parsed TOSA semantic version. + is_U55_subset (bool): True if the ``u55`` subset is requested. - Profiles are uppercase letters and extensions and level is lowercase. """ version: Version is_U55_subset: bool def support_integer(self) -> bool: - """ - Returns true if any integer operations are supported for the specification. - """ + """Return True if integer operations are supported.""" raise NotImplementedError def support_float(self) -> bool: - """ - Returns true if any float operations are supported for the specification. - """ + """Return True if floating-point operations are supported.""" raise NotImplementedError def __init__(self, version: Version, extras: List[str]): + """Initialize the base specification. + + Args: + version (Version): Parsed TOSA semantic version. + extras (List[str]): Remaining tokens such as profiles, levels, and extensions. + + """ self.version = version self.is_U55_subset = "u55" in extras @@ -57,11 +60,20 @@ def __init__(self, version: Version, extras: List[str]): @staticmethod def create_from_string(repr: str) -> "TosaSpecification": - """ - Creates a TOSA specification class from a string representation: - TOSA-1.00.0+INT+FP+int4+cf - """ + """Create a specification from a standard string format. + + Example: ``TOSA-1.00.0+INT+FP+int4+cf``. + Args: + repr (str): Standard representation string. + + Returns: + TosaSpecification: Parsed specification instance. + + Raises: + ValueError: If the representation is malformed or version is unsupported. + + """ pattern = r"^(TOSA)-([\d.]+)\+(.+)$" match = re.match(pattern, repr) if match: @@ -80,6 +92,18 @@ def create_from_string(repr: str) -> "TosaSpecification": class Tosa_1_00(TosaSpecification): + """Provide TOSA 1.00 profile and extension semantics. + + This variant validates profiles (``INT``, ``FP``), the optional ``8k`` level, + and allowed extensions based on the selected profiles. + + Attributes: + profiles (List[str]): Selected profiles, e.g., ``["INT"]`` or ``["INT", "FP"]``. + level_8k (bool): True if the ``8k`` level is enabled. + extensions (List[str]): Enabled extensions valid for the chosen profiles. + + """ + profiles: List[str] level_8k: bool extensions: List[str] @@ -91,6 +115,16 @@ class Tosa_1_00(TosaSpecification): } def __init__(self, version: Version, extras: List[str]): + """Initialize the 1.00 specification and validate extras. + + Args: + version (Version): Semantic version (major=1, minor=0). + extras (List[str]): Tokens including profiles, level, and extensions. + + Raises: + ValueError: If no/too many profiles are provided or extensions are invalid. + + """ super().__init__(version, extras) # Check that we have at least one profile in the extensions list @@ -129,12 +163,20 @@ def __init__(self, version: Version, extras: List[str]): self.extensions = extras def _get_profiles_string(self) -> str: + """Return the ``+``-joined profile segment (e.g., ``+INT+FP``).""" return "".join(["+" + p for p in self.profiles]) def _get_extensions_string(self) -> str: + """Return the ``+``-joined extensions segment (e.g., ``+int4+cf``).""" return "".join(["+" + e for e in self.extensions]) def __repr__(self): + """Return the standard specification string format. + + Returns: + str: Standard form like ``TOSA-1.00.0+INT+8k+int4``. + + """ extensions = self._get_extensions_string() if self.level_8k: extensions += "+8k" @@ -143,9 +185,24 @@ def __repr__(self): return f"TOSA-{self.version}{self._get_profiles_string()}{extensions}" def __hash__(self) -> int: + """Return a stable hash for use in sets and dict keys. + + Returns: + int: Hash value derived from version and profiles. + + """ return hash(str(self.version) + self._get_profiles_string()) def __eq__(self, other: object) -> bool: + """Return True if another instance represents the same spec. + + Args: + other (object): Object to compare. + + Returns: + bool: True if versions and profiles match. + + """ if isinstance(other, Tosa_1_00): return (self.version == other.version) and ( self._get_profiles_string() == other._get_profiles_string() @@ -153,12 +210,23 @@ def __eq__(self, other: object) -> bool: return False def support_integer(self): + """Return True if the ``INT`` profile is present.""" return "INT" in self.profiles def support_float(self): + """Return True if the ``FP`` profile is present.""" return "FP" in self.profiles def support_extension(self, extension: str) -> bool: + """Return True if an extension is supported and enabled. + + Args: + extension (str): Extension name (for example ``int4``, ``bf16``). + + Returns: + bool: True if the extension is valid for the active profiles and selected. + + """ for p in self.profiles: if extension in self.valid_extensions[p] and extension in self.extensions: return True @@ -167,30 +235,63 @@ def support_extension(self, extension: str) -> bool: class TosaLoweringContext: - """ - A context manager to handle the TOSA specific aspects of the lowering process. - For now it only handles the TOSA specification context, but it can be extended - to include other policies or configurations. + """Manage the TOSA specification context for lowering. + + For now, only the active ``TosaSpecification`` is tracked, but this can be + extended to carry additional lowering policies or configuration. + + Attributes: + tosa_spec_var (contextvars.ContextVar): Context variable storing the active spec. + spec (TosaSpecification): Specification passed to the context manager. + """ # Define a context variable for the spec tosa_spec_var: contextvars.ContextVar = contextvars.ContextVar("tosa_spec") def __init__(self, spec: TosaSpecification): + """Initialize the lowering context with a specification. + + Args: + spec (TosaSpecification): Active specification to put into context. + + """ self.spec = spec def __enter__(self): + """Set the context variable and return self. + + Returns: + TosaLoweringContext: This context manager instance. + + """ # Set the spec in the context variable and store the token for later reset self.token = TosaLoweringContext.tosa_spec_var.set(self.spec) return self def __exit__(self, exc_type, exc_value, traceback): + """Reset the context variable to its previous state. + + Args: + exc_type (type | None): Exception type, if any. + exc_value (BaseException | None): Exception instance, if any. + traceback (TracebackType | None): Traceback, if any. + + """ # Reset the context variable to its previous state TosaLoweringContext.tosa_spec_var.reset(self.token) -# A helper function to retrieve the current spec anywhere in your code def get_context_spec() -> TosaSpecification: + """Get the current ``TosaSpecification`` from the lowering context. + + Returns: + TosaSpecification: Active specification retrieved from the context var. + + Raises: + RuntimeError: If called outside a ``TosaLoweringContext``. + + """ try: return TosaLoweringContext.tosa_spec_var.get() except LookupError: