Skip to content

Commit 024a9cd

Browse files
authored
Add type checks to Attr methods (#2310)
Add type checks and raise `TypeError` in `Attr` class methods in `onnxscript/ir/_core.py`. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/microsoft/onnxscript/pull/2310?shareId=249b81eb-c684-4866-81f8-a62209ca79d4).
1 parent b90c1ad commit 024a9cd

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

onnxscript/ir/_core.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3183,34 +3183,58 @@ def __repr__(self) -> str:
31833183
# Well typed getters
31843184
def as_float(self) -> float:
31853185
"""Get the attribute value as a float."""
3186+
if self.type != _enums.AttributeType.FLOAT:
3187+
raise TypeError(
3188+
f"Attribute '{self.name}' is not of type FLOAT. Actual type: {self.type}"
3189+
)
31863190
# Do not use isinstance check because it may prevent np.float32 etc. from being used
31873191
return float(self.value)
31883192

31893193
def as_int(self) -> int:
31903194
"""Get the attribute value as an int."""
3195+
if self.type != _enums.AttributeType.INT:
3196+
raise TypeError(
3197+
f"Attribute '{self.name}' is not of type INT. Actual type: {self.type}"
3198+
)
31913199
# Do not use isinstance check because it may prevent np.int32 etc. from being used
31923200
return int(self.value)
31933201

31943202
def as_string(self) -> str:
31953203
"""Get the attribute value as a string."""
3204+
if self.type != _enums.AttributeType.STRING:
3205+
raise TypeError(
3206+
f"Attribute '{self.name}' is not of type STRING. Actual type: {self.type}"
3207+
)
31963208
if not isinstance(self.value, str):
31973209
raise TypeError(f"Value of attribute '{self!r}' is not a string.")
31983210
return self.value
31993211

32003212
def as_tensor(self) -> _protocols.TensorProtocol:
32013213
"""Get the attribute value as a tensor."""
3214+
if self.type != _enums.AttributeType.TENSOR:
3215+
raise TypeError(
3216+
f"Attribute '{self.name}' is not of type TENSOR. Actual type: {self.type}"
3217+
)
32023218
if not isinstance(self.value, _protocols.TensorProtocol):
32033219
raise TypeError(f"Value of attribute '{self!r}' is not a tensor.")
32043220
return self.value
32053221

32063222
def as_graph(self) -> Graph:
32073223
"""Get the attribute value as a graph."""
3224+
if self.type != _enums.AttributeType.GRAPH:
3225+
raise TypeError(
3226+
f"Attribute '{self.name}' is not of type GRAPH. Actual type: {self.type}"
3227+
)
32083228
if not isinstance(self.value, Graph):
32093229
raise TypeError(f"Value of attribute '{self!r}' is not a graph.")
32103230
return self.value
32113231

32123232
def as_floats(self) -> Sequence[float]:
32133233
"""Get the attribute value as a sequence of floats."""
3234+
if self.type != _enums.AttributeType.FLOATS:
3235+
raise TypeError(
3236+
f"Attribute '{self.name}' is not of type FLOATS. Actual type: {self.type}"
3237+
)
32143238
if not isinstance(self.value, Sequence):
32153239
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
32163240
# Do not use isinstance check on elements because it may prevent np.int32 etc. from being used
@@ -3219,6 +3243,10 @@ def as_floats(self) -> Sequence[float]:
32193243

32203244
def as_ints(self) -> Sequence[int]:
32213245
"""Get the attribute value as a sequence of ints."""
3246+
if self.type != _enums.AttributeType.INTS:
3247+
raise TypeError(
3248+
f"Attribute '{self.name}' is not of type INTS. Actual type: {self.type}"
3249+
)
32223250
if not isinstance(self.value, Sequence):
32233251
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
32243252
# Do not use isinstance check on elements because it may prevent np.int32 etc. from being used
@@ -3227,6 +3255,10 @@ def as_ints(self) -> Sequence[int]:
32273255

32283256
def as_strings(self) -> Sequence[str]:
32293257
"""Get the attribute value as a sequence of strings."""
3258+
if self.type != _enums.AttributeType.STRINGS:
3259+
raise TypeError(
3260+
f"Attribute '{self.name}' is not of type STRINGS. Actual type: {self.type}"
3261+
)
32303262
if not isinstance(self.value, Sequence):
32313263
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
32323264
if onnxscript.DEBUG:
@@ -3237,6 +3269,10 @@ def as_strings(self) -> Sequence[str]:
32373269

32383270
def as_tensors(self) -> Sequence[_protocols.TensorProtocol]:
32393271
"""Get the attribute value as a sequence of tensors."""
3272+
if self.type != _enums.AttributeType.TENSORS:
3273+
raise TypeError(
3274+
f"Attribute '{self.name}' is not of type TENSORS. Actual type: {self.type}"
3275+
)
32403276
if not isinstance(self.value, Sequence):
32413277
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
32423278
if onnxscript.DEBUG:
@@ -3247,6 +3283,10 @@ def as_tensors(self) -> Sequence[_protocols.TensorProtocol]:
32473283

32483284
def as_graphs(self) -> Sequence[Graph]:
32493285
"""Get the attribute value as a sequence of graphs."""
3286+
if self.type != _enums.AttributeType.GRAPHS:
3287+
raise TypeError(
3288+
f"Attribute '{self.name}' is not of type GRAPHS. Actual type: {self.type}"
3289+
)
32503290
if not isinstance(self.value, Sequence):
32513291
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
32523292
if onnxscript.DEBUG:

onnxscript/ir/_core_test.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1714,6 +1714,56 @@ def test_as_graphs(self):
17141714
attr = _core.Attr("test", ir.AttributeType.GRAPHS, [_core.Graph((), (), nodes=())])
17151715
self.assertIsInstance(attr.as_graphs()[0], _core.Graph)
17161716

1717+
def test_as_float_type_error(self):
1718+
attr = _core.Attr("test", ir.AttributeType.INT, 42)
1719+
with self.assertRaises(TypeError):
1720+
attr.as_float()
1721+
1722+
def test_as_int_type_error(self):
1723+
attr = _core.Attr("test", ir.AttributeType.FLOAT, 42.0)
1724+
with self.assertRaises(TypeError):
1725+
attr.as_int()
1726+
1727+
def test_as_string_type_error(self):
1728+
attr = _core.Attr("test", ir.AttributeType.INT, 42)
1729+
with self.assertRaises(TypeError):
1730+
attr.as_string()
1731+
1732+
def test_as_tensor_type_error(self):
1733+
attr = _core.Attr("test", ir.AttributeType.INT, 42)
1734+
with self.assertRaises(TypeError):
1735+
attr.as_tensor()
1736+
1737+
def test_as_graph_type_error(self):
1738+
attr = _core.Attr("test", ir.AttributeType.INT, 42)
1739+
with self.assertRaises(TypeError):
1740+
attr.as_graph()
1741+
1742+
def test_as_floats_type_error(self):
1743+
attr = _core.Attr("test", ir.AttributeType.INT, 42)
1744+
with self.assertRaises(TypeError):
1745+
attr.as_floats()
1746+
1747+
def test_as_ints_type_error(self):
1748+
attr = _core.Attr("test", ir.AttributeType.FLOAT, 42.0)
1749+
with self.assertRaises(TypeError):
1750+
attr.as_ints()
1751+
1752+
def test_as_strings_type_error(self):
1753+
attr = _core.Attr("test", ir.AttributeType.INT, 42)
1754+
with self.assertRaises(TypeError):
1755+
attr.as_strings()
1756+
1757+
def test_as_tensors_type_error(self):
1758+
attr = _core.Attr("test", ir.AttributeType.INT, 42)
1759+
with self.assertRaises(TypeError):
1760+
attr.as_tensors()
1761+
1762+
def test_as_graphs_type_error(self):
1763+
attr = _core.Attr("test", ir.AttributeType.INT, 42)
1764+
with self.assertRaises(TypeError):
1765+
attr.as_graphs()
1766+
17171767

17181768
class LazyTensorTest(unittest.TestCase):
17191769
def test_lazy_tensor_initialization(self):

0 commit comments

Comments
 (0)