Skip to content

Commit 3bc9282

Browse files
Arm backend: Add docstrings for tosa/specification.py (#14536)
Signed-off-by: Sebastian Larsson <sebastian.larsson@arm.com>
1 parent 042e087 commit 3bc9282

File tree

1 file changed

+131
-30
lines changed

1 file changed

+131
-30
lines changed

backends/arm/tosa/specification.py

Lines changed: 131 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7+
"""Provide TOSA specification parsing and context utilities.
78
8-
#
9-
# Main implementation of AoT flow to partition and preprocess for Arm target
10-
# backends. Converts via TOSA as an intermediate form supported by AoT and
11-
# JIT compiler flows.
12-
#
9+
Use these helpers to parse and validate TOSA profile/extension strings and to
10+
manage a lowering-time context for the active specification.
11+
12+
"""
1313

1414
import contextvars
1515
import re
@@ -19,36 +19,39 @@
1919

2020

2121
class TosaSpecification:
22-
"""
23-
This class implements a representation of TOSA specification
24-
(https://www.mlplatform.org/tosa/tosa_spec.html) with a version, a profile
25-
(with extension) and a level (8k).
26-
For 1.00 releases the profile is INT or FP, and the extensions are for
27-
INT: int16, int4, var, cf
28-
FP: bf16, fp8e4m3, fp8e5m2, fft, var, cf
22+
"""Represent a TOSA specification.
2923
30-
The TOSA specification is encoded in the string represenatation
31-
TOSA-major.minor.patch+profile[+level][+extensions]
24+
A specification includes a semantic version, one or more profiles, and
25+
optional extensions and levels (for example ``8k``).
26+
The encoded form follows ``TOSA-<major>.<minor>.<patch>+<PROFILE>[+<LEVEL>][+<EXT>...]``.
27+
Profiles use uppercase (for example ``INT``, ``FP``); levels and extensions
28+
use lowercase.
29+
30+
Attributes:
31+
version (Version): Parsed TOSA semantic version.
32+
is_U55_subset (bool): True if the ``u55`` subset is requested.
3233
33-
Profiles are uppercase letters and extensions and level is lowercase.
3434
"""
3535

3636
version: Version
3737
is_U55_subset: bool
3838

3939
def support_integer(self) -> bool:
40-
"""
41-
Returns true if any integer operations are supported for the specification.
42-
"""
40+
"""Return True if integer operations are supported."""
4341
raise NotImplementedError
4442

4543
def support_float(self) -> bool:
46-
"""
47-
Returns true if any float operations are supported for the specification.
48-
"""
44+
"""Return True if floating-point operations are supported."""
4945
raise NotImplementedError
5046

5147
def __init__(self, version: Version, extras: List[str]):
48+
"""Initialize the base specification.
49+
50+
Args:
51+
version (Version): Parsed TOSA semantic version.
52+
extras (List[str]): Remaining tokens such as profiles, levels, and extensions.
53+
54+
"""
5255
self.version = version
5356

5457
self.is_U55_subset = "u55" in extras
@@ -57,11 +60,20 @@ def __init__(self, version: Version, extras: List[str]):
5760

5861
@staticmethod
5962
def create_from_string(repr: str) -> "TosaSpecification":
60-
"""
61-
Creates a TOSA specification class from a string representation:
62-
TOSA-1.00.0+INT+FP+int4+cf
63-
"""
63+
"""Create a specification from a standard string format.
64+
65+
Example: ``TOSA-1.00.0+INT+FP+int4+cf``.
6466
67+
Args:
68+
repr (str): Standard representation string.
69+
70+
Returns:
71+
TosaSpecification: Parsed specification instance.
72+
73+
Raises:
74+
ValueError: If the representation is malformed or version is unsupported.
75+
76+
"""
6577
pattern = r"^(TOSA)-([\d.]+)\+(.+)$"
6678
match = re.match(pattern, repr)
6779
if match:
@@ -80,6 +92,18 @@ def create_from_string(repr: str) -> "TosaSpecification":
8092

8193

8294
class Tosa_1_00(TosaSpecification):
95+
"""Provide TOSA 1.00 profile and extension semantics.
96+
97+
This variant validates profiles (``INT``, ``FP``), the optional ``8k`` level,
98+
and allowed extensions based on the selected profiles.
99+
100+
Attributes:
101+
profiles (List[str]): Selected profiles, e.g., ``["INT"]`` or ``["INT", "FP"]``.
102+
level_8k (bool): True if the ``8k`` level is enabled.
103+
extensions (List[str]): Enabled extensions valid for the chosen profiles.
104+
105+
"""
106+
83107
profiles: List[str]
84108
level_8k: bool
85109
extensions: List[str]
@@ -91,6 +115,16 @@ class Tosa_1_00(TosaSpecification):
91115
}
92116

93117
def __init__(self, version: Version, extras: List[str]):
118+
"""Initialize the 1.00 specification and validate extras.
119+
120+
Args:
121+
version (Version): Semantic version (major=1, minor=0).
122+
extras (List[str]): Tokens including profiles, level, and extensions.
123+
124+
Raises:
125+
ValueError: If no/too many profiles are provided or extensions are invalid.
126+
127+
"""
94128
super().__init__(version, extras)
95129

96130
# Check that we have at least one profile in the extensions list
@@ -129,12 +163,20 @@ def __init__(self, version: Version, extras: List[str]):
129163
self.extensions = extras
130164

131165
def _get_profiles_string(self) -> str:
166+
"""Return the ``+``-joined profile segment (e.g., ``+INT+FP``)."""
132167
return "".join(["+" + p for p in self.profiles])
133168

134169
def _get_extensions_string(self) -> str:
170+
"""Return the ``+``-joined extensions segment (e.g., ``+int4+cf``)."""
135171
return "".join(["+" + e for e in self.extensions])
136172

137173
def __repr__(self):
174+
"""Return the standard specification string format.
175+
176+
Returns:
177+
str: Standard form like ``TOSA-1.00.0+INT+8k+int4``.
178+
179+
"""
138180
extensions = self._get_extensions_string()
139181
if self.level_8k:
140182
extensions += "+8k"
@@ -143,22 +185,48 @@ def __repr__(self):
143185
return f"TOSA-{self.version}{self._get_profiles_string()}{extensions}"
144186

145187
def __hash__(self) -> int:
188+
"""Return a stable hash for use in sets and dict keys.
189+
190+
Returns:
191+
int: Hash value derived from version and profiles.
192+
193+
"""
146194
return hash(str(self.version) + self._get_profiles_string())
147195

148196
def __eq__(self, other: object) -> bool:
197+
"""Return True if another instance represents the same spec.
198+
199+
Args:
200+
other (object): Object to compare.
201+
202+
Returns:
203+
bool: True if versions and profiles match.
204+
205+
"""
149206
if isinstance(other, Tosa_1_00):
150207
return (self.version == other.version) and (
151208
self._get_profiles_string() == other._get_profiles_string()
152209
)
153210
return False
154211

155212
def support_integer(self):
213+
"""Return True if the ``INT`` profile is present."""
156214
return "INT" in self.profiles
157215

158216
def support_float(self):
217+
"""Return True if the ``FP`` profile is present."""
159218
return "FP" in self.profiles
160219

161220
def support_extension(self, extension: str) -> bool:
221+
"""Return True if an extension is supported and enabled.
222+
223+
Args:
224+
extension (str): Extension name (for example ``int4``, ``bf16``).
225+
226+
Returns:
227+
bool: True if the extension is valid for the active profiles and selected.
228+
229+
"""
162230
for p in self.profiles:
163231
if extension in self.valid_extensions[p] and extension in self.extensions:
164232
return True
@@ -167,30 +235,63 @@ def support_extension(self, extension: str) -> bool:
167235

168236

169237
class TosaLoweringContext:
170-
"""
171-
A context manager to handle the TOSA specific aspects of the lowering process.
172-
For now it only handles the TOSA specification context, but it can be extended
173-
to include other policies or configurations.
238+
"""Manage the TOSA specification context for lowering.
239+
240+
For now, only the active ``TosaSpecification`` is tracked, but this can be
241+
extended to carry additional lowering policies or configuration.
242+
243+
Attributes:
244+
tosa_spec_var (contextvars.ContextVar): Context variable storing the active spec.
245+
spec (TosaSpecification): Specification passed to the context manager.
246+
174247
"""
175248

176249
# Define a context variable for the spec
177250
tosa_spec_var: contextvars.ContextVar = contextvars.ContextVar("tosa_spec")
178251

179252
def __init__(self, spec: TosaSpecification):
253+
"""Initialize the lowering context with a specification.
254+
255+
Args:
256+
spec (TosaSpecification): Active specification to put into context.
257+
258+
"""
180259
self.spec = spec
181260

182261
def __enter__(self):
262+
"""Set the context variable and return self.
263+
264+
Returns:
265+
TosaLoweringContext: This context manager instance.
266+
267+
"""
183268
# Set the spec in the context variable and store the token for later reset
184269
self.token = TosaLoweringContext.tosa_spec_var.set(self.spec)
185270
return self
186271

187272
def __exit__(self, exc_type, exc_value, traceback):
273+
"""Reset the context variable to its previous state.
274+
275+
Args:
276+
exc_type (type | None): Exception type, if any.
277+
exc_value (BaseException | None): Exception instance, if any.
278+
traceback (TracebackType | None): Traceback, if any.
279+
280+
"""
188281
# Reset the context variable to its previous state
189282
TosaLoweringContext.tosa_spec_var.reset(self.token)
190283

191284

192-
# A helper function to retrieve the current spec anywhere in your code
193285
def get_context_spec() -> TosaSpecification:
286+
"""Get the current ``TosaSpecification`` from the lowering context.
287+
288+
Returns:
289+
TosaSpecification: Active specification retrieved from the context var.
290+
291+
Raises:
292+
RuntimeError: If called outside a ``TosaLoweringContext``.
293+
294+
"""
194295
try:
195296
return TosaLoweringContext.tosa_spec_var.get()
196297
except LookupError:

0 commit comments

Comments
 (0)