Skip to content

Commit

Permalink
Add foundation for TypeVar defaults (PEP 696) (#14872)
Browse files Browse the repository at this point in the history
Start implementing [PEP 696](https://peps.python.org/pep-0696/) TypeVar
defaults. This PR
* Adds a `default` parameter to `TypeVarLikeExpr` and `TypeVarLikeType`.
* Updates most visitors to account for the new `default` parameter.
* Update existing calls to add value for `default` =>
`AnyType(TypeOfAny.from_omitted_generics)`.

A followup PR will update the semantic analyzer and add basic tests for
`TypeVar`, `ParamSpec`, and `TypeVarTuple` calls with a `default`
argument. -> #14873

Ref #14851
  • Loading branch information
cdce8p committed May 29, 2023
1 parent 7fe1fdd commit a568f3a
Show file tree
Hide file tree
Showing 23 changed files with 356 additions and 84 deletions.
1 change: 1 addition & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7191,6 +7191,7 @@ def detach_callable(typ: CallableType) -> CallableType:
id=var.id,
values=var.values,
upper_bound=var.upper_bound,
default=var.default,
variance=var.variance,
)
)
Expand Down
66 changes: 60 additions & 6 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4189,7 +4189,14 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag:
# Used for list and set expressions, as well as for tuples
# containing star expressions that don't refer to a
# Tuple. (Note: "lst" stands for list-set-tuple. :-)
tv = TypeVarType("T", "T", id=-1, values=[], upper_bound=self.object_type())
tv = TypeVarType(
"T",
"T",
id=-1,
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
)
constructor = CallableType(
[tv],
[nodes.ARG_STAR],
Expand Down Expand Up @@ -4357,8 +4364,22 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
return dt

# Define type variables (used in constructors below).
kt = TypeVarType("KT", "KT", id=-1, values=[], upper_bound=self.object_type())
vt = TypeVarType("VT", "VT", id=-2, values=[], upper_bound=self.object_type())
kt = TypeVarType(
"KT",
"KT",
id=-1,
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
)
vt = TypeVarType(
"VT",
"VT",
id=-2,
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
)

# Collect function arguments, watching out for **expr.
args: list[Expression] = []
Expand Down Expand Up @@ -4722,7 +4743,14 @@ def check_generator_or_comprehension(

# Infer the type of the list comprehension by using a synthetic generic
# callable type.
tv = TypeVarType("T", "T", id=-1, values=[], upper_bound=self.object_type())
tv = TypeVarType(
"T",
"T",
id=-1,
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
)
tv_list: list[Type] = [tv]
constructor = CallableType(
tv_list,
Expand All @@ -4742,8 +4770,22 @@ def visit_dictionary_comprehension(self, e: DictionaryComprehension) -> Type:

# Infer the type of the list comprehension by using a synthetic generic
# callable type.
ktdef = TypeVarType("KT", "KT", id=-1, values=[], upper_bound=self.object_type())
vtdef = TypeVarType("VT", "VT", id=-2, values=[], upper_bound=self.object_type())
ktdef = TypeVarType(
"KT",
"KT",
id=-1,
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
)
vtdef = TypeVarType(
"VT",
"VT",
id=-2,
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
)
constructor = CallableType(
[ktdef, vtdef],
[nodes.ARG_POS, nodes.ARG_POS],
Expand Down Expand Up @@ -5264,6 +5306,18 @@ def visit_callable_type(self, t: CallableType) -> bool:
return False
return super().visit_callable_type(t)

def visit_type_var(self, t: TypeVarType) -> bool:
default = [t.default] if t.has_default() else []
return self.query_types([t.upper_bound, *default] + t.values)

def visit_param_spec(self, t: ParamSpecType) -> bool:
default = [t.default] if t.has_default() else []
return self.query_types([t.upper_bound, *default])

def visit_type_var_tuple(self, t: TypeVarTupleType) -> bool:
default = [t.default] if t.has_default() else []
return self.query_types([t.upper_bound, *default])


def has_coroutine_decorator(t: Type) -> bool:
"""Whether t came from a function decorated with `@coroutine`."""
Expand Down
8 changes: 6 additions & 2 deletions mypy/copytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def visit_type_var(self, t: TypeVarType) -> ProperType:
return self.copy_common(t, t.copy_modified())

def visit_param_spec(self, t: ParamSpecType) -> ProperType:
dup = ParamSpecType(t.name, t.fullname, t.id, t.flavor, t.upper_bound, prefix=t.prefix)
dup = ParamSpecType(
t.name, t.fullname, t.id, t.flavor, t.upper_bound, t.default, prefix=t.prefix
)
return self.copy_common(t, dup)

def visit_parameters(self, t: Parameters) -> ProperType:
Expand All @@ -86,7 +88,9 @@ def visit_parameters(self, t: Parameters) -> ProperType:
return self.copy_common(t, dup)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
dup = TypeVarTupleType(t.name, t.fullname, t.id, t.upper_bound, t.tuple_fallback)
dup = TypeVarTupleType(
t.name, t.fullname, t.id, t.upper_bound, t.tuple_fallback, t.default
)
return self.copy_common(t, dup)

def visit_unpack_type(self, t: UnpackType) -> ProperType:
Expand Down
9 changes: 1 addition & 8 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,7 @@ def freshen_function_type_vars(callee: F) -> F:
tvs = []
tvmap: dict[TypeVarId, Type] = {}
for v in callee.variables:
if isinstance(v, TypeVarType):
tv: TypeVarLikeType = TypeVarType.new_unification_variable(v)
elif isinstance(v, TypeVarTupleType):
assert isinstance(v, TypeVarTupleType)
tv = TypeVarTupleType.new_unification_variable(v)
else:
assert isinstance(v, ParamSpecType)
tv = ParamSpecType.new_unification_variable(v)
tv = v.new_unification_variable(v)
tvs.append(tv)
tvmap[v.id] = tv
fresh = expand_type(callee, tvmap).copy_modified(variables=tvs)
Expand Down
10 changes: 8 additions & 2 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,21 @@ def visit_class_def(self, c: ClassDef) -> None:
for value in v.values:
value.accept(self.type_fixer)
v.upper_bound.accept(self.type_fixer)
v.default.accept(self.type_fixer)

def visit_type_var_expr(self, tv: TypeVarExpr) -> None:
for value in tv.values:
value.accept(self.type_fixer)
tv.upper_bound.accept(self.type_fixer)
tv.default.accept(self.type_fixer)

def visit_paramspec_expr(self, p: ParamSpecExpr) -> None:
p.upper_bound.accept(self.type_fixer)
p.default.accept(self.type_fixer)

def visit_type_var_tuple_expr(self, tv: TypeVarTupleExpr) -> None:
tv.upper_bound.accept(self.type_fixer)
tv.default.accept(self.type_fixer)

def visit_var(self, v: Var) -> None:
if self.current_info is not None:
Expand Down Expand Up @@ -303,14 +307,16 @@ def visit_type_var(self, tvt: TypeVarType) -> None:
if tvt.values:
for vt in tvt.values:
vt.accept(self)
if tvt.upper_bound is not None:
tvt.upper_bound.accept(self)
tvt.upper_bound.accept(self)
tvt.default.accept(self)

def visit_param_spec(self, p: ParamSpecType) -> None:
p.upper_bound.accept(self)
p.default.accept(self)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
t.upper_bound.accept(self)
t.default.accept(self)

def visit_unpack_type(self, u: UnpackType) -> None:
u.type.accept(self)
Expand Down
6 changes: 3 additions & 3 deletions mypy/indirection.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ def visit_deleted_type(self, t: types.DeletedType) -> set[str]:
return set()

def visit_type_var(self, t: types.TypeVarType) -> set[str]:
return self._visit(t.values) | self._visit(t.upper_bound)
return self._visit(t.values) | self._visit(t.upper_bound) | self._visit(t.default)

def visit_param_spec(self, t: types.ParamSpecType) -> set[str]:
return set()
return self._visit(t.upper_bound) | self._visit(t.default)

def visit_type_var_tuple(self, t: types.TypeVarTupleType) -> set[str]:
return self._visit(t.upper_bound)
return self._visit(t.upper_bound) | self._visit(t.default)

def visit_unpack_type(self, t: types.UnpackType) -> set[str]:
return t.type.accept(self)
Expand Down
25 changes: 21 additions & 4 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2439,26 +2439,35 @@ class TypeVarLikeExpr(SymbolNode, Expression):
Note that they are constructed by the semantic analyzer.
"""

__slots__ = ("_name", "_fullname", "upper_bound", "variance")
__slots__ = ("_name", "_fullname", "upper_bound", "default", "variance")

_name: str
_fullname: str
# Upper bound: only subtypes of upper_bound are valid as values. By default
# this is 'object', meaning no restriction.
upper_bound: mypy.types.Type
# Default: used to resolve the TypeVar if the default is not explicitly given.
# By default this is 'AnyType(TypeOfAny.from_omitted_generics)'. See PEP 696.
default: mypy.types.Type
# Variance of the type variable. Invariant is the default.
# TypeVar(..., covariant=True) defines a covariant type variable.
# TypeVar(..., contravariant=True) defines a contravariant type
# variable.
variance: int

def __init__(
self, name: str, fullname: str, upper_bound: mypy.types.Type, variance: int = INVARIANT
self,
name: str,
fullname: str,
upper_bound: mypy.types.Type,
default: mypy.types.Type,
variance: int = INVARIANT,
) -> None:
super().__init__()
self._name = name
self._fullname = fullname
self.upper_bound = upper_bound
self.default = default
self.variance = variance

@property
Expand Down Expand Up @@ -2496,9 +2505,10 @@ def __init__(
fullname: str,
values: list[mypy.types.Type],
upper_bound: mypy.types.Type,
default: mypy.types.Type,
variance: int = INVARIANT,
) -> None:
super().__init__(name, fullname, upper_bound, variance)
super().__init__(name, fullname, upper_bound, default, variance)
self.values = values

def accept(self, visitor: ExpressionVisitor[T]) -> T:
Expand All @@ -2511,6 +2521,7 @@ def serialize(self) -> JsonDict:
"fullname": self._fullname,
"values": [t.serialize() for t in self.values],
"upper_bound": self.upper_bound.serialize(),
"default": self.default.serialize(),
"variance": self.variance,
}

Expand All @@ -2522,6 +2533,7 @@ def deserialize(cls, data: JsonDict) -> TypeVarExpr:
data["fullname"],
[mypy.types.deserialize_type(v) for v in data["values"]],
mypy.types.deserialize_type(data["upper_bound"]),
mypy.types.deserialize_type(data["default"]),
data["variance"],
)

Expand All @@ -2540,6 +2552,7 @@ def serialize(self) -> JsonDict:
"name": self._name,
"fullname": self._fullname,
"upper_bound": self.upper_bound.serialize(),
"default": self.default.serialize(),
"variance": self.variance,
}

Expand All @@ -2550,6 +2563,7 @@ def deserialize(cls, data: JsonDict) -> ParamSpecExpr:
data["name"],
data["fullname"],
mypy.types.deserialize_type(data["upper_bound"]),
mypy.types.deserialize_type(data["default"]),
data["variance"],
)

Expand All @@ -2569,9 +2583,10 @@ def __init__(
fullname: str,
upper_bound: mypy.types.Type,
tuple_fallback: mypy.types.Instance,
default: mypy.types.Type,
variance: int = INVARIANT,
) -> None:
super().__init__(name, fullname, upper_bound, variance)
super().__init__(name, fullname, upper_bound, default, variance)
self.tuple_fallback = tuple_fallback

def accept(self, visitor: ExpressionVisitor[T]) -> T:
Expand All @@ -2584,6 +2599,7 @@ def serialize(self) -> JsonDict:
"fullname": self._fullname,
"upper_bound": self.upper_bound.serialize(),
"tuple_fallback": self.tuple_fallback.serialize(),
"default": self.default.serialize(),
"variance": self.variance,
}

Expand All @@ -2595,6 +2611,7 @@ def deserialize(cls, data: JsonDict) -> TypeVarTupleExpr:
data["fullname"],
mypy.types.deserialize_type(data["upper_bound"]),
mypy.types.Instance.deserialize(data["tuple_fallback"]),
mypy.types.deserialize_type(data["default"]),
data["variance"],
)

Expand Down
7 changes: 6 additions & 1 deletion mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,9 +772,14 @@ def _add_order(ctx: mypy.plugin.ClassDefContext, adder: MethodAdder) -> None:
id=-1,
values=[],
upper_bound=object_type,
default=AnyType(TypeOfAny.from_omitted_generics),
)
self_tvar_expr = TypeVarExpr(
SELF_TVAR_NAME, ctx.cls.info.fullname + "." + SELF_TVAR_NAME, [], object_type
SELF_TVAR_NAME,
ctx.cls.info.fullname + "." + SELF_TVAR_NAME,
[],
object_type,
AnyType(TypeOfAny.from_omitted_generics),
)
ctx.cls.info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)

Expand Down
7 changes: 6 additions & 1 deletion mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,11 @@ def transform(self) -> bool:
# Type variable for self types in generated methods.
obj_type = self._api.named_type("builtins.object")
self_tvar_expr = TypeVarExpr(
SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, [], obj_type
SELF_TVAR_NAME,
info.fullname + "." + SELF_TVAR_NAME,
[],
obj_type,
AnyType(TypeOfAny.from_omitted_generics),
)
info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)

Expand All @@ -273,6 +277,7 @@ def transform(self) -> bool:
id=-1,
values=[],
upper_bound=obj_type,
default=AnyType(TypeOfAny.from_omitted_generics),
)
order_return_type = self._api.named_type("builtins.bool")
order_args = [
Expand Down

0 comments on commit a568f3a

Please sign in to comment.