From c7308a66ef1f4e60f7dbfa75177634666817d861 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tin=20Tvrtkovi=C4=87?= Date: Sat, 18 Mar 2023 19:31:15 +0100 Subject: [PATCH] Flesh out resolve_types (#1099) * Flesh out resolve_types * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add changelog entry * Fix flake? * Update 1099.change.md --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hynek Schlawack --- changelog.d/1099.change.md | 1 + src/attr/__init__.pyi | 6 ++++-- src/attr/_compat.py | 1 + src/attr/_funcs.py | 18 +++++++++++++++--- tests/test_annotations.py | 28 ++++++++++++++++++++++++++++ 5 files changed, 49 insertions(+), 5 deletions(-) create mode 100644 changelog.d/1099.change.md diff --git a/changelog.d/1099.change.md b/changelog.d/1099.change.md new file mode 100644 index 000000000..a2114ae87 --- /dev/null +++ b/changelog.d/1099.change.md @@ -0,0 +1 @@ +`attrs.resolve_types()` can now pass `include_extras` to `typing.get_type_hints()` on Python 3.9+, and does so by default. diff --git a/src/attr/__init__.pyi b/src/attr/__init__.pyi index 1e70b593e..003dac4b4 100644 --- a/src/attr/__init__.pyi +++ b/src/attr/__init__.pyi @@ -69,6 +69,7 @@ _ValidatorArgType = Union[_ValidatorType[_T], Sequence[_ValidatorType[_T]]] class AttrsInstance(AttrsInstance_, Protocol): pass +_A = TypeVar("_A", bound=AttrsInstance) # _make -- class _Nothing(enum.Enum): @@ -488,11 +489,12 @@ def fields(cls: Type[AttrsInstance]) -> Any: ... def fields_dict(cls: Type[AttrsInstance]) -> Dict[str, Attribute[Any]]: ... def validate(inst: AttrsInstance) -> None: ... def resolve_types( - cls: _C, + cls: _A, globalns: Optional[Dict[str, Any]] = ..., localns: Optional[Dict[str, Any]] = ..., attribs: Optional[List[Attribute[Any]]] = ..., -) -> _C: ... + include_extras: bool = ..., +) -> _A: ... # TODO: add support for returning a proper attrs class from the mypy plugin # we use Any instead of _CountingAttr so that e.g. `make_class('Foo', diff --git a/src/attr/_compat.py b/src/attr/_compat.py index bd44d4f4c..c3bf5e33b 100644 --- a/src/attr/_compat.py +++ b/src/attr/_compat.py @@ -13,6 +13,7 @@ PYPY = platform.python_implementation() == "PyPy" +PY_3_9_PLUS = sys.version_info[:2] >= (3, 9) PY310 = sys.version_info[:2] >= (3, 10) PY_3_12_PLUS = sys.version_info[:2] >= (3, 12) diff --git a/src/attr/_funcs.py b/src/attr/_funcs.py index 6fa2456dc..fe110739b 100644 --- a/src/attr/_funcs.py +++ b/src/attr/_funcs.py @@ -3,7 +3,7 @@ import copy -from ._compat import get_generic_base +from ._compat import PY_3_9_PLUS, get_generic_base from ._make import NOTHING, _obj_setattr, fields from .exceptions import AttrsAttributeNotFoundError @@ -379,7 +379,9 @@ def evolve(inst, **changes): return cls(**changes) -def resolve_types(cls, globalns=None, localns=None, attribs=None): +def resolve_types( + cls, globalns=None, localns=None, attribs=None, include_extras=True +): """ Resolve any strings and forward annotations in type annotations. @@ -399,6 +401,10 @@ def resolve_types(cls, globalns=None, localns=None, attribs=None): :param Optional[list] attribs: List of attribs for the given class. This is necessary when calling from inside a ``field_transformer`` since *cls* is not an *attrs* class yet. + :param bool include_extras: Resolve more accurately, if possible. + Pass ``include_extras`` to ``typing.get_hints``, if supported by the + typing module. On supported Python versions (3.9+), this resolves the + types more accurately. :raise TypeError: If *cls* is not a class. :raise attrs.exceptions.NotAnAttrsClassError: If *cls* is not an *attrs* @@ -411,6 +417,7 @@ class and you didn't pass any attribs. .. versionadded:: 20.1.0 .. versionadded:: 21.1.0 *attribs* + .. versionadded:: 23.1.0 *include_extras* """ # Since calling get_type_hints is expensive we cache whether we've @@ -418,7 +425,12 @@ class and you didn't pass any attribs. if getattr(cls, "__attrs_types_resolved__", None) != cls: import typing - hints = typing.get_type_hints(cls, globalns=globalns, localns=localns) + kwargs = {"globalns": globalns, "localns": localns} + + if PY_3_9_PLUS: + kwargs["include_extras"] = include_extras + + hints = typing.get_type_hints(cls, **kwargs) for field in fields(cls) if attribs is None else attribs: if field.name in hints: # Since fields have been frozen we must work around it. diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 0ff68c779..f5ad41d0b 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -516,6 +516,34 @@ class C: assert str is attr.fields(C).y.type assert None is attr.fields(C).z.type + @pytest.mark.skipif( + sys.version_info[:2] < (3, 9), + reason="Incompatible behavior on older Pythons", + ) + def test_extra_resolve(self): + """ + `get_type_hints` returns extra type hints. + """ + from typing import Annotated + + globals = {"Annotated": Annotated} + + @attr.define + class C: + x: 'Annotated[float, "test"]' + + attr.resolve_types(C, globals) + + assert attr.fields(C).x.type == Annotated[float, "test"] + + @attr.define + class D: + x: 'Annotated[float, "test"]' + + attr.resolve_types(D, globals, include_extras=False) + + assert attr.fields(D).x.type == float + def test_resolve_types_auto_attrib(self, slots): """ Types can be resolved even when strings are involved.