From 4340e1ac36b7a04a8ffdf251d0dbc93bf6b956e0 Mon Sep 17 00:00:00 2001 From: "Ka Wing, Li" <68145845+kingiler@users.noreply.github.com> Date: Thu, 6 Jul 2023 11:41:32 +0000 Subject: [PATCH 1/3] Honour safe erase in rewriter.earse_op --- xdsl/pattern_rewriter.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/xdsl/pattern_rewriter.py b/xdsl/pattern_rewriter.py index a18b43491b..a24e467abb 100644 --- a/xdsl/pattern_rewriter.py +++ b/xdsl/pattern_rewriter.py @@ -4,18 +4,10 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from types import UnionType -from typing import ( - Callable, - TypeVar, - Union, - get_args, - get_origin, - Iterable, - Sequence, -) +from typing import Callable, Iterable, Sequence, TypeVar, Union, get_args, get_origin from xdsl.dialects.builtin import ModuleOp -from xdsl.ir import Operation, Region, Block, BlockArgument, Attribute, SSAValue +from xdsl.ir import Attribute, Block, BlockArgument, Operation, Region, SSAValue from xdsl.rewriter import Rewriter @@ -154,7 +146,7 @@ def erase_op(self, op: Operation, safe_erase: bool = True): """ self.has_done_action = True if op == self.current_operation: - return self.erase_matched_op() + return self.erase_matched_op(safe_erase) if not self._can_modify_op(op): raise Exception( "PatternRewriter can only erase operations that are the matched operation" From bd59a50c6f9b5b2e80277e8d34543574d9d2116d Mon Sep 17 00:00:00 2001 From: "Ka Wing, Li" <68145845+kingiler@users.noreply.github.com> Date: Fri, 7 Jul 2023 13:31:43 +0000 Subject: [PATCH 2/3] tests: Add test case for erase_op --- tests/test_rewriter.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_rewriter.py b/tests/test_rewriter.py index 85f876d47b..456ee4a07f 100644 --- a/tests/test_rewriter.py +++ b/tests/test_rewriter.py @@ -1,5 +1,6 @@ from typing import Callable +import pytest from conftest import assert_print_op from xdsl.dialects.arith import Addi, Arith, Constant @@ -441,3 +442,34 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None: rewriter.replace_op(return_op, [new_op]) rewrite_and_compare(prog, expected, transformation) + + +# Test erase operation +def test_erase_op(): + prog = """\ +"builtin.module"() ({ + %0 = "arith.constant"() {"value" = 42 : i32} : () -> i32 + %1 = "arith.addi"(%0, %0) : (i32, i32) -> i32 +}) : () -> () +""" + + expected = """\ +"builtin.module"() ({ + %0 = "arith.addi"(%1, %1) : (i32, i32) -> i32 +}) : () -> () +""" + + def transformation_safe(module: ModuleOp, rewriter: Rewriter) -> None: + constant_op = module.ops.first + assert constant_op is not None + rewriter.erase_op(constant_op, safe_erase=True) + + def transformation_unsafe(module: ModuleOp, rewriter: Rewriter) -> None: + constant_op = module.ops.first + assert constant_op is not None + rewriter.erase_op(constant_op, safe_erase=False) + + with pytest.raises(Exception): + rewrite_and_compare(prog, expected, transformation_safe) + + rewrite_and_compare(prog, expected, transformation_unsafe) From a30871c468988b9e273e1911bb7294008dce36ff Mon Sep 17 00:00:00 2001 From: "Ka Wing, Li" <68145845+kingiler@users.noreply.github.com> Date: Fri, 7 Jul 2023 14:43:35 +0000 Subject: [PATCH 3/3] tests: Swap test order --- tests/test_rewriter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_rewriter.py b/tests/test_rewriter.py index 456ee4a07f..bf39f5c94e 100644 --- a/tests/test_rewriter.py +++ b/tests/test_rewriter.py @@ -469,7 +469,7 @@ def transformation_unsafe(module: ModuleOp, rewriter: Rewriter) -> None: assert constant_op is not None rewriter.erase_op(constant_op, safe_erase=False) + rewrite_and_compare(prog, expected, transformation_unsafe) + with pytest.raises(Exception): rewrite_and_compare(prog, expected, transformation_safe) - - rewrite_and_compare(prog, expected, transformation_unsafe)