Skip to content

Commit

Permalink
dialects: riscv-func add initial function call lowering (#929)
Browse files Browse the repository at this point in the history
together with prologue and epilogue register constraints compliant to
conventions.

Integrated this branch with reference to the sasha/workshop-all branch.

---------

Co-authored-by: Sasha Lopoukhine <superlopuh@gmail.com>
  • Loading branch information
eymay and superlopuh committed Jun 1, 2023
1 parent e1e8e2d commit 4343a8f
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 20 deletions.
88 changes: 76 additions & 12 deletions tests/filecheck/dialects/riscv_func/lower_riscv_func.mlir
Original file line number Diff line number Diff line change
@@ -1,17 +1,81 @@
// RUN: xdsl-opt -p lower-riscv-func %s | filecheck %s

"builtin.module"() ({
%file = "riscv.li"() {"immediate" = 0 : i32} : () -> !riscv.reg<s0>
%success = "riscv_func.syscall"(%file) {"syscall_num" = 64 : i32}: (!riscv.reg<s0>) -> !riscv.reg<s1>
// CHECK: %file = "riscv.li"() {"immediate" = 0 : i32} : () -> !riscv.reg<s0>
// CHECK-NEXT: %{{.+}} = "riscv.mv"(%{{.+}}) : (!riscv.reg<s0>) -> !riscv.reg<a0>
// CHECK-NEXT: %{{.+}} = "riscv.li"() {"immediate" = 64 : i32} : () -> !riscv.reg<a7>
// CHECK-NEXT: "riscv.ecall"() : () -> ()
// CHECK-NEXT: %{{.+}} = "riscv.get_register"() : () -> !riscv.reg<a0>
// CHECK-NEXT: %{{.+}} = "riscv.mv"(%{{.+}}) : (!riscv.reg<a0>) -> !riscv.reg<s1>

"riscv_func.syscall"() {"syscall_num" = 93 : i32} : () -> ()
// CHECK-NEXT: %{{.+}} = "riscv.li"() {"immediate" = 93 : i32} : () -> !riscv.reg<a7>
// CHECK-NEXT: "riscv.ecall"() : () -> ()
// CHECK: builtin.module {

%file = "riscv.li"() {"immediate" = 0 : i32} : () -> !riscv.reg<s0>
%success = "riscv_func.syscall"(%file) {"syscall_num" = 64 : i32}: (!riscv.reg<s0>) -> !riscv.reg<s1>
// CHECK-NEXT: %file = "riscv.li"() {"immediate" = 0 : i32} : () -> !riscv.reg<s0>
// CHECK-NEXT: %{{.+}} = "riscv.mv"(%{{.+}}) : (!riscv.reg<s0>) -> !riscv.reg<a0>
// CHECK-NEXT: %{{.+}} = "riscv.li"() {"immediate" = 64 : i32} : () -> !riscv.reg<a7>
// CHECK-NEXT: "riscv.ecall"() : () -> ()
// CHECK-NEXT: %{{.+}} = "riscv.get_register"() : () -> !riscv.reg<a0>
// CHECK-NEXT: %{{.+}} = "riscv.mv"(%{{.+}}) : (!riscv.reg<a0>) -> !riscv.reg<s1>


"riscv_func.syscall"() {"syscall_num" = 93 : i32} : () -> ()
// CHECK-NEXT: %{{.+}} = "riscv.li"() {"immediate" = 93 : i32} : () -> !riscv.reg<a7>
// CHECK-NEXT: "riscv.ecall"() : () -> ()

"riscv_func.func"() ({
%0 = "riscv_func.call"() {"func_name" = "get_one"} : () -> !riscv.reg<>
%1 = "riscv_func.call"() {"func_name" = "get_one"} : () -> !riscv.reg<>
%2 = "riscv_func.call"(%0, %1) {"func_name" = "add"} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
"riscv_func.call"(%2) {"func_name" = "my_print"} : (!riscv.reg<>) -> ()
"riscv_func.return"() : () -> ()
}) {"func_name" = "main"} : () -> ()

// CHECK-NEXT: "riscv.label"() ({
// CHECK-NEXT: "riscv.jal"() {"immediate" = #riscv.label<"get_one">} : () -> ()
// CHECK-NEXT: %{{.*}} = "riscv.get_register"() : () -> !riscv.reg<a0>
// CHECK-NEXT: %{{.*}} = "riscv.mv"(%{{.*}}) : (!riscv.reg<a0>) -> !riscv.reg<>
// CHECK-NEXT: "riscv.jal"() {"immediate" = #riscv.label<"get_one">} : () -> ()
// CHECK-NEXT: %{{.*}} = "riscv.get_register"() : () -> !riscv.reg<a0>
// CHECK-NEXT: %{{.*}} = "riscv.mv"(%{{.*}}) : (!riscv.reg<a0>) -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = "riscv.mv"(%{{.*}}) : (!riscv.reg<>) -> !riscv.reg<a0>
// CHECK-NEXT: %{{.*}} = "riscv.mv"(%{{.*}}) : (!riscv.reg<>) -> !riscv.reg<a1>
// CHECK-NEXT: "riscv.jal"() {"immediate" = #riscv.label<"add">} : () -> ()
// CHECK-NEXT: %{{.*}} = "riscv.get_register"() : () -> !riscv.reg<a0>
// CHECK-NEXT: %{{.*}} = "riscv.mv"(%{{.*}}) : (!riscv.reg<a0>) -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = "riscv.mv"(%{{.*}}) : (!riscv.reg<>) -> !riscv.reg<a0>
// CHECK-NEXT: "riscv.jal"() {"immediate" = #riscv.label<"my_print">} : () -> ()
// CHECK-NEXT: "riscv.ret"() : () -> ()
// CHECK-NEXT: }) {"label" = #riscv.label<"main">} : () -> ()


"riscv_func.func"() ({
"riscv_func.return"() : () -> ()
}) {"func_name" = "my_print"} : () -> ()

// CHECK-NEXT: "riscv.label"() ({
// CHECK-NEXT: "riscv.ret"() : () -> ()
// CHECK-NEXT: }) {"label" = #riscv.label<"my_print">} : () -> ()

"riscv_func.func"() ({
%0 = "riscv.li"() {"immediate" = 1 : i32} : () -> !riscv.reg<>
"riscv_func.return"(%0) : (!riscv.reg<>) -> ()
}) {"func_name" = "get_one"} : () -> ()

// CHECK-NEXT: "riscv.label"() ({
// CHECK-NEXT: %{{.*}} = "riscv.li"() {"immediate" = 1 : i32} : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = "riscv.mv"(%{{.*}}) : (!riscv.reg<>) -> !riscv.reg<a0>
// CHECK-NEXT: "riscv.ret"() : () -> ()
// CHECK-NEXT: }) {"label" = #riscv.label<"get_one">} : () -> ()

"riscv_func.func"() ({
^0(%0 : !riscv.reg<>, %1 : !riscv.reg<>):
%2 = "riscv.add"(%0, %1) : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
"riscv_func.return"(%2) : (!riscv.reg<>) -> ()
}) {"func_name" = "add"} : () -> ()

// CHECK-NEXT: "riscv.label"() ({
// CHECK-NEXT: %{{.*}} = "riscv.get_register"() : () -> !riscv.reg<a0>
// CHECK-NEXT: %{{.*}} = "riscv.get_register"() : () -> !riscv.reg<a1>
// CHECK-NEXT: %{{.*}} = "riscv.add"(%{{.*}}, %{{.*}}) : (!riscv.reg<a0>, !riscv.reg<a1>) -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = "riscv.mv"(%{{.*}}) : (!riscv.reg<>) -> !riscv.reg<a0>
// CHECK-NEXT: "riscv.ret"() : () -> ()
// CHECK-NEXT: }) {"label" = #riscv.label<"add">} : () -> ()

}) : () -> ()

// CHECK-NEXT: }
96 changes: 93 additions & 3 deletions xdsl/dialects/riscv_func.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from __future__ import annotations

from typing import Annotated
from typing import Annotated, Sequence

from xdsl.ir import Operation, SSAValue, Dialect
from xdsl.ir import Operation, SSAValue, Dialect, Attribute, Region
from xdsl.traits import HasParent
from xdsl.utils.exceptions import VerifyException

from xdsl.irdl import (
IRDLOperation,
OptOpAttr,
OptOpResult,
VarOpResult,
VarOperand,
irdl_op_definition,
SingleBlockRegion,
OpAttr,
)
from xdsl.dialects.builtin import AnyIntegerAttr, IntegerAttr, IntegerType
from xdsl.dialects.builtin import AnyIntegerAttr, IntegerAttr, IntegerType, StringAttr
from xdsl.dialects import riscv


Expand Down Expand Up @@ -48,9 +52,95 @@ def verify_(self):
)


@irdl_op_definition
class CallOp(IRDLOperation):
"""RISC-V function call operation"""

name = "riscv_func.call"
args: Annotated[VarOperand, riscv.RegisterType]
func_name: OpAttr[StringAttr]
ress: Annotated[VarOpResult, riscv.RegisterType]

def __init__(
self,
func_name: StringAttr,
args: Sequence[Operation | SSAValue],
result_types: Sequence[riscv.RegisterType],
comment: StringAttr | None = None,
):
super().__init__(
operands=[args],
result_types=result_types,
attributes={
"func_name": func_name,
"comment": comment,
},
)

def verify_(self):
if len(self.args) >= 9:
raise VerifyException(
f"Function op has too many operands ({len(self.args)}), expected fewer than 9"
)

if len(self.results) >= 3:
raise VerifyException(
f"Function op has too many results ({len(self.results)}), expected fewer than 3"
)


@irdl_op_definition
class FuncOp(IRDLOperation):
"""RISC-V function definition operation"""

name = "riscv_func.func"
func_name: OpAttr[StringAttr]
func_body: SingleBlockRegion

def __init__(self, name: str, region: Region):
attributes: dict[str, Attribute] = {"func_name": StringAttr(name)}

super().__init__(attributes=attributes, regions=[region])


@irdl_op_definition
class ReturnOp(IRDLOperation):
"""RISC-V function return operation"""

name = "riscv_func.return"
values: Annotated[VarOperand, riscv.RegisterType]
comment: OptOpAttr[StringAttr]

traits = frozenset([HasParent(FuncOp)])

def __init__(
self,
values: Sequence[Operation | SSAValue],
*,
comment: str | StringAttr | None = None,
):
if isinstance(comment, str):
comment = StringAttr(comment)
super().__init__(
operands=[values],
attributes={
"comment": comment,
},
)

def verify_(self):
if len(self.results) >= 3:
raise VerifyException(
f"Function op has too many results ({len(self.results)}), expected fewer than 3"
)


RISCV_Func = Dialect(
[
SyscallOp,
CallOp,
FuncOp,
ReturnOp,
],
[],
)
73 changes: 68 additions & 5 deletions xdsl/transforms/lower_riscv_func.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
from typing import cast
from xdsl.dialects.builtin import ModuleOp

from xdsl.ir import MLContext, Operation
from xdsl.ir import MLContext, OpResult, Operation
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriteWalker,
op_type_rewrite_pattern,
RewritePattern,
PatternRewriter,
)
from xdsl.dialects import riscv, riscv_func
from xdsl.transforms.dead_code_elimination import dce


class LowerSyscallOp(RewritePattern):
"""
Lower SSA version of syscall, storing the optional result to a0.
Different platforms have different calling conventions. This lowering assumes that
the inputs are stored in a0-a6, and the opcode is stored to a7. Upon return, the
a0 contains the result value. This is not the case for some kernels.
Expand Down Expand Up @@ -59,9 +59,72 @@ def match_and_rewrite(self, op: riscv_func.SyscallOp, rewriter: PatternRewriter)
rewriter.replace_matched_op(ops, new_results=new_results)


class LowerRISCVFuncOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: riscv_func.FuncOp, rewriter: PatternRewriter):
body = op.func_body.block
first_op = body.first_op
assert first_op is not None
while len(body.args):
# arguments are passed to riscv functions via a0, a1, ...
# replace arguments with `GetRegisterOp`s
index = len(body.args) - 1
last_arg = body.args[-1]
get_reg_op = riscv.GetRegisterOp(riscv.Register(f"a{index}"))
last_arg.replace_by(get_reg_op.res)
rewriter.insert_op_before(get_reg_op, first_op)
first_op = get_reg_op
rewriter.erase_block_argument(last_arg)

label_body = rewriter.move_region_contents_to_new_regions(op.func_body)

rewriter.replace_matched_op(riscv.LabelOp(op.func_name.data, region=label_body))


class LowerRISCVFuncReturnOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: riscv_func.ReturnOp, rewriter: PatternRewriter):
for i, value in enumerate(op.values):
rewriter.insert_op_before_matched_op(
riscv.MVOp(value, rd=riscv.Register(f"a{i}"))
)
rewriter.replace_matched_op(riscv.ReturnOp())


class LowerRISCVCallOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: riscv_func.CallOp, rewriter: PatternRewriter):
for i, arg in enumerate(op.operands):
# Load arguments into a0...
rewriter.insert_op_before_matched_op(
riscv.MVOp(arg, rd=riscv.Register(f"a{i}"))
)

ops: list[Operation] = [
riscv.JalOp(op.func_name.data),
]
new_results: list[OpResult] = []

for i in range(len(op.results)):
get_reg = riscv.GetRegisterOp(riscv.Register(f"a{i}"))
move_res = riscv.MVOp(get_reg)
ops.extend((get_reg, move_res))
new_results.append(move_res.rd)

rewriter.replace_matched_op(ops, new_results=new_results)


class LowerRISCVFunc(ModulePass):
name = "lower-riscv-func"

def apply(self, ctx: MLContext, op: ModuleOp) -> None:
PatternRewriteWalker(LowerSyscallOp()).rewrite_module(op)
dce(op)
PatternRewriteWalker(
GreedyRewritePatternApplier(
[
LowerRISCVFuncReturnOp(),
LowerRISCVFuncOp(),
LowerRISCVCallOp(),
LowerSyscallOp(),
]
)
).rewrite_module(op)

0 comments on commit 4343a8f

Please sign in to comment.