Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support variadic tuple packing/unpacking #16205

Merged
merged 10 commits into from
Oct 8, 2023
23 changes: 22 additions & 1 deletion mypy/argmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
Type,
TypedDictType,
TypeOfAny,
TypeVarTupleType,
UnpackType,
get_proper_type,
)

Expand Down Expand Up @@ -174,6 +176,7 @@ def expand_actual_type(
actual_kind: nodes.ArgKind,
formal_name: str | None,
formal_kind: nodes.ArgKind,
allow_unpack: bool = False,
) -> Type:
"""Return the actual (caller) type(s) of a formal argument with the given kinds.

Expand All @@ -189,6 +192,11 @@ def expand_actual_type(
original_actual = actual_type
actual_type = get_proper_type(actual_type)
if actual_kind == nodes.ARG_STAR:
if isinstance(actual_type, TypeVarTupleType):
# This code path is hit when *Ts is passed to a callable and various
# special-handling didn't catch this. The best thing we can do is to use
# the upper bound.
actual_type = get_proper_type(actual_type.upper_bound)
if isinstance(actual_type, Instance) and actual_type.args:
from mypy.subtypes import is_subtype

Expand All @@ -209,7 +217,20 @@ def expand_actual_type(
self.tuple_index = 1
else:
self.tuple_index += 1
return actual_type.items[self.tuple_index - 1]
item = actual_type.items[self.tuple_index - 1]
if isinstance(item, UnpackType) and not allow_unpack:
# An upack item that doesn't have special handling, use upper bound as above.
unpacked = get_proper_type(item.type)
if isinstance(unpacked, TypeVarTupleType):
fallback = get_proper_type(unpacked.upper_bound)
else:
fallback = unpacked
assert (
isinstance(fallback, Instance)
and fallback.type.fullname == "builtins.tuple"
)
item = fallback.args[0]
return item
elif isinstance(actual_type, ParamSpecType):
# ParamSpec is valid in *args but it can't be unpacked.
return actual_type
Expand Down
107 changes: 99 additions & 8 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,13 @@
TypeType,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
UnboundType,
UninhabitedType,
UnionType,
UnpackType,
find_unpack_in_list,
flatten_nested_unions,
get_proper_type,
get_proper_types,
Expand Down Expand Up @@ -3430,6 +3433,37 @@ def is_assignable_slot(self, lvalue: Lvalue, typ: Type | None) -> bool:
return all(self.is_assignable_slot(lvalue, u) for u in typ.items)
return False

def flatten_rvalues(self, rvalues: list[Expression]) -> list[Expression]:
"""Flatten expression list by expanding those * items that have tuple type.

For each regular type item in the tuple type use a TempNode(), for an Unpack
item use a corresponding StarExpr(TempNode()).
"""
new_rvalues = []
for rv in rvalues:
if not isinstance(rv, StarExpr):
new_rvalues.append(rv)
continue
typ = get_proper_type(self.expr_checker.accept(rv.expr))
if not isinstance(typ, TupleType):
new_rvalues.append(rv)
continue
for t in typ.items:
if not isinstance(t, UnpackType):
new_rvalues.append(TempNode(t))
else:
unpacked = get_proper_type(t.type)
if isinstance(unpacked, TypeVarTupleType):
fallback = unpacked.upper_bound
else:
assert (
isinstance(unpacked, Instance)
and unpacked.type.fullname == "builtins.tuple"
)
fallback = unpacked
new_rvalues.append(StarExpr(TempNode(fallback)))
return new_rvalues

def check_assignment_to_multiple_lvalues(
self,
lvalues: list[Lvalue],
Expand All @@ -3439,18 +3473,16 @@ def check_assignment_to_multiple_lvalues(
) -> None:
if isinstance(rvalue, (TupleExpr, ListExpr)):
# Recursively go into Tuple or List expression rhs instead of
# using the type of rhs, because this allowed more fine grained
# using the type of rhs, because this allows more fine-grained
# control in cases like: a, b = [int, str] where rhs would get
# type List[object]
rvalues: list[Expression] = []
iterable_type: Type | None = None
last_idx: int | None = None
for idx_rval, rval in enumerate(rvalue.items):
for idx_rval, rval in enumerate(self.flatten_rvalues(rvalue.items)):
if isinstance(rval, StarExpr):
typs = get_proper_type(self.expr_checker.accept(rval.expr))
if isinstance(typs, TupleType):
rvalues.extend([TempNode(typ) for typ in typs.items])
elif self.type_is_iterable(typs) and isinstance(typs, Instance):
if self.type_is_iterable(typs) and isinstance(typs, Instance):
if iterable_type is not None and iterable_type != self.iterable_item_type(
typs, rvalue
):
Expand Down Expand Up @@ -3517,8 +3549,32 @@ def check_assignment_to_multiple_lvalues(
self.check_multi_assignment(lvalues, rvalue, context, infer_lvalue_type)

def check_rvalue_count_in_assignment(
self, lvalues: list[Lvalue], rvalue_count: int, context: Context
self,
lvalues: list[Lvalue],
rvalue_count: int,
context: Context,
rvalue_unpack: int | None = None,
) -> bool:
if rvalue_unpack is not None:
if not any(isinstance(e, StarExpr) for e in lvalues):
self.fail("Variadic tuple unpacking requires a star target", context)
return False
if len(lvalues) > rvalue_count:
self.fail(message_registry.TOO_MANY_TARGETS_FOR_VARIADIC_UNPACK, context)
return False
left_star_index = next(i for i, lv in enumerate(lvalues) if isinstance(lv, StarExpr))
left_prefix = left_star_index
left_suffix = len(lvalues) - left_star_index - 1
right_prefix = rvalue_unpack
right_suffix = rvalue_count - rvalue_unpack - 1
if left_suffix > right_suffix or left_prefix > right_prefix:
# Case of asymmetric unpack like:
# rv: tuple[int, *Ts, int, int]
# x, y, *xs, z = rv
# it is technically valid, but is tricky to reason about.
# TODO: support this (at least if the r.h.s. unpack is a homogeneous tuple).
self.fail(message_registry.TOO_MANY_TARGETS_FOR_VARIADIC_UNPACK, context)
return True
if any(isinstance(lvalue, StarExpr) for lvalue in lvalues):
if len(lvalues) - 1 > rvalue_count:
self.msg.wrong_number_values_to_unpack(rvalue_count, len(lvalues) - 1, context)
Expand Down Expand Up @@ -3552,6 +3608,13 @@ def check_multi_assignment(
if len(relevant_items) == 1:
rvalue_type = get_proper_type(relevant_items[0])

if (
isinstance(rvalue_type, TupleType)
and find_unpack_in_list(rvalue_type.items) is not None
):
# Normalize for consistent handling with "old-style" homogeneous tuples.
rvalue_type = expand_type(rvalue_type, {})

if isinstance(rvalue_type, AnyType):
for lv in lvalues:
if isinstance(lv, StarExpr):
Expand Down Expand Up @@ -3663,7 +3726,10 @@ def check_multi_assignment_from_tuple(
undefined_rvalue: bool,
infer_lvalue_type: bool = True,
) -> None:
if self.check_rvalue_count_in_assignment(lvalues, len(rvalue_type.items), context):
rvalue_unpack = find_unpack_in_list(rvalue_type.items)
if self.check_rvalue_count_in_assignment(
lvalues, len(rvalue_type.items), context, rvalue_unpack=rvalue_unpack
):
star_index = next(
(i for i, lv in enumerate(lvalues) if isinstance(lv, StarExpr)), len(lvalues)
)
Expand Down Expand Up @@ -3708,12 +3774,37 @@ def check_multi_assignment_from_tuple(
self.check_assignment(lv, self.temp_node(rv_type, context), infer_lvalue_type)
if star_lv:
list_expr = ListExpr(
[self.temp_node(rv_type, context) for rv_type in star_rv_types]
[
self.temp_node(rv_type, context)
if not isinstance(rv_type, UnpackType)
else StarExpr(self.temp_node(rv_type.type, context))
for rv_type in star_rv_types
]
)
list_expr.set_line(context)
self.check_assignment(star_lv.expr, list_expr, infer_lvalue_type)
for lv, rv_type in zip(right_lvs, right_rv_types):
self.check_assignment(lv, self.temp_node(rv_type, context), infer_lvalue_type)
else:
# Store meaningful Any types for lvalues, errors are already given
# by check_rvalue_count_in_assignment()
if infer_lvalue_type:
for lv in lvalues:
if (
isinstance(lv, NameExpr)
and isinstance(lv.node, Var)
and lv.node.type is None
):
lv.node.type = AnyType(TypeOfAny.from_error)
elif isinstance(lv, StarExpr):
if (
isinstance(lv.expr, NameExpr)
and isinstance(lv.expr.node, Var)
and lv.expr.node.type is None
):
lv.expr.node.type = self.named_generic_type(
"builtins.list", [AnyType(TypeOfAny.from_error)]
)

def lvalue_type_for_inference(self, lvalues: list[Lvalue], rvalue_type: TupleType) -> Type:
star_index = next(
Expand Down