4
4
# LICENSE file in the root directory of this source tree.
5
5
6
6
# pyre-unsafe
7
+ """Provide TOSA specification parsing and context utilities.
7
8
8
- #
9
- # Main implementation of AoT flow to partition and preprocess for Arm target
10
- # backends. Converts via TOSA as an intermediate form supported by AoT and
11
- # JIT compiler flows.
12
- #
9
+ Use these helpers to parse and validate TOSA profile/extension strings and to
10
+ manage a lowering-time context for the active specification.
11
+
12
+ """
13
13
14
14
import contextvars
15
15
import re
19
19
20
20
21
21
class TosaSpecification :
22
- """
23
- This class implements a representation of TOSA specification
24
- (https://www.mlplatform.org/tosa/tosa_spec.html) with a version, a profile
25
- (with extension) and a level (8k).
26
- For 1.00 releases the profile is INT or FP, and the extensions are for
27
- INT: int16, int4, var, cf
28
- FP: bf16, fp8e4m3, fp8e5m2, fft, var, cf
22
+ """Represent a TOSA specification.
29
23
30
- The TOSA specification is encoded in the string represenatation
31
- TOSA-major.minor.patch+profile[+level][+extensions]
24
+ A specification includes a semantic version, one or more profiles, and
25
+ optional extensions and levels (for example ``8k``).
26
+ The encoded form follows ``TOSA-<major>.<minor>.<patch>+<PROFILE>[+<LEVEL>][+<EXT>...]``.
27
+ Profiles use uppercase (for example ``INT``, ``FP``); levels and extensions
28
+ use lowercase.
29
+
30
+ Attributes:
31
+ version (Version): Parsed TOSA semantic version.
32
+ is_U55_subset (bool): True if the ``u55`` subset is requested.
32
33
33
- Profiles are uppercase letters and extensions and level is lowercase.
34
34
"""
35
35
36
36
version : Version
37
37
is_U55_subset : bool
38
38
39
39
def support_integer (self ) -> bool :
40
- """
41
- Returns true if any integer operations are supported for the specification.
42
- """
40
+ """Return True if integer operations are supported."""
43
41
raise NotImplementedError
44
42
45
43
def support_float (self ) -> bool :
46
- """
47
- Returns true if any float operations are supported for the specification.
48
- """
44
+ """Return True if floating-point operations are supported."""
49
45
raise NotImplementedError
50
46
51
47
def __init__ (self , version : Version , extras : List [str ]):
48
+ """Initialize the base specification.
49
+
50
+ Args:
51
+ version (Version): Parsed TOSA semantic version.
52
+ extras (List[str]): Remaining tokens such as profiles, levels, and extensions.
53
+
54
+ """
52
55
self .version = version
53
56
54
57
self .is_U55_subset = "u55" in extras
@@ -57,11 +60,20 @@ def __init__(self, version: Version, extras: List[str]):
57
60
58
61
@staticmethod
59
62
def create_from_string (repr : str ) -> "TosaSpecification" :
60
- """
61
- Creates a TOSA specification class from a string representation:
62
- TOSA-1.00.0+INT+FP+int4+cf
63
- """
63
+ """Create a specification from a standard string format.
64
+
65
+ Example: ``TOSA-1.00.0+INT+FP+int4+cf``.
64
66
67
+ Args:
68
+ repr (str): Standard representation string.
69
+
70
+ Returns:
71
+ TosaSpecification: Parsed specification instance.
72
+
73
+ Raises:
74
+ ValueError: If the representation is malformed or version is unsupported.
75
+
76
+ """
65
77
pattern = r"^(TOSA)-([\d.]+)\+(.+)$"
66
78
match = re .match (pattern , repr )
67
79
if match :
@@ -80,6 +92,18 @@ def create_from_string(repr: str) -> "TosaSpecification":
80
92
81
93
82
94
class Tosa_1_00 (TosaSpecification ):
95
+ """Provide TOSA 1.00 profile and extension semantics.
96
+
97
+ This variant validates profiles (``INT``, ``FP``), the optional ``8k`` level,
98
+ and allowed extensions based on the selected profiles.
99
+
100
+ Attributes:
101
+ profiles (List[str]): Selected profiles, e.g., ``["INT"]`` or ``["INT", "FP"]``.
102
+ level_8k (bool): True if the ``8k`` level is enabled.
103
+ extensions (List[str]): Enabled extensions valid for the chosen profiles.
104
+
105
+ """
106
+
83
107
profiles : List [str ]
84
108
level_8k : bool
85
109
extensions : List [str ]
@@ -91,6 +115,16 @@ class Tosa_1_00(TosaSpecification):
91
115
}
92
116
93
117
def __init__ (self , version : Version , extras : List [str ]):
118
+ """Initialize the 1.00 specification and validate extras.
119
+
120
+ Args:
121
+ version (Version): Semantic version (major=1, minor=0).
122
+ extras (List[str]): Tokens including profiles, level, and extensions.
123
+
124
+ Raises:
125
+ ValueError: If no/too many profiles are provided or extensions are invalid.
126
+
127
+ """
94
128
super ().__init__ (version , extras )
95
129
96
130
# Check that we have at least one profile in the extensions list
@@ -129,12 +163,20 @@ def __init__(self, version: Version, extras: List[str]):
129
163
self .extensions = extras
130
164
131
165
def _get_profiles_string (self ) -> str :
166
+ """Return the ``+``-joined profile segment (e.g., ``+INT+FP``)."""
132
167
return "" .join (["+" + p for p in self .profiles ])
133
168
134
169
def _get_extensions_string (self ) -> str :
170
+ """Return the ``+``-joined extensions segment (e.g., ``+int4+cf``)."""
135
171
return "" .join (["+" + e for e in self .extensions ])
136
172
137
173
def __repr__ (self ):
174
+ """Return the standard specification string format.
175
+
176
+ Returns:
177
+ str: Standard form like ``TOSA-1.00.0+INT+8k+int4``.
178
+
179
+ """
138
180
extensions = self ._get_extensions_string ()
139
181
if self .level_8k :
140
182
extensions += "+8k"
@@ -143,22 +185,48 @@ def __repr__(self):
143
185
return f"TOSA-{ self .version } { self ._get_profiles_string ()} { extensions } "
144
186
145
187
def __hash__ (self ) -> int :
188
+ """Return a stable hash for use in sets and dict keys.
189
+
190
+ Returns:
191
+ int: Hash value derived from version and profiles.
192
+
193
+ """
146
194
return hash (str (self .version ) + self ._get_profiles_string ())
147
195
148
196
def __eq__ (self , other : object ) -> bool :
197
+ """Return True if another instance represents the same spec.
198
+
199
+ Args:
200
+ other (object): Object to compare.
201
+
202
+ Returns:
203
+ bool: True if versions and profiles match.
204
+
205
+ """
149
206
if isinstance (other , Tosa_1_00 ):
150
207
return (self .version == other .version ) and (
151
208
self ._get_profiles_string () == other ._get_profiles_string ()
152
209
)
153
210
return False
154
211
155
212
def support_integer (self ):
213
+ """Return True if the ``INT`` profile is present."""
156
214
return "INT" in self .profiles
157
215
158
216
def support_float (self ):
217
+ """Return True if the ``FP`` profile is present."""
159
218
return "FP" in self .profiles
160
219
161
220
def support_extension (self , extension : str ) -> bool :
221
+ """Return True if an extension is supported and enabled.
222
+
223
+ Args:
224
+ extension (str): Extension name (for example ``int4``, ``bf16``).
225
+
226
+ Returns:
227
+ bool: True if the extension is valid for the active profiles and selected.
228
+
229
+ """
162
230
for p in self .profiles :
163
231
if extension in self .valid_extensions [p ] and extension in self .extensions :
164
232
return True
@@ -167,30 +235,63 @@ def support_extension(self, extension: str) -> bool:
167
235
168
236
169
237
class TosaLoweringContext :
170
- """
171
- A context manager to handle the TOSA specific aspects of the lowering process.
172
- For now it only handles the TOSA specification context, but it can be extended
173
- to include other policies or configurations.
238
+ """Manage the TOSA specification context for lowering.
239
+
240
+ For now, only the active ``TosaSpecification`` is tracked, but this can be
241
+ extended to carry additional lowering policies or configuration.
242
+
243
+ Attributes:
244
+ tosa_spec_var (contextvars.ContextVar): Context variable storing the active spec.
245
+ spec (TosaSpecification): Specification passed to the context manager.
246
+
174
247
"""
175
248
176
249
# Define a context variable for the spec
177
250
tosa_spec_var : contextvars .ContextVar = contextvars .ContextVar ("tosa_spec" )
178
251
179
252
def __init__ (self , spec : TosaSpecification ):
253
+ """Initialize the lowering context with a specification.
254
+
255
+ Args:
256
+ spec (TosaSpecification): Active specification to put into context.
257
+
258
+ """
180
259
self .spec = spec
181
260
182
261
def __enter__ (self ):
262
+ """Set the context variable and return self.
263
+
264
+ Returns:
265
+ TosaLoweringContext: This context manager instance.
266
+
267
+ """
183
268
# Set the spec in the context variable and store the token for later reset
184
269
self .token = TosaLoweringContext .tosa_spec_var .set (self .spec )
185
270
return self
186
271
187
272
def __exit__ (self , exc_type , exc_value , traceback ):
273
+ """Reset the context variable to its previous state.
274
+
275
+ Args:
276
+ exc_type (type | None): Exception type, if any.
277
+ exc_value (BaseException | None): Exception instance, if any.
278
+ traceback (TracebackType | None): Traceback, if any.
279
+
280
+ """
188
281
# Reset the context variable to its previous state
189
282
TosaLoweringContext .tosa_spec_var .reset (self .token )
190
283
191
284
192
- # A helper function to retrieve the current spec anywhere in your code
193
285
def get_context_spec () -> TosaSpecification :
286
+ """Get the current ``TosaSpecification`` from the lowering context.
287
+
288
+ Returns:
289
+ TosaSpecification: Active specification retrieved from the context var.
290
+
291
+ Raises:
292
+ RuntimeError: If called outside a ``TosaLoweringContext``.
293
+
294
+ """
194
295
try :
195
296
return TosaLoweringContext .tosa_spec_var .get ()
196
297
except LookupError :
0 commit comments