Skip to content

Commit

Permalink
transformations: Add stencil-storage-materialization and first tests. (
Browse files Browse the repository at this point in the history
…#1111)

Add the stencil-storage-materialization pass first implementation.

The idea is to add `stencil.buffer` on temporary values used and not
otherwise mapped to storage.
This allows to lower chains of apply which are just chained by direct
usage of another's result; can be a nice baseline to compare inlining
to.
  • Loading branch information
PapyChacal committed Jun 14, 2023
1 parent fdbe68a commit 83fedee
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 1 deletion.
100 changes: 100 additions & 0 deletions tests/filecheck/transforms/stencil-storage-materialization.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// RUN: xdsl-opt %s -p stencil-storage-materialization | filecheck %s

// This should not change with the pass applied.

builtin.module{
func.func @copy(%in : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) {
%int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<?xf64>
%outt = "stencil.apply"(%int) ({
^0(%inb : !stencil.temp<?xf64>):
%v = "stencil.access"(%inb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
"stencil.return"(%v) : (f64) -> ()
}) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
"stencil.store"(%outt, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> ()
"func.return"() : () -> ()
}

// CHECK: func.func @copy(%in : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) {
// CHECK-NEXT: %int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<?xf64>
// CHECK-NEXT: %outt = "stencil.apply"(%int) ({
// CHECK-NEXT: ^0(%inb : !stencil.temp<?xf64>):
// CHECK-NEXT: %v = "stencil.access"(%inb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
// CHECK-NEXT: "stencil.return"(%v) : (f64) -> ()
// CHECK-NEXT: }) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
// CHECK-NEXT: "stencil.store"(%outt, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> ()
// CHECK-NEXT: "func.return"() : () -> ()
// CHECK-NEXT: }

// Here we want to see a buffer added after the first apply.

func.func @buffer_copy(%in : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) {
%int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<?xf64>
%midt = "stencil.apply"(%int) ({
^0(%inb : !stencil.temp<?xf64>):
%v = "stencil.access"(%inb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
"stencil.return"(%v) : (f64) -> ()
}) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
%outt = "stencil.apply"(%midt) ({
^0(%midb : !stencil.temp<?xf64>):
%v = "stencil.access"(%midb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
"stencil.return"(%v) : (f64) -> ()
}) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
"stencil.store"(%outt, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> ()
"func.return"() : () -> ()
}

//CHECK: func.func @buffer_copy(%in_1 : !stencil.field<[-4,68]xf64>, %out_1 : !stencil.field<[-4,68]xf64>) {
//CHECK-NEXT: %int_1 = "stencil.load"(%in_1) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<?xf64>
//CHECK-NEXT: %midt = "stencil.apply"(%int_1) ({
//CHECK-NEXT: ^1(%0 : !stencil.temp<?xf64>):
//CHECK-NEXT: %1 = "stencil.access"(%0) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
//CHECK-NEXT: "stencil.return"(%1) : (f64) -> ()
//CHECK-NEXT: }) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
//CHECK-NEXT: %midt_1 = "stencil.buffer"(%midt) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
//CHECK-NEXT: %outt_1 = "stencil.apply"(%midt_1) ({
//CHECK-NEXT: ^2(%midb : !stencil.temp<?xf64>):
//CHECK-NEXT: %v_1 = "stencil.access"(%midb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
//CHECK-NEXT: "stencil.return"(%v_1) : (f64) -> ()
//CHECK-NEXT: }) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
//CHECK-NEXT: "stencil.store"(%outt_1, %out_1) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> ()
//CHECK-NEXT: "func.return"() : () -> ()
//CHECK-NEXT: }

// Here we don't want to see a buffer added after the apply, because the result is stored.

func.func @stored_copy(%in : !stencil.field<[-4,68]xf64>, %midout : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) {
%int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<?xf64>
%midt = "stencil.apply"(%int) ({
^0(%inb : !stencil.temp<?xf64>):
%v = "stencil.access"(%inb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
"stencil.return"(%v) : (f64) -> ()
}) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
"stencil.store"(%midt, %midout) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> ()
%outt = "stencil.apply"(%midt) ({
^0(%midb : !stencil.temp<?xf64>):
%v = "stencil.access"(%midb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
"stencil.return"(%v) : (f64) -> ()
}) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
"stencil.store"(%outt, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> ()
"func.return"() : () -> ()
}

// CHECK: func.func @stored_copy(%in_2 : !stencil.field<[-4,68]xf64>, %midout : !stencil.field<[-4,68]xf64>, %out_2 : !stencil.field<[-4,68]xf64>) {
// CHECK-NEXT: %int_2 = "stencil.load"(%in_2) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<?xf64>
// CHECK-NEXT: %midt_2 = "stencil.apply"(%int_2) ({
// CHECK-NEXT: ^3(%inb_1 : !stencil.temp<?xf64>):
// CHECK-NEXT: %v_2 = "stencil.access"(%inb_1) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
// CHECK-NEXT: "stencil.return"(%v_2) : (f64) -> ()
// CHECK-NEXT: }) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
// CHECK-NEXT: "stencil.store"(%midt_2, %midout) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> ()
// CHECK-NEXT: %outt_2 = "stencil.apply"(%midt_2) ({
// CHECK-NEXT: ^4(%midb_1 : !stencil.temp<?xf64>):
// CHECK-NEXT: %v_3 = "stencil.access"(%midb_1) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
// CHECK-NEXT: "stencil.return"(%v_3) : (f64) -> ()
// CHECK-NEXT: }) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
// CHECK-NEXT: "stencil.store"(%outt_2, %out_2) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> ()
// CHECK-NEXT: "func.return"() : () -> ()
// CHECK-NEXT: }
}

// CHECK: }
15 changes: 14 additions & 1 deletion xdsl/dialects/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@
)
from xdsl.dialects import memref

from xdsl.ir import OpResult, SSAValue, Operation, Attribute, Dialect, TypeAttribute
from xdsl.ir import (
Block,
OpResult,
Region,
SSAValue,
Operation,
Attribute,
Dialect,
TypeAttribute,
)
from xdsl.irdl import (
attr_def,
irdl_attr_definition,
Expand Down Expand Up @@ -580,6 +589,10 @@ class BufferOp(IRDLOperation):
temp: Operand = operand_def(TempType)
res: OpResult = result_def(TempType)

def __init__(self: IRDLOperation, temp: SSAValue | Operation):
temp = SSAValue.get(temp)
super().__init__(operands=[temp], result_types=[temp.typ])

def verify_(self) -> None:
if self.temp.typ != self.res.typ:
raise VerifyException(
Expand Down
68 changes: 68 additions & 0 deletions xdsl/transforms/experimental/StencilStorageMaterialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from xdsl.dialects import builtin
from xdsl.dialects.stencil import (
ApplyOp,
BufferOp,
StoreOp,
)

from xdsl.ir import MLContext, SSAValue
from xdsl.ir.core import OpResult
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriteWalker,
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)


def should_materialize(temp: SSAValue):
"""
Predicates if a specific stencil.apply output should be buffered.
It should if it is used by another stencil.apply and not already buffered or stored.
"""
return any(isinstance(u.operation, ApplyOp) for u in temp.uses) and not any(
isinstance(u.operation, StoreOp | BufferOp) for u in temp.uses
)


class ApplyOpMaterialization(RewritePattern):
"""
Adds stencil.buffer to any used output of a stencil.apply that is not otherwised
mapped to storage.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /):
clone = op.clone()
new_res: list[OpResult] = []
buffers: list[BufferOp] = []
for i, out in enumerate(op.res):
if should_materialize(out):
buffer = BufferOp(clone.res[i])
buffers.append(buffer)
new_res.append(buffer.res)
else:
new_res.append(out)
if buffers:
rewriter.replace_matched_op([clone, *buffers], new_res)


class StencilStorageMaterializationPass(ModulePass):
"""
Pass adding stencil.buffer whenever necessary to lower a stencil dialect IR,
by adding stencil.buffer on any used stencil.apply output not otherwise mapped
to storage.
"""

name = "stencil-storage-materialization"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier(
[
ApplyOpMaterialization(),
]
)
).rewrite_module(op)
4 changes: 4 additions & 0 deletions xdsl/xdsl_opt_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@

from xdsl.frontend.passes.desymref import DesymrefyPass
from xdsl.transforms.dead_code_elimination import DeadCodeElimination
from xdsl.transforms.experimental.StencilStorageMaterialization import (
StencilStorageMaterializationPass,
)
from xdsl.transforms.riscv_register_allocation import RISCVRegisterAllocation
from xdsl.transforms.lower_riscv_func import LowerRISCVFunc
from xdsl.transforms.lower_mpi import LowerMPIPass
Expand Down Expand Up @@ -106,6 +109,7 @@ def get_all_passes() -> list[type[ModulePass]]:
LowerSnitchRuntimePass,
RISCVRegisterAllocation,
StencilShapeInferencePass,
StencilStorageMaterializationPass,
]


Expand Down

0 comments on commit 83fedee

Please sign in to comment.