diff --git a/tests/test_rewriter.py b/tests/test_rewriter.py index 85f876d47b..bf39f5c94e 100644 --- a/tests/test_rewriter.py +++ b/tests/test_rewriter.py @@ -1,5 +1,6 @@ from typing import Callable +import pytest from conftest import assert_print_op from xdsl.dialects.arith import Addi, Arith, Constant @@ -441,3 +442,34 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None: rewriter.replace_op(return_op, [new_op]) rewrite_and_compare(prog, expected, transformation) + + +# Test erase operation +def test_erase_op(): + prog = """\ +"builtin.module"() ({ + %0 = "arith.constant"() {"value" = 42 : i32} : () -> i32 + %1 = "arith.addi"(%0, %0) : (i32, i32) -> i32 +}) : () -> () +""" + + expected = """\ +"builtin.module"() ({ + %0 = "arith.addi"(%1, %1) : (i32, i32) -> i32 +}) : () -> () +""" + + def transformation_safe(module: ModuleOp, rewriter: Rewriter) -> None: + constant_op = module.ops.first + assert constant_op is not None + rewriter.erase_op(constant_op, safe_erase=True) + + def transformation_unsafe(module: ModuleOp, rewriter: Rewriter) -> None: + constant_op = module.ops.first + assert constant_op is not None + rewriter.erase_op(constant_op, safe_erase=False) + + rewrite_and_compare(prog, expected, transformation_unsafe) + + with pytest.raises(Exception): + rewrite_and_compare(prog, expected, transformation_safe) diff --git a/xdsl/pattern_rewriter.py b/xdsl/pattern_rewriter.py index 01f91192aa..0dc14368f1 100644 --- a/xdsl/pattern_rewriter.py +++ b/xdsl/pattern_rewriter.py @@ -146,7 +146,7 @@ def erase_op(self, op: Operation, safe_erase: bool = True): """ self.has_done_action = True if op == self.current_operation: - return self.erase_matched_op() + return self.erase_matched_op(safe_erase) if not self._can_modify_op(op): raise Exception( "PatternRewriter can only erase operations that are the matched operation"