From edbc0e30a9e587cee1189be023b9385adc2f239a Mon Sep 17 00:00:00 2001 From: srcarroll <50210727+srcarroll@users.noreply.github.com> Date: Wed, 3 Jul 2024 14:03:54 -0500 Subject: [PATCH] [mlir][loops] Reland Refactor LoopFuseSiblingOp and support parallel fusion #94391 (#97607) The refactor had a bug where the fused loop was inserted in an incorrect location. This patch fixes the bug and relands the original PR https://github.com/llvm/llvm-project/pull/94391. This patch refactors code related to LoopFuseSiblingOp transform in attempt to reduce duplicate common code. The aim is to refactor as much as possible to a functions on LoopLikeOpInterfaces, but this is still a work in progress. A full refactor will require more additions to the LoopLikeOpInterface. In addition, scf.parallel fusion support has been added. --- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 3 +- mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 20 ++ .../mlir/Interfaces/LoopLikeInterface.h | 20 ++ mlir/lib/Dialect/SCF/IR/SCF.cpp | 38 +++ .../SCF/TransformOps/SCFTransformOps.cpp | 140 ++------- .../SCF/Transforms/ParallelLoopFusion.cpp | 80 +---- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 279 +++++++++++------ mlir/lib/Interfaces/LoopLikeInterface.cpp | 59 ++++ .../SCF/transform-loop-fuse-sibling.mlir | 290 +++++++++++++++++- 9 files changed, 646 insertions(+), 283 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index f35ea962bea16..bf95fbe6721cf 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -303,7 +303,8 @@ def ForallOp : SCF_Op<"forall", [ DeclareOpInterfaceMethods, + "replaceWithAdditionalYields", "promoteIfSingleIteration", + "yieldTiledValuesAndReplace"]>, RecursiveMemoryEffects, SingleBlockImplicitTerminator<"scf::InParallelOp">, DeclareOpInterfaceMethods, diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index de807c3e4e1f8..6a40304e2eeba 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -181,6 +181,16 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef sizes); void getPerfectlyNestedLoops(SmallVectorImpl &nestedLoops, scf::ForOp root); +//===----------------------------------------------------------------------===// +// Fusion related helpers +//===----------------------------------------------------------------------===// + +/// Check structural compatibility between two loops such as iteration space +/// and dominance. +bool checkFusionStructuralLegality(LoopLikeOpInterface target, + LoopLikeOpInterface source, + Diagnostic &diag); + /// Given two scf.forall loops, `target` and `source`, fuses `target` into /// `source`. Assumes that the given loops are siblings and are independent of /// each other. @@ -202,6 +212,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter); +/// Given two scf.parallel loops, `target` and `source`, fuses `target` into +/// `source`. Assumes that the given loops are siblings and are independent of +/// each other. +/// +/// This function does not perform any legality checks and simply fuses the +/// loops. The caller is responsible for ensuring that the loops are legal to +/// fuse. +scf::ParallelOp fuseIndependentSiblingParallelLoops(scf::ParallelOp target, + scf::ParallelOp source, + RewriterBase &rewriter); } // namespace mlir #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_ diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h index 9925fc6ce6ca9..d08e097a9b4af 100644 --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h @@ -90,4 +90,24 @@ struct JamBlockGatherer { /// Include the generated interface declarations. #include "mlir/Interfaces/LoopLikeInterface.h.inc" +namespace mlir { +/// A function that rewrites `target`'s terminator as a teminator obtained by +/// fusing `source` into `target`. +using FuseTerminatorFn = + function_ref; + +/// Returns a fused `LoopLikeOpInterface` created by fusing `source` to +/// `target`. The `NewYieldValuesFn` callback is used to pass to the +/// `replaceWithAdditionalYields` interface method to replace the loop with a +/// new loop with (possibly) additional yields, while the `FuseTerminatorFn` +/// callback is repsonsible for updating the fused loop terminator. +LoopLikeOpInterface createFused(LoopLikeOpInterface target, + LoopLikeOpInterface source, + RewriterBase &rewriter, + NewYieldValuesFn newYieldValuesFn, + FuseTerminatorFn fuseTerminatorFn); + +} // namespace mlir + #endif // MLIR_INTERFACES_LOOPLIKEINTERFACE_H_ diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 907d7f794593d..cb15e0ecebf05 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -618,6 +618,44 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point, SmallVector ForallOp::getLoopRegions() { return {&getRegion()}; } +FailureOr ForallOp::replaceWithAdditionalYields( + RewriterBase &rewriter, ValueRange newInitOperands, + bool replaceInitOperandUsesInLoop, + const NewYieldValuesFn &newYieldValuesFn) { + // Create a new loop before the existing one, with the extra operands. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(getOperation()); + SmallVector inits(getOutputs()); + llvm::append_range(inits, newInitOperands); + scf::ForallOp newLoop = rewriter.create( + getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(), + inits, getMapping(), + /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {}); + + // Move the loop body to the new op. + rewriter.mergeBlocks(getBody(), newLoop.getBody(), + newLoop.getBody()->getArguments().take_front( + getBody()->getNumArguments())); + + if (replaceInitOperandUsesInLoop) { + // Replace all uses of `newInitOperands` with the corresponding basic block + // arguments. + for (auto &&[newOperand, oldOperand] : + llvm::zip(newInitOperands, newLoop.getBody()->getArguments().take_back( + newInitOperands.size()))) { + rewriter.replaceUsesWithIf(newOperand, oldOperand, [&](OpOperand &use) { + Operation *user = use.getOwner(); + return newLoop->isProperAncestor(user); + }); + } + } + + // Replace the old loop. + rewriter.replaceOp(getOperation(), + newLoop->getResults().take_front(getNumResults())); + return cast(newLoop.getOperation()); +} + /// Promotes the loop body of a forallOp to its containing block if it can be /// determined that the loop has a single iteration. LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) { diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 56ff2709a589e..41834fea3bb84 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -261,8 +261,10 @@ loopScheduling(scf::ForOp forOp, return 1; }; - std::optional ubConstant = getConstantIntValue(forOp.getUpperBound()); - std::optional lbConstant = getConstantIntValue(forOp.getLowerBound()); + std::optional ubConstant = + getConstantIntValue(forOp.getUpperBound()); + std::optional lbConstant = + getConstantIntValue(forOp.getLowerBound()); DenseMap opCycles; std::map> wrappedSchedule; for (Operation &op : forOp.getBody()->getOperations()) { @@ -447,113 +449,6 @@ void transform::TakeAssumedBranchOp::getEffects( // LoopFuseSiblingOp //===----------------------------------------------------------------------===// -/// Check if `target` and `source` are siblings, in the context that `target` -/// is being fused into `source`. -/// -/// This is a simple check that just checks if both operations are in the same -/// block and some checks to ensure that the fused IR does not violate -/// dominance. -static DiagnosedSilenceableFailure isOpSibling(Operation *target, - Operation *source) { - // Check if both operations are same. - if (target == source) - return emitSilenceableFailure(source) - << "target and source need to be different loops"; - - // Check if both operations are in the same block. - if (target->getBlock() != source->getBlock()) - return emitSilenceableFailure(source) - << "target and source are not in the same block"; - - // Check if fusion will violate dominance. - DominanceInfo domInfo(source); - if (target->isBeforeInBlock(source)) { - // Since `target` is before `source`, all users of results of `target` - // need to be dominated by `source`. - for (Operation *user : target->getUsers()) { - if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) { - return emitSilenceableFailure(target) - << "user of results of target should be properly dominated by " - "source"; - } - } - } else { - // Since `target` is after `source`, all values used by `target` need - // to dominate `source`. - - // Check if operands of `target` are dominated by `source`. - for (Value operand : target->getOperands()) { - Operation *operandOp = operand.getDefiningOp(); - // Operands without defining operations are block arguments. When `target` - // and `source` occur in the same block, these operands dominate `source`. - if (!operandOp) - continue; - - // Operand's defining operation should properly dominate `source`. - if (!domInfo.properlyDominates(operandOp, source, - /*enclosingOpOk=*/false)) - return emitSilenceableFailure(target) - << "operands of target should be properly dominated by source"; - } - - // Check if values used by `target` are dominated by `source`. - bool failed = false; - OpOperand *failedValue = nullptr; - visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) { - Operation *operandOp = operand->get().getDefiningOp(); - if (operandOp && !domInfo.properlyDominates(operandOp, source, - /*enclosingOpOk=*/false)) { - // `operand` is not an argument of an enclosing block and the defining - // op of `operand` is outside `target` but does not dominate `source`. - failed = true; - failedValue = operand; - } - }); - - if (failed) - return emitSilenceableFailure(failedValue->getOwner()) - << "values used inside regions of target should be properly " - "dominated by source"; - } - - return DiagnosedSilenceableFailure::success(); -} - -/// Check if `target` scf.forall can be fused into `source` scf.forall. -/// -/// This simply checks if both loops have the same bounds, steps and mapping. -/// No attempt is made at checking that the side effects of `target` and -/// `source` are independent of each other. -static bool isForallWithIdenticalConfiguration(Operation *target, - Operation *source) { - auto targetOp = dyn_cast(target); - auto sourceOp = dyn_cast(source); - if (!targetOp || !sourceOp) - return false; - - return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() && - targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() && - targetOp.getMixedStep() == sourceOp.getMixedStep() && - targetOp.getMapping() == sourceOp.getMapping(); -} - -/// Check if `target` scf.for can be fused into `source` scf.for. -/// -/// This simply checks if both loops have the same bounds and steps. No attempt -/// is made at checking that the side effects of `target` and `source` are -/// independent of each other. -static bool isForWithIdenticalConfiguration(Operation *target, - Operation *source) { - auto targetOp = dyn_cast(target); - auto sourceOp = dyn_cast(source); - if (!targetOp || !sourceOp) - return false; - - return targetOp.getLowerBound() == sourceOp.getLowerBound() && - targetOp.getUpperBound() == sourceOp.getUpperBound() && - targetOp.getStep() == sourceOp.getStep(); -} - DiagnosedSilenceableFailure transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, @@ -569,25 +464,32 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, << "source handle (got " << llvm::range_size(sourceOps) << ")"; } - Operation *target = *targetOps.begin(); - Operation *source = *sourceOps.begin(); + auto target = dyn_cast(*targetOps.begin()); + auto source = dyn_cast(*sourceOps.begin()); + if (!target || !source) + return emitSilenceableFailure(target->getLoc()) + << "target or source is not a loop op"; - // Check if the target and source are siblings. - DiagnosedSilenceableFailure diag = isOpSibling(target, source); - if (!diag.succeeded()) - return diag; + // Check if loops can be fused + Diagnostic diag(target.getLoc(), DiagnosticSeverity::Error); + if (!mlir::checkFusionStructuralLegality(target, source, diag)) + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); Operation *fusedLoop; - /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall. - if (isForWithIdenticalConfiguration(target, source)) { + // TODO: Support fusion for loop-like ops besides scf.for, scf.forall + // and scf.parallel. + if (isa(target) && isa(source)) { fusedLoop = fuseIndependentSiblingForLoops( cast(target), cast(source), rewriter); - } else if (isForallWithIdenticalConfiguration(target, source)) { + } else if (isa(target) && isa(source)) { fusedLoop = fuseIndependentSiblingForallLoops( cast(target), cast(source), rewriter); + } else if (isa(target) && isa(source)) { + fusedLoop = fuseIndependentSiblingParallelLoops( + cast(target), cast(source), rewriter); } else return emitSilenceableFailure(target->getLoc()) - << "operations cannot be fused"; + << "unsupported loop type for fusion"; assert(fusedLoop && "failed to fuse operations"); diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp index 5934d85373b03..b775f988576e3 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/OpDefinition.h" @@ -37,24 +38,6 @@ static bool hasNestedParallelOp(ParallelOp ploop) { return walkResult.wasInterrupted(); } -/// Verify equal iteration spaces. -static bool equalIterationSpaces(ParallelOp firstPloop, - ParallelOp secondPloop) { - if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) - return false; - - auto matchOperands = [&](const OperandRange &lhs, - const OperandRange &rhs) -> bool { - // TODO: Extend this to support aliases and equal constants. - return std::equal(lhs.begin(), lhs.end(), rhs.begin()); - }; - return matchOperands(firstPloop.getLowerBound(), - secondPloop.getLowerBound()) && - matchOperands(firstPloop.getUpperBound(), - secondPloop.getUpperBound()) && - matchOperands(firstPloop.getStep(), secondPloop.getStep()); -} - /// Checks if the parallel loops have mixed access to the same buffers. Returns /// `true` if the first parallel loop writes to the same indices that the second /// loop reads. @@ -153,9 +136,10 @@ verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref mayAlias) { + Diagnostic diag(firstPloop.getLoc(), DiagnosticSeverity::Remark); return !hasNestedParallelOp(firstPloop) && !hasNestedParallelOp(secondPloop) && - equalIterationSpaces(firstPloop, secondPloop) && + checkFusionStructuralLegality(firstPloop, secondPloop, diag) && succeeded(verifyDependencies(firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)); } @@ -174,61 +158,9 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, mayAlias)) return; - DominanceInfo dom; - // We are fusing first loop into second, make sure there are no users of the - // first loop results between loops. - for (Operation *user : firstPloop->getUsers()) - if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false)) - return; - - ValueRange inits1 = firstPloop.getInitVals(); - ValueRange inits2 = secondPloop.getInitVals(); - - SmallVector newInitVars(inits1.begin(), inits1.end()); - newInitVars.append(inits2.begin(), inits2.end()); - - IRRewriter b(builder); - b.setInsertionPoint(secondPloop); - auto newSecondPloop = b.create( - secondPloop.getLoc(), secondPloop.getLowerBound(), - secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars); - - Block *newBlock = newSecondPloop.getBody(); - auto term1 = cast(block1->getTerminator()); - auto term2 = cast(block2->getTerminator()); - - b.inlineBlockBefore(block2, newBlock, newBlock->begin(), - newBlock->getArguments()); - b.inlineBlockBefore(block1, newBlock, newBlock->begin(), - newBlock->getArguments()); - - ValueRange results = newSecondPloop.getResults(); - if (!results.empty()) { - b.setInsertionPointToEnd(newBlock); - - ValueRange reduceArgs1 = term1.getOperands(); - ValueRange reduceArgs2 = term2.getOperands(); - SmallVector newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); - newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); - - auto newReduceOp = b.create(term2.getLoc(), newReduceArgs); - - for (auto &&[i, reg] : llvm::enumerate(llvm::concat( - term1.getReductions(), term2.getReductions()))) { - Block &oldRedBlock = reg.front(); - Block &newRedBlock = newReduceOp.getReductions()[i].front(); - b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(), - newRedBlock.getArguments()); - } - - firstPloop.replaceAllUsesWith(results.take_front(inits1.size())); - secondPloop.replaceAllUsesWith(results.take_back(inits2.size())); - } - term1->erase(); - term2->erase(); - firstPloop.erase(); - secondPloop.erase(); - secondPloop = newSecondPloop; + IRRewriter rewriter(builder); + secondPloop = mlir::fuseIndependentSiblingParallelLoops( + firstPloop, secondPloop, rewriter); } void mlir::scf::naivelyFuseParallelOps( diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index c0ee9d2afe91c..abfc9a1b4d444 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" @@ -1262,54 +1263,131 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp, return tileLoops; } +//===----------------------------------------------------------------------===// +// Fusion related helpers +//===----------------------------------------------------------------------===// + +/// Check if `target` and `source` are siblings, in the context that `target` +/// is being fused into `source`. +/// +/// This is a simple check that just checks if both operations are in the same +/// block and some checks to ensure that the fused IR does not violate +/// dominance. +static bool isOpSibling(Operation *target, Operation *source, + Diagnostic &diag) { + // Check if both operations are same. + if (target == source) { + diag << "target and source need to be different loops"; + return false; + } + + // Check if both operations are in the same block. + if (target->getBlock() != source->getBlock()) { + diag << "target and source are not in the same block"; + return false; + } + + // Check if fusion will violate dominance. + DominanceInfo domInfo(source); + if (target->isBeforeInBlock(source)) { + // Since `target` is before `source`, all users of results of `target` + // need to be dominated by `source`. + for (Operation *user : target->getUsers()) { + if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) { + diag << "user of results of target should " + "be properly dominated by source"; + return false; + } + } + } else { + // Since `target` is after `source`, all values used by `target` need + // to dominate `source`. + + // Check if operands of `target` are dominated by `source`. + for (Value operand : target->getOperands()) { + Operation *operandOp = operand.getDefiningOp(); + // Operands without defining operations are block arguments. When `target` + // and `source` occur in the same block, these operands dominate `source`. + if (!operandOp) + continue; + + // Operand's defining operation should properly dominate `source`. + if (!domInfo.properlyDominates(operandOp, source, + /*enclosingOpOk=*/false)) { + diag << "operands of target should be properly dominated by source"; + return false; + } + } + + // Check if values used by `target` are dominated by `source`. + bool failed = false; + OpOperand *failedValue = nullptr; + visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) { + Operation *operandOp = operand->get().getDefiningOp(); + if (operandOp && !domInfo.properlyDominates(operandOp, source, + /*enclosingOpOk=*/false)) { + // `operand` is not an argument of an enclosing block and the defining + // op of `operand` is outside `target` but does not dominate `source`. + failed = true; + failedValue = operand; + } + }); + + if (failed) { + diag << "values used inside regions of target should be properly " + "dominated by source"; + diag.attachNote(failedValue->getOwner()->getLoc()) << "see operation"; + return false; + } + } + + return true; +} + +bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target, + LoopLikeOpInterface source, + Diagnostic &diag) { + if (target->getName() != source->getName()) { + diag << "target and source must be same loop type"; + return false; + } + + bool iterSpaceEq = + target.getLoopLowerBounds() == source.getLoopLowerBounds() && + target.getLoopUpperBounds() == source.getLoopUpperBounds() && + target.getLoopSteps() == source.getLoopSteps(); + // TODO: Decouple checks on concrete loop types and move this function + // somewhere for general utility for `LoopLikeOpInterface` + if (auto forAllTarget = dyn_cast(*target)) + iterSpaceEq = iterSpaceEq && forAllTarget.getMapping() == + cast(*source).getMapping(); + if (!iterSpaceEq) { + diag << "target and source iteration spaces must be equal"; + return false; + } + return isOpSibling(target, source, diag); +} + scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter) { - unsigned numTargetOuts = target.getNumResults(); - unsigned numSourceOuts = source.getNumResults(); - - // Create fused shared_outs. - SmallVector fusedOuts; - llvm::append_range(fusedOuts, target.getOutputs()); - llvm::append_range(fusedOuts, source.getOutputs()); - - // Create a new scf.forall op after the source loop. - rewriter.setInsertionPointAfter(source); - scf::ForallOp fusedLoop = rewriter.create( - source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(), - source.getMixedStep(), fusedOuts, source.getMapping()); - - // Map control operands. - IRMapping mapping; - mapping.map(target.getInductionVars(), fusedLoop.getInductionVars()); - mapping.map(source.getInductionVars(), fusedLoop.getInductionVars()); - - // Map shared outs. - mapping.map(target.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); - mapping.map(source.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); - - // Append everything except the terminator into the fused operation. - rewriter.setInsertionPointToStart(fusedLoop.getBody()); - for (Operation &op : target.getBody()->without_terminator()) - rewriter.clone(op, mapping); - for (Operation &op : source.getBody()->without_terminator()) - rewriter.clone(op, mapping); - - // Fuse the old terminator in_parallel ops into the new one. - scf::InParallelOp targetTerm = target.getTerminator(); - scf::InParallelOp sourceTerm = source.getTerminator(); - scf::InParallelOp fusedTerm = fusedLoop.getTerminator(); - rewriter.setInsertionPointToStart(fusedTerm.getBody()); - for (Operation &op : targetTerm.getYieldingOps()) - rewriter.clone(op, mapping); - for (Operation &op : sourceTerm.getYieldingOps()) - rewriter.clone(op, mapping); - - // Replace old loops by substituting their uses by results of the fused loop. - rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); - rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); + scf::ForallOp fusedLoop = cast(createFused( + target, source, rewriter, + [&](OpBuilder &b, Location loc, ArrayRef newBBArgs) { + // `ForallOp` does not have yields, rather an `InParallelOp` terminator. + return ValueRange{}; + }, + [&](RewriterBase &b, LoopLikeOpInterface source, + LoopLikeOpInterface &target, IRMapping mapping) { + auto sourceForall = cast(source); + auto targetForall = cast(target); + scf::InParallelOp fusedTerm = targetForall.getTerminator(); + b.setInsertionPointToEnd(fusedTerm.getBody()); + for (Operation &op : sourceForall.getTerminator().getYieldingOps()) + b.clone(op, mapping); + })); + rewriter.replaceOp(source, + fusedLoop.getResults().take_back(source.getNumResults())); return fusedLoop; } @@ -1317,49 +1395,74 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter) { - unsigned numTargetOuts = target.getNumResults(); - unsigned numSourceOuts = source.getNumResults(); - - // Create fused init_args, with target's init_args before source's init_args. - SmallVector fusedInitArgs; - llvm::append_range(fusedInitArgs, target.getInitArgs()); - llvm::append_range(fusedInitArgs, source.getInitArgs()); - - // Create a new scf.for op after the source loop (with scf.yield terminator - // (without arguments) only in case its init_args is empty). - rewriter.setInsertionPointAfter(source); - scf::ForOp fusedLoop = rewriter.create( - source.getLoc(), source.getLowerBound(), source.getUpperBound(), - source.getStep(), fusedInitArgs); - - // Map original induction variables and operands to those of the fused loop. - IRMapping mapping; - mapping.map(target.getInductionVar(), fusedLoop.getInductionVar()); - mapping.map(target.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); - mapping.map(source.getInductionVar(), fusedLoop.getInductionVar()); - mapping.map(source.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); - - // Merge target's body into the new (fused) for loop and then source's body. - rewriter.setInsertionPointToStart(fusedLoop.getBody()); - for (Operation &op : target.getBody()->without_terminator()) - rewriter.clone(op, mapping); - for (Operation &op : source.getBody()->without_terminator()) - rewriter.clone(op, mapping); - - // Build fused yield results by appropriately mapping original yield operands. - SmallVector yieldResults; - for (Value operand : target.getBody()->getTerminator()->getOperands()) - yieldResults.push_back(mapping.lookupOrDefault(operand)); - for (Value operand : source.getBody()->getTerminator()->getOperands()) - yieldResults.push_back(mapping.lookupOrDefault(operand)); - if (!yieldResults.empty()) - rewriter.create(source.getLoc(), yieldResults); - - // Replace old loops by substituting their uses by results of the fused loop. - rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); - rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); + scf::ForOp fusedLoop = cast(createFused( + target, source, rewriter, + [&](OpBuilder &b, Location loc, ArrayRef newBBArgs) { + return source.getYieldedValues(); + }, + [&](RewriterBase &b, LoopLikeOpInterface source, + LoopLikeOpInterface &target, IRMapping mapping) { + auto targetFor = cast(target); + auto newTerm = b.clone(*targetFor.getBody()->getTerminator(), mapping); + b.replaceOp(targetFor.getBody()->getTerminator(), newTerm); + })); + rewriter.replaceOp(source, + fusedLoop.getResults().take_back(source.getNumResults())); + return fusedLoop; +} + +// TODO: Finish refactoring this a la the above, but likely requires additional +// interface methods. +scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops( + scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) { + OpBuilder::InsertionGuard guard(rewriter); + Block *block1 = target.getBody(); + Block *block2 = source.getBody(); + auto term1 = cast(block1->getTerminator()); + auto term2 = cast(block2->getTerminator()); + + ValueRange inits1 = target.getInitVals(); + ValueRange inits2 = source.getInitVals(); + + SmallVector newInitVars(inits1.begin(), inits1.end()); + newInitVars.append(inits2.begin(), inits2.end()); + + rewriter.setInsertionPoint(source); + auto fusedLoop = rewriter.create( + rewriter.getFusedLoc(target.getLoc(), source.getLoc()), + source.getLowerBound(), source.getUpperBound(), source.getStep(), + newInitVars); + Block *newBlock = fusedLoop.getBody(); + rewriter.inlineBlockBefore(block2, newBlock, newBlock->begin(), + newBlock->getArguments()); + rewriter.inlineBlockBefore(block1, newBlock, newBlock->begin(), + newBlock->getArguments()); + + ValueRange results = fusedLoop.getResults(); + if (!results.empty()) { + rewriter.setInsertionPointToEnd(newBlock); + + ValueRange reduceArgs1 = term1.getOperands(); + ValueRange reduceArgs2 = term2.getOperands(); + SmallVector newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); + newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); + + auto newReduceOp = rewriter.create( + rewriter.getFusedLoc(term1.getLoc(), term2.getLoc()), newReduceArgs); + + for (auto &&[i, reg] : llvm::enumerate(llvm::concat( + term1.getReductions(), term2.getReductions()))) { + Block &oldRedBlock = reg.front(); + Block &newRedBlock = newReduceOp.getReductions()[i].front(); + rewriter.inlineBlockBefore(&oldRedBlock, &newRedBlock, + newRedBlock.begin(), + newRedBlock.getArguments()); + } + } + rewriter.replaceOp(target, results.take_front(inits1.size())); + rewriter.replaceOp(source, results.take_back(inits2.size())); + rewriter.eraseOp(term1); + rewriter.eraseOp(term2); return fusedLoop; } diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp index 1e0e87b64e811..5a119a7cf2659 100644 --- a/mlir/lib/Interfaces/LoopLikeInterface.cpp +++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp @@ -8,6 +8,8 @@ #include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "llvm/ADT/DenseSet.h" @@ -113,3 +115,60 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) { return success(); } + +LoopLikeOpInterface mlir::createFused(LoopLikeOpInterface target, + LoopLikeOpInterface source, + RewriterBase &rewriter, + NewYieldValuesFn newYieldValuesFn, + FuseTerminatorFn fuseTerminatorFn) { + auto targetIterArgs = target.getRegionIterArgs(); + std::optional> targetInductionVar = + target.getLoopInductionVars(); + SmallVector targetYieldOperands(target.getYieldedValues()); + auto sourceIterArgs = source.getRegionIterArgs(); + std::optional> sourceInductionVar = + *source.getLoopInductionVars(); + SmallVector sourceYieldOperands(source.getYieldedValues()); + auto sourceRegion = source.getLoopRegions().front(); + + FailureOr maybeFusedLoop = + target.replaceWithAdditionalYields(rewriter, source.getInits(), + /*replaceInitOperandUsesInLoop=*/false, + newYieldValuesFn); + if (failed(maybeFusedLoop)) + llvm_unreachable("failed to replace loop"); + LoopLikeOpInterface fusedLoop = *maybeFusedLoop; + // Since the target op is rewritten at the original's location, we move it to + // the soure op's location. + rewriter.moveOpBefore(fusedLoop, source); + + // Map control operands. + IRMapping mapping; + std::optional> fusedInductionVar = + fusedLoop.getLoopInductionVars(); + if (fusedInductionVar) { + if (!targetInductionVar || !sourceInductionVar) + llvm_unreachable( + "expected target and source loops to have induction vars"); + mapping.map(*targetInductionVar, *fusedInductionVar); + mapping.map(*sourceInductionVar, *fusedInductionVar); + } + mapping.map(targetIterArgs, + fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size())); + mapping.map(targetYieldOperands, + fusedLoop.getYieldedValues().take_front(targetIterArgs.size())); + mapping.map(sourceIterArgs, + fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size())); + mapping.map(sourceYieldOperands, + fusedLoop.getYieldedValues().take_back(sourceIterArgs.size())); + // Append everything except the terminator into the fused operation. + rewriter.setInsertionPoint( + fusedLoop.getLoopRegions().front()->front().getTerminator()); + for (Operation &op : sourceRegion->front().without_terminator()) + rewriter.clone(op, mapping); + + // TODO: Replace with corresponding interface method if added + fuseTerminatorFn(rewriter, source, fusedLoop, mapping); + + return fusedLoop; +} diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir index 54dd2bdf953ca..f8246b74a5744 100644 --- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir +++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir @@ -47,6 +47,169 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func @fuse_two_parallel +// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) { +func.func @fuse_two_parallel(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { +// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index +// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index +// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1. + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1fp = arith.constant 1.0 : f32 +// CHECK: [[SUM:%.*]] = memref.alloc() + %sum = memref.alloc() : memref<2x2xf32> +// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { +// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]] +// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]] +// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK-NOT: scf.parallel +// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]] +// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]] +// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]] +// CHECK: scf.reduce +// CHECK: } + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + %sum_elem = arith.addf %B_elem, %c1fp : f32 + memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> + scf.reduce + } + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + %product_elem = arith.mulf %sum_elem, %A_elem : f32 + memref.store %product_elem, %B[%i, %j] : memref<2x2xf32> + scf.reduce + } +// CHECK: memref.dealloc [[SUM]] + memref.dealloc %sum : memref<2x2xf32> + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @fuse_two_parallel_reverse +// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) { +func.func @fuse_two_parallel_reverse(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { +// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index +// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index +// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1. + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1fp = arith.constant 1.0 : f32 +// CHECK: [[SUM:%.*]] = memref.alloc() + %sum = memref.alloc() : memref<2x2xf32> +// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { +// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]] +// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]] +// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]] +// CHECK-NOT: scf.parallel +// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]] +// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]] +// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK: scf.reduce +// CHECK: } + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + %sum_elem = arith.addf %B_elem, %c1fp : f32 + memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> + scf.reduce + } + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + %product_elem = arith.mulf %sum_elem, %A_elem : f32 + memref.store %product_elem, %B[%i, %j] : memref<2x2xf32> + scf.reduce + } +// CHECK: memref.dealloc [[SUM]] + memref.dealloc %sum : memref<2x2xf32> + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused = transform.loop.fuse_sibling %parallel#1 into %parallel#0 : (!transform.any_op,!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @fuse_reductions_two +// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32) +func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) +// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) +// CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32) +// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]] +// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]] +// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) { +// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32 +// CHECK: scf.reduce.return %[[R]] : f32 +// CHECK: } +// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32 +// CHECK: scf.reduce.return %[[R]] : f32 +// CHECK: } +// CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32 + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %init1 = arith.constant 1.0 : f32 + %init2 = arith.constant 2.0 : f32 + %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + scf.reduce(%A_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.addf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 { + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + scf.reduce(%B_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.mulf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + return %res1, %res2 : f32, f32 +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + // CHECK: func.func @fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}} func.func @fuse_2nd_for_into_1st(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) { // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index @@ -208,6 +371,62 @@ module attributes {transform.with_named_sequence} { } } + +// ----- + +// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 32) +#map = affine_map<(d0) -> (d0 * 32)> +#map1 = affine_map<(d0, d1) -> (d0, d1)> +module { + // CHECK: func.func @loop_sibling_fusion(%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}} + func.func @loop_sibling_fusion(%arg0: tensor<128xf32>, %arg1: tensor<128x128xf16>, %arg2: tensor<128x64xf32>, %arg3: tensor<128x128xf32>) -> (tensor<128xf32>, tensor<128x128xf16>) { + // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<128x128xf16> + // CHECK-NEXT: %[[RESULTS:.*]]:2 = scf.forall (%[[I:.*]]) in (4) shared_outs(%[[S1:.*]] = %[[ARG0]], %[[S2:.*]] = %[[ARG1]]) -> (tensor<128xf32>, tensor<128x128xf16>) { + // CHECK-NEXT: %[[IDX:.*]] = affine.apply #[[$MAP]](%[[I]]) + // CHECK-NEXT: %[[SLICE0:.*]] = tensor.extract_slice %[[ARG3]][%[[IDX]], 0] [32, 1] [1, 1] : tensor<128x128xf32> to tensor<32xf32> + // CHECK-NEXT: %[[SLICE1:.*]] = tensor.extract_slice %[[ARG3]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<128x128xf32> to tensor<32x128xf32> + // CHECK-NEXT: %[[SLICE2:.*]] = tensor.extract_slice %[[EMPTY]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<128x128xf16> to tensor<32x128xf16> + // CHECK-NEXT: %[[GENERIC:.*]] = linalg.generic {{.*}} ins(%[[SLICE1]] : {{.*}}) outs(%[[SLICE2]] : {{.*}}) + // CHECK: scf.forall.in_parallel { + // CHECK-NEXT: tensor.parallel_insert_slice %[[SLICE0]] into %[[S1]][%[[IDX]]] [32] [1] : tensor<32xf32> into tensor<128xf32> + // CHECK-NEXT: tensor.parallel_insert_slice %[[GENERIC]] into %[[S2]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<32x128xf16> into tensor<128x128xf16> + // CHECK-NEXT: } + // CHECK-NEXT: } {mapping = [#gpu.warp]} + // CHECK-NEXT: return %[[RESULTS]]#0, %[[RESULTS]]#1 + %0 = scf.forall (%arg4) in (4) shared_outs(%arg5 = %arg0) -> (tensor<128xf32>) { + %3 = affine.apply #map(%arg4) + %extracted_slice = tensor.extract_slice %arg3[%3, 0] [32, 1] [1, 1] : tensor<128x128xf32> to tensor<32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %extracted_slice into %arg5[%3] [32] [1] : tensor<32xf32> into tensor<128xf32> + } + } {mapping = [#gpu.warp]} + %1 = tensor.empty() : tensor<128x128xf16> + %2 = scf.forall (%arg4) in (4) shared_outs(%arg5 = %arg1) -> (tensor<128x128xf16>) { + %3 = affine.apply #map(%arg4) + %extracted_slice = tensor.extract_slice %arg3[%3, 0] [32, 128] [1, 1] : tensor<128x128xf32> to tensor<32x128xf32> + %extracted_slice_0 = tensor.extract_slice %1[%3, 0] [32, 128] [1, 1] : tensor<128x128xf16> to tensor<32x128xf16> + %4 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<32x128xf32>) outs(%extracted_slice_0 : tensor<32x128xf16>) { + ^bb0(%in: f32, %out: f16): + %5 = arith.truncf %in : f32 to f16 + linalg.yield %5 : f16 + } -> tensor<32x128xf16> + scf.forall.in_parallel { + tensor.parallel_insert_slice %4 into %arg5[%3, 0] [32, 128] [1, 1] : tensor<32x128xf16> into tensor<128x128xf16> + } + } {mapping = [#gpu.warp]} + return %0, %2 : tensor<128xf32>, tensor<128x128xf16> + } +} + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%root: !transform.any_op) { + %loops = transform.structured.match ops{["scf.forall"]} in %root : (!transform.any_op) -> !transform.any_op + %loop1, %loop2 = transform.split_handle %loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %loop3 = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.yield + } +} + // ----- func.func @source_for_uses_result_of_target_for_err(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) { @@ -282,8 +501,9 @@ func.func @target_for_region_uses_result_of_source_for_err(%A: tensor<128xf32>, %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32> scf.yield %6 : tensor<128xf32> } - %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) { // expected-error @below {{values used inside regions of target should be properly dominated by source}} + %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) { + // expected-note @below {{see operation}} %dup2 = vector.transfer_read %1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> %dup5 = arith.addf %dup3, %dup2 : vector<16xf32> @@ -328,6 +548,74 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +func.func @non_matching_iteration_spaces_err(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1fp = arith.constant 1.0 : f32 + %sum = memref.alloc() : memref<2x2xf32> + // expected-error @below {{target and source iteration spaces must be equal}} + scf.parallel (%i) = (%c0) to (%c2) step (%c1) { + %B_elem = memref.load %B[%i, %c0] : memref<2x2xf32> + %sum_elem = arith.addf %B_elem, %c1fp : f32 + memref.store %sum_elem, %sum[%i, %c0] : memref<2x2xf32> + scf.reduce + } + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + %product_elem = arith.mulf %sum_elem, %A_elem : f32 + memref.store %product_elem, %B[%i, %j] : memref<2x2xf32> + scf.reduce + } + memref.dealloc %sum : memref<2x2xf32> + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +func.func @non_matching_loop_types_err(%A: memref<2xf32>, %B: memref<2xf32>) { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1fp = arith.constant 1.0 : f32 + %sum = memref.alloc() : memref<2xf32> + // expected-error @below {{target and source must be same loop type}} + scf.for %i = %c0 to %c2 step %c1 { + %B_elem = memref.load %B[%i] : memref<2xf32> + %sum_elem = arith.addf %B_elem, %c1fp : f32 + memref.store %sum_elem, %sum[%i] : memref<2xf32> + } + scf.parallel (%i) = (%c0) to (%c2) step (%c1) { + %sum_elem = memref.load %sum[%i] : memref<2xf32> + %A_elem = memref.load %A[%i] : memref<2xf32> + %product_elem = arith.mulf %sum_elem, %A_elem : f32 + memref.store %product_elem, %B[%i] : memref<2xf32> + scf.reduce + } + memref.dealloc %sum : memref<2xf32> + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %fused = transform.loop.fuse_sibling %0 into %1 : (!transform.any_op,!transform.any_op) -> !transform.any_op + transform.yield + } +} + // ----- // CHECK: func.func @foreach_loop_pair_fuse([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}