Skip to content

Commit

Permalink
[dataclass_transform] support implicit default for "init" parameter i…
Browse files Browse the repository at this point in the history
…n field specifiers (#15010)

(Basic functionality was implemented by @wesleywright in #14870. I added
overload resolution.)

This note from PEP 681 was missed in the initial implementation of field 
specifiers:

> If unspecified, init defaults to True. Field specifier functions can
use overloads that implicitly specify the value of init using a literal
bool value type (Literal[False] or Literal[True]).

This commit adds support for reading a default from the declared type of
the `init` parameter if possible. Otherwise, it continues to use the
typical default of `True`.

The implementation was non-trivial, since regular overload resolution
can't be used in the dataclass plugin, which is applied before type
checking. As a workaround, I added a simple overload resolution helper
that should be enough to support typical use cases. It doesn't do full
overload resolution using types, but it knows about `None`,
`Literal[True]` and `Literal[False]` and a few other things.

---------

Co-authored-by: Wesley Collin Wright <wesleyw@dropbox.com>
  • Loading branch information
JukkaL and wesleywright committed Apr 5, 2023
1 parent 7beaec2 commit 06aa182
Show file tree
Hide file tree
Showing 7 changed files with 299 additions and 12 deletions.
77 changes: 75 additions & 2 deletions mypy/plugins/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from mypy.argmap import map_actuals_to_formals
from mypy.fixup import TypeFixer
from mypy.nodes import (
ARG_POS,
Expand All @@ -13,6 +14,7 @@
Expression,
FuncDef,
JsonDict,
NameExpr,
Node,
PassStmt,
RefExpr,
Expand All @@ -22,20 +24,27 @@
from mypy.plugin import CheckerPluginInterface, ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.semanal_shared import (
ALLOW_INCOMPATIBLE_OVERRIDE,
parse_bool,
require_bool_literal_argument,
set_callable_name,
)
from mypy.typeops import ( # noqa: F401 # Part of public API
try_getting_str_literals as try_getting_str_literals,
)
from mypy.types import (
AnyType,
CallableType,
Instance,
LiteralType,
NoneType,
Overloaded,
Type,
TypeOfAny,
TypeType,
TypeVarType,
deserialize_type,
get_proper_type,
is_optional,
)
from mypy.typevars import fill_typevars
from mypy.util import get_unique_redefinition_name
Expand Down Expand Up @@ -87,6 +96,71 @@ def _get_argument(call: CallExpr, name: str) -> Expression | None:
return None


def find_shallow_matching_overload_item(overload: Overloaded, call: CallExpr) -> CallableType:
"""Perform limited lookup of a matching overload item.
Full overload resolution is only supported during type checking, but plugins
sometimes need to resolve overloads. This can be used in some such use cases.
Resolve overloads based on these things only:
* Match using argument kinds and names
* If formal argument has type None, only accept the "None" expression in the callee
* If formal argument has type Literal[True] or Literal[False], only accept the
relevant bool literal
Return the first matching overload item, or the last one if nothing matches.
"""
for item in overload.items[:-1]:
ok = True
mapped = map_actuals_to_formals(
call.arg_kinds,
call.arg_names,
item.arg_kinds,
item.arg_names,
lambda i: AnyType(TypeOfAny.special_form),
)

# Look for extra actuals
matched_actuals = set()
for actuals in mapped:
matched_actuals.update(actuals)
if any(i not in matched_actuals for i in range(len(call.args))):
ok = False

for arg_type, kind, actuals in zip(item.arg_types, item.arg_kinds, mapped):
if kind.is_required() and not actuals:
# Missing required argument
ok = False
break
elif actuals:
args = [call.args[i] for i in actuals]
arg_type = get_proper_type(arg_type)
arg_none = any(isinstance(arg, NameExpr) and arg.name == "None" for arg in args)
if isinstance(arg_type, NoneType):
if not arg_none:
ok = False
break
elif (
arg_none
and not is_optional(arg_type)
and not (
isinstance(arg_type, Instance)
and arg_type.type.fullname == "builtins.object"
)
and not isinstance(arg_type, AnyType)
):
ok = False
break
elif isinstance(arg_type, LiteralType) and type(arg_type.value) is bool:
if not any(parse_bool(arg) == arg_type.value for arg in args):
ok = False
break
if ok:
return item
return overload.items[-1]


def _get_callee_type(call: CallExpr) -> CallableType | None:
"""Return the type of the callee, regardless of its syntatic form."""

Expand All @@ -103,8 +177,7 @@ def _get_callee_type(call: CallExpr) -> CallableType | None:
if isinstance(callee_node, (Var, SYMBOL_FUNCBASE_TYPES)) and callee_node.type:
callee_node_type = get_proper_type(callee_node.type)
if isinstance(callee_node_type, Overloaded):
# We take the last overload.
return callee_node_type.items[-1]
return find_shallow_matching_overload_item(callee_node_type, call)
elif isinstance(callee_node_type, CallableType):
return callee_node_type

Expand Down
29 changes: 27 additions & 2 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.plugins.common import (
_get_callee_type,
_get_decorator_bool_argument,
add_attribute_to_class,
add_method_to_class,
Expand All @@ -48,7 +49,7 @@
from mypy.semanal_shared import find_dataclass_transform_spec, require_bool_literal_argument
from mypy.server.trigger import make_wildcard_trigger
from mypy.state import state
from mypy.typeops import map_type_from_supertype
from mypy.typeops import map_type_from_supertype, try_getting_literals_from_type
from mypy.types import (
AnyType,
CallableType,
Expand Down Expand Up @@ -517,7 +518,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:

is_in_init_param = field_args.get("init")
if is_in_init_param is None:
is_in_init = True
is_in_init = self._get_default_init_value_for_field_specifier(stmt.rvalue)
else:
is_in_init = bool(self._api.parse_bool(is_in_init_param))

Expand Down Expand Up @@ -760,6 +761,30 @@ def _get_bool_arg(self, name: str, default: bool) -> bool:
return require_bool_literal_argument(self._api, expression, name, default)
return default

def _get_default_init_value_for_field_specifier(self, call: Expression) -> bool:
"""
Find a default value for the `init` parameter of the specifier being called. If the
specifier's type signature includes an `init` parameter with a type of `Literal[True]` or
`Literal[False]`, return the appropriate boolean value from the literal. Otherwise,
fall back to the standard default of `True`.
"""
if not isinstance(call, CallExpr):
return True

specifier_type = _get_callee_type(call)
if specifier_type is None:
return True

parameter = specifier_type.argument_by_name("init")
if parameter is None:
return True

literals = try_getting_literals_from_type(parameter.typ, bool, "builtins.bool")
if literals is None or len(literals) != 1:
return True

return literals[0]

def _infer_dataclass_attr_init_type(
self, sym: SymbolTableNode, name: str, context: Context
) -> Type | None:
Expand Down
9 changes: 3 additions & 6 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@
calculate_tuple_fallback,
find_dataclass_transform_spec,
has_placeholder,
parse_bool,
require_bool_literal_argument,
set_callable_name as set_callable_name,
)
Expand Down Expand Up @@ -6465,12 +6466,8 @@ def is_initial_mangled_global(self, name: str) -> bool:
return name == unmangle(name) + "'"

def parse_bool(self, expr: Expression) -> bool | None:
if isinstance(expr, NameExpr):
if expr.fullname == "builtins.True":
return True
if expr.fullname == "builtins.False":
return False
return None
# This wrapper is preserved for plugins.
return parse_bool(expr)

def parse_str_literal(self, expr: Expression) -> str | None:
"""Attempt to find the string literal value of the given expression. Returns `None` if no
Expand Down
12 changes: 11 additions & 1 deletion mypy/semanal_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Decorator,
Expression,
FuncDef,
NameExpr,
Node,
OverloadedFuncDef,
RefExpr,
Expand Down Expand Up @@ -451,11 +452,20 @@ def require_bool_literal_argument(
default: bool | None = None,
) -> bool | None:
"""Attempt to interpret an expression as a boolean literal, and fail analysis if we can't."""
value = api.parse_bool(expression)
value = parse_bool(expression)
if value is None:
api.fail(
f'"{name}" argument must be a True or False literal', expression, code=LITERAL_REQ
)
return default

return value


def parse_bool(expr: Expression) -> bool | None:
if isinstance(expr, NameExpr):
if expr.fullname == "builtins.True":
return True
if expr.fullname == "builtins.False":
return False
return None
148 changes: 147 additions & 1 deletion mypy/test/testtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,21 @@
from mypy.indirection import TypeIndirectionVisitor
from mypy.join import join_simple, join_types
from mypy.meet import meet_types, narrow_declared_type
from mypy.nodes import ARG_OPT, ARG_POS, ARG_STAR, ARG_STAR2, CONTRAVARIANT, COVARIANT, INVARIANT
from mypy.nodes import (
ARG_NAMED,
ARG_OPT,
ARG_POS,
ARG_STAR,
ARG_STAR2,
CONTRAVARIANT,
COVARIANT,
INVARIANT,
ArgKind,
CallExpr,
Expression,
NameExpr,
)
from mypy.plugins.common import find_shallow_matching_overload_item
from mypy.state import state
from mypy.subtypes import is_more_precise, is_proper_subtype, is_same_type, is_subtype
from mypy.test.helpers import Suite, assert_equal, assert_type, skip
Expand Down Expand Up @@ -1287,3 +1301,135 @@ def assert_union_result(self, t: ProperType, expected: list[Type]) -> None:
t2 = remove_instance_last_known_values(t)
assert type(t2) is UnionType
assert t2.items == expected


class ShallowOverloadMatchingSuite(Suite):
def setUp(self) -> None:
self.fx = TypeFixture()

def test_simple(self) -> None:
fx = self.fx
ov = self.make_overload([[("x", fx.anyt, ARG_NAMED)], [("y", fx.anyt, ARG_NAMED)]])
# Match first only
self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "x")), 0)
# Match second only
self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "y")), 1)
# No match -- invalid keyword arg name
self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "z")), 1)
# No match -- missing arg
self.assert_find_shallow_matching_overload_item(ov, make_call(), 1)
# No match -- extra arg
self.assert_find_shallow_matching_overload_item(
ov, make_call(("foo", "x"), ("foo", "z")), 1
)

def test_match_using_types(self) -> None:
fx = self.fx
ov = self.make_overload(
[
[("x", fx.nonet, ARG_POS)],
[("x", fx.lit_false, ARG_POS)],
[("x", fx.lit_true, ARG_POS)],
[("x", fx.anyt, ARG_POS)],
]
)
self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 0)
self.assert_find_shallow_matching_overload_item(ov, make_call(("builtins.False", None)), 1)
self.assert_find_shallow_matching_overload_item(ov, make_call(("builtins.True", None)), 2)
self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", None)), 3)

def test_none_special_cases(self) -> None:
fx = self.fx
ov = self.make_overload(
[[("x", fx.callable(fx.nonet), ARG_POS)], [("x", fx.nonet, ARG_POS)]]
)
self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 1)
self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0)
ov = self.make_overload([[("x", fx.str_type, ARG_POS)], [("x", fx.nonet, ARG_POS)]])
self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 1)
self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0)
ov = self.make_overload(
[[("x", UnionType([fx.str_type, fx.a]), ARG_POS)], [("x", fx.nonet, ARG_POS)]]
)
self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 1)
self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0)
ov = self.make_overload([[("x", fx.o, ARG_POS)], [("x", fx.nonet, ARG_POS)]])
self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 0)
self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0)
ov = self.make_overload(
[[("x", UnionType([fx.str_type, fx.nonet]), ARG_POS)], [("x", fx.nonet, ARG_POS)]]
)
self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 0)
self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0)
ov = self.make_overload([[("x", fx.anyt, ARG_POS)], [("x", fx.nonet, ARG_POS)]])
self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 0)
self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0)

def test_optional_arg(self) -> None:
fx = self.fx
ov = self.make_overload(
[[("x", fx.anyt, ARG_NAMED)], [("y", fx.anyt, ARG_OPT)], [("z", fx.anyt, ARG_NAMED)]]
)
self.assert_find_shallow_matching_overload_item(ov, make_call(), 1)
self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "x")), 0)
self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "y")), 1)
self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "z")), 2)

def test_two_args(self) -> None:
fx = self.fx
ov = self.make_overload(
[
[("x", fx.nonet, ARG_OPT), ("y", fx.anyt, ARG_OPT)],
[("x", fx.anyt, ARG_OPT), ("y", fx.anyt, ARG_OPT)],
]
)
self.assert_find_shallow_matching_overload_item(ov, make_call(), 0)
self.assert_find_shallow_matching_overload_item(ov, make_call(("None", "x")), 0)
self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "x")), 1)
self.assert_find_shallow_matching_overload_item(
ov, make_call(("foo", "y"), ("None", "x")), 0
)
self.assert_find_shallow_matching_overload_item(
ov, make_call(("foo", "y"), ("bar", "x")), 1
)

def assert_find_shallow_matching_overload_item(
self, ov: Overloaded, call: CallExpr, expected_index: int
) -> None:
c = find_shallow_matching_overload_item(ov, call)
assert c in ov.items
assert ov.items.index(c) == expected_index

def make_overload(self, items: list[list[tuple[str, Type, ArgKind]]]) -> Overloaded:
result = []
for item in items:
arg_types = []
arg_names = []
arg_kinds = []
for name, typ, kind in item:
arg_names.append(name)
arg_types.append(typ)
arg_kinds.append(kind)
result.append(
CallableType(
arg_types, arg_kinds, arg_names, ret_type=NoneType(), fallback=self.fx.o
)
)
return Overloaded(result)


def make_call(*items: tuple[str, str | None]) -> CallExpr:
args: list[Expression] = []
arg_names = []
arg_kinds = []
for arg, name in items:
shortname = arg.split(".")[-1]
n = NameExpr(shortname)
n.fullname = arg
args.append(n)
arg_names.append(name)
if name:
arg_kinds.append(ARG_NAMED)
else:
arg_kinds.append(ARG_POS)
return CallExpr(NameExpr("f"), args, arg_kinds, arg_names)
Loading

0 comments on commit 06aa182

Please sign in to comment.