Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformations: Add stencil-storage-materialization and first tests. #1111

Merged
merged 8 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
94 changes: 94 additions & 0 deletions tests/filecheck/transforms/stencil-storage-materialization.mlir
PapyChacal marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// RUN: xdsl-opt %s -p stencil-storage-materialization | filecheck %s

builtin.module{
func.func @copy(%in : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) {
%int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<?xf64>
%outt = "stencil.apply"(%int) ({
^0(%inb : !stencil.temp<?xf64>):
%v = "stencil.access"(%inb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
"stencil.return"(%v) : (f64) -> ()
}) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
"stencil.store"(%outt, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> ()
"func.return"() : () -> ()
}

// CHECK: func.func @copy(%in : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) {
// CHECK-NEXT: %int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<?xf64>
// CHECK-NEXT: %outt = "stencil.apply"(%int) ({
// CHECK-NEXT: ^0(%inb : !stencil.temp<?xf64>):
// CHECK-NEXT: %v = "stencil.access"(%inb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
// CHECK-NEXT: "stencil.return"(%v) : (f64) -> ()
// CHECK-NEXT: }) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
// CHECK-NEXT: "stencil.store"(%outt, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> ()
// CHECK-NEXT: "func.return"() : () -> ()
// CHECK-NEXT: }

func.func @buffer_copy(%in : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) {
%int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<?xf64>
%midt = "stencil.apply"(%int) ({
^0(%inb : !stencil.temp<?xf64>):
%v = "stencil.access"(%inb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
"stencil.return"(%v) : (f64) -> ()
}) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
%outt = "stencil.apply"(%midt) ({
^0(%midb : !stencil.temp<?xf64>):
%v = "stencil.access"(%midb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
"stencil.return"(%v) : (f64) -> ()
}) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
"stencil.store"(%outt, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> ()
"func.return"() : () -> ()
}

//CHECK: func.func @buffer_copy(%in_1 : !stencil.field<[-4,68]xf64>, %out_1 : !stencil.field<[-4,68]xf64>) {
//CHECK-NEXT: %int_1 = "stencil.load"(%in_1) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<?xf64>
//CHECK-NEXT: %midt = "stencil.apply"(%int_1) ({
//CHECK-NEXT: ^1(%0 : !stencil.temp<?xf64>):
//CHECK-NEXT: %1 = "stencil.access"(%0) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
//CHECK-NEXT: "stencil.return"(%1) : (f64) -> ()
//CHECK-NEXT: }) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
//CHECK-NEXT: %midt_1 = "stencil.buffer"(%midt) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
//CHECK-NEXT: %outt_1 = "stencil.apply"(%midt_1) ({
//CHECK-NEXT: ^2(%midb : !stencil.temp<?xf64>):
//CHECK-NEXT: %v_1 = "stencil.access"(%midb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
//CHECK-NEXT: "stencil.return"(%v_1) : (f64) -> ()
//CHECK-NEXT: }) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
//CHECK-NEXT: "stencil.store"(%outt_1, %out_1) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> ()
//CHECK-NEXT: "func.return"() : () -> ()
//CHECK-NEXT: }

func.func @stored_copy(%in : !stencil.field<[-4,68]xf64>, %midout : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) {
%int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<?xf64>
%midt = "stencil.apply"(%int) ({
^0(%inb : !stencil.temp<?xf64>):
%v = "stencil.access"(%inb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
"stencil.return"(%v) : (f64) -> ()
}) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
"stencil.store"(%midt, %midout) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> ()
%outt = "stencil.apply"(%midt) ({
^0(%midb : !stencil.temp<?xf64>):
%v = "stencil.access"(%midb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
"stencil.return"(%v) : (f64) -> ()
}) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
"stencil.store"(%outt, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> ()
"func.return"() : () -> ()
}

// CHECK: func.func @stored_copy(%in_2 : !stencil.field<[-4,68]xf64>, %midout : !stencil.field<[-4,68]xf64>, %out_2 : !stencil.field<[-4,68]xf64>) {
// CHECK-NEXT: %int_2 = "stencil.load"(%in_2) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<?xf64>
// CHECK-NEXT: %midt_2 = "stencil.apply"(%int_2) ({
// CHECK-NEXT: ^3(%inb_1 : !stencil.temp<?xf64>):
// CHECK-NEXT: %v_2 = "stencil.access"(%inb_1) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
// CHECK-NEXT: "stencil.return"(%v_2) : (f64) -> ()
// CHECK-NEXT: }) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
// CHECK-NEXT: "stencil.store"(%midt_2, %midout) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> ()
// CHECK-NEXT: %outt_2 = "stencil.apply"(%midt_2) ({
// CHECK-NEXT: ^4(%midb_1 : !stencil.temp<?xf64>):
// CHECK-NEXT: %v_3 = "stencil.access"(%midb_1) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
// CHECK-NEXT: "stencil.return"(%v_3) : (f64) -> ()
// CHECK-NEXT: }) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64>
// CHECK-NEXT: "stencil.store"(%outt_2, %out_2) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> ()
// CHECK-NEXT: "func.return"() : () -> ()
// CHECK-NEXT: }
}

// CHECK: }
15 changes: 14 additions & 1 deletion xdsl/dialects/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@
)
from xdsl.dialects import memref

from xdsl.ir import OpResult, SSAValue, Operation, Attribute, Dialect, TypeAttribute
from xdsl.ir import (
Block,
OpResult,
Region,
SSAValue,
Operation,
Attribute,
Dialect,
TypeAttribute,
)
from xdsl.irdl import (
attr_def,
irdl_attr_definition,
Expand Down Expand Up @@ -572,6 +581,10 @@ class BufferOp(IRDLOperation):
temp: Operand = operand_def(TempType)
res: OpResult = result_def(TempType)

def __init__(self: IRDLOperation, temp: SSAValue | Operation):
temp = SSAValue.get(temp)
PapyChacal marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(operands=[temp], result_types=[temp.typ])

def verify_(self) -> None:
if self.temp.typ != self.res.typ:
raise VerifyException(
Expand Down
68 changes: 68 additions & 0 deletions xdsl/transforms/experimental/StencilStorageMaterialization.py
PapyChacal marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from xdsl.dialects import builtin
from xdsl.dialects.stencil import (
ApplyOp,
BufferOp,
StoreOp,
)

from xdsl.ir import MLContext, SSAValue
from xdsl.ir.core import OpResult
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriteWalker,
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)


def should_materialize(temp: SSAValue):
"""
Predicates if a specific stencil.apply output should be buffered.
It should if it is used by another stencil.apply and not already buffered or stored.
"""
return any(isinstance(u.operation, ApplyOp) for u in temp.uses) and not any(
isinstance(u.operation, StoreOp | BufferOp) for u in temp.uses
)


class ApplyOpMaterialization(RewritePattern):
PapyChacal marked this conversation as resolved.
Show resolved Hide resolved
"""
Adds stencil.buffer to any used output of a stencil.apply that is not otherwised
mapped to storage.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /):
clone = op.clone()
new_res: list[OpResult] = []
buffers: list[BufferOp] = []
for i, out in enumerate(op.res):
if should_materialize(out):
buffer = BufferOp(clone.res[i])
buffers.append(buffer)
new_res.append(buffer.res)
else:
new_res.append(out)
if buffers:
rewriter.replace_matched_op([clone, *buffers], new_res)


class StencilStorageMaterializationPass(ModulePass):
"""
Pass adding stencil.buffer whenever necessary to lower a stencil dialect IR,
by adding stencil.buffer on any used stencil.apply output not otherwise mapped
to storage.
"""

name = "stencil-storage-materialization"
PapyChacal marked this conversation as resolved.
Show resolved Hide resolved

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier(
[
ApplyOpMaterialization(),
]
)
).rewrite_module(op)
4 changes: 4 additions & 0 deletions xdsl/xdsl_opt_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@

from xdsl.frontend.passes.desymref import DesymrefyPass
from xdsl.transforms.dead_code_elimination import DeadCodeElimination
from xdsl.transforms.experimental.StencilStorageMaterialization import (
StencilStorageMaterializationPass,
)
from xdsl.transforms.riscv_register_allocation import RISCVRegisterAllocation
from xdsl.transforms.lower_riscv_func import LowerRISCVFunc
from xdsl.transforms.lower_mpi import LowerMPIPass
Expand Down Expand Up @@ -106,6 +109,7 @@ def get_all_passes() -> list[type[ModulePass]]:
LowerSnitchRuntimePass,
RISCVRegisterAllocation,
StencilShapeInferencePass,
StencilStorageMaterializationPass,
]


Expand Down