Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: Update support for returning multiple results using stencil.return #659

Merged
merged 12 commits into from Apr 17, 2023
43 changes: 43 additions & 0 deletions tests/dialects/test_stencil.py
@@ -0,0 +1,43 @@
from xdsl.dialects.builtin import FloatAttr, f32
from xdsl.dialects.experimental.stencil import ReturnOp, ResultType

from xdsl.utils.test_value import TestSSAValue


def test_stencil_return_single_float():
float_val1 = TestSSAValue(FloatAttr(4.0, f32))
return_op = ReturnOp.get([float_val1])

assert return_op.arg[0] is float_val1


def test_stencil_return_multiple_floats():
float_val1 = TestSSAValue(FloatAttr(4.0, f32))
float_val2 = TestSSAValue(FloatAttr(5.0, f32))
float_val3 = TestSSAValue(FloatAttr(6.0, f32))

return_op = ReturnOp.get([float_val1, float_val2, float_val3])

assert return_op.arg[0] is float_val1
assert return_op.arg[1] is float_val2
assert return_op.arg[2] is float_val3


def test_stencil_return_single_ResultType():
result_type_val1 = TestSSAValue(ResultType.from_type(f32))
return_op = ReturnOp.get([result_type_val1])

assert return_op.arg[0] is result_type_val1


def test_stencil_return_multiple_ResultType():
result_type_val1 = TestSSAValue(ResultType.from_type(f32))
result_type_val2 = TestSSAValue(ResultType.from_type(f32))
result_type_val3 = TestSSAValue(ResultType.from_type(f32))

return_op = ReturnOp.get(
[result_type_val1, result_type_val2, result_type_val3])

assert return_op.arg[0] is result_type_val1
assert return_op.arg[1] is result_type_val2
assert return_op.arg[2] is result_type_val3
116 changes: 59 additions & 57 deletions tests/filecheck/dialects/stencil/hdiff.mlir
Expand Up @@ -2,11 +2,12 @@

"builtin.module"() ({
"func.func"() ({
^0(%0 : !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, %1 : !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>):
^0(%0 : !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, %1 : !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, %2 : !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>):
%3 = "stencil.cast"(%0) {"lb" = #stencil.index<[-4 : i64, -4 : i64, -4 : i64]>, "ub" = #stencil.index<[68 : i64, 68 : i64, 68 : i64]>} : (!stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>) -> !stencil.field<[72 : i64, 72 : i64, 72 : i64], f64>
%4 = "stencil.cast"(%1) {"lb" = #stencil.index<[-4 : i64, -4 : i64, -4 : i64]>, "ub" = #stencil.index<[68 : i64, 68 : i64, 68 : i64]>} : (!stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>) -> !stencil.field<[72 : i64, 72 : i64, 72 : i64], f64>
%5 = "stencil.cast"(%2) {"lb" = #stencil.index<[-4 : i64, -4 : i64, -4 : i64]>, "ub" = #stencil.index<[68 : i64, 68 : i64, 68 : i64]>} : (!stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>) -> !stencil.field<[72 : i64, 72 : i64, 72 : i64], f64>
%6 = "stencil.load"(%3) : (!stencil.field<[72 : i64, 72 : i64, 72 : i64], f64>) -> !stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>
%8 = "stencil.apply"(%6) ({
%7, %8 = "stencil.apply"(%6) ({
^1(%9 : !stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>):
%10 = "stencil.access"(%9) {"offset" = #stencil.index<[-1 : i64, 0 : i64, 0 : i64]>} : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>) -> f64
%11 = "stencil.access"(%9) {"offset" = #stencil.index<[1 : i64, 0 : i64, 0 : i64]>} : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>) -> f64
Expand All @@ -19,79 +20,80 @@
%cst = "arith.constant"() {"value" = -4.0 : f64} : () -> f64
%18 = "arith.mulf"(%14, %cst) : (f64, f64) -> f64
%19 = "arith.addf"(%18, %17) : (f64, f64) -> f64
"stencil.return"(%19) : (!stencil.result<f64>) -> ()
}) : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>) -> !stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>
"stencil.store"(%8, %4) {"lb" = #stencil.index<[0 : i64, 0 : i64, 0: i64]>, "ub" = #stencil.index<[64 : i64, 64 : i64, 64 : i64]>} : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>, !stencil.field<[72 : i64, 72 : i64, 72 : i64], f64>) -> ()
"stencil.return"(%19, %18) : (f64, f64) -> ()
}) : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>) -> (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>, !stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>)
"stencil.store"(%7, %4) {"lb" = #stencil.index<[0 : i64, 0 : i64, 0 : i64]>, "ub" = #stencil.index<[64 : i64, 64 : i64, 64 : i64]>} : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>, !stencil.field<[72 : i64, 72 : i64, 72 : i64], f64>) -> ()
"func.return"() : () -> ()
}) {"function_type" = (!stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>) -> (), "sym_name" = "stencil_hdiff"} : () -> ()
}) {"function_type" = (!stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>) -> (), "sym_name" = "stencil_hdiff"} : () -> ()
}) : () -> ()

// CHECK: "builtin.module"() ({
// CHECK-NEXT: "func.func"() ({
// CHECK-NEXT: ^0(%0 : memref<?x?x?xf64>, %1 : memref<?x?x?xf64>):
// CHECK-NEXT: %2 = "memref.cast"(%0) : (memref<?x?x?xf64>) -> memref<72x72x72xf64>
// CHECK-NEXT: %3 = "memref.cast"(%1) : (memref<?x?x?xf64>) -> memref<72x72x72xf64>
// CHECK-NEXT: %4 = "memref.subview"(%3) {"static_offsets" = array<i64: 4, 4, 4>, "static_sizes" = array<i64: 64, 64, 64>, "static_strides" = array<i64: 1, 1, 1>, "operand_segment_sizes" = array<i32: 1, 0, 0, 0>} : (memref<72x72x72xf64>) -> memref<64x64x64xf64, strided<[5184, 72, 1], offset: 21028>>
// CHECK-NEXT: %5 = "memref.subview"(%2) {"static_offsets" = array<i64: 3, 3, 4>, "static_sizes" = array<i64: 66, 66, 64>, "static_strides" = array<i64: 1, 1, 1>, "operand_segment_sizes" = array<i32: 1, 0, 0, 0>} : (memref<72x72x72xf64>) -> memref<66x66x64xf64, strided<[5184, 72, 1], offset: 15772>>
// CHECK-NEXT: %6 = "arith.constant"() {"value" = 0 : index} : () -> index
// CHECK-NEXT: %7 = "arith.constant"() {"value" = 1 : index} : () -> index
// CHECK-NEXT: %8 = "arith.constant"() {"value" = 64 : index} : () -> index
// CHECK-NEXT: %9 = "arith.constant"() {"value" = 64 : index} : () -> index
// CHECK-NEXT: ^0(%0 : memref<?x?x?xf64>, %1 : memref<?x?x?xf64>, %2 : memref<?x?x?xf64>):
// CHECK-NEXT: %3 = "memref.cast"(%0) : (memref<?x?x?xf64>) -> memref<72x72x72xf64>
// CHECK-NEXT: %4 = "memref.cast"(%1) : (memref<?x?x?xf64>) -> memref<72x72x72xf64>
// CHECK-NEXT: %5 = "memref.subview"(%4) {"static_offsets" = array<i64: 4, 4, 4>, "static_sizes" = array<i64: 64, 64, 64>, "static_strides" = array<i64: 1, 1, 1>, "operand_segment_sizes" = array<i32: 1, 0, 0, 0>} : (memref<72x72x72xf64>) -> memref<64x64x64xf64, strided<[5184, 72, 1], offset: 21028>>
// CHECK-NEXT: %6 = "memref.cast"(%2) : (memref<?x?x?xf64>) -> memref<72x72x72xf64>
// CHECK-NEXT: %7 = "memref.subview"(%3) {"static_offsets" = array<i64: 3, 3, 4>, "static_sizes" = array<i64: 66, 66, 64>, "static_strides" = array<i64: 1, 1, 1>, "operand_segment_sizes" = array<i32: 1, 0, 0, 0>} : (memref<72x72x72xf64>) -> memref<66x66x64xf64, strided<[5184, 72, 1], offset: 15772>>
// CHECK-NEXT: %8 = "arith.constant"() {"value" = 0 : index} : () -> index
// CHECK-NEXT: %9 = "arith.constant"() {"value" = 1 : index} : () -> index
// CHECK-NEXT: %10 = "arith.constant"() {"value" = 64 : index} : () -> index
// CHECK-NEXT: "scf.parallel"(%6, %8, %7) ({
// CHECK-NEXT: ^1(%11 : index):
// CHECK-NEXT: "scf.for"(%6, %9, %7) ({
// CHECK-NEXT: ^2(%12 : index):
// CHECK-NEXT: "scf.for"(%6, %10, %7) ({
// CHECK-NEXT: ^3(%13 : index):
// CHECK-NEXT: %14 = "arith.constant"() {"value" = 0 : index} : () -> index
// CHECK-NEXT: %15 = "arith.constant"() {"value" = 1 : index} : () -> index
// CHECK-NEXT: %11 = "arith.constant"() {"value" = 64 : index} : () -> index
// CHECK-NEXT: %12 = "arith.constant"() {"value" = 64 : index} : () -> index
// CHECK-NEXT: "scf.parallel"(%8, %10, %9) ({
// CHECK-NEXT: ^1(%13 : index):
// CHECK-NEXT: "scf.for"(%8, %11, %9) ({
// CHECK-NEXT: ^2(%14 : index):
// CHECK-NEXT: "scf.for"(%8, %12, %9) ({
// CHECK-NEXT: ^3(%15 : index):
// CHECK-NEXT: %16 = "arith.constant"() {"value" = 0 : index} : () -> index
// CHECK-NEXT: %17 = "arith.addi"(%11, %14) : (index, index) -> index
// CHECK-NEXT: %18 = "arith.addi"(%12, %15) : (index, index) -> index
// CHECK-NEXT: %17 = "arith.constant"() {"value" = 1 : index} : () -> index
// CHECK-NEXT: %18 = "arith.constant"() {"value" = 0 : index} : () -> index
// CHECK-NEXT: %19 = "arith.addi"(%13, %16) : (index, index) -> index
// CHECK-NEXT: %20 = "memref.load"(%5, %17, %18, %19) : (memref<66x66x64xf64, strided<[5184, 72, 1], offset: 15772>>, index, index, index) -> f64
// CHECK-NEXT: %21 = "arith.constant"() {"value" = 2 : index} : () -> index
// CHECK-NEXT: %22 = "arith.constant"() {"value" = 1 : index} : () -> index
// CHECK-NEXT: %23 = "arith.constant"() {"value" = 0 : index} : () -> index
// CHECK-NEXT: %24 = "arith.addi"(%11, %21) : (index, index) -> index
// CHECK-NEXT: %25 = "arith.addi"(%12, %22) : (index, index) -> index
// CHECK-NEXT: %20 = "arith.addi"(%14, %17) : (index, index) -> index
// CHECK-NEXT: %21 = "arith.addi"(%15, %18) : (index, index) -> index
// CHECK-NEXT: %22 = "memref.load"(%7, %19, %20, %21) : (memref<66x66x64xf64, strided<[5184, 72, 1], offset: 15772>>, index, index, index) -> f64
// CHECK-NEXT: %23 = "arith.constant"() {"value" = 2 : index} : () -> index
// CHECK-NEXT: %24 = "arith.constant"() {"value" = 1 : index} : () -> index
// CHECK-NEXT: %25 = "arith.constant"() {"value" = 0 : index} : () -> index
// CHECK-NEXT: %26 = "arith.addi"(%13, %23) : (index, index) -> index
// CHECK-NEXT: %27 = "memref.load"(%5, %24, %25, %26) : (memref<66x66x64xf64, strided<[5184, 72, 1], offset: 15772>>, index, index, index) -> f64
// CHECK-NEXT: %28 = "arith.constant"() {"value" = 1 : index} : () -> index
// CHECK-NEXT: %29 = "arith.constant"() {"value" = 2 : index} : () -> index
// CHECK-NEXT: %30 = "arith.constant"() {"value" = 0 : index} : () -> index
// CHECK-NEXT: %31 = "arith.addi"(%11, %28) : (index, index) -> index
// CHECK-NEXT: %32 = "arith.addi"(%12, %29) : (index, index) -> index
// CHECK-NEXT: %27 = "arith.addi"(%14, %24) : (index, index) -> index
// CHECK-NEXT: %28 = "arith.addi"(%15, %25) : (index, index) -> index
// CHECK-NEXT: %29 = "memref.load"(%7, %26, %27, %28) : (memref<66x66x64xf64, strided<[5184, 72, 1], offset: 15772>>, index, index, index) -> f64
// CHECK-NEXT: %30 = "arith.constant"() {"value" = 1 : index} : () -> index
// CHECK-NEXT: %31 = "arith.constant"() {"value" = 2 : index} : () -> index
// CHECK-NEXT: %32 = "arith.constant"() {"value" = 0 : index} : () -> index
// CHECK-NEXT: %33 = "arith.addi"(%13, %30) : (index, index) -> index
// CHECK-NEXT: %34 = "memref.load"(%5, %31, %32, %33) : (memref<66x66x64xf64, strided<[5184, 72, 1], offset: 15772>>, index, index, index) -> f64
// CHECK-NEXT: %35 = "arith.constant"() {"value" = 1 : index} : () -> index
// CHECK-NEXT: %36 = "arith.constant"() {"value" = 0 : index} : () -> index
// CHECK-NEXT: %37 = "arith.constant"() {"value" = 0 : index} : () -> index
// CHECK-NEXT: %38 = "arith.addi"(%11, %35) : (index, index) -> index
// CHECK-NEXT: %39 = "arith.addi"(%12, %36) : (index, index) -> index
// CHECK-NEXT: %34 = "arith.addi"(%14, %31) : (index, index) -> index
// CHECK-NEXT: %35 = "arith.addi"(%15, %32) : (index, index) -> index
// CHECK-NEXT: %36 = "memref.load"(%7, %33, %34, %35) : (memref<66x66x64xf64, strided<[5184, 72, 1], offset: 15772>>, index, index, index) -> f64
// CHECK-NEXT: %37 = "arith.constant"() {"value" = 1 : index} : () -> index
// CHECK-NEXT: %38 = "arith.constant"() {"value" = 0 : index} : () -> index
// CHECK-NEXT: %39 = "arith.constant"() {"value" = 0 : index} : () -> index
// CHECK-NEXT: %40 = "arith.addi"(%13, %37) : (index, index) -> index
// CHECK-NEXT: %41 = "memref.load"(%5, %38, %39, %40) : (memref<66x66x64xf64, strided<[5184, 72, 1], offset: 15772>>, index, index, index) -> f64
// CHECK-NEXT: %42 = "arith.constant"() {"value" = 1 : index} : () -> index
// CHECK-NEXT: %43 = "arith.constant"() {"value" = 1 : index} : () -> index
// CHECK-NEXT: %44 = "arith.constant"() {"value" = 0 : index} : () -> index
// CHECK-NEXT: %45 = "arith.addi"(%11, %42) : (index, index) -> index
// CHECK-NEXT: %46 = "arith.addi"(%12, %43) : (index, index) -> index
// CHECK-NEXT: %41 = "arith.addi"(%14, %38) : (index, index) -> index
// CHECK-NEXT: %42 = "arith.addi"(%15, %39) : (index, index) -> index
// CHECK-NEXT: %43 = "memref.load"(%7, %40, %41, %42) : (memref<66x66x64xf64, strided<[5184, 72, 1], offset: 15772>>, index, index, index) -> f64
// CHECK-NEXT: %44 = "arith.constant"() {"value" = 1 : index} : () -> index
// CHECK-NEXT: %45 = "arith.constant"() {"value" = 1 : index} : () -> index
// CHECK-NEXT: %46 = "arith.constant"() {"value" = 0 : index} : () -> index
// CHECK-NEXT: %47 = "arith.addi"(%13, %44) : (index, index) -> index
// CHECK-NEXT: %48 = "memref.load"(%5, %45, %46, %47) : (memref<66x66x64xf64, strided<[5184, 72, 1], offset: 15772>>, index, index, index) -> f64
// CHECK-NEXT: %49 = "arith.addf"(%20, %27) : (f64, f64) -> f64
// CHECK-NEXT: %50 = "arith.addf"(%34, %41) : (f64, f64) -> f64
// CHECK-NEXT: %51 = "arith.addf"(%49, %50) : (f64, f64) -> f64
// CHECK-NEXT: %48 = "arith.addi"(%14, %45) : (index, index) -> index
// CHECK-NEXT: %49 = "arith.addi"(%15, %46) : (index, index) -> index
// CHECK-NEXT: %50 = "memref.load"(%7, %47, %48, %49) : (memref<66x66x64xf64, strided<[5184, 72, 1], offset: 15772>>, index, index, index) -> f64
// CHECK-NEXT: %51 = "arith.addf"(%22, %29) : (f64, f64) -> f64
// CHECK-NEXT: %52 = "arith.addf"(%36, %43) : (f64, f64) -> f64
// CHECK-NEXT: %53 = "arith.addf"(%51, %52) : (f64, f64) -> f64
// CHECK-NEXT: %cst = "arith.constant"() {"value" = -4.0 : f64} : () -> f64
// CHECK-NEXT: %52 = "arith.mulf"(%48, %cst) : (f64, f64) -> f64
// CHECK-NEXT: %53 = "arith.addf"(%52, %51) : (f64, f64) -> f64
// CHECK-NEXT: "memref.store"(%53, %4, %11, %12, %13) : (f64, memref<64x64x64xf64, strided<[5184, 72, 1], offset: 21028>>, index, index, index) -> ()
// CHECK-NEXT: %54 = "arith.mulf"(%50, %cst) : (f64, f64) -> f64
// CHECK-NEXT: %55 = "arith.addf"(%54, %53) : (f64, f64) -> f64
// CHECK-NEXT: "memref.store"(%55, %5, %13, %14, %15) : (f64, memref<64x64x64xf64, strided<[5184, 72, 1], offset: 21028>>, index, index, index) -> ()
// CHECK-NEXT: "scf.yield"() : () -> ()
// CHECK-NEXT: }) : (index, index, index) -> ()
// CHECK-NEXT: "scf.yield"() : () -> ()
// CHECK-NEXT: }) : (index, index, index) -> ()
// CHECK-NEXT: "scf.yield"() : () -> ()
// CHECK-NEXT: }) {"operand_segment_sizes" = array<i32: 1, 1, 1, 0>} : (index, index, index) -> ()
// CHECK-NEXT: "func.return"() : () -> ()
// CHECK-NEXT: }) {"function_type" = (memref<?x?x?xf64>, memref<?x?x?xf64>) -> (), "sym_name" = "stencil_hdiff"} : () -> ()
// CHECK-NEXT: }) {"function_type" = (memref<?x?x?xf64>, memref<?x?x?xf64>, memref<?x?x?xf64>) -> (), "sym_name" = "stencil_hdiff"} : () -> ()
// CHECK-NEXT: }) : () -> ()
6 changes: 3 additions & 3 deletions xdsl/dialects/experimental/stencil.py
Expand Up @@ -481,11 +481,11 @@ class ReturnOp(Operation):
stencil.return %0 : !stencil.result<f64>
"""
name: str = "stencil.return"
arg: Annotated[Operand, ResultType | AnyFloat]
arg: Annotated[VarOperand, ResultType | AnyFloat]

@staticmethod
def get(*res: SSAValue | Operation):
return ReturnOp.build(operands=[*res])
def get(res: Sequence[SSAValue | Operation]):
return ReturnOp.build(operands=[list(res)])


@irdl_op_definition
Expand Down