diff --git a/mypyc/analysis/dataflow.py b/mypyc/analysis/dataflow.py index 827c70a0eb4db..8b6b1d52c988f 100644 --- a/mypyc/analysis/dataflow.py +++ b/mypyc/analysis/dataflow.py @@ -24,6 +24,7 @@ FloatNeg, FloatOp, GetAttr, + GetElement, GetElementPtr, Goto, IncRef, @@ -271,6 +272,9 @@ def visit_float_comparison_op(self, op: FloatComparisonOp) -> GenAndKill[T]: def visit_load_mem(self, op: LoadMem) -> GenAndKill[T]: return self.visit_register_op(op) + def visit_get_element(self, op: GetElement) -> GenAndKill[T]: + return self.visit_register_op(op) + def visit_get_element_ptr(self, op: GetElementPtr) -> GenAndKill[T]: return self.visit_register_op(op) diff --git a/mypyc/analysis/ircheck.py b/mypyc/analysis/ircheck.py index efa3cf046f2c4..238d416729241 100644 --- a/mypyc/analysis/ircheck.py +++ b/mypyc/analysis/ircheck.py @@ -22,6 +22,7 @@ FloatNeg, FloatOp, GetAttr, + GetElement, GetElementPtr, Goto, IncRef, @@ -449,6 +450,9 @@ def visit_load_mem(self, op: LoadMem) -> None: def visit_set_mem(self, op: SetMem) -> None: pass + def visit_get_element(self, op: GetElement) -> None: + pass + def visit_get_element_ptr(self, op: GetElementPtr) -> None: pass diff --git a/mypyc/analysis/selfleaks.py b/mypyc/analysis/selfleaks.py index 8f46cbe3312bc..b2e837565307f 100644 --- a/mypyc/analysis/selfleaks.py +++ b/mypyc/analysis/selfleaks.py @@ -16,6 +16,7 @@ FloatNeg, FloatOp, GetAttr, + GetElement, GetElementPtr, Goto, InitStatic, @@ -179,6 +180,9 @@ def visit_float_comparison_op(self, op: FloatComparisonOp) -> GenAndKill: def visit_load_mem(self, op: LoadMem) -> GenAndKill: return CLEAN + def visit_get_element(self, op: GetElement) -> GenAndKill: + return CLEAN + def visit_get_element_ptr(self, op: GetElementPtr) -> GenAndKill: return CLEAN diff --git a/mypyc/codegen/emitfunc.py b/mypyc/codegen/emitfunc.py index 6164f0f5d026a..3f1bbab58895f 100644 --- a/mypyc/codegen/emitfunc.py +++ b/mypyc/codegen/emitfunc.py @@ -42,6 +42,7 @@ FloatNeg, FloatOp, GetAttr, + GetElement, GetElementPtr, Goto, IncRef, @@ -795,6 +796,12 @@ def visit_set_mem(self, op: SetMem) -> None: if dest != src: self.emit_line(f"*({dest_type} *){dest} = {src};") + def visit_get_element(self, op: GetElement) -> None: + dest = self.reg(op) + src = self.reg(op.src) + dest_type = self.ctype(op.type) + self.emit_line(f"{dest} = ({dest_type}){src}.{op.field};") + def visit_get_element_ptr(self, op: GetElementPtr) -> None: dest = self.reg(op) src = self.reg(op.src) diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index ff3e74bc32574..8140660361acd 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -337,6 +337,8 @@ def set_sources(self, new: list[Value]) -> None: (self.src,) = new def stolen(self) -> list[Value]: + if not self.dest.type.is_refcounted: + return [] return [self.src] def accept(self, visitor: OpVisitor[T]) -> T: @@ -1679,9 +1681,36 @@ def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_set_mem(self) +@final +class GetElement(RegisterOp): + """Get the value of a struct element from a struct value.""" + + error_kind = ERR_NEVER + is_borrowed = True + + def __init__(self, src: Value, field: str, line: int = -1) -> None: + super().__init__(line) + assert isinstance(src.type, RStruct) + self.type = src.type.field_type(field) + self.src = src + self.src_type = src.type + self.field = field + + def sources(self) -> list[Value]: + return [self.src] + + def set_sources(self, new: list[Value]) -> None: + (self.src,) = new + + def accept(self, visitor: OpVisitor[T]) -> T: + return visitor.visit_get_element(self) + + @final class GetElementPtr(RegisterOp): - """Get the address of a struct element. + """Get the address of a struct element from a pointer to a struct. + + If you have a struct value, use GetElement instead. Note that you may need to use KeepAlive to avoid the struct being freed, if it's reference counted, such as PyObject *. @@ -1691,6 +1720,7 @@ class GetElementPtr(RegisterOp): def __init__(self, src: Value, src_type: RType, field: str, line: int = -1) -> None: super().__init__(line) + assert not isinstance(src.type, RStruct) self.type = pointer_rprimitive self.src = src self.src_type = src_type @@ -2008,6 +2038,10 @@ def visit_load_mem(self, op: LoadMem) -> T: def visit_set_mem(self, op: SetMem) -> T: raise NotImplementedError + @abstractmethod + def visit_get_element(self, op: GetElement) -> T: + raise NotImplementedError + @abstractmethod def visit_get_element_ptr(self, op: GetElementPtr) -> T: raise NotImplementedError diff --git a/mypyc/ir/pprint.py b/mypyc/ir/pprint.py index 67a71d961f8d5..d0db9f2460a1d 100644 --- a/mypyc/ir/pprint.py +++ b/mypyc/ir/pprint.py @@ -29,6 +29,7 @@ FloatNeg, FloatOp, GetAttr, + GetElement, GetElementPtr, Goto, IncRef, @@ -280,6 +281,9 @@ def visit_load_mem(self, op: LoadMem) -> str: def visit_set_mem(self, op: SetMem) -> str: return self.format("set_mem %r, %r :: %t*", op.dest, op.src, op.dest_type) + def visit_get_element(self, op: GetElement) -> str: + return self.format("%r = %r.%s", op, op.src, op.field) + def visit_get_element_ptr(self, op: GetElementPtr) -> str: return self.format("%r = get_element_ptr %r %s :: %t", op, op.src, op.field, op.src_type) diff --git a/mypyc/ir/rtypes.py b/mypyc/ir/rtypes.py index 0905f20dfea1e..96f6f7c676f1a 100644 --- a/mypyc/ir/rtypes.py +++ b/mypyc/ir/rtypes.py @@ -865,6 +865,8 @@ def __init__(self, name: str, names: list[str], types: list[RType]) -> None: self.name = name self.names = names self.types = types + self.is_refcounted = any(t.is_refcounted for t in self.types) + # generate dummy names if len(self.names) < len(self.types): for i in range(len(self.types) - len(self.names)): @@ -872,6 +874,12 @@ def __init__(self, name: str, names: list[str], types: list[RType]) -> None: self.offsets, self.size = compute_aligned_offsets_and_size(types) self._ctype = name + def field_type(self, name: str) -> RType: + for n, t in zip(self.names, self.types): + if n == name: + return t + assert False, f"{self.name} has no field '{name}'" + def accept(self, visitor: RTypeVisitor[T]) -> T: return visitor.visit_rstruct(self) diff --git a/mypyc/transform/ir_transform.py b/mypyc/transform/ir_transform.py index 4724e9d96fe83..0ae48b7c80d7a 100644 --- a/mypyc/transform/ir_transform.py +++ b/mypyc/transform/ir_transform.py @@ -20,6 +20,7 @@ FloatNeg, FloatOp, GetAttr, + GetElement, GetElementPtr, Goto, IncRef, @@ -212,6 +213,9 @@ def visit_load_mem(self, op: LoadMem) -> Value | None: def visit_set_mem(self, op: SetMem) -> Value | None: return self.add(op) + def visit_get_element(self, op: GetElement) -> Value | None: + return self.add(op) + def visit_get_element_ptr(self, op: GetElementPtr) -> Value | None: return self.add(op) @@ -355,6 +359,9 @@ def visit_set_mem(self, op: SetMem) -> None: op.dest = self.fix_op(op.dest) op.src = self.fix_op(op.src) + def visit_get_element(self, op: GetElement) -> None: + op.src = self.fix_op(op.src) + def visit_get_element_ptr(self, op: GetElementPtr) -> None: op.src = self.fix_op(op.src)