diff --git a/docs/Toy/toy/tests/test_regalloc.py b/docs/Toy/toy/tests/test_regalloc.py new file mode 100644 index 0000000000..7bab6369a0 --- /dev/null +++ b/docs/Toy/toy/tests/test_regalloc.py @@ -0,0 +1,120 @@ +from io import StringIO +from xdsl.builder import Builder +from xdsl.dialects import riscv +from xdsl.dialects.builtin import ModuleOp +from xdsl.ir import MLContext + +from xdsl.transforms.riscv_register_allocation import ( + RISCVRegisterAllocation, +) + + +def context() -> MLContext: + ctx = MLContext() + return ctx + + +def riscv_code(module: ModuleOp) -> str: + stream = StringIO() + riscv.print_assembly(module, stream) + return stream.getvalue() + + +# Handwritten riscv dialect code to test register allocation + + +@ModuleOp +@Builder.implicit_region +def simple_linear_riscv(): + """ + The following riscv dialect IR is generated from the following C code: + + int main() { + int a = 1; + int b = 2; + int c = a + b; + int d = a - b * c; + int f = a * b + c + d; + int g = a - b + c * d + f; + int h = a * b * c * d * f * g; + return a + b * c - d + f * g; + } + + The goal of this test is to check that the register allocator is able to handle very simple linear code with no branching. + """ + + @Builder.implicit_region + def text_region(): + @Builder.implicit_region + def main_region() -> None: + zero = riscv.GetRegisterOp(riscv.Registers.ZERO).res + v0 = riscv.AddiOp(zero, 1).rd + v1 = riscv.AddiOp(zero, 2).rd + + v3 = riscv.AddOp(v1, v0).rd + v4 = riscv.MulOp(v3, v1).rd + v5 = riscv.SubOp(v0, v4).rd + v6 = riscv.MulOp(v1, v0).rd + v7 = riscv.AddOp(v6, v3).rd + v8 = riscv.AddOp(v7, v5).rd + v9 = riscv.SubOp(v0, v1).rd + v10 = riscv.MulOp(v5, v3).rd + v11 = riscv.AddOp(v9, v10).rd + v12 = riscv.AddOp(v11, v8).rd + v13 = riscv.AddOp(v4, v0).rd + v14 = riscv.SubOp(v13, v5).rd + v15 = riscv.MulOp(v12, v8).rd + v16 = riscv.AddOp(v14, v15).rd + + riscv.MVOp(v16, rd=riscv.Registers.A0) + riscv.AddiOp(zero, 93, rd=riscv.Registers.A7).rd + riscv.EcallOp() + + riscv.LabelOp("main", main_region) + + riscv.DirectiveOp(".text", None, text_region) + + +@ModuleOp +@Builder.implicit_region +def simple_linear_riscv_allocated(): + """ + Register allocated version based on BlockNaive strategy of the code in simple_linear_riscv. + """ + + @Builder.implicit_region + def text_region(): + @Builder.implicit_region + def main_region() -> None: + zero = riscv.GetRegisterOp(riscv.Registers.ZERO).res + v0 = riscv.AddiOp(zero, 1, rd=riscv.Registers.T6).rd + v1 = riscv.AddiOp(zero, 2, rd=riscv.Registers.T5).rd + + v3 = riscv.AddOp(v1, v0, rd=riscv.Registers.T4).rd + v4 = riscv.MulOp(v3, v1, rd=riscv.Registers.T3).rd + v5 = riscv.SubOp(v0, v4, rd=riscv.Registers.S11).rd + v6 = riscv.MulOp(v1, v0, rd=riscv.Registers.S10).rd + v7 = riscv.AddOp(v6, v3, rd=riscv.Registers.S9).rd + v8 = riscv.AddOp(v7, v5, rd=riscv.Registers.S8).rd + v9 = riscv.SubOp(v0, v1, rd=riscv.Registers.S7).rd + v10 = riscv.MulOp(v5, v3, rd=riscv.Registers.S6).rd + v11 = riscv.AddOp(v9, v10, rd=riscv.Registers.S5).rd + v12 = riscv.AddOp(v11, v8, rd=riscv.Registers.S4).rd + v13 = riscv.AddOp(v4, v0, rd=riscv.Registers.S3).rd + v14 = riscv.SubOp(v13, v5, rd=riscv.Registers.S2).rd + v15 = riscv.MulOp(v12, v8, rd=riscv.Registers.A7).rd + v16 = riscv.AddOp(v14, v15, rd=riscv.Registers.A6).rd + + riscv.MVOp(v16, rd=riscv.Registers.A0) + riscv.AddiOp(zero, 93, rd=riscv.Registers.A7).rd + riscv.EcallOp() + + riscv.LabelOp("main", main_region) + + riscv.DirectiveOp(".text", None, text_region) + + +def test_allocate_simple_linear(): + RISCVRegisterAllocation("BlockNaive").apply(context(), simple_linear_riscv) + + assert riscv_code(simple_linear_riscv) == riscv_code(simple_linear_riscv_allocated) diff --git a/tests/filecheck/dialects/riscv/riscv_register_allocation_block_naive.mlir b/tests/filecheck/dialects/riscv/riscv_register_allocation_block_naive.mlir new file mode 100644 index 0000000000..a252d41d67 --- /dev/null +++ b/tests/filecheck/dialects/riscv/riscv_register_allocation_block_naive.mlir @@ -0,0 +1,71 @@ +// RUN: xdsl-opt -p riscv-allocate-registers{allocation_strategy=BlockNaive} %s --print-op-generic | filecheck %s + +"builtin.module"() ({ + %0 = "riscv.li"() {"immediate" = 6 : i32} : () -> !riscv.reg<> + %1 = "riscv.li"() {"immediate" = 5 : i32} : () -> !riscv.reg + %2 = "riscv.add"(%0, %1) : (!riscv.reg<>, !riscv.reg) -> !riscv.reg<> + %3 = "riscv.li"() {"immediate" = 29 : i32} : () -> !riscv.reg<> + %4 = "riscv.li"() {"immediate" = 28 : i32} : () -> !riscv.reg<> + %5 = "riscv.li"() {"immediate" = 27 : i32} : () -> !riscv.reg<> + %6 = "riscv.li"() {"immediate" = 26 : i32} : () -> !riscv.reg<> + %7 = "riscv.li"() {"immediate" = 25 : i32} : () -> !riscv.reg<> + %8 = "riscv.li"() {"immediate" = 24 : i32} : () -> !riscv.reg<> + %9 = "riscv.li"() {"immediate" = 23 : i32} : () -> !riscv.reg<> + %10 = "riscv.li"() {"immediate" = 22 : i32} : () -> !riscv.reg<> + %11 = "riscv.li"() {"immediate" = 21 : i32} : () -> !riscv.reg<> + %12 = "riscv.li"() {"immediate" = 20 : i32} : () -> !riscv.reg<> + %13 = "riscv.li"() {"immediate" = 19 : i32} : () -> !riscv.reg<> + %14 = "riscv.li"() {"immediate" = 18 : i32} : () -> !riscv.reg<> + %15 = "riscv.li"() {"immediate" = 17 : i32} : () -> !riscv.reg<> + %16 = "riscv.li"() {"immediate" = 16 : i32} : () -> !riscv.reg<> + %17 = "riscv.li"() {"immediate" = 15 : i32} : () -> !riscv.reg<> + %18 = "riscv.li"() {"immediate" = 14 : i32} : () -> !riscv.reg<> + %19 = "riscv.li"() {"immediate" = 13 : i32} : () -> !riscv.reg<> + %20 = "riscv.li"() {"immediate" = 12 : i32} : () -> !riscv.reg<> + %21 = "riscv.li"() {"immediate" = 11 : i32} : () -> !riscv.reg<> + %22 = "riscv.li"() {"immediate" = 10 : i32} : () -> !riscv.reg<> + %23 = "riscv.li"() {"immediate" = 9 : i32} : () -> !riscv.reg<> + %24 = "riscv.li"() {"immediate" = 8 : i32} : () -> !riscv.reg<> + %25 = "riscv.li"() {"immediate" = 7 : i32} : () -> !riscv.reg<> + %26 = "riscv.li"() {"immediate" = 6 : i32} : () -> !riscv.reg<> + %27 = "riscv.li"() {"immediate" = 5 : i32} : () -> !riscv.reg<> + %28 = "riscv.li"() {"immediate" = 4 : i32} : () -> !riscv.reg<> + %29 = "riscv.li"() {"immediate" = 3 : i32} : () -> !riscv.reg<> + %30 = "riscv.li"() {"immediate" = 2 : i32} : () -> !riscv.reg<> + %31 = "riscv.li"() {"immediate" = 1 : i32} : () -> !riscv.reg<> +}) : () -> () + +// CHECK: "builtin.module"() ({ +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 6 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 5 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.add"(%{{\d+}}, %{{\d+}}) : (!riscv.reg, !riscv.reg) -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 29 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 28 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 27 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 26 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 25 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 24 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 23 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 22 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 21 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 20 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 19 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 18 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 17 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 16 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 15 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 14 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 13 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 12 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 11 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 10 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 9 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 8 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 7 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 6 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 5 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 4 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 3 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 2 : i32} : () -> !riscv.reg +// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 1 : i32} : () -> !riscv.reg +// CHECK-NEXT: }) : () -> () diff --git a/tests/filecheck/dialects/riscv/riscv_register_allocation.mlir b/tests/filecheck/dialects/riscv/riscv_register_allocation_jregs.mlir similarity index 83% rename from tests/filecheck/dialects/riscv/riscv_register_allocation.mlir rename to tests/filecheck/dialects/riscv/riscv_register_allocation_jregs.mlir index 6fb55bbc24..c3aedbd93a 100644 --- a/tests/filecheck/dialects/riscv/riscv_register_allocation.mlir +++ b/tests/filecheck/dialects/riscv/riscv_register_allocation_jregs.mlir @@ -1,4 +1,4 @@ -// RUN: xdsl-opt -p riscv-allocate-registers %s --print-op-generic | filecheck %s +// RUN: xdsl-opt -p riscv-allocate-registers{allocation_strategy=GlobalJRegs} %s --print-op-generic | filecheck %s "builtin.module"() ({ %0 = "riscv.li"() {"immediate" = 6 : i32} : () -> !riscv.reg<> diff --git a/xdsl/transforms/riscv_register_allocation.py b/xdsl/transforms/riscv_register_allocation.py index 5908142087..f2357c421f 100644 --- a/xdsl/transforms/riscv_register_allocation.py +++ b/xdsl/transforms/riscv_register_allocation.py @@ -1,19 +1,81 @@ +from abc import ABC +from dataclasses import dataclass from xdsl.dialects.builtin import ModuleOp from xdsl.dialects.riscv import Register, RegisterType, RISCVOp from xdsl.ir import MLContext from xdsl.passes import ModulePass -class RegisterAllocator: +class AbstractRegisterAllocator(ABC): + """ + Base class for register allocation strategies. + """ + + def __init__(self) -> None: + pass + + def allocate_registers(self, module: ModuleOp) -> None: + """ + Allocates unallocated registers in the module. + """ + + raise NotImplementedError() + + +class RegisterAllocatorBlockNaive(AbstractRegisterAllocator): idx: int def __init__(self) -> None: self.idx = 0 + """ + Since we've got neither right now a handling of a consistent ABI nor of a calling convention, + let's just assume that we have all the registers available for our use except the one explicitly reserved by the default riscv ABI. + """ + + self.available_registers = list(Register.ABI_INDEX_BY_NAME.keys()) + reserved_registers = set(["zero", "sp", "gp", "tp", "fp", "s0"]) + self.available_registers = [ + reg for reg in self.available_registers if reg not in reserved_registers + ] + def allocate_registers(self, module: ModuleOp) -> None: """ - Allocates unallocated registers in the module. Currently sets them to an infinite set - of `j` registers. + Sets unallocated registers for each block to a finite set of real available registers. + When it runs out of real registers for a block, it allocates j registers. + """ + + for region in module.regions: + for block in region.blocks: + block_registers = self.available_registers.copy() + + for op in block.walk(): + if not isinstance(op, RISCVOp): + # Don't perform register allocations on non-RISCV-ops + continue + + for result in op.results: + assert isinstance(result.typ, RegisterType) + if result.typ.data.name is None: + # If we run out of real registers, allocate a j register + if not block_registers: + result.typ = RegisterType(Register(f"j{self.idx}")) + self.idx += 1 + else: + result.typ = RegisterType( + Register(block_registers.pop()) + ) + + +class RegisterAllocatorJRegs(AbstractRegisterAllocator): + idx: int + + def __init__(self) -> None: + self.idx = 0 + + def allocate_registers(self, module: ModuleOp) -> None: + """ + Sets unallocated registers to an infinite set of `j` registers """ for op in module.walk(): if not isinstance(op, RISCVOp): @@ -27,14 +89,27 @@ def allocate_registers(self, module: ModuleOp) -> None: self.idx += 1 +@dataclass class RISCVRegisterAllocation(ModulePass): """ - Allocates unallocated registers in the module. Currently sets them to an infinite set - of `j` registers. + Allocates unallocated registers in the module. """ name = "riscv-allocate-registers" + allocation_strategy: str = "GlobalJRegs" + def apply(self, ctx: MLContext, op: ModuleOp) -> None: - allocator = RegisterAllocator() + allocator_strategies = { + "GlobalJRegs": RegisterAllocatorJRegs, + "BlockNaive": RegisterAllocatorBlockNaive, + } + + if self.allocation_strategy not in allocator_strategies: + raise ValueError( + f"Unknown register allocation strategy {self.allocation_strategy}. " + f"Available allocation types: {allocator_strategies.keys()}" + ) + + allocator = allocator_strategies[self.allocation_strategy]() allocator.allocate_registers(op)