Skip to content

Commit

Permalink
dialects: Add support for returning multiple results using stencil.re…
Browse files Browse the repository at this point in the history
…turn (#648)
  • Loading branch information
meshtag committed Apr 3, 2023
1 parent a78a113 commit ba826ec
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 92 deletions.
128 changes: 69 additions & 59 deletions tests/filecheck/dialects/stencil/hdiff.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

"builtin.module"() ({
"func.func"() ({
^0(%0 : !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, %1 : !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>):
^0(%0 : !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, %1 : !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, %2 : !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>):
%3 = "stencil.cast"(%0) {"lb" = #stencil.index<[-4 : i64, -4 : i64, -4 : i64]>, "ub" = #stencil.index<[68 : i64, 68 : i64, 68 : i64]>} : (!stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>) -> !stencil.field<[72 : i64, 72 : i64, 72 : i64], f64>
%4 = "stencil.cast"(%1) {"lb" = #stencil.index<[-4 : i64, -4 : i64, -4 : i64]>, "ub" = #stencil.index<[68 : i64, 68 : i64, 68 : i64]>} : (!stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>) -> !stencil.field<[72 : i64, 72 : i64, 72 : i64], f64>
%5 = "stencil.cast"(%2) {"lb" = #stencil.index<[-4 : i64, -4 : i64, -4 : i64]>, "ub" = #stencil.index<[68 : i64, 68 : i64, 68 : i64]>} : (!stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>) -> !stencil.field<[72 : i64, 72 : i64, 72 : i64], f64>
%6 = "stencil.load"(%3) : (!stencil.field<[72 : i64, 72 : i64, 72 : i64], f64>) -> !stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>
%8 = "stencil.apply"(%6) ({
%7, %8 = "stencil.apply"(%6) ({
^1(%9 : !stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>):
%10 = "stencil.access"(%9) {"offset" = #stencil.index<[-1 : i64, 0 : i64, 0 : i64]>} : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>) -> f64
%11 = "stencil.access"(%9) {"offset" = #stencil.index<[1 : i64, 0 : i64, 0 : i64]>} : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>) -> f64
Expand All @@ -20,76 +21,85 @@
%cst = "arith.constant"() {"value" = -4.0 : f64} : () -> f64
%18 = "arith.mulf"(%14, %cst) : (f64, f64) -> f64
%19 = "arith.addf"(%18, %17) : (f64, f64) -> f64
"stencil.return"(%19) : (!stencil.result<f64>) -> ()
}) : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>) -> !stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>
"stencil.store"(%8, %4) {"lb" = #stencil.index<[0 : i64, 0 : i64, 0: i64]>, "ub" = #stencil.index<[64 : i64, 64 : i64, 64 : i64]>} : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>, !stencil.field<[72 : i64, 72 : i64, 72 : i64], f64>) -> ()
"stencil.return"(%19, %18) : (f64, f64) -> ()
}) : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>) -> (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>, !stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>)
"stencil.store"(%7, %4) {"lb" = #stencil.index<[0 : i64, 0 : i64, 0 : i64]>, "ub" = #stencil.index<[64 : i64, 64 : i64, 64 : i64]>} : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>, !stencil.field<[72 : i64, 72 : i64, 72 : i64], f64>) -> ()
"stencil.store"(%8, %5) {"lb" = #stencil.index<[0 : i64, 0 : i64, 0 : i64]>, "ub" = #stencil.index<[64 : i64, 64 : i64, 64 : i64]>} : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>, !stencil.field<[72 : i64, 72 : i64, 72 : i64], f64>) -> ()
"func.return"() : () -> ()
}) {"function_type" = (!stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>) -> (), "sym_name" = "stencil_hdiff"} : () -> ()
}) {"function_type" = (!stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>) -> (), "sym_name" = "stencil_hdiff"} : () -> ()
}) : () -> ()


// CHECK-NEXT: "builtin.module"() ({
// CHECK-NEXT: "func.func"() ({
// CHECK-NEXT: ^0(%0 : memref<?x?x?xf64>, %1 : memref<?x?x?xf64>):
// CHECK-NEXT: %2 = "memref.cast"(%0) : (memref<?x?x?xf64>) -> memref<72x72x72xf64>
// CHECK-NEXT: %3 = "memref.cast"(%1) : (memref<?x?x?xf64>) -> memref<72x72x72xf64>
// CHECK-NEXT: %4 = "arith.constant"() {"value" = 0 : index} : () -> index
// CHECK-NEXT: %5 = "arith.constant"() {"value" = 1 : index} : () -> index
// CHECK-NEXT: %6 = "arith.constant"() {"value" = 64 : index} : () -> index
// CHECK-NEXT: %7 = "arith.constant"() {"value" = 64 : index} : () -> index
// CHECK-NEXT: ^0(%0 : memref<?x?x?xf64>, %1 : memref<?x?x?xf64>, %2 : memref<?x?x?xf64>):
// CHECK-NEXT: %3 = "memref.cast"(%0) : (memref<?x?x?xf64>) -> memref<72x72x72xf64>
// CHECK-NEXT: %4 = "memref.cast"(%1) : (memref<?x?x?xf64>) -> memref<72x72x72xf64>
// CHECK-NEXT: %5 = "memref.cast"(%2) : (memref<?x?x?xf64>) -> memref<72x72x72xf64>
// CHECK-NEXT: %6 = "arith.constant"() {"value" = 0 : index} : () -> index
// CHECK-NEXT: %7 = "arith.constant"() {"value" = 1 : index} : () -> index
// CHECK-NEXT: %8 = "arith.constant"() {"value" = 64 : index} : () -> index
// CHECK-NEXT: "scf.parallel"(%4, %4, %4, %6, %7, %8, %5, %5, %5) ({
// CHECK-NEXT: ^1(%9 : index, %10 : index, %11 : index):
// CHECK-NEXT: %12 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %13 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %14 = "arith.constant"() {"value" = 3 : index} : () -> index
// CHECK-NEXT: %15 = "arith.addi"(%11, %12) : (index, index) -> index
// CHECK-NEXT: %16 = "arith.addi"(%10, %13) : (index, index) -> index
// CHECK-NEXT: %17 = "arith.addi"(%9, %14) : (index, index) -> index
// CHECK-NEXT: %18 = "memref.load"(%2, %15, %16, %17) : (memref<72x72x72xf64>, index, index, index) -> f64
// CHECK-NEXT: %19 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %20 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %21 = "arith.constant"() {"value" = 5 : index} : () -> index
// CHECK-NEXT: %22 = "arith.addi"(%11, %19) : (index, index) -> index
// CHECK-NEXT: %23 = "arith.addi"(%10, %20) : (index, index) -> index
// CHECK-NEXT: %24 = "arith.addi"(%9, %21) : (index, index) -> index
// CHECK-NEXT: %25 = "memref.load"(%2, %22, %23, %24) : (memref<72x72x72xf64>, index, index, index) -> f64
// CHECK-NEXT: %26 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %27 = "arith.constant"() {"value" = 5 : index} : () -> index
// CHECK-NEXT: %9 = "arith.constant"() {"value" = 64 : index} : () -> index
// CHECK-NEXT: %10 = "arith.constant"() {"value" = 64 : index} : () -> index
// CHECK-NEXT: "scf.parallel"(%6, %6, %6, %8, %9, %10, %7, %7, %7) ({
// CHECK-NEXT: ^1(%11 : index, %12 : index, %13 : index):
// CHECK-NEXT: %14 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %15 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %16 = "arith.constant"() {"value" = 3 : index} : () -> index
// CHECK-NEXT: %17 = "arith.addi"(%13, %14) : (index, index) -> index
// CHECK-NEXT: %18 = "arith.addi"(%12, %15) : (index, index) -> index
// CHECK-NEXT: %19 = "arith.addi"(%11, %16) : (index, index) -> index
// CHECK-NEXT: %20 = "memref.load"(%3, %17, %18, %19) : (memref<72x72x72xf64>, index, index, index) -> f64
// CHECK-NEXT: %21 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %22 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %23 = "arith.constant"() {"value" = 5 : index} : () -> index
// CHECK-NEXT: %24 = "arith.addi"(%13, %21) : (index, index) -> index
// CHECK-NEXT: %25 = "arith.addi"(%12, %22) : (index, index) -> index
// CHECK-NEXT: %26 = "arith.addi"(%11, %23) : (index, index) -> index
// CHECK-NEXT: %27 = "memref.load"(%3, %24, %25, %26) : (memref<72x72x72xf64>, index, index, index) -> f64
// CHECK-NEXT: %28 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %29 = "arith.addi"(%11, %26) : (index, index) -> index
// CHECK-NEXT: %30 = "arith.addi"(%10, %27) : (index, index) -> index
// CHECK-NEXT: %31 = "arith.addi"(%9, %28) : (index, index) -> index
// CHECK-NEXT: %32 = "memref.load"(%2, %29, %30, %31) : (memref<72x72x72xf64>, index, index, index) -> f64
// CHECK-NEXT: %33 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %34 = "arith.constant"() {"value" = 3 : index} : () -> index
// CHECK-NEXT: %29 = "arith.constant"() {"value" = 5 : index} : () -> index
// CHECK-NEXT: %30 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %31 = "arith.addi"(%13, %28) : (index, index) -> index
// CHECK-NEXT: %32 = "arith.addi"(%12, %29) : (index, index) -> index
// CHECK-NEXT: %33 = "arith.addi"(%11, %30) : (index, index) -> index
// CHECK-NEXT: %34 = "memref.load"(%3, %31, %32, %33) : (memref<72x72x72xf64>, index, index, index) -> f64
// CHECK-NEXT: %35 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %36 = "arith.addi"(%11, %33) : (index, index) -> index
// CHECK-NEXT: %37 = "arith.addi"(%10, %34) : (index, index) -> index
// CHECK-NEXT: %38 = "arith.addi"(%9, %35) : (index, index) -> index
// CHECK-NEXT: %39 = "memref.load"(%2, %36, %37, %38) : (memref<72x72x72xf64>, index, index, index) -> f64
// CHECK-NEXT: %40 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %41 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %36 = "arith.constant"() {"value" = 3 : index} : () -> index
// CHECK-NEXT: %37 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %38 = "arith.addi"(%13, %35) : (index, index) -> index
// CHECK-NEXT: %39 = "arith.addi"(%12, %36) : (index, index) -> index
// CHECK-NEXT: %40 = "arith.addi"(%11, %37) : (index, index) -> index
// CHECK-NEXT: %41 = "memref.load"(%3, %38, %39, %40) : (memref<72x72x72xf64>, index, index, index) -> f64
// CHECK-NEXT: %42 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %43 = "arith.addi"(%11, %40) : (index, index) -> index
// CHECK-NEXT: %44 = "arith.addi"(%10, %41) : (index, index) -> index
// CHECK-NEXT: %45 = "arith.addi"(%9, %42) : (index, index) -> index
// CHECK-NEXT: %46 = "memref.load"(%2, %43, %44, %45) : (memref<72x72x72xf64>, index, index, index) -> f64
// CHECK-NEXT: %47 = "arith.addf"(%18, %25) : (f64, f64) -> f64
// CHECK-NEXT: %48 = "arith.addf"(%32, %39) : (f64, f64) -> f64
// CHECK-NEXT: %49 = "arith.addf"(%47, %48) : (f64, f64) -> f64
// CHECK-NEXT: %43 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %44 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %45 = "arith.addi"(%13, %42) : (index, index) -> index
// CHECK-NEXT: %46 = "arith.addi"(%12, %43) : (index, index) -> index
// CHECK-NEXT: %47 = "arith.addi"(%11, %44) : (index, index) -> index
// CHECK-NEXT: %48 = "memref.load"(%3, %45, %46, %47) : (memref<72x72x72xf64>, index, index, index) -> f64
// CHECK-NEXT: %49 = "arith.addf"(%20, %27) : (f64, f64) -> f64
// CHECK-NEXT: %50 = "arith.addf"(%34, %41) : (f64, f64) -> f64
// CHECK-NEXT: %51 = "arith.addf"(%49, %50) : (f64, f64) -> f64
// CHECK-NEXT: %cst = "arith.constant"() {"value" = -4.0 : f64} : () -> f64
// CHECK-NEXT: %50 = "arith.mulf"(%46, %cst) : (f64, f64) -> f64
// CHECK-NEXT: %51 = "arith.addf"(%50, %49) : (f64, f64) -> f64
// CHECK-NEXT: %52 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %53 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %52 = "arith.mulf"(%48, %cst) : (f64, f64) -> f64
// CHECK-NEXT: %53 = "arith.addf"(%52, %51) : (f64, f64) -> f64
// CHECK-NEXT: %54 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %55 = "arith.addi"(%11, %52) : (index, index) -> index
// CHECK-NEXT: %56 = "arith.addi"(%10, %53) : (index, index) -> index
// CHECK-NEXT: %57 = "arith.addi"(%9, %54) : (index, index) -> index
// CHECK-NEXT: "memref.store"(%51, %3, %55, %56, %57) : (f64, memref<72x72x72xf64>, index, index, index) -> ()
// CHECK-NEXT: %55 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %56 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %57 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %58 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %59 = "arith.constant"() {"value" = 4 : index} : () -> index
// CHECK-NEXT: %60 = "arith.addi"(%13, %54) : (index, index) -> index
// CHECK-NEXT: %61 = "arith.addi"(%12, %55) : (index, index) -> index
// CHECK-NEXT: %62 = "arith.addi"(%11, %56) : (index, index) -> index
// CHECK-NEXT: %63 = "arith.addi"(%13, %57) : (index, index) -> index
// CHECK-NEXT: %64 = "arith.addi"(%12, %58) : (index, index) -> index
// CHECK-NEXT: %65 = "arith.addi"(%11, %59) : (index, index) -> index
// CHECK-NEXT: "memref.store"(%53, %4, %60, %61, %62) : (f64, memref<72x72x72xf64>, index, index, index) -> ()
// CHECK-NEXT: "memref.store"(%52, %5, %63, %64, %65) : (f64, memref<72x72x72xf64>, index, index, index) -> ()
// CHECK-NEXT: "scf.yield"() : () -> ()
// CHECK-NEXT: }) {"operand_segment_sizes" = array<i32: 3, 3, 3, 0>} : (index, index, index, index, index, index, index, index, index) -> ()
// CHECK-NEXT: "func.return"() : () -> ()
// CHECK-NEXT: }) {"function_type" = (memref<?x?x?xf64>, memref<?x?x?xf64>) -> (), "sym_name" = "stencil_hdiff"} : () -> ()
// CHECK-NEXT: }) {"function_type" = (memref<?x?x?xf64>, memref<?x?x?xf64>, memref<?x?x?xf64>) -> (), "sym_name" = "stencil_hdiff"} : () -> ()
// CHECK-NEXT: }) : () -> ()
2 changes: 1 addition & 1 deletion xdsl/dialects/experimental/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ class ReturnOp(Operation):
stencil.return %0 : !stencil.result<f64>
"""
name: str = "stencil.return"
arg: Annotated[Operand, ResultType | AnyFloat]
arg: Annotated[VarOperand, ResultType | AnyFloat]

@staticmethod
def get(*res: SSAValue | Operation):
Expand Down
70 changes: 38 additions & 32 deletions xdsl/transforms/experimental/ConvertStencilToLLMLIR.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def GetMemRefFromFieldWithLBAndUB(memref_element_type: _TypeElement,
@dataclass
class CastOpToMemref(RewritePattern):

return_target: dict[ReturnOp, CastOp | memref.Cast]
return_target: dict[str, CastOp | memref.Cast]
gpu: bool = False

@op_type_rewrite_pattern
Expand Down Expand Up @@ -95,40 +95,47 @@ def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter, /):
@dataclass
class ReturnOpToMemref(RewritePattern):

return_target: dict[ReturnOp, CastOp | memref.Cast]
return_target: dict[str, CastOp | memref.Cast]

@op_type_rewrite_pattern
def match_and_rewrite(self, op: ReturnOp, rewriter: PatternRewriter, /):
off_const_ops: list[arith.Constant] = []
off_sum_ops: list[arith.Addi] = []
load: list[memref.Store] = []

parallel = op.parent_op()
assert isinstance(parallel, scf.ParallelOp | gpu.LaunchOp)

cast = self.return_target[op]
for j in range(len(op.arg)):
cast = self.return_target[str(op) + "_result" + str(j)]
assert isinstance(cast, CastOp)

assert isinstance(cast, CastOp)

offsets = cast.lb
assert isinstance(offsets, IndexAttr)
offsets = cast.lb
assert isinstance(offsets, IndexAttr)

assert (block := op.parent_block()) is not None
assert (block := op.parent_block()) is not None

off_const_ops = [
arith.Constant.from_int_and_width(-x.value.data,
builtin.IndexType())
for x in offsets.array.data
]
off_const_ops.reverse()
off_const_ops_curr = [
arith.Constant.from_int_and_width(-x.value.data,
builtin.IndexType())
for x in offsets.array.data
]
off_const_ops_curr.reverse()
off_const_ops.extend(off_const_ops_curr)

args = list(block.args)
args.reverse()
args = list(block.args)
args.reverse()

off_sum_ops = [
arith.Addi.get(i, x) for i, x in zip(args, off_const_ops)
]
off_sum_ops_curr = [
arith.Addi.get(i, x) for i, x in zip(args, off_const_ops_curr)
]
off_sum_ops.extend(off_sum_ops_curr)

load = memref.Store.get(op.arg, cast.result, off_sum_ops)
load_curr = memref.Store.get(op.arg[j], cast.result,
off_sum_ops_curr)
load.append(load_curr)

rewriter.replace_matched_op([*off_const_ops, *off_sum_ops, load])
rewriter.replace_matched_op([*off_const_ops, *off_sum_ops, *load])


def verify_load_bounds(cast: CastOp, load: LoadOp):
Expand Down Expand Up @@ -332,7 +339,7 @@ def match_and_rewrite(self, op: ExternalStoreOp, rewriter: PatternRewriter,

def return_target_analysis(module: ModuleOp):

return_targets: dict[ReturnOp, CastOp | memref.Cast] = {}
return_targets: dict[str, CastOp | memref.Cast] = {}

def map_returns(op: Operation) -> None:
if not isinstance(op, ReturnOp):
Expand All @@ -341,18 +348,17 @@ def map_returns(op: Operation) -> None:
apply = op.parent_op()
assert isinstance(apply, ApplyOp)

res = list(apply.res)[0]

if (len(res.uses) > 1) or (not isinstance(
(store := list(res.uses)[0].operation), StoreOp)):
warn("Only single store result atm")
return
for i, res in enumerate(list(apply.res)):
if (len(res.uses) > 1) or (not isinstance(
(store := list(res.uses)[0].operation), StoreOp)):
warn("Only single store for a single return op result atm")
return

cast = store.field.owner
cast = store.field.owner

assert isinstance(cast, CastOp)
assert isinstance(cast, CastOp)

return_targets[op] = cast
return_targets[str(op) + "_result" + str(i)] = cast

module.walk(map_returns)

Expand All @@ -366,7 +372,7 @@ def map_returns(op: Operation) -> None:
])


def StencilConversion(return_targets: dict[ReturnOp, CastOp | memref.Cast],
def StencilConversion(return_targets: dict[str, CastOp | memref.Cast],
gpu: bool):
return GreedyRewritePatternApplier([
ApplyOpToParallel(),
Expand Down

0 comments on commit ba826ec

Please sign in to comment.