Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypyc] Support __pow__, __rpow__, and __ipow__ dunders #14616

Merged
merged 2 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mypyc/codegen/emitclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
generate_dunder_wrapper,
generate_get_wrapper,
generate_hash_wrapper,
generate_ipow_wrapper,
generate_len_wrapper,
generate_richcompare_wrapper,
generate_set_del_item_wrapper,
Expand Down Expand Up @@ -109,6 +110,11 @@ def wrapper_slot(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"__ior__": ("nb_inplace_or", generate_dunder_wrapper),
"__ixor__": ("nb_inplace_xor", generate_dunder_wrapper),
"__imatmul__": ("nb_inplace_matrix_multiply", generate_dunder_wrapper),
# Ternary operations. (yes, really)
# These are special cased in generate_bin_op_wrapper().
"__pow__": ("nb_power", generate_bin_op_wrapper),
"__rpow__": ("nb_power", generate_bin_op_wrapper),
"__ipow__": ("nb_inplace_power", generate_ipow_wrapper),
}

AS_ASYNC_SLOT_DEFS: SlotTable = {
Expand Down
105 changes: 88 additions & 17 deletions mypyc/codegen/emitwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,32 @@ def generate_dunder_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
return gen.wrapper_name()


def generate_ipow_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generate a wrapper for native __ipow__.

Since __ipow__ fills a ternary slot, but almost no one defines __ipow__ to take three
arguments, the wrapper needs to tweaked to force it to accept three arguments.
"""
gen = WrapperGenerator(cl, emitter)
gen.set_target(fn)
assert len(fn.args) in (2, 3), "__ipow__ should only take 2 or 3 arguments"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The number of arguments accepted by dunders mostly isn't checked by mypy or mypyc, so it looks like we can actually trip the assertion -- but that's a more general problem and doesn't need to be addressed now (mypyc/mypyc#949).

gen.arg_names = ["self", "exp", "mod"]
gen.emit_header()
gen.emit_arg_processing()
handle_third_pow_argument(
fn,
emitter,
gen,
if_unsupported=[
'PyErr_SetString(PyExc_TypeError, "__ipow__ takes 2 positional arguments but 3 were given");',
"return NULL;",
],
)
gen.emit_call()
gen.finish()
return gen.wrapper_name()


def generate_bin_op_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generates a wrapper for a native binary dunder method.

Expand All @@ -311,13 +337,16 @@ def generate_bin_op_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""
gen = WrapperGenerator(cl, emitter)
gen.set_target(fn)
gen.arg_names = ["left", "right"]
if fn.name in ("__pow__", "__rpow__"):
gen.arg_names = ["left", "right", "mod"]
else:
gen.arg_names = ["left", "right"]
wrapper_name = gen.wrapper_name()

gen.emit_header()
if fn.name not in reverse_op_methods and fn.name in reverse_op_method_names:
# There's only a reverse operator method.
generate_bin_op_reverse_only_wrapper(emitter, gen)
generate_bin_op_reverse_only_wrapper(fn, emitter, gen)
else:
rmethod = reverse_op_methods[fn.name]
fn_rev = cl.get_method(rmethod)
Expand All @@ -334,6 +363,7 @@ def generate_bin_op_forward_only_wrapper(
fn: FuncIR, emitter: Emitter, gen: WrapperGenerator
) -> None:
gen.emit_arg_processing(error=GotoHandler("typefail"), raise_exception=False)
handle_third_pow_argument(fn, emitter, gen, if_unsupported=["goto typefail;"])
gen.emit_call(not_implemented_handler="goto typefail;")
gen.emit_error_handling()
emitter.emit_label("typefail")
Expand All @@ -352,19 +382,16 @@ def generate_bin_op_forward_only_wrapper(
# if not isinstance(other, int):
# return NotImplemented
# ...
rmethod = reverse_op_methods[fn.name]
emitter.emit_line(f"_Py_IDENTIFIER({rmethod});")
emitter.emit_line(
'return CPy_CallReverseOpMethod(obj_left, obj_right, "{}", &PyId_{});'.format(
op_methods_to_symbols[fn.name], rmethod
)
)
generate_bin_op_reverse_dunder_call(fn, emitter, reverse_op_methods[fn.name])
gen.finish()


def generate_bin_op_reverse_only_wrapper(emitter: Emitter, gen: WrapperGenerator) -> None:
def generate_bin_op_reverse_only_wrapper(
fn: FuncIR, emitter: Emitter, gen: WrapperGenerator
) -> None:
gen.arg_names = ["right", "left"]
gen.emit_arg_processing(error=GotoHandler("typefail"), raise_exception=False)
handle_third_pow_argument(fn, emitter, gen, if_unsupported=["goto typefail;"])
gen.emit_call()
gen.emit_error_handling()
emitter.emit_label("typefail")
Expand All @@ -390,7 +417,14 @@ def generate_bin_op_both_wrappers(
)
)
gen.emit_arg_processing(error=GotoHandler("typefail"), raise_exception=False)
gen.emit_call(not_implemented_handler="goto typefail;")
handle_third_pow_argument(fn, emitter, gen, if_unsupported=["goto typefail2;"])
# Ternary __rpow__ calls aren't a thing so immediately bail
# if ternary __pow__ returns NotImplemented.
if fn.name == "__pow__" and len(fn.args) == 3:
fwd_not_implemented_handler = "goto typefail2;"
else:
fwd_not_implemented_handler = "goto typefail;"
gen.emit_call(not_implemented_handler=fwd_not_implemented_handler)
gen.emit_error_handling()
emitter.emit_line("}")
emitter.emit_label("typefail")
Expand All @@ -402,22 +436,59 @@ def generate_bin_op_both_wrappers(
gen.set_target(fn_rev)
gen.arg_names = ["right", "left"]
gen.emit_arg_processing(error=GotoHandler("typefail2"), raise_exception=False)
handle_third_pow_argument(fn_rev, emitter, gen, if_unsupported=["goto typefail2;"])
gen.emit_call()
gen.emit_error_handling()
emitter.emit_line("} else {")
emitter.emit_line(f"_Py_IDENTIFIER({fn_rev.name});")
emitter.emit_line(
'return CPy_CallReverseOpMethod(obj_left, obj_right, "{}", &PyId_{});'.format(
op_methods_to_symbols[fn.name], fn_rev.name
)
)
generate_bin_op_reverse_dunder_call(fn, emitter, fn_rev.name)
emitter.emit_line("}")
emitter.emit_label("typefail2")
emitter.emit_line("Py_INCREF(Py_NotImplemented);")
emitter.emit_line("return Py_NotImplemented;")
gen.finish()


def generate_bin_op_reverse_dunder_call(fn: FuncIR, emitter: Emitter, rmethod: str) -> None:
if fn.name in ("__pow__", "__rpow__"):
# Ternary pow() will never call the reverse dunder.
emitter.emit_line("if (obj_mod == Py_None) {")
emitter.emit_line(f"_Py_IDENTIFIER({rmethod});")
emitter.emit_line(
'return CPy_CallReverseOpMethod(obj_left, obj_right, "{}", &PyId_{});'.format(
op_methods_to_symbols[fn.name], rmethod
)
)
if fn.name in ("__pow__", "__rpow__"):
emitter.emit_line("} else {")
emitter.emit_line("Py_INCREF(Py_NotImplemented);")
emitter.emit_line("return Py_NotImplemented;")
emitter.emit_line("}")


def handle_third_pow_argument(
fn: FuncIR, emitter: Emitter, gen: WrapperGenerator, *, if_unsupported: list[str]
) -> None:
if fn.name not in ("__pow__", "__rpow__", "__ipow__"):
return

if (fn.name in ("__pow__", "__ipow__") and len(fn.args) == 2) or fn.name == "__rpow__":
# If the power dunder only supports two arguments and the third
# argument (AKA mod) is set to a non-default value, simply bail.
#
# Importantly, this prevents any ternary __rpow__ calls from
# happening (as per the language specification).
emitter.emit_line("if (obj_mod != Py_None) {")
for line in if_unsupported:
emitter.emit_line(line)
emitter.emit_line("}")
# The slot wrapper will receive three arguments, but the call only
# supports two so make sure that the third argument isn't passed
# along. This is needed as two-argument __(i)pow__ is allowed and
# rather common.
if len(gen.arg_names) == 3:
gen.arg_names.pop()


RICHCOMPARE_OPS = {
"__lt__": "Py_LT",
"__gt__": "Py_GT",
Expand Down
1 change: 1 addition & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ CPyTagged CPyObject_Hash(PyObject *o);
PyObject *CPyObject_GetAttr3(PyObject *v, PyObject *name, PyObject *defl);
PyObject *CPyIter_Next(PyObject *iter);
PyObject *CPyNumber_Power(PyObject *base, PyObject *index);
PyObject *CPyNumber_InPlacePower(PyObject *base, PyObject *index);
PyObject *CPyObject_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end);


Expand Down
5 changes: 5 additions & 0 deletions mypyc/lib-rt/generic_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ PyObject *CPyNumber_Power(PyObject *base, PyObject *index)
return PyNumber_Power(base, index, Py_None);
}

PyObject *CPyNumber_InPlacePower(PyObject *base, PyObject *index)
{
return PyNumber_InPlacePower(base, index, Py_None);
}

PyObject *CPyObject_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end) {
PyObject *start_obj = CPyTagged_AsObject(start);
PyObject *end_obj = CPyTagged_AsObject(end);
Expand Down
27 changes: 19 additions & 8 deletions mypyc/primitives/generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,25 @@
priority=0,
)

binary_op(
name="**",
arg_types=[object_rprimitive, object_rprimitive],
return_type=object_rprimitive,
error_kind=ERR_MAGIC,
c_function_name="CPyNumber_Power",
priority=0,
)
for op, c_function in (("**", "CPyNumber_Power"), ("**=", "CPyNumber_InPlacePower")):
binary_op(
name=op,
arg_types=[object_rprimitive, object_rprimitive],
return_type=object_rprimitive,
error_kind=ERR_MAGIC,
c_function_name=c_function,
priority=0,
)

for arg_count, c_function in ((2, "CPyNumber_Power"), (3, "PyNumber_Power")):
function_op(
name="builtins.pow",
arg_types=[object_rprimitive] * arg_count,
return_type=object_rprimitive,
error_kind=ERR_MAGIC,
c_function_name=c_function,
priority=0,
)

binary_op(
name="in",
Expand Down
22 changes: 22 additions & 0 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@ def __divmod__(self, other: T_contra) -> T_co: ...
class __SupportsRDivMod(Protocol[T_contra, T_co]):
def __rdivmod__(self, other: T_contra) -> T_co: ...

_M = TypeVar("_M", contravariant=True)

class __SupportsPow2(Protocol[T_contra, T_co]):
def __pow__(self, other: T_contra) -> T_co: ...

class __SupportsPow3NoneOnly(Protocol[T_contra, T_co]):
def __pow__(self, other: T_contra, modulo: None = ...) -> T_co: ...

class __SupportsPow3(Protocol[T_contra, _M, T_co]):
def __pow__(self, other: T_contra, modulo: _M) -> T_co: ...

__SupportsSomeKindOfPow = Union[
__SupportsPow2[Any, Any], __SupportsPow3NoneOnly[Any, Any] | __SupportsPow3[Any, Any, Any]
]

class object:
def __init__(self) -> None: pass
def __eq__(self, x: object) -> bool: pass
Expand Down Expand Up @@ -99,6 +114,7 @@ def __add__(self, n: float) -> float: pass
def __sub__(self, n: float) -> float: pass
def __mul__(self, n: float) -> float: pass
def __truediv__(self, n: float) -> float: pass
def __pow__(self, n: float) -> float: pass
def __neg__(self) -> float: pass
def __pos__(self) -> float: pass
def __abs__(self) -> float: pass
Expand Down Expand Up @@ -318,6 +334,12 @@ def abs(x: __SupportsAbs[T]) -> T: ...
def divmod(x: __SupportsDivMod[T_contra, T_co], y: T_contra) -> T_co: ...
@overload
def divmod(x: T_contra, y: __SupportsRDivMod[T_contra, T_co]) -> T_co: ...
@overload
def pow(base: __SupportsPow2[T_contra, T_co], exp: T_contra, mod: None = None) -> T_co: ...
@overload
def pow(base: __SupportsPow3NoneOnly[T_contra, T_co], exp: T_contra, mod: None = None) -> T_co: ...
@overload
def pow(base: __SupportsPow3[T_contra, _M, T_co], exp: T_contra, mod: _M) -> T_co: ...
def exit() -> None: ...
def min(x: T, y: T) -> T: ...
def max(x: T, y: T) -> T: ...
Expand Down
25 changes: 25 additions & 0 deletions mypyc/test-data/irbuild-any.test
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ L0:
[case testFunctionBasedOps]
def f() -> None:
a = divmod(5, 2)
def f2() -> int:
return pow(2, 5)
def f3() -> float:
return pow(2, 5, 3)
[out]
def f():
r0, r1, r2 :: object
Expand All @@ -212,4 +216,25 @@ L0:
r3 = unbox(tuple[float, float], r2)
a = r3
return 1
def f2():
r0, r1, r2 :: object
r3 :: int
L0:
r0 = object 2
r1 = object 5
r2 = CPyNumber_Power(r0, r1)
r3 = unbox(int, r2)
return r3
def f3():
r0, r1, r2, r3 :: object
r4 :: int
r5 :: object
L0:
r0 = object 2
r1 = object 5
r2 = object 3
r3 = PyNumber_Power(r0, r1, r2)
r4 = unbox(int, r3)
r5 = box(int, r4)
return r5

Loading