From 90f75e1069f2d692480bcd305fc35b4fe7847e18 Mon Sep 17 00:00:00 2001 From: "Miss Islington (bot)" <31488909+miss-islington@users.noreply.github.com> Date: Fri, 1 Mar 2024 19:01:27 +0100 Subject: [PATCH] [3.12] gh-112281: Allow `Union` with unhashable `Annotated` metadata (GH-112283) (#116213) Co-authored-by: Nikita Sobolev Co-authored-by: Alex Waygood --- Lib/test/test_types.py | 20 ++++ Lib/test/test_typing.py | 107 +++++++++++++++++- Lib/typing.py | 45 +++++--- ...-11-20-16-15-44.gh-issue-112281.gH4EVk.rst | 2 + 4 files changed, 156 insertions(+), 18 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2023-11-20-16-15-44.gh-issue-112281.gH4EVk.rst diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py index b86392f43cc5be..5ffe4085f09548 100644 --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -709,6 +709,26 @@ def test_hash(self): self.assertEqual(hash(int | str), hash(str | int)) self.assertEqual(hash(int | str), hash(typing.Union[int, str])) + def test_union_of_unhashable(self): + class UnhashableMeta(type): + __hash__ = None + + class A(metaclass=UnhashableMeta): ... + class B(metaclass=UnhashableMeta): ... + + self.assertEqual((A | B).__args__, (A, B)) + union1 = A | B + with self.assertRaises(TypeError): + hash(union1) + + union2 = int | B + with self.assertRaises(TypeError): + hash(union2) + + union3 = A | int + with self.assertRaises(TypeError): + hash(union3) + def test_instancecheck_and_subclasscheck(self): for x in (int | str, typing.Union[int, str]): with self.subTest(x=x): diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 7f9c10dd2a54e6..e0f71464796bca 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -2,10 +2,11 @@ import collections import collections.abc from collections import defaultdict -from functools import lru_cache, wraps +from functools import lru_cache, wraps, reduce import gc import inspect import itertools +import operator import pickle import re import sys @@ -1770,6 +1771,26 @@ def test_union_union(self): v = Union[u, Employee] self.assertEqual(v, Union[int, float, Employee]) + def test_union_of_unhashable(self): + class UnhashableMeta(type): + __hash__ = None + + class A(metaclass=UnhashableMeta): ... + class B(metaclass=UnhashableMeta): ... + + self.assertEqual(Union[A, B].__args__, (A, B)) + union1 = Union[A, B] + with self.assertRaises(TypeError): + hash(union1) + + union2 = Union[int, B] + with self.assertRaises(TypeError): + hash(union2) + + union3 = Union[A, int] + with self.assertRaises(TypeError): + hash(union3) + def test_repr(self): self.assertEqual(repr(Union), 'typing.Union') u = Union[Employee, int] @@ -5295,10 +5316,8 @@ def some(self): self.assertFalse(hasattr(WithOverride.some, "__override__")) def test_multiple_decorators(self): - import functools - def with_wraps(f): # similar to `lru_cache` definition - @functools.wraps(f) + @wraps(f) def wrapper(*args, **kwargs): return f(*args, **kwargs) return wrapper @@ -8183,6 +8202,76 @@ def test_flatten(self): self.assertEqual(A.__metadata__, (4, 5)) self.assertEqual(A.__origin__, int) + def test_deduplicate_from_union(self): + # Regular: + self.assertEqual(get_args(Annotated[int, 1] | int), + (Annotated[int, 1], int)) + self.assertEqual(get_args(Union[Annotated[int, 1], int]), + (Annotated[int, 1], int)) + self.assertEqual(get_args(Annotated[int, 1] | Annotated[int, 2] | int), + (Annotated[int, 1], Annotated[int, 2], int)) + self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[int, 2], int]), + (Annotated[int, 1], Annotated[int, 2], int)) + self.assertEqual(get_args(Annotated[int, 1] | Annotated[str, 1] | int), + (Annotated[int, 1], Annotated[str, 1], int)) + self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[str, 1], int]), + (Annotated[int, 1], Annotated[str, 1], int)) + + # Duplicates: + self.assertEqual(Annotated[int, 1] | Annotated[int, 1] | int, + Annotated[int, 1] | int) + self.assertEqual(Union[Annotated[int, 1], Annotated[int, 1], int], + Union[Annotated[int, 1], int]) + + # Unhashable metadata: + self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[int, set()] | int), + (str, Annotated[int, {}], Annotated[int, set()], int)) + self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[int, set()], int]), + (str, Annotated[int, {}], Annotated[int, set()], int)) + self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[str, {}] | int), + (str, Annotated[int, {}], Annotated[str, {}], int)) + self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[str, {}], int]), + (str, Annotated[int, {}], Annotated[str, {}], int)) + + self.assertEqual(get_args(Annotated[int, 1] | str | Annotated[str, {}] | int), + (Annotated[int, 1], str, Annotated[str, {}], int)) + self.assertEqual(get_args(Union[Annotated[int, 1], str, Annotated[str, {}], int]), + (Annotated[int, 1], str, Annotated[str, {}], int)) + + import dataclasses + @dataclasses.dataclass + class ValueRange: + lo: int + hi: int + v = ValueRange(1, 2) + self.assertEqual(get_args(Annotated[int, v] | None), + (Annotated[int, v], types.NoneType)) + self.assertEqual(get_args(Union[Annotated[int, v], None]), + (Annotated[int, v], types.NoneType)) + self.assertEqual(get_args(Optional[Annotated[int, v]]), + (Annotated[int, v], types.NoneType)) + + # Unhashable metadata duplicated: + self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int, + Annotated[int, {}] | int) + self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int, + int | Annotated[int, {}]) + self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int], + Union[Annotated[int, {}], int]) + self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int], + Union[int, Annotated[int, {}]]) + + def test_order_in_union(self): + expr1 = Annotated[int, 1] | str | Annotated[str, {}] | int + for args in itertools.permutations(get_args(expr1)): + with self.subTest(args=args): + self.assertEqual(expr1, reduce(operator.or_, args)) + + expr2 = Union[Annotated[int, 1], str, Annotated[str, {}], int] + for args in itertools.permutations(get_args(expr2)): + with self.subTest(args=args): + self.assertEqual(expr2, Union[args]) + def test_specialize(self): L = Annotated[List[T], "my decoration"] LI = Annotated[List[int], "my decoration"] @@ -8203,6 +8292,16 @@ def test_hash_eq(self): {Annotated[int, 4, 5], Annotated[int, 4, 5], Annotated[T, 4, 5]}, {Annotated[int, 4, 5], Annotated[T, 4, 5]} ) + # Unhashable `metadata` raises `TypeError`: + a1 = Annotated[int, []] + with self.assertRaises(TypeError): + hash(a1) + + class A: + __hash__ = None + a2 = Annotated[int, A()] + with self.assertRaises(TypeError): + hash(a2) def test_instantiate(self): class C: diff --git a/Lib/typing.py b/Lib/typing.py index 1e4c725be473b7..7581c16119d851 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -314,19 +314,33 @@ def _unpack_args(args): newargs.append(arg) return newargs -def _deduplicate(params): +def _deduplicate(params, *, unhashable_fallback=False): # Weed out strict duplicates, preserving the first of each occurrence. - all_params = set(params) - if len(all_params) < len(params): - new_params = [] - for t in params: - if t in all_params: - new_params.append(t) - all_params.remove(t) - params = new_params - assert not all_params, all_params - return params - + try: + return dict.fromkeys(params) + except TypeError: + if not unhashable_fallback: + raise + # Happens for cases like `Annotated[dict, {'x': IntValidator()}]` + return _deduplicate_unhashable(params) + +def _deduplicate_unhashable(unhashable_params): + new_unhashable = [] + for t in unhashable_params: + if t not in new_unhashable: + new_unhashable.append(t) + return new_unhashable + +def _compare_args_orderless(first_args, second_args): + first_unhashable = _deduplicate_unhashable(first_args) + second_unhashable = _deduplicate_unhashable(second_args) + t = list(second_unhashable) + try: + for elem in first_unhashable: + t.remove(elem) + except ValueError: + return False + return not t def _remove_dups_flatten(parameters): """Internal helper for Union creation and substitution. @@ -341,7 +355,7 @@ def _remove_dups_flatten(parameters): else: params.append(p) - return tuple(_deduplicate(params)) + return tuple(_deduplicate(params, unhashable_fallback=True)) def _flatten_literal_params(parameters): @@ -1548,7 +1562,10 @@ def copy_with(self, params): def __eq__(self, other): if not isinstance(other, (_UnionGenericAlias, types.UnionType)): return NotImplemented - return set(self.__args__) == set(other.__args__) + try: # fast path + return set(self.__args__) == set(other.__args__) + except TypeError: # not hashable, slow path + return _compare_args_orderless(self.__args__, other.__args__) def __hash__(self): return hash(frozenset(self.__args__)) diff --git a/Misc/NEWS.d/next/Library/2023-11-20-16-15-44.gh-issue-112281.gH4EVk.rst b/Misc/NEWS.d/next/Library/2023-11-20-16-15-44.gh-issue-112281.gH4EVk.rst new file mode 100644 index 00000000000000..01f6689bb471cd --- /dev/null +++ b/Misc/NEWS.d/next/Library/2023-11-20-16-15-44.gh-issue-112281.gH4EVk.rst @@ -0,0 +1,2 @@ +Allow creating :ref:`union of types` for +:class:`typing.Annotated` with unhashable metadata.