From ac8079a81af8b1f704bcd90f24442ae736b87589 Mon Sep 17 00:00:00 2001 From: David Euresti Date: Mon, 26 Mar 2018 07:10:32 -0700 Subject: [PATCH] Annotate __init__ with type hints This just adds the annotations found at run-time to the `__annotations__` attribute of the created `__init__` function Fixes #249 --- src/attr/_make.py | 12 +++++++++--- tests/test_annotations.py | 25 +++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/src/attr/_make.py b/src/attr/_make.py index ebe12078b..cfbaf807b 100644 --- a/src/attr/_make.py +++ b/src/attr/_make.py @@ -1034,7 +1034,7 @@ def _make_init(attrs, post_init, frozen, slots, super_attr_map): sha1.hexdigest() ) - script, globs = _attrs_to_init_script( + script, globs, annotations = _attrs_to_init_script( attrs, frozen, slots, @@ -1063,7 +1063,9 @@ def _make_init(attrs, post_init, frozen, slots, super_attr_map): unique_filename, ) - return locs["__init__"] + __init__ = locs["__init__"] + __init__.__annotations__ = annotations + return __init__ def _add_init(cls, frozen): @@ -1259,6 +1261,7 @@ def fmt_setter_with_converter(attr_name, value_var): # This is a dictionary of names to validator and converter callables. # Injecting this into __init__ globals lets us avoid lookups. names_for_globals = {} + annotations = {} for a in attrs: if a.validator: @@ -1349,6 +1352,9 @@ def fmt_setter_with_converter(attr_name, value_var): else: lines.append(fmt_setter(attr_name, arg_name)) + if a.init is True and a.converter is None and a.type is not None: + annotations[arg_name] = a.type + if attrs_to_validate: # we can skip this if there are no validators. names_for_globals["_config"] = _config lines.append("if _config._run_validators is True:") @@ -1368,7 +1374,7 @@ def __init__(self, {args}): """.format( args=", ".join(args), lines="\n ".join(lines) if lines else "pass", - ), names_for_globals + ), names_for_globals, annotations class Attribute(object): diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 602f21bd5..a6fc49700 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -13,6 +13,8 @@ from attr.exceptions import UnannotatedAttributeError +NoneType = type(None) + class TestAnnotations: """ @@ -32,6 +34,7 @@ class C: assert int is attr.fields(C).x.type assert str is attr.fields(C).y.type assert None is attr.fields(C).z.type + assert C.__init__.__annotations__ == {'x': int, 'y': str} def test_catches_basic_type_conflict(self): """ @@ -57,6 +60,8 @@ class C: assert typing.List[int] is attr.fields(C).x.type assert typing.Optional[str] is attr.fields(C).y.type + assert C.__init__.__annotations__ == {'x': typing.List[int], + 'y': typing.Union[str, NoneType]} def test_only_attrs_annotations_collected(self): """ @@ -68,6 +73,7 @@ class C: y: int assert 1 == len(attr.fields(C)) + assert C.__init__.__annotations__ == {'x': typing.List[int]} @pytest.mark.parametrize("slots", [True, False]) def test_auto_attribs(self, slots): @@ -115,6 +121,14 @@ class C: i.y = 23 assert 23 == i.y + assert C.__init__.__annotations__ == { + 'a': int, + 'x': typing.List[int], + 'y': int, + 'z': int, + 'foo': typing.Any + } + @pytest.mark.parametrize("slots", [True, False]) def test_auto_attribs_unannotated(self, slots): """ @@ -154,3 +168,14 @@ class C(A): assert "B(a=1, b=2)" == repr(B()) assert "C(a=1)" == repr(C()) + + assert A.__init__.__annotations__ == { + 'a': int, + } + assert B.__init__.__annotations__ == { + 'a': int, + 'b': int, + } + assert C.__init__.__annotations__ == { + 'a': int, + }