diff --git a/pyiron_workflow/type_hinting.py b/pyiron_workflow/type_hinting.py index ee109a949..0b32804c9 100644 --- a/pyiron_workflow/type_hinting.py +++ b/pyiron_workflow/type_hinting.py @@ -49,31 +49,39 @@ def type_hint_to_tuple(type_hint) -> tuple: return (type_hint,) +def _get_type_hints(type_hint) -> tuple[type | None, typing.Any]: + hint = typing.get_origin(type_hint) + if hint is typing.Annotated: + return typing.get_origin(type_hint.__origin__), type_hint.__origin__ + else: + return hint, type_hint + + def type_hint_is_as_or_more_specific_than(hint, other) -> bool: - hint_origin = typing.get_origin(hint) - other_origin = typing.get_origin(other) + hint_origin, hint_type = _get_type_hints(hint) + other_origin, other_type = _get_type_hints(other) if {hint_origin, other_origin} & {types.UnionType, typing.Union}: # If either hint is a union, turn both into tuples and call recursively return all( any( type_hint_is_as_or_more_specific_than(h, o) - for o in type_hint_to_tuple(other) + for o in type_hint_to_tuple(other_type) ) - for h in type_hint_to_tuple(hint) + for h in type_hint_to_tuple(hint_type) ) elif hint_origin is None and other_origin is None: # Once both are raw classes, just do a subclass test try: - return issubclass(hint, other) + return issubclass(hint_type, other_type) except TypeError: - return hint == other + return hint_type == other_type elif other_origin is None and hint_origin is not None: # When the hint adds specificity to an empty origin - return hint_origin == other + return hint_origin == other_type elif hint_origin == other_origin: # If they both have an origin, break into arguments and treat cases - hint_args = typing.get_args(hint) - other_args = typing.get_args(other) + hint_args = typing.get_args(hint_type) + other_args = typing.get_args(other_type) if len(hint_args) == 0 and len(other_args) > 0: # Failing to specify anything is not being more specific return False diff --git a/tests/unit/test_type_hinting.py b/tests/unit/test_type_hinting.py index fc20ce1e9..e4e3454d3 100644 --- a/tests/unit/test_type_hinting.py +++ b/tests/unit/test_type_hinting.py @@ -4,6 +4,7 @@ from pint import UnitRegistry from pyiron_workflow.type_hinting import ( + _get_type_hints, type_hint_is_as_or_more_specific_than, valid_value, ) @@ -95,6 +96,16 @@ def test_hint_comparisons(self): typing.Callable[[int, float], float], False, ), + ( + typing.Annotated[int, "foo"], + int, + True, + ), + ( + int, + typing.Annotated[int, "foo"], + True, + ), ]: with self.subTest( target=target, reference=reference, expected=is_more_specific @@ -102,8 +113,22 @@ def test_hint_comparisons(self): self.assertEqual( type_hint_is_as_or_more_specific_than(target, reference), is_more_specific, + msg=f"{target} is {'not ' if not is_more_specific else ''}more specific than {reference}", ) + def test_get_type_hints(self): + for hint, origin in [ + (int | float, type(int| float)), + (typing.Annotated[int | float, "foo"], type(int | float)), + (int, None), + (typing.Annotated[int, "foo"], None), + (typing.Annotated[list[int], "foo"], list), + (list[int], list), + ]: + with self.subTest(hint=hint, origin=origin): + self.assertEqual(_get_type_hints(hint)[0], origin) + + if __name__ == "__main__": unittest.main()