Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: riscv: Add basic register allocator strategy #995

Merged
merged 29 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d13dc82
dialects/riscv: Add LabelOp to RISCV dialect
compor May 24, 2023
18bb1bd
dialects/riscv: Add optional regions to LabelOp
compor May 24, 2023
e202937
dialects: risc: Add assembly writer case for LabelOp
compor May 24, 2023
3a956a0
dialects: risc: Add test case for LabelOp assembly writer
compor May 24, 2023
8125bd5
dialects: risc: Add test case for LabelOp with region
compor May 24, 2023
e57fdc0
dialects: risc: Split test cases for LabelOp
compor May 24, 2023
38dd1ef
dialects: riscv: Add a naive register allocator strategy
adutilleul May 24, 2023
37f37c0
dialects: riscv: fix some formatting for regalloc
adutilleul May 24, 2023
986653c
dialects: riscv: add some basic filecheck for regalloc
adutilleul May 24, 2023
f6d7f7b
dialects/riscv: remove ctor
compor May 24, 2023
16c236e
dialects: riscv: simplify RISCVRegisterAllocation
adutilleul May 25, 2023
109290f
dialects/riscv: Add LabelOp to RISCV dialect
compor May 24, 2023
a3366cf
dialects/riscv: Add optional regions to LabelOp
compor May 24, 2023
ca0f2ce
dialects: risc: Add assembly writer case for LabelOp
compor May 24, 2023
cd71742
dialects: risc: Add test case for LabelOp assembly writer
compor May 24, 2023
0550202
dialects: risc: Add test case for LabelOp with region
compor May 24, 2023
061da8e
dialects: risc: Split test cases for LabelOp
compor May 24, 2023
8f6aa4a
dialects/riscv: Add docstring to LabelOp
compor May 24, 2023
4a6283b
dialects/riscv: Add filecheck tests
compor May 25, 2023
39e0e4a
dialects/riscv: Remove trivial f-strings
compor May 25, 2023
fa25a79
dialects/riscv: Use utility method for comment generation
compor May 25, 2023
7b0ca72
Merge remote-tracking branch 'origin/christos/riscv/labelop' into adu…
adutilleul May 25, 2023
06c477b
dialects: riscv: misc nitpicks in regalloc
adutilleul May 29, 2023
bd842f4
dialects: riscv: remove riscv-emu dep for regalloc tests
adutilleul May 29, 2023
a7ec22a
dialects: riscv: stick to list for `available_registers` in regalloc
adutilleul May 29, 2023
6777ce5
Merge branch 'main' into adutilleul/riscv/regalloc
adutilleul May 30, 2023
bd3e7a1
dialects: riscv: fix formatting in `test_regalloc.py`
adutilleul May 30, 2023
c72eb63
Merge branch 'main' into adutilleul/riscv/regalloc
adutilleul Jun 1, 2023
b3fff0a
dialects: riscv: fix type issue in regalloc filecheck test
adutilleul Jun 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
120 changes: 120 additions & 0 deletions docs/Toy/toy/tests/test_regalloc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from io import StringIO
from xdsl.builder import Builder
from xdsl.dialects import riscv
from xdsl.dialects.builtin import ModuleOp
from xdsl.ir import MLContext

from xdsl.transforms.riscv_register_allocation import (
RISCVRegisterAllocation,
)


def context() -> MLContext:
ctx = MLContext()
return ctx


def riscv_code(module: ModuleOp) -> str:
stream = StringIO()
riscv.print_assembly(module, stream)
return stream.getvalue()


# Handwritten riscv dialect code to test register allocation


@ModuleOp
@Builder.implicit_region
def simple_linear_riscv():
"""
The following riscv dialect IR is generated from the following C code:

int main() {
int a = 1;
int b = 2;
int c = a + b;
int d = a - b * c;
int f = a * b + c + d;
int g = a - b + c * d + f;
int h = a * b * c * d * f * g;
return a + b * c - d + f * g;
}

The goal of this test is to check that the register allocator is able to handle very simple linear code with no branching.
"""

@Builder.implicit_region
def text_region():
@Builder.implicit_region
def main_region() -> None:
zero = riscv.GetRegisterOp(riscv.Registers.ZERO).res
v0 = riscv.AddiOp(zero, 1).rd
v1 = riscv.AddiOp(zero, 2).rd

v3 = riscv.AddOp(v1, v0).rd
v4 = riscv.MulOp(v3, v1).rd
v5 = riscv.SubOp(v0, v4).rd
v6 = riscv.MulOp(v1, v0).rd
v7 = riscv.AddOp(v6, v3).rd
v8 = riscv.AddOp(v7, v5).rd
v9 = riscv.SubOp(v0, v1).rd
v10 = riscv.MulOp(v5, v3).rd
v11 = riscv.AddOp(v9, v10).rd
v12 = riscv.AddOp(v11, v8).rd
v13 = riscv.AddOp(v4, v0).rd
v14 = riscv.SubOp(v13, v5).rd
v15 = riscv.MulOp(v12, v8).rd
v16 = riscv.AddOp(v14, v15).rd

riscv.MVOp(v16, rd=riscv.Registers.A0)
riscv.AddiOp(zero, 93, rd=riscv.Registers.A7).rd
riscv.EcallOp()

riscv.LabelOp("main", main_region)

riscv.DirectiveOp(".text", None, text_region)


@ModuleOp
@Builder.implicit_region
def simple_linear_riscv_allocated():
"""
Register allocated version based on BlockNaive strategy of the code in simple_linear_riscv.
"""

@Builder.implicit_region
def text_region():
@Builder.implicit_region
def main_region() -> None:
zero = riscv.GetRegisterOp(riscv.Registers.ZERO).res
v0 = riscv.AddiOp(zero, 1, rd=riscv.Registers.T6).rd
v1 = riscv.AddiOp(zero, 2, rd=riscv.Registers.T5).rd

v3 = riscv.AddOp(v1, v0, rd=riscv.Registers.T4).rd
v4 = riscv.MulOp(v3, v1, rd=riscv.Registers.T3).rd
v5 = riscv.SubOp(v0, v4, rd=riscv.Registers.S11).rd
v6 = riscv.MulOp(v1, v0, rd=riscv.Registers.S10).rd
v7 = riscv.AddOp(v6, v3, rd=riscv.Registers.S9).rd
v8 = riscv.AddOp(v7, v5, rd=riscv.Registers.S8).rd
v9 = riscv.SubOp(v0, v1, rd=riscv.Registers.S7).rd
v10 = riscv.MulOp(v5, v3, rd=riscv.Registers.S6).rd
v11 = riscv.AddOp(v9, v10, rd=riscv.Registers.S5).rd
v12 = riscv.AddOp(v11, v8, rd=riscv.Registers.S4).rd
v13 = riscv.AddOp(v4, v0, rd=riscv.Registers.S3).rd
v14 = riscv.SubOp(v13, v5, rd=riscv.Registers.S2).rd
v15 = riscv.MulOp(v12, v8, rd=riscv.Registers.A7).rd
v16 = riscv.AddOp(v14, v15, rd=riscv.Registers.A6).rd

riscv.MVOp(v16, rd=riscv.Registers.A0)
riscv.AddiOp(zero, 93, rd=riscv.Registers.A7).rd
riscv.EcallOp()

riscv.LabelOp("main", main_region)

riscv.DirectiveOp(".text", None, text_region)


def test_allocate_simple_linear():
RISCVRegisterAllocation("BlockNaive").apply(context(), simple_linear_riscv)

assert riscv_code(simple_linear_riscv) == riscv_code(simple_linear_riscv_allocated)
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// RUN: xdsl-opt -p riscv-allocate-registers{allocation_strategy=BlockNaive} %s --print-op-generic | filecheck %s

"builtin.module"() ({
%0 = "riscv.li"() {"immediate" = 6 : i32} : () -> !riscv.reg<>
%1 = "riscv.li"() {"immediate" = 5 : i32} : () -> !riscv.reg<s0>
%2 = "riscv.add"(%0, %1) : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
%3 = "riscv.li"() {"immediate" = 29 : i32} : () -> !riscv.reg<>
%4 = "riscv.li"() {"immediate" = 28 : i32} : () -> !riscv.reg<>
%5 = "riscv.li"() {"immediate" = 27 : i32} : () -> !riscv.reg<>
%6 = "riscv.li"() {"immediate" = 26 : i32} : () -> !riscv.reg<>
%7 = "riscv.li"() {"immediate" = 25 : i32} : () -> !riscv.reg<>
%8 = "riscv.li"() {"immediate" = 24 : i32} : () -> !riscv.reg<>
%9 = "riscv.li"() {"immediate" = 23 : i32} : () -> !riscv.reg<>
%10 = "riscv.li"() {"immediate" = 22 : i32} : () -> !riscv.reg<>
%11 = "riscv.li"() {"immediate" = 21 : i32} : () -> !riscv.reg<>
%12 = "riscv.li"() {"immediate" = 20 : i32} : () -> !riscv.reg<>
%13 = "riscv.li"() {"immediate" = 19 : i32} : () -> !riscv.reg<>
%14 = "riscv.li"() {"immediate" = 18 : i32} : () -> !riscv.reg<>
%15 = "riscv.li"() {"immediate" = 17 : i32} : () -> !riscv.reg<>
%16 = "riscv.li"() {"immediate" = 16 : i32} : () -> !riscv.reg<>
%17 = "riscv.li"() {"immediate" = 15 : i32} : () -> !riscv.reg<>
%18 = "riscv.li"() {"immediate" = 14 : i32} : () -> !riscv.reg<>
%19 = "riscv.li"() {"immediate" = 13 : i32} : () -> !riscv.reg<>
%20 = "riscv.li"() {"immediate" = 12 : i32} : () -> !riscv.reg<>
%21 = "riscv.li"() {"immediate" = 11 : i32} : () -> !riscv.reg<>
%22 = "riscv.li"() {"immediate" = 10 : i32} : () -> !riscv.reg<>
%23 = "riscv.li"() {"immediate" = 9 : i32} : () -> !riscv.reg<>
%24 = "riscv.li"() {"immediate" = 8 : i32} : () -> !riscv.reg<>
%25 = "riscv.li"() {"immediate" = 7 : i32} : () -> !riscv.reg<>
%26 = "riscv.li"() {"immediate" = 6 : i32} : () -> !riscv.reg<>
%27 = "riscv.li"() {"immediate" = 5 : i32} : () -> !riscv.reg<>
%28 = "riscv.li"() {"immediate" = 4 : i32} : () -> !riscv.reg<>
%29 = "riscv.li"() {"immediate" = 3 : i32} : () -> !riscv.reg<>
%30 = "riscv.li"() {"immediate" = 2 : i32} : () -> !riscv.reg<>
%31 = "riscv.li"() {"immediate" = 1 : i32} : () -> !riscv.reg<>
}) : () -> ()

// CHECK: "builtin.module"() ({
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 6 : i32} : () -> !riscv.reg<t6>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 5 : i32} : () -> !riscv.reg<s0>
// CHECK-NEXT: %{{\d+}} = "riscv.add"(%{{\d+}}, %{{\d+}}) : (!riscv.reg<t6>, !riscv.reg<s0>) -> !riscv.reg<t5>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 29 : i32} : () -> !riscv.reg<t4>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 28 : i32} : () -> !riscv.reg<t3>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 27 : i32} : () -> !riscv.reg<s11>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 26 : i32} : () -> !riscv.reg<s10>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 25 : i32} : () -> !riscv.reg<s9>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 24 : i32} : () -> !riscv.reg<s8>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 23 : i32} : () -> !riscv.reg<s7>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 22 : i32} : () -> !riscv.reg<s6>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 21 : i32} : () -> !riscv.reg<s5>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 20 : i32} : () -> !riscv.reg<s4>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 19 : i32} : () -> !riscv.reg<s3>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 18 : i32} : () -> !riscv.reg<s2>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 17 : i32} : () -> !riscv.reg<a7>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 16 : i32} : () -> !riscv.reg<a6>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 15 : i32} : () -> !riscv.reg<a5>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 14 : i32} : () -> !riscv.reg<a4>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 13 : i32} : () -> !riscv.reg<a3>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 12 : i32} : () -> !riscv.reg<a2>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 11 : i32} : () -> !riscv.reg<a1>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 10 : i32} : () -> !riscv.reg<a0>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 9 : i32} : () -> !riscv.reg<s1>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 8 : i32} : () -> !riscv.reg<t2>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 7 : i32} : () -> !riscv.reg<t1>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 6 : i32} : () -> !riscv.reg<t0>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 5 : i32} : () -> !riscv.reg<ra>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 4 : i32} : () -> !riscv.reg<j0>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 3 : i32} : () -> !riscv.reg<j1>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 2 : i32} : () -> !riscv.reg<j2>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 1 : i32} : () -> !riscv.reg<j3>
// CHECK-NEXT: }) : () -> ()
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: xdsl-opt -p riscv-allocate-registers %s --print-op-generic | filecheck %s
// RUN: xdsl-opt -p riscv-allocate-registers{allocation_strategy=GlobalJRegs} %s --print-op-generic | filecheck %s

"builtin.module"() ({
%0 = "riscv.li"() {"immediate" = 6 : i32} : () -> !riscv.reg<>
Expand Down
87 changes: 81 additions & 6 deletions xdsl/transforms/riscv_register_allocation.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,81 @@
from abc import ABC
from dataclasses import dataclass
from xdsl.dialects.builtin import ModuleOp
from xdsl.dialects.riscv import Register, RegisterType, RISCVOp
from xdsl.ir import MLContext
from xdsl.passes import ModulePass


class RegisterAllocator:
class AbstractRegisterAllocator(ABC):
"""
Base class for register allocation strategies.
"""

def __init__(self) -> None:
pass

def allocate_registers(self, module: ModuleOp) -> None:
"""
Allocates unallocated registers in the module.
"""

raise NotImplementedError()


class RegisterAllocatorBlockNaive(AbstractRegisterAllocator):
idx: int

def __init__(self) -> None:
self.idx = 0

"""
Since we've got neither right now a handling of a consistent ABI nor of a calling convention,
let's just assume that we have all the registers available for our use except the one explicitly reserved by the default riscv ABI.
"""

self.available_registers = list(Register.ABI_INDEX_BY_NAME.keys())
superlopuh marked this conversation as resolved.
Show resolved Hide resolved
reserved_registers = set(["zero", "sp", "gp", "tp", "fp", "s0"])
self.available_registers = [
reg for reg in self.available_registers if reg not in reserved_registers
]

def allocate_registers(self, module: ModuleOp) -> None:
"""
Allocates unallocated registers in the module. Currently sets them to an infinite set
of `j` registers.
Sets unallocated registers for each block to a finite set of real available registers.
When it runs out of real registers for a block, it allocates j registers.
"""

for region in module.regions:
for block in region.blocks:
block_registers = self.available_registers.copy()

for op in block.walk():
if not isinstance(op, RISCVOp):
# Don't perform register allocations on non-RISCV-ops
continue

for result in op.results:
assert isinstance(result.typ, RegisterType)
if result.typ.data.name is None:
# If we run out of real registers, allocate a j register
if not block_registers:
result.typ = RegisterType(Register(f"j{self.idx}"))
self.idx += 1
else:
result.typ = RegisterType(
Register(block_registers.pop())
)


class RegisterAllocatorJRegs(AbstractRegisterAllocator):
idx: int

def __init__(self) -> None:
self.idx = 0

def allocate_registers(self, module: ModuleOp) -> None:
"""
Sets unallocated registers to an infinite set of `j` registers
"""
for op in module.walk():
if not isinstance(op, RISCVOp):
Expand All @@ -27,14 +89,27 @@ def allocate_registers(self, module: ModuleOp) -> None:
self.idx += 1


@dataclass
class RISCVRegisterAllocation(ModulePass):
"""
Allocates unallocated registers in the module. Currently sets them to an infinite set
of `j` registers.
Allocates unallocated registers in the module.
"""

name = "riscv-allocate-registers"

allocation_strategy: str = "GlobalJRegs"

def apply(self, ctx: MLContext, op: ModuleOp) -> None:
allocator = RegisterAllocator()
allocator_strategies = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

"GlobalJRegs": RegisterAllocatorJRegs,
"BlockNaive": RegisterAllocatorBlockNaive,
}

if self.allocation_strategy not in allocator_strategies:
raise ValueError(
f"Unknown register allocation strategy {self.allocation_strategy}. "
f"Available allocation types: {allocator_strategies.keys()}"
)

allocator = allocator_strategies[self.allocation_strategy]()
allocator.allocate_registers(op)