-
Notifications
You must be signed in to change notification settings - Fork 67
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
transformations: Add stencil-storage-materialization and first tests. (…
…#1111) Add the stencil-storage-materialization pass first implementation. The idea is to add `stencil.buffer` on temporary values used and not otherwise mapped to storage. This allows to lower chains of apply which are just chained by direct usage of another's result; can be a nice baseline to compare inlining to.
- Loading branch information
1 parent
fdbe68a
commit 83fedee
Showing
4 changed files
with
186 additions
and
1 deletion.
There are no files selected for viewing
100 changes: 100 additions & 0 deletions
100
tests/filecheck/transforms/stencil-storage-materialization.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
// RUN: xdsl-opt %s -p stencil-storage-materialization | filecheck %s | ||
|
||
// This should not change with the pass applied. | ||
|
||
builtin.module{ | ||
func.func @copy(%in : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) { | ||
%int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<?xf64> | ||
%outt = "stencil.apply"(%int) ({ | ||
^0(%inb : !stencil.temp<?xf64>): | ||
%v = "stencil.access"(%inb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64 | ||
"stencil.return"(%v) : (f64) -> () | ||
}) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64> | ||
"stencil.store"(%outt, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> () | ||
"func.return"() : () -> () | ||
} | ||
|
||
// CHECK: func.func @copy(%in : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) { | ||
// CHECK-NEXT: %int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<?xf64> | ||
// CHECK-NEXT: %outt = "stencil.apply"(%int) ({ | ||
// CHECK-NEXT: ^0(%inb : !stencil.temp<?xf64>): | ||
// CHECK-NEXT: %v = "stencil.access"(%inb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64 | ||
// CHECK-NEXT: "stencil.return"(%v) : (f64) -> () | ||
// CHECK-NEXT: }) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64> | ||
// CHECK-NEXT: "stencil.store"(%outt, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> () | ||
// CHECK-NEXT: "func.return"() : () -> () | ||
// CHECK-NEXT: } | ||
|
||
// Here we want to see a buffer added after the first apply. | ||
|
||
func.func @buffer_copy(%in : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) { | ||
%int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<?xf64> | ||
%midt = "stencil.apply"(%int) ({ | ||
^0(%inb : !stencil.temp<?xf64>): | ||
%v = "stencil.access"(%inb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64 | ||
"stencil.return"(%v) : (f64) -> () | ||
}) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64> | ||
%outt = "stencil.apply"(%midt) ({ | ||
^0(%midb : !stencil.temp<?xf64>): | ||
%v = "stencil.access"(%midb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64 | ||
"stencil.return"(%v) : (f64) -> () | ||
}) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64> | ||
"stencil.store"(%outt, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> () | ||
"func.return"() : () -> () | ||
} | ||
|
||
//CHECK: func.func @buffer_copy(%in_1 : !stencil.field<[-4,68]xf64>, %out_1 : !stencil.field<[-4,68]xf64>) { | ||
//CHECK-NEXT: %int_1 = "stencil.load"(%in_1) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<?xf64> | ||
//CHECK-NEXT: %midt = "stencil.apply"(%int_1) ({ | ||
//CHECK-NEXT: ^1(%0 : !stencil.temp<?xf64>): | ||
//CHECK-NEXT: %1 = "stencil.access"(%0) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64 | ||
//CHECK-NEXT: "stencil.return"(%1) : (f64) -> () | ||
//CHECK-NEXT: }) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64> | ||
//CHECK-NEXT: %midt_1 = "stencil.buffer"(%midt) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64> | ||
//CHECK-NEXT: %outt_1 = "stencil.apply"(%midt_1) ({ | ||
//CHECK-NEXT: ^2(%midb : !stencil.temp<?xf64>): | ||
//CHECK-NEXT: %v_1 = "stencil.access"(%midb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64 | ||
//CHECK-NEXT: "stencil.return"(%v_1) : (f64) -> () | ||
//CHECK-NEXT: }) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64> | ||
//CHECK-NEXT: "stencil.store"(%outt_1, %out_1) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> () | ||
//CHECK-NEXT: "func.return"() : () -> () | ||
//CHECK-NEXT: } | ||
|
||
// Here we don't want to see a buffer added after the apply, because the result is stored. | ||
|
||
func.func @stored_copy(%in : !stencil.field<[-4,68]xf64>, %midout : !stencil.field<[-4,68]xf64>, %out : !stencil.field<[-4,68]xf64>) { | ||
%int = "stencil.load"(%in) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<?xf64> | ||
%midt = "stencil.apply"(%int) ({ | ||
^0(%inb : !stencil.temp<?xf64>): | ||
%v = "stencil.access"(%inb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64 | ||
"stencil.return"(%v) : (f64) -> () | ||
}) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64> | ||
"stencil.store"(%midt, %midout) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> () | ||
%outt = "stencil.apply"(%midt) ({ | ||
^0(%midb : !stencil.temp<?xf64>): | ||
%v = "stencil.access"(%midb) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64 | ||
"stencil.return"(%v) : (f64) -> () | ||
}) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64> | ||
"stencil.store"(%outt, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> () | ||
"func.return"() : () -> () | ||
} | ||
|
||
// CHECK: func.func @stored_copy(%in_2 : !stencil.field<[-4,68]xf64>, %midout : !stencil.field<[-4,68]xf64>, %out_2 : !stencil.field<[-4,68]xf64>) { | ||
// CHECK-NEXT: %int_2 = "stencil.load"(%in_2) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<?xf64> | ||
// CHECK-NEXT: %midt_2 = "stencil.apply"(%int_2) ({ | ||
// CHECK-NEXT: ^3(%inb_1 : !stencil.temp<?xf64>): | ||
// CHECK-NEXT: %v_2 = "stencil.access"(%inb_1) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64 | ||
// CHECK-NEXT: "stencil.return"(%v_2) : (f64) -> () | ||
// CHECK-NEXT: }) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64> | ||
// CHECK-NEXT: "stencil.store"(%midt_2, %midout) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> () | ||
// CHECK-NEXT: %outt_2 = "stencil.apply"(%midt_2) ({ | ||
// CHECK-NEXT: ^4(%midb_1 : !stencil.temp<?xf64>): | ||
// CHECK-NEXT: %v_3 = "stencil.access"(%midb_1) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64 | ||
// CHECK-NEXT: "stencil.return"(%v_3) : (f64) -> () | ||
// CHECK-NEXT: }) : (!stencil.temp<?xf64>) -> !stencil.temp<?xf64> | ||
// CHECK-NEXT: "stencil.store"(%outt_2, %out_2) {"lb" = #stencil.index<0>, "ub" = #stencil.index<68>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> () | ||
// CHECK-NEXT: "func.return"() : () -> () | ||
// CHECK-NEXT: } | ||
} | ||
|
||
// CHECK: } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
68 changes: 68 additions & 0 deletions
68
xdsl/transforms/experimental/StencilStorageMaterialization.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from xdsl.dialects import builtin | ||
from xdsl.dialects.stencil import ( | ||
ApplyOp, | ||
BufferOp, | ||
StoreOp, | ||
) | ||
|
||
from xdsl.ir import MLContext, SSAValue | ||
from xdsl.ir.core import OpResult | ||
from xdsl.passes import ModulePass | ||
from xdsl.pattern_rewriter import ( | ||
GreedyRewritePatternApplier, | ||
PatternRewriteWalker, | ||
PatternRewriter, | ||
RewritePattern, | ||
op_type_rewrite_pattern, | ||
) | ||
|
||
|
||
def should_materialize(temp: SSAValue): | ||
""" | ||
Predicates if a specific stencil.apply output should be buffered. | ||
It should if it is used by another stencil.apply and not already buffered or stored. | ||
""" | ||
return any(isinstance(u.operation, ApplyOp) for u in temp.uses) and not any( | ||
isinstance(u.operation, StoreOp | BufferOp) for u in temp.uses | ||
) | ||
|
||
|
||
class ApplyOpMaterialization(RewritePattern): | ||
""" | ||
Adds stencil.buffer to any used output of a stencil.apply that is not otherwised | ||
mapped to storage. | ||
""" | ||
|
||
@op_type_rewrite_pattern | ||
def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /): | ||
clone = op.clone() | ||
new_res: list[OpResult] = [] | ||
buffers: list[BufferOp] = [] | ||
for i, out in enumerate(op.res): | ||
if should_materialize(out): | ||
buffer = BufferOp(clone.res[i]) | ||
buffers.append(buffer) | ||
new_res.append(buffer.res) | ||
else: | ||
new_res.append(out) | ||
if buffers: | ||
rewriter.replace_matched_op([clone, *buffers], new_res) | ||
|
||
|
||
class StencilStorageMaterializationPass(ModulePass): | ||
""" | ||
Pass adding stencil.buffer whenever necessary to lower a stencil dialect IR, | ||
by adding stencil.buffer on any used stencil.apply output not otherwise mapped | ||
to storage. | ||
""" | ||
|
||
name = "stencil-storage-materialization" | ||
|
||
def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: | ||
PatternRewriteWalker( | ||
GreedyRewritePatternApplier( | ||
[ | ||
ApplyOpMaterialization(), | ||
] | ||
) | ||
).rewrite_module(op) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters