Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

documentation: (toy) add initial lowering to riscv #1365

Merged
merged 1 commit into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/Toy/examples/tests/with-mlir/interpret.toy
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# RUN: python -m toy %s --emit=scf | filecheck %s
# RUN: python -m toy %s --emit=riscv | filecheck %s

# RUN: python -m toy %s --emit=scf --accelerate | filecheck %s
# RUN: python -m toy %s --emit=riscv --accelerate | filecheck %s

# User defined generic function that operates on unknown shaped arguments
def multiply_transpose(a, b) {
Expand Down
20 changes: 18 additions & 2 deletions docs/Toy/toy/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from xdsl.interpreters.func import FuncFunctions
from xdsl.interpreters.memref import MemrefFunctions
from xdsl.interpreters.printf import PrintfFunctions
from xdsl.interpreters.riscv_func import RiscvFuncFunctions
from xdsl.interpreters.scf import ScfFunctions
from xdsl.parser import Parser as IRParser
from xdsl.printer import Printer
Expand All @@ -30,9 +31,10 @@
"toy-infer-shapes",
"affine",
"scf",
"riscv",
],
default="toy-infer-shapes",
help="Action to perform on source file (default: toy-infer-shapes)",
default="riscv",
help="Action to perform on source file (default: riscv)",
)
parser.add_argument("--ir", dest="ir", action="store_true")
parser.add_argument("--print-op-generic", dest="print_generic", action="store_true")
Expand Down Expand Up @@ -84,6 +86,20 @@ def main(path: Path, emit: str, ir: bool, accelerate: bool, print_generic: bool)
if emit == "scf":
interpreter.register_implementations(ScfFunctions())
interpreter.register_implementations(BuiltinFunctions())

if accelerate and emit in ("riscv",):
# TODO: remove when we add lowering from Toy accelerator to custom riscv
interpreter.register_implementations(ToyAcceleratorFunctions())
if emit in ("riscv",):
interpreter.register_implementations(RiscvFuncFunctions())
interpreter.register_implementations(BuiltinFunctions())
# TODO: remove as we add lowerings to riscv
interpreter.register_implementations(ScfFunctions())
interpreter.register_implementations(ArithFunctions())
interpreter.register_implementations(MemrefFunctions())
interpreter.register_implementations(PrintfFunctions())
interpreter.register_implementations(FuncFunctions())

interpreter.call_op("main", ())


Expand Down
33 changes: 32 additions & 1 deletion docs/Toy/toy/compiler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from pathlib import Path

from xdsl.dialects import affine, arith, func, memref, printf, scf
from xdsl.dialects import (
affine,
arith,
func,
memref,
printf,
riscv,
riscv_func,
scf,
)
from xdsl.dialects.builtin import Builtin, ModuleOp
from xdsl.ir import MLContext
from xdsl.transforms.dead_code_elimination import DeadCodeElimination
from xdsl.transforms.mlir_opt import MLIROptPass

from .dialects import toy
Expand All @@ -12,6 +22,7 @@
from .rewrites.lower_to_toy_accelerator import LowerToToyAccelerator
from .rewrites.lower_toy_affine import LowerToAffinePass
from .rewrites.optimise_toy import OptimiseToy
from .rewrites.setup_riscv_pass import SetupRiscvPass
from .rewrites.shape_inference import ShapeInferencePass


Expand All @@ -23,6 +34,8 @@ def context() -> MLContext:
ctx.register_dialect(func.Func)
ctx.register_dialect(memref.MemRef)
ctx.register_dialect(printf.Printf)
ctx.register_dialect(riscv_func.RISCV_Func)
ctx.register_dialect(riscv.RISCV)
ctx.register_dialect(scf.Scf)
ctx.register_dialect(toy.Toy)
return ctx
Expand Down Expand Up @@ -83,4 +96,22 @@ def transform(
if target == "scf":
return

# When the commented passes are uncommented, we can print RISC-V assembly

SetupRiscvPass().apply(ctx, module_op)
# LowerFuncToRiscvFunc().apply(ctx, module_op)
# LowerToyAccelerator().apply(ctx, module_op)
# LowerMemrefToRiscv().apply(ctx, module_op)
# LowerPrintfRiscvPass().apply(ctx, module_op)
# LowerArithRiscvPass().apply(ctx, module_op)
DeadCodeElimination().apply(ctx, module_op)
# ReconcileUnrealizedCastsPass().apply(ctx, module_op)

DeadCodeElimination().apply(ctx, module_op)
Comment on lines +107 to +110
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

twice?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes I guess not really necessary right now, I'll remove it in the next PR


module_op.verify()

if target == "riscv":
return

raise ValueError(f"Unknown target option {target}")
41 changes: 41 additions & 0 deletions docs/Toy/toy/rewrites/setup_riscv_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from xdsl.dialects import riscv
from xdsl.dialects.builtin import ModuleOp
from xdsl.ir.core import Block, MLContext, Region
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)


class AddSections(RewritePattern):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With @AntonLydike's riscemu malloc and friends, the .bss section hack is not necessary, similarly we can probably get away with dropping the data section. I'd like to first merge the end-to-end flow as it's working today, and make it better later, if that's ok with everyone.

@op_type_rewrite_pattern
def match_and_rewrite(self, op: ModuleOp, rewriter: PatternRewriter):
# bss stands for block starting symbol
heap_section = riscv.DirectiveOp(
".bss",
None,
Region(
Block(
[
riscv.LabelOp("heap"),
riscv.DirectiveOp(".space", f"{1024}"), # 1kb
]
)
),
)
data_section = riscv.DirectiveOp(".data", None, Region(Block()))
text_section = riscv.DirectiveOp(
".text", None, rewriter.move_region_contents_to_new_regions(op.regions[0])
)

op.body.add_block(Block([heap_section, data_section, text_section]))


class SetupRiscvPass(ModulePass):
name = "setup-lowering-to-riscv"

def apply(self, ctx: MLContext, op: ModuleOp) -> None:
PatternRewriteWalker(AddSections()).rewrite_module(op)
26 changes: 26 additions & 0 deletions docs/Toy/toy/tests/test_add_sections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from xdsl.builder import ImplicitBuilder
from xdsl.dialects.builtin import ModuleOp
from xdsl.dialects.func import FuncOp
from xdsl.dialects.riscv import DirectiveOp, LabelOp
from xdsl.ir import Block, MLContext, Region

from ..rewrites.setup_riscv_pass import SetupRiscvPass

with ImplicitBuilder((input_module := ModuleOp([])).body):
FuncOp("main", ((), ()))

with ImplicitBuilder((output_module := ModuleOp([])).body):
bss = DirectiveOp(".bss", None, Region(Block()))
with ImplicitBuilder(bss.data):
LabelOp("heap")
DirectiveOp(".space", f"{1024}")
data = DirectiveOp(".data", None, Region(Block()))
text = DirectiveOp(".text", None, Region(Block()))
with ImplicitBuilder(text.data):
FuncOp("main", ((), ()))


def test_add_sections():
module = input_module.clone()
SetupRiscvPass().apply(MLContext(), module)
assert f"{module}" == f"{output_module}"