From 56045f81172ae52b8d8165a4e3d46c6bd6d7fff7 Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Sat, 18 Dec 2021 08:21:58 +0100 Subject: [PATCH] Fix ForwardRef issue in Python 3.9.8 (#280) --- apischema/typing.py | 58 ++++++++++++++++++-------------------------- apischema/visitor.py | 12 +++------ tests/test_typing.py | 11 ++++++--- 3 files changed, 35 insertions(+), 46 deletions(-) diff --git a/apischema/typing.py b/apischema/typing.py index fd21ddd1..e0cc1f90 100644 --- a/apischema/typing.py +++ b/apischema/typing.py @@ -7,12 +7,12 @@ Any, Callable, Collection, + Dict, Generic, Set, Tuple, Type, TypeVar, - _eval_type, ) @@ -152,44 +152,32 @@ def generic_mro(tp): return tuple(result.get(sub_cls, sub_cls) for sub_cls in cls.__mro__) -def _class_annotations(cls, globalns, localns): - hints = {} - if globalns is None: - base_globals = sys.modules[cls.__module__].__dict__ - else: - base_globals = globalns - for name, value in cls.__dict__.get("__annotations__", {}).items(): - if value is None: - value = type(None) - if isinstance(value, str): - value = ForwardRef(value, is_argument=False) - hints[name] = _eval_type(value, base_globals, localns) - return hints - +def resolve_type_hints(obj: Any) -> Dict[str, Any]: + """Wrap get_type_hints to resolve type vars in case of generic inheritance. -def get_type_hints2(obj, globalns=None, localns=None): # type: ignore - if isinstance(obj, type) or isinstance(get_origin(obj), type): - hints = {} + `obj` can also be a parametrized generic class.""" + origin_or_obj = get_origin(obj) or obj + hints = get_type_hints(origin_or_obj, include_extras=True) + if isinstance(origin_or_obj, type): for base in reversed(generic_mro(obj)): - origin = get_origin(base) - if hasattr(origin, "__orig_bases__"): - parameters = _collect_type_vars(origin.__orig_bases__) - substitution = dict(zip(parameters, get_args(base))) - annotations = _class_annotations(get_origin(base), globalns, localns) - for name, tp in annotations.items(): - if isinstance(tp, TypeVar): - hints[name] = substitution.get(tp, tp) - elif getattr(tp, "__parameters__", ()): - hints[name] = tp[ - tuple(substitution.get(p, p) for p in tp.__parameters__) + base_origin = get_origin(base) + if base_origin is not None and getattr(base_origin, "__parameters__", ()): # type: ignore + substitution = dict(zip(base_origin.__parameters__, get_args(base))) + base_annotations = getattr(base_origin, "__dict__", {}).get( + "__annotations__", {} + ) + for name, hint in get_type_hints(base_origin).items(): + if name not in base_annotations: + continue + if isinstance(hint, TypeVar): + hints[name] = substitution.get(hint, hint) + elif getattr(hint, "__parameters__", ()): + hints[name] = hint[ + tuple(substitution.get(p, p) for p in hint.__parameters__) ] else: - hints[name] = tp - else: - hints.update(_class_annotations(base, globalns, localns)) - return hints - else: - return get_type_hints(obj, globalns, localns, include_extras=True) + hints[name] = hint + return hints _T = TypeVar("_T") diff --git a/apischema/visitor.py b/apischema/visitor.py index f8c26a5b..6060dc60 100644 --- a/apischema/visitor.py +++ b/apischema/visitor.py @@ -31,23 +31,19 @@ get_args, get_origin, get_type_hints, - get_type_hints2, is_annotated, is_literal, is_named_tuple, is_type_var, is_typed_dict, required_keys, + resolve_type_hints, ) from apischema.utils import PREFIX, get_origin_or_type, has_type_vars, is_dataclass TUPLE_TYPE = get_origin(Tuple[Any]) -def type_hints_cache(obj) -> Mapping[str, AnyType]: - return get_type_hints2(obj) - - def dataclass_types_and_fields( tp: AnyType, ) -> Tuple[Mapping[str, AnyType], Sequence[Field], Sequence[Field]]: @@ -55,7 +51,7 @@ def dataclass_types_and_fields( cls = get_origin_or_type(tp) assert is_dataclass(cls) - types = get_type_hints2(tp) + types = resolve_type_hints(tp) fields, init_fields = [], [] for field in getattr(cls, _FIELDS).values(): assert isinstance(field, Field) @@ -188,7 +184,7 @@ def visit(self, tp: AnyType) -> Result: # NamedTuple if is_named_tuple(origin): if hasattr(origin, "__annotations__"): - types = type_hints_cache(origin) + types = resolve_type_hints(origin) elif hasattr(origin, "__field_types"): # pragma: no cover types = origin.__field_types # type: ignore else: # pragma: no cover @@ -200,7 +196,7 @@ def visit(self, tp: AnyType) -> Result: return self.literal(origin.__values__) # type: ignore if is_typed_dict(origin): return self.typed_dict( - origin, type_hints_cache(origin), required_keys(origin) + origin, resolve_type_hints(origin), required_keys(origin) ) if is_type_var(origin): if origin.__constraints__: diff --git a/tests/test_typing.py b/tests/test_typing.py index f85d1960..02489f11 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -3,7 +3,12 @@ from pytest import mark -from apischema.typing import _TypedDictMeta, generic_mro, get_type_hints2, required_keys +from apischema.typing import ( + _TypedDictMeta, + generic_mro, + required_keys, + resolve_type_hints, +) T = TypeVar("T") U = TypeVar("U") @@ -44,8 +49,8 @@ def test_generic_mro(tp, result, _): @mark.parametrize("tp, _, result", test_cases) -def test_get_type_hints2(tp, _, result): - assert get_type_hints2(tp) == result +def test_resolve_type_hints(tp, _, result): + assert resolve_type_hints(tp) == result def test_required_keys():