diff --git a/tests/filecheck/transforms/stencil-storage-materialization.mlir b/tests/filecheck/transforms/stencil-storage-materialization.mlir new file mode 100644 index 0000000000..e3abab8238 --- /dev/null +++ b/tests/filecheck/transforms/stencil-storage-materialization.mlir @@ -0,0 +1,100 @@ +// RUN: xdsl-opt %s -p stencil-storage-materialization | filecheck %s + +// This should not change with the pass applied. + +builtin.module{ + func.func @copy(%in : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) { + %int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp + %outt = "stencil.apply"(%int) ({ + ^0(%inb : !stencil.temp): + %v = "stencil.access"(%inb) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 + "stencil.return"(%v) : (f64) -> () + }) : (!stencil.temp) -> !stencil.temp + "stencil.store"(%outt, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp, !stencil.field<[-4,68]xf64>) -> () + "func.return"() : () -> () + } + +// CHECK: func.func @copy(%in : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) { +// CHECK-NEXT: %int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp +// CHECK-NEXT: %outt = "stencil.apply"(%int) ({ +// CHECK-NEXT: ^0(%inb : !stencil.temp): +// CHECK-NEXT: %v = "stencil.access"(%inb) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 +// CHECK-NEXT: "stencil.return"(%v) : (f64) -> () +// CHECK-NEXT: }) : (!stencil.temp) -> !stencil.temp +// CHECK-NEXT: "stencil.store"(%outt, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp, !stencil.field<[-4,68]xf64>) -> () +// CHECK-NEXT: "func.return"() : () -> () +// CHECK-NEXT: } + + // Here we want to see a buffer added after the first apply. + + func.func @buffer_copy(%in : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) { + %int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp + %midt = "stencil.apply"(%int) ({ + ^0(%inb : !stencil.temp): + %v = "stencil.access"(%inb) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 + "stencil.return"(%v) : (f64) -> () + }) : (!stencil.temp) -> !stencil.temp + %outt = "stencil.apply"(%midt) ({ + ^0(%midb : !stencil.temp): + %v = "stencil.access"(%midb) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 + "stencil.return"(%v) : (f64) -> () + }) : (!stencil.temp) -> !stencil.temp + "stencil.store"(%outt, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp, !stencil.field<[-4,68]xf64>) -> () + "func.return"() : () -> () + } + + //CHECK: func.func @buffer_copy(%in_1 : !stencil.field<[-4,68]xf64>, %out_1 : !stencil.field<[-4,68]xf64>) { + //CHECK-NEXT: %int_1 = "stencil.load"(%in_1) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp + //CHECK-NEXT: %midt = "stencil.apply"(%int_1) ({ + //CHECK-NEXT: ^1(%0 : !stencil.temp): + //CHECK-NEXT: %1 = "stencil.access"(%0) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 + //CHECK-NEXT: "stencil.return"(%1) : (f64) -> () + //CHECK-NEXT: }) : (!stencil.temp) -> !stencil.temp + //CHECK-NEXT: %midt_1 = "stencil.buffer"(%midt) : (!stencil.temp) -> !stencil.temp + //CHECK-NEXT: %outt_1 = "stencil.apply"(%midt_1) ({ + //CHECK-NEXT: ^2(%midb : !stencil.temp): + //CHECK-NEXT: %v_1 = "stencil.access"(%midb) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 + //CHECK-NEXT: "stencil.return"(%v_1) : (f64) -> () + //CHECK-NEXT: }) : (!stencil.temp) -> !stencil.temp + //CHECK-NEXT: "stencil.store"(%outt_1, %out_1) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp, !stencil.field<[-4,68]xf64>) -> () + //CHECK-NEXT: "func.return"() : () -> () + //CHECK-NEXT: } + + // Here we don't want to see a buffer added after the apply, because the result is stored. + + func.func @stored_copy(%in : !stencil.field<[-4,68]xf64>, %midout : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) { + %int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp + %midt = "stencil.apply"(%int) ({ + ^0(%inb : !stencil.temp): + %v = "stencil.access"(%inb) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 + "stencil.return"(%v) : (f64) -> () + }) : (!stencil.temp) -> !stencil.temp + "stencil.store"(%midt, %midout) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp, !stencil.field<[-4,68]xf64>) -> () + %outt = "stencil.apply"(%midt) ({ + ^0(%midb : !stencil.temp): + %v = "stencil.access"(%midb) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 + "stencil.return"(%v) : (f64) -> () + }) : (!stencil.temp) -> !stencil.temp + "stencil.store"(%outt, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp, !stencil.field<[-4,68]xf64>) -> () + "func.return"() : () -> () + } + +// CHECK: func.func @stored_copy(%in_2 : !stencil.field<[-4,68]xf64>, %midout : !stencil.field<[-4,68]xf64>, %out_2 : !stencil.field<[-4,68]xf64>) { +// CHECK-NEXT: %int_2 = "stencil.load"(%in_2) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp +// CHECK-NEXT: %midt_2 = "stencil.apply"(%int_2) ({ +// CHECK-NEXT: ^3(%inb_1 : !stencil.temp): +// CHECK-NEXT: %v_2 = "stencil.access"(%inb_1) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 +// CHECK-NEXT: "stencil.return"(%v_2) : (f64) -> () +// CHECK-NEXT: }) : (!stencil.temp) -> !stencil.temp +// CHECK-NEXT: "stencil.store"(%midt_2, %midout) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp, !stencil.field<[-4,68]xf64>) -> () +// CHECK-NEXT: %outt_2 = "stencil.apply"(%midt_2) ({ +// CHECK-NEXT: ^4(%midb_1 : !stencil.temp): +// CHECK-NEXT: %v_3 = "stencil.access"(%midb_1) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 +// CHECK-NEXT: "stencil.return"(%v_3) : (f64) -> () +// CHECK-NEXT: }) : (!stencil.temp) -> !stencil.temp +// CHECK-NEXT: "stencil.store"(%outt_2, %out_2) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp, !stencil.field<[-4,68]xf64>) -> () +// CHECK-NEXT: "func.return"() : () -> () +// CHECK-NEXT: } +} + +// CHECK: } diff --git a/xdsl/dialects/stencil.py b/xdsl/dialects/stencil.py index c21390f6fa..681d6acf05 100644 --- a/xdsl/dialects/stencil.py +++ b/xdsl/dialects/stencil.py @@ -13,7 +13,16 @@ ) from xdsl.dialects import memref -from xdsl.ir import OpResult, SSAValue, Operation, Attribute, Dialect, TypeAttribute +from xdsl.ir import ( + Block, + OpResult, + Region, + SSAValue, + Operation, + Attribute, + Dialect, + TypeAttribute, +) from xdsl.irdl import ( attr_def, irdl_attr_definition, @@ -580,6 +589,10 @@ class BufferOp(IRDLOperation): temp: Operand = operand_def(TempType) res: OpResult = result_def(TempType) + def __init__(self: IRDLOperation, temp: SSAValue | Operation): + temp = SSAValue.get(temp) + super().__init__(operands=[temp], result_types=[temp.typ]) + def verify_(self) -> None: if self.temp.typ != self.res.typ: raise VerifyException( diff --git a/xdsl/transforms/experimental/StencilStorageMaterialization.py b/xdsl/transforms/experimental/StencilStorageMaterialization.py new file mode 100644 index 0000000000..d952bab2e0 --- /dev/null +++ b/xdsl/transforms/experimental/StencilStorageMaterialization.py @@ -0,0 +1,68 @@ +from xdsl.dialects import builtin +from xdsl.dialects.stencil import ( + ApplyOp, + BufferOp, + StoreOp, +) + +from xdsl.ir import MLContext, SSAValue +from xdsl.ir.core import OpResult +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, + PatternRewriteWalker, + PatternRewriter, + RewritePattern, + op_type_rewrite_pattern, +) + + +def should_materialize(temp: SSAValue): + """ + Predicates if a specific stencil.apply output should be buffered. + It should if it is used by another stencil.apply and not already buffered or stored. + """ + return any(isinstance(u.operation, ApplyOp) for u in temp.uses) and not any( + isinstance(u.operation, StoreOp | BufferOp) for u in temp.uses + ) + + +class ApplyOpMaterialization(RewritePattern): + """ + Adds stencil.buffer to any used output of a stencil.apply that is not otherwised + mapped to storage. + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /): + clone = op.clone() + new_res: list[OpResult] = [] + buffers: list[BufferOp] = [] + for i, out in enumerate(op.res): + if should_materialize(out): + buffer = BufferOp(clone.res[i]) + buffers.append(buffer) + new_res.append(buffer.res) + else: + new_res.append(out) + if buffers: + rewriter.replace_matched_op([clone, *buffers], new_res) + + +class StencilStorageMaterializationPass(ModulePass): + """ + Pass adding stencil.buffer whenever necessary to lower a stencil dialect IR, + by adding stencil.buffer on any used stencil.apply output not otherwise mapped + to storage. + """ + + name = "stencil-storage-materialization" + + def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: + PatternRewriteWalker( + GreedyRewritePatternApplier( + [ + ApplyOpMaterialization(), + ] + ) + ).rewrite_module(op) diff --git a/xdsl/xdsl_opt_main.py b/xdsl/xdsl_opt_main.py index e0880cbef3..19614dcee3 100644 --- a/xdsl/xdsl_opt_main.py +++ b/xdsl/xdsl_opt_main.py @@ -37,6 +37,9 @@ from xdsl.frontend.passes.desymref import DesymrefyPass from xdsl.transforms.dead_code_elimination import DeadCodeElimination +from xdsl.transforms.experimental.StencilStorageMaterialization import ( + StencilStorageMaterializationPass, +) from xdsl.transforms.riscv_register_allocation import RISCVRegisterAllocation from xdsl.transforms.lower_riscv_func import LowerRISCVFunc from xdsl.transforms.lower_mpi import LowerMPIPass @@ -106,6 +109,7 @@ def get_all_passes() -> list[type[ModulePass]]: LowerSnitchRuntimePass, RISCVRegisterAllocation, StencilShapeInferencePass, + StencilStorageMaterializationPass, ]