@@ -698,30 +698,63 @@ void mlir::canonicalizeSetAndOperands(
698
698
}
699
699
700
700
namespace {
701
- // / Simplify AffineApply operations.
701
+ // / Simplify AffineApply, AffineLoad, and AffineStore operations by composing
702
+ // / maps that supply results into them.
702
703
// /
703
- struct SimplifyAffineApply : public OpRewritePattern <AffineApplyOp> {
704
- using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
704
+ template <typename AffineOpTy>
705
+ struct SimplifyAffineOp : public OpRewritePattern <AffineOpTy> {
706
+ using OpRewritePattern<AffineOpTy>::OpRewritePattern;
705
707
706
- PatternMatchResult matchAndRewrite (AffineApplyOp apply,
707
- PatternRewriter &rewriter) const override {
708
- auto map = apply.getAffineMap ();
708
+ void replaceAffineOp (PatternRewriter &rewriter, AffineOpTy affineOp,
709
+ AffineMap map, ArrayRef<Value *> mapOperands) const ;
709
710
711
+ PatternMatchResult matchAndRewrite (AffineOpTy affineOp,
712
+ PatternRewriter &rewriter) const override {
713
+ static_assert (std::is_same<AffineOpTy, AffineLoadOp>::value ||
714
+ std::is_same<AffineOpTy, AffineStoreOp>::value ||
715
+ std::is_same<AffineOpTy, AffineApplyOp>::value,
716
+ " affine load/store/apply op expected" );
717
+ auto map = affineOp.getAffineMap ();
710
718
AffineMap oldMap = map;
711
- SmallVector<Value *, 8 > resultOperands (apply.getOperands ());
719
+ auto oldOperands = affineOp.getMapOperands ();
720
+ SmallVector<Value *, 8 > resultOperands (oldOperands);
712
721
composeAffineMapAndOperands (&map, &resultOperands);
713
- if (map == oldMap)
714
- return matchFailure ();
722
+ if (map == oldMap && std::equal (oldOperands.begin (), oldOperands.end (),
723
+ resultOperands.begin ()))
724
+ return this ->matchFailure ();
715
725
716
- rewriter. replaceOpWithNewOp <AffineApplyOp>(apply , map, resultOperands);
717
- return matchSuccess ();
726
+ replaceAffineOp (rewriter, affineOp , map, resultOperands);
727
+ return this -> matchSuccess ();
718
728
}
719
729
};
730
+
731
+ // Specialize the template to account for the different build signatures for
732
+ // affine load, store, and apply ops.
733
+ template <>
734
+ void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
735
+ PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
736
+ ArrayRef<Value *> mapOperands) const {
737
+ rewriter.replaceOpWithNewOp <AffineLoadOp>(load, load.getMemRef (), map,
738
+ mapOperands);
739
+ }
740
+ template <>
741
+ void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
742
+ PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
743
+ ArrayRef<Value *> mapOperands) const {
744
+ rewriter.replaceOpWithNewOp <AffineStoreOp>(
745
+ store, store.getValueToStore (), store.getMemRef (), map, mapOperands);
746
+ }
747
+ template <>
748
+ void SimplifyAffineOp<AffineApplyOp>::replaceAffineOp(
749
+ PatternRewriter &rewriter, AffineApplyOp apply, AffineMap map,
750
+ ArrayRef<Value *> mapOperands) const {
751
+ rewriter.replaceOpWithNewOp <AffineApplyOp>(apply, map, mapOperands);
752
+ }
720
753
} // end anonymous namespace.
721
754
722
755
void AffineApplyOp::getCanonicalizationPatterns (
723
756
OwningRewritePatternList &results, MLIRContext *context) {
724
- results.insert <SimplifyAffineApply >(context);
757
+ results.insert <SimplifyAffineOp<AffineApplyOp> >(context);
725
758
}
726
759
727
760
// ===----------------------------------------------------------------------===//
@@ -1689,6 +1722,7 @@ void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1689
1722
1690
1723
void AffineLoadOp::build (Builder *builder, OperationState *result,
1691
1724
AffineMap map, ArrayRef<Value *> operands) {
1725
+ assert (operands.size () == 1 + map.getNumInputs () && " inconsistent operands" );
1692
1726
result->addOperands (operands);
1693
1727
if (map)
1694
1728
result->addAttribute (getMapAttrName (), builder->getAffineMapAttr (map));
@@ -1697,17 +1731,25 @@ void AffineLoadOp::build(Builder *builder, OperationState *result,
1697
1731
}
1698
1732
1699
1733
void AffineLoadOp::build (Builder *builder, OperationState *result,
1700
- Value *memref, ArrayRef<Value *> indices) {
1734
+ Value *memref, AffineMap map,
1735
+ ArrayRef<Value *> mapOperands) {
1736
+ assert (map.getNumInputs () == mapOperands.size () && " inconsistent index info" );
1701
1737
result->addOperands (memref);
1702
- result->addOperands (indices);
1738
+ result->addOperands (mapOperands);
1739
+ auto memrefType = memref->getType ().cast <MemRefType>();
1740
+ result->addAttribute (getMapAttrName (), builder->getAffineMapAttr (map));
1741
+ result->types .push_back (memrefType.getElementType ());
1742
+ }
1743
+
1744
+ void AffineLoadOp::build (Builder *builder, OperationState *result,
1745
+ Value *memref, ArrayRef<Value *> indices) {
1703
1746
auto memrefType = memref->getType ().cast <MemRefType>();
1704
1747
auto rank = memrefType.getRank ();
1705
1748
// Create identity map for memrefs with at least one dimension or () -> ()
1706
1749
// for zero-dimensional memrefs.
1707
1750
auto map = rank ? builder->getMultiDimIdentityMap (rank)
1708
1751
: builder->getEmptyAffineMap ();
1709
- result->addAttribute (getMapAttrName (), builder->getAffineMapAttr (map));
1710
- result->types .push_back (memrefType.getElementType ());
1752
+ build (builder, result, memref, map, indices);
1711
1753
}
1712
1754
1713
1755
ParseResult AffineLoadOp::parse (OpAsmParser *parser, OperationState *result) {
@@ -1733,7 +1775,7 @@ void AffineLoadOp::print(OpAsmPrinter *p) {
1733
1775
*p << " affine.load " << *getMemRef () << ' [' ;
1734
1776
AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName ());
1735
1777
if (mapAttr) {
1736
- SmallVector<Value *, 2 > operands (getIndices ());
1778
+ SmallVector<Value *, 2 > operands (getMapOperands ());
1737
1779
p->printAffineMapOfSSAIds (mapAttr, operands);
1738
1780
}
1739
1781
*p << ' ]' ;
@@ -1759,7 +1801,7 @@ LogicalResult AffineLoadOp::verify() {
1759
1801
" expects the number of subscripts to be equal to memref rank" );
1760
1802
}
1761
1803
1762
- for (auto *idx : getIndices ()) {
1804
+ for (auto *idx : getMapOperands ()) {
1763
1805
if (!idx->getType ().isIndex ())
1764
1806
return emitOpError (" index to load must have 'index' type" );
1765
1807
if (!isValidAffineIndexOperand (idx))
@@ -1772,34 +1814,34 @@ void AffineLoadOp::getCanonicalizationPatterns(
1772
1814
OwningRewritePatternList &results, MLIRContext *context) {
1773
1815
// / load(memrefcast) -> load
1774
1816
results.insert <MemRefCastFolder>(getOperationName (), context);
1817
+ results.insert <SimplifyAffineOp<AffineLoadOp>>(context);
1775
1818
}
1776
1819
1777
1820
// ===----------------------------------------------------------------------===//
1778
1821
// AffineStoreOp
1779
1822
// ===----------------------------------------------------------------------===//
1780
1823
1781
1824
void AffineStoreOp::build (Builder *builder, OperationState *result,
1782
- Value *valueToStore, AffineMap map,
1783
- ArrayRef<Value *> operands) {
1825
+ Value *valueToStore, Value *memref, AffineMap map,
1826
+ ArrayRef<Value *> mapOperands) {
1827
+ assert (map.getNumInputs () == mapOperands.size () && " inconsistent index info" );
1784
1828
result->addOperands (valueToStore);
1785
- result->addOperands (operands );
1786
- if (map)
1787
- result->addAttribute (getMapAttrName (), builder->getAffineMapAttr (map));
1829
+ result->addOperands (memref );
1830
+ result-> addOperands (mapOperands);
1831
+ result->addAttribute (getMapAttrName (), builder->getAffineMapAttr (map));
1788
1832
}
1789
1833
1834
+ // Use identity map.
1790
1835
void AffineStoreOp::build (Builder *builder, OperationState *result,
1791
1836
Value *valueToStore, Value *memref,
1792
- ArrayRef<Value *> operands) {
1793
- result->addOperands (valueToStore);
1794
- result->addOperands (memref);
1795
- result->addOperands (operands);
1837
+ ArrayRef<Value *> indices) {
1796
1838
auto memrefType = memref->getType ().cast <MemRefType>();
1797
1839
auto rank = memrefType.getRank ();
1798
1840
// Create identity map for memrefs with at least one dimension or () -> ()
1799
1841
// for zero-dimensional memrefs.
1800
1842
auto map = rank ? builder->getMultiDimIdentityMap (rank)
1801
1843
: builder->getEmptyAffineMap ();
1802
- result-> addAttribute ( getMapAttrName (), builder-> getAffineMapAttr ( map) );
1844
+ build (builder, result, valueToStore, memref, map, indices );
1803
1845
}
1804
1846
1805
1847
ParseResult AffineStoreOp::parse (OpAsmParser *parser, OperationState *result) {
@@ -1828,7 +1870,7 @@ void AffineStoreOp::print(OpAsmPrinter *p) {
1828
1870
*p << " , " << *getMemRef () << ' [' ;
1829
1871
AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName ());
1830
1872
if (mapAttr) {
1831
- SmallVector<Value *, 2 > operands (getIndices ());
1873
+ SmallVector<Value *, 2 > operands (getMapOperands ());
1832
1874
p->printAffineMapOfSSAIds (mapAttr, operands);
1833
1875
}
1834
1876
*p << ' ]' ;
@@ -1855,7 +1897,7 @@ LogicalResult AffineStoreOp::verify() {
1855
1897
" expects the number of subscripts to be equal to memref rank" );
1856
1898
}
1857
1899
1858
- for (auto *idx : getIndices ()) {
1900
+ for (auto *idx : getMapOperands ()) {
1859
1901
if (!idx->getType ().isIndex ())
1860
1902
return emitOpError (" index to store must have 'index' type" );
1861
1903
if (!isValidAffineIndexOperand (idx))
@@ -1868,6 +1910,7 @@ void AffineStoreOp::getCanonicalizationPatterns(
1868
1910
OwningRewritePatternList &results, MLIRContext *context) {
1869
1911
// / load(memrefcast) -> load
1870
1912
results.insert <MemRefCastFolder>(getOperationName (), context);
1913
+ results.insert <SimplifyAffineOp<AffineStoreOp>>(context);
1871
1914
}
1872
1915
1873
1916
#define GET_OP_CLASSES
0 commit comments