Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
attr_define_makers: Final = {"attr.define", "attr.mutable", "attrs.define", "attrs.mutable"}
attr_attrib_makers: Final = {"attr.ib", "attr.attrib", "attr.attr", "attr.field", "attrs.field"}
attr_optional_converters: Final = {"attr.converters.optional", "attrs.converters.optional"}
attr_converter_classes: Final = {"attr.Converter", "attrs.Converter"}

SELF_TVAR_NAME: Final = "_AT"
MAGIC_ATTR_NAME: Final = "__attrs_attrs__"
Expand Down Expand Up @@ -720,6 +721,19 @@ def _parse_converter(
else:
is_attr_converters_optional = False

if (
isinstance(converter_expr, CallExpr)
and isinstance(converter_expr.callee, RefExpr)
and converter_expr.callee.fullname in attr_converter_classes
and converter_expr.args
and converter_expr.args[0]
):
# Special handling for attrs.Converter(callable, takes_self=..., takes_field=...).
# The first positional argument is the actual conversion callable; the keyword
# arguments only affect what mypy gets passed at runtime, but the init type
# is still the first parameter of the wrapped callable.
converter_expr = converter_expr.args[0]

converter_type: Type | None = None
if isinstance(converter_expr, RefExpr) and converter_expr.node:
if isinstance(converter_expr.node, FuncDef):
Expand All @@ -734,6 +748,11 @@ def _parse_converter(
converter_type = converter_expr.node.type
elif isinstance(converter_expr.node, TypeInfo):
converter_type = type_object_type(converter_expr.node)
elif isinstance(converter_expr.node, Var) and converter_expr.node.type:
# The converter is a variable annotated with a callable type.
var_type = get_proper_type(converter_expr.node.type)
if isinstance(var_type, FunctionLike):
converter_type = var_type
elif (
isinstance(converter_expr, IndexExpr)
and isinstance(converter_expr.analyzed, TypeApplication)
Expand All @@ -751,6 +770,10 @@ def _parse_converter(
)
else:
converter_type = None
elif isinstance(converter_expr, CallExpr):
# The converter is the result of a call, e.g. converter=make_converter(arg).
# Use the return type of the callee as the converter type.
converter_type = _callable_return_type(converter_expr)

if isinstance(converter_expr, LambdaExpr):
# TODO: should we send a fail if converter_expr.min_args > 1?
Expand Down Expand Up @@ -794,6 +817,43 @@ def _parse_converter(
return converter_info


def _callable_return_type(call: CallExpr) -> Type | None:
"""Return the return type of call if it is statically known to be callable.

This is used to support converters created by higher-order functions, e.g.
converter=make_converter(arg). We don't perform full type inference at the
call site; we just look at the statically declared return type of the callee.
Generic returns are returned as-is and may contain unresolved type variables.
"""
callee = call.callee
callee_type: Type | None = None
if isinstance(callee, RefExpr) and callee.node:
if isinstance(callee.node, (FuncDef, OverloadedFuncDef)):
callee_type = callee.node.type
elif isinstance(callee.node, Var):
callee_type = callee.node.type
elif isinstance(callee, CallExpr):
# Chained calls like factory()(arg).
callee_type = _callable_return_type(callee)
if callee_type is None:
return None
callee_type = get_proper_type(callee_type)
if isinstance(callee_type, CallableType):
ret = get_proper_type(callee_type.ret_type)
if isinstance(ret, FunctionLike):
return ret
elif isinstance(callee_type, Overloaded):
# Without type inference at the call site we can't pick the correct
# overload. As a heuristic, take the first overload whose return type is
# itself a callable (this matches helpers like attrs.converters.pipe,
# whose first overload is the most specific callable form).
for item in callee_type.items:
ret = get_proper_type(item.ret_type)
if isinstance(ret, FunctionLike):
return ret
return None


def is_valid_overloaded_converter(defn: OverloadedFuncDef) -> bool:
return all(
(not isinstance(item, Decorator) or isinstance(item.func.type, FunctionLike))
Expand Down
138 changes: 134 additions & 4 deletions test-data/unit/check-plugin-attrs.test
Original file line number Diff line number Diff line change
Expand Up @@ -892,9 +892,9 @@ class A:
reveal_type(A)
[out]
main:16: error: Cannot determine __init__ type from converter
main:16: error: Argument "converter" has incompatible type "Callable[[], str]"; expected "Callable[[Any], str]"
main:16: error: Argument "converter" has incompatible type "Callable[[], str]"; expected "Callable[[Any], Never] | Converter"
main:17: error: Cannot determine __init__ type from converter
main:17: error: Argument "converter" has incompatible type overloaded function; expected "Callable[[Any], int]"
main:17: error: Argument "converter" has incompatible type overloaded function; expected "Callable[[Any], Never] | Converter"
main:18: note: Revealed type is "def (bad: Any, bad_overloaded: Any) -> __main__.A"
[builtins fixtures/list.pyi]

Expand All @@ -920,9 +920,9 @@ class A:
reveal_type(A)
[out]
main:17: error: Cannot determine __init__ type from converter
main:17: error: Argument "converter" has incompatible type "Callable[[], str]"; expected "Callable[[Any], str]"
main:17: error: Argument "converter" has incompatible type "Callable[[], str]"; expected "Callable[[Any], Never] | Converter"
main:18: error: Cannot determine __init__ type from converter
main:18: error: Argument "converter" has incompatible type overloaded function; expected "Callable[[Any], int]"
main:18: error: Argument "converter" has incompatible type overloaded function; expected "Callable[[Any], Never] | Converter"
main:19: note: Revealed type is "def (bad: Any, bad_overloaded: Any) -> __main__.A"
[builtins fixtures/list.pyi]

Expand All @@ -942,6 +942,136 @@ class C:
reveal_type(C) # N: Revealed type is "def (x: Any, y: Any, z: Any) -> __main__.C"
[builtins fixtures/list.pyi]

[case testAttrsUsingHigherOrderConverter]
# Regression test for https://github.com/python/mypy/issues/15736
from typing import Any, Callable
import attr

def make_converter(_length: int) -> Callable[[str], str]:
def converter(val: str) -> str:
return val
return converter

def make_untyped_converter(_length: int) -> Callable[[Any], Any]:
def f(val: Any) -> Any:
return val
return f

@attr.s
class C:
a: str = attr.ib(converter=make_converter(40))
b: str = attr.ib(converter=make_untyped_converter(40))

reveal_type(C) # N: Revealed type is "def (a: builtins.str, b: Any) -> __main__.C"
reveal_type(C("hi", 5).a) # N: Revealed type is "builtins.str"
[builtins fixtures/list.pyi]

[case testAttrsUsingCallableVariableConverter]
from typing import Callable
import attr

def to_str(x: int) -> str:
return ""
my_converter: Callable[[int], str] = to_str

@attr.s
class C:
x: str = attr.ib(converter=my_converter)

reveal_type(C) # N: Revealed type is "def (x: builtins.int) -> __main__.C"
reveal_type(C(15).x) # N: Revealed type is "builtins.str"
[builtins fixtures/list.pyi]

[case testAttrsUsingHigherOrderConverterChainedCall]
from typing import Callable
import attr

def outer() -> Callable[[int], Callable[[str], str]]:
def middle(_n: int) -> Callable[[str], str]:
def inner(v: str) -> str:
return v
return inner
return middle

@attr.s
class C:
x: str = attr.ib(converter=outer()(40))

reveal_type(C) # N: Revealed type is "def (x: builtins.str) -> __main__.C"
[builtins fixtures/list.pyi]

[case testAttrsUsingConverterClass]
import attr
import attrs

def to_int(val: str) -> int:
return int(val)

def with_self(val: str, instance: object) -> int:
return int(val)

def with_field(val: str, attr: object) -> int:
return int(val)

def with_both(val: str, instance: object, attr: object) -> int:
return int(val)

@attr.s
class A:
x: int = attr.ib(converter=attr.Converter(to_int))
y: int = attr.ib(converter=attr.Converter(with_self, takes_self=True))
z: int = attr.ib(converter=attr.Converter(with_field, takes_field=True))
w: int = attr.ib(converter=attr.Converter(with_both, takes_self=True, takes_field=True))

reveal_type(A) # N: Revealed type is "def (x: builtins.str, y: builtins.str, z: builtins.str, w: builtins.str) -> __main__.A"

@attrs.define
class B:
x: int = attrs.field(converter=attrs.Converter(to_int))

reveal_type(B) # N: Revealed type is "def (x: builtins.str) -> __main__.B"
[builtins fixtures/plugin_attrs.pyi]

[case testAttrsUsingPipeConverter]
import attr
from attr.converters import pipe

def to_str(val: int) -> str:
return ""

def repeat(val: str) -> str:
return val

@attr.s
class C:
x: str = attr.ib(converter=pipe(to_str, repeat))

reveal_type(C) # N: Revealed type is "def (x: Any) -> __main__.C"
[builtins fixtures/plugin_attrs.pyi]

[case testAttrsUsingDefaultIfNoneConverter]
from typing import Optional
import attr
from attr.converters import default_if_none

@attr.s
class C:
x: int = attr.ib(default=None, converter=default_if_none(0))

reveal_type(C) # N: Revealed type is "def (x: Any =) -> __main__.C"
[builtins fixtures/plugin_attrs.pyi]

[case testAttrsUsingToBoolConverter]
import attr
from attr.converters import to_bool

@attr.s
class C:
x: bool = attr.ib(converter=to_bool)

reveal_type(C) # N: Revealed type is "def (x: Any) -> __main__.C"
[builtins fixtures/plugin_attrs.pyi]

[case testAttrsUsingConverterAndSubclass]
import attr

Expand Down
24 changes: 18 additions & 6 deletions test-data/unit/lib-stub/attr/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,20 @@ from typing import TypeVar, overload, Callable, Any, Type, Optional, Union, Sequ
_T = TypeVar('_T')
_C = TypeVar('_C', bound=type)

class Converter:
# Simplified non-generic stub for testing the attrs plugin. The real attrs
# ``Converter`` is ``Generic[In, Out]``; the plugin doesn't rely on that.
def __init__(
self,
converter: Callable[..., Any],
*,
takes_self: bool = ...,
takes_field: bool = ...,
) -> None: ...

_ValidatorType = Callable[[Any, Any, _T], Any]
_ConverterType = Callable[[Any], _T]
_FieldConverterType = Union[_ConverterType[_T], Converter]
_FilterType = Callable[[Any, Any], bool]
_ValidatorArgType = Union[_ValidatorType[_T], Sequence[_ValidatorType[_T]]]

Expand Down Expand Up @@ -36,7 +48,7 @@ def attrib(default: None = ...,
convert: Optional[_ConverterType[_T]] = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: Optional[Type[_T]] = ...,
converter: Optional[_ConverterType[_T]] = ...,
converter: Optional[_FieldConverterType[_T]] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[bool] = ...,
Expand All @@ -53,7 +65,7 @@ def attrib(default: _T,
convert: Optional[_ConverterType[_T]] = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: Optional[Type[_T]] = ...,
converter: Optional[_ConverterType[_T]] = ...,
converter: Optional[_FieldConverterType[_T]] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[bool] = ...,
Expand All @@ -70,7 +82,7 @@ def attrib(default: Optional[_T] = ...,
convert: Optional[_ConverterType[_T]] = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: object = ...,
converter: Optional[_ConverterType[_T]] = ...,
converter: Optional[_FieldConverterType[_T]] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[bool] = ...,
Expand Down Expand Up @@ -203,7 +215,7 @@ def field(
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
converter: Optional[_ConverterType] = ...,
converter: Optional[_FieldConverterType[_T]] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[bool] = ...,
Expand All @@ -221,7 +233,7 @@ def field(
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
converter: Optional[_ConverterType] = ...,
converter: Optional[_FieldConverterType[_T]] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[bool] = ...,
Expand All @@ -239,7 +251,7 @@ def field(
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
converter: Optional[_ConverterType] = ...,
converter: Optional[_FieldConverterType[_T]] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[bool] = ...,
Expand Down
4 changes: 3 additions & 1 deletion test-data/unit/lib-stub/attr/converters.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeVar, Optional, Callable, overload
from typing import Any, TypeVar, Optional, Callable, overload
from . import _ConverterType

_T = TypeVar("_T")
Expand All @@ -10,3 +10,5 @@ def optional(
def default_if_none(default: _T) -> _ConverterType[_T]: ...
@overload
def default_if_none(*, factory: Callable[[], _T]) -> _ConverterType[_T]: ...
def pipe(*converters: Callable[[Any], Any]) -> Callable[[Any], Any]: ...
def to_bool(val: Any) -> bool: ...
9 changes: 5 additions & 4 deletions test-data/unit/lib-stub/attrs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ from typing import TypeVar, overload, Callable, Any, Optional, Union, Sequence,
Protocol, ClassVar, Type
from typing_extensions import TypeGuard

from attr import Attribute as Attribute
from attr import Attribute as Attribute, Converter as Converter


class AttrsInstance(Protocol):
Expand All @@ -14,6 +14,7 @@ _C = TypeVar('_C', bound=type)

_ValidatorType = Callable[[Any, Any, _T], Any]
_ConverterType = Callable[[Any], _T]
_FieldConverterType = Union[_ConverterType[_T], Converter]
_ValidatorArgType = Union[_ValidatorType[_T], Sequence[_ValidatorType[_T]]]

@overload
Expand Down Expand Up @@ -95,7 +96,7 @@ def field(
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
converter: Optional[_ConverterType] = ...,
converter: Optional[_FieldConverterType[_T]] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[bool] = ...,
Expand All @@ -114,7 +115,7 @@ def field(
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
converter: Optional[_ConverterType] = ...,
converter: Optional[_FieldConverterType[_T]] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[bool] = ...,
Expand All @@ -133,7 +134,7 @@ def field(
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
converter: Optional[_ConverterType] = ...,
converter: Optional[_FieldConverterType[_T]] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[bool] = ...,
Expand Down
4 changes: 3 additions & 1 deletion test-data/unit/lib-stub/attrs/converters.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeVar, Optional, Callable, overload
from typing import Any, TypeVar, Optional, Callable, overload
from attr import _ConverterType

_T = TypeVar("_T")
Expand All @@ -10,3 +10,5 @@ def optional(
def default_if_none(default: _T) -> _ConverterType[_T]: ...
@overload
def default_if_none(*, factory: Callable[[], _T]) -> _ConverterType[_T]: ...
def pipe(*converters: Callable[[Any], Any]) -> Callable[[Any], Any]: ...
def to_bool(val: Any) -> bool: ...
Loading