Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (riscv) add func to riscv_func lowering #1382

Merged
merged 4 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions tests/backend/riscv/test_func_to_riscv_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import pytest

from xdsl.backend.riscv.lowering.lower_func_riscv_func import LowerFuncToRiscvFunc
from xdsl.builder import Builder, ImplicitBuilder
from xdsl.dialects import func
from xdsl.dialects.builtin import ModuleOp
from xdsl.dialects.test import TestType
from xdsl.ir import MLContext
from xdsl.utils.test_value import TestSSAValue


def test_lower_non_main_failure():
@ModuleOp
@Builder.implicit_region
def non_main():
with ImplicitBuilder(func.FuncOp("not_main", ((), ())).body):
func.Return()

with pytest.raises(
NotImplementedError, match="Only support lowering main function for now"
):
LowerFuncToRiscvFunc().apply(MLContext(), non_main)


def test_lower_with_args_failure():
@ModuleOp
@Builder.implicit_region
def multiple_args():
with ImplicitBuilder(
func.FuncOp("main", ((TestType("misc"),), (TestType("misc"),))).body
):
func.Return()

with pytest.raises(
NotImplementedError, match="Only support functions with no arguments for now"
):
LowerFuncToRiscvFunc().apply(MLContext(), multiple_args)


def test_lower_with_non_empty_return_failure():
@ModuleOp
@Builder.implicit_region
def non_empty_return():
with ImplicitBuilder(func.FuncOp("main", ((), ())).body):
test_ssa = TestSSAValue(TestType("misc"))
func.Return(test_ssa)

with pytest.raises(
NotImplementedError, match="Only support return with no arguments for now"
):
LowerFuncToRiscvFunc().apply(MLContext(), non_empty_return)


def test_lower_function_call_failure():
@ModuleOp
@Builder.implicit_region
def function_call():
with ImplicitBuilder(func.FuncOp("main", ((), ())).body):
test_ssa = TestSSAValue(TestType("misc"))
func.Call("bar", (test_ssa,), ())
func.Return()

with pytest.raises(
NotImplementedError, match="Function call lowering not implemented yet"
):
LowerFuncToRiscvFunc().apply(MLContext(), function_call)
15 changes: 15 additions & 0 deletions tests/filecheck/backend/riscv/riscv_func_to_riscv_func.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: xdsl-opt -p lower-func-to-riscv-func --split-input-file %s | filecheck %s

builtin.module {
func.func @main() {
"test.op"() : () -> ()
func.return
}
}

// CHECK: builtin.module {
// CHECK-NEXT: "riscv_func.func"() ({
// CHECK-NEXT: "test.op"() : () -> ()
// CHECK-NEXT: "riscv_func.return"() : () -> ()
// CHECK-NEXT: }) {"sym_name" = "main"} : () -> ()
// CHECK-NEXT: }
62 changes: 62 additions & 0 deletions xdsl/backend/riscv/lowering/lower_func_riscv_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from xdsl.dialects import func, riscv_func
from xdsl.dialects.builtin import ModuleOp
from xdsl.ir import MLContext
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)


class LowerFuncOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
name = op.sym_name.data

if name != "main":
raise NotImplementedError("Only support lowering main function for now")

if op.body.blocks[0].args:
raise NotImplementedError(
"Only support functions with no arguments for now"
)

new_func = riscv_func.FuncOp(
op.sym_name.data, rewriter.move_region_contents_to_new_regions(op.body)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not know we had this API, TIL! Very nice!

)

rewriter.replace_matched_op(new_func)


class LowerFuncCallOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.Call, rewriter: PatternRewriter) -> None:
raise NotImplementedError("Function call lowering not implemented yet")
Comment on lines +34 to +37
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why this is even in the PR, it just does absolutely nothing?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it, it's a nice stub for us to replace later, and then the failing unit tests are replaced with working unit tests, so we get a little win in the future also



class LowerReturnOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.Return, rewriter: PatternRewriter):
if op.arguments:
raise NotImplementedError("Only support return with no arguments for now")

rewriter.replace_matched_op(riscv_func.ReturnOp(()))
Comment on lines +40 to +46
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this is here but not tested, and also quite useless imo?



class LowerFuncToRiscvFunc(ModulePass):
name = "lower-func-to-riscv-func"

def apply(self, ctx: MLContext, op: ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier(
[
LowerFuncOp(),
LowerFuncCallOp(),
LowerReturnOp(),
]
),
apply_recursively=False,
).rewrite_module(op)
2 changes: 2 additions & 0 deletions xdsl/tools/command_line_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
from typing import IO, Callable

from xdsl.backend.riscv.lowering.lower_func_riscv_func import LowerFuncToRiscvFunc
from xdsl.backend.riscv.lowering.riscv_arith_lowering import RISCVLowerArith
from xdsl.dialects.affine import Affine
from xdsl.dialects.arith import Arith
Expand Down Expand Up @@ -106,6 +107,7 @@ def get_all_passes() -> list[type[ModulePass]]:
printf_to_llvm.PrintfToLLVM,
riscv_register_allocation.RISCVRegisterAllocation,
RISCVLowerArith,
LowerFuncToRiscvFunc,
stencil_shape_inference.StencilShapeInferencePass,
stencil_storage_materialization.StencilStorageMaterializationPass,
reconcile_unrealized_casts.ReconcileUnrealizedCastsPass,
Expand Down