Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,10 @@ def add_function(self, func_ir: FuncIR, line: int) -> None:
self.function_names.add(name)
self.functions.append(func_ir)

def get_current_class_ir(self) -> ClassIR | None:
type_info = self.fn_info.fitem.info
return self.mapper.type_to_ir.get(type_info)


def gen_arg_defaults(builder: IRBuilder) -> None:
"""Generate blocks for arguments that have default values.
Expand Down
35 changes: 9 additions & 26 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
from mypyc.ir.ops import (
Assign,
BasicBlock,
Call,
ComparisonOp,
Integer,
LoadAddress,
Expand Down Expand Up @@ -98,7 +97,11 @@
join_formatted_strings,
tokenizer_printf_style,
)
from mypyc.irbuild.specialize import apply_function_specialization, apply_method_specialization
from mypyc.irbuild.specialize import (
apply_function_specialization,
apply_method_specialization,
translate_object_new,
)
from mypyc.primitives.bytes_ops import bytes_slice_op
from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, exact_dict_set_item_op
from mypyc.primitives.generic_ops import iter_op, name_op
Expand Down Expand Up @@ -473,35 +476,15 @@ def translate_super_method_call(builder: IRBuilder, expr: CallExpr, callee: Supe
if callee.name in base.method_decls:
break
else:
if callee.name == "__new__":
result = translate_object_new(builder, expr, MemberExpr(callee.call, "__new__"))
if result:
return result
if ir.is_ext_class and ir.builtin_base is None and not ir.inherits_python:
if callee.name == "__init__" and len(expr.args) == 0:
# Call translates to object.__init__(self), which is a
# no-op, so omit the call.
return builder.none()
elif callee.name == "__new__":
# object.__new__(cls)
assert (
len(expr.args) == 1
), f"Expected object.__new__() call to have exactly 1 argument, got {len(expr.args)}"
typ_arg = expr.args[0]
method_args = builder.fn_info.fitem.arg_names
if (
isinstance(typ_arg, NameExpr)
and len(method_args) > 0
and method_args[0] == typ_arg.name
):
subtype = builder.accept(expr.args[0])
return builder.add(Call(ir.setup, [subtype], expr.line))

if callee.name == "__new__":
call = "super().__new__()"
if not ir.is_ext_class:
builder.error(f"{call} not supported for non-extension classes", expr.line)
if ir.inherits_python:
builder.error(
f"{call} not supported for classes inheriting from non-native classes",
expr.line,
)
return translate_call(builder, expr, callee)

decl = base.method_decl(callee.name)
Expand Down
44 changes: 44 additions & 0 deletions mypyc/irbuild/specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@
NameExpr,
RefExpr,
StrExpr,
SuperExpr,
TupleExpr,
Var,
)
from mypy.types import AnyType, TypeOfAny
from mypyc.ir.ops import (
BasicBlock,
Call,
Extend,
Integer,
RaiseStandardError,
Expand Down Expand Up @@ -68,6 +70,7 @@
is_list_rprimitive,
is_uint8_rprimitive,
list_rprimitive,
object_rprimitive,
set_rprimitive,
str_rprimitive,
uint8_rprimitive,
Expand Down Expand Up @@ -1002,3 +1005,44 @@ def translate_ord(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value
if isinstance(arg, (StrExpr, BytesExpr)) and len(arg.value) == 1:
return Integer(ord(arg.value))
return None


@specialize_function("__new__", object_rprimitive)
def translate_object_new(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
fn = builder.fn_info
if fn.name != "__new__":
return None

is_super_new = isinstance(expr.callee, SuperExpr)
is_object_new = (
isinstance(callee, MemberExpr)
and isinstance(callee.expr, NameExpr)
and callee.expr.fullname == "builtins.object"
)
if not (is_super_new or is_object_new):
return None

ir = builder.get_current_class_ir()
if ir is None:
return None

call = '"object.__new__()"'
if not ir.is_ext_class:
builder.error(f"{call} not supported for non-extension classes", expr.line)
return None
if ir.inherits_python:
builder.error(
f"{call} not supported for classes inheriting from non-native classes", expr.line
)
return None
if len(expr.args) != 1:
builder.error(f"{call} supported only with 1 argument, got {len(expr.args)}", expr.line)
return None

typ_arg = expr.args[0]
method_args = fn.fitem.arg_names
if isinstance(typ_arg, NameExpr) and len(method_args) > 0 and method_args[0] == typ_arg.name:
subtype = builder.accept(expr.args[0])
return builder.add(Call(ir.setup, [subtype], expr.line))

return None
34 changes: 34 additions & 0 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ListExpr,
Lvalue,
MatchStmt,
NameExpr,
OperatorAssignmentStmt,
RaiseStmt,
ReturnStmt,
Expand Down Expand Up @@ -170,10 +171,43 @@ def transform_return_stmt(builder: IRBuilder, stmt: ReturnStmt) -> None:
builder.nonlocal_control[-1].gen_return(builder, retval, stmt.line)


def check_unsupported_cls_assignment(builder: IRBuilder, stmt: AssignmentStmt) -> None:
fn = builder.fn_info
method_args = fn.fitem.arg_names
if fn.name != "__new__" or len(method_args) == 0:
return

ir = builder.get_current_class_ir()
if ir is None or ir.inherits_python or not ir.is_ext_class:
return

cls_arg = method_args[0]

def flatten(lvalues: list[Expression]) -> list[Expression]:
flat = []
for lvalue in lvalues:
if isinstance(lvalue, (TupleExpr, ListExpr)):
flat += flatten(lvalue.items)
else:
flat.append(lvalue)
return flat

lvalues = flatten(stmt.lvalues)

for lvalue in lvalues:
if isinstance(lvalue, NameExpr) and lvalue.name == cls_arg:
# Disallowed because it could break the transformation of object.__new__ calls
# inside __new__ methods.
builder.error(
f'Assignment to argument "{cls_arg}" in "__new__" method unsupported', stmt.line
)


def transform_assignment_stmt(builder: IRBuilder, stmt: AssignmentStmt) -> None:
lvalues = stmt.lvalues
assert lvalues
builder.disallow_class_assignments(lvalues, stmt.line)
check_unsupported_cls_assignment(builder, stmt)
first_lvalue = lvalues[0]
if stmt.type and isinstance(stmt.rvalue, TempNode):
# This is actually a variable annotation without initializer. Don't generate
Expand Down
Loading
Loading