From c94d8e3cd02e8be5e4594d84ec77c84b3faf7948 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Sat, 9 Mar 2024 09:16:41 +0000 Subject: [PATCH] [mypyc] Provide an easier way to define IR-to-IR transforms (#16998) This makes it easy to define simple IR-to-IR transforms by subclassing `IRTansform` and overriding some visit methods. Add an implementation of a simple copy propagation optimization as an example. This will be used by the implementation of mypyc/mypyc#854, and this can also be used for various optimizations. The IR transform preserves the identities of ops that are not modified. This means that the old IR is no longer valid after the transform, but the transform can be fast since we don't need to allocate many objects if only a small subset of ops will be modified by a transform. --- mypyc/codegen/emitmodule.py | 13 +- mypyc/irbuild/builder.py | 6 +- mypyc/irbuild/ll_builder.py | 8 +- mypyc/test-data/opt-copy-propagation.test | 400 ++++++++++++++++++++++ mypyc/test/test_copy_propagation.py | 47 +++ mypyc/transform/copy_propagation.py | 94 +++++ mypyc/transform/ir_transform.py | 353 +++++++++++++++++++ 7 files changed, 904 insertions(+), 17 deletions(-) create mode 100644 mypyc/test-data/opt-copy-propagation.test create mode 100644 mypyc/test/test_copy_propagation.py create mode 100644 mypyc/transform/copy_propagation.py create mode 100644 mypyc/transform/ir_transform.py diff --git a/mypyc/codegen/emitmodule.py b/mypyc/codegen/emitmodule.py index 6c0dfd43b9af..0035bd53188b 100644 --- a/mypyc/codegen/emitmodule.py +++ b/mypyc/codegen/emitmodule.py @@ -56,6 +56,7 @@ from mypyc.irbuild.prepare import load_type_map from mypyc.namegen import NameGenerator, exported_name from mypyc.options import CompilerOptions +from mypyc.transform.copy_propagation import do_copy_propagation from mypyc.transform.exceptions import insert_exception_handling from mypyc.transform.refcount import insert_ref_count_opcodes from mypyc.transform.uninit import insert_uninit_checks @@ -225,18 +226,16 @@ def compile_scc_to_ir( if errors.num_errors > 0: return modules - # Insert uninit checks. for module in modules.values(): for fn in module.functions: + # Insert uninit checks. insert_uninit_checks(fn) - # Insert exception handling. - for module in modules.values(): - for fn in module.functions: + # Insert exception handling. insert_exception_handling(fn) - # Insert refcount handling. - for module in modules.values(): - for fn in module.functions: + # Insert refcount handling. insert_ref_count_opcodes(fn) + # Perform copy propagation optimization. + do_copy_propagation(fn, compiler_options) return modules diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index f201a4737f89..52891d68e3b2 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -160,7 +160,7 @@ def __init__( options: CompilerOptions, singledispatch_impls: dict[FuncDef, list[RegisterImplInfo]], ) -> None: - self.builder = LowLevelIRBuilder(current_module, errors, mapper, options) + self.builder = LowLevelIRBuilder(errors, options) self.builders = [self.builder] self.symtables: list[dict[SymbolNode, SymbolTarget]] = [{}] self.runtime_args: list[list[RuntimeArg]] = [[]] @@ -1111,9 +1111,7 @@ def flatten_classes(self, arg: RefExpr | TupleExpr) -> list[ClassIR] | None: def enter(self, fn_info: FuncInfo | str = "") -> None: if isinstance(fn_info, str): fn_info = FuncInfo(name=fn_info) - self.builder = LowLevelIRBuilder( - self.current_module, self.errors, self.mapper, self.options - ) + self.builder = LowLevelIRBuilder(self.errors, self.options) self.builder.set_module(self.module_name, self.module_path) self.builders.append(self.builder) self.symtables.append({}) diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index d1ea91476a66..45c06e11befd 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -126,7 +126,6 @@ short_int_rprimitive, str_rprimitive, ) -from mypyc.irbuild.mapper import Mapper from mypyc.irbuild.util import concrete_arg_kind from mypyc.options import CompilerOptions from mypyc.primitives.bytes_ops import bytes_compare @@ -220,12 +219,8 @@ class LowLevelIRBuilder: - def __init__( - self, current_module: str, errors: Errors, mapper: Mapper, options: CompilerOptions - ) -> None: - self.current_module = current_module + def __init__(self, errors: Errors | None, options: CompilerOptions) -> None: self.errors = errors - self.mapper = mapper self.options = options self.args: list[Register] = [] self.blocks: list[BasicBlock] = [] @@ -2394,6 +2389,7 @@ def _create_dict(self, keys: list[Value], values: list[Value], line: int) -> Val return self.call_c(dict_new_op, [], line) def error(self, msg: str, line: int) -> None: + assert self.errors is not None, "cannot generate errors in this compiler phase" self.errors.error(msg, self.module_path, line) diff --git a/mypyc/test-data/opt-copy-propagation.test b/mypyc/test-data/opt-copy-propagation.test new file mode 100644 index 000000000000..49b80f4385fc --- /dev/null +++ b/mypyc/test-data/opt-copy-propagation.test @@ -0,0 +1,400 @@ +-- Test cases for copy propagation optimization. This also tests IR transforms in general, +-- as copy propagation was the first IR transform that was implemented. + +[case testCopyPropagationSimple] +def g() -> int: + return 1 + +def f() -> int: + y = g() + return y +[out] +def g(): +L0: + return 2 +def f(): + r0 :: int +L0: + r0 = g() + return r0 + +[case testCopyPropagationChain] +def f(x: int) -> int: + y = x + z = y + return z +[out] +def f(x): + x :: int +L0: + return x + +[case testCopyPropagationChainPartial] +def f(x: int) -> int: + y = x + z = y + x = 2 + return z +[out] +def f(x): + x, y :: int +L0: + y = x + x = 4 + return y + +[case testCopyPropagationChainBad] +def f(x: int) -> int: + y = x + z = y + y = 2 + return z +[out] +def f(x): + x, y, z :: int +L0: + y = x + z = y + y = 4 + return z + +[case testCopyPropagationMutatedSource1] +def f(x: int) -> int: + y = x + x = 1 + return y +[out] +def f(x): + x, y :: int +L0: + y = x + x = 2 + return y + +[case testCopyPropagationMutatedSource2] +def f() -> int: + z = 1 + y = z + z = 2 + return y +[out] +def f(): + z, y :: int +L0: + z = 2 + y = z + z = 4 + return y + +[case testCopyPropagationTooComplex] +def f(b: bool, x: int) -> int: + if b: + y = x + return y + else: + y = 1 + return y +[out] +def f(b, x): + b :: bool + x, y :: int +L0: + if b goto L1 else goto L2 :: bool +L1: + y = x + return y +L2: + y = 2 + return y + +[case testCopyPropagationArg] +def f(x: int) -> int: + x = 2 + return x +[out] +def f(x): + x :: int +L0: + x = 4 + return x + +[case testCopyPropagationPartiallyDefined1] +def f(b: bool) -> int: + if b: + x = 1 + y = x + return y +[out] +def f(b): + b :: bool + r0, x :: int + r1 :: bool + y :: int +L0: + r0 = :: int + x = r0 + if b goto L1 else goto L2 :: bool +L1: + x = 2 +L2: + if is_error(x) goto L3 else goto L4 +L3: + r1 = raise UnboundLocalError('local variable "x" referenced before assignment') + unreachable +L4: + y = x + return y + +-- The remaining test cases test basic IRTransform functionality and are not +-- all needed for testing copy propagation as such. + +[case testIRTransformBranch] +from mypy_extensions import i64 + +def f(x: bool) -> int: + y = x + if y: + return 1 + else: + return 2 +[out] +def f(x): + x :: bool +L0: + if x goto L1 else goto L2 :: bool +L1: + return 2 +L2: + return 4 + +[case testIRTransformAssignment] +def f(b: bool, x: int) -> int: + y = x + if b: + return y + else: + return 1 +[out] +def f(b, x): + b :: bool + x :: int +L0: + if b goto L1 else goto L2 :: bool +L1: + return x +L2: + return 2 + +[case testIRTransformRegisterOps1] +from __future__ import annotations +from typing import cast + +class C: + a: int + + def m(self, x: int) -> None: pass + +def get_attr(x: C) -> int: + y = x + return y.a + +def set_attr(x: C) -> None: + y = x + y.a = 1 + +def tuple_get(x: tuple[int, int]) -> int: + y = x + return y[0] + +def tuple_set(x: int, xx: int) -> tuple[int, int]: + y = x + z = xx + return y, z + +def call(x: int) -> int: + y = x + return call(y) + +def method_call(c: C, x: int) -> None: + y = x + c.m(y) + +def cast_op(x: object) -> str: + y = x + return cast(str, y) + +def box(x: int) -> object: + y = x + return y + +def unbox(x: object) -> int: + y = x + return cast(int, y) + +def call_c(x: list[str]) -> None: + y = x + y.append("x") + +def keep_alive(x: C) -> int: + y = x + return y.a + 1 +[out] +def C.m(self, x): + self :: __main__.C + x :: int +L0: + return 1 +def get_attr(x): + x :: __main__.C + r0 :: int +L0: + r0 = x.a + return r0 +def set_attr(x): + x :: __main__.C + r0 :: bool +L0: + x.a = 2; r0 = is_error + return 1 +def tuple_get(x): + x :: tuple[int, int] + r0 :: int +L0: + r0 = x[0] + return r0 +def tuple_set(x, xx): + x, xx :: int + r0 :: tuple[int, int] +L0: + r0 = (x, xx) + return r0 +def call(x): + x, r0 :: int +L0: + r0 = call(x) + return r0 +def method_call(c, x): + c :: __main__.C + x :: int + r0 :: None +L0: + r0 = c.m(x) + return 1 +def cast_op(x): + x :: object + r0 :: str +L0: + r0 = cast(str, x) + return r0 +def box(x): + x :: int + r0 :: object +L0: + r0 = box(int, x) + return r0 +def unbox(x): + x :: object + r0 :: int +L0: + r0 = unbox(int, x) + return r0 +def call_c(x): + x :: list + r0 :: str + r1 :: i32 + r2 :: bit +L0: + r0 = 'x' + r1 = PyList_Append(x, r0) + r2 = r1 >= 0 :: signed + return 1 +def keep_alive(x): + x :: __main__.C + r0, r1 :: int +L0: + r0 = borrow x.a + r1 = CPyTagged_Add(r0, 2) + keep_alive x + return r1 + +[case testIRTransformRegisterOps2] +from mypy_extensions import i32, i64 + +def truncate(x: i64) -> i32: + y = x + return i32(y) + +def extend(x: i32) -> i64: + y = x + return i64(y) + +def int_op(x: i64, xx: i64) -> i64: + y = x + z = xx + return y + z + +def comparison_op(x: i64, xx: i64) -> bool: + y = x + z = xx + return y == z + +def float_op(x: float, xx: float) -> float: + y = x + z = xx + return y + z + +def float_neg(x: float) -> float: + y = x + return -y + +def float_comparison_op(x: float, xx: float) -> bool: + y = x + z = xx + return y == z +[out] +def truncate(x): + x :: i64 + r0 :: i32 +L0: + r0 = truncate x: i64 to i32 + return r0 +def extend(x): + x :: i32 + r0 :: i64 +L0: + r0 = extend signed x: i32 to i64 + return r0 +def int_op(x, xx): + x, xx, r0 :: i64 +L0: + r0 = x + xx + return r0 +def comparison_op(x, xx): + x, xx :: i64 + r0 :: bit +L0: + r0 = x == xx + return r0 +def float_op(x, xx): + x, xx, r0 :: float +L0: + r0 = x + xx + return r0 +def float_neg(x): + x, r0 :: float +L0: + r0 = -x + return r0 +def float_comparison_op(x, xx): + x, xx :: float + r0 :: bit +L0: + r0 = x == xx + return r0 + +-- Note that transforms of these ops aren't tested here: +-- * LoadMem +-- * SetMem +-- * GetElementPtr +-- * LoadAddress +-- * Unborrow diff --git a/mypyc/test/test_copy_propagation.py b/mypyc/test/test_copy_propagation.py new file mode 100644 index 000000000000..c729e3d186c3 --- /dev/null +++ b/mypyc/test/test_copy_propagation.py @@ -0,0 +1,47 @@ +"""Runner for copy propagation optimization tests.""" + +from __future__ import annotations + +import os.path + +from mypy.errors import CompileError +from mypy.test.config import test_temp_dir +from mypy.test.data import DataDrivenTestCase +from mypyc.common import TOP_LEVEL_NAME +from mypyc.ir.pprint import format_func +from mypyc.options import CompilerOptions +from mypyc.test.testutil import ( + ICODE_GEN_BUILTINS, + MypycDataSuite, + assert_test_output, + build_ir_for_single_file, + remove_comment_lines, + use_custom_builtins, +) +from mypyc.transform.copy_propagation import do_copy_propagation +from mypyc.transform.uninit import insert_uninit_checks + +files = ["opt-copy-propagation.test"] + + +class TestCopyPropagation(MypycDataSuite): + files = files + base_path = test_temp_dir + + def run_case(self, testcase: DataDrivenTestCase) -> None: + with use_custom_builtins(os.path.join(self.data_prefix, ICODE_GEN_BUILTINS), testcase): + expected_output = remove_comment_lines(testcase.output) + try: + ir = build_ir_for_single_file(testcase.input) + except CompileError as e: + actual = e.messages + else: + actual = [] + for fn in ir: + if fn.name == TOP_LEVEL_NAME and not testcase.name.endswith("_toplevel"): + continue + insert_uninit_checks(fn) + do_copy_propagation(fn, CompilerOptions()) + actual.extend(format_func(fn)) + + assert_test_output(testcase, actual, "Invalid source code output", expected_output) diff --git a/mypyc/transform/copy_propagation.py b/mypyc/transform/copy_propagation.py new file mode 100644 index 000000000000..49de616f85a3 --- /dev/null +++ b/mypyc/transform/copy_propagation.py @@ -0,0 +1,94 @@ +"""Simple copy propagation optimization. + +Example input: + + x = f() + y = x + +The register x is redundant and we can directly assign its value to y: + + y = f() + +This can optimize away registers that are assigned to once. +""" + +from __future__ import annotations + +from mypyc.ir.func_ir import FuncIR +from mypyc.ir.ops import Assign, AssignMulti, LoadAddress, LoadErrorValue, Register, Value +from mypyc.irbuild.ll_builder import LowLevelIRBuilder +from mypyc.options import CompilerOptions +from mypyc.sametype import is_same_type +from mypyc.transform.ir_transform import IRTransform + + +def do_copy_propagation(fn: FuncIR, options: CompilerOptions) -> None: + """Perform copy propagation optimization for fn.""" + + # Anything with an assignment count >1 will not be optimized + # here, as it would be require data flow analysis and we want to + # keep this simple and fast, at least until we've made data flow + # analysis much faster. + counts: dict[Value, int] = {} + replacements: dict[Value, Value] = {} + for arg in fn.arg_regs: + # Arguments are always assigned to initially + counts[arg] = 1 + + for block in fn.blocks: + for op in block.ops: + if isinstance(op, Assign): + c = counts.get(op.dest, 0) + counts[op.dest] = c + 1 + # Does this look like a supported assignment? + # TODO: Something needs LoadErrorValue assignments to be preserved? + if ( + c == 0 + and is_same_type(op.dest.type, op.src.type) + and not isinstance(op.src, LoadErrorValue) + ): + replacements[op.dest] = op.src + elif c == 1: + # Too many assignments -- don't replace this one + replacements.pop(op.dest, 0) + elif isinstance(op, AssignMulti): + # Copy propagation not supported for AssignMulti destinations + counts[op.dest] = 2 + replacements.pop(op.dest, 0) + elif isinstance(op, LoadAddress): + # We don't support taking the address of an arbitrary Value, + # so we'll need to preserve the operands of LoadAddress. + if isinstance(op.src, Register): + counts[op.src] = 2 + replacements.pop(op.src, 0) + + # Follow chains of propagation with more than one assignment. + for src, dst in list(replacements.items()): + if counts.get(dst, 0) > 1: + # Not supported + del replacements[src] + else: + while dst in replacements: + dst = replacements[dst] + if counts.get(dst, 0) > 1: + # Not supported + del replacements[src] + if src in replacements: + replacements[src] = dst + + builder = LowLevelIRBuilder(None, options) + transform = CopyPropagationTransform(builder, replacements) + transform.transform_blocks(fn.blocks) + fn.blocks = builder.blocks + + +class CopyPropagationTransform(IRTransform): + def __init__(self, builder: LowLevelIRBuilder, map: dict[Value, Value]) -> None: + super().__init__(builder) + self.op_map.update(map) + self.removed = set(map) + + def visit_assign(self, op: Assign) -> Value | None: + if op.dest in self.removed: + return None + return self.add(op) diff --git a/mypyc/transform/ir_transform.py b/mypyc/transform/ir_transform.py new file mode 100644 index 000000000000..1bcfc8fb5feb --- /dev/null +++ b/mypyc/transform/ir_transform.py @@ -0,0 +1,353 @@ +"""Helpers for implementing generic IR to IR transforms.""" + +from __future__ import annotations + +from typing import Final, Optional + +from mypyc.ir.ops import ( + Assign, + AssignMulti, + BasicBlock, + Box, + Branch, + Call, + CallC, + Cast, + ComparisonOp, + DecRef, + Extend, + FloatComparisonOp, + FloatNeg, + FloatOp, + GetAttr, + GetElementPtr, + Goto, + IncRef, + InitStatic, + IntOp, + KeepAlive, + LoadAddress, + LoadErrorValue, + LoadGlobal, + LoadLiteral, + LoadMem, + LoadStatic, + MethodCall, + Op, + OpVisitor, + RaiseStandardError, + Return, + SetAttr, + SetMem, + Truncate, + TupleGet, + TupleSet, + Unborrow, + Unbox, + Unreachable, + Value, +) +from mypyc.irbuild.ll_builder import LowLevelIRBuilder + + +class IRTransform(OpVisitor[Optional[Value]]): + """Identity transform. + + Subclass and override to perform changes to IR. + + Subclass IRTransform and override any OpVisitor visit_* methods + that perform any IR changes. The default implementations implement + an identity transform. + + A visit method can return None to remove ops. In this case the + transform must ensure that no op uses the original removed op + as a source after the transform. + + You can retain old BasicBlock and op references in ops. The transform + will automatically patch these for you as needed. + """ + + def __init__(self, builder: LowLevelIRBuilder) -> None: + self.builder = builder + # Subclasses add additional op mappings here. A None value indicates + # that the op/register is deleted. + self.op_map: dict[Value, Value | None] = {} + + def transform_blocks(self, blocks: list[BasicBlock]) -> None: + """Transform basic blocks that represent a single function. + + The result of the transform will be collected at self.builder.blocks. + """ + block_map: dict[BasicBlock, BasicBlock] = {} + op_map = self.op_map + for block in blocks: + new_block = BasicBlock() + block_map[block] = new_block + self.builder.activate_block(new_block) + new_block.error_handler = block.error_handler + for op in block.ops: + new_op = op.accept(self) + if new_op is not op: + op_map[op] = new_op + + # Update all op/block references to point to the transformed ones. + patcher = PatchVisitor(op_map, block_map) + for block in self.builder.blocks: + for op in block.ops: + op.accept(patcher) + if block.error_handler is not None: + block.error_handler = block_map.get(block.error_handler, block.error_handler) + + def add(self, op: Op) -> Value: + return self.builder.add(op) + + def visit_goto(self, op: Goto) -> Value: + return self.add(op) + + def visit_branch(self, op: Branch) -> Value: + return self.add(op) + + def visit_return(self, op: Return) -> Value: + return self.add(op) + + def visit_unreachable(self, op: Unreachable) -> Value: + return self.add(op) + + def visit_assign(self, op: Assign) -> Value | None: + return self.add(op) + + def visit_assign_multi(self, op: AssignMulti) -> Value | None: + return self.add(op) + + def visit_load_error_value(self, op: LoadErrorValue) -> Value | None: + return self.add(op) + + def visit_load_literal(self, op: LoadLiteral) -> Value | None: + return self.add(op) + + def visit_get_attr(self, op: GetAttr) -> Value | None: + return self.add(op) + + def visit_set_attr(self, op: SetAttr) -> Value | None: + return self.add(op) + + def visit_load_static(self, op: LoadStatic) -> Value | None: + return self.add(op) + + def visit_init_static(self, op: InitStatic) -> Value | None: + return self.add(op) + + def visit_tuple_get(self, op: TupleGet) -> Value | None: + return self.add(op) + + def visit_tuple_set(self, op: TupleSet) -> Value | None: + return self.add(op) + + def visit_inc_ref(self, op: IncRef) -> Value | None: + return self.add(op) + + def visit_dec_ref(self, op: DecRef) -> Value | None: + return self.add(op) + + def visit_call(self, op: Call) -> Value | None: + return self.add(op) + + def visit_method_call(self, op: MethodCall) -> Value | None: + return self.add(op) + + def visit_cast(self, op: Cast) -> Value | None: + return self.add(op) + + def visit_box(self, op: Box) -> Value | None: + return self.add(op) + + def visit_unbox(self, op: Unbox) -> Value | None: + return self.add(op) + + def visit_raise_standard_error(self, op: RaiseStandardError) -> Value | None: + return self.add(op) + + def visit_call_c(self, op: CallC) -> Value | None: + return self.add(op) + + def visit_truncate(self, op: Truncate) -> Value | None: + return self.add(op) + + def visit_extend(self, op: Extend) -> Value | None: + return self.add(op) + + def visit_load_global(self, op: LoadGlobal) -> Value | None: + return self.add(op) + + def visit_int_op(self, op: IntOp) -> Value | None: + return self.add(op) + + def visit_comparison_op(self, op: ComparisonOp) -> Value | None: + return self.add(op) + + def visit_float_op(self, op: FloatOp) -> Value | None: + return self.add(op) + + def visit_float_neg(self, op: FloatNeg) -> Value | None: + return self.add(op) + + def visit_float_comparison_op(self, op: FloatComparisonOp) -> Value | None: + return self.add(op) + + def visit_load_mem(self, op: LoadMem) -> Value | None: + return self.add(op) + + def visit_set_mem(self, op: SetMem) -> Value | None: + return self.add(op) + + def visit_get_element_ptr(self, op: GetElementPtr) -> Value | None: + return self.add(op) + + def visit_load_address(self, op: LoadAddress) -> Value | None: + return self.add(op) + + def visit_keep_alive(self, op: KeepAlive) -> Value | None: + return self.add(op) + + def visit_unborrow(self, op: Unborrow) -> Value | None: + return self.add(op) + + +class PatchVisitor(OpVisitor[None]): + def __init__( + self, op_map: dict[Value, Value | None], block_map: dict[BasicBlock, BasicBlock] + ) -> None: + self.op_map: Final = op_map + self.block_map: Final = block_map + + def fix_op(self, op: Value) -> Value: + new = self.op_map.get(op, op) + assert new is not None, "use of removed op" + return new + + def fix_block(self, block: BasicBlock) -> BasicBlock: + return self.block_map.get(block, block) + + def visit_goto(self, op: Goto) -> None: + op.label = self.fix_block(op.label) + + def visit_branch(self, op: Branch) -> None: + op.value = self.fix_op(op.value) + op.true = self.fix_block(op.true) + op.false = self.fix_block(op.false) + + def visit_return(self, op: Return) -> None: + op.value = self.fix_op(op.value) + + def visit_unreachable(self, op: Unreachable) -> None: + pass + + def visit_assign(self, op: Assign) -> None: + op.src = self.fix_op(op.src) + + def visit_assign_multi(self, op: AssignMulti) -> None: + op.src = [self.fix_op(s) for s in op.src] + + def visit_load_error_value(self, op: LoadErrorValue) -> None: + pass + + def visit_load_literal(self, op: LoadLiteral) -> None: + pass + + def visit_get_attr(self, op: GetAttr) -> None: + op.obj = self.fix_op(op.obj) + + def visit_set_attr(self, op: SetAttr) -> None: + op.obj = self.fix_op(op.obj) + op.src = self.fix_op(op.src) + + def visit_load_static(self, op: LoadStatic) -> None: + pass + + def visit_init_static(self, op: InitStatic) -> None: + op.value = self.fix_op(op.value) + + def visit_tuple_get(self, op: TupleGet) -> None: + op.src = self.fix_op(op.src) + + def visit_tuple_set(self, op: TupleSet) -> None: + op.items = [self.fix_op(item) for item in op.items] + + def visit_inc_ref(self, op: IncRef) -> None: + op.src = self.fix_op(op.src) + + def visit_dec_ref(self, op: DecRef) -> None: + op.src = self.fix_op(op.src) + + def visit_call(self, op: Call) -> None: + op.args = [self.fix_op(arg) for arg in op.args] + + def visit_method_call(self, op: MethodCall) -> None: + op.obj = self.fix_op(op.obj) + op.args = [self.fix_op(arg) for arg in op.args] + + def visit_cast(self, op: Cast) -> None: + op.src = self.fix_op(op.src) + + def visit_box(self, op: Box) -> None: + op.src = self.fix_op(op.src) + + def visit_unbox(self, op: Unbox) -> None: + op.src = self.fix_op(op.src) + + def visit_raise_standard_error(self, op: RaiseStandardError) -> None: + if isinstance(op.value, Value): + op.value = self.fix_op(op.value) + + def visit_call_c(self, op: CallC) -> None: + op.args = [self.fix_op(arg) for arg in op.args] + + def visit_truncate(self, op: Truncate) -> None: + op.src = self.fix_op(op.src) + + def visit_extend(self, op: Extend) -> None: + op.src = self.fix_op(op.src) + + def visit_load_global(self, op: LoadGlobal) -> None: + pass + + def visit_int_op(self, op: IntOp) -> None: + op.lhs = self.fix_op(op.lhs) + op.rhs = self.fix_op(op.rhs) + + def visit_comparison_op(self, op: ComparisonOp) -> None: + op.lhs = self.fix_op(op.lhs) + op.rhs = self.fix_op(op.rhs) + + def visit_float_op(self, op: FloatOp) -> None: + op.lhs = self.fix_op(op.lhs) + op.rhs = self.fix_op(op.rhs) + + def visit_float_neg(self, op: FloatNeg) -> None: + op.src = self.fix_op(op.src) + + def visit_float_comparison_op(self, op: FloatComparisonOp) -> None: + op.lhs = self.fix_op(op.lhs) + op.rhs = self.fix_op(op.rhs) + + def visit_load_mem(self, op: LoadMem) -> None: + op.src = self.fix_op(op.src) + + def visit_set_mem(self, op: SetMem) -> None: + op.dest = self.fix_op(op.dest) + op.src = self.fix_op(op.src) + + def visit_get_element_ptr(self, op: GetElementPtr) -> None: + op.src = self.fix_op(op.src) + + def visit_load_address(self, op: LoadAddress) -> None: + if isinstance(op.src, LoadStatic): + new = self.fix_op(op.src) + assert isinstance(new, LoadStatic) + op.src = new + + def visit_keep_alive(self, op: KeepAlive) -> None: + op.src = [self.fix_op(s) for s in op.src] + + def visit_unborrow(self, op: Unborrow) -> None: + op.src = self.fix_op(op.src)