Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypyc] Make tuple packing and unpacking more efficient #16022

Merged
merged 10 commits into from
Sep 12, 2023
4 changes: 4 additions & 0 deletions mypyc/analysis/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
Truncate,
TupleGet,
TupleSet,
Unborrow,
Unbox,
Unreachable,
Value,
Expand Down Expand Up @@ -272,6 +273,9 @@ def visit_load_address(self, op: LoadAddress) -> GenAndKill[T]:
def visit_keep_alive(self, op: KeepAlive) -> GenAndKill[T]:
return self.visit_register_op(op)

def visit_unborrow(self, op: Unborrow) -> GenAndKill[T]:
return self.visit_register_op(op)


class DefinedVisitor(BaseAnalysisVisitor[Value]):
"""Visitor for finding defined registers.
Expand Down
4 changes: 4 additions & 0 deletions mypyc/analysis/ircheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
Truncate,
TupleGet,
TupleSet,
Unborrow,
Unbox,
Unreachable,
Value,
Expand Down Expand Up @@ -422,3 +423,6 @@ def visit_load_address(self, op: LoadAddress) -> None:

def visit_keep_alive(self, op: KeepAlive) -> None:
pass

def visit_unborrow(self, op: Unborrow) -> None:
pass
4 changes: 4 additions & 0 deletions mypyc/analysis/selfleaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
Truncate,
TupleGet,
TupleSet,
Unborrow,
Unbox,
Unreachable,
)
Expand Down Expand Up @@ -184,6 +185,9 @@ def visit_load_address(self, op: LoadAddress) -> GenAndKill:
def visit_keep_alive(self, op: KeepAlive) -> GenAndKill:
return CLEAN

def visit_unborrow(self, op: Unborrow) -> GenAndKill:
return CLEAN

def check_register_op(self, op: RegisterOp) -> GenAndKill:
if any(src is self.self_reg for src in op.sources()):
return DIRTY
Expand Down
11 changes: 9 additions & 2 deletions mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
Truncate,
TupleGet,
TupleSet,
Unborrow,
Unbox,
Unreachable,
Value,
Expand Down Expand Up @@ -260,7 +261,6 @@ def visit_tuple_set(self, op: TupleSet) -> None:
else:
for i, item in enumerate(op.items):
self.emit_line(f"{dest}.f{i} = {self.reg(item)};")
self.emit_inc_ref(dest, tuple_type)

def visit_assign(self, op: Assign) -> None:
dest = self.reg(op.dest)
Expand Down Expand Up @@ -499,7 +499,8 @@ def visit_tuple_get(self, op: TupleGet) -> None:
dest = self.reg(op)
src = self.reg(op.src)
self.emit_line(f"{dest} = {src}.f{op.index};")
self.emit_inc_ref(dest, op.type)
if not op.is_borrowed:
self.emit_inc_ref(dest, op.type)

def get_dest_assign(self, dest: Value) -> str:
if not dest.is_void:
Expand Down Expand Up @@ -746,6 +747,12 @@ def visit_keep_alive(self, op: KeepAlive) -> None:
# This is a no-op.
pass

def visit_unborrow(self, op: Unborrow) -> None:
# This is a no-op that propagates the source value.
dest = self.reg(op)
src = self.reg(op.src)
self.emit_line(f"{dest} = {src};")

# Helpers

def label(self, label: BasicBlock) -> str:
Expand Down
67 changes: 65 additions & 2 deletions mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,9 @@ def __init__(self, items: list[Value], line: int) -> None:
def sources(self) -> list[Value]:
return self.items.copy()

def stolen(self) -> list[Value]:
return self.items.copy()

def accept(self, visitor: OpVisitor[T]) -> T:
return visitor.visit_tuple_set(self)

Expand All @@ -801,13 +804,14 @@ class TupleGet(RegisterOp):

error_kind = ERR_NEVER

def __init__(self, src: Value, index: int, line: int = -1) -> None:
def __init__(self, src: Value, index: int, line: int = -1, *, borrow: bool = False) -> None:
super().__init__(line)
self.src = src
self.index = index
assert isinstance(src.type, RTuple), "TupleGet only operates on tuples"
assert index >= 0
self.type = src.type.types[index]
self.is_borrowed = borrow

def sources(self) -> list[Value]:
return [self.src]
Expand Down Expand Up @@ -1387,21 +1391,76 @@ class KeepAlive(RegisterOp):
If we didn't have "keep_alive x", x could be freed immediately
after taking the address of 'item', resulting in a read after free
on the second line.

If 'steal' is true, the value is considered to be stolen at
this op, i.e. it won't be decref'd. You need to ensure that
the value is freed otherwise, perhaps by using borrowing
followed by Unborrow.

Be careful with steal=True -- this can cause memory leaks.
"""

error_kind = ERR_NEVER

def __init__(self, src: list[Value]) -> None:
def __init__(self, src: list[Value], *, steal: bool = False) -> None:
assert src
self.src = src
self.steal = steal

def sources(self) -> list[Value]:
return self.src.copy()

def stolen(self) -> list[Value]:
if self.steal:
return self.src.copy()
return []

def accept(self, visitor: OpVisitor[T]) -> T:
return visitor.visit_keep_alive(self)


class Unborrow(RegisterOp):
"""A no-op op to create a regular reference from a borrowed one.

Borrowed references can only be used temporarily and the reference
counts won't be managed. This value will be refcounted normally.

This is mainly useful if you split an aggregate value, such as
a tuple, into components using borrowed values (to avoid increfs),
and want to treat the components as sharing the original managed
reference. You'll also need to use KeepAlive with steal=True to
"consume" the original tuple reference:

# t is a 2-tuple
r0 = borrow t[0]
r1 = borrow t[1]
r2 = unborrow r0
r3 = unborrow r1
# now (r2, r3) represent the tuple as separate items, and the
# original tuple can be considered dead and available to be
# stolen
keep_alive steal t

Be careful with this -- this can easily cause double freeing.
"""

error_kind = ERR_NEVER

def __init__(self, src: Value) -> None:
assert src.is_borrowed
self.src = src
self.type = src.type

def sources(self) -> list[Value]:
return [self.src]

def stolen(self) -> list[Value]:
return []

def accept(self, visitor: OpVisitor[T]) -> T:
return visitor.visit_unborrow(self)


@trait
class OpVisitor(Generic[T]):
"""Generic visitor over ops (uses the visitor design pattern)."""
Expand Down Expand Up @@ -1548,6 +1607,10 @@ def visit_load_address(self, op: LoadAddress) -> T:
def visit_keep_alive(self, op: KeepAlive) -> T:
raise NotImplementedError

@abstractmethod
def visit_unborrow(self, op: Unborrow) -> T:
raise NotImplementedError


# TODO: Should the following definition live somewhere else?

Expand Down
14 changes: 12 additions & 2 deletions mypyc/ir/pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
Truncate,
TupleGet,
TupleSet,
Unborrow,
Unbox,
Unreachable,
Value,
Expand Down Expand Up @@ -153,7 +154,7 @@ def visit_init_static(self, op: InitStatic) -> str:
return self.format("%s = %r :: %s", name, op.value, op.namespace)

def visit_tuple_get(self, op: TupleGet) -> str:
return self.format("%r = %r[%d]", op, op.src, op.index)
return self.format("%r = %s%r[%d]", op, self.borrow_prefix(op), op.src, op.index)

def visit_tuple_set(self, op: TupleSet) -> str:
item_str = ", ".join(self.format("%r", item) for item in op.items)
Expand Down Expand Up @@ -274,7 +275,16 @@ def visit_load_address(self, op: LoadAddress) -> str:
return self.format("%r = load_address %s", op, op.src)

def visit_keep_alive(self, op: KeepAlive) -> str:
return self.format("keep_alive %s" % ", ".join(self.format("%r", v) for v in op.src))
if op.steal:
steal = "steal "
else:
steal = ""
return self.format(
"keep_alive {}{}".format(steal, ", ".join(self.format("%r", v) for v in op.src))
)

def visit_unborrow(self, op: Unborrow) -> str:
return self.format("%r = unborrow %r", op, op.src)

# Helpers

Expand Down
3 changes: 3 additions & 0 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ def goto_and_activate(self, block: BasicBlock) -> None:
self.goto(block)
self.activate_block(block)

def keep_alive(self, values: list[Value], *, steal: bool = False) -> None:
self.add(KeepAlive(values, steal=steal))

def push_error_handler(self, handler: BasicBlock | None) -> None:
self.error_handlers.append(handler)

Expand Down
23 changes: 23 additions & 0 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@
Register,
Return,
TupleGet,
Unborrow,
Unreachable,
Value,
)
from mypyc.ir.rtypes import (
RInstance,
RTuple,
c_pyssize_t_rprimitive,
exc_rtuple,
is_tagged,
Expand Down Expand Up @@ -183,8 +185,29 @@ def transform_assignment_stmt(builder: IRBuilder, stmt: AssignmentStmt) -> None:

line = stmt.rvalue.line
rvalue_reg = builder.accept(stmt.rvalue)

if builder.non_function_scope() and stmt.is_final_def:
builder.init_final_static(first_lvalue, rvalue_reg)

# Special-case multiple assignments like 'x, y = expr' to reduce refcount ops.
if (
isinstance(first_lvalue, (TupleExpr, ListExpr))
and isinstance(rvalue_reg.type, RTuple)
and len(rvalue_reg.type.types) == len(first_lvalue.items)
and len(lvalues) == 1
and all(is_simple_lvalue(item) for item in first_lvalue.items)
and any(t.is_refcounted for t in rvalue_reg.type.types)
):
n = len(first_lvalue.items)
for i in range(n):
target = builder.get_assignment_target(first_lvalue.items[i])
rvalue_item = builder.add(TupleGet(rvalue_reg, i, borrow=True))
rvalue_item = builder.add(Unborrow(rvalue_item))
builder.assign(target, rvalue_item, line)
builder.builder.keep_alive([rvalue_reg], steal=True)
builder.flush_keep_alives()
return

for lvalue in lvalues:
target = builder.get_assignment_target(lvalue)
builder.assign(target, rvalue_reg, line)
Expand Down
29 changes: 16 additions & 13 deletions mypyc/test-data/irbuild-statements.test
Original file line number Diff line number Diff line change
Expand Up @@ -502,16 +502,16 @@ L0:
[case testMultipleAssignmentBasicUnpacking]
from typing import Tuple, Any

def from_tuple(t: Tuple[int, str]) -> None:
def from_tuple(t: Tuple[bool, None]) -> None:
x, y = t

def from_any(a: Any) -> None:
x, y = a
[out]
def from_tuple(t):
t :: tuple[int, str]
r0, x :: int
r1, y :: str
t :: tuple[bool, None]
r0, x :: bool
r1, y :: None
L0:
r0 = t[0]
x = r0
Expand Down Expand Up @@ -563,16 +563,19 @@ def from_any(a: Any) -> None:
[out]
def from_tuple(t):
t :: tuple[int, object]
r0 :: int
r1, x, r2 :: object
r3, y :: int
r0, r1 :: int
r2, x, r3, r4 :: object
r5, y :: int
L0:
r0 = t[0]
r1 = box(int, r0)
x = r1
r2 = t[1]
r3 = unbox(int, r2)
y = r3
r0 = borrow t[0]
r1 = unborrow r0
r2 = box(int, r1)
x = r2
r3 = borrow t[1]
r4 = unborrow r3
r5 = unbox(int, r4)
y = r5
keep_alive steal t
return 1
def from_any(a):
a, r0, r1 :: object
Expand Down