Skip to content

Commit bd7de6d

Browse files
bondhugulatensorflower-gardener
authored andcommitted
Add rewrite pattern to compose maps into affine load/stores
- add canonicalization pattern to compose maps into affine loads/stores; templatize the pattern and reuse it for affine.apply as well - rename getIndices -> getMapOperands() (getIndices is confusing since these are no longer the indices themselves but operands to the map whose results are the indices). This also makes the accessor uniform across affine.apply/load/store. Change arg names on the affine load/store builder to avoid confusion. Drop an unused confusing build method on AffineStoreOp. - update incomplete doc comment for canonicalizeMapAndOperands (this was missed from a previous update). Addresses issue tensorflow/mlir#121 Signed-off-by: Uday Bondhugula <uday@polymagelabs.com> Closes tensorflow/mlir#122 COPYBARA_INTEGRATE_REVIEW=tensorflow/mlir#122 from bondhugula:compose-load-store e71de1771e56a85c4282c10cb43f30cef0701c4f PiperOrigin-RevId: 269619540
1 parent 62e1faa commit bd7de6d

File tree

8 files changed

+130
-51
lines changed

8 files changed

+130
-51
lines changed

mlir/include/mlir/Dialect/AffineOps/AffineOps.h

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ class AffineApplyOp : public Op<AffineApplyOp, OpTrait::VariadicOperands,
8383

8484
static StringRef getOperationName() { return "affine.apply"; }
8585

86+
operand_range getMapOperands() { return getOperands(); }
87+
8688
// Hooks to customize behavior of this op.
8789
static ParseResult parse(OpAsmParser *parser, OperationState *result);
8890
void print(OpAsmPrinter *p);
@@ -400,9 +402,12 @@ class AffineLoadOp : public Op<AffineLoadOp, OpTrait::OneResult,
400402
/// Builds an affine load op with the specified map and operands.
401403
static void build(Builder *builder, OperationState *result, AffineMap map,
402404
ArrayRef<Value *> operands);
403-
/// Builds an affine load op an identify map and operands.
405+
/// Builds an affine load op with an identity map and operands.
404406
static void build(Builder *builder, OperationState *result, Value *memref,
405407
ArrayRef<Value *> indices = {});
408+
/// Builds an affine load op with the specified map and its operands.
409+
static void build(Builder *builder, OperationState *result, Value *memref,
410+
AffineMap map, ArrayRef<Value *> mapOperands);
406411

407412
/// Returns the operand index of the memref.
408413
unsigned getMemRefOperandIndex() { return 0; }
@@ -415,7 +420,7 @@ class AffineLoadOp : public Op<AffineLoadOp, OpTrait::OneResult,
415420
}
416421

417422
/// Get affine map operands.
418-
operand_range getIndices() { return llvm::drop_begin(getOperands(), 1); }
423+
operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 1); }
419424

420425
/// Returns the affine map used to index the memref for this operation.
421426
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
@@ -462,14 +467,14 @@ class AffineStoreOp : public Op<AffineStoreOp, OpTrait::ZeroResult,
462467
public:
463468
using Op::Op;
464469

465-
/// Builds an affine store operation with the specified map and operands.
466-
static void build(Builder *builder, OperationState *result,
467-
Value *valueToStore, AffineMap map,
468-
ArrayRef<Value *> operands);
469-
/// Builds an affine store operation with an identity map and operands.
470+
/// Builds an affine store operation with the provided indices (identity map).
470471
static void build(Builder *builder, OperationState *result,
471472
Value *valueToStore, Value *memref,
472-
ArrayRef<Value *> operands);
473+
ArrayRef<Value *> indices);
474+
/// Builds an affine store operation with the specified map and its operands.
475+
static void build(Builder *builder, OperationState *result,
476+
Value *valueToStore, Value *memref, AffineMap map,
477+
ArrayRef<Value *> mapOperands);
473478

474479
/// Get value to be stored by store operation.
475480
Value *getValueToStore() { return getOperand(0); }
@@ -486,7 +491,7 @@ class AffineStoreOp : public Op<AffineStoreOp, OpTrait::ZeroResult,
486491
}
487492

488493
/// Get affine map operands.
489-
operand_range getIndices() { return llvm::drop_begin(getOperands(), 2); }
494+
operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 2); }
490495

491496
/// Returns the affine map used to index the memref for this operation.
492497
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
@@ -521,6 +526,9 @@ bool isValidSymbol(Value *value);
521526
/// Modifies both `map` and `operands` in-place so as to:
522527
/// 1. drop duplicate operands
523528
/// 2. drop unused dims and symbols from map
529+
/// 3. promote valid symbols to symbolic operands in case they appeared as
530+
/// dimensional operands
531+
/// 4. propagate constant operands and drop them
524532
void canonicalizeMapAndOperands(AffineMap *map,
525533
llvm::SmallVectorImpl<Value *> *operands);
526534
/// Canonicalizes an integer set the same way canonicalizeMapAndOperands does

mlir/lib/Analysis/LoopAnalysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ static bool isContiguousAccess(Value *iv, LoadOrStoreOp memoryOp,
236236

237237
int uniqueVaryingIndexAlongIv = -1;
238238
auto accessMap = memoryOp.getAffineMap();
239-
SmallVector<Value *, 4> mapOperands(memoryOp.getIndices());
239+
SmallVector<Value *, 4> mapOperands(memoryOp.getMapOperands());
240240
unsigned numDims = accessMap.getNumDims();
241241
for (unsigned i = 0, e = memRefType.getRank(); i < e; ++i) {
242242
// Gather map operands used result expr 'i' in 'exprOperands'.

mlir/lib/Analysis/Utils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,7 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
847847
opInst = loadOrStoreOpInst;
848848
auto loadMemrefType = loadOp.getMemRefType();
849849
indices.reserve(loadMemrefType.getRank());
850-
for (auto *index : loadOp.getIndices()) {
850+
for (auto *index : loadOp.getMapOperands()) {
851851
indices.push_back(index);
852852
}
853853
} else {
@@ -857,7 +857,7 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
857857
memref = storeOp.getMemRef();
858858
auto storeMemrefType = storeOp.getMemRefType();
859859
indices.reserve(storeMemrefType.getRank());
860-
for (auto *index : storeOp.getIndices()) {
860+
for (auto *index : storeOp.getMapOperands()) {
861861
indices.push_back(index);
862862
}
863863
}

mlir/lib/Dialect/AffineOps/AffineOps.cpp

Lines changed: 73 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -698,30 +698,63 @@ void mlir::canonicalizeSetAndOperands(
698698
}
699699

700700
namespace {
701-
/// Simplify AffineApply operations.
701+
/// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
702+
/// maps that supply results into them.
702703
///
703-
struct SimplifyAffineApply : public OpRewritePattern<AffineApplyOp> {
704-
using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
704+
template <typename AffineOpTy>
705+
struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
706+
using OpRewritePattern<AffineOpTy>::OpRewritePattern;
705707

706-
PatternMatchResult matchAndRewrite(AffineApplyOp apply,
707-
PatternRewriter &rewriter) const override {
708-
auto map = apply.getAffineMap();
708+
void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
709+
AffineMap map, ArrayRef<Value *> mapOperands) const;
709710

711+
PatternMatchResult matchAndRewrite(AffineOpTy affineOp,
712+
PatternRewriter &rewriter) const override {
713+
static_assert(std::is_same<AffineOpTy, AffineLoadOp>::value ||
714+
std::is_same<AffineOpTy, AffineStoreOp>::value ||
715+
std::is_same<AffineOpTy, AffineApplyOp>::value,
716+
"affine load/store/apply op expected");
717+
auto map = affineOp.getAffineMap();
710718
AffineMap oldMap = map;
711-
SmallVector<Value *, 8> resultOperands(apply.getOperands());
719+
auto oldOperands = affineOp.getMapOperands();
720+
SmallVector<Value *, 8> resultOperands(oldOperands);
712721
composeAffineMapAndOperands(&map, &resultOperands);
713-
if (map == oldMap)
714-
return matchFailure();
722+
if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
723+
resultOperands.begin()))
724+
return this->matchFailure();
715725

716-
rewriter.replaceOpWithNewOp<AffineApplyOp>(apply, map, resultOperands);
717-
return matchSuccess();
726+
replaceAffineOp(rewriter, affineOp, map, resultOperands);
727+
return this->matchSuccess();
718728
}
719729
};
730+
731+
// Specialize the template to account for the different build signatures for
732+
// affine load, store, and apply ops.
733+
template <>
734+
void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
735+
PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
736+
ArrayRef<Value *> mapOperands) const {
737+
rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
738+
mapOperands);
739+
}
740+
template <>
741+
void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
742+
PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
743+
ArrayRef<Value *> mapOperands) const {
744+
rewriter.replaceOpWithNewOp<AffineStoreOp>(
745+
store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
746+
}
747+
template <>
748+
void SimplifyAffineOp<AffineApplyOp>::replaceAffineOp(
749+
PatternRewriter &rewriter, AffineApplyOp apply, AffineMap map,
750+
ArrayRef<Value *> mapOperands) const {
751+
rewriter.replaceOpWithNewOp<AffineApplyOp>(apply, map, mapOperands);
752+
}
720753
} // end anonymous namespace.
721754

722755
void AffineApplyOp::getCanonicalizationPatterns(
723756
OwningRewritePatternList &results, MLIRContext *context) {
724-
results.insert<SimplifyAffineApply>(context);
757+
results.insert<SimplifyAffineOp<AffineApplyOp>>(context);
725758
}
726759

727760
//===----------------------------------------------------------------------===//
@@ -1689,6 +1722,7 @@ void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
16891722

16901723
void AffineLoadOp::build(Builder *builder, OperationState *result,
16911724
AffineMap map, ArrayRef<Value *> operands) {
1725+
assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
16921726
result->addOperands(operands);
16931727
if (map)
16941728
result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
@@ -1697,17 +1731,25 @@ void AffineLoadOp::build(Builder *builder, OperationState *result,
16971731
}
16981732

16991733
void AffineLoadOp::build(Builder *builder, OperationState *result,
1700-
Value *memref, ArrayRef<Value *> indices) {
1734+
Value *memref, AffineMap map,
1735+
ArrayRef<Value *> mapOperands) {
1736+
assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
17011737
result->addOperands(memref);
1702-
result->addOperands(indices);
1738+
result->addOperands(mapOperands);
1739+
auto memrefType = memref->getType().cast<MemRefType>();
1740+
result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
1741+
result->types.push_back(memrefType.getElementType());
1742+
}
1743+
1744+
void AffineLoadOp::build(Builder *builder, OperationState *result,
1745+
Value *memref, ArrayRef<Value *> indices) {
17031746
auto memrefType = memref->getType().cast<MemRefType>();
17041747
auto rank = memrefType.getRank();
17051748
// Create identity map for memrefs with at least one dimension or () -> ()
17061749
// for zero-dimensional memrefs.
17071750
auto map = rank ? builder->getMultiDimIdentityMap(rank)
17081751
: builder->getEmptyAffineMap();
1709-
result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
1710-
result->types.push_back(memrefType.getElementType());
1752+
build(builder, result, memref, map, indices);
17111753
}
17121754

17131755
ParseResult AffineLoadOp::parse(OpAsmParser *parser, OperationState *result) {
@@ -1733,7 +1775,7 @@ void AffineLoadOp::print(OpAsmPrinter *p) {
17331775
*p << "affine.load " << *getMemRef() << '[';
17341776
AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
17351777
if (mapAttr) {
1736-
SmallVector<Value *, 2> operands(getIndices());
1778+
SmallVector<Value *, 2> operands(getMapOperands());
17371779
p->printAffineMapOfSSAIds(mapAttr, operands);
17381780
}
17391781
*p << ']';
@@ -1759,7 +1801,7 @@ LogicalResult AffineLoadOp::verify() {
17591801
"expects the number of subscripts to be equal to memref rank");
17601802
}
17611803

1762-
for (auto *idx : getIndices()) {
1804+
for (auto *idx : getMapOperands()) {
17631805
if (!idx->getType().isIndex())
17641806
return emitOpError("index to load must have 'index' type");
17651807
if (!isValidAffineIndexOperand(idx))
@@ -1772,34 +1814,34 @@ void AffineLoadOp::getCanonicalizationPatterns(
17721814
OwningRewritePatternList &results, MLIRContext *context) {
17731815
/// load(memrefcast) -> load
17741816
results.insert<MemRefCastFolder>(getOperationName(), context);
1817+
results.insert<SimplifyAffineOp<AffineLoadOp>>(context);
17751818
}
17761819

17771820
//===----------------------------------------------------------------------===//
17781821
// AffineStoreOp
17791822
//===----------------------------------------------------------------------===//
17801823

17811824
void AffineStoreOp::build(Builder *builder, OperationState *result,
1782-
Value *valueToStore, AffineMap map,
1783-
ArrayRef<Value *> operands) {
1825+
Value *valueToStore, Value *memref, AffineMap map,
1826+
ArrayRef<Value *> mapOperands) {
1827+
assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
17841828
result->addOperands(valueToStore);
1785-
result->addOperands(operands);
1786-
if (map)
1787-
result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
1829+
result->addOperands(memref);
1830+
result->addOperands(mapOperands);
1831+
result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
17881832
}
17891833

1834+
// Use identity map.
17901835
void AffineStoreOp::build(Builder *builder, OperationState *result,
17911836
Value *valueToStore, Value *memref,
1792-
ArrayRef<Value *> operands) {
1793-
result->addOperands(valueToStore);
1794-
result->addOperands(memref);
1795-
result->addOperands(operands);
1837+
ArrayRef<Value *> indices) {
17961838
auto memrefType = memref->getType().cast<MemRefType>();
17971839
auto rank = memrefType.getRank();
17981840
// Create identity map for memrefs with at least one dimension or () -> ()
17991841
// for zero-dimensional memrefs.
18001842
auto map = rank ? builder->getMultiDimIdentityMap(rank)
18011843
: builder->getEmptyAffineMap();
1802-
result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
1844+
build(builder, result, valueToStore, memref, map, indices);
18031845
}
18041846

18051847
ParseResult AffineStoreOp::parse(OpAsmParser *parser, OperationState *result) {
@@ -1828,7 +1870,7 @@ void AffineStoreOp::print(OpAsmPrinter *p) {
18281870
*p << ", " << *getMemRef() << '[';
18291871
AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
18301872
if (mapAttr) {
1831-
SmallVector<Value *, 2> operands(getIndices());
1873+
SmallVector<Value *, 2> operands(getMapOperands());
18321874
p->printAffineMapOfSSAIds(mapAttr, operands);
18331875
}
18341876
*p << ']';
@@ -1855,7 +1897,7 @@ LogicalResult AffineStoreOp::verify() {
18551897
"expects the number of subscripts to be equal to memref rank");
18561898
}
18571899

1858-
for (auto *idx : getIndices()) {
1900+
for (auto *idx : getMapOperands()) {
18591901
if (!idx->getType().isIndex())
18601902
return emitOpError("index to store must have 'index' type");
18611903
if (!isValidAffineIndexOperand(idx))
@@ -1868,6 +1910,7 @@ void AffineStoreOp::getCanonicalizationPatterns(
18681910
OwningRewritePatternList &results, MLIRContext *context) {
18691911
/// load(memrefcast) -> load
18701912
results.insert<MemRefCastFolder>(getOperationName(), context);
1913+
results.insert<SimplifyAffineOp<AffineStoreOp>>(context);
18711914
}
18721915

18731916
#define GET_OP_CLASSES

mlir/lib/Transforms/LowerAffine.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> {
403403
virtual PatternMatchResult
404404
matchAndRewrite(AffineLoadOp op, PatternRewriter &rewriter) const override {
405405
// Expand affine map from 'affineLoadOp'.
406-
SmallVector<Value *, 8> indices(op.getIndices());
406+
SmallVector<Value *, 8> indices(op.getMapOperands());
407407
auto maybeExpandedMap =
408408
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
409409
if (!maybeExpandedMap)
@@ -425,7 +425,7 @@ class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> {
425425
virtual PatternMatchResult
426426
matchAndRewrite(AffineStoreOp op, PatternRewriter &rewriter) const override {
427427
// Expand affine map from 'affineStoreOp'.
428-
SmallVector<Value *, 8> indices(op.getIndices());
428+
SmallVector<Value *, 8> indices(op.getMapOperands());
429429
auto maybeExpandedMap =
430430
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
431431
if (!maybeExpandedMap)

mlir/lib/Transforms/Vectorize.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -814,14 +814,15 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv,
814814
// as needed by various targets.
815815
if (auto load = dyn_cast<AffineLoadOp>(opInst)) {
816816
OpBuilder b(opInst);
817-
SmallVector<Value *, 4> mapOperands(load.getIndices());
817+
SmallVector<Value *, 4> mapOperands(load.getMapOperands());
818818
SmallVector<Value *, 8> indices;
819819
indices.reserve(load.getMemRefType().getRank());
820820
if (load.getAffineMap() !=
821821
b.getMultiDimIdentityMap(load.getMemRefType().getRank())) {
822822
computeMemoryOpIndices(opInst, load.getAffineMap(), mapOperands, indices);
823823
} else {
824-
indices.append(load.getIndices().begin(), load.getIndices().end());
824+
indices.append(load.getMapOperands().begin(),
825+
load.getMapOperands().end());
825826
}
826827
auto permutationMap =
827828
makePermutationMap(opInst, indices, state->strategy->loopToVectorDim);
@@ -1038,15 +1039,16 @@ static Operation *vectorizeOneOperation(Operation *opInst,
10381039
auto *value = store.getValueToStore();
10391040
auto *vectorValue = vectorizeOperand(value, opInst, state);
10401041

1041-
SmallVector<Value *, 4> mapOperands(store.getIndices());
1042+
SmallVector<Value *, 4> mapOperands(store.getMapOperands());
10421043
SmallVector<Value *, 8> indices;
10431044
indices.reserve(store.getMemRefType().getRank());
10441045
if (store.getAffineMap() !=
10451046
b.getMultiDimIdentityMap(store.getMemRefType().getRank())) {
10461047
computeMemoryOpIndices(opInst, store.getAffineMap(), mapOperands,
10471048
indices);
10481049
} else {
1049-
indices.append(store.getIndices().begin(), store.getIndices().end());
1050+
indices.append(store.getMapOperands().begin(),
1051+
store.getMapOperands().end());
10501052
}
10511053

10521054
auto permutationMap =

mlir/test/AffineOps/canonicalize.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ func @fold_empty_loop() {
424424
}
425425
return
426426
}
427+
// CHECK: return
427428

428429
// -----
429430

@@ -476,3 +477,29 @@ func @canonicalize_bounds(%M : index, %N : index) {
476477
}
477478
return
478479
}
480+
481+
// -----
482+
483+
// Compose maps into affine load and store ops.
484+
485+
// CHECK-DAG: #map{{[0-9]+}} = (d0) -> (d0 + 1)
486+
487+
// CHECK-LABEL: @compose_into_affine_load_store
488+
func @compose_into_affine_load_store(%A : memref<1024xf32>, %u : index) {
489+
%cf1 = constant 1.0 : f32
490+
// CHECK: affine.for %[[IV:.*]] = 0 to 1024
491+
affine.for %i = 0 to 1024 {
492+
// Make sure the unused operand (%u below) gets dropped as well.
493+
%idx = affine.apply (d0, d1) -> (d0 + 1) (%i, %u)
494+
affine.load %A[%idx] : memref<1024xf32>
495+
affine.store %cf1, %A[%idx] : memref<1024xf32>
496+
// CHECK-NEXT: affine.load %{{.*}}[%[[IV]] + 1]
497+
// CHECK-NEXT: affine.store %cst, %{{.*}}[%[[IV]] + 1]
498+
499+
// Map remains the same, but operand changes on composition.
500+
%copy = affine.apply (d0) -> (d0) (%i)
501+
affine.load %A[%copy] : memref<1024xf32>
502+
// CHECK-NEXT: affine.load %{{.*}}[%[[IV]]]
503+
}
504+
return
505+
}

0 commit comments

Comments
 (0)