Skip to content

Commit

Permalink
Fix ForwardRef issue in Python 3.9.8 (#280)
Browse files Browse the repository at this point in the history
  • Loading branch information
wyfo committed Dec 18, 2021
1 parent 299213b commit 56045f8
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 46 deletions.
58 changes: 23 additions & 35 deletions apischema/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
Any,
Callable,
Collection,
Dict,
Generic,
Set,
Tuple,
Type,
TypeVar,
_eval_type,
)


Expand Down Expand Up @@ -152,44 +152,32 @@ def generic_mro(tp):
return tuple(result.get(sub_cls, sub_cls) for sub_cls in cls.__mro__)


def _class_annotations(cls, globalns, localns):
hints = {}
if globalns is None:
base_globals = sys.modules[cls.__module__].__dict__
else:
base_globals = globalns
for name, value in cls.__dict__.get("__annotations__", {}).items():
if value is None:
value = type(None)
if isinstance(value, str):
value = ForwardRef(value, is_argument=False)
hints[name] = _eval_type(value, base_globals, localns)
return hints

def resolve_type_hints(obj: Any) -> Dict[str, Any]:
"""Wrap get_type_hints to resolve type vars in case of generic inheritance.
def get_type_hints2(obj, globalns=None, localns=None): # type: ignore
if isinstance(obj, type) or isinstance(get_origin(obj), type):
hints = {}
`obj` can also be a parametrized generic class."""
origin_or_obj = get_origin(obj) or obj
hints = get_type_hints(origin_or_obj, include_extras=True)
if isinstance(origin_or_obj, type):
for base in reversed(generic_mro(obj)):
origin = get_origin(base)
if hasattr(origin, "__orig_bases__"):
parameters = _collect_type_vars(origin.__orig_bases__)
substitution = dict(zip(parameters, get_args(base)))
annotations = _class_annotations(get_origin(base), globalns, localns)
for name, tp in annotations.items():
if isinstance(tp, TypeVar):
hints[name] = substitution.get(tp, tp)
elif getattr(tp, "__parameters__", ()):
hints[name] = tp[
tuple(substitution.get(p, p) for p in tp.__parameters__)
base_origin = get_origin(base)
if base_origin is not None and getattr(base_origin, "__parameters__", ()): # type: ignore
substitution = dict(zip(base_origin.__parameters__, get_args(base)))
base_annotations = getattr(base_origin, "__dict__", {}).get(
"__annotations__", {}
)
for name, hint in get_type_hints(base_origin).items():
if name not in base_annotations:
continue
if isinstance(hint, TypeVar):
hints[name] = substitution.get(hint, hint)
elif getattr(hint, "__parameters__", ()):
hints[name] = hint[
tuple(substitution.get(p, p) for p in hint.__parameters__)
]
else:
hints[name] = tp
else:
hints.update(_class_annotations(base, globalns, localns))
return hints
else:
return get_type_hints(obj, globalns, localns, include_extras=True)
hints[name] = hint
return hints


_T = TypeVar("_T")
Expand Down
12 changes: 4 additions & 8 deletions apischema/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,27 @@
get_args,
get_origin,
get_type_hints,
get_type_hints2,
is_annotated,
is_literal,
is_named_tuple,
is_type_var,
is_typed_dict,
required_keys,
resolve_type_hints,
)
from apischema.utils import PREFIX, get_origin_or_type, has_type_vars, is_dataclass

TUPLE_TYPE = get_origin(Tuple[Any])


def type_hints_cache(obj) -> Mapping[str, AnyType]:
return get_type_hints2(obj)


def dataclass_types_and_fields(
tp: AnyType,
) -> Tuple[Mapping[str, AnyType], Sequence[Field], Sequence[Field]]:
from apischema.metadata.keys import INIT_VAR_METADATA

cls = get_origin_or_type(tp)
assert is_dataclass(cls)
types = get_type_hints2(tp)
types = resolve_type_hints(tp)
fields, init_fields = [], []
for field in getattr(cls, _FIELDS).values():
assert isinstance(field, Field)
Expand Down Expand Up @@ -188,7 +184,7 @@ def visit(self, tp: AnyType) -> Result:
# NamedTuple
if is_named_tuple(origin):
if hasattr(origin, "__annotations__"):
types = type_hints_cache(origin)
types = resolve_type_hints(origin)
elif hasattr(origin, "__field_types"): # pragma: no cover
types = origin.__field_types # type: ignore
else: # pragma: no cover
Expand All @@ -200,7 +196,7 @@ def visit(self, tp: AnyType) -> Result:
return self.literal(origin.__values__) # type: ignore
if is_typed_dict(origin):
return self.typed_dict(
origin, type_hints_cache(origin), required_keys(origin)
origin, resolve_type_hints(origin), required_keys(origin)
)
if is_type_var(origin):
if origin.__constraints__:
Expand Down
11 changes: 8 additions & 3 deletions tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

from pytest import mark

from apischema.typing import _TypedDictMeta, generic_mro, get_type_hints2, required_keys
from apischema.typing import (
_TypedDictMeta,
generic_mro,
required_keys,
resolve_type_hints,
)

T = TypeVar("T")
U = TypeVar("U")
Expand Down Expand Up @@ -44,8 +49,8 @@ def test_generic_mro(tp, result, _):


@mark.parametrize("tp, _, result", test_cases)
def test_get_type_hints2(tp, _, result):
assert get_type_hints2(tp) == result
def test_resolve_type_hints(tp, _, result):
assert resolve_type_hints(tp) == result


def test_required_keys():
Expand Down

0 comments on commit 56045f8

Please sign in to comment.