diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 2488e7ef94455..f792d7d04dddc 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -21,9 +21,10 @@ #define DEBUG_TYPE "differentiation" -#include "swift/AST/AnyFunctionRef.h" +#include "Differentiation.h" #include "swift/AST/ASTMangler.h" #include "swift/AST/ASTPrinter.h" +#include "swift/AST/AnyFunctionRef.h" #include "swift/AST/AutoDiff.h" #include "swift/AST/Builtins.h" #include "swift/AST/DeclContext.h" @@ -388,9 +389,11 @@ class PullbackInfo { /// The original function. SILFunction *const original; - /// The pullback data structures. - DenseMap> - pullbackDataStructures; + /// Mapping from original basic blocks to pullback structs. + DenseMap pullbackStructs; + + /// Mapping from original basic blocks to predecessor enums. + DenseMap predecessorEnums; /// Mapping from `apply` and `struct_extract` instructions in the original /// function to the corresponding pullback declaration in the pullback struct. @@ -545,9 +548,8 @@ class PullbackInfo { /// Creates a struct declaration with the given VJP generic signature, for /// storing the pullback values and predecessor of the given original block. - std::pair + StructDecl * createPullbackStruct(SILBasicBlock *originalBB, SILAutoDiffIndices indices, - EnumDecl *predecessorEnum, CanGenericSignature vjpGenericSig) { auto *original = originalBB->getParent(); auto &astCtx = original->getASTContext(); @@ -575,15 +577,6 @@ class PullbackInfo { pullbackStruct->computeType(); assert(pullbackStruct->hasInterfaceType()); file.addVisibleDecl(pullbackStruct); - // Add predecessor field if not entry block. - VarDecl *predecessorEnumField = nullptr; - if (!originalBB->isEntry()) { - predecessorEnumField = addVarDecl( - pullbackStruct, astCtx.getIdentifier("predecessor").str(), - predecessorEnum->getDeclaredInterfaceType()); - pullbackStructPredecessorFields.insert( - {pullbackStruct, predecessorEnumField}); - } LLVM_DEBUG({ auto &s = getADDebugStream(); s << "Pullback struct created for function @" << original->getName() @@ -591,77 +584,51 @@ class PullbackInfo { pullbackStruct->print(s); s << '\n'; }); - return {pullbackStruct, predecessorEnumField}; + return pullbackStruct; } public: PullbackInfo(const PullbackInfo &) = delete; PullbackInfo &operator=(const PullbackInfo &) = delete; - explicit PullbackInfo(SILFunction *original, SILFunction *vjp, - const SILAutoDiffIndices &indices, - Lowering::TypeConverter &typeConverter) - : original(original), typeConverter(typeConverter) { - // Get VJP generic signature. - CanGenericSignature vjpGenSig = nullptr; - if (auto *vjpGenEnv = vjp->getGenericEnvironment()) - vjpGenSig = vjpGenEnv->getGenericSignature()->getCanonicalSignature(); - // Create predecessor enum and pullback struct for each original block. - // TODO: Adapt data structure generation to handle loops. - // - Generate all pullback structs before predecessor enums. - // - Mark appropriate predecessor enums as indirect. - for (auto &origBB : *original) { - auto *predEnum = createBasicBlockPredecessorEnum( - &origBB, indices, vjpGenSig); - StructDecl *pbStruct; - VarDecl *predecessorEnumField; - std::tie(pbStruct, predecessorEnumField) = createPullbackStruct( - &origBB, indices, predEnum, vjpGenSig); - pullbackDataStructures.insert({&origBB, {pbStruct, predEnum}}); - } - } - - /// Returns the pullback struct and predecessor enum associated with the - /// given original block. - std::pair - getPullbackDataStructures(SILBasicBlock *origBB) { - return pullbackDataStructures.lookup(origBB); - } + explicit PullbackInfo(ADContext &context, SILFunction *original, + SILFunction *vjp, const SILAutoDiffIndices &indices); /// Returns the pullback struct associated with the given original block. StructDecl *getPullbackStruct(SILBasicBlock *origBB) const { - return pullbackDataStructures.lookup(origBB).first; + return pullbackStructs.lookup(origBB); } /// Returns the lowered SIL type of the pullback struct associated with the /// given original block. SILType getPullbackStructLoweredType(SILBasicBlock *origBB) const { auto *pbStruct = getPullbackStruct(origBB); - auto pbStructType = pbStruct->getDeclaredInterfaceType() - ->getCanonicalType(); - return typeConverter.getLoweredType( - pbStructType, ResilienceExpansion::Minimal); + auto pbStructType = + pbStruct->getDeclaredInterfaceType()->getCanonicalType(); + return typeConverter.getLoweredType(pbStructType, + ResilienceExpansion::Minimal); } /// Returns the predecessor enum associated with the given original block. EnumDecl *getPredecessorEnum(SILBasicBlock *origBB) const { - return pullbackDataStructures.lookup(origBB).second; + return predecessorEnums.lookup(origBB); } /// Returns the lowered SIL type of the predecessor enum associated with the /// given original block. SILType getPredecessorEnumLoweredType(SILBasicBlock *origBB) const { auto *predEnum = getPredecessorEnum(origBB); - auto predEnumType = predEnum->getDeclaredInterfaceType() - ->getCanonicalType(); - return typeConverter.getLoweredType( - predEnumType, ResilienceExpansion::Minimal); + auto predEnumType = + predEnum->getDeclaredInterfaceType()->getCanonicalType(); + return typeConverter.getLoweredType(predEnumType, + ResilienceExpansion::Minimal); } - /// Returns the enum element in the given successor block's predecessor enum, + /// Returns the enum element in the given successor block's predecessor enum /// corresponding to the given predecessor block. - EnumElementDecl *lookUpPredecessorEnumElement( - SILBasicBlock *origPredBB, SILBasicBlock *origSuccBB) { + EnumElementDecl * + lookUpPredecessorEnumElement(SILBasicBlock *origPredBB, + SILBasicBlock *origSuccBB) const { assert(origPredBB->getParent() == original); return predecessorEnumCases.lookup({origPredBB, origSuccBB}); } @@ -675,7 +642,7 @@ class PullbackInfo { /// Returns the predecessor enum field for the pullback struct of the given /// original block. VarDecl *lookUpPullbackStructPredecessorField(SILBasicBlock *origBB) { - auto *pullbackStruct = getPullbackDataStructures(origBB).first; + auto *pullbackStruct = getPullbackStruct(origBB); return pullbackStructPredecessorFields.lookup(pullbackStruct); } @@ -1254,6 +1221,38 @@ ADContext::emitNondifferentiabilityError(SourceLoc loc, } } +PullbackInfo::PullbackInfo(ADContext &context, SILFunction *original, + SILFunction *vjp, const SILAutoDiffIndices &indices) + : original(original), typeConverter(context.getTypeConverter()) { + auto &astCtx = original->getASTContext(); + auto *loopAnalysis = context.getPassManager().getAnalysis(); + auto *loopInfo = loopAnalysis->get(original); + // Get VJP generic signature. + CanGenericSignature vjpGenSig = nullptr; + if (auto *vjpGenEnv = vjp->getGenericEnvironment()) + vjpGenSig = vjpGenEnv->getGenericSignature()->getCanonicalSignature(); + // Create predecessor enum and pullback struct for each original block. + for (auto &origBB : *original) { + auto *pbStruct = createPullbackStruct(&origBB, indices, vjpGenSig); + pullbackStructs.insert({&origBB, pbStruct}); + } + for (auto &origBB : *original) { + auto *pbStruct = getPullbackStruct(&origBB); + auto *predEnum = + createBasicBlockPredecessorEnum(&origBB, indices, vjpGenSig); + // If original block is in a loop, mark predecessor enum as indirect. + if (loopInfo->getLoopFor(&origBB)) + predEnum->getAttrs().add(new (astCtx) IndirectAttr(/*Implicit*/ true)); + predecessorEnums.insert({&origBB, predEnum}); + if (origBB.isEntry()) + continue; + auto *predEnumField = + addVarDecl(pbStruct, astCtx.getIdentifier("predecessor").str(), + predEnum->getDeclaredInterfaceType()); + pullbackStructPredecessorFields.insert({pbStruct, predEnumField}); + } +} + //===----------------------------------------------------------------------===// // Activity Analysis //===----------------------------------------------------------------------===// @@ -1737,6 +1736,7 @@ static void dumpActivityInfo(SILFunction &fn, for (auto &inst : bb) for (auto res : inst.getResults()) dumpActivityInfo(res, indices, activityInfo, s); + s << '\n'; } } @@ -1756,23 +1756,13 @@ static bool diagnoseNoReturn(ADContext &context, SILFunction *original, /// flow unsupported" error at appropriate source locations. Returns true if /// error is emitted. /// -/// Update as control flow support is added. Currently, loops and branching -/// terminators other than `br` and `cond_br` are not supported. +/// Update as control flow support is added. Currently, branching terminators +/// other than `br`, `cond_br`, `switch_enum` are not supported. static bool diagnoseUnsupportedControlFlow(ADContext &context, SILFunction *original, DifferentiationInvoker invoker) { if (original->getBlocks().size() <= 1) return false; - // Diagnose loops first, to provide a more specific diagnostic. - auto *loopAnalysis = context.getPassManager().getAnalysis(); - auto *loopInfo = loopAnalysis->get(original); - if (!loopInfo->empty()) { - auto *loop = *loopInfo->getTopLevelLoops().begin(); - context.emitNondifferentiabilityError( - loop->getHeader()->getTerminator(), invoker, - diag::autodiff_loops_not_supported); - return true; - } // Diagnose unsupported branching terminators. for (auto &bb : *original) { auto *term = bb.getTerminator(); @@ -2846,15 +2836,12 @@ class VJPEmitter final } public: - explicit VJPEmitter(ADContext &context, - SILFunction *original, - SILDifferentiableAttr *attr, - SILFunction *vjp, + explicit VJPEmitter(ADContext &context, SILFunction *original, + SILDifferentiableAttr *attr, SILFunction *vjp, DifferentiationInvoker invoker) : TypeSubstCloner(*vjp, *original, getSubstitutionMap(original, vjp)), context(context), original(original), attr(attr), vjp(vjp), - pullbackInfo( - original, vjp, attr->getIndices(), context.getTypeConverter()), + pullbackInfo(context, original, vjp, attr->getIndices()), invoker(invoker), activityInfo(getActivityInfo( context, original, attr->getIndices(), vjp)) { // Create empty adjoint function. @@ -3055,15 +3042,31 @@ class VJPEmitter final /// Build a predecessor enum instance using the given builder for the given /// original predecessor/successor blocks and pullback struct value. - EnumInst *buildPredecessorEnumValue( - SILBuilder &builder, SILBasicBlock *predBB, SILBasicBlock *succBB, - StructInst *pbStructVal) { + EnumInst *buildPredecessorEnumValue(SILBuilder &builder, + SILBasicBlock *predBB, + SILBasicBlock *succBB, + StructInst *pbStructVal) { auto loc = pbStructVal->getLoc(); auto *succEnum = pullbackInfo.getPredecessorEnum(succBB); auto enumLoweredTy = getNominalDeclLoweredType(succEnum); auto *enumEltDecl = pullbackInfo.lookUpPredecessorEnumElement(predBB, succBB); - return builder.createEnum(loc, pbStructVal, enumEltDecl, enumLoweredTy); + auto enumEltType = remapType( + enumLoweredTy.getEnumElementType(enumEltDecl, getModule())); + // If the enum element type does not have a box type (i.e. the enum case is + // not indirect), then directly create an enum. + auto boxType = dyn_cast(enumEltType.getASTType()); + if (!boxType) + return builder.createEnum(loc, pbStructVal, enumEltDecl, enumLoweredTy); + // Otherwise, box the pullback struct value and create an enum. + auto *allocBox = builder.createAllocBox(loc, boxType); + auto *projectBox = builder.createProjectBox(loc, allocBox, /*index*/ 0); + builder.createStore(loc, pbStructVal, projectBox, + getBufferSOQ(projectBox->getType().getASTType(), *vjp)); + // NOTE(TF-585): `fix_lifetime` is generated to avoid AllocBoxToStack crash + // for nested loop AD. + builder.createFixLifetime(loc, allocBox); + return builder.createEnum(loc, allocBox, enumEltDecl, enumLoweredTy); } public: @@ -3156,25 +3159,25 @@ class VJPEmitter final // constructs a predecessor enum value and branches to the VJP successor // block. auto createTrampolineBasicBlock = - [&](SILBasicBlock *origSuccBB) -> SILBasicBlock * { - auto *vjpSuccBB = getOpBasicBlock(origSuccBB); - // Create the trampoline block. - auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB); - for (auto *arg : vjpSuccBB->getArguments().drop_back()) - trampolineBB->createPhiArgument(arg->getType(), - arg->getOwnershipKind()); - // Build predecessor enum value for successor block and branch to it. - SILBuilder trampolineBuilder(trampolineBB); - auto *succEnumVal = buildPredecessorEnumValue( - trampolineBuilder, origBB, origSuccBB, pbStructVal); - SmallVector forwardedArguments( - trampolineBB->getArguments().begin(), - trampolineBB->getArguments().end()); - forwardedArguments.push_back(succEnumVal); - trampolineBuilder.createBranch( - sei->getLoc(), vjpSuccBB, forwardedArguments); - return trampolineBB; - }; + [&](SILBasicBlock *origSuccBB) -> SILBasicBlock * { + auto *vjpSuccBB = getOpBasicBlock(origSuccBB); + // Create the trampoline block. + auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB); + for (auto *arg : vjpSuccBB->getArguments().drop_back()) + trampolineBB->createPhiArgument(arg->getType(), + arg->getOwnershipKind()); + // Build predecessor enum value for successor block and branch to it. + SILBuilder trampolineBuilder(trampolineBB); + auto *succEnumVal = buildPredecessorEnumValue(trampolineBuilder, origBB, + origSuccBB, pbStructVal); + SmallVector forwardedArguments( + trampolineBB->getArguments().begin(), + trampolineBB->getArguments().end()); + forwardedArguments.push_back(succEnumVal); + trampolineBuilder.createBranch(sei->getLoc(), vjpSuccBB, + forwardedArguments); + return trampolineBB; + }; // Create trampoline successor basic blocks. SmallVector, 4> caseBBs; @@ -3711,6 +3714,9 @@ class AdjointEmitter final : public SILInstructionVisitor { /// Dominance info for the original function. DominanceInfo *domInfo = nullptr; + /// Post-dominance info for the original function. + PostDominanceInfo *postDomInfo = nullptr; + /// Post-order info for the original function. PostOrderFunctionInfo *postOrderInfo = nullptr; @@ -3738,16 +3744,17 @@ class AdjointEmitter final : public SILInstructionVisitor { /// Mapping from original basic blocks to dominated active values. DenseMap> activeValues; - /// Local adjoint values to be cleaned up. This is populated when adjoint - /// emission is run on one basic block and cleaned before processing another - /// basic block. - SmallVector blockLocalAdjointValues; - /// Mapping from original basic blocks and original active values to /// corresponding adjoint block arguments. DenseMap, SILArgument *> activeValueAdjointBBArgumentMap; + /// Mapping from original basic blocks to local adjoint values to be cleaned + /// up. This is populated when adjoint emission is run on one basic block and + /// cleaned before processing another basic block. + DenseMap> + blockLocalAdjointValues; + /// Stack buffers allocated for storing local adjoint adjoint values. SmallVector functionLocalAllocations; @@ -3794,8 +3801,10 @@ class AdjointEmitter final : public SILInstructionVisitor { // Get dominance and post-order info for the original function. auto &passManager = getContext().getPassManager(); auto *domAnalysis = passManager.getAnalysis(); + auto *postDomAnalysis = passManager.getAnalysis(); auto *postOrderAnalysis = passManager.getAnalysis(); domInfo = domAnalysis->get(vjpEmitter.original); + postDomInfo = postDomAnalysis->get(vjpEmitter.original); postOrderInfo = postOrderAnalysis->get(vjpEmitter.original); } @@ -3966,7 +3975,7 @@ class AdjointEmitter final : public SILInstructionVisitor { valueMap.erase(it); auto adjVal = accumulateAdjointsDirect(existingValue, newAdjointValue); initializeAdjointValue(origBB, originalValue, adjVal); - blockLocalAdjointValues.push_back(adjVal); + blockLocalAdjointValues[origBB].push_back(adjVal); } /// Get the adjoint block argument corresponding to the given original block @@ -4203,8 +4212,7 @@ class AdjointEmitter final : public SILInstructionVisitor { auto &bbActiveValues = activeValues[bb]; // If the current block has an immediate dominator, append the immediate // dominator block's active values to the current block's active values. - auto *domNode = domInfo->getNode(bb)->getIDom(); - if (domNode) { + if (auto *domNode = domInfo->getNode(bb)->getIDom()) { auto &domBBActiveValues = activeValues[domNode->getBlock()]; bbActiveValues.append(domBBActiveValues.begin(), domBBActiveValues.end()); @@ -4250,8 +4258,21 @@ class AdjointEmitter final : public SILInstructionVisitor { return true; // Create adjoint blocks and arguments, visiting original blocks in - // post-order. - for (auto *origBB : postOrderInfo->getPostOrder()) { + // post-order post-dominance order. + SmallVector postOrderPostDomOrder; + // Start from the root node, which may have a marker `nullptr` block if + // there are multiple roots. + PostOrderPostDominanceOrder postDomOrder(postDomInfo->getRootNode(), + postOrderInfo, original.size()); + while (auto *origNode = postDomOrder.getNext()) { + auto *origBB = origNode->getBlock(); + postDomOrder.pushChildren(origNode); + // If node is the `nullptr` marker basic block, do not push it. + if (!origBB) + continue; + postOrderPostDomOrder.push_back(origBB); + } + for (auto *origBB : postOrderPostDomOrder) { auto *adjointBB = adjoint.createBasicBlock(); adjointBBMap.insert({origBB, adjointBB}); auto pbStructLoweredType = @@ -4302,10 +4323,17 @@ class AdjointEmitter final : public SILInstructionVisitor { // adjoint original block, trampoline adjoint values of active values. for (auto *succBB : origBB->getSuccessorBlocks()) { auto *adjointTrampolineBB = adjoint.createBasicBlockBefore(adjointBB); - adjointTrampolineBBMap.insert( - {{origBB, succBB}, adjointTrampolineBB}); - adjointTrampolineBB->createPhiArgument( - pbStructLoweredType, ValueOwnershipKind::Guaranteed); + adjointTrampolineBBMap.insert({{origBB, succBB}, adjointTrampolineBB}); + // Get the enum element type (i.e. the pullback struct type). The enum + // element type may be boxed if the enum is indirect. + auto enumLoweredTy = + getPullbackInfo().getPredecessorEnumLoweredType(succBB); + auto *enumEltDecl = + getPullbackInfo().lookUpPredecessorEnumElement(origBB, succBB); + auto enumEltType = remapType( + enumLoweredTy.getEnumElementType(enumEltDecl, getModule())); + adjointTrampolineBB->createPhiArgument(enumEltType, + ValueOwnershipKind::Guaranteed); } } @@ -4356,7 +4384,7 @@ class AdjointEmitter final : public SILInstructionVisitor { // Visit original blocks blocks in post-order and perform differentiation // in corresponding adjoint blocks. - for (auto *bb : postOrderInfo->getPostOrder()) { + for (auto *bb : postOrderPostDomOrder) { if (errorOccurred) break; // Get the corresponding adjoint basic block. @@ -4392,10 +4420,7 @@ class AdjointEmitter final : public SILInstructionVisitor { // 1. Get the pullback struct adjoint bb argument. // 2. Extract the predecessor enum value from the pullback struct value. auto *pbStructVal = getAdjointBlockPullbackStructArgument(bb); - StructDecl *pbStruct; - EnumDecl *predEnum; - std::tie(pbStruct, predEnum) = - getPullbackInfo().getPullbackDataStructures(bb); + auto *predEnum = getPullbackInfo().getPredecessorEnum(bb); auto *predEnumField = getPullbackInfo().lookUpPullbackStructPredecessorField(bb); auto *predEnumVal = @@ -4482,19 +4507,30 @@ class AdjointEmitter final : public SILInstructionVisitor { } } // Propagate pullback struct argument. - trampolineArguments.push_back(adjointSuccBB->getArguments().front()); + auto *predPBStructVal = adjointTrampolineBB->getArguments().front(); + auto boxType = + dyn_cast(predPBStructVal->getType().getASTType()); + if (!boxType) { + trampolineArguments.push_back(predPBStructVal); + } else { + auto *projectBox = adjointTrampolineBBBuilder.createProjectBox( + adjLoc, predPBStructVal, /*index*/ 0); + auto *loadInst = adjointTrampolineBBBuilder.createLoad( + adjLoc, projectBox, + getBufferLOQ(projectBox->getType().getASTType(), adjoint)); + trampolineArguments.push_back(loadInst); + } // Branch from adjoint trampoline block to adjoint block. - adjointTrampolineBBBuilder.createBranch( - adjLoc, adjointBB, trampolineArguments); + adjointTrampolineBBBuilder.createBranch(adjLoc, adjointBB, + trampolineArguments); } auto *enumEltDecl = getPullbackInfo().lookUpPredecessorEnumElement(predBB, bb); adjointSuccessorCases.push_back({enumEltDecl, adjointSuccBB}); } // Emit cleanups for all block-local adjoint values. - for (auto adjVal : blockLocalAdjointValues) + for (auto adjVal : blockLocalAdjointValues[bb]) emitCleanupForAdjointValue(adjVal); - blockLocalAdjointValues.clear(); // - If the original block has exactly one predecessor, then the adjoint // block has exactly one successor. Extract the pullback struct value // from the predecessor enum value using `unchecked_enum_data` and @@ -4563,9 +4599,8 @@ class AdjointEmitter final : public SILInstructionVisitor { for (auto i : getIndices().parameters->getIndices()) addRetElt(i); // Emit cleanups for all local values. - for (auto adjVal : blockLocalAdjointValues) + for (auto adjVal : blockLocalAdjointValues[origEntry]) emitCleanupForAdjointValue(adjVal); - blockLocalAdjointValues.clear(); // Disable cleanup for original indirect parameter adjoint buffers. // Copy them to adjoint indirect results. diff --git a/lib/SILOptimizer/Mandatory/Differentiation.h b/lib/SILOptimizer/Mandatory/Differentiation.h new file mode 100644 index 0000000000000..c15fc1d13f2d0 --- /dev/null +++ b/lib/SILOptimizer/Mandatory/Differentiation.h @@ -0,0 +1,93 @@ +//===--- Differentiation.h - SIL Automatic Differentiation ----*- C++ -*---===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// SWIFT_ENABLE_TENSORFLOW +// +// Reverse-mode automatic differentiation utilities. +// +// NOTE: Although the AD feature is developed as part of the Swift for +// TensorFlow project, it is completely independent from TensorFlow support. +// +// TODO: Move definitions here from Differentiation.cpp. +// +//===----------------------------------------------------------------------===// + +#ifndef SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_H +#define SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_H + +#include "swift/SILOptimizer/Analysis/DominanceAnalysis.h" +#include "swift/SILOptimizer/Utils/Local.h" + +namespace swift { + +using llvm::DenseMap; +using llvm::SmallDenseMap; +using llvm::SmallDenseSet; +using llvm::SmallMapVector; +using llvm::SmallSet; + +/// Helper class for visiting basic blocks in post-order post-dominance order, +/// based on a worklist algorithm. +class PostOrderPostDominanceOrder { + SmallVector buffer; + PostOrderFunctionInfo *postOrderInfo; + size_t srcIdx = 0; + +public: + /// Constructor. + /// \p root The root of the post-dominator tree. + /// \p postOrderInfo The post-order info of the function. + /// \p capacity Should be the number of basic blocks in the dominator tree to + /// reduce memory allocation. + PostOrderPostDominanceOrder(DominanceInfoNode *root, + PostOrderFunctionInfo *postOrderInfo, + int capacity = 0) + : postOrderInfo(postOrderInfo) { + buffer.reserve(capacity); + buffer.push_back(root); + } + + /// Get the next block from the worklist. + DominanceInfoNode *getNext() { + if (srcIdx == buffer.size()) + return nullptr; + return buffer[srcIdx++]; + } + + /// Pushes the dominator children of a block onto the worklist in post-order. + void pushChildren(DominanceInfoNode *node) { + pushChildrenIf(node, [](SILBasicBlock *) { return true; }); + } + + /// Conditionally pushes the dominator children of a block onto the worklist + /// in post-order. + template + void pushChildrenIf(DominanceInfoNode *node, Pred pred) { + SmallVector children; + for (auto *child : *node) + children.push_back(child); + llvm::sort(children.begin(), children.end(), + [&](DominanceInfoNode *n1, DominanceInfoNode *n2) { + return postOrderInfo->getPONumber(n1->getBlock()) < + postOrderInfo->getPONumber(n2->getBlock()); + }); + for (auto *child : children) { + SILBasicBlock *childBB = child->getBlock(); + if (pred(childBB)) + buffer.push_back(child); + } + } +}; + +} // end namespace swift + +#endif // SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_H diff --git a/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift b/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift index ee26d8e18fbb2..9a5664ade86a2 100644 --- a/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift +++ b/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift @@ -114,7 +114,17 @@ extension Tracked : SignedNumeric & Numeric where T : SignedNumeric, T == T.Magn } public static func *= (lhs: inout Tracked, rhs: Tracked) { - lhs = Tracked(lhs.value * rhs.value) + lhs = lhs * rhs + } +} + +extension Tracked where T : FloatingPoint { + public static func / (lhs: Tracked, rhs: Tracked) -> Tracked { + return Tracked(lhs.value / rhs.value) + } + + public static func /= (lhs: inout Tracked, rhs: Tracked) { + lhs = lhs / rhs } } @@ -181,6 +191,16 @@ extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude, } } +extension Tracked where T : Differentiable & FloatingPoint, + T == T.AllDifferentiableVariables, T == T.TangentVector { + @usableFromInline + @differentiating(/) + internal static func _vjpDivide(lhs: Self, rhs: Self) + -> (value: Self, pullback: (Self) -> (Self, Self)) { + return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) }) + } +} + // Differential operators for `Tracked`. public extension Differentiable { @inlinable diff --git a/test/AutoDiff/control_flow.swift b/test/AutoDiff/control_flow.swift index d3754063f1fab..df6d49dcd9d7a 100644 --- a/test/AutoDiff/control_flow.swift +++ b/test/AutoDiff/control_flow.swift @@ -519,4 +519,112 @@ ControlFlowTests.test("Enums") { } } +ControlFlowTests.test("Loops") { + func for_loop(_ x: Float) -> Float { + var result = x + for _ in 1..<3 { + result = result * x + } + return result + } + expectEqual((8, 12), valueWithGradient(at: 2, in: for_loop)) + expectEqual((27, 27), valueWithGradient(at: 3, in: for_loop)) + + func while_loop(_ x: Float) -> Float { + var result = x + var i = 1 + while i < 3 { + result = result * x + i += 1 + } + return result + } + expectEqual((8, 12), valueWithGradient(at: 2, in: while_loop)) + expectEqual((27, 27), valueWithGradient(at: 3, in: while_loop)) + + func repeat_while_loop(_ x: Float) -> Float { + var result = x + var i = 1 + repeat { + result = result * x + i += 1 + } while i < 3 + return result + } + // FIXME(TF-584): Investigate incorrect (too big) gradient values + // for repeat-while loops. + // expectEqual((8, 12), valueWithGradient(at: 2, in: repeat_while_loop)) + // expectEqual((27, 27), valueWithGradient(at: 3, in: repeat_while_loop)) + expectEqual((8, 18), valueWithGradient(at: 2, in: repeat_while_loop)) + expectEqual((27, 36), valueWithGradient(at: 3, in: repeat_while_loop)) + + func loop_continue(_ x: Float) -> Float { + var result = x + for i in 1..<10 { + if i.isMultiple(of: 2) { + continue + } + result = result * x + } + return result + } + expectEqual((64, 192), valueWithGradient(at: 2, in: loop_continue)) + expectEqual((729, 1458), valueWithGradient(at: 3, in: loop_continue)) + + func loop_break(_ x: Float) -> Float { + var result = x + for i in 1..<10 { + if i.isMultiple(of: 2) { + continue + } + result = result * x + } + return result + } + expectEqual((64, 192), valueWithGradient(at: 2, in: loop_break)) + expectEqual((729, 1458), valueWithGradient(at: 3, in: loop_break)) + + func nested_loop1(_ x: Float) -> Float { + var outer = x + for _ in 1..<3 { + outer = outer * x + + var inner = outer + var i = 1 + while i < 3 { + inner = inner + x + i += 1 + } + outer = inner + } + return outer + } + expectEqual((20, 22), valueWithGradient(at: 2, in: nested_loop1)) + expectEqual((104, 66), valueWithGradient(at: 4, in: nested_loop1)) + + func nested_loop2(_ x: Float, count: Int) -> Float { + var outer = x + outerLoop: for _ in 1.. Float { } } +// Test loops. + +@differentiable +func for_loop(_ x: Float) -> Float { + var result: Float = x + for _ in 0..<3 { + result = result * x + } + return result +} + +@differentiable +func while_loop(_ x: Float) -> Float { + var result = x + var i = 1 + while i < 3 { + result = result * x + i += 1 + } + return result +} + +@differentiable +func nested_loop(_ x: Float) -> Float { + var outer = x + for _ in 1..<3 { + outer = outer * x + + var inner = outer + var i = 1 + while i < 3 { + inner = inner / x + i += 1 + } + outer = inner + } + return outer +} + +// Test `try_apply`. + +// expected-error @+1 {{function is not differentiable}} +@differentiable +// expected-note @+1 {{when differentiating this function definition}} +func withoutDerivative( + at x: T, in body: (T) throws -> R +) rethrows -> R { + // expected-note @+1 {{differentiating control flow is not yet supported}} + try body(x) +} + // Test unsupported differentiation of active enum values. // expected-error @+1 {{function is not differentiable}} @@ -91,16 +142,14 @@ enum Tree : Differentiable & AdditiveArithmetic { } } -// Test loops. - // expected-error @+1 {{function is not differentiable}} @differentiable // expected-note @+1 {{when differentiating this function definition}} -func loop(_ x: Float) -> Float { +func loop_array(_ array: [Float]) -> Float { var result: Float = 1 - // expected-note @+1 {{differentiating loops is not yet supported}} - for _ in 0..<3 { - result += x + // expected-note @+1 {{differentiating enum values is not yet supported}} + for x in array { + result = result * x } - return x + return result } diff --git a/test/AutoDiff/control_flow_sil.swift b/test/AutoDiff/control_flow_sil.swift index 65b17c3ec1eae..48f6dc6a7affd 100644 --- a/test/AutoDiff/control_flow_sil.swift +++ b/test/AutoDiff/control_flow_sil.swift @@ -16,31 +16,31 @@ func cond(_ x: Float) -> Float { return x - x } -// CHECK-DATA-STRUCTURES: enum _AD__cond_bb0__Pred__src_0_wrt_0 { -// CHECK-DATA-STRUCTURES: } // CHECK-DATA-STRUCTURES: struct _AD__cond_bb0__PB__src_0_wrt_0 { // CHECK-DATA-STRUCTURES: } -// CHECK-DATA-STRUCTURES: enum _AD__cond_bb1__Pred__src_0_wrt_0 { -// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0) -// CHECK-DATA-STRUCTURES: } // CHECK-DATA-STRUCTURES: struct _AD__cond_bb1__PB__src_0_wrt_0 { // CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb1__Pred__src_0_wrt_0 { get set } // CHECK-DATA-STRUCTURES: @_hasStorage var pullback_0: (Float) -> (Float, Float) { get set } // CHECK-DATA-STRUCTURES: } -// CHECK-DATA-STRUCTURES: enum _AD__cond_bb2__Pred__src_0_wrt_0 { -// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0) -// CHECK-DATA-STRUCTURES: } // CHECK-DATA-STRUCTURES: struct _AD__cond_bb2__PB__src_0_wrt_0 { // CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb2__Pred__src_0_wrt_0 { get set } // CHECK-DATA-STRUCTURES: @_hasStorage var pullback_1: (Float) -> (Float, Float) { get set } // CHECK-DATA-STRUCTURES: } +// CHECK-DATA-STRUCTURES: struct _AD__cond_bb3__PB__src_0_wrt_0 { +// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb3__Pred__src_0_wrt_0 { get set } +// CHECK-DATA-STRUCTURES: } +// CHECK-DATA-STRUCTURES: enum _AD__cond_bb0__Pred__src_0_wrt_0 { +// CHECK-DATA-STRUCTURES: } +// CHECK-DATA-STRUCTURES: enum _AD__cond_bb1__Pred__src_0_wrt_0 { +// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0) +// CHECK-DATA-STRUCTURES: } +// CHECK-DATA-STRUCTURES: enum _AD__cond_bb2__Pred__src_0_wrt_0 { +// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0) +// CHECK-DATA-STRUCTURES: } // CHECK-DATA-STRUCTURES: enum _AD__cond_bb3__Pred__src_0_wrt_0 { // CHECK-DATA-STRUCTURES: case bb2(_AD__cond_bb2__PB__src_0_wrt_0) // CHECK-DATA-STRUCTURES: case bb1(_AD__cond_bb1__PB__src_0_wrt_0) // CHECK-DATA-STRUCTURES: } -// CHECK-DATA-STRUCTURES: struct _AD__cond_bb3__PB__src_0_wrt_0 { -// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb3__Pred__src_0_wrt_0 { get set } -// CHECK-DATA-STRUCTURES: } // CHECK-SIL-LABEL: sil hidden @AD__cond__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { // CHECK-SIL: bb0([[INPUT_ARG:%.*]] : $Float): @@ -137,6 +137,20 @@ func nested_cond_generic(_ x: T, _ y: T) -> return y } +@differentiable +@_silgen_name("loop_generic") +func loop_generic(_ x: T) -> T { + var result = x + for _ in 1..<3 { + var y = x + for _ in 1..<3 { + result = y + y = result + } + } + return result +} + // Test control flow + tuple buffer. // Verify that adjoint buffers are not allocated for address projections. diff --git a/test/AutoDiff/leakchecking.swift b/test/AutoDiff/leakchecking.swift index 6cb791bd5d305..bafeba61a21aa 100644 --- a/test/AutoDiff/leakchecking.swift +++ b/test/AutoDiff/leakchecking.swift @@ -190,6 +190,44 @@ LeakCheckingTests.test("ControlFlow") { expectEqual((-2674, 2), Tracked(-1337).valueWithGradient(in: { x in enum_notactive2(.b(4, 5), x) })) } + // FIXME: Fix control flow AD memory leaks. + // See related FIXME comments in adjoint value/buffer propagation in + // lib/SILOptimizer/Mandatory/Differentiation.cpp. + testWithLeakChecking(expectedLeakCount: 6) { + func for_loop(_ x: Tracked) -> Tracked { + var result = x + for _ in 1..<3 { + result = result * x + } + return result + } + expectEqual((8, 12), Tracked(2).valueWithGradient(in: for_loop)) + expectEqual((27, 27), Tracked(3).valueWithGradient(in: for_loop)) + } + + // FIXME: Fix control flow AD memory leaks. + // See related FIXME comments in adjoint value/buffer propagation in + // lib/SILOptimizer/Mandatory/Differentiation.cpp. + testWithLeakChecking(expectedLeakCount: 20) { + func nested_loop(_ x: Tracked) -> Tracked { + var outer = x + for _ in 1..<3 { + outer = outer * x + + var inner = outer + var i = 1 + while i < 3 { + inner = inner / x + i += 1 + } + outer = inner + } + return outer + } + expectEqual((0.5, -0.25), Tracked(2).valueWithGradient(in: nested_loop)) + expectEqual((0.25, -0.0625), Tracked(4).valueWithGradient(in: nested_loop)) + } + // FIXME: Fix control flow AD memory leaks. // See related FIXME comments in adjoint value/buffer propagation in // lib/SILOptimizer/Mandatory/Differentiation.cpp. diff --git a/test/AutoDiff/refcounting.swift b/test/AutoDiff/refcounting.swift index 0fade58c9d39a..124fb20110ef4 100644 --- a/test/AutoDiff/refcounting.swift +++ b/test/AutoDiff/refcounting.swift @@ -36,11 +36,11 @@ func testOwnedVector(_ x: Vector) -> Vector { } _ = pullback(at: Vector.zero, in: testOwnedVector) -// CHECK-LABEL: enum {{.*}}testOwnedVector{{.*}}__Pred__src_0_wrt_0 { -// CHECK-NEXT: } // CHECK-LABEL: struct {{.*}}testOwnedVector{{.*}}__PB__src_0_wrt_0 { // CHECK-NEXT: @_hasStorage var pullback_0: (Vector) -> (Vector, Vector) { get set } // CHECK-NEXT: } +// CHECK-LABEL: enum {{.*}}testOwnedVector{{.*}}__Pred__src_0_wrt_0 { +// CHECK-NEXT: } // CHECK-LABEL: sil hidden @{{.*}}UsesMethodOfNoDerivativeMember{{.*}}applied2to{{.*}}__adjoint_src_0_wrt_0_1 // CHECK: bb0([[SEED:%.*]] : $Vector, [[PB_STRUCT:%.*]] : ${{.*}}UsesMethodOfNoDerivativeMember{{.*}}applied2to{{.*}}__PB__src_0_wrt_0_1):