-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c17bd92
commit ec65671
Showing
3 changed files
with
203 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from __future__ import annotations | ||
from typing import IO, ClassVar | ||
|
||
# pyright: reportMissingTypeStubs=false | ||
|
||
from riscemu import RunConfig, UserModeCPU, RV32I, RV32M, AssemblyFileLoader, MMU | ||
from riscemu.instructions import InstructionSet, Instruction | ||
|
||
from io import StringIO | ||
|
||
|
||
class RV_Debug(InstructionSet): | ||
stream: ClassVar[IO[str] | None] = None | ||
|
||
# this instruction will dissappear into our emualtor soon-ish | ||
def instruction_print(self, ins: Instruction): | ||
reg = ins.get_reg(0) | ||
value = self.regs.get(reg) | ||
print(value, file=type(self).stream) | ||
|
||
def __eq__(self, __value: object) -> bool: | ||
if not isinstance(__value, RV_Debug): | ||
return False | ||
return self.stream is __value.stream | ||
|
||
def __hash__(self) -> int: | ||
return hash(id(self.stream)) | ||
|
||
|
||
def run_riscv( | ||
code: str, | ||
extensions: list[type[InstructionSet]] = [], | ||
unlimited_regs: bool = False, | ||
verbosity: int = 5, | ||
): | ||
cfg = RunConfig( | ||
debug_instruction=False, | ||
verbosity=verbosity, | ||
debug_on_exception=False, | ||
unlimited_registers=unlimited_regs, | ||
) | ||
|
||
cpu = UserModeCPU([RV32I, RV32M, RV_Debug, *extensions], cfg) | ||
|
||
io = StringIO(code) | ||
|
||
loader = AssemblyFileLoader.instantiate("example.asm", {}) | ||
assert isinstance(loader, AssemblyFileLoader) | ||
cpu.load_program(loader.parse_io(io)) # pyright: ignore[reportUnknownMemberType] | ||
|
||
mmu: MMU = getattr(cpu, "mmu") | ||
try: | ||
cpu.launch(mmu.programs[-1], verbosity > 1) | ||
except Exception as ex: | ||
print(ex) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
from io import StringIO | ||
|
||
from xdsl.builder import Builder | ||
from xdsl.dialects.builtin import ModuleOp | ||
from xdsl.dialects import riscv | ||
from xdsl.transforms.riscv_register_allocation import RegisterAllocator | ||
|
||
|
||
from ..emulator.emulator_iop import RV_Debug, run_riscv | ||
from .utils import riscv_code | ||
|
||
|
||
def test_simple(): | ||
@ModuleOp | ||
@Builder.implicit_region | ||
def module(): | ||
six = riscv.LiOp(6).rd | ||
seven = riscv.LiOp(7).rd | ||
forty_two = riscv.MulOp(six, seven).rd | ||
riscv.CustomEmulatorInstructionOp("print", inputs=[forty_two], result_types=[]) | ||
|
||
RegisterAllocator().allocate_registers(module) | ||
code = riscv_code(module) | ||
|
||
stream = StringIO() | ||
RV_Debug.stream = stream | ||
run_riscv( | ||
code, | ||
extensions=[RV_Debug], | ||
unlimited_regs=True, | ||
verbosity=1, | ||
) | ||
assert "42\n" == stream.getvalue() | ||
|
||
|
||
def test_multiply_add(): | ||
@ModuleOp | ||
@Builder.implicit_region | ||
def module(): | ||
riscv.DirectiveOp(".bss", "") | ||
riscv.LabelOp("heap") | ||
riscv.DirectiveOp(".space", "1024") | ||
riscv.DirectiveOp(".text", "") | ||
riscv.LabelOp("main") | ||
heap = riscv.LiOp("heap").rd | ||
riscv.AddiOp( | ||
heap, | ||
1020, | ||
rd=riscv.Registers.SP, | ||
comment="stack grows from the top of the heap", | ||
) | ||
|
||
riscv.LiOp(3, rd=riscv.Registers.A0) | ||
riscv.LiOp(2, rd=riscv.Registers.A1) | ||
riscv.LiOp(1, rd=riscv.Registers.A2) | ||
|
||
riscv.JalOp("muladd") | ||
res = riscv.GetRegisterOp(riscv.Registers.A0).res | ||
riscv.CustomEmulatorInstructionOp("print", [res], []) | ||
|
||
riscv.LiOp(93, rd=riscv.Registers.A7) | ||
riscv.EcallOp() | ||
|
||
riscv.LabelOp("multiply") | ||
riscv.CommentOp("no extra registers needed, so no need to deal with stack") | ||
a0_multiply = riscv.GetRegisterOp(riscv.Registers.A0) | ||
a1_multiply = riscv.GetRegisterOp(riscv.Registers.A1) | ||
riscv.MulOp(a0_multiply, a1_multiply, rd=riscv.Registers.A0) | ||
riscv.ReturnOp() | ||
|
||
riscv.LabelOp("add") | ||
riscv.CommentOp("no extra registers needed, so no need to deal with stack") | ||
a0_add = riscv.GetRegisterOp(riscv.Registers.A0) | ||
a1_add = riscv.GetRegisterOp(riscv.Registers.A1) | ||
riscv.AddOp(a0_add, a1_add, rd=riscv.Registers.A0) | ||
riscv.ReturnOp() | ||
|
||
riscv.LabelOp("muladd") | ||
riscv.CommentOp("a0 <- a0 * a1 + a2") | ||
riscv.CommentOp("prologue") | ||
# get registers with the arguments to muladd | ||
a2_muladd = riscv.GetRegisterOp(riscv.Registers.A2) | ||
|
||
# get registers we'll use in this section | ||
sp_muladd = riscv.GetRegisterOp(riscv.Registers.SP) | ||
s0_muladd_0 = riscv.GetRegisterOp(riscv.Registers.S0) | ||
ra_muladd = riscv.GetRegisterOp(riscv.Registers.RA) | ||
riscv.CommentOp( | ||
"decrement stack pointer by number of register values we need to store for later" | ||
) | ||
riscv.AddiOp(sp_muladd, -8, rd=riscv.Registers.SP) | ||
riscv.CommentOp("save the s registers we'll use on the stack") | ||
riscv.SwOp(s0_muladd_0, sp_muladd, 0) | ||
riscv.CommentOp("save the return address we'll use on the stack") | ||
riscv.SwOp(ra_muladd, sp_muladd, 4) | ||
|
||
# store third parameter, in a2 to the temporary register s0 | ||
# guaranteed to be the same after call to multiply | ||
s0_muladd_1 = riscv.MVOp(a2_muladd, rd=riscv.Registers.S0) | ||
riscv.JalOp("multiply") | ||
|
||
# The product of a0 and a1 is stored in a0 | ||
# We now have to move s0 to a1, and call add | ||
|
||
riscv.MVOp(s0_muladd_1, rd=riscv.Registers.A1) | ||
|
||
riscv.JalOp("add") | ||
|
||
riscv.CommentOp("epilogue") | ||
riscv.CommentOp("store the old values back into the s registers") | ||
riscv.LwOp(sp_muladd, 0, rd=riscv.Registers.S0) | ||
|
||
riscv.CommentOp("store the return address back into the ra register") | ||
riscv.LwOp(sp_muladd, 4, rd=riscv.Registers.RA) | ||
|
||
riscv.CommentOp( | ||
"set the sp back to what it was at the start of the function call" | ||
) | ||
riscv.AddiOp(sp_muladd, 8, rd=riscv.Registers.SP) | ||
|
||
riscv.CommentOp("jump back to caller") | ||
riscv.ReturnOp() | ||
|
||
RegisterAllocator().allocate_registers(module) | ||
code = riscv_code(module) | ||
|
||
stream = StringIO() | ||
RV_Debug.stream = stream | ||
run_riscv( | ||
code, | ||
extensions=[RV_Debug], | ||
unlimited_regs=True, | ||
verbosity=1, | ||
) | ||
assert ( | ||
stream.getvalue() | ||
== """7 | ||
""" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from io import StringIO | ||
from xdsl.dialects import riscv | ||
from xdsl.dialects.builtin import ModuleOp | ||
|
||
|
||
def riscv_code(module: ModuleOp) -> str: | ||
stream = StringIO() | ||
riscv.print_assembly(module, stream) | ||
return stream.getvalue() |