Skip to content

Commit

Permalink
Annotate __init__ with type hints
Browse files Browse the repository at this point in the history
This just adds the annotations found at run-time to the
`__annotations__` attribute of the created `__init__` function

Fixes #249
  • Loading branch information
euresti committed Mar 26, 2018
1 parent c2fef8c commit ac8079a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/attr/_make.py
Expand Up @@ -1034,7 +1034,7 @@ def _make_init(attrs, post_init, frozen, slots, super_attr_map):
sha1.hexdigest()
)

script, globs = _attrs_to_init_script(
script, globs, annotations = _attrs_to_init_script(
attrs,
frozen,
slots,
Expand Down Expand Up @@ -1063,7 +1063,9 @@ def _make_init(attrs, post_init, frozen, slots, super_attr_map):
unique_filename,
)

return locs["__init__"]
__init__ = locs["__init__"]
__init__.__annotations__ = annotations
return __init__


def _add_init(cls, frozen):
Expand Down Expand Up @@ -1259,6 +1261,7 @@ def fmt_setter_with_converter(attr_name, value_var):
# This is a dictionary of names to validator and converter callables.
# Injecting this into __init__ globals lets us avoid lookups.
names_for_globals = {}
annotations = {}

for a in attrs:
if a.validator:
Expand Down Expand Up @@ -1349,6 +1352,9 @@ def fmt_setter_with_converter(attr_name, value_var):
else:
lines.append(fmt_setter(attr_name, arg_name))

if a.init is True and a.converter is None and a.type is not None:
annotations[arg_name] = a.type

if attrs_to_validate: # we can skip this if there are no validators.
names_for_globals["_config"] = _config
lines.append("if _config._run_validators is True:")
Expand All @@ -1368,7 +1374,7 @@ def __init__(self, {args}):
""".format(
args=", ".join(args),
lines="\n ".join(lines) if lines else "pass",
), names_for_globals
), names_for_globals, annotations


class Attribute(object):
Expand Down
25 changes: 25 additions & 0 deletions tests/test_annotations.py
Expand Up @@ -13,6 +13,8 @@

from attr.exceptions import UnannotatedAttributeError

NoneType = type(None)


class TestAnnotations:
"""
Expand All @@ -32,6 +34,7 @@ class C:
assert int is attr.fields(C).x.type
assert str is attr.fields(C).y.type
assert None is attr.fields(C).z.type
assert C.__init__.__annotations__ == {'x': int, 'y': str}

def test_catches_basic_type_conflict(self):
"""
Expand All @@ -57,6 +60,8 @@ class C:

assert typing.List[int] is attr.fields(C).x.type
assert typing.Optional[str] is attr.fields(C).y.type
assert C.__init__.__annotations__ == {'x': typing.List[int],
'y': typing.Union[str, NoneType]}

def test_only_attrs_annotations_collected(self):
"""
Expand All @@ -68,6 +73,7 @@ class C:
y: int

assert 1 == len(attr.fields(C))
assert C.__init__.__annotations__ == {'x': typing.List[int]}

@pytest.mark.parametrize("slots", [True, False])
def test_auto_attribs(self, slots):
Expand Down Expand Up @@ -115,6 +121,14 @@ class C:
i.y = 23
assert 23 == i.y

assert C.__init__.__annotations__ == {
'a': int,
'x': typing.List[int],
'y': int,
'z': int,
'foo': typing.Any
}

@pytest.mark.parametrize("slots", [True, False])
def test_auto_attribs_unannotated(self, slots):
"""
Expand Down Expand Up @@ -154,3 +168,14 @@ class C(A):

assert "B(a=1, b=2)" == repr(B())
assert "C(a=1)" == repr(C())

assert A.__init__.__annotations__ == {
'a': int,
}
assert B.__init__.__annotations__ == {
'a': int,
'b': int,
}
assert C.__init__.__annotations__ == {
'a': int,
}

0 comments on commit ac8079a

Please sign in to comment.