Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mypyc/codegen/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
TYPE_PREFIX,
)
from mypyc.ir.class_ir import ClassIR, all_concrete_classes
from mypyc.ir.func_ir import FuncDecl, FuncIR, get_text_signature
from mypyc.ir.func_ir import FUNC_STATICMETHOD, FuncDecl, FuncIR, get_text_signature
from mypyc.ir.ops import BasicBlock, Value
from mypyc.ir.rtypes import (
RInstance,
Expand Down Expand Up @@ -1222,10 +1222,11 @@ def emit_cpyfunction_instance(
cfunc = f"(PyCFunction){cname}"
func_flags = "METH_FASTCALL | METH_KEYWORDS"
doc = f"PyDoc_STR({native_function_doc_initializer(fn)})"
has_self_arg = "true" if fn.class_name and fn.decl.kind != FUNC_STATICMETHOD else "false"

code_flags = "CO_COROUTINE"
self.emit_line(
f'PyObject* {wrapper_name} = CPyFunction_New({module}, "{filepath}", "{name}", {cfunc}, {func_flags}, {doc}, {fn.line}, {code_flags});'
f'PyObject* {wrapper_name} = CPyFunction_New({module}, "{filepath}", "{name}", {cfunc}, {func_flags}, {doc}, {fn.line}, {code_flags}, {has_self_arg});'
)
self.emit_line(f"if (unlikely(!{wrapper_name}))")
self.emit_line(error_stmt)
Expand Down
2 changes: 1 addition & 1 deletion mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ typedef struct {

PyObject* CPyFunction_New(PyObject *module, const char *filename, const char *funcname,
PyCFunction func, int func_flags, const char *func_doc,
int first_line, int code_flags);
int first_line, int code_flags, bool has_self_arg);
PyObject* CPyFunction_get_name(PyObject *op, void *context);
int CPyFunction_set_name(PyObject *op, PyObject *value, void *context);
PyObject* CPyFunction_get_code(PyObject *op, void *context);
Expand Down
21 changes: 17 additions & 4 deletions mypyc/lib-rt/function_wrapper.c
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,21 @@ static PyObject* CPyFunction_Vectorcall(PyObject *func, PyObject *const *args, s
PyCFunction meth = ((PyCFunctionObject *)f)->m_ml->ml_meth;

self = ((PyCFunctionObject *)f)->m_self;
if (!self) {
self = args[0];
args += 1;
nargs -= 1;
}
return ((_PyCFunctionFastWithKeywords)(void(*)(void))meth)(self, args, nargs, kwnames);
}


static CPyFunction* CPyFunction_Init(CPyFunction *op, PyMethodDef *ml, PyObject* name,
PyObject *module, PyObject* code) {
PyObject *module, PyObject* code, bool set_self) {
PyCFunctionObject *cf = (PyCFunctionObject *)op;
CPyFunction_weakreflist(op) = NULL;
cf->m_ml = ml;
cf->m_self = (PyObject *) op;
cf->m_self = set_self ? (PyObject *) op : NULL;

Py_XINCREF(module);
cf->m_module = module;
Expand Down Expand Up @@ -226,9 +231,10 @@ static PyMethodDef* CPyMethodDef_New(const char *name, PyCFunction func, int fla

PyObject* CPyFunction_New(PyObject *module, const char *filename, const char *funcname,
PyCFunction func, int func_flags, const char *func_doc,
int first_line, int code_flags) {
int first_line, int code_flags, bool has_self_arg) {
PyMethodDef *method = NULL;
PyObject *code = NULL, *op = NULL;
bool set_self = false;

if (!CPyFunctionType) {
CPyFunctionType = (PyTypeObject *)PyType_FromSpec(&CPyFunction_spec);
Expand All @@ -245,8 +251,15 @@ PyObject* CPyFunction_New(PyObject *module, const char *filename, const char *fu
if (unlikely(!code)) {
goto err;
}

// Set m_self inside the function wrapper only if the wrapped function has no self arg
// to pass m_self as the self arg when the function is called.
// When the function has a self arg, it will come in the args vector passed to the
// vectorcall handler.
set_self = !has_self_arg;
op = (PyObject *)CPyFunction_Init(PyObject_GC_New(CPyFunction, CPyFunctionType),
method, PyUnicode_FromString(funcname), module, code);
method, PyUnicode_FromString(funcname), module,
code, set_self);
if (unlikely(!op)) {
goto err;
}
Expand Down
14 changes: 12 additions & 2 deletions mypyc/test-data/run-async.test
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ async def sleep(t: float) -> None: ...

[case testAsyncWith]
from testutil import async_val
from typing import Any

class async_ctx:
async def __aenter__(self) -> str:
Expand All @@ -242,15 +243,22 @@ async def async_with() -> str:
async with async_ctx() as x:
return await async_val("body")

async def async_with_vectorcall() -> str:
ctx: Any = async_ctx()
async with ctx:
return await async_val("vc")

[file driver.py]
from native import async_with
from native import async_with, async_with_vectorcall
from testutil import run_generator

yields, val = run_generator(async_with(), [None, 'x', None])
assert yields == ('enter', 'body', 'exit'), yields
assert val == 'x', val

yields, val = run_generator(async_with_vectorcall(), [None, 'x', None])
assert yields == ('enter', 'vc', 'exit'), yields
assert val == 'x', val

[case testAsyncReturn]
from testutil import async_val
Expand Down Expand Up @@ -1516,7 +1524,9 @@ def test_method() -> None:
assert str(T.returns_one_async).startswith("<function T.returns_one_async"), str(T.returns_one_async)

t = T()
assert asyncio.run(t.returns_one_async()) == 1
# Call through variable to make sure the call is through vectorcall and not optimized to a native call.
f: Any = t.returns_one_async
assert asyncio.run(f()) == 1

assert not is_coroutine(T.returns_two)
assert is_coroutine(T.returns_two_async)
Expand Down