From dc75c193d64a4f7439a8a865c57cadc3f6cc9667 Mon Sep 17 00:00:00 2001 From: Richard Si Date: Sat, 21 Jan 2023 13:23:51 -0500 Subject: [PATCH] [mypyc] Support __(r)divmod__ dunder --- mypyc/codegen/emitclass.py | 2 ++ mypyc/primitives/generic_ops.py | 11 +++++++++++ mypyc/test-data/fixtures/ir.py | 11 +++++++++++ mypyc/test-data/irbuild-any.test | 15 +++++++++++++++ mypyc/test-data/run-dunders.test | 5 +++++ 5 files changed, 44 insertions(+) diff --git a/mypyc/codegen/emitclass.py b/mypyc/codegen/emitclass.py index 72e16345a325..79fdd9103371 100644 --- a/mypyc/codegen/emitclass.py +++ b/mypyc/codegen/emitclass.py @@ -82,6 +82,8 @@ def wrapper_slot(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: "__rtruediv__": ("nb_true_divide", generate_bin_op_wrapper), "__floordiv__": ("nb_floor_divide", generate_bin_op_wrapper), "__rfloordiv__": ("nb_floor_divide", generate_bin_op_wrapper), + "__divmod__": ("nb_divmod", generate_bin_op_wrapper), + "__rdivmod__": ("nb_divmod", generate_bin_op_wrapper), "__lshift__": ("nb_lshift", generate_bin_op_wrapper), "__rlshift__": ("nb_lshift", generate_bin_op_wrapper), "__rshift__": ("nb_rshift", generate_bin_op_wrapper), diff --git a/mypyc/primitives/generic_ops.py b/mypyc/primitives/generic_ops.py index f6817ad024b7..4f04608d11f3 100644 --- a/mypyc/primitives/generic_ops.py +++ b/mypyc/primitives/generic_ops.py @@ -75,6 +75,17 @@ priority=0, ) + +function_op( + name="builtins.divmod", + arg_types=[object_rprimitive, object_rprimitive], + return_type=object_rprimitive, + c_function_name="PyNumber_Divmod", + error_kind=ERR_MAGIC, + priority=0, +) + + for op, funcname in [ ("+=", "PyNumber_InPlaceAdd"), ("-=", "PyNumber_InPlaceSubtract"), diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 2f3c18e9c731..37aab1d826d7 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -8,6 +8,7 @@ T = TypeVar('T') T_co = TypeVar('T_co', covariant=True) +T_contra = TypeVar('T_contra', contravariant=True) S = TypeVar('S') K = TypeVar('K') # for keys in mapping V = TypeVar('V') # for values in mapping @@ -15,6 +16,11 @@ class __SupportsAbs(Protocol[T_co]): def __abs__(self) -> T_co: pass +class __SupportsDivMod(Protocol[T_contra, T_co]): + def __divmod__(self, other: T_contra) -> T_co: ... + +class __SupportsRDivMod(Protocol[T_contra, T_co]): + def __rdivmod__(self, other: T_contra) -> T_co: ... class object: def __init__(self) -> None: pass @@ -42,6 +48,7 @@ def __pow__(self, n: int, modulo: Optional[int] = None) -> int: pass def __floordiv__(self, x: int) -> int: pass def __truediv__(self, x: float) -> float: pass def __mod__(self, x: int) -> int: pass + def __divmod__(self, x: float) -> Tuple[float, float]: pass def __neg__(self) -> int: pass def __pos__(self) -> int: pass def __abs__(self) -> int: pass @@ -307,6 +314,10 @@ def zip(x: Iterable[T], y: Iterable[S]) -> Iterator[Tuple[T, S]]: ... def zip(x: Iterable[T], y: Iterable[S], z: Iterable[V]) -> Iterator[Tuple[T, S, V]]: ... def eval(e: str) -> Any: ... def abs(x: __SupportsAbs[T]) -> T: ... +@overload +def divmod(x: __SupportsDivMod[T_contra, T_co], y: T_contra) -> T_co: ... +@overload +def divmod(x: T_contra, y: __SupportsRDivMod[T_contra, T_co]) -> T_co: ... def exit() -> None: ... def min(x: T, y: T) -> T: ... def max(x: T, y: T) -> T: ... diff --git a/mypyc/test-data/irbuild-any.test b/mypyc/test-data/irbuild-any.test index bcf9a1880635..8cc626100262 100644 --- a/mypyc/test-data/irbuild-any.test +++ b/mypyc/test-data/irbuild-any.test @@ -198,3 +198,18 @@ L0: b = r4 return 1 +[case testFunctionBasedOps] +def f() -> None: + a = divmod(5, 2) +[out] +def f(): + r0, r1, r2 :: object + r3, a :: tuple[float, float] +L0: + r0 = object 5 + r1 = object 2 + r2 = PyNumber_Divmod(r0, r1) + r3 = unbox(tuple[float, float], r2) + a = r3 + return 1 + diff --git a/mypyc/test-data/run-dunders.test b/mypyc/test-data/run-dunders.test index 0b156e5c3af8..23323c7244de 100644 --- a/mypyc/test-data/run-dunders.test +++ b/mypyc/test-data/run-dunders.test @@ -402,6 +402,9 @@ class C: def __floordiv__(self, y: int) -> int: return self.x + y + 30 + def __divmod__(self, y: int) -> int: + return self.x + y + 40 + def test_generic() -> None: a: Any = C() assert a + 3 == 8 @@ -417,11 +420,13 @@ def test_generic() -> None: assert a @ 3 == 18 assert a / 2 == 27 assert a // 2 == 37 + assert divmod(a, 2) == 47 def test_native() -> None: c = C() assert c + 3 == 8 assert c - 3 == 2 + assert divmod(c, 3) == 48 def test_error() -> None: a: Any = C()