From 4b1e0f8d4060ea3b6a7f6ffb5ee6bab2778dc40c Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Wed, 9 Oct 2019 10:36:54 -0700 Subject: [PATCH] Add support for some multi-store cases in affine fusion This PR is a stepping stone towards supporting generic multi-store source loop nests in affine loop fusion. It extends the algorithm to support fusion of multi-store loop nests that: 1. have only one store that writes to a function-local live out, and 2. the remaining stores are involved in loop nest self dependences or no dependences within the function. Closes #162 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/162 from dcaballe:dcaballe/multi-output-fusion 7fb7dec6fe8b45f5ce176f018bfe37b256420c45 PiperOrigin-RevId: 273773907 --- .../mlir/lib/Transforms/LoopFusion.cpp | 100 ++++++++++++------ 1 file changed, 66 insertions(+), 34 deletions(-) diff --git a/third_party/mlir/lib/Transforms/LoopFusion.cpp b/third_party/mlir/lib/Transforms/LoopFusion.cpp index 188165b94e1d08..15dc36c9c136db 100644 --- a/third_party/mlir/lib/Transforms/LoopFusion.cpp +++ b/third_party/mlir/lib/Transforms/LoopFusion.cpp @@ -322,6 +322,44 @@ struct MemRefDependenceGraph { return false; } + // Returns the unique AffineStoreOp in `node` that meets all the following: + // *) store is the only one that writes to a function-local memref live out + // of `node`, + // *) store is not the source of a self-dependence on `node`. + // Otherwise, returns a null AffineStoreOp. + AffineStoreOp getUniqueOutgoingStore(Node *node) { + AffineStoreOp uniqueStore; + + // Return null if `node` doesn't have any outgoing edges. + auto outEdgeIt = outEdges.find(node->id); + if (outEdgeIt == outEdges.end()) + return nullptr; + + const auto &nodeOutEdges = outEdgeIt->second; + for (auto *op : node->stores) { + auto storeOp = cast(op); + auto *memref = storeOp.getMemRef(); + // Skip this store if there are no dependences on its memref. This means + // that store either: + // *) writes to a memref that is only read within the same loop nest + // (self-dependence edges are not represented in graph at the moment), + // *) writes to a function live out memref (function parameter), or + // *) is dead. + if (llvm::all_of(nodeOutEdges, [=](const Edge &edge) { + return (edge.value != memref); + })) + continue; + + if (uniqueStore) + // Found multiple stores to function-local live-out memrefs. + return nullptr; + // Found first store to function-local live-out memref. + uniqueStore = storeOp; + } + + return uniqueStore; + } + // Returns true if node 'id' can be removed from the graph. Returns false // otherwise. A node can be removed from the graph iff the following // conditions are met: @@ -963,42 +1001,30 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, return newMemRef; } -// Checks if node 'srcId' (which writes to a live out memref), can be safely -// fused into node 'dstId'. Returns true if the following conditions are met: -// *) 'srcNode' only writes to live out 'memref'. -// *) 'srcNode' has exactly one output edge on 'memref' (which is to 'dstId'). -// *) 'dstNode's read/write region to 'memref' is a super set of 'srcNode's -// write region to 'memref'. +// Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId' +// may write to multiple memrefs but it is required that only one of them, +// 'srcLiveOutStoreOp', have an output edge. +// Returns true if 'dstNode's read/write region to 'memref' is a super set of +// 'srcNode's write region to 'memref'. // TODO(andydavis) Generalize this to handle more live in/out cases. static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, - Value *memref, + AffineStoreOp srcLiveOutStoreOp, MemRefDependenceGraph *mdg) { - auto *srcNode = mdg->getNode(srcId); + assert(srcLiveOutStoreOp && "Expected a valid store op"); + assert(mdg->getOutEdgeCount(srcId) == 1 && "Expected only one output edge"); auto *dstNode = mdg->getNode(dstId); + Value *memref = srcLiveOutStoreOp.getMemRef(); - // Gather all memrefs from 'srcNode' store ops. - DenseSet storeMemrefs; - for (auto *storeOpInst : srcNode->stores) { - storeMemrefs.insert(cast(storeOpInst).getMemRef()); - } - // Return false if any of the following are true: - // *) 'srcNode' writes to a live in/out memref other than 'memref'. - // *) 'srcNode' has more than one output edge on 'memref'. - // Check that all stores are to the same memref. - if (storeMemrefs.size() != 1 || - mdg->getOutEdgeCount(srcNode->id, memref) != 1) - return false; - // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOpInst' on 'memref'. - auto *srcStoreOpInst = srcNode->stores.front(); - MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); - if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) { + // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOp' on 'memref'. + MemRefRegion srcWriteRegion(srcLiveOutStoreOp.getLoc()); + if (failed(srcWriteRegion.compute(srcLiveOutStoreOp, /*loopDepth=*/0))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute MemRefRegion for source operation\n."); return false; } SmallVector srcShape; // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'. - // by 'srcStoreOpInst' at depth 'dstLoopDepth'. + // by 'srcStoreOp' at depth 'dstLoopDepth'. Optional srcNumElements = srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape); if (!srcNumElements.hasValue()) @@ -1491,17 +1517,25 @@ struct GreedyFusion { // Skip if 'srcNode' is not a loop nest. if (!isa(srcNode->op)) continue; - // Skip if 'srcNode' has more than one store to any memref. - // TODO(andydavis) Support fusing multi-output src loop nests. - if (srcNode->stores.size() != 1) + // Skip if 'srcNode' has more than one live-out store to a + // function-local memref. + // TODO(andydavis) Support more generic multi-output src loop nests + // fusion. + auto srcStoreOp = mdg->getUniqueOutgoingStore(srcNode); + if (!srcStoreOp) continue; + // Unique outgoing store found must write to 'memref' since 'memref' + // is the one that established the producer-consumer relationship + // between 'srcNode' and 'dstNode'. + assert(srcStoreOp.getMemRef() == memref && + "Found store to unexpected memref"); // Skip if 'srcNode' writes to any live in or escaping memrefs, // and cannot be fused. bool writesToLiveInOrOut = mdg->writesToLiveInOrEscapingMemrefs(srcNode->id); if (writesToLiveInOrOut && - !canFuseSrcWhichWritesToLiveOut(srcId, dstId, memref, mdg)) + !canFuseSrcWhichWritesToLiveOut(srcId, dstId, srcStoreOp, mdg)) continue; // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'. @@ -1515,8 +1549,6 @@ struct GreedyFusion { if (insertPointInst == nullptr) continue; - // Get unique 'srcNode' store op. - auto *srcStoreOpInst = srcNode->stores.front(); // Gather 'dstNode' store ops to 'memref'. SmallVector dstStoreOpInsts; for (auto *storeOpInst : dstNode->stores) @@ -1526,8 +1558,8 @@ struct GreedyFusion { unsigned bestDstLoopDepth; mlir::ComputationSliceState sliceState; // Check if fusion would be profitable. - if (!isFusionProfitable(srcStoreOpInst, srcStoreOpInst, - dstLoadOpInsts, dstStoreOpInsts, &sliceState, + if (!isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts, + dstStoreOpInsts, &sliceState, &bestDstLoopDepth, maximalFusion)) continue; // TODO(andydavis) Remove the following test code when canFuseLoops @@ -1542,7 +1574,7 @@ struct GreedyFusion { } // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. auto sliceLoopNest = mlir::insertBackwardComputationSlice( - srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); + srcStoreOp, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); if (sliceLoopNest) { LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n" << *sliceLoopNest.getOperation() << "\n");