From 3d7289995f1c2915464d4b41db1546469a6b8ea9 Mon Sep 17 00:00:00 2001 From: Skyler Curtis Date: Mon, 31 Mar 2025 20:33:07 -0400 Subject: [PATCH 1/9] TDD: write tests first tests that match the desired behavior for PEP-585 types passed to singledispatch --- Lib/test/test_functools.py | 118 +++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 2b49615178f136..9e25c3d9804bd7 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -1,3 +1,4 @@ +from __future__ import annotations import abc import builtins import collections @@ -2136,6 +2137,123 @@ def cached_staticmeth(x, y): class TestSingleDispatch(unittest.TestCase): + + def test_pep585_basic(self): + @functools.singledispatch + def g(obj): + return "base" + def g_list_int(li): + return "list of ints" + # previously this failed with: 'not a class' + g.register(list[int], g_list_int) + self.assertEqual(g([1]), "list of ints") + self.assertIs(g.dispatch(list[int]), g_list_int) + + def test_pep585_annotation(self): + @functools.singledispatch + def g(obj): + return "base" + # previously this failed with: 'not a class' + @g.register + def g_list_int(li: list[int]): + return "list of ints" + self.assertEqual(g([1,2,3]), "list of ints") + self.assertIs(g.dispatch(tuple[int]), g_list_int) + + def test_pep585_all_must_match(self): + @functools.singledispatch + def g(obj): + return "base" + def g_list_int(li): + return "list of ints" + def g_list_not_ints(l): + # should only trigger if list doesnt match `list[int]` + # ie. at least one element is not an int + return "!all(int)" + + g.register(list[int], g_list_int) + g.register(list, g_list_not_ints) + + self.assertEqual(g([1,2,3]), "list of ints") + self.assertEqual(g([1,2,3, "hello"]), "!all(int)") + self.assertEqual(g([3.14]), "!all(int)") + + self.assertIs(g.dispatch(list[int]), g_list_int) + self.assertIs(g.dispatch(list[str]), g_list_not_ints) + self.assertIs(g.dispatch(list[float]), g_list_not_ints) + self.assertIs(g.dispatch(list[int|str]), g_list_not_ints) + + def test_pep585_specificity(self): + @functools.singledispatch + def g(obj): + return "base" + @g.register + def g_list(l: list): + return "basic list" + @g.register + def g_list_int(li: list[int]): + return "int" + @g.register + def g_list_str(ls: list[str]): + return "str" + @g.register + def g_list_mixed_int_str(lmis:list[int|str]): + return "int|str" + @g.register + def g_list_mixed_int_float(lmif: list[int|float]): + return "int|float" + @g.register + def g_list_mixed_int_float_str(lmifs: list[int|float|str]): + return "int|float|str" + + # this matches list, list[int], list[int|str], list[int|float|str], list[int|...|...|...|...] + # but list[int] is the most specific, so that is correct + self.assertEqual(g([1,2,3]), "int") + + # this cannot match list[int] because of the string + # it does match list[int|float|str] but this is incorrect because, + # the most specific is list[int|str] + self.assertEqual(g([1,2,3, "hello"]), "int|str") + + # list[float] is not mapped so, + # list[int|float] is the most specific + self.assertEqual(g([3.14]), "int|float") + + self.assertIs(g.dispatch(list[int]), g_list_int) + self.assertIs(g.dispatch(list[float]), g_list_mixed_int_float) + self.assertIs(g.dispatch(list[int|str]), g_list_mixed_int_str) + + def test_pep585_ambiguous(self): + @functools.singledispatch + def g(obj): + return "base" + @g.register + def g_list_int_float(l: list[int|float]): + return "int|float" + @g.register + def g_list_int_str(l: list[int|str]): + return "int|str" + @g.register + def g_list_int(l: list[int]): + return "int only" + + self.assertEqual(g([3.1]), "int|float") # floats only + self.assertEqual(g(["hello"]), "int|str") # strings only + self.assertEqual(g([3.14, 1]), "int|float") # ints and floats + self.assertEqual(g(["hello", 1]), "int|str") # ints and strings + + self.assertIs(g.dispatch(list[int]), g_list_int) + self.assertIs(g.dispatch(list[str]), g_list_int_str) + self.assertIs(g.dispatch(list[float]), g_list_int_float) + self.assertIs(g.dispatch(list[int|str]), g_list_int_str) + self.assertIs(g.dispatch(list[int|float]), g_list_int_float) + + # these should fail because it's unclear which target is "correct" + with self.assertRaises(RuntimeError): + g([1]) + + self.assertRaises(RuntimeError, g.dispatch(list[int])) + def test_simple_overloads(self): @functools.singledispatch def g(obj): From 20616d9a33579d8d6dc5c944887edca77a5594f8 Mon Sep 17 00:00:00 2001 From: Skyler Curtis Date: Tue, 1 Apr 2025 06:31:48 -0400 Subject: [PATCH 2/9] _is_valid_dispatch_type GenericAlias --- Lib/functools.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/Lib/functools.py b/Lib/functools.py index 714070c6ac9460..6f379eb623ca05 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -913,6 +913,11 @@ def dispatch(cls): def _is_valid_dispatch_type(cls): if isinstance(cls, type): return True + + if isinstance(cls, GenericAlias): + from typing import get_args + return all(isinstance(arg, (type, UnionType)) for arg in get_args(cls)) + return (isinstance(cls, UnionType) and all(isinstance(arg, type) for arg in cls.__args__)) From dc94cbae9a689fa555d5748cacd169406e7f608b Mon Sep 17 00:00:00 2001 From: Skyler Curtis Date: Tue, 1 Apr 2025 05:55:35 -0400 Subject: [PATCH 3/9] _dispatch --- Lib/functools.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/Lib/functools.py b/Lib/functools.py index 6f379eb623ca05..0fcb4f96800e96 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -887,13 +887,14 @@ def singledispatch(func): dispatch_cache = weakref.WeakKeyDictionary() cache_token = None - def dispatch(cls): + def dispatch(cls_obj): """generic_func.dispatch(cls) -> Runs the dispatch algorithm to return the best available implementation for the given *cls* registered on *generic_func*. """ + cls = cls_obj.__class__ nonlocal cache_token if cache_token is not None: current_token = get_cache_token() @@ -981,7 +982,7 @@ def wrapper(*args, **kw): if not args: raise TypeError(f'{funcname} requires at least ' '1 positional argument') - return dispatch(args[0].__class__)(*args, **kw) + return dispatch(args[0])(*args, **kw) funcname = getattr(func, '__name__', 'singledispatch function') registry[object] = func @@ -1069,7 +1070,7 @@ def __call__(self, /, *args, **kwargs): 'singledispatchmethod method') raise TypeError(f'{funcname} requires at least ' '1 positional argument') - return self._dispatch(args[0].__class__).__get__(self._obj, self._cls)(*args, **kwargs) + return self._dispatch(args[0]).__get__(self._obj, self._cls)(*args, **kwargs) def __getattr__(self, name): # Resolve these attributes lazily to speed up creation of From 609858ff25d26066a50b34368921176de647435c Mon Sep 17 00:00:00 2001 From: Skyler Curtis Date: Tue, 1 Apr 2025 06:08:47 -0400 Subject: [PATCH 4/9] _find_impl --- Lib/functools.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Lib/functools.py b/Lib/functools.py index 0fcb4f96800e96..d7939274b74475 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -843,8 +843,8 @@ def is_strict_base(typ): mro.append(subcls) return _c3_mro(cls, abcs=mro) -def _find_impl(cls, registry): - """Returns the best matching implementation from *registry* for type *cls*. +def _find_impl(cls_obj, registry): + """Returns the best matching implementation from *registry* for type *cls_obj*. Where there is no registered implementation for a specific type, its method resolution order is used to find a more generic implementation. @@ -853,6 +853,7 @@ def _find_impl(cls, registry): *object* type, this function may return None. """ + cls = cls_obj if isinstance(cls_obj, type) else cls_obj.__class__ mro = _compose_mro(cls, registry.keys()) match = None for t in mro: From 3c2eefe991be552037b54aa748d80b142c60daee Mon Sep 17 00:00:00 2001 From: Skyler Curtis Date: Tue, 1 Apr 2025 06:06:32 -0400 Subject: [PATCH 5/9] ignore cache for pep585 --- Lib/functools.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/Lib/functools.py b/Lib/functools.py index d7939274b74475..5bc30dd1472aba 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -902,9 +902,13 @@ def dispatch(cls_obj): if cache_token != current_token: dispatch_cache.clear() cache_token = current_token - try: - impl = dispatch_cache[cls] - except KeyError: + + + # if PEP-585 types are not registered for the given *cls*, + # then we can use the cache. Otherwise, the cache cannot be used + # because we need to confirm every item matches first + from typing import get_origin + if not any(i for i in registry.keys() if get_origin(i) == cls): try: impl = registry[cls] except KeyError: From 386d6f7955ede829fc859eb6491c00bd8a505e56 Mon Sep 17 00:00:00 2001 From: Skyler Curtis Date: Tue, 1 Apr 2025 11:32:03 -0400 Subject: [PATCH 6/9] _find_impl_match --- Lib/functools.py | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/Lib/functools.py b/Lib/functools.py index 5bc30dd1472aba..2de77f8e6b4c58 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -843,7 +843,7 @@ def is_strict_base(typ): mro.append(subcls) return _c3_mro(cls, abcs=mro) -def _find_impl(cls_obj, registry): +def _find_impl_match(cls_obj, registry): """Returns the best matching implementation from *registry* for type *cls_obj*. Where there is no registered implementation for a specific type, its method @@ -856,6 +856,32 @@ def _find_impl(cls_obj, registry): cls = cls_obj if isinstance(cls_obj, type) else cls_obj.__class__ mro = _compose_mro(cls, registry.keys()) match = None + + from typing import get_origin, get_args + + if (not isinstance(cls_obj, type) and + len(cls_obj) > 0 and # dont try to match the types of empty containers + any(i for i in registry.keys() if get_origin(i) == cls)): + # check containers that match cls first + for t in [i for i in registry.keys() if get_origin(i) == cls]: + if not all((isinstance(i, get_args(t)) for i in cls_obj)): + continue + + if match is None: + match = t + + else: + match_args = get_args(get_args(match)[0]) + t_args = get_args(get_args(t)[0]) + if len(match_args) == len(t_args): + raise RuntimeError("Ambiguous dispatch: {} or {}".format( match, t)) + + elif len(t_args) Date: Tue, 1 Apr 2025 18:11:15 -0400 Subject: [PATCH 7/9] finalizing --- Lib/functools.py | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/Lib/functools.py b/Lib/functools.py index 2de77f8e6b4c58..3d0e47aa8635af 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -843,6 +843,10 @@ def is_strict_base(typ): mro.append(subcls) return _c3_mro(cls, abcs=mro) +def _pep585_registry_matches(cls, registry): + from typing import get_origin + return (i for i in registry.keys() if get_origin(i) == cls) + def _find_impl_match(cls_obj, registry): """Returns the best matching implementation from *registry* for type *cls_obj*. @@ -861,9 +865,9 @@ def _find_impl_match(cls_obj, registry): if (not isinstance(cls_obj, type) and len(cls_obj) > 0 and # dont try to match the types of empty containers - any(i for i in registry.keys() if get_origin(i) == cls)): + any(_pep585_registry_matches(cls, registry))): # check containers that match cls first - for t in [i for i in registry.keys() if get_origin(i) == cls]: + for t in _pep585_registry_matches(cls, registry): if not all((isinstance(i, get_args(t)) for i in cls_obj)): continue @@ -898,10 +902,11 @@ def _find_impl_match(cls_obj, registry): return match def _find_impl(cls_obj, registry): - return ( + return registry.get( _find_impl_match(cls_obj, registry) ) + def singledispatch(func): """Single-dispatch generic function decorator. @@ -920,6 +925,18 @@ def singledispatch(func): dispatch_cache = weakref.WeakKeyDictionary() cache_token = None + def _fetch_dispatch_with_cache(cls): + try: + impl = dispatch_cache[cls] + except KeyError: + try: + impl = registry[cls] + except KeyError: + impl = _find_impl(cls, registry) + dispatch_cache[cls] = impl + return impl + + def dispatch(cls_obj): """generic_func.dispatch(cls) -> @@ -935,18 +952,13 @@ def dispatch(cls_obj): dispatch_cache.clear() cache_token = current_token - # if PEP-585 types are not registered for the given *cls*, # then we can use the cache. Otherwise, the cache cannot be used # because we need to confirm every item matches first - from typing import get_origin - if not any(i for i in registry.keys() if get_origin(i) == cls): - try: - impl = registry[cls] - except KeyError: - impl = _find_impl(cls, registry) - dispatch_cache[cls] = impl - return impl + if not any(_pep585_registry_matches(cls, registry)): + return _fetch_dispatch_with_cache(cls) + + return _find_impl(cls_obj, registry) def _is_valid_dispatch_type(cls): if isinstance(cls, type): From 6b132be1f1992dcee92db590d2182f575ecb101d Mon Sep 17 00:00:00 2001 From: Skyler Curtis Date: Tue, 1 Apr 2025 18:11:38 -0400 Subject: [PATCH 8/9] singledispatchmethod pep585 tests --- Lib/test/test_functools.py | 127 +++++++++++++++++++++++++++++++++++++ 1 file changed, 127 insertions(+) diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 9e25c3d9804bd7..6fc7d3c9a80ad9 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -2254,6 +2254,133 @@ def g_list_int(l: list[int]): self.assertRaises(RuntimeError, g.dispatch(list[int])) + def test_pep585_method_basic(self): + class A: + @functools.singledispatchmethod + def g(obj): + return "base" + def g_list_int(li): + return "list of ints" + + a = A() + a.g.register(list[int], A.g_list_int) + self.assertEqual(a.g([1]), "list of ints") + self.assertIs(a.g.dispatch(list[int]), A.g_list_int) + + def test_pep585_method_annotation(self): + class A: + @functools.singledispatchmethod + def g(obj): + return "base" + # previously this failed with: 'not a class' + @g.register + def g_list_int(li: list[int]): + return "list of ints" + a = A() + self.assertEqual(a.g([1,2,3]), "list of ints") + self.assertIs(g.dispatch(tuple[int]), A.g_list_int) + + def test_pep585_method_all_must_match(self): + class A: + @functools.singledispatch + def g(obj): + return "base" + def g_list_int(li): + return "list of ints" + def g_list_not_ints(l): + # should only trigger if list doesnt match `list[int]` + # ie. at least one element is not an int + return "!all(int)" + + a = A() + a.g.register(list[int], A.g_list_int) + a.g.register(list, A.g_list_not_ints) + + self.assertEqual(a.g([1,2,3]), "list of ints") + self.assertEqual(a.g([1,2,3, "hello"]), "!all(int)") + self.assertEqual(a.g([3.14]), "!all(int)") + + self.assertIs(a.g.dispatch(list[int]), A.g_list_int) + self.assertIs(a.g.dispatch(list[str]), A.g_list_not_ints) + self.assertIs(a.g.dispatch(list[float]), A.g_list_not_ints) + self.assertIs(a.g.dispatch(list[int|str]), A.g_list_not_ints) + + def test_pep585_method_specificity(self): + class A: + @functools.singledispatch + def g(obj): + return "base" + @g.register + def g_list(l: list): + return "basic list" + @g.register + def g_list_int(li: list[int]): + return "int" + @g.register + def g_list_str(ls: list[str]): + return "str" + @g.register + def g_list_mixed_int_str(lmis:list[int|str]): + return "int|str" + @g.register + def g_list_mixed_int_float(lmif: list[int|float]): + return "int|float" + @g.register + def g_list_mixed_int_float_str(lmifs: list[int|float|str]): + return "int|float|str" + + a = A() + + # this matches list, list[int], list[int|str], list[int|float|str], list[int|...|...|...|...] + # but list[int] is the most specific, so that is correct + self.assertEqual(a.g([1,2,3]), "int") + + # this cannot match list[int] because of the string + # it does match list[int|float|str] but this is incorrect because, + # the most specific is list[int|str] + self.assertEqual(a.g([1,2,3, "hello"]), "int|str") + + # list[float] is not mapped so, + # list[int|float] is the most specific + self.assertEqual(a.g([3.14]), "int|float") + + self.assertIs(a.g.dispatch(list[int]), A.g_list_int) + self.assertIs(a.g.dispatch(list[float]), A.g_list_mixed_int_float) + self.assertIs(a.g.dispatch(list[int|str]), A.g_list_mixed_int_str) + + def test_pep585_method_ambiguous(self): + class A: + @functools.singledispatch + def g(obj): + return "base" + @g.register + def g_list_int_float(l: list[int|float]): + return "int|float" + @g.register + def g_list_int_str(l: list[int|str]): + return "int|str" + @g.register + def g_list_int(l: list[int]): + return "int only" + + a = A() + + self.assertEqual(a.g([3.1]), "int|float") # floats only + self.assertEqual(a.g(["hello"]), "int|str") # strings only + self.assertEqual(a.g([3.14, 1]), "int|float") # ints and floats + self.assertEqual(a.g(["hello", 1]), "int|str") # ints and strings + + self.assertIs(a.g.dispatch(list[int]), A.g_list_int) + self.assertIs(a.g.dispatch(list[str]), A.g_list_int_str) + self.assertIs(a.g.dispatch(list[float]), A.g_list_int_float) + self.assertIs(a.g.dispatch(list[int|str]), A.g_list_int_str) + self.assertIs(a.g.dispatch(list[int|float]), A.g_list_int_float) + + # these should fail because it's unclear which target is "correct" + self.assertRaises(RuntimeError, a.g([1])) + + self.assertRaises(RuntimeError, a.g.dispatch(list[int])) + def test_simple_overloads(self): @functools.singledispatch def g(obj): From a396d8a71e3beafb8ae1b5ca819044df54fb91f7 Mon Sep 17 00:00:00 2001 From: Skyler Curtis Date: Wed, 2 Apr 2025 01:33:40 -0400 Subject: [PATCH 9/9] fixing test failures --- Lib/functools.py | 4 ++- Lib/test/test_functools.py | 52 ++++++++++++++++---------------------- 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/Lib/functools.py b/Lib/functools.py index 3d0e47aa8635af..3cd4c23c7104cf 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -969,7 +969,8 @@ def _is_valid_dispatch_type(cls): return all(isinstance(arg, (type, UnionType)) for arg in get_args(cls)) return (isinstance(cls, UnionType) and - all(isinstance(arg, type) for arg in cls.__args__)) + all(isinstance(arg, (type, GenericAlias)) for arg in cls.__args__)) + def register(cls, func=None): """generic_func.register(cls, func) -> func @@ -987,6 +988,7 @@ def register(cls, func=None): f"Invalid first argument to `register()`. " f"{cls!r} is not a class or union type." ) + ann = getattr(cls, '__annotate__', None) if ann is None: raise TypeError( diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 6fc7d3c9a80ad9..fbc6623e83be01 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -3483,18 +3483,12 @@ def test_register_genericalias(self): def f(arg): return "default" - with self.assertRaisesRegex(TypeError, "Invalid first argument to "): - f.register(list[int], lambda arg: "types.GenericAlias") - with self.assertRaisesRegex(TypeError, "Invalid first argument to "): - f.register(typing.List[int], lambda arg: "typing.GenericAlias") - with self.assertRaisesRegex(TypeError, "Invalid first argument to "): - f.register(list[int] | str, lambda arg: "types.UnionTypes(types.GenericAlias)") - with self.assertRaisesRegex(TypeError, "Invalid first argument to "): - f.register(typing.List[float] | bytes, lambda arg: "typing.Union[typing.GenericAlias]") - - self.assertEqual(f([1]), "default") - self.assertEqual(f([1.0]), "default") - self.assertEqual(f(""), "default") + f.register(list[int], lambda arg: "types.GenericAlias") + f.register(list[float] | str, lambda arg: "types.UnionTypes(types.GenericAlias)") + + self.assertEqual(f([1]), "types.GenericAlias") + self.assertEqual(f([1.0]), "types.UnionTypes(types.GenericAlias)") + self.assertEqual(f(""), "types.UnionTypes(types.GenericAlias)") self.assertEqual(f(b""), "default") def test_register_genericalias_decorator(self): @@ -3502,41 +3496,39 @@ def test_register_genericalias_decorator(self): def f(arg): return "default" - with self.assertRaisesRegex(TypeError, "Invalid first argument to "): - f.register(list[int]) - with self.assertRaisesRegex(TypeError, "Invalid first argument to "): - f.register(typing.List[int]) - with self.assertRaisesRegex(TypeError, "Invalid first argument to "): - f.register(list[int] | str) - with self.assertRaisesRegex(TypeError, "Invalid first argument to "): - f.register(typing.List[int] | str) + f.register(list[int]) + #f.register(typing.List[int]) + f.register(list[int] | str) + #f.register(typing.List[int] | str) def test_register_genericalias_annotation(self): @functools.singledispatch def f(arg): return "default" - with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): - @f.register - def _(arg: list[int]): - return "types.GenericAlias" + @f.register + def _(arg: list[int]): + return "types.GenericAlias" + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): @f.register def _(arg: typing.List[float]): return "typing.GenericAlias" - with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): - @f.register - def _(arg: list[int] | str): - return "types.UnionType(types.GenericAlias)" + + @f.register + def _(arg: list[bytes] | str): + return "types.UnionType(types.GenericAlias)" + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): @f.register def _(arg: typing.List[float] | bytes): return "typing.Union[typing.GenericAlias]" - self.assertEqual(f([1]), "default") + self.assertEqual(f([1]), "types.GenericAlias") self.assertEqual(f([1.0]), "default") - self.assertEqual(f(""), "default") + self.assertEqual(f(""), "types.UnionType(types.GenericAlias)") self.assertEqual(f(b""), "default") + self.assertEqual(f([b""]), "types.UnionType(types.GenericAlias)") def test_forward_reference(self): @functools.singledispatch