Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add foundation for TypeVar defaults (PEP 696) #14872

Merged
merged 3 commits into from
May 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7191,6 +7191,7 @@ def detach_callable(typ: CallableType) -> CallableType:
id=var.id,
values=var.values,
upper_bound=var.upper_bound,
default=var.default,
variance=var.variance,
)
)
Expand Down
66 changes: 60 additions & 6 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4189,7 +4189,14 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag:
# Used for list and set expressions, as well as for tuples
# containing star expressions that don't refer to a
# Tuple. (Note: "lst" stands for list-set-tuple. :-)
tv = TypeVarType("T", "T", id=-1, values=[], upper_bound=self.object_type())
tv = TypeVarType(
"T",
"T",
id=-1,
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
)
constructor = CallableType(
[tv],
[nodes.ARG_STAR],
Expand Down Expand Up @@ -4357,8 +4364,22 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
return dt

# Define type variables (used in constructors below).
kt = TypeVarType("KT", "KT", id=-1, values=[], upper_bound=self.object_type())
vt = TypeVarType("VT", "VT", id=-2, values=[], upper_bound=self.object_type())
kt = TypeVarType(
"KT",
"KT",
id=-1,
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
)
vt = TypeVarType(
"VT",
"VT",
id=-2,
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
)

# Collect function arguments, watching out for **expr.
args: list[Expression] = []
Expand Down Expand Up @@ -4722,7 +4743,14 @@ def check_generator_or_comprehension(

# Infer the type of the list comprehension by using a synthetic generic
# callable type.
tv = TypeVarType("T", "T", id=-1, values=[], upper_bound=self.object_type())
tv = TypeVarType(
"T",
"T",
id=-1,
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
)
tv_list: list[Type] = [tv]
constructor = CallableType(
tv_list,
Expand All @@ -4742,8 +4770,22 @@ def visit_dictionary_comprehension(self, e: DictionaryComprehension) -> Type:

# Infer the type of the list comprehension by using a synthetic generic
# callable type.
ktdef = TypeVarType("KT", "KT", id=-1, values=[], upper_bound=self.object_type())
vtdef = TypeVarType("VT", "VT", id=-2, values=[], upper_bound=self.object_type())
ktdef = TypeVarType(
"KT",
"KT",
id=-1,
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
)
vtdef = TypeVarType(
"VT",
"VT",
id=-2,
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
)
constructor = CallableType(
[ktdef, vtdef],
[nodes.ARG_POS, nodes.ARG_POS],
Expand Down Expand Up @@ -5264,6 +5306,18 @@ def visit_callable_type(self, t: CallableType) -> bool:
return False
return super().visit_callable_type(t)

def visit_type_var(self, t: TypeVarType) -> bool:
default = [t.default] if t.has_default() else []
return self.query_types([t.upper_bound, *default] + t.values)

def visit_param_spec(self, t: ParamSpecType) -> bool:
default = [t.default] if t.has_default() else []
return self.query_types([t.upper_bound, *default])

def visit_type_var_tuple(self, t: TypeVarTupleType) -> bool:
default = [t.default] if t.has_default() else []
return self.query_types([t.upper_bound, *default])


def has_coroutine_decorator(t: Type) -> bool:
"""Whether t came from a function decorated with `@coroutine`."""
Expand Down
8 changes: 6 additions & 2 deletions mypy/copytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def visit_type_var(self, t: TypeVarType) -> ProperType:
return self.copy_common(t, t.copy_modified())

def visit_param_spec(self, t: ParamSpecType) -> ProperType:
dup = ParamSpecType(t.name, t.fullname, t.id, t.flavor, t.upper_bound, prefix=t.prefix)
dup = ParamSpecType(
t.name, t.fullname, t.id, t.flavor, t.upper_bound, t.default, prefix=t.prefix
)
return self.copy_common(t, dup)

def visit_parameters(self, t: Parameters) -> ProperType:
Expand All @@ -86,7 +88,9 @@ def visit_parameters(self, t: Parameters) -> ProperType:
return self.copy_common(t, dup)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
dup = TypeVarTupleType(t.name, t.fullname, t.id, t.upper_bound, t.tuple_fallback)
dup = TypeVarTupleType(
t.name, t.fullname, t.id, t.upper_bound, t.tuple_fallback, t.default
)
return self.copy_common(t, dup)

def visit_unpack_type(self, t: UnpackType) -> ProperType:
Expand Down
9 changes: 1 addition & 8 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,7 @@ def freshen_function_type_vars(callee: F) -> F:
tvs = []
tvmap: dict[TypeVarId, Type] = {}
for v in callee.variables:
if isinstance(v, TypeVarType):
tv: TypeVarLikeType = TypeVarType.new_unification_variable(v)
elif isinstance(v, TypeVarTupleType):
assert isinstance(v, TypeVarTupleType)
tv = TypeVarTupleType.new_unification_variable(v)
else:
assert isinstance(v, ParamSpecType)
tv = ParamSpecType.new_unification_variable(v)
tv = v.new_unification_variable(v)
tvs.append(tv)
tvmap[v.id] = tv
fresh = expand_type(callee, tvmap).copy_modified(variables=tvs)
Expand Down
10 changes: 8 additions & 2 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,21 @@ def visit_class_def(self, c: ClassDef) -> None:
for value in v.values:
value.accept(self.type_fixer)
v.upper_bound.accept(self.type_fixer)
v.default.accept(self.type_fixer)

def visit_type_var_expr(self, tv: TypeVarExpr) -> None:
for value in tv.values:
value.accept(self.type_fixer)
tv.upper_bound.accept(self.type_fixer)
tv.default.accept(self.type_fixer)

def visit_paramspec_expr(self, p: ParamSpecExpr) -> None:
p.upper_bound.accept(self.type_fixer)
p.default.accept(self.type_fixer)

def visit_type_var_tuple_expr(self, tv: TypeVarTupleExpr) -> None:
tv.upper_bound.accept(self.type_fixer)
tv.default.accept(self.type_fixer)

def visit_var(self, v: Var) -> None:
if self.current_info is not None:
Expand Down Expand Up @@ -303,14 +307,16 @@ def visit_type_var(self, tvt: TypeVarType) -> None:
if tvt.values:
for vt in tvt.values:
vt.accept(self)
if tvt.upper_bound is not None:
tvt.upper_bound.accept(self)
tvt.upper_bound.accept(self)
tvt.default.accept(self)

def visit_param_spec(self, p: ParamSpecType) -> None:
p.upper_bound.accept(self)
p.default.accept(self)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
t.upper_bound.accept(self)
t.default.accept(self)

def visit_unpack_type(self, u: UnpackType) -> None:
u.type.accept(self)
Expand Down
6 changes: 3 additions & 3 deletions mypy/indirection.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ def visit_deleted_type(self, t: types.DeletedType) -> set[str]:
return set()

def visit_type_var(self, t: types.TypeVarType) -> set[str]:
return self._visit(t.values) | self._visit(t.upper_bound)
return self._visit(t.values) | self._visit(t.upper_bound) | self._visit(t.default)

def visit_param_spec(self, t: types.ParamSpecType) -> set[str]:
return set()
return self._visit(t.upper_bound) | self._visit(t.default)

def visit_type_var_tuple(self, t: types.TypeVarTupleType) -> set[str]:
return self._visit(t.upper_bound)
return self._visit(t.upper_bound) | self._visit(t.default)

def visit_unpack_type(self, t: types.UnpackType) -> set[str]:
return t.type.accept(self)
Expand Down
25 changes: 21 additions & 4 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2439,26 +2439,35 @@ class TypeVarLikeExpr(SymbolNode, Expression):
Note that they are constructed by the semantic analyzer.
"""

__slots__ = ("_name", "_fullname", "upper_bound", "variance")
__slots__ = ("_name", "_fullname", "upper_bound", "default", "variance")

_name: str
_fullname: str
# Upper bound: only subtypes of upper_bound are valid as values. By default
# this is 'object', meaning no restriction.
upper_bound: mypy.types.Type
# Default: used to resolve the TypeVar if the default is not explicitly given.
# By default this is 'AnyType(TypeOfAny.from_omitted_generics)'. See PEP 696.
default: mypy.types.Type
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved
# Variance of the type variable. Invariant is the default.
# TypeVar(..., covariant=True) defines a covariant type variable.
# TypeVar(..., contravariant=True) defines a contravariant type
# variable.
variance: int

def __init__(
self, name: str, fullname: str, upper_bound: mypy.types.Type, variance: int = INVARIANT
self,
name: str,
fullname: str,
upper_bound: mypy.types.Type,
default: mypy.types.Type,
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved
variance: int = INVARIANT,
) -> None:
super().__init__()
self._name = name
self._fullname = fullname
self.upper_bound = upper_bound
self.default = default
self.variance = variance

@property
Expand Down Expand Up @@ -2496,9 +2505,10 @@ def __init__(
fullname: str,
values: list[mypy.types.Type],
upper_bound: mypy.types.Type,
default: mypy.types.Type,
variance: int = INVARIANT,
) -> None:
super().__init__(name, fullname, upper_bound, variance)
super().__init__(name, fullname, upper_bound, default, variance)
self.values = values

def accept(self, visitor: ExpressionVisitor[T]) -> T:
Expand All @@ -2511,6 +2521,7 @@ def serialize(self) -> JsonDict:
"fullname": self._fullname,
"values": [t.serialize() for t in self.values],
"upper_bound": self.upper_bound.serialize(),
"default": self.default.serialize(),
"variance": self.variance,
}

Expand All @@ -2522,6 +2533,7 @@ def deserialize(cls, data: JsonDict) -> TypeVarExpr:
data["fullname"],
[mypy.types.deserialize_type(v) for v in data["values"]],
mypy.types.deserialize_type(data["upper_bound"]),
mypy.types.deserialize_type(data["default"]),
data["variance"],
)

Expand All @@ -2540,6 +2552,7 @@ def serialize(self) -> JsonDict:
"name": self._name,
"fullname": self._fullname,
"upper_bound": self.upper_bound.serialize(),
"default": self.default.serialize(),
"variance": self.variance,
}

Expand All @@ -2550,6 +2563,7 @@ def deserialize(cls, data: JsonDict) -> ParamSpecExpr:
data["name"],
data["fullname"],
mypy.types.deserialize_type(data["upper_bound"]),
mypy.types.deserialize_type(data["default"]),
data["variance"],
)

Expand All @@ -2569,9 +2583,10 @@ def __init__(
fullname: str,
upper_bound: mypy.types.Type,
tuple_fallback: mypy.types.Instance,
default: mypy.types.Type,
variance: int = INVARIANT,
) -> None:
super().__init__(name, fullname, upper_bound, variance)
super().__init__(name, fullname, upper_bound, default, variance)
self.tuple_fallback = tuple_fallback

def accept(self, visitor: ExpressionVisitor[T]) -> T:
Expand All @@ -2584,6 +2599,7 @@ def serialize(self) -> JsonDict:
"fullname": self._fullname,
"upper_bound": self.upper_bound.serialize(),
"tuple_fallback": self.tuple_fallback.serialize(),
"default": self.default.serialize(),
"variance": self.variance,
}

Expand All @@ -2595,6 +2611,7 @@ def deserialize(cls, data: JsonDict) -> TypeVarTupleExpr:
data["fullname"],
mypy.types.deserialize_type(data["upper_bound"]),
mypy.types.Instance.deserialize(data["tuple_fallback"]),
mypy.types.deserialize_type(data["default"]),
data["variance"],
)

Expand Down
7 changes: 6 additions & 1 deletion mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,9 +772,14 @@ def _add_order(ctx: mypy.plugin.ClassDefContext, adder: MethodAdder) -> None:
id=-1,
values=[],
upper_bound=object_type,
default=AnyType(TypeOfAny.from_omitted_generics),
)
self_tvar_expr = TypeVarExpr(
SELF_TVAR_NAME, ctx.cls.info.fullname + "." + SELF_TVAR_NAME, [], object_type
SELF_TVAR_NAME,
ctx.cls.info.fullname + "." + SELF_TVAR_NAME,
[],
object_type,
AnyType(TypeOfAny.from_omitted_generics),
)
ctx.cls.info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)

Expand Down
7 changes: 6 additions & 1 deletion mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,11 @@ def transform(self) -> bool:
# Type variable for self types in generated methods.
obj_type = self._api.named_type("builtins.object")
self_tvar_expr = TypeVarExpr(
SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, [], obj_type
SELF_TVAR_NAME,
info.fullname + "." + SELF_TVAR_NAME,
[],
obj_type,
AnyType(TypeOfAny.from_omitted_generics),
)
info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)

Expand All @@ -273,6 +277,7 @@ def transform(self) -> bool:
id=-1,
values=[],
upper_bound=obj_type,
default=AnyType(TypeOfAny.from_omitted_generics),
)
order_return_type = self._api.named_type("builtins.bool")
order_args = [
Expand Down
Loading
Loading