diff --git a/mypyc/codegen/emit.py b/mypyc/codegen/emit.py index 778e14355d68..34837a73adbd 100644 --- a/mypyc/codegen/emit.py +++ b/mypyc/codegen/emit.py @@ -23,7 +23,7 @@ TYPE_PREFIX, ) from mypyc.ir.class_ir import ClassIR, all_concrete_classes -from mypyc.ir.func_ir import FuncDecl, FuncIR, get_text_signature +from mypyc.ir.func_ir import FUNC_STATICMETHOD, FuncDecl, FuncIR, get_text_signature from mypyc.ir.ops import BasicBlock, Value from mypyc.ir.rtypes import ( RInstance, @@ -1222,10 +1222,11 @@ def emit_cpyfunction_instance( cfunc = f"(PyCFunction){cname}" func_flags = "METH_FASTCALL | METH_KEYWORDS" doc = f"PyDoc_STR({native_function_doc_initializer(fn)})" + has_self_arg = "true" if fn.class_name and fn.decl.kind != FUNC_STATICMETHOD else "false" code_flags = "CO_COROUTINE" self.emit_line( - f'PyObject* {wrapper_name} = CPyFunction_New({module}, "{filepath}", "{name}", {cfunc}, {func_flags}, {doc}, {fn.line}, {code_flags});' + f'PyObject* {wrapper_name} = CPyFunction_New({module}, "{filepath}", "{name}", {cfunc}, {func_flags}, {doc}, {fn.line}, {code_flags}, {has_self_arg});' ) self.emit_line(f"if (unlikely(!{wrapper_name}))") self.emit_line(error_stmt) diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 5526508f5aca..a46a1d2df332 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -971,7 +971,7 @@ typedef struct { PyObject* CPyFunction_New(PyObject *module, const char *filename, const char *funcname, PyCFunction func, int func_flags, const char *func_doc, - int first_line, int code_flags); + int first_line, int code_flags, bool has_self_arg); PyObject* CPyFunction_get_name(PyObject *op, void *context); int CPyFunction_set_name(PyObject *op, PyObject *value, void *context); PyObject* CPyFunction_get_code(PyObject *op, void *context); diff --git a/mypyc/lib-rt/function_wrapper.c b/mypyc/lib-rt/function_wrapper.c index 578e4870968d..8348168776c2 100644 --- a/mypyc/lib-rt/function_wrapper.c +++ b/mypyc/lib-rt/function_wrapper.c @@ -177,16 +177,21 @@ static PyObject* CPyFunction_Vectorcall(PyObject *func, PyObject *const *args, s PyCFunction meth = ((PyCFunctionObject *)f)->m_ml->ml_meth; self = ((PyCFunctionObject *)f)->m_self; + if (!self) { + self = args[0]; + args += 1; + nargs -= 1; + } return ((_PyCFunctionFastWithKeywords)(void(*)(void))meth)(self, args, nargs, kwnames); } static CPyFunction* CPyFunction_Init(CPyFunction *op, PyMethodDef *ml, PyObject* name, - PyObject *module, PyObject* code) { + PyObject *module, PyObject* code, bool set_self) { PyCFunctionObject *cf = (PyCFunctionObject *)op; CPyFunction_weakreflist(op) = NULL; cf->m_ml = ml; - cf->m_self = (PyObject *) op; + cf->m_self = set_self ? (PyObject *) op : NULL; Py_XINCREF(module); cf->m_module = module; @@ -226,9 +231,10 @@ static PyMethodDef* CPyMethodDef_New(const char *name, PyCFunction func, int fla PyObject* CPyFunction_New(PyObject *module, const char *filename, const char *funcname, PyCFunction func, int func_flags, const char *func_doc, - int first_line, int code_flags) { + int first_line, int code_flags, bool has_self_arg) { PyMethodDef *method = NULL; PyObject *code = NULL, *op = NULL; + bool set_self = false; if (!CPyFunctionType) { CPyFunctionType = (PyTypeObject *)PyType_FromSpec(&CPyFunction_spec); @@ -245,8 +251,15 @@ PyObject* CPyFunction_New(PyObject *module, const char *filename, const char *fu if (unlikely(!code)) { goto err; } + + // Set m_self inside the function wrapper only if the wrapped function has no self arg + // to pass m_self as the self arg when the function is called. + // When the function has a self arg, it will come in the args vector passed to the + // vectorcall handler. + set_self = !has_self_arg; op = (PyObject *)CPyFunction_Init(PyObject_GC_New(CPyFunction, CPyFunctionType), - method, PyUnicode_FromString(funcname), module, code); + method, PyUnicode_FromString(funcname), module, + code, set_self); if (unlikely(!op)) { goto err; } diff --git a/mypyc/test-data/run-async.test b/mypyc/test-data/run-async.test index 95efc3492b60..c28a8175b951 100644 --- a/mypyc/test-data/run-async.test +++ b/mypyc/test-data/run-async.test @@ -228,6 +228,7 @@ async def sleep(t: float) -> None: ... [case testAsyncWith] from testutil import async_val +from typing import Any class async_ctx: async def __aenter__(self) -> str: @@ -242,15 +243,22 @@ async def async_with() -> str: async with async_ctx() as x: return await async_val("body") +async def async_with_vectorcall() -> str: + ctx: Any = async_ctx() + async with ctx: + return await async_val("vc") [file driver.py] -from native import async_with +from native import async_with, async_with_vectorcall from testutil import run_generator yields, val = run_generator(async_with(), [None, 'x', None]) assert yields == ('enter', 'body', 'exit'), yields assert val == 'x', val +yields, val = run_generator(async_with_vectorcall(), [None, 'x', None]) +assert yields == ('enter', 'vc', 'exit'), yields +assert val == 'x', val [case testAsyncReturn] from testutil import async_val @@ -1516,7 +1524,9 @@ def test_method() -> None: assert str(T.returns_one_async).startswith("