Skip to content

Add type checks to Attr methods #2310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 27, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
@@ -3183,34 +3183,58 @@ def __repr__(self) -> str:
# Well typed getters
def as_float(self) -> float:
"""Get the attribute value as a float."""
if self.type != _enums.AttributeType.FLOAT:
raise TypeError(
f"Attribute '{self.name}' is not of type FLOAT. Actual type: {self.type}"
)
# Do not use isinstance check because it may prevent np.float32 etc. from being used
return float(self.value)

def as_int(self) -> int:
"""Get the attribute value as an int."""
if self.type != _enums.AttributeType.INT:
raise TypeError(
f"Attribute '{self.name}' is not of type INT. Actual type: {self.type}"
)
# Do not use isinstance check because it may prevent np.int32 etc. from being used
return int(self.value)

def as_string(self) -> str:
"""Get the attribute value as a string."""
if self.type != _enums.AttributeType.STRING:
raise TypeError(
f"Attribute '{self.name}' is not of type STRING. Actual type: {self.type}"
)
if not isinstance(self.value, str):
raise TypeError(f"Value of attribute '{self!r}' is not a string.")
return self.value

def as_tensor(self) -> _protocols.TensorProtocol:
"""Get the attribute value as a tensor."""
if self.type != _enums.AttributeType.TENSOR:
raise TypeError(
f"Attribute '{self.name}' is not of type TENSOR. Actual type: {self.type}"
)
if not isinstance(self.value, _protocols.TensorProtocol):
raise TypeError(f"Value of attribute '{self!r}' is not a tensor.")
return self.value

def as_graph(self) -> Graph:
"""Get the attribute value as a graph."""
if self.type != _enums.AttributeType.GRAPH:
raise TypeError(
f"Attribute '{self.name}' is not of type GRAPH. Actual type: {self.type}"
)
if not isinstance(self.value, Graph):
raise TypeError(f"Value of attribute '{self!r}' is not a graph.")
return self.value

def as_floats(self) -> Sequence[float]:
"""Get the attribute value as a sequence of floats."""
if self.type != _enums.AttributeType.FLOATS:
raise TypeError(
f"Attribute '{self.name}' is not of type FLOATS. Actual type: {self.type}"
)
if not isinstance(self.value, Sequence):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
# Do not use isinstance check on elements because it may prevent np.int32 etc. from being used
@@ -3219,6 +3243,10 @@ def as_floats(self) -> Sequence[float]:

def as_ints(self) -> Sequence[int]:
"""Get the attribute value as a sequence of ints."""
if self.type != _enums.AttributeType.INTS:
raise TypeError(
f"Attribute '{self.name}' is not of type INTS. Actual type: {self.type}"
)
if not isinstance(self.value, Sequence):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
# Do not use isinstance check on elements because it may prevent np.int32 etc. from being used
@@ -3227,6 +3255,10 @@ def as_ints(self) -> Sequence[int]:

def as_strings(self) -> Sequence[str]:
"""Get the attribute value as a sequence of strings."""
if self.type != _enums.AttributeType.STRINGS:
raise TypeError(
f"Attribute '{self.name}' is not of type STRINGS. Actual type: {self.type}"
)
if not isinstance(self.value, Sequence):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
if onnxscript.DEBUG:
@@ -3237,6 +3269,10 @@ def as_strings(self) -> Sequence[str]:

def as_tensors(self) -> Sequence[_protocols.TensorProtocol]:
"""Get the attribute value as a sequence of tensors."""
if self.type != _enums.AttributeType.TENSORS:
raise TypeError(
f"Attribute '{self.name}' is not of type TENSORS. Actual type: {self.type}"
)
if not isinstance(self.value, Sequence):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
if onnxscript.DEBUG:
@@ -3247,6 +3283,10 @@ def as_tensors(self) -> Sequence[_protocols.TensorProtocol]:

def as_graphs(self) -> Sequence[Graph]:
"""Get the attribute value as a sequence of graphs."""
if self.type != _enums.AttributeType.GRAPHS:
raise TypeError(
f"Attribute '{self.name}' is not of type GRAPHS. Actual type: {self.type}"
)
if not isinstance(self.value, Sequence):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
if onnxscript.DEBUG:
50 changes: 50 additions & 0 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
@@ -1694,6 +1694,56 @@ def test_as_graphs(self):
attr = _core.Attr("test", ir.AttributeType.GRAPHS, [_core.Graph((), (), nodes=())])
self.assertIsInstance(attr.as_graphs()[0], _core.Graph)

def test_as_float_type_error(self):
attr = _core.Attr("test", ir.AttributeType.INT, 42)
with self.assertRaises(TypeError):
attr.as_float()

def test_as_int_type_error(self):
attr = _core.Attr("test", ir.AttributeType.FLOAT, 42.0)
with self.assertRaises(TypeError):
attr.as_int()

def test_as_string_type_error(self):
attr = _core.Attr("test", ir.AttributeType.INT, 42)
with self.assertRaises(TypeError):
attr.as_string()

def test_as_tensor_type_error(self):
attr = _core.Attr("test", ir.AttributeType.INT, 42)
with self.assertRaises(TypeError):
attr.as_tensor()

def test_as_graph_type_error(self):
attr = _core.Attr("test", ir.AttributeType.INT, 42)
with self.assertRaises(TypeError):
attr.as_graph()

def test_as_floats_type_error(self):
attr = _core.Attr("test", ir.AttributeType.INT, 42)
with self.assertRaises(TypeError):
attr.as_floats()

def test_as_ints_type_error(self):
attr = _core.Attr("test", ir.AttributeType.FLOAT, 42.0)
with self.assertRaises(TypeError):
attr.as_ints()

def test_as_strings_type_error(self):
attr = _core.Attr("test", ir.AttributeType.INT, 42)
with self.assertRaises(TypeError):
attr.as_strings()

def test_as_tensors_type_error(self):
attr = _core.Attr("test", ir.AttributeType.INT, 42)
with self.assertRaises(TypeError):
attr.as_tensors()

def test_as_graphs_type_error(self):
attr = _core.Attr("test", ir.AttributeType.INT, 42)
with self.assertRaises(TypeError):
attr.as_graphs()


class LazyTensorTest(unittest.TestCase):
def test_lazy_tensor_initialization(self):
Loading
Oops, something went wrong.