From ccd926a6ac92ee65190d80271ad4908704b460b1 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Thu, 1 Jun 2023 14:12:26 +0100 Subject: [PATCH] Split ApplyOp and AccessOp shape inference patterns, to naturally handle different input offsets. Add a test case for this. Now, shape inference walk ops regions first; each accessop offset it's stencil.temp operand size individually. stencil.apply shape inference then just spreads these bounds to its operands (from its block arguments). --- .../transforms/stencil-shape-inference.mlir | 34 ++++++++++++ .../experimental/StencilShapeInference.py | 54 +++++++++++++------ 2 files changed, 71 insertions(+), 17 deletions(-) diff --git a/tests/filecheck/transforms/stencil-shape-inference.mlir b/tests/filecheck/transforms/stencil-shape-inference.mlir index 8af15b7bf7..3ff5261dc5 100644 --- a/tests/filecheck/transforms/stencil-shape-inference.mlir +++ b/tests/filecheck/transforms/stencil-shape-inference.mlir @@ -1,5 +1,39 @@ // RUN: xdsl-opt -p stencil-shape-inference --verify-diagnostics --split-input-file %s | filecheck %s +builtin.module { + func.func @different_input_offsets(%out : !stencil.field<[-4,68]xf64>, %left : !stencil.field<[-4,68]xf64>, %right : !stencil.field<[-4,68]xf64>) { + %tleft = "stencil.load"(%left) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp + %tright = "stencil.load"(%right) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp + %tout = "stencil.apply"(%tleft, %tright) ({ + ^0(%lefti : !stencil.temp, %righti : !stencil.temp): + %l = "stencil.access"(%lefti) {"offset" = #stencil.index<-1>} : (!stencil.temp) -> f64 + %r = "stencil.access"(%righti) {"offset" = #stencil.index< 1>} : (!stencil.temp) -> f64 + %o = arith.addf %l, %r : f64 + "stencil.return"(%o) : (f64) -> () + }) : (!stencil.temp, !stencil.temp) -> !stencil.temp + "stencil.store"(%tout, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<64>} : (!stencil.temp, !stencil.field<[-4,68]xf64>) -> () + "func.return"() : () -> () + } +} + +// CHECK: builtin.module { +// CHECK-NEXT: func.func @different_input_offsets(%out : !stencil.field<[-4,68]xf64>, %left : !stencil.field<[-4,68]xf64>, %right : !stencil.field<[-4,68]xf64>) { +// CHECK-NEXT: %tleft = "stencil.load"(%left) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<[-1,63]xf64> +// CHECK-NEXT: %tright = "stencil.load"(%right) : (!stencil.field<[-4,68]xf64>) -> !stencil.temp<[1,65]xf64> +// CHECK-NEXT: %tout = "stencil.apply"(%tleft, %tright) ({ +// CHECK-NEXT: ^0(%lefti : !stencil.temp<[-1,63]xf64>, %righti : !stencil.temp<[1,65]xf64>): +// CHECK-NEXT: %l = "stencil.access"(%lefti) {"offset" = #stencil.index<-1>} : (!stencil.temp<[-1,63]xf64>) -> f64 +// CHECK-NEXT: %r = "stencil.access"(%righti) {"offset" = #stencil.index<1>} : (!stencil.temp<[1,65]xf64>) -> f64 +// CHECK-NEXT: %o = arith.addf %l, %r : f64 +// CHECK-NEXT: "stencil.return"(%o) : (f64) -> () +// CHECK-NEXT: }) : (!stencil.temp<[-1,63]xf64>, !stencil.temp<[1,65]xf64>) -> !stencil.temp<[0,64]xf64> +// CHECK-NEXT: "stencil.store"(%tout, %out) {"lb" = #stencil.index<0>, "ub" = #stencil.index<64>} : (!stencil.temp<[0,64]xf64>, !stencil.field<[-4,68]xf64>) -> () +// CHECK-NEXT: "func.return"() : () -> () +// CHECK-NEXT: } +// CHECK-NEXT: } + +// ----- + builtin.module { func.func @stencil_hdiff(%0 : !stencil.field, %1 : !stencil.field) { %2 = "stencil.cast"(%0) : (!stencil.field) -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> diff --git a/xdsl/transforms/experimental/StencilShapeInference.py b/xdsl/transforms/experimental/StencilShapeInference.py index ed6735c850..452f676e2b 100644 --- a/xdsl/transforms/experimental/StencilShapeInference.py +++ b/xdsl/transforms/experimental/StencilShapeInference.py @@ -11,7 +11,7 @@ TempType, ) -from xdsl.ir import Attribute, MLContext, Operation, SSAValue +from xdsl.ir import Attribute, BlockArgument, MLContext, Operation, SSAValue from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -84,6 +84,35 @@ def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter, /): op.temp.typ = TempType(tuple(zip(temp_lb, temp_ub)), temp.element_type) +class AccessOpShapeInference(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: AccessOp, rewriter: PatternRewriter): + apply = op.parent_op() + assert isinstance(apply, ApplyOp) + assert isa(op.temp.typ, TempType[Attribute]) + assert isinstance(op.temp, BlockArgument) + assert op.temp.block.parent_op() is apply + assert isa(apply.res[0].typ, TempType[Attribute]), f"{apply.res[0]}" + + temp_typ = op.temp.typ + temp_lb = None + temp_ub = None + if isinstance(temp_typ.bounds, StencilBoundsAttr): + temp_lb = temp_typ.bounds.lb + temp_ub = temp_typ.bounds.ub + output_size = apply.res[0].typ.bounds + print(f"output_size: {output_size}") + assert isinstance(output_size, StencilBoundsAttr) + + lb = IndexAttr.min(output_size.lb + op.offset, temp_lb) + ub = IndexAttr.max(output_size.ub + op.offset, temp_ub) + ntyp = TempType(tuple(zip(lb, ub)), temp_typ.element_type) + + print(f"new access size: {ntyp.bounds}") + + op.temp.typ = ntyp + + class ApplyOpShapeInference(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /): @@ -95,28 +124,16 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /): ntyp = res_typ assert isinstance(ntyp.bounds, StencilBoundsAttr) - accesses = [a for a in op.walk() if isinstance(a, AccessOp)] - if not accesses: - return - for access in accesses: - temp = access.temp.typ - assert isa(temp, TempType[Attribute]) - - lb = IndexAttr.min(res_typ.bounds.lb + access.offset, ntyp.bounds.lb) - ub = IndexAttr.max(res_typ.bounds.ub + access.offset, ntyp.bounds.ub) - ntyp = TempType(tuple(zip(lb, ub)), temp.element_type) - assert isinstance(ntyp.bounds, StencilBoundsAttr) - - for i, arg in enumerate(op.args): + for i, arg in enumerate(op.region.block.args): if not isa(arg.typ, TempType[Attribute]): continue - arg.typ = ntyp - op.region.block.args[i].typ = ntyp + op.operands[i].typ = arg.typ ShapeInference = GreedyRewritePatternApplier( [ ApplyOpShapeInference(), + AccessOpShapeInference(), LoadOpShapeInference(), StoreOpShapeInference(), ] @@ -128,6 +145,9 @@ class StencilShapeInferencePass(ModulePass): def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: inference_walker = PatternRewriteWalker( - ShapeInference, apply_recursively=False, walk_reverse=True + ShapeInference, + apply_recursively=False, + walk_reverse=True, + walk_regions_first=True, ) inference_walker.rewrite_module(op)