Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ForwardRef issue in Python 3.9.8 #280

Merged
merged 1 commit into from
Dec 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 23 additions & 35 deletions apischema/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
Any,
Callable,
Collection,
Dict,
Generic,
Set,
Tuple,
Type,
TypeVar,
Union,
_eval_type,
)


Expand Down Expand Up @@ -154,44 +154,32 @@ def generic_mro(tp):
return tuple(result.get(sub_cls, sub_cls) for sub_cls in cls.__mro__)


def _class_annotations(cls, globalns, localns):
hints = {}
if globalns is None:
base_globals = sys.modules[cls.__module__].__dict__
else:
base_globals = globalns
for name, value in cls.__dict__.get("__annotations__", {}).items():
if value is None:
value = type(None)
if isinstance(value, str):
value = ForwardRef(value, is_argument=False)
hints[name] = _eval_type(value, base_globals, localns)
return hints

def resolve_type_hints(obj: Any) -> Dict[str, Any]:
"""Wrap get_type_hints to resolve type vars in case of generic inheritance.

def get_type_hints2(obj, globalns=None, localns=None): # type: ignore
if isinstance(obj, type) or isinstance(get_origin(obj), type):
hints = {}
`obj` can also be a parametrized generic class."""
origin_or_obj = get_origin(obj) or obj
hints = get_type_hints(origin_or_obj, include_extras=True)
if isinstance(origin_or_obj, type):
for base in reversed(generic_mro(obj)):
origin = get_origin(base)
if hasattr(origin, "__orig_bases__"):
parameters = _collect_type_vars(origin.__orig_bases__)
substitution = dict(zip(parameters, get_args(base)))
annotations = _class_annotations(get_origin(base), globalns, localns)
for name, tp in annotations.items():
if isinstance(tp, TypeVar):
hints[name] = substitution.get(tp, tp)
elif getattr(tp, "__parameters__", ()):
hints[name] = tp[
tuple(substitution.get(p, p) for p in tp.__parameters__)
base_origin = get_origin(base)
if base_origin is not None and getattr(base_origin, "__parameters__", ()):
substitution = dict(zip(base_origin.__parameters__, get_args(base)))
base_annotations = getattr(base_origin, "__dict__", {}).get(
"__annotations__", {}
)
for name, hint in get_type_hints(base_origin).items():
Copy link
Contributor

@thomascobb thomascobb Dec 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can confirm that adding include_extras=True to this line fixes #281

if name not in base_annotations:
continue
if isinstance(hint, TypeVar):
hints[name] = substitution.get(hint, hint)
elif getattr(hint, "__parameters__", ()):
hints[name] = hint[
tuple(substitution.get(p, p) for p in hint.__parameters__)
]
else:
hints[name] = tp
else:
hints.update(_class_annotations(base, globalns, localns))
return hints
else:
return get_type_hints(obj, globalns, localns, include_extras=True)
hints[name] = hint
return hints


_T = TypeVar("_T")
Expand Down
12 changes: 4 additions & 8 deletions apischema/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@
get_args,
get_origin,
get_type_hints,
get_type_hints2,
is_annotated,
is_literal,
is_named_tuple,
is_type_var,
is_typed_dict,
is_union,
required_keys,
resolve_type_hints,
)
from apischema.utils import PREFIX, get_origin_or_type, has_type_vars, is_dataclass

Expand All @@ -50,18 +50,14 @@
TUPLE_TYPE = get_origin(Tuple[Any])


def type_hints_cache(obj) -> Mapping[str, AnyType]:
return get_type_hints2(obj)


def dataclass_types_and_fields(
tp: AnyType,
) -> Tuple[Mapping[str, AnyType], Sequence[Field], Sequence[Field]]:
from apischema.metadata.keys import INIT_VAR_METADATA

cls = get_origin_or_type(tp)
assert is_dataclass(cls)
types = get_type_hints2(tp)
types = resolve_type_hints(tp)
fields, init_fields = [], []
for field in getattr(cls, _FIELDS).values():
assert isinstance(field, Field)
Expand Down Expand Up @@ -200,7 +196,7 @@ def visit(self, tp: AnyType) -> Result:
# NamedTuple
if is_named_tuple(origin):
if hasattr(origin, "__annotations__"):
types = type_hints_cache(origin)
types = resolve_type_hints(origin)
elif hasattr(origin, "__field_types"): # pragma: no cover
types = origin.__field_types # type: ignore
else: # pragma: no cover
Expand All @@ -212,7 +208,7 @@ def visit(self, tp: AnyType) -> Result:
return self.literal(origin.__values__) # type: ignore
if is_typed_dict(origin):
return self.typed_dict(
origin, type_hints_cache(origin), required_keys(origin)
origin, resolve_type_hints(origin), required_keys(origin)
)
if is_type_var(origin):
if origin.__constraints__:
Expand Down
11 changes: 8 additions & 3 deletions tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

from pytest import mark

from apischema.typing import _TypedDictMeta, generic_mro, get_type_hints2, required_keys
from apischema.typing import (
_TypedDictMeta,
generic_mro,
required_keys,
resolve_type_hints,
)

T = TypeVar("T")
U = TypeVar("U")
Expand Down Expand Up @@ -44,8 +49,8 @@ def test_generic_mro(tp, result, _):


@mark.parametrize("tp, _, result", test_cases)
def test_get_type_hints2(tp, _, result):
assert get_type_hints2(tp) == result
def test_resolve_type_hints(tp, _, result):
assert resolve_type_hints(tp) == result


def test_required_keys():
Expand Down