diff --git a/tests/backend/riscv/test_func_to_riscv_func.py b/tests/backend/riscv/test_func_to_riscv_func.py new file mode 100644 index 0000000000..009e18aac2 --- /dev/null +++ b/tests/backend/riscv/test_func_to_riscv_func.py @@ -0,0 +1,66 @@ +import pytest + +from xdsl.backend.riscv.lowering.lower_func_riscv_func import LowerFuncToRiscvFunc +from xdsl.builder import Builder, ImplicitBuilder +from xdsl.dialects import func +from xdsl.dialects.builtin import ModuleOp +from xdsl.dialects.test import TestType +from xdsl.ir import MLContext +from xdsl.utils.test_value import TestSSAValue + + +def test_lower_non_main_failure(): + @ModuleOp + @Builder.implicit_region + def non_main(): + with ImplicitBuilder(func.FuncOp("not_main", ((), ())).body): + func.Return() + + with pytest.raises( + NotImplementedError, match="Only support lowering main function for now" + ): + LowerFuncToRiscvFunc().apply(MLContext(), non_main) + + +def test_lower_with_args_failure(): + @ModuleOp + @Builder.implicit_region + def multiple_args(): + with ImplicitBuilder( + func.FuncOp("main", ((TestType("misc"),), (TestType("misc"),))).body + ): + func.Return() + + with pytest.raises( + NotImplementedError, match="Only support functions with no arguments for now" + ): + LowerFuncToRiscvFunc().apply(MLContext(), multiple_args) + + +def test_lower_with_non_empty_return_failure(): + @ModuleOp + @Builder.implicit_region + def non_empty_return(): + with ImplicitBuilder(func.FuncOp("main", ((), ())).body): + test_ssa = TestSSAValue(TestType("misc")) + func.Return(test_ssa) + + with pytest.raises( + NotImplementedError, match="Only support return with no arguments for now" + ): + LowerFuncToRiscvFunc().apply(MLContext(), non_empty_return) + + +def test_lower_function_call_failure(): + @ModuleOp + @Builder.implicit_region + def function_call(): + with ImplicitBuilder(func.FuncOp("main", ((), ())).body): + test_ssa = TestSSAValue(TestType("misc")) + func.Call("bar", (test_ssa,), ()) + func.Return() + + with pytest.raises( + NotImplementedError, match="Function call lowering not implemented yet" + ): + LowerFuncToRiscvFunc().apply(MLContext(), function_call) diff --git a/tests/filecheck/backend/riscv/riscv_func_to_riscv_func.mlir b/tests/filecheck/backend/riscv/riscv_func_to_riscv_func.mlir new file mode 100644 index 0000000000..fa7822cda7 --- /dev/null +++ b/tests/filecheck/backend/riscv/riscv_func_to_riscv_func.mlir @@ -0,0 +1,15 @@ +// RUN: xdsl-opt -p lower-func-to-riscv-func --split-input-file %s | filecheck %s + +builtin.module { + func.func @main() { + "test.op"() : () -> () + func.return + } +} + +// CHECK: builtin.module { +// CHECK-NEXT: "riscv_func.func"() ({ +// CHECK-NEXT: "test.op"() : () -> () +// CHECK-NEXT: "riscv_func.return"() : () -> () +// CHECK-NEXT: }) {"sym_name" = "main"} : () -> () +// CHECK-NEXT: } diff --git a/xdsl/backend/riscv/lowering/lower_func_riscv_func.py b/xdsl/backend/riscv/lowering/lower_func_riscv_func.py new file mode 100644 index 0000000000..9741ea0751 --- /dev/null +++ b/xdsl/backend/riscv/lowering/lower_func_riscv_func.py @@ -0,0 +1,62 @@ +from xdsl.dialects import func, riscv_func +from xdsl.dialects.builtin import ModuleOp +from xdsl.ir import MLContext +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, + PatternRewriter, + PatternRewriteWalker, + RewritePattern, + op_type_rewrite_pattern, +) + + +class LowerFuncOp(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): + name = op.sym_name.data + + if name != "main": + raise NotImplementedError("Only support lowering main function for now") + + if op.body.blocks[0].args: + raise NotImplementedError( + "Only support functions with no arguments for now" + ) + + new_func = riscv_func.FuncOp( + op.sym_name.data, rewriter.move_region_contents_to_new_regions(op.body) + ) + + rewriter.replace_matched_op(new_func) + + +class LowerFuncCallOp(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: func.Call, rewriter: PatternRewriter) -> None: + raise NotImplementedError("Function call lowering not implemented yet") + + +class LowerReturnOp(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: func.Return, rewriter: PatternRewriter): + if op.arguments: + raise NotImplementedError("Only support return with no arguments for now") + + rewriter.replace_matched_op(riscv_func.ReturnOp(())) + + +class LowerFuncToRiscvFunc(ModulePass): + name = "lower-func-to-riscv-func" + + def apply(self, ctx: MLContext, op: ModuleOp) -> None: + PatternRewriteWalker( + GreedyRewritePatternApplier( + [ + LowerFuncOp(), + LowerFuncCallOp(), + LowerReturnOp(), + ] + ), + apply_recursively=False, + ).rewrite_module(op) diff --git a/xdsl/tools/command_line_tool.py b/xdsl/tools/command_line_tool.py index 6e8fa3d8f7..709bd6a374 100644 --- a/xdsl/tools/command_line_tool.py +++ b/xdsl/tools/command_line_tool.py @@ -3,6 +3,7 @@ import sys from typing import IO, Callable +from xdsl.backend.riscv.lowering.lower_func_riscv_func import LowerFuncToRiscvFunc from xdsl.backend.riscv.lowering.riscv_arith_lowering import RISCVLowerArith from xdsl.dialects.affine import Affine from xdsl.dialects.arith import Arith @@ -106,6 +107,7 @@ def get_all_passes() -> list[type[ModulePass]]: printf_to_llvm.PrintfToLLVM, riscv_register_allocation.RISCVRegisterAllocation, RISCVLowerArith, + LowerFuncToRiscvFunc, stencil_shape_inference.StencilShapeInferencePass, stencil_storage_materialization.StencilStorageMaterializationPass, reconcile_unrealized_casts.ReconcileUnrealizedCastsPass,