Skip to content

Commit

Permalink
Split ApplyOp and AccessOp shape inference patterns, to naturally han…
Browse files Browse the repository at this point in the history
…dle different input offsets.

Add a test case for this.
Now, shape inference walk ops regions first; each accessop offset it's stencil.temp operand size individually.
stencil.apply shape inference then just spreads these bounds to its operands (from its block arguments).
  • Loading branch information
PapyChacal committed Jun 1, 2023
1 parent 4343a8f commit ccd926a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 17 deletions.
34 changes: 34 additions & 0 deletions tests/filecheck/transforms/stencil-shape-inference.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,39 @@
// RUN: xdsl-opt -p stencil-shape-inference --verify-diagnostics --split-input-file %s | filecheck %s

builtin.module {
func.func @different_input_offsets(%out : !stencil.field<[-4,68]xf64>, %left : !stencil.field<[-4,68]xf64>, %right : !stencil.field<[-4,68]xf64>) {
%tleft = "stencil.load"(%left) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<?xf64>
%tright = "stencil.load"(%right) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<?xf64>
%tout = "stencil.apply"(%tleft, %tright) ({
^0(%lefti : !stencil.temp<?xf64>, %righti : !stencil.temp<?xf64>):
%l = "stencil.access"(%lefti) {"offset" = #stencil.index<-1>} : (!stencil.temp<?xf64>) -> f64
%r = "stencil.access"(%righti) {"offset" = #stencil.index< 1>} : (!stencil.temp<?xf64>) -> f64
%o = arith.addf %l, %r : f64
"stencil.return"(%o) : (f64) -> ()
}) : (!stencil.temp<?xf64>, !stencil.temp<?xf64>) -> !stencil.temp<?xf64>
"stencil.store"(%tout, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<64>} : (!stencil.temp<?xf64>, !stencil.field<[-4,68]xf64>) -> ()
"func.return"() : () -> ()
}
}

// CHECK: builtin.module {
// CHECK-NEXT: func.func @different_input_offsets(%out : !stencil.field<[-4,68]xf64>, %left : !stencil.field<[-4,68]xf64>, %right : !stencil.field<[-4,68]xf64>) {
// CHECK-NEXT: %tleft = "stencil.load"(%left) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<[-1,63]xf64>
// CHECK-NEXT: %tright = "stencil.load"(%right) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<[1,65]xf64>
// CHECK-NEXT: %tout = "stencil.apply"(%tleft, %tright) ({
// CHECK-NEXT: ^0(%lefti : !stencil.temp<[-1,63]xf64>, %righti : !stencil.temp<[1,65]xf64>):
// CHECK-NEXT: %l = "stencil.access"(%lefti) {"offset" = #stencil.index<-1>} : (!stencil.temp<[-1,63]xf64>) -> f64
// CHECK-NEXT: %r = "stencil.access"(%righti) {"offset" = #stencil.index<1>} : (!stencil.temp<[1,65]xf64>) -> f64
// CHECK-NEXT: %o = arith.addf %l, %r : f64
// CHECK-NEXT: "stencil.return"(%o) : (f64) -> ()
// CHECK-NEXT: }) : (!stencil.temp<[-1,63]xf64>, !stencil.temp<[1,65]xf64>) -> !stencil.temp<[0,64]xf64>
// CHECK-NEXT: "stencil.store"(%tout, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<64>} : (!stencil.temp<[0,64]xf64>, !stencil.field<[-4,68]xf64>) -> ()
// CHECK-NEXT: "func.return"() : () -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }

// -----

builtin.module {
func.func @stencil_hdiff(%0 : !stencil.field<?x?x?xf64>, %1 : !stencil.field<?x?x?xf64>) {
%2 = "stencil.cast"(%0) : (!stencil.field<?x?x?xf64>) -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>
Expand Down
54 changes: 37 additions & 17 deletions xdsl/transforms/experimental/StencilShapeInference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
TempType,
)

from xdsl.ir import Attribute, MLContext, Operation, SSAValue
from xdsl.ir import Attribute, BlockArgument, MLContext, Operation, SSAValue
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
Expand Down Expand Up @@ -84,6 +84,35 @@ def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter, /):
op.temp.typ = TempType(tuple(zip(temp_lb, temp_ub)), temp.element_type)


class AccessOpShapeInference(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: AccessOp, rewriter: PatternRewriter):
apply = op.parent_op()
assert isinstance(apply, ApplyOp)
assert isa(op.temp.typ, TempType[Attribute])
assert isinstance(op.temp, BlockArgument)
assert op.temp.block.parent_op() is apply
assert isa(apply.res[0].typ, TempType[Attribute]), f"{apply.res[0]}"

temp_typ = op.temp.typ
temp_lb = None
temp_ub = None
if isinstance(temp_typ.bounds, StencilBoundsAttr):
temp_lb = temp_typ.bounds.lb
temp_ub = temp_typ.bounds.ub
output_size = apply.res[0].typ.bounds
print(f"output_size: {output_size}")
assert isinstance(output_size, StencilBoundsAttr)

lb = IndexAttr.min(output_size.lb + op.offset, temp_lb)
ub = IndexAttr.max(output_size.ub + op.offset, temp_ub)
ntyp = TempType(tuple(zip(lb, ub)), temp_typ.element_type)

print(f"new access size: {ntyp.bounds}")

op.temp.typ = ntyp


class ApplyOpShapeInference(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /):
Expand All @@ -95,28 +124,16 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /):
ntyp = res_typ
assert isinstance(ntyp.bounds, StencilBoundsAttr)

accesses = [a for a in op.walk() if isinstance(a, AccessOp)]
if not accesses:
return
for access in accesses:
temp = access.temp.typ
assert isa(temp, TempType[Attribute])

lb = IndexAttr.min(res_typ.bounds.lb + access.offset, ntyp.bounds.lb)
ub = IndexAttr.max(res_typ.bounds.ub + access.offset, ntyp.bounds.ub)
ntyp = TempType(tuple(zip(lb, ub)), temp.element_type)
assert isinstance(ntyp.bounds, StencilBoundsAttr)

for i, arg in enumerate(op.args):
for i, arg in enumerate(op.region.block.args):
if not isa(arg.typ, TempType[Attribute]):
continue
arg.typ = ntyp
op.region.block.args[i].typ = ntyp
op.operands[i].typ = arg.typ


ShapeInference = GreedyRewritePatternApplier(
[
ApplyOpShapeInference(),
AccessOpShapeInference(),
LoadOpShapeInference(),
StoreOpShapeInference(),
]
Expand All @@ -128,6 +145,9 @@ class StencilShapeInferencePass(ModulePass):

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
inference_walker = PatternRewriteWalker(
ShapeInference, apply_recursively=False, walk_reverse=True
ShapeInference,
apply_recursively=False,
walk_reverse=True,
walk_regions_first=True,
)
inference_walker.rewrite_module(op)

0 comments on commit ccd926a

Please sign in to comment.