
# Setup
For this to work be sure to add the project to your PYTHONPATH, e.g. `export PYTHONPATH=$PYTHONPATH:/path/to/xdsl/src`

In [12]:
from xdsl.parser import Parser
from xdsl.printer import Printer
import xdsl.dialects.arith as arith
import xdsl.dialects.scf as scf
import xdsl.dialects.func as func
import xdsl.dialects.builtin as builtin
from xdsl.elevate import *
from xdsl.immutable_ir import *
import sys
sys.path.append('${workspaceFolder}')
sys.path.append('${workspaceFolder}/src')

# MLContext, containing information about the registered dialects
ctx = MLContext()

# Init dialects
arith.Arith(ctx)
builtin.Builtin(ctx)
func.Func(ctx)
scf.Scf(ctx)


['/home/martin/development/phd/projects/xDSL/xdsl/docs', '/usr/lib/python310.zip', '/usr/lib/python3.10', '/usr/lib/python3.10/lib-dynload', '', '/home/martin/.local/lib/python3.10/site-packages', '/usr/lib/python3.10/site-packages', '/home/martin/development/phd/projects/xDSL/xdsl', '/home/martin/development/phd/projects/xDSL/xdsl/src', '${workspaceFolder}', '${workspaceFolder}/src', '${workspaceFolder}', '${workspaceFolder}/src']


Scf(ctx=MLContext(_registeredOps={'arith.constant': <class 'xdsl.dialects.arith.Constant'>, 'arith.addi': <class 'xdsl.dialects.arith.Addi'>, 'arith.muli': <class 'xdsl.dialects.arith.Muli'>, 'arith.subi': <class 'xdsl.dialects.arith.Subi'>, 'arith.floordivsi': <class 'xdsl.dialects.arith.FloorDiviSI'>, 'arith.remsi': <class 'xdsl.dialects.arith.RemSI'>, 'arith.addf': <class 'xdsl.dialects.arith.Addf'>, 'arith.mulf': <class 'xdsl.dialects.arith.Mulf'>, 'arith.cmpi': <class 'xdsl.dialects.arith.Cmpi'>, 'arith.andi': <class 'xdsl.dialects.arith.AndI'>, 'arith.ori': <class 'xdsl.dialects.arith.OrI'>, 'arith.xori': <class 'xdsl.dialects.arith.XOrI'>, 'module': <class 'xdsl.dialects.builtin.ModuleOp'>, 'func.func': <class 'xdsl.dialects.func.FuncOp'>, 'func.call': <class 'xdsl.dialects.func.Call'>, 'func.return': <class 'xdsl.dialects.func.Return'>, 'scf.if': <class 'xdsl.dialects.scf.If'>, 'scf.yield': <class 'xdsl.dialects.scf.Yield'>, 'scf.condition': <class 'xdsl.dialects.scf.Condition'

In [13]:
IR = \
"""module() {
  func.func()["type" = !fun<[], [!i32]>] {
    %0 : !i32 = arith.constant() ["value" = 4 : !i32]
    %1 : !i32 = arith.constant() ["value" = 2 : !i32]
    %2 : !i32 = arith.constant() ["value" = 1 : !i32]
    %3 : !i32 = arith.addi(%2 : !i32, %1 : !i32)
    %4 : !i32 = arith.addi(%3 : !i32, %0 : !i32)
    func.return(%4 : !i32)
  }
}
"""
parser = Parser(ctx, IR)
module: Operation = parser.parse_op()

printer = Printer()
printer.print_op(module)



module() {
  func.func() ["type" = !fun<[], [!i32]>] {
    %0 : !i32 = arith.constant() ["value" = 4 : !i32]
    %1 : !i32 = arith.constant() ["value" = 2 : !i32]
    %2 : !i32 = arith.constant() ["value" = 1 : !i32]
    %3 : !i32 = arith.addi(%2 : !i32, %1 : !i32)
    %4 : !i32 = arith.addi(%3 : !i32, %0 : !i32)
    func.return(%4 : !i32)
  }
}


# Rewrites
- operate on immutableIR
- match a single root `op` and replace it
- no way to reference parents (Block, Region, op) of matched `op`

In [7]:
imm_module: IOp = get_immutable_copy(module)


@dataclass(frozen=True)
class AddZero(Strategy):

    def impl(self, op: IOp) -> RewriteResult:
        match op:
            # match an IOp with the only restriction that its result is an Integer
            case IOp(results=[IResult(typ=IntegerType() as type)]):
                new_ops = new_op(Addi, 
                                operands=[
                                    op,
                                    new_op(Constant,
                                        attributes={
                                            "value": IntegerAttr.from_int_and_width(0, 32)
                                        }, result_types=[type])
                                ], 
                                result_types=[type])
                return success(new_ops)
            case _:
                return failure(self)

rewrite_result = backwards(AddZero()).apply(imm_module)

printer = Printer()
print("before:")
printer.print_op(module)

print("after AddZero Rewrite:")
printer = Printer()
printer.print_op(rewrite_result.result_op.get_mutable_copy())

before:
module() {
  func.func() ["type" = !fun<[], [!i32]>] {
    %0 : !i32 = arith.constant() ["value" = 4 : !i32]
    %1 : !i32 = arith.constant() ["value" = 2 : !i32]
    %2 : !i32 = arith.constant() ["value" = 1 : !i32]
    %3 : !i32 = arith.addi(%2 : !i32, %1 : !i32)
    %4 : !i32 = arith.addi(%3 : !i32, %0 : !i32)
    func.return(%4 : !i32)
  }
}
after AddZero Rewrite:
module() {
  func.func() ["type" = !fun<[], [!i32]>] {
    %0 : !i32 = arith.constant() ["value" = 4 : !i32]
    %1 : !i32 = arith.constant() ["value" = 2 : !i32]
    %2 : !i32 = arith.constant() ["value" = 1 : !i32]
    %3 : !i32 = arith.addi(%2 : !i32, %1 : !i32)
    %4 : !i32 = arith.addi(%3 : !i32, %0 : !i32)
    %5 : !i32 = arith.constant() ["value" = 0 : !i32]
    %6 : !i32 = arith.addi(%4 : !i32, %5 : !i32)
    func.return(%6 : !i32)
  }
}


In [9]:
@dataclass(frozen=True)
class CommuteAdd(Strategy):

    def impl(self, op: IOp) -> RewriteResult:
        match op:
            case IOp(op_type=arith.Addi,
                     operands=[operand0, operand1]):
                new_ops = from_op(op, operands=[operand1, operand0])
                return success(new_ops)
            case _:
                return failure(self)

imm_module: IOp = get_immutable_copy(module)
rewrite_result = backwards(CommuteAdd()).apply(imm_module)

printer = Printer()
print("before:")
printer.print_op(module)

print("after AddZero Rewrite:")

printer = Printer()
printer.print_op(rewrite_result.result_op.get_mutable_copy())

before:
module() {
  func.func() ["type" = !fun<[], [!i32]>] {
    %0 : !i32 = arith.constant() ["value" = 4 : !i32]
    %1 : !i32 = arith.constant() ["value" = 2 : !i32]
    %2 : !i32 = arith.constant() ["value" = 1 : !i32]
    %3 : !i32 = arith.addi(%2 : !i32, %1 : !i32)
    %4 : !i32 = arith.addi(%3 : !i32, %0 : !i32)
    func.return(%4 : !i32)
  }
}
after AddZero Rewrite:
module() {
  func.func() ["type" = !fun<[], [!i32]>] {
    %0 : !i32 = arith.constant() ["value" = 4 : !i32]
    %1 : !i32 = arith.constant() ["value" = 2 : !i32]
    %2 : !i32 = arith.constant() ["value" = 1 : !i32]
    %3 : !i32 = arith.addi(%2 : !i32, %1 : !i32)
    %4 : !i32 = arith.addi(%0 : !i32, %3 : !i32)
    func.return(%4 : !i32)
  }
}
