Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: riscv: Add basic register allocator strategy #995

Merged
merged 29 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d13dc82
dialects/riscv: Add LabelOp to RISCV dialect
compor May 24, 2023
18bb1bd
dialects/riscv: Add optional regions to LabelOp
compor May 24, 2023
e202937
dialects: risc: Add assembly writer case for LabelOp
compor May 24, 2023
3a956a0
dialects: risc: Add test case for LabelOp assembly writer
compor May 24, 2023
8125bd5
dialects: risc: Add test case for LabelOp with region
compor May 24, 2023
e57fdc0
dialects: risc: Split test cases for LabelOp
compor May 24, 2023
38dd1ef
dialects: riscv: Add a naive register allocator strategy
adutilleul May 24, 2023
37f37c0
dialects: riscv: fix some formatting for regalloc
adutilleul May 24, 2023
986653c
dialects: riscv: add some basic filecheck for regalloc
adutilleul May 24, 2023
f6d7f7b
dialects/riscv: remove ctor
compor May 24, 2023
16c236e
dialects: riscv: simplify RISCVRegisterAllocation
adutilleul May 25, 2023
109290f
dialects/riscv: Add LabelOp to RISCV dialect
compor May 24, 2023
a3366cf
dialects/riscv: Add optional regions to LabelOp
compor May 24, 2023
ca0f2ce
dialects: risc: Add assembly writer case for LabelOp
compor May 24, 2023
cd71742
dialects: risc: Add test case for LabelOp assembly writer
compor May 24, 2023
0550202
dialects: risc: Add test case for LabelOp with region
compor May 24, 2023
061da8e
dialects: risc: Split test cases for LabelOp
compor May 24, 2023
8f6aa4a
dialects/riscv: Add docstring to LabelOp
compor May 24, 2023
4a6283b
dialects/riscv: Add filecheck tests
compor May 25, 2023
39e0e4a
dialects/riscv: Remove trivial f-strings
compor May 25, 2023
fa25a79
dialects/riscv: Use utility method for comment generation
compor May 25, 2023
7b0ca72
Merge remote-tracking branch 'origin/christos/riscv/labelop' into adu…
adutilleul May 25, 2023
06c477b
dialects: riscv: misc nitpicks in regalloc
adutilleul May 29, 2023
bd842f4
dialects: riscv: remove riscv-emu dep for regalloc tests
adutilleul May 29, 2023
a7ec22a
dialects: riscv: stick to list for `available_registers` in regalloc
adutilleul May 29, 2023
6777ce5
Merge branch 'main' into adutilleul/riscv/regalloc
adutilleul May 30, 2023
bd3e7a1
dialects: riscv: fix formatting in `test_regalloc.py`
adutilleul May 30, 2023
c72eb63
Merge branch 'main' into adutilleul/riscv/regalloc
adutilleul Jun 1, 2023
b3fff0a
dialects: riscv: fix type issue in regalloc filecheck test
adutilleul Jun 1, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 42 additions & 0 deletions docs/Toy/toy/emulator/emulator_iop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations
from typing import Optional

# pyright: reportMissingTypeStubs=false

from riscemu import RunConfig, UserModeCPU, RV32I, RV32M, AssemblyFileLoader, MMU
from riscemu.instructions import InstructionSet

from io import StringIO


def run_riscv(
code: str,
extensions: list[type[InstructionSet]] = [],
unlimited_regs: bool = False,
setup_stack: bool = False,
verbosity: int = 5,
) -> Optional[int]:
cfg = RunConfig(
debug_instruction=False,
verbosity=verbosity,
debug_on_exception=False,
unlimited_registers=unlimited_regs,
)

cpu = UserModeCPU([RV32I, RV32M, *extensions], cfg)

io = StringIO(code)

loader = AssemblyFileLoader.instantiate("example.asm", {})
assert isinstance(loader, AssemblyFileLoader)
cpu.load_program(loader.parse_io(io)) # pyright: ignore[reportUnknownMemberType]

mmu: MMU = getattr(cpu, "mmu")
try:
if setup_stack:
cpu.setup_stack(cfg.stack_size)
cpu.launch(mmu.programs[-1], verbosity > 1)
return cpu.exit_code
except Exception as ex:
print(ex)
return None
179 changes: 179 additions & 0 deletions docs/Toy/toy/tests/test_regalloc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from xdsl.builder import Builder
from xdsl.dialects import riscv
from xdsl.dialects.builtin import ModuleOp
from xdsl.ir import MLContext
from xdsl.riscv_asm_writer import riscv_code

from xdsl.transforms.riscv_register_allocation import (
RISCVRegisterAllocation,
)

from ..emulator.emulator_iop import run_riscv

ALLOCATION_STRATEGIES = [
"GlobalJRegs",
"BlockNaive",
]


def context() -> MLContext:
ctx = MLContext()
return ctx


# Handwritten riscv dialect code to test register allocation


@ModuleOp
@Builder.implicit_region
def simple_branching_riscv():
"""
The following riscv dialect IR is generated from the following C code:

int main() {
int a = 5;
int b = 77777;
int c = 6;
if (a==b) {
c = a * a;
} else {
c = b + b;
}

return c;
}

The goal of this test is to check that the register allocator is able to handle very simple branching code with multiple basic blocks.
Morever it uses some reserved registers (ra, s0) to check that the register allocator does not use them.
"""

@Builder.implicit_region
def text_region():
@Builder.implicit_region
def main_region() -> None:
sp = riscv.GetRegisterOp(riscv.Registers.SP).res
riscv.AddiOp(sp, -32, rd=riscv.Registers.SP).rd
ra = riscv.GetRegisterOp(riscv.Registers.RA).res
riscv.SwOp(ra, sp, 28)
s0 = riscv.GetRegisterOp(riscv.Registers.S0).res
riscv.SwOp(s0, sp, 24)
a = riscv.AddiOp(sp, 32).rd
b = riscv.LiOp(5).rd
riscv.SwOp(b, a, -12)
c = riscv.LuiOp(19).rd
d = riscv.AddiOp(c, -47).rd
riscv.SwOp(d, a, -16)
e = riscv.LiOp(6).rd
riscv.SwOp(e, a, -20)
f = riscv.LwOp(a, -12).rd
g = riscv.LwOp(a, -16).rd
riscv.BneOp(f, g, riscv.LabelAttr("LBB0_2"))
riscv.JOp(riscv.LabelAttr("LBB0_1"))

@Builder.implicit_region
def true_branch() -> None:
f = riscv.LwOp(a, -12).rd
f = riscv.MulOp(f, f).rd
riscv.SwOp(f, a, -20)
riscv.JOp(riscv.LabelAttr("LBB0_3"))

riscv.LabelOp("LBB0_1", true_branch)

@Builder.implicit_region
def false_branch() -> None:
f = riscv.LwOp(a, -16).rd
f = riscv.AddOp(f, f).rd
riscv.SwOp(f, a, -20)
riscv.JOp(riscv.LabelAttr("LBB0_3"))

riscv.LabelOp("LBB0_2", false_branch)

@Builder.implicit_region
def merge_if() -> None:
f = riscv.LwOp(a, -20).rd
riscv.LwOp(sp, 28, rd=riscv.Registers.RA).rd
riscv.LwOp(sp, 24, rd=riscv.Registers.S0).rd
riscv.AddiOp(sp, 32, rd=riscv.Registers.SP).rd
riscv.MVOp(f, rd=riscv.Registers.A0)
zero = riscv.GetRegisterOp(riscv.Registers.ZERO).res
riscv.AddiOp(zero, 93, rd=riscv.Registers.A7).rd
riscv.EcallOp()

riscv.LabelOp("LBB0_3", merge_if)

riscv.LabelOp("main", main_region)

riscv.DirectiveOp(".text", None, text_region)


@ModuleOp
@Builder.implicit_region
def simple_linear_riscv():
"""
The following riscv dialect IR is generated from the following C code:

int main() {
int a = 1;
int b = 2;
int c = a + b;
int d = a - b * c;
int f = a * b + c + d;
int g = a - b + c * d + f;
int h = a * b * c * d * f * g;
return a + b * c - d + f * g;
}

The goal of this test is to check that the register allocator is able to handle very simple linear code with no branching.
"""

@Builder.implicit_region
def text_region():
@Builder.implicit_region
def main_region() -> None:
zero = riscv.GetRegisterOp(riscv.Registers.ZERO).res
v0 = riscv.AddiOp(zero, 1).rd
v1 = riscv.AddiOp(zero, 2).rd

v3 = riscv.AddOp(v1, v0).rd
v4 = riscv.MulOp(v3, v1).rd
v5 = riscv.SubOp(v0, v4).rd
v6 = riscv.MulOp(v1, v0).rd
v7 = riscv.AddOp(v6, v3).rd
v8 = riscv.AddOp(v7, v5).rd
v9 = riscv.SubOp(v0, v1).rd
v10 = riscv.MulOp(v5, v3).rd
v11 = riscv.AddOp(v9, v10).rd
v12 = riscv.AddOp(v11, v8).rd
v13 = riscv.AddOp(v4, v0).rd
v14 = riscv.SubOp(v13, v5).rd
v15 = riscv.MulOp(v12, v8).rd
v16 = riscv.AddOp(v14, v15).rd

riscv.MVOp(v16, rd=riscv.Registers.A0)
riscv.AddiOp(zero, 93, rd=riscv.Registers.A7).rd
riscv.EcallOp()

riscv.LabelOp("main", main_region)

riscv.DirectiveOp(".text", None, text_region)


def test_allocate_simple_branching():
for allocation_strategy in ALLOCATION_STRATEGIES:
RISCVRegisterAllocation(allocation_strategy).apply(
context(), simple_branching_riscv
)
code = riscv_code(simple_branching_riscv)
assert (
run_riscv(code, unlimited_regs=True, setup_stack=True, verbosity=1)
== 155554
)


def test_allocate_simple_linear():
for allocation_strategy in ALLOCATION_STRATEGIES:
RISCVRegisterAllocation(allocation_strategy).apply(
context(), simple_linear_riscv
)
code = riscv_code(simple_linear_riscv)
assert run_riscv(code, unlimited_regs=True, setup_stack=True, verbosity=1) == 12
1 change: 1 addition & 0 deletions requirements-optional.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ nbval<0.11
filecheck<0.0.24
lit<17.0.0
pre-commit==3.3.2
git+https://github.com/antonlydike/riscemu.git@25d059da090760862f9143478524ed6daf0ab449#egg=riscemu
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't actually include this yet, or at least would need to do this in a separate PR where we can discuss the extra dependency separately. I don't think the register allocation requires this change, even if it's useful for debugging, so I'd recommend taking out all riscemu related changes for now.

# pyright has to be the last line and fixed with `==`. The CI parses this file
# in `.github/parse_pyright_version.py` and installs the according version for
# typechecking.
Expand Down
38 changes: 38 additions & 0 deletions tests/dialects/test_riscv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from xdsl.riscv_asm_writer import riscv_code
from xdsl.builder import Builder
from xdsl.utils.test_value import TestSSAValue
from xdsl.dialects import riscv

Expand Down Expand Up @@ -89,6 +90,43 @@ def test_comment_op():
assert code == " # my comment\n"


def test_label_op_without_comment():
label_str = "mylabel"
label_op = riscv.LabelOp(label_str)

assert label_op.label.data == f"{label_str}"

code = riscv_code(ModuleOp([label_op]))
assert code == f"{label_str}:\n"


def test_label_op_with_comment():
label_str = "mylabel"
label_op = riscv.LabelOp(f"{label_str}", comment="my label")

assert label_op.label.data == "mylabel"
assert label_op.label.data == f"{label_str}"

code = riscv_code(ModuleOp([label_op]))
assert code == f"{label_str}: # my label\n"


def test_label_op_with_region():
@Builder.implicit_region
def label_region():
a1_reg = TestSSAValue(riscv.RegisterType(riscv.Registers.A1))
a2_reg = TestSSAValue(riscv.RegisterType(riscv.Registers.A2))
riscv.AddOp(a1_reg, a2_reg, rd=riscv.Registers.A0)

label_str = "mylabel"
label_op = riscv.LabelOp(f"{label_str}", region=label_region)

assert label_op.label.data == f"{label_str}"

code = riscv_code(ModuleOp([label_op]))
assert code == f"{label_str}:\n add a0, a1, a2\n"


def test_return_op():
return_op = riscv.EbreakOp(comment="my comment")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// RUN: xdsl-opt -p riscv-allocate-registers{allocation_type=BlockNaive} %s --print-op-generic | filecheck %s

"builtin.module"() ({
%0 = "riscv.li"() {"immediate" = 6 : i32} : () -> !riscv.reg<>
%1 = "riscv.li"() {"immediate" = 5 : i32} : () -> !riscv.reg<s0>
%2 = "riscv.add"(%0, %1) : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
%3 = "riscv.li"() {"immediate" = 29 : i32} : () -> !riscv.reg<>
%4 = "riscv.li"() {"immediate" = 28 : i32} : () -> !riscv.reg<>
%5 = "riscv.li"() {"immediate" = 27 : i32} : () -> !riscv.reg<>
%6 = "riscv.li"() {"immediate" = 26 : i32} : () -> !riscv.reg<>
%7 = "riscv.li"() {"immediate" = 25 : i32} : () -> !riscv.reg<>
%8 = "riscv.li"() {"immediate" = 24 : i32} : () -> !riscv.reg<>
%9 = "riscv.li"() {"immediate" = 23 : i32} : () -> !riscv.reg<>
%10 = "riscv.li"() {"immediate" = 22 : i32} : () -> !riscv.reg<>
%11 = "riscv.li"() {"immediate" = 21 : i32} : () -> !riscv.reg<>
%12 = "riscv.li"() {"immediate" = 20 : i32} : () -> !riscv.reg<>
%13 = "riscv.li"() {"immediate" = 19 : i32} : () -> !riscv.reg<>
%14 = "riscv.li"() {"immediate" = 18 : i32} : () -> !riscv.reg<>
%15 = "riscv.li"() {"immediate" = 17 : i32} : () -> !riscv.reg<>
%16 = "riscv.li"() {"immediate" = 16 : i32} : () -> !riscv.reg<>
%17 = "riscv.li"() {"immediate" = 15 : i32} : () -> !riscv.reg<>
%18 = "riscv.li"() {"immediate" = 14 : i32} : () -> !riscv.reg<>
%19 = "riscv.li"() {"immediate" = 13 : i32} : () -> !riscv.reg<>
%20 = "riscv.li"() {"immediate" = 12 : i32} : () -> !riscv.reg<>
%21 = "riscv.li"() {"immediate" = 11 : i32} : () -> !riscv.reg<>
%22 = "riscv.li"() {"immediate" = 10 : i32} : () -> !riscv.reg<>
%23 = "riscv.li"() {"immediate" = 9 : i32} : () -> !riscv.reg<>
%24 = "riscv.li"() {"immediate" = 8 : i32} : () -> !riscv.reg<>
%25 = "riscv.li"() {"immediate" = 7 : i32} : () -> !riscv.reg<>
%26 = "riscv.li"() {"immediate" = 6 : i32} : () -> !riscv.reg<>
%27 = "riscv.li"() {"immediate" = 5 : i32} : () -> !riscv.reg<>
%28 = "riscv.li"() {"immediate" = 4 : i32} : () -> !riscv.reg<>
%29 = "riscv.li"() {"immediate" = 3 : i32} : () -> !riscv.reg<>
%30 = "riscv.li"() {"immediate" = 2 : i32} : () -> !riscv.reg<>
%31 = "riscv.li"() {"immediate" = 1 : i32} : () -> !riscv.reg<>
}) : () -> ()

// CHECK: "builtin.module"() ({
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 6 : i32} : () -> !riscv.reg<t0>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 5 : i32} : () -> !riscv.reg<s0>
// CHECK-NEXT: %{{\d+}} = "riscv.add"(%{{\d+}}, %{{\d+}}) : (!riscv.reg<t0>, !riscv.reg<s0>) -> !riscv.reg<t1>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 29 : i32} : () -> !riscv.reg<t2>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 28 : i32} : () -> !riscv.reg<s1>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 27 : i32} : () -> !riscv.reg<a0>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 26 : i32} : () -> !riscv.reg<a1>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 25 : i32} : () -> !riscv.reg<a2>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 24 : i32} : () -> !riscv.reg<a3>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 23 : i32} : () -> !riscv.reg<a4>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 22 : i32} : () -> !riscv.reg<a5>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 21 : i32} : () -> !riscv.reg<a6>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 20 : i32} : () -> !riscv.reg<a7>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 19 : i32} : () -> !riscv.reg<s2>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 18 : i32} : () -> !riscv.reg<s3>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 17 : i32} : () -> !riscv.reg<s4>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 16 : i32} : () -> !riscv.reg<s5>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 15 : i32} : () -> !riscv.reg<s6>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 14 : i32} : () -> !riscv.reg<s7>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 13 : i32} : () -> !riscv.reg<s8>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 12 : i32} : () -> !riscv.reg<s9>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 11 : i32} : () -> !riscv.reg<s10>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 10 : i32} : () -> !riscv.reg<s11>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 9 : i32} : () -> !riscv.reg<t3>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 8 : i32} : () -> !riscv.reg<t4>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 7 : i32} : () -> !riscv.reg<t5>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 6 : i32} : () -> !riscv.reg<t6>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 5 : i32} : () -> !riscv.reg<j0>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 4 : i32} : () -> !riscv.reg<j1>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 3 : i32} : () -> !riscv.reg<j2>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 2 : i32} : () -> !riscv.reg<j3>
// CHECK-NEXT: %{{\d+}} = "riscv.li"() {"immediate" = 1 : i32} : () -> !riscv.reg<j4>
// CHECK-NEXT: }) : () -> ()
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: xdsl-opt -p riscv-allocate-registers %s --print-op-generic | filecheck %s
// RUN: xdsl-opt -p riscv-allocate-registers{allocation_type=GlobalJRegs} %s --print-op-generic | filecheck %s

"builtin.module"() ({
%0 = "riscv.li"() {"immediate" = 6 : i32} : () -> !riscv.reg<>
Expand Down