Skip to content

Commit

Permalink
[mypyc] Support Python 3.12 type alias syntax (PEP 695) (#17384)
Browse files Browse the repository at this point in the history
The main tricky bit is supporting uses of type alias objects at runtime.
Python evaluates values of type aliases lazily, but there's no way to do
this using public APIs, so we directly modify the `TypeAliasType` object
that is used to represent a type alias at runtime in C. Unfortunately,
this is fragile and will need to be updated each time CPython updates
the internal representation of `TypeAliasType` objects.

Wrap the target of the type alias within a lambda expression, so that we
can easily create the lazy compute function in mypyc. This also reflects
how this is implemented in CPython.

Improve test stubs to avoid various false positives or confusing errors
in tests when type checking runtime operations on types. This also makes
some exisisting tests more realistic.

Follow-up to #17357.
  • Loading branch information
JukkaL committed Jun 17, 2024
1 parent 31faa43 commit b202552
Show file tree
Hide file tree
Showing 32 changed files with 310 additions and 81 deletions.
4 changes: 4 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
TryStmt,
TupleExpr,
TypeAlias,
TypeAliasStmt,
TypeInfo,
TypeVarExpr,
UnaryExpr,
Expand Down Expand Up @@ -5289,6 +5290,9 @@ def remove_capture_conflicts(self, type_map: TypeMap, inferred_types: dict[Var,
if node not in inferred_types or not is_subtype(typ, inferred_types[node]):
del type_map[expr]

def visit_type_alias_stmt(self, o: TypeAliasStmt) -> None:
self.expr_checker.accept(o.value)

def make_fake_typeinfo(
self,
curr_module_fullname: str,
Expand Down
4 changes: 3 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,9 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
result = self.alias_type_in_runtime_context(
node, ctx=e, alias_definition=e.is_alias_rvalue or lvalue
)
elif isinstance(node, (TypeVarExpr, ParamSpecExpr, TypeVarTupleExpr)):
elif isinstance(node, TypeVarExpr):
return self.named_type("typing.TypeVar")
elif isinstance(node, (ParamSpecExpr, TypeVarTupleExpr)):
result = self.object_type()
else:
if isinstance(node, PlaceholderNode):
Expand Down
8 changes: 7 additions & 1 deletion mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1791,7 +1791,13 @@ def visit_TypeAlias(self, n: ast_TypeAlias) -> TypeAliasStmt | AssignmentStmt:
if NEW_GENERIC_SYNTAX in self.options.enable_incomplete_feature:
type_params = self.translate_type_params(n.type_params)
value = self.visit(n.value)
node = TypeAliasStmt(self.visit_Name(n.name), type_params, value)
# Since the value is evaluated lazily, wrap the value inside a lambda.
# This helps mypyc.
ret = ReturnStmt(value)
self.set_line(ret, n.value)
value_func = LambdaExpr(body=Block([ret]))
self.set_line(value_func, n.value)
node = TypeAliasStmt(self.visit_Name(n.name), type_params, value_func)
return self.set_line(node, n)
else:
self.fail(
Expand Down
4 changes: 2 additions & 2 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1653,10 +1653,10 @@ class TypeAliasStmt(Statement):

name: NameExpr
type_args: list[TypeParam]
value: Expression # Will get translated into a type
value: LambdaExpr # Return value will get translated into a type
invalid_recursive_alias: bool

def __init__(self, name: NameExpr, type_args: list[TypeParam], value: Expression) -> None:
def __init__(self, name: NameExpr, type_args: list[TypeParam], value: LambdaExpr) -> None:
super().__init__()
self.name = name
self.type_args = type_args
Expand Down
13 changes: 11 additions & 2 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3766,6 +3766,10 @@ def analyze_alias(
last_tvar_name_with_default = tvar_def.name
tvar_defs.append(tvar_def)

if python_3_12_type_alias:
with self.allow_unbound_tvars_set():
rvalue.accept(self)

analyzed, depends_on = analyze_type_alias(
typ,
self,
Expand Down Expand Up @@ -5360,7 +5364,7 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None:
tag = self.track_incomplete_refs()
res, alias_tvars, depends_on, qualified_tvars, empty_tuple_index = self.analyze_alias(
s.name.name,
s.value,
s.value.expr(),
allow_placeholder=True,
declared_type_vars=type_params,
all_declared_type_params_names=all_type_params_names,
Expand Down Expand Up @@ -5443,6 +5447,7 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None:
current_node = existing.node if existing else alias_node
assert isinstance(current_node, TypeAlias)
self.disable_invalid_recursive_aliases(s, current_node, s.value)
s.name.accept(self)
finally:
self.pop_type_args(s.type_args)

Expand All @@ -5457,7 +5462,11 @@ def visit_name_expr(self, expr: NameExpr) -> None:

def bind_name_expr(self, expr: NameExpr, sym: SymbolTableNode) -> None:
"""Bind name expression to a symbol table node."""
if isinstance(sym.node, TypeVarExpr) and self.tvar_scope.get_binding(sym):
if (
isinstance(sym.node, TypeVarExpr)
and self.tvar_scope.get_binding(sym)
and not self.allow_unbound_tvars
):
self.fail(f'"{expr.name}" is a type variable and only valid in type context', expr)
elif isinstance(sym.node, PlaceholderNode):
self.process_placeholder(expr.name, "name", expr)
Expand Down
46 changes: 46 additions & 0 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
ARG_POS,
GDEF,
LDEF,
PARAM_SPEC_KIND,
TYPE_VAR_KIND,
TYPE_VAR_TUPLE_KIND,
ArgKind,
CallExpr,
Decorator,
Expand All @@ -44,6 +47,7 @@
TupleExpr,
TypeAlias,
TypeInfo,
TypeParam,
UnaryExpr,
Var,
)
Expand Down Expand Up @@ -1409,3 +1413,45 @@ def get_call_target_fullname(ref: RefExpr) -> str:
if isinstance(target, Instance):
return target.type.fullname
return ref.fullname


def create_type_params(
builder: IRBuilder, typing_mod: Value, type_args: list[TypeParam], line: int
) -> list[Value]:
"""Create objects representing various kinds of Python 3.12 type parameters.
The "typing_mod" argument is the "_typing" module object. The type objects
are looked up from it.
The returned list has one item for each "type_args" item, in the same order.
Each item is either a TypeVar, TypeVarTuple or ParamSpec instance.
"""
tvs = []
type_var_imported: Value | None = None
for type_param in type_args:
if type_param.kind == TYPE_VAR_KIND:
if type_var_imported:
# Reuse previously imported value as a minor optimization
tvt = type_var_imported
else:
tvt = builder.py_get_attr(typing_mod, "TypeVar", line)
type_var_imported = tvt
elif type_param.kind == TYPE_VAR_TUPLE_KIND:
tvt = builder.py_get_attr(typing_mod, "TypeVarTuple", line)
else:
assert type_param.kind == PARAM_SPEC_KIND
tvt = builder.py_get_attr(typing_mod, "ParamSpec", line)
if type_param.kind != TYPE_VAR_TUPLE_KIND:
# To match runtime semantics, pass infer_variance=True
tv = builder.py_call(
tvt,
[builder.load_str(type_param.name), builder.true()],
line,
arg_kinds=[ARG_POS, ARG_NAMED],
arg_names=[None, "infer_variance"],
)
else:
tv = builder.py_call(tvt, [builder.load_str(type_param.name)], line)
builder.init_type_var(tv, type_param.name, line)
tvs.append(tv)
return tvs
37 changes: 10 additions & 27 deletions mypyc/irbuild/classdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from typing import Callable, Final

from mypy.nodes import (
PARAM_SPEC_KIND,
TYPE_VAR_KIND,
TYPE_VAR_TUPLE_KIND,
AssignmentStmt,
CallExpr,
Expand Down Expand Up @@ -57,7 +55,7 @@
is_optional_type,
object_rprimitive,
)
from mypyc.irbuild.builder import IRBuilder
from mypyc.irbuild.builder import IRBuilder, create_type_params
from mypyc.irbuild.function import (
gen_property_getter_ir,
gen_property_setter_ir,
Expand Down Expand Up @@ -475,35 +473,20 @@ def make_generic_base_class(
) -> Value:
"""Construct Generic[...] base class object for a new-style generic class (Python 3.12)."""
mod = builder.call_c(import_op, [builder.load_str("_typing")], line)
tvs = []
type_var_imported: Value | None = None
for type_param in type_args:
unpack = False
if type_param.kind == TYPE_VAR_KIND:
if type_var_imported:
# Reuse previously imported value as a minor optimization
tvt = type_var_imported
else:
tvt = builder.py_get_attr(mod, "TypeVar", line)
type_var_imported = tvt
elif type_param.kind == TYPE_VAR_TUPLE_KIND:
tvt = builder.py_get_attr(mod, "TypeVarTuple", line)
unpack = True
else:
assert type_param.kind == PARAM_SPEC_KIND
tvt = builder.py_get_attr(mod, "ParamSpec", line)
tv = builder.py_call(tvt, [builder.load_str(type_param.name)], line)
builder.init_type_var(tv, type_param.name, line)
if unpack:
tvs = create_type_params(builder, mod, type_args, line)
args = []
for tv, type_param in zip(tvs, type_args):
if type_param.kind == TYPE_VAR_TUPLE_KIND:
# Evaluate *Ts for a TypeVarTuple
it = builder.call_c(iter_op, [tv], line)
tv = builder.call_c(next_op, [it], line)
tvs.append(tv)
args.append(tv)

gent = builder.py_get_attr(mod, "Generic", line)
if len(tvs) == 1:
arg = tvs[0]
if len(args) == 1:
arg = args[0]
else:
arg = builder.new_tuple(tvs, line)
arg = builder.new_tuple(args, line)

base = builder.call_c(py_get_item_op, [gent, arg], line)
return base
Expand Down
34 changes: 33 additions & 1 deletion mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from typing import Callable, Sequence

from mypy.nodes import (
ARG_NAMED,
ARG_POS,
AssertStmt,
AssignmentStmt,
AwaitExpr,
Expand All @@ -37,6 +39,7 @@
TempNode,
TryStmt,
TupleExpr,
TypeAliasStmt,
WhileStmt,
WithStmt,
YieldExpr,
Expand Down Expand Up @@ -74,7 +77,7 @@
object_rprimitive,
)
from mypyc.irbuild.ast_helpers import is_borrow_friendly_expr, process_conditional
from mypyc.irbuild.builder import IRBuilder, int_borrow_friendly_op
from mypyc.irbuild.builder import IRBuilder, create_type_params, int_borrow_friendly_op
from mypyc.irbuild.for_helpers import for_loop_helper
from mypyc.irbuild.generator import add_raise_exception_blocks_to_generator_class
from mypyc.irbuild.nonlocalcontrol import (
Expand Down Expand Up @@ -105,7 +108,9 @@
coro_op,
import_from_many_op,
import_many_op,
import_op,
send_op,
set_type_alias_compute_function_op,
type_op,
yield_from_except_op,
)
Expand Down Expand Up @@ -1015,3 +1020,30 @@ def transform_await_expr(builder: IRBuilder, o: AwaitExpr) -> Value:

def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None:
m.accept(MatchVisitor(builder, m))


def transform_type_alias_stmt(builder: IRBuilder, s: TypeAliasStmt) -> None:
line = s.line
# Use "_typing" to avoid importing "typing", as the latter can be expensive.
# "_typing" includes everything we need here.
mod = builder.call_c(import_op, [builder.load_str("_typing")], line)
type_params = create_type_params(builder, mod, s.type_args, s.line)

type_alias_type = builder.py_get_attr(mod, "TypeAliasType", line)
args = [builder.load_str(s.name.name), builder.none()]
arg_names: list[str | None] = [None, None]
arg_kinds = [ARG_POS, ARG_POS]
if s.type_args:
args.append(builder.new_tuple(type_params, line))
arg_names.append("type_params")
arg_kinds.append(ARG_NAMED)
alias = builder.py_call(type_alias_type, args, line, arg_names=arg_names, arg_kinds=arg_kinds)

# Use primitive to set function used to lazily compute type alias type value.
# The value needs to be lazily computed to match Python runtime behavior, but
# Python public APIs don't support this, so we use a C primitive.
compute_fn = s.value.accept(builder.visitor)
builder.builder.primitive_op(set_type_alias_compute_function_op, [alias, compute_fn], line)

target = builder.get_assignment_target(s.name)
builder.assign(target, alias, line)
3 changes: 2 additions & 1 deletion mypyc/irbuild/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
transform_raise_stmt,
transform_return_stmt,
transform_try_stmt,
transform_type_alias_stmt,
transform_while_stmt,
transform_with_stmt,
transform_yield_expr,
Expand Down Expand Up @@ -251,7 +252,7 @@ def visit_match_stmt(self, stmt: MatchStmt) -> None:
transform_match_stmt(self.builder, stmt)

def visit_type_alias_stmt(self, stmt: TypeAliasStmt) -> None:
self.bail('The "type" statement is not yet supported by mypyc', stmt.line)
transform_type_alias_stmt(self.builder, stmt)

# Expressions

Expand Down
1 change: 1 addition & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,7 @@ PyObject *CPySingledispatch_RegisterFunction(PyObject *singledispatch_func, PyOb

PyObject *CPy_GetAIter(PyObject *obj);
PyObject *CPy_GetANext(PyObject *aiter);
void CPy_SetTypeAliasTypeComputeFunction(PyObject *alias, PyObject *compute_value);

#ifdef __cplusplus
}
Expand Down
31 changes: 31 additions & 0 deletions mypyc/lib-rt/misc_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -940,3 +940,34 @@ PyObject *CPy_GetANext(PyObject *aiter)
error:
return NULL;
}

#ifdef CPY_3_12_FEATURES

// Copied from Python 3.12.3, since this struct is internal to CPython. It defines
// the structure of typing.TypeAliasType objects. We need it since compute_value is
// not part of the public API, and we need to set it to match Python runtime semantics.
//
// IMPORTANT: This needs to be kept in sync with CPython!
typedef struct {
PyObject_HEAD
PyObject *name;
PyObject *type_params;
PyObject *compute_value;
PyObject *value;
PyObject *module;
} typealiasobject;

void CPy_SetTypeAliasTypeComputeFunction(PyObject *alias, PyObject *compute_value) {
typealiasobject *obj = (typealiasobject *)alias;
if (obj->value != NULL) {
Py_DECREF(obj->value);
}
obj->value = NULL;
Py_INCREF(compute_value);
if (obj->compute_value != NULL) {
Py_DECREF(obj->compute_value);
}
obj->compute_value = compute_value;
}

#endif
12 changes: 12 additions & 0 deletions mypyc/primitives/misc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,15 @@
return_type=c_pyssize_t_rprimitive,
error_kind=ERR_NEVER,
)

# Set the lazy value compute function of an TypeAliasType instance (Python 3.12+).
# This must only be used as part of initializing the object. Any existing value
# will be cleared.
set_type_alias_compute_function_op = custom_primitive_op(
name="set_type_alias_compute_function",
c_function_name="CPy_SetTypeAliasTypeComputeFunction",
# (alias object, value compute function)
arg_types=[object_rprimitive, object_rprimitive],
return_type=void_rtype,
error_kind=ERR_NEVER,
)
1 change: 1 addition & 0 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __ne__(self, x: object) -> bool: pass

class type:
def __init__(self, o: object) -> None: ...
def __or__(self, o: object) -> Any: ...
__name__ : str
__annotations__: Dict[str, Any]

Expand Down
Loading

0 comments on commit b202552

Please sign in to comment.