From a607260e06f08f25d8406b8dead9f82c759c1c57 Mon Sep 17 00:00:00 2001 From: samwaseda Date: Fri, 24 Jan 2025 18:53:33 +0000 Subject: [PATCH 1/5] Update tests --- pyiron_workflow/type_hinting.py | 26 +++++++++++++++++--------- tests/unit/test_type_hinting.py | 23 +++++++++++++++++++++++ 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/pyiron_workflow/type_hinting.py b/pyiron_workflow/type_hinting.py index ee109a949..fd8a1e297 100644 --- a/pyiron_workflow/type_hinting.py +++ b/pyiron_workflow/type_hinting.py @@ -49,31 +49,39 @@ def type_hint_to_tuple(type_hint) -> tuple: return (type_hint,) +def _get_type_hints(type_hint) -> typing.Union[type, None]: + hint = typing.get_origin(type_hint) + if hint is typing.Annotated: + return typing.get_origin(type_hint.__origin__), type_hint.__origin__ + else: + return hint, type_hint + + def type_hint_is_as_or_more_specific_than(hint, other) -> bool: - hint_origin = typing.get_origin(hint) - other_origin = typing.get_origin(other) + hint_origin, hint_type = _get_type_hints(hint) + other_origin, other_type = _get_type_hints(other) if {hint_origin, other_origin} & {types.UnionType, typing.Union}: # If either hint is a union, turn both into tuples and call recursively return all( any( type_hint_is_as_or_more_specific_than(h, o) - for o in type_hint_to_tuple(other) + for o in type_hint_to_tuple(other_type) ) - for h in type_hint_to_tuple(hint) + for h in type_hint_to_tuple(hint_type) ) elif hint_origin is None and other_origin is None: # Once both are raw classes, just do a subclass test try: - return issubclass(hint, other) + return issubclass(hint_type, other_type) except TypeError: - return hint == other + return hint_type == other_type elif other_origin is None and hint_origin is not None: # When the hint adds specificity to an empty origin - return hint_origin == other + return hint_origin == other_type elif hint_origin == other_origin: # If they both have an origin, break into arguments and treat cases - hint_args = typing.get_args(hint) - other_args = typing.get_args(other) + hint_args = typing.get_args(hint_type) + other_args = typing.get_args(other_type) if len(hint_args) == 0 and len(other_args) > 0: # Failing to specify anything is not being more specific return False diff --git a/tests/unit/test_type_hinting.py b/tests/unit/test_type_hinting.py index fc20ce1e9..a10343f4d 100644 --- a/tests/unit/test_type_hinting.py +++ b/tests/unit/test_type_hinting.py @@ -6,6 +6,7 @@ from pyiron_workflow.type_hinting import ( type_hint_is_as_or_more_specific_than, valid_value, + _get_type_hints, ) @@ -95,6 +96,16 @@ def test_hint_comparisons(self): typing.Callable[[int, float], float], False, ), + ( + typing.Annotated[int, "foo"], + int, + True, + ), + ( + int, + typing.Annotated[int, "foo"], + True, + ), ]: with self.subTest( target=target, reference=reference, expected=is_more_specific @@ -102,8 +113,20 @@ def test_hint_comparisons(self): self.assertEqual( type_hint_is_as_or_more_specific_than(target, reference), is_more_specific, + msg=f"{target} is {'not ' if not is_more_specific else ''}more specific than {reference}", ) + def test_get_type_hints(self): + for hint, origin in [ + (int | float, type(int| float)), + (typing.Annotated[int | float, "foo"], type(int | float)), + (int, None), + (typing.Annotated[int, "foo"], None), + ]: + with self.subTest(hint=hint, origin=origin): + self.assertEqual(_get_type_hints(hint)[0], origin) + + if __name__ == "__main__": unittest.main() From 753112f88f6e779d8132d52ceb28a358c34c1514 Mon Sep 17 00:00:00 2001 From: samwaseda Date: Sat, 25 Jan 2025 11:18:29 +0000 Subject: [PATCH 2/5] Update type checking --- pyiron_workflow/type_hinting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyiron_workflow/type_hinting.py b/pyiron_workflow/type_hinting.py index fd8a1e297..1399514c7 100644 --- a/pyiron_workflow/type_hinting.py +++ b/pyiron_workflow/type_hinting.py @@ -49,7 +49,7 @@ def type_hint_to_tuple(type_hint) -> tuple: return (type_hint,) -def _get_type_hints(type_hint) -> typing.Union[type, None]: +def _get_type_hints(type_hint) -> tuple[typing.Optional[type], typing.Any]: hint = typing.get_origin(type_hint) if hint is typing.Annotated: return typing.get_origin(type_hint.__origin__), type_hint.__origin__ From a336818ab887af9fcfe589130a04b8f1afc0643f Mon Sep 17 00:00:00 2001 From: samwaseda Date: Sat, 25 Jan 2025 11:20:05 +0000 Subject: [PATCH 3/5] Apply what ruff says --- pyiron_workflow/type_hinting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyiron_workflow/type_hinting.py b/pyiron_workflow/type_hinting.py index 1399514c7..0b32804c9 100644 --- a/pyiron_workflow/type_hinting.py +++ b/pyiron_workflow/type_hinting.py @@ -49,7 +49,7 @@ def type_hint_to_tuple(type_hint) -> tuple: return (type_hint,) -def _get_type_hints(type_hint) -> tuple[typing.Optional[type], typing.Any]: +def _get_type_hints(type_hint) -> tuple[type | None, typing.Any]: hint = typing.get_origin(type_hint) if hint is typing.Annotated: return typing.get_origin(type_hint.__origin__), type_hint.__origin__ From 8f641bba6c6ef0a49f9d84da6e2d88fbc278562a Mon Sep 17 00:00:00 2001 From: samwaseda Date: Sat, 25 Jan 2025 11:23:23 +0000 Subject: [PATCH 4/5] I don't know what the error says but I'm guessing that the order matters? --- tests/unit/test_type_hinting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_type_hinting.py b/tests/unit/test_type_hinting.py index a10343f4d..c98d0d8d8 100644 --- a/tests/unit/test_type_hinting.py +++ b/tests/unit/test_type_hinting.py @@ -4,9 +4,9 @@ from pint import UnitRegistry from pyiron_workflow.type_hinting import ( + _get_type_hints, type_hint_is_as_or_more_specific_than, valid_value, - _get_type_hints, ) From b34d0400df3ceb85fb184f6ebb29abfc7b0b4071 Mon Sep 17 00:00:00 2001 From: Sam Dareska <37879103+samwaseda@users.noreply.github.com> Date: Thu, 30 Jan 2025 23:15:09 +0100 Subject: [PATCH 5/5] Update tests/unit/test_type_hinting.py Co-authored-by: Liam Huber --- tests/unit/test_type_hinting.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/test_type_hinting.py b/tests/unit/test_type_hinting.py index c98d0d8d8..e4e3454d3 100644 --- a/tests/unit/test_type_hinting.py +++ b/tests/unit/test_type_hinting.py @@ -122,6 +122,8 @@ def test_get_type_hints(self): (typing.Annotated[int | float, "foo"], type(int | float)), (int, None), (typing.Annotated[int, "foo"], None), + (typing.Annotated[list[int], "foo"], list), + (list[int], list), ]: with self.subTest(hint=hint, origin=origin): self.assertEqual(_get_type_hints(hint)[0], origin)