diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index 4f2f539118d7..12b5bc7f8f82 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -1437,6 +1437,10 @@ def add_function(self, func_ir: FuncIR, line: int) -> None: self.function_names.add(name) self.functions.append(func_ir) + def get_current_class_ir(self) -> ClassIR | None: + type_info = self.fn_info.fitem.info + return self.mapper.type_to_ir.get(type_info) + def gen_arg_defaults(builder: IRBuilder) -> None: """Generate blocks for arguments that have default values. diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index 4409b1acff26..1f39b09c0995 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -57,7 +57,6 @@ from mypyc.ir.ops import ( Assign, BasicBlock, - Call, ComparisonOp, Integer, LoadAddress, @@ -98,7 +97,11 @@ join_formatted_strings, tokenizer_printf_style, ) -from mypyc.irbuild.specialize import apply_function_specialization, apply_method_specialization +from mypyc.irbuild.specialize import ( + apply_function_specialization, + apply_method_specialization, + translate_object_new, +) from mypyc.primitives.bytes_ops import bytes_slice_op from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, exact_dict_set_item_op from mypyc.primitives.generic_ops import iter_op, name_op @@ -473,35 +476,15 @@ def translate_super_method_call(builder: IRBuilder, expr: CallExpr, callee: Supe if callee.name in base.method_decls: break else: + if callee.name == "__new__": + result = translate_object_new(builder, expr, MemberExpr(callee.call, "__new__")) + if result: + return result if ir.is_ext_class and ir.builtin_base is None and not ir.inherits_python: if callee.name == "__init__" and len(expr.args) == 0: # Call translates to object.__init__(self), which is a # no-op, so omit the call. return builder.none() - elif callee.name == "__new__": - # object.__new__(cls) - assert ( - len(expr.args) == 1 - ), f"Expected object.__new__() call to have exactly 1 argument, got {len(expr.args)}" - typ_arg = expr.args[0] - method_args = builder.fn_info.fitem.arg_names - if ( - isinstance(typ_arg, NameExpr) - and len(method_args) > 0 - and method_args[0] == typ_arg.name - ): - subtype = builder.accept(expr.args[0]) - return builder.add(Call(ir.setup, [subtype], expr.line)) - - if callee.name == "__new__": - call = "super().__new__()" - if not ir.is_ext_class: - builder.error(f"{call} not supported for non-extension classes", expr.line) - if ir.inherits_python: - builder.error( - f"{call} not supported for classes inheriting from non-native classes", - expr.line, - ) return translate_call(builder, expr, callee) decl = base.method_decl(callee.name) diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index 576b7a7ebffd..29820787d10c 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -30,12 +30,14 @@ NameExpr, RefExpr, StrExpr, + SuperExpr, TupleExpr, Var, ) from mypy.types import AnyType, TypeOfAny from mypyc.ir.ops import ( BasicBlock, + Call, Extend, Integer, RaiseStandardError, @@ -68,6 +70,7 @@ is_list_rprimitive, is_uint8_rprimitive, list_rprimitive, + object_rprimitive, set_rprimitive, str_rprimitive, uint8_rprimitive, @@ -1002,3 +1005,44 @@ def translate_ord(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value if isinstance(arg, (StrExpr, BytesExpr)) and len(arg.value) == 1: return Integer(ord(arg.value)) return None + + +@specialize_function("__new__", object_rprimitive) +def translate_object_new(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: + fn = builder.fn_info + if fn.name != "__new__": + return None + + is_super_new = isinstance(expr.callee, SuperExpr) + is_object_new = ( + isinstance(callee, MemberExpr) + and isinstance(callee.expr, NameExpr) + and callee.expr.fullname == "builtins.object" + ) + if not (is_super_new or is_object_new): + return None + + ir = builder.get_current_class_ir() + if ir is None: + return None + + call = '"object.__new__()"' + if not ir.is_ext_class: + builder.error(f"{call} not supported for non-extension classes", expr.line) + return None + if ir.inherits_python: + builder.error( + f"{call} not supported for classes inheriting from non-native classes", expr.line + ) + return None + if len(expr.args) != 1: + builder.error(f"{call} supported only with 1 argument, got {len(expr.args)}", expr.line) + return None + + typ_arg = expr.args[0] + method_args = fn.fitem.arg_names + if isinstance(typ_arg, NameExpr) and len(method_args) > 0 and method_args[0] == typ_arg.name: + subtype = builder.accept(expr.args[0]) + return builder.add(Call(ir.setup, [subtype], expr.line)) + + return None diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index eeeb40ac672f..c83c5550d059 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -33,6 +33,7 @@ ListExpr, Lvalue, MatchStmt, + NameExpr, OperatorAssignmentStmt, RaiseStmt, ReturnStmt, @@ -170,10 +171,43 @@ def transform_return_stmt(builder: IRBuilder, stmt: ReturnStmt) -> None: builder.nonlocal_control[-1].gen_return(builder, retval, stmt.line) +def check_unsupported_cls_assignment(builder: IRBuilder, stmt: AssignmentStmt) -> None: + fn = builder.fn_info + method_args = fn.fitem.arg_names + if fn.name != "__new__" or len(method_args) == 0: + return + + ir = builder.get_current_class_ir() + if ir is None or ir.inherits_python or not ir.is_ext_class: + return + + cls_arg = method_args[0] + + def flatten(lvalues: list[Expression]) -> list[Expression]: + flat = [] + for lvalue in lvalues: + if isinstance(lvalue, (TupleExpr, ListExpr)): + flat += flatten(lvalue.items) + else: + flat.append(lvalue) + return flat + + lvalues = flatten(stmt.lvalues) + + for lvalue in lvalues: + if isinstance(lvalue, NameExpr) and lvalue.name == cls_arg: + # Disallowed because it could break the transformation of object.__new__ calls + # inside __new__ methods. + builder.error( + f'Assignment to argument "{cls_arg}" in "__new__" method unsupported', stmt.line + ) + + def transform_assignment_stmt(builder: IRBuilder, stmt: AssignmentStmt) -> None: lvalues = stmt.lvalues assert lvalues builder.disallow_class_assignments(lvalues, stmt.line) + check_unsupported_cls_assignment(builder, stmt) first_lvalue = lvalues[0] if stmt.type and isinstance(stmt.rvalue, TempNode): # This is actually a variable annotation without initializer. Don't generate diff --git a/mypyc/test-data/irbuild-classes.test b/mypyc/test-data/irbuild-classes.test index 78ca7b68cefb..92857f525cca 100644 --- a/mypyc/test-data/irbuild-classes.test +++ b/mypyc/test-data/irbuild-classes.test @@ -1662,6 +1662,7 @@ L0: [case testDunderNew] from __future__ import annotations +from typing import Any class Test: val: int @@ -1686,6 +1687,169 @@ class NewClassMethod: def fn2() -> NewClassMethod: return NewClassMethod.__new__(42) +class NotTransformed: + def __new__(cls, val: int) -> Any: + return super().__new__(str) + + def factory(cls: Any, val: int) -> Any: + cls = str + return super().__new__(cls) + +[out] +def Test.__new__(cls, val): + cls :: object + val :: int + r0, obj :: __main__.Test + r1 :: bool +L0: + r0 = __mypyc__Test_setup(cls) + obj = r0 + obj.val = val; r1 = is_error + return obj +def fn(): + r0 :: object + r1 :: __main__.Test +L0: + r0 = __main__.Test :: type + r1 = Test.__new__(r0, 84) + return r1 +def NewClassMethod.__new__(cls, val): + cls :: object + val :: int + r0, obj :: __main__.NewClassMethod + r1 :: bool +L0: + r0 = __mypyc__NewClassMethod_setup(cls) + obj = r0 + obj.val = val; r1 = is_error + return obj +def fn2(): + r0 :: object + r1 :: __main__.NewClassMethod +L0: + r0 = __main__.NewClassMethod :: type + r1 = NewClassMethod.__new__(r0, 84) + return r1 +def NotTransformed.__new__(cls, val): + cls :: object + val :: int + r0 :: object + r1 :: str + r2, r3 :: object + r4 :: object[2] + r5 :: object_ptr + r6 :: object + r7 :: str + r8, r9 :: object + r10 :: object[1] + r11 :: object_ptr + r12 :: object + r13 :: str +L0: + r0 = builtins :: module + r1 = 'super' + r2 = CPyObject_GetAttr(r0, r1) + r3 = __main__.NotTransformed :: type + r4 = [r3, cls] + r5 = load_address r4 + r6 = PyObject_Vectorcall(r2, r5, 2, 0) + keep_alive r3, cls + r7 = '__new__' + r8 = CPyObject_GetAttr(r6, r7) + r9 = load_address PyUnicode_Type + r10 = [r9] + r11 = load_address r10 + r12 = PyObject_Vectorcall(r8, r11, 1, 0) + keep_alive r9 + r13 = cast(str, r12) + return r13 +def NotTransformed.factory(cls, val): + cls :: object + val :: int + r0, r1 :: object + r2 :: str + r3, r4 :: object + r5 :: object[2] + r6 :: object_ptr + r7 :: object + r8 :: str + r9 :: object + r10 :: object[1] + r11 :: object_ptr + r12 :: object +L0: + r0 = load_address PyUnicode_Type + cls = r0 + r1 = builtins :: module + r2 = 'super' + r3 = CPyObject_GetAttr(r1, r2) + r4 = __main__.NotTransformed :: type + r5 = [r4, cls] + r6 = load_address r5 + r7 = PyObject_Vectorcall(r3, r6, 2, 0) + keep_alive r4, cls + r8 = '__new__' + r9 = CPyObject_GetAttr(r7, r8) + r10 = [cls] + r11 = load_address r10 + r12 = PyObject_Vectorcall(r9, r11, 1, 0) + keep_alive cls + return r12 + +[case testObjectDunderNew_64bit] +from __future__ import annotations +from mypy_extensions import mypyc_attr +from typing import Any + +class Test: + val: int + + def __new__(cls, val: int) -> Test: + obj = object.__new__(cls) + obj.val = val + return obj + +def fn() -> Test: + return Test.__new__(Test, 42) + +class NewClassMethod: + val: int + + @classmethod + def __new__(cls, val: int) -> NewClassMethod: + obj = object.__new__(cls) + obj.val = val + return obj + +def fn2() -> NewClassMethod: + return NewClassMethod.__new__(42) + +class NotTransformed: + def __new__(cls, val: int) -> Any: + return object.__new__(str) + + def factory(cls: Any, val: int) -> Any: + cls = str + return object.__new__(cls) + +@mypyc_attr(native_class=False) +class NonNative: + def __new__(cls: Any) -> Any: + cls = str + return cls("str") + +class InheritsPython(dict): + def __new__(cls: Any) -> Any: + cls = dict + return cls({}) + +class ObjectNewOutsideDunderNew: + def __init__(self) -> None: + object.__new__(ObjectNewOutsideDunderNew) + +def object_new_outside_class() -> None: + object.__new__(Test) + [out] def Test.__new__(cls, val): cls :: object @@ -1721,19 +1885,185 @@ L0: r0 = __main__.NewClassMethod :: type r1 = NewClassMethod.__new__(r0, 84) return r1 +def NotTransformed.__new__(cls, val): + cls :: object + val :: int + r0 :: object + r1 :: str + r2, r3 :: object + r4 :: str + r5 :: object[2] + r6 :: object_ptr + r7 :: object + r8 :: str +L0: + r0 = builtins :: module + r1 = 'object' + r2 = CPyObject_GetAttr(r0, r1) + r3 = load_address PyUnicode_Type + r4 = '__new__' + r5 = [r2, r3] + r6 = load_address r5 + r7 = PyObject_VectorcallMethod(r4, r6, 9223372036854775810, 0) + keep_alive r2, r3 + r8 = cast(str, r7) + return r8 +def NotTransformed.factory(cls, val): + cls :: object + val :: int + r0, r1 :: object + r2 :: str + r3 :: object + r4 :: str + r5 :: object[2] + r6 :: object_ptr + r7 :: object +L0: + r0 = load_address PyUnicode_Type + cls = r0 + r1 = builtins :: module + r2 = 'object' + r3 = CPyObject_GetAttr(r1, r2) + r4 = '__new__' + r5 = [r3, cls] + r6 = load_address r5 + r7 = PyObject_VectorcallMethod(r4, r6, 9223372036854775810, 0) + keep_alive r3, cls + return r7 +def __new___NonNative_obj.__get__(__mypyc_self__, instance, owner): + __mypyc_self__, instance, owner, r0 :: object + r1 :: bit + r2 :: object +L0: + r0 = load_address _Py_NoneStruct + r1 = instance == r0 + if r1 goto L1 else goto L2 :: bool +L1: + return __mypyc_self__ +L2: + r2 = PyMethod_New(__mypyc_self__, instance) + return r2 +def __new___NonNative_obj.__call__(__mypyc_self__, cls): + __mypyc_self__ :: __main__.__new___NonNative_obj + cls, r0 :: object + r1 :: str + r2 :: object[1] + r3 :: object_ptr + r4 :: object +L0: + r0 = load_address PyUnicode_Type + cls = r0 + r1 = 'str' + r2 = [r1] + r3 = load_address r2 + r4 = PyObject_Vectorcall(cls, r3, 1, 0) + keep_alive r1 + return r4 +def InheritsPython.__new__(cls): + cls, r0 :: object + r1 :: dict + r2 :: object[1] + r3 :: object_ptr + r4 :: object +L0: + r0 = load_address PyDict_Type + cls = r0 + r1 = PyDict_New() + r2 = [r1] + r3 = load_address r2 + r4 = PyObject_Vectorcall(cls, r3, 1, 0) + keep_alive r1 + return r4 +def ObjectNewOutsideDunderNew.__init__(self): + self :: __main__.ObjectNewOutsideDunderNew + r0 :: object + r1 :: str + r2, r3 :: object + r4 :: str + r5 :: object[2] + r6 :: object_ptr + r7 :: object +L0: + r0 = builtins :: module + r1 = 'object' + r2 = CPyObject_GetAttr(r0, r1) + r3 = __main__.ObjectNewOutsideDunderNew :: type + r4 = '__new__' + r5 = [r2, r3] + r6 = load_address r5 + r7 = PyObject_VectorcallMethod(r4, r6, 9223372036854775810, 0) + keep_alive r2, r3 + return 1 +def object_new_outside_class(): + r0 :: object + r1 :: str + r2, r3 :: object + r4 :: str + r5 :: object[2] + r6 :: object_ptr + r7 :: object +L0: + r0 = builtins :: module + r1 = 'object' + r2 = CPyObject_GetAttr(r0, r1) + r3 = __main__.Test :: type + r4 = '__new__' + r5 = [r2, r3] + r6 = load_address r5 + r7 = PyObject_VectorcallMethod(r4, r6, 9223372036854775810, 0) + keep_alive r2, r3 + return 1 [case testUnsupportedDunderNew] from __future__ import annotations from mypy_extensions import mypyc_attr +from typing import Any @mypyc_attr(native_class=False) class NonNative: def __new__(cls) -> NonNative: - return super().__new__(cls) # E: super().__new__() not supported for non-extension classes + return super().__new__(cls) # E: "object.__new__()" not supported for non-extension classes class InheritsPython(dict): def __new__(cls) -> InheritsPython: - return super().__new__(cls) # E: super().__new__() not supported for classes inheriting from non-native classes + return super().__new__(cls) # E: "object.__new__()" not supported for classes inheriting from non-native classes + +@mypyc_attr(native_class=False) +class NonNativeObjectNew: + def __new__(cls) -> NonNativeObjectNew: + return object.__new__(cls) # E: "object.__new__()" not supported for non-extension classes + +class InheritsPythonObjectNew(dict): + def __new__(cls) -> InheritsPythonObjectNew: + return object.__new__(cls) # E: "object.__new__()" not supported for classes inheriting from non-native classes + +class ClsAssignment: + def __new__(cls: Any) -> Any: + cls = str # E: Assignment to argument "cls" in "__new__" method unsupported + return super().__new__(cls) + +class ClsTupleAssignment: + def __new__(class_i_want: Any, val: int) -> Any: + class_i_want, val = dict, 1 # E: Assignment to argument "class_i_want" in "__new__" method unsupported + return object.__new__(class_i_want) + +class ClsListAssignment: + def __new__(cls: Any, val: str) -> Any: + [cls, val] = [object, "object"] # E: Assignment to argument "cls" in "__new__" method unsupported + return object.__new__(cls) + +class ClsNestedAssignment: + def __new__(cls: Any, val1: str, val2: int) -> Any: + [val1, [val2, cls]] = ["val1", [2, int]] # E: Assignment to argument "cls" in "__new__" method unsupported + return object.__new__(cls) + +class WrongNumberOfArgs: + def __new__(cls): + return super().__new__() # E: "object.__new__()" supported only with 1 argument, got 0 + +class WrongNumberOfArgsObjectNew: + def __new__(cls): + return object.__new__(cls, 1) # E: "object.__new__()" supported only with 1 argument, got 2 [case testClassWithFreeList] from mypy_extensions import mypyc_attr, trait diff --git a/mypyc/test-data/run-classes.test b/mypyc/test-data/run-classes.test index 6c4ddc03887a..3d0250cd24ee 100644 --- a/mypyc/test-data/run-classes.test +++ b/mypyc/test-data/run-classes.test @@ -3493,6 +3493,156 @@ Add(0, 5)=5 running __new__ with 1 and 0 Add(1, 0)=1 +[case testObjectDunderNew] +from __future__ import annotations +from typing import Any, Union + +from testutil import assertRaises + +class Add: + l: IntLike + r: IntLike + + def __new__(cls, l: IntLike, r: IntLike) -> Any: + return ( + l if r == 0 else + r if l == 0 else + object.__new__(cls) + ) + + def __init__(self, l: IntLike, r: IntLike): + self.l = l + self.r = r + +IntLike = Union[int, Add] + +class RaisesException: + def __new__(cls, val: int) -> RaisesException: + if val == 0: + raise RuntimeError("Invalid value!") + return object.__new__(cls) + + def __init__(self, val: int) -> None: + self.val = val + +class ClsArgNotPassed: + def __new__(cls) -> Any: + return object.__new__(str) + +class SkipsBase(Add): + def __new__(cls) -> Any: + obj = object.__new__(cls) + obj.l = 0 + obj.r = 0 + return obj + +def test_dunder_new() -> None: + add_instance: Any = Add(1, 5) + assert type(add_instance) == Add + assert add_instance.l == 1 + assert add_instance.r == 5 + + # TODO: explicit types should not be needed but mypy does not use + # the return type of __new__ which makes mypyc add casts to Add. + right_int: Any = Add(0, 5) + assert type(right_int) == int + assert right_int == 5 + + left_int: Any = Add(1, 0) + assert type(left_int) == int + assert left_int == 1 + + with assertRaises(RuntimeError, "Invalid value!"): + _ = RaisesException(0) + + not_raised = RaisesException(1) + assert not_raised.val == 1 + + with assertRaises(TypeError, "object.__new__(str) is not safe, use str.__new__()"): + _ = ClsArgNotPassed() + + skip = SkipsBase.__new__(SkipsBase) + assert type(skip) == SkipsBase + assert skip.l == 0 + assert skip.r == 0 + +[case testObjectDunderNewInInterpreted] +from __future__ import annotations +from typing import Any, Union + +class Add: + l: IntLike + r: IntLike + + def __new__(cls, l: IntLike, r: IntLike) -> Any: + print(f'running __new__ with {l} and {r}') + + return ( + l if r == 0 else + r if l == 0 else + object.__new__(cls) + ) + + def __init__(self, l: IntLike, r: IntLike): + self.l = l + self.r = r + + def __repr__(self) -> str: + return f'({self.l} + {self.r})' + +IntLike = Union[int, Add] + +class RaisesException: + def __new__(cls, val: int) -> RaisesException: + if val == 0: + raise RuntimeError("Invalid value!") + return object.__new__(cls) + + def __init__(self, val: int) -> None: + self.val = val + +class ClsArgNotPassed: + def __new__(cls) -> Any: + return object.__new__(str) + +class SkipsBase(Add): + def __new__(cls) -> Any: + obj = object.__new__(cls) + obj.l = 0 + obj.r = 0 + return obj + +[file driver.py] +from native import Add, ClsArgNotPassed, RaisesException, SkipsBase + +from testutil import assertRaises + +print(f'{Add(1, 5)=}') +print(f'{Add(0, 5)=}') +print(f'{Add(1, 0)=}') + +with assertRaises(RuntimeError, "Invalid value!"): + raised = RaisesException(0) + +not_raised = RaisesException(1) +assert not_raised.val == 1 + +with assertRaises(TypeError, "object.__new__(str) is not safe, use str.__new__()"): + str_as_cls = ClsArgNotPassed() + +skip = SkipsBase.__new__(SkipsBase) +assert type(skip) == SkipsBase +assert skip.l == 0 +assert skip.r == 0 + +[out] +running __new__ with 1 and 5 +Add(1, 5)=(1 + 5) +running __new__ with 0 and 5 +Add(0, 5)=5 +running __new__ with 1 and 0 +Add(1, 0)=1 + [case testInheritedDunderNew] from __future__ import annotations from mypy_extensions import mypyc_attr @@ -3795,6 +3945,53 @@ assert t.generic == "{}" assert t.bitfield == 0x0C assert t.default == 10 +[case testUntransformedDunderNewCalls] +from testutil import assertRaises +from typing import Any + +class TestStrCls: + def __new__(cls): + return str.__new__(cls) + + @classmethod + def factory(cls): + return str.__new__(cls) + +class TestStrStr: + def __new__(cls): + return str.__new__(str) + + @classmethod + def factory(cls): + return str.__new__(str) + +class TestStrInt: + def __new__(cls): + return str.__new__(int) + + @classmethod + def factory(cls): + return str.__new__(int) + +def test_untransformed_dunder_new() -> None: + with assertRaises(TypeError, "str.__new__(TestStrCls): TestStrCls is not a subtype of str"): + i = TestStrCls() + + j: Any = TestStrStr() + assert j == "" + + with assertRaises(TypeError, "str.__new__(int): int is not a subtype of str"): + k = TestStrInt() + + with assertRaises(TypeError, "str.__new__(TestStrCls): TestStrCls is not a subtype of str"): + i = TestStrCls.factory() + + j = TestStrStr.factory() + assert j == "" + + with assertRaises(TypeError, "str.__new__(int): int is not a subtype of str"): + k = TestStrInt.factory() + [case testPerTypeFreeList] from __future__ import annotations