Skip to content

Commit b3e26e7

Browse files
authored
[mypyc] Fix inheritance of async defs (#20044)
When inferring a precise generator return type for an async def (or generator), make the generate returned by an override in a subclass inherit from the base class generator. This means that the environment has to be moved to a separate class in the base class generator. Don't infer a precise generator return type when an override might have a less precise return type, since it would break LSP. Fixes mypyc/mypyc#1141.
1 parent 6aa44da commit b3e26e7

File tree

9 files changed

+212
-24
lines changed

9 files changed

+212
-24
lines changed

mypyc/codegen/emitclass.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def setter_name(cl: ClassIR, attribute: str, names: NameGenerator) -> str:
410410

411411

412412
def generate_object_struct(cl: ClassIR, emitter: Emitter) -> None:
413-
seen_attrs: set[tuple[str, RType]] = set()
413+
seen_attrs: set[str] = set()
414414
lines: list[str] = []
415415
lines += ["typedef struct {", "PyObject_HEAD", "CPyVTableItem *vtable;"]
416416
if cl.has_method("__call__"):
@@ -427,9 +427,11 @@ def generate_object_struct(cl: ClassIR, emitter: Emitter) -> None:
427427
lines.append(f"{BITMAP_TYPE} {attr};")
428428
bitmap_attrs.append(attr)
429429
for attr, rtype in base.attributes.items():
430-
if (attr, rtype) not in seen_attrs:
430+
# Generated class may redefine certain attributes with different
431+
# types in subclasses (this would be unsafe for user-defined classes).
432+
if attr not in seen_attrs:
431433
lines.append(f"{emitter.ctype_spaced(rtype)}{emitter.attr(attr)};")
432-
seen_attrs.add((attr, rtype))
434+
seen_attrs.add(attr)
433435

434436
if isinstance(rtype, RTuple):
435437
emitter.declare_tuple_struct(rtype)

mypyc/codegen/emitmodule.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,8 @@ def emit_module_exec_func(
10641064
"(PyObject *){t}_template, NULL, modname);".format(t=type_struct)
10651065
)
10661066
emitter.emit_lines(f"if (unlikely(!{type_struct}))", " goto fail;")
1067+
name_prefix = cl.name_prefix(emitter.names)
1068+
emitter.emit_line(f"CPyDef_{name_prefix}_trait_vtable_setup();")
10671069

10681070
emitter.emit_lines("if (CPyGlobalsInit() < 0)", " goto fail;")
10691071

mypyc/ir/func_ir.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,11 @@ def __init__(
149149
module_name: str,
150150
sig: FuncSignature,
151151
kind: int = FUNC_NORMAL,
152+
*,
152153
is_prop_setter: bool = False,
153154
is_prop_getter: bool = False,
155+
is_generator: bool = False,
156+
is_coroutine: bool = False,
154157
implicit: bool = False,
155158
internal: bool = False,
156159
) -> None:
@@ -161,6 +164,8 @@ def __init__(
161164
self.kind = kind
162165
self.is_prop_setter = is_prop_setter
163166
self.is_prop_getter = is_prop_getter
167+
self.is_generator = is_generator
168+
self.is_coroutine = is_coroutine
164169
if class_name is None:
165170
self.bound_sig: FuncSignature | None = None
166171
else:
@@ -219,6 +224,8 @@ def serialize(self) -> JsonDict:
219224
"kind": self.kind,
220225
"is_prop_setter": self.is_prop_setter,
221226
"is_prop_getter": self.is_prop_getter,
227+
"is_generator": self.is_generator,
228+
"is_coroutine": self.is_coroutine,
222229
"implicit": self.implicit,
223230
"internal": self.internal,
224231
}
@@ -240,10 +247,12 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> FuncDecl:
240247
data["module_name"],
241248
FuncSignature.deserialize(data["sig"], ctx),
242249
data["kind"],
243-
data["is_prop_setter"],
244-
data["is_prop_getter"],
245-
data["implicit"],
246-
data["internal"],
250+
is_prop_setter=data["is_prop_setter"],
251+
is_prop_getter=data["is_prop_getter"],
252+
is_generator=data["is_generator"],
253+
is_coroutine=data["is_coroutine"],
254+
implicit=data["implicit"],
255+
internal=data["internal"],
247256
)
248257

249258

mypyc/irbuild/context.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,11 @@ def curr_env_reg(self) -> Value:
9898
def can_merge_generator_and_env_classes(self) -> bool:
9999
# In simple cases we can place the environment into the generator class,
100100
# instead of having two separate classes.
101-
return self.is_generator and not self.is_nested and not self.contains_nested
101+
if self._generator_class and not self._generator_class.ir.is_final_class:
102+
result = False
103+
else:
104+
result = self.is_generator and not self.is_nested and not self.contains_nested
105+
return result
102106

103107

104108
class ImplicitClass:

mypyc/irbuild/function.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
instantiate_callable_class,
7070
setup_callable_class,
7171
)
72-
from mypyc.irbuild.context import FuncInfo
72+
from mypyc.irbuild.context import FuncInfo, GeneratorClass
7373
from mypyc.irbuild.env_class import (
7474
add_vars_to_env,
7575
finalize_env_class,
@@ -246,6 +246,12 @@ def c() -> None:
246246
is_generator = fn_info.is_generator
247247
builder.enter(fn_info, ret_type=sig.ret_type)
248248

249+
if is_generator:
250+
fitem = builder.fn_info.fitem
251+
assert isinstance(fitem, FuncDef), fitem
252+
generator_class_ir = builder.mapper.fdef_to_generator[fitem]
253+
builder.fn_info.generator_class = GeneratorClass(generator_class_ir)
254+
249255
# Functions that contain nested functions need an environment class to store variables that
250256
# are free in their nested functions. Generator functions need an environment class to
251257
# store a variable denoting the next instruction to be executed when the __next__ function
@@ -357,8 +363,8 @@ def gen_func_ir(
357363
builder.module_name,
358364
sig,
359365
func_decl.kind,
360-
func_decl.is_prop_getter,
361-
func_decl.is_prop_setter,
366+
is_prop_getter=func_decl.is_prop_getter,
367+
is_prop_setter=func_decl.is_prop_setter,
362368
)
363369
func_ir = FuncIR(func_decl, args, blocks, fitem.line, traceback_name=fitem.name)
364370
else:

mypyc/irbuild/generator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
object_rprimitive,
4040
)
4141
from mypyc.irbuild.builder import IRBuilder, calculate_arg_defaults, gen_arg_defaults
42-
from mypyc.irbuild.context import FuncInfo, GeneratorClass
42+
from mypyc.irbuild.context import FuncInfo
4343
from mypyc.irbuild.env_class import (
4444
add_args_to_env,
4545
add_vars_to_env,
@@ -166,10 +166,8 @@ def setup_generator_class(builder: IRBuilder) -> ClassIR:
166166
builder.fn_info.env_class = generator_class_ir
167167
else:
168168
generator_class_ir.attributes[ENV_ATTR_NAME] = RInstance(builder.fn_info.env_class)
169-
generator_class_ir.mro = [generator_class_ir]
170169

171170
builder.classes.append(generator_class_ir)
172-
builder.fn_info.generator_class = GeneratorClass(generator_class_ir)
173171
return generator_class_ir
174172

175173

mypyc/irbuild/prepare.py

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,15 @@ def prepare_func_def(
202202
else (FUNC_STATICMETHOD if fdef.is_static else FUNC_NORMAL)
203203
)
204204
sig = mapper.fdef_to_sig(fdef, options.strict_dunders_typing)
205-
decl = FuncDecl(fdef.name, class_name, module_name, sig, kind)
205+
decl = FuncDecl(
206+
fdef.name,
207+
class_name,
208+
module_name,
209+
sig,
210+
kind,
211+
is_generator=fdef.is_generator,
212+
is_coroutine=fdef.is_coroutine,
213+
)
206214
mapper.func_to_decl[fdef] = decl
207215
return decl
208216

@@ -217,7 +225,7 @@ def create_generator_class_for_func(
217225
"""
218226
assert fdef.is_coroutine or fdef.is_generator
219227
name = "_".join(x for x in [fdef.name, class_name] if x) + "_gen" + name_suffix
220-
cir = ClassIR(name, module_name, is_generated=True, is_final_class=True)
228+
cir = ClassIR(name, module_name, is_generated=True, is_final_class=class_name is None)
221229
cir.reuse_freed_instance = True
222230
mapper.fdef_to_generator[fdef] = cir
223231

@@ -816,14 +824,70 @@ def adjust_generator_classes_of_methods(mapper: Mapper) -> None:
816824
This is a separate pass after type map has been built, since we need all classes
817825
to be processed to analyze class hierarchies.
818826
"""
819-
for fdef, ir in mapper.func_to_decl.items():
827+
828+
generator_methods = []
829+
830+
for fdef, fn_ir in mapper.func_to_decl.items():
820831
if isinstance(fdef, FuncDef) and (fdef.is_coroutine or fdef.is_generator):
821-
gen_ir = create_generator_class_for_func(ir.module_name, ir.class_name, fdef, mapper)
832+
gen_ir = create_generator_class_for_func(
833+
fn_ir.module_name, fn_ir.class_name, fdef, mapper
834+
)
822835
# TODO: We could probably support decorators sometimes (static and class method?)
823836
if not fdef.is_decorated:
824-
# Give a more precise type for generators, so that we can optimize
825-
# code that uses them. They return a generator object, which has a
826-
# specific class. Without this, the type would have to be 'object'.
827-
ir.sig.ret_type = RInstance(gen_ir)
828-
if ir.bound_sig:
829-
ir.bound_sig.ret_type = RInstance(gen_ir)
837+
name = fn_ir.name
838+
precise_ret_type = True
839+
if fn_ir.class_name is not None:
840+
class_ir = mapper.type_to_ir[fdef.info]
841+
subcls = class_ir.subclasses()
842+
if subcls is None:
843+
# Override could be of a different type, so we can't make assumptions.
844+
precise_ret_type = False
845+
else:
846+
for s in subcls:
847+
if name in s.method_decls:
848+
m = s.method_decls[name]
849+
if (
850+
m.is_generator != fn_ir.is_generator
851+
or m.is_coroutine != fn_ir.is_coroutine
852+
):
853+
# Override is of a different kind, and the optimization
854+
# to use a precise generator return type doesn't work.
855+
precise_ret_type = False
856+
else:
857+
class_ir = None
858+
859+
if precise_ret_type:
860+
# Give a more precise type for generators, so that we can optimize
861+
# code that uses them. They return a generator object, which has a
862+
# specific class. Without this, the type would have to be 'object'.
863+
fn_ir.sig.ret_type = RInstance(gen_ir)
864+
if fn_ir.bound_sig:
865+
fn_ir.bound_sig.ret_type = RInstance(gen_ir)
866+
if class_ir is not None:
867+
if class_ir.is_method_final(name):
868+
gen_ir.is_final_class = True
869+
generator_methods.append((name, class_ir, gen_ir))
870+
871+
new_bases = {}
872+
873+
for name, class_ir, gen in generator_methods:
874+
# For generator methods, we need to have subclass generator classes inherit from
875+
# baseclass generator classes when there are overrides to maintain LSP.
876+
base = class_ir.real_base()
877+
if base is not None:
878+
if base.has_method(name):
879+
base_sig = base.method_sig(name)
880+
if isinstance(base_sig.ret_type, RInstance):
881+
base_gen = base_sig.ret_type.class_ir
882+
new_bases[gen] = base_gen
883+
884+
# Add generator inheritance relationships by adjusting MROs.
885+
for deriv, base in new_bases.items():
886+
if base.children is not None:
887+
base.children.append(deriv)
888+
while True:
889+
deriv.mro.append(base)
890+
deriv.base_mro.append(base)
891+
if base not in new_bases:
892+
break
893+
base = new_bases[base]

mypyc/test-data/run-async.test

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,3 +1291,77 @@ class CancelledError(Exception): ...
12911291
def run(x: object) -> object: ...
12921292
def get_running_loop() -> Any: ...
12931293
def create_task(x: object) -> Any: ...
1294+
1295+
[case testAsyncInheritance1]
1296+
from typing import final, Coroutine, Any, TypeVar
1297+
1298+
import asyncio
1299+
1300+
class Base1:
1301+
async def foo(self) -> int:
1302+
return 1
1303+
1304+
class Derived1(Base1):
1305+
async def foo(self) -> int:
1306+
return await super().foo() + 1
1307+
1308+
async def base1_foo(b: Base1) -> int:
1309+
return await b.foo()
1310+
1311+
async def derived1_foo(b: Derived1) -> int:
1312+
return await b.foo()
1313+
1314+
def test_async_inheritance() -> None:
1315+
assert asyncio.run(base1_foo(Base1())) == 1
1316+
assert asyncio.run(base1_foo(Derived1())) == 2
1317+
assert asyncio.run(derived1_foo(Derived1())) == 2
1318+
1319+
@final
1320+
class FinalClass:
1321+
async def foo(self) -> int:
1322+
return 3
1323+
1324+
async def final_class_foo(b: FinalClass) -> int:
1325+
return await b.foo()
1326+
1327+
def test_final_class() -> None:
1328+
assert asyncio.run(final_class_foo(FinalClass())) == 3
1329+
1330+
class Base2:
1331+
async def foo(self) -> int:
1332+
return 4
1333+
1334+
async def bar(self) -> int:
1335+
return 5
1336+
1337+
class Derived2(Base2):
1338+
# Does not override "foo"
1339+
async def bar(self) -> int:
1340+
return 6
1341+
1342+
async def base2_foo(b: Base2) -> int:
1343+
return await b.foo()
1344+
1345+
def test_no_override() -> None:
1346+
assert asyncio.run(base2_foo(Base2())) == 4
1347+
assert asyncio.run(base2_foo(Derived2())) == 4
1348+
1349+
class Base3:
1350+
async def foo(self) -> int:
1351+
return 7
1352+
1353+
class Derived3(Base3):
1354+
def foo(self) -> Coroutine[Any, Any, int]:
1355+
async def inner() -> int:
1356+
return 8
1357+
return inner()
1358+
1359+
async def base3_foo(b: Base3) -> int:
1360+
return await b.foo()
1361+
1362+
def test_override_non_async() -> None:
1363+
assert asyncio.run(base3_foo(Base3())) == 7
1364+
assert asyncio.run(base3_foo(Derived3())) == 8
1365+
1366+
[file asyncio/__init__.pyi]
1367+
def run(x: object) -> object: ...

mypyc/test-data/run-generators.test

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,3 +907,32 @@ def test_same_names() -> None:
907907
# matches the variable name in the input code, since internally it's generated
908908
# with a prefix.
909909
list(undefined())
910+
911+
[case testGeneratorInheritance]
912+
from typing import Iterator
913+
914+
class Base1:
915+
def foo(self) -> Iterator[int]:
916+
yield 1
917+
918+
class Derived1(Base1):
919+
def foo(self) -> Iterator[int]:
920+
yield 2
921+
yield 3
922+
923+
def base1_foo(b: Base1) -> list[int]:
924+
a = []
925+
for x in b.foo():
926+
a.append(x)
927+
return a
928+
929+
def derived1_foo(b: Derived1) -> list[int]:
930+
a = []
931+
for x in b.foo():
932+
a.append(x)
933+
return a
934+
935+
def test_generator_override() -> None:
936+
assert base1_foo(Base1()) == [1]
937+
assert base1_foo(Derived1()) == [2, 3]
938+
assert derived1_foo(Derived1()) == [2, 3]

0 commit comments

Comments
 (0)