From 9d6f60c61341e6ff3f821093f4c0ac0c3b8ed9d3 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Fri, 9 Jun 2023 18:15:42 +0100 Subject: [PATCH 1/5] Add StencilSotrageMaterializationPass and first tests. --- .../stencil-storage-materialization.mlir | 94 +++++++++++++++++++ xdsl/dialects/stencil.py | 17 +++- .../StencilStorageMaterialization.py | 53 +++++++++++ xdsl/xdsl_opt_main.py | 4 + 4 files changed, 166 insertions(+), 2 deletions(-) create mode 100644 tests/filecheck/transforms/stencil-storage-materialization.mlir create mode 100644 xdsl/transforms/experimental/StencilStorageMaterialization.py diff --git a/tests/filecheck/transforms/stencil-storage-materialization.mlir b/tests/filecheck/transforms/stencil-storage-materialization.mlir new file mode 100644 index 0000000000..056bbf76c4 --- /dev/null +++ b/tests/filecheck/transforms/stencil-storage-materialization.mlir @@ -0,0 +1,94 @@ +// RUN: xdsl-opt %s -p stencil-storage-materialization | filecheck %s + +builtin.module{ + func.func @copy(%in : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) { + %int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp + %outt = "stencil.apply"(%int) ({ + ^0(%inb : !stencil.temp): + %v = "stencil.access"(%inb) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 + "stencil.return"(%v) : (f64) -> () + }) : (!stencil.temp) -> !stencil.temp + "stencil.store"(%outt, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp, !stencil.field<[-4,68]xf64>) -> () + "func.return"() : () -> () + } + +// CHECK: func.func @copy(%in : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) { +// CHECK-NEXT: %int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp +// CHECK-NEXT: %outt = "stencil.apply"(%int) ({ +// CHECK-NEXT: ^0(%inb : !stencil.temp): +// CHECK-NEXT: %v = "stencil.access"(%inb) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 +// CHECK-NEXT: "stencil.return"(%v) : (f64) -> () +// CHECK-NEXT: }) : (!stencil.temp) -> !stencil.temp +// CHECK-NEXT: "stencil.store"(%outt, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp, !stencil.field<[-4,68]xf64>) -> () +// CHECK-NEXT: "func.return"() : () -> () +// CHECK-NEXT: } + + func.func @buffer_copy(%in : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) { + %int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp + %midt = "stencil.apply"(%int) ({ + ^0(%inb : !stencil.temp): + %v = "stencil.access"(%inb) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 + "stencil.return"(%v) : (f64) -> () + }) : (!stencil.temp) -> !stencil.temp + %outt = "stencil.apply"(%midt) ({ + ^0(%midb : !stencil.temp): + %v = "stencil.access"(%midb) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 + "stencil.return"(%v) : (f64) -> () + }) : (!stencil.temp) -> !stencil.temp + "stencil.store"(%outt, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp, !stencil.field<[-4,68]xf64>) -> () + "func.return"() : () -> () + } + + //CHECK: func.func @buffer_copy(%in_1 : !stencil.field<[-4,68]xf64>, %out_1 : !stencil.field<[-4,68]xf64>) { + //CHECK-NEXT: %int_1 = "stencil.load"(%in_1) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp + //CHECK-NEXT: %midt = "stencil.apply"(%int_1) ({ + //CHECK-NEXT: ^1(%0 : !stencil.temp): + //CHECK-NEXT: %1 = "stencil.access"(%0) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 + //CHECK-NEXT: "stencil.return"(%1) : (f64) -> () + //CHECK-NEXT: }) : (!stencil.temp) -> !stencil.temp + //CHECK-NEXT: %midt_1 = "stencil.buffer"(%midt) : (!stencil.temp) -> !stencil.temp + //CHECK-NEXT: %outt_1 = "stencil.apply"(%midt_1) ({ + //CHECK-NEXT: ^2(%midb : !stencil.temp): + //CHECK-NEXT: %v_1 = "stencil.access"(%midb) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 + //CHECK-NEXT: "stencil.return"(%v_1) : (f64) -> () + //CHECK-NEXT: }) : (!stencil.temp) -> !stencil.temp + //CHECK-NEXT: "stencil.store"(%outt_1, %out_1) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp, !stencil.field<[-4,68]xf64>) -> () + //CHECK-NEXT: "func.return"() : () -> () + //CHECK-NEXT: } + + func.func @stored_copy(%in : !stencil.field<[-4,68]xf64>, %midout : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) { + %int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp + %midt = "stencil.apply"(%int) ({ + ^0(%inb : !stencil.temp): + %v = "stencil.access"(%inb) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 + "stencil.return"(%v) : (f64) -> () + }) : (!stencil.temp) -> !stencil.temp + "stencil.store"(%midt, %midout) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp, !stencil.field<[-4,68]xf64>) -> () + %outt = "stencil.apply"(%midt) ({ + ^0(%midb : !stencil.temp): + %v = "stencil.access"(%midb) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 + "stencil.return"(%v) : (f64) -> () + }) : (!stencil.temp) -> !stencil.temp + "stencil.store"(%outt, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp, !stencil.field<[-4,68]xf64>) -> () + "func.return"() : () -> () + } + +// CHECK: func.func @stored_copy(%in_2 : !stencil.field<[-4,68]xf64>, %midout : !stencil.field<[-4,68]xf64>, %out_2 : !stencil.field<[-4,68]xf64>) { +// CHECK-NEXT: %int_2 = "stencil.load"(%in_2) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp +// CHECK-NEXT: %midt_2 = "stencil.apply"(%int_2) ({ +// CHECK-NEXT: ^3(%inb_1 : !stencil.temp): +// CHECK-NEXT: %v_2 = "stencil.access"(%inb_1) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 +// CHECK-NEXT: "stencil.return"(%v_2) : (f64) -> () +// CHECK-NEXT: }) : (!stencil.temp) -> !stencil.temp +// CHECK-NEXT: "stencil.store"(%midt_2, %midout) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp, !stencil.field<[-4,68]xf64>) -> () +// CHECK-NEXT: %outt_2 = "stencil.apply"(%midt_2) ({ +// CHECK-NEXT: ^4(%midb_1 : !stencil.temp): +// CHECK-NEXT: %v_3 = "stencil.access"(%midb_1) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 +// CHECK-NEXT: "stencil.return"(%v_3) : (f64) -> () +// CHECK-NEXT: }) : (!stencil.temp) -> !stencil.temp +// CHECK-NEXT: "stencil.store"(%outt_2, %out_2) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp, !stencil.field<[-4,68]xf64>) -> () +// CHECK-NEXT: "func.return"() : () -> () +// CHECK-NEXT: } +} + +// CHECK: } diff --git a/xdsl/dialects/stencil.py b/xdsl/dialects/stencil.py index de5d2138b8..4a2fcab9ce 100644 --- a/xdsl/dialects/stencil.py +++ b/xdsl/dialects/stencil.py @@ -1,7 +1,7 @@ from __future__ import annotations from operator import add, lt, neg -from typing import Sequence, TypeVar, cast, Iterable, Iterator, Annotated +from typing import Mapping, Sequence, TypeVar, cast, Iterable, Iterator, Annotated from xdsl.dialects import builtin from xdsl.dialects.builtin import ( @@ -13,7 +13,16 @@ ) from xdsl.dialects import memref -from xdsl.ir import OpResult, SSAValue, Operation, Attribute, Dialect, TypeAttribute +from xdsl.ir import ( + Block, + OpResult, + Region, + SSAValue, + Operation, + Attribute, + Dialect, + TypeAttribute, +) from xdsl.irdl import ( irdl_attr_definition, irdl_op_definition, @@ -495,6 +504,10 @@ class BufferOp(IRDLOperation): temp: Annotated[Operand, TempType] res: Annotated[OpResult, TempType] + def __init__(self: IRDLOperation, temp: SSAValue | Operation): + temp = SSAValue.get(temp) + super().__init__(operands=[temp], result_types=[temp.typ]) + def verify_(self) -> None: if self.temp.typ != self.res.typ: raise VerifyException( diff --git a/xdsl/transforms/experimental/StencilStorageMaterialization.py b/xdsl/transforms/experimental/StencilStorageMaterialization.py new file mode 100644 index 0000000000..ee506f5d47 --- /dev/null +++ b/xdsl/transforms/experimental/StencilStorageMaterialization.py @@ -0,0 +1,53 @@ +from xdsl.dialects import builtin +from xdsl.dialects.stencil import ( + ApplyOp, + BufferOp, + StoreOp, +) + +from xdsl.ir import MLContext, SSAValue +from xdsl.ir.core import OpResult +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, + PatternRewriteWalker, + PatternRewriter, + RewritePattern, + op_type_rewrite_pattern, +) + + +def should_materialize(temp: SSAValue): + return any(isinstance(u.operation, ApplyOp) for u in temp.uses) and not any( + isinstance(u.operation, StoreOp | BufferOp) for u in temp.uses + ) + + +class ApplyOpMaterialization(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /): + clone = op.clone() + new_res: list[OpResult] = [] + buffers: list[BufferOp] = [] + for i, out in enumerate(op.res): + if should_materialize(out): + buffer = BufferOp(clone.res[i]) + buffers.append(buffer) + new_res.append(buffer.res) + else: + new_res.append(out) + if buffers: + rewriter.replace_matched_op([clone, *buffers], new_res) + + +class StencilStorageMaterializationPass(ModulePass): + name = "stencil-storage-materialization" + + def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: + PatternRewriteWalker( + GreedyRewritePatternApplier( + [ + ApplyOpMaterialization(), + ] + ) + ).rewrite_module(op) diff --git a/xdsl/xdsl_opt_main.py b/xdsl/xdsl_opt_main.py index 5554e0b857..4d9cc7b369 100644 --- a/xdsl/xdsl_opt_main.py +++ b/xdsl/xdsl_opt_main.py @@ -36,6 +36,9 @@ from xdsl.frontend.passes.desymref import DesymrefyPass from xdsl.transforms.dead_code_elimination import DeadCodeElimination +from xdsl.transforms.experimental.StencilStorageMaterialization import ( + StencilStorageMaterializationPass, +) from xdsl.transforms.riscv_register_allocation import RISCVRegisterAllocation from xdsl.transforms.lower_riscv_func import LowerRISCVFunc from xdsl.transforms.lower_mpi import LowerMPIPass @@ -104,6 +107,7 @@ def get_all_passes() -> list[type[ModulePass]]: LowerSnitchRuntimePass, RISCVRegisterAllocation, StencilShapeInferencePass, + StencilStorageMaterializationPass, ] From 45bc079ca4681b19c93553f001e800830c59a5ee Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Fri, 9 Jun 2023 18:23:03 +0100 Subject: [PATCH 2/5] pyright --- xdsl/dialects/stencil.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xdsl/dialects/stencil.py b/xdsl/dialects/stencil.py index 4a2fcab9ce..34b464babc 100644 --- a/xdsl/dialects/stencil.py +++ b/xdsl/dialects/stencil.py @@ -1,7 +1,7 @@ from __future__ import annotations from operator import add, lt, neg -from typing import Mapping, Sequence, TypeVar, cast, Iterable, Iterator, Annotated +from typing import Sequence, TypeVar, cast, Iterable, Iterator, Annotated from xdsl.dialects import builtin from xdsl.dialects.builtin import ( From ef3329a679ff24d28b80ffff78c51a59b3910699 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Mon, 12 Jun 2023 18:42:40 +0100 Subject: [PATCH 3/5] Add docstrings. --- .../experimental/StencilStorageMaterialization.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/xdsl/transforms/experimental/StencilStorageMaterialization.py b/xdsl/transforms/experimental/StencilStorageMaterialization.py index ee506f5d47..472344e807 100644 --- a/xdsl/transforms/experimental/StencilStorageMaterialization.py +++ b/xdsl/transforms/experimental/StencilStorageMaterialization.py @@ -24,6 +24,11 @@ def should_materialize(temp: SSAValue): class ApplyOpMaterialization(RewritePattern): + """ + Adds stencil.buffer to any output of a stencil.apply that is not otherwised mapped + to storage. + """ + @op_type_rewrite_pattern def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /): clone = op.clone() @@ -41,6 +46,12 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /): class StencilStorageMaterializationPass(ModulePass): + """ + Pass adding stencil.buffer whenever necessary to lower a stencil dialect IR, + by adding stencil.buffer on any stencil.apply output not otherwise mapped + to storage. + """ + name = "stencil-storage-materialization" def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: From c5136bb288e56a87c24620e4508899c93921f24b Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Mon, 12 Jun 2023 18:44:19 +0100 Subject: [PATCH 4/5] More docstrings. --- .../experimental/StencilStorageMaterialization.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/xdsl/transforms/experimental/StencilStorageMaterialization.py b/xdsl/transforms/experimental/StencilStorageMaterialization.py index 472344e807..d952bab2e0 100644 --- a/xdsl/transforms/experimental/StencilStorageMaterialization.py +++ b/xdsl/transforms/experimental/StencilStorageMaterialization.py @@ -18,6 +18,10 @@ def should_materialize(temp: SSAValue): + """ + Predicates if a specific stencil.apply output should be buffered. + It should if it is used by another stencil.apply and not already buffered or stored. + """ return any(isinstance(u.operation, ApplyOp) for u in temp.uses) and not any( isinstance(u.operation, StoreOp | BufferOp) for u in temp.uses ) @@ -25,8 +29,8 @@ def should_materialize(temp: SSAValue): class ApplyOpMaterialization(RewritePattern): """ - Adds stencil.buffer to any output of a stencil.apply that is not otherwised mapped - to storage. + Adds stencil.buffer to any used output of a stencil.apply that is not otherwised + mapped to storage. """ @op_type_rewrite_pattern @@ -48,7 +52,7 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /): class StencilStorageMaterializationPass(ModulePass): """ Pass adding stencil.buffer whenever necessary to lower a stencil dialect IR, - by adding stencil.buffer on any stencil.apply output not otherwise mapped + by adding stencil.buffer on any used stencil.apply output not otherwise mapped to storage. """ From 46178004a1d5eb4fc544517dd48a99fb5ceb3b73 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Wed, 14 Jun 2023 15:41:01 +0100 Subject: [PATCH 5/5] Add filecheck comments. --- .../transforms/stencil-storage-materialization.mlir | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/filecheck/transforms/stencil-storage-materialization.mlir b/tests/filecheck/transforms/stencil-storage-materialization.mlir index 056bbf76c4..e3abab8238 100644 --- a/tests/filecheck/transforms/stencil-storage-materialization.mlir +++ b/tests/filecheck/transforms/stencil-storage-materialization.mlir @@ -1,5 +1,7 @@ // RUN: xdsl-opt %s -p stencil-storage-materialization | filecheck %s +// This should not change with the pass applied. + builtin.module{ func.func @copy(%in : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) { %int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp @@ -23,6 +25,8 @@ builtin.module{ // CHECK-NEXT: "func.return"() : () -> () // CHECK-NEXT: } + // Here we want to see a buffer added after the first apply. + func.func @buffer_copy(%in : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) { %int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp %midt = "stencil.apply"(%int) ({ @@ -56,6 +60,8 @@ builtin.module{ //CHECK-NEXT: "func.return"() : () -> () //CHECK-NEXT: } + // Here we don't want to see a buffer added after the apply, because the result is stored. + func.func @stored_copy(%in : !stencil.field<[-4,68]xf64>, %midout : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) { %int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp %midt = "stencil.apply"(%int) ({