From 4036d53e48feba6f2c353bb3781a77cd7d00b727 Mon Sep 17 00:00:00 2001 From: Marc Rasi Date: Fri, 22 Nov 2019 15:47:01 -0800 Subject: [PATCH 01/10] [AutoDiff] make diff_wit_fns instead of fn_refs --- include/swift/IRGen/Linking.h | 1 + include/swift/SIL/SILBuilder.h | 10 +- include/swift/SIL/SILInstruction.h | 8 +- include/swift/SIL/SILModule.h | 10 + lib/IRGen/GenDiffWitness.cpp | 22 +- lib/IRGen/LoadableByAddress.cpp | 37 ++++ lib/SIL/SILDifferentiabilityWitness.cpp | 4 + lib/SIL/SILInstructions.cpp | 8 +- lib/SIL/SILModule.cpp | 6 + lib/SILGen/SILGen.cpp | 12 +- lib/SILOptimizer/Mandatory/CMakeLists.txt | 1 + .../Mandatory/Differentiation.cpp | 190 +++++++----------- .../Differentiation/DerivativeLookup.cpp | 122 +++++++++++ .../Differentiation/DerivativeLookup.h | 72 +++++++ test/AutoDiff/closures.swift | 5 +- ...891-protocol-req-capture-propagation.swift | 2 +- .../differentiable_sil_attr_roundtrip.swift | 10 +- test/AutoDiff/forward_mode_sil.swift | 12 +- test/AutoDiff/refcounting.swift | 4 +- test/AutoDiff/subset_parameters_thunk.swift | 4 +- test/AutoDiff/witness_table_sil.swift | 13 +- 21 files changed, 377 insertions(+), 176 deletions(-) create mode 100644 lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.cpp create mode 100644 lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.h rename test/AutoDiff/{compiler_crashers => compiler_crashers_fixed}/tf891-protocol-req-capture-propagation.swift (99%) diff --git a/include/swift/IRGen/Linking.h b/include/swift/IRGen/Linking.h index 0bf6963106238..20c30bdfd077b 100644 --- a/include/swift/IRGen/Linking.h +++ b/include/swift/IRGen/Linking.h @@ -478,6 +478,7 @@ class LinkEntity { setForDifferentiabilityWitness(Kind kind, const SILDifferentiabilityWitness *witness) { Pointer = const_cast(static_cast(witness)); + SecondaryPointer = nullptr; Data = LINKENTITY_SET_FIELD(Kind, unsigned(kind)); } // SWIFT_ENABLE_TENSORFLOW_END diff --git a/include/swift/SIL/SILBuilder.h b/include/swift/SIL/SILBuilder.h index 3cb68afb4dcef..deeb464799708 100644 --- a/include/swift/SIL/SILBuilder.h +++ b/include/swift/SIL/SILBuilder.h @@ -554,12 +554,14 @@ class SILBuilder { NormalDifferentiableFunctionTypeComponent::Original, TheFunction)); } - DifferentiabilityWitnessFunctionInst * - createDifferentiabilityWitnessFunction( + /// Note: explicit function type may be specified only in lowered SIL. + DifferentiabilityWitnessFunctionInst *createDifferentiabilityWitnessFunction( SILLocation Loc, DifferentiabilityWitnessFunctionKind WitnessKind, - SILDifferentiabilityWitness *Witness) { + SILDifferentiabilityWitness *Witness, + Optional FunctionType = None) { return insert(new (getModule()) DifferentiabilityWitnessFunctionInst( - getModule(), getSILDebugLocation(Loc), WitnessKind, Witness)); + getModule(), getSILDebugLocation(Loc), WitnessKind, Witness, + FunctionType)); } // SWIFT_ENABLE_TENSORFLOW END diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h index 14b67f0e8af0b..fed8da5f7c786 100644 --- a/include/swift/SIL/SILInstruction.h +++ b/include/swift/SIL/SILInstruction.h @@ -8118,15 +8118,11 @@ class DifferentiabilityWitnessFunctionInst SILDifferentiabilityWitness *witness); public: + /// Note: explicit function type may be specified only in lowered SIL. DifferentiabilityWitnessFunctionInst( SILModule &module, SILDebugLocation loc, DifferentiabilityWitnessFunctionKind witnessKind, - SILDifferentiabilityWitness *witness); - - static DifferentiabilityWitnessFunctionInst *create( - SILModule &module, SILDebugLocation loc, - DifferentiabilityWitnessFunctionKind witnessKind, - SILDifferentiabilityWitness *witness); + SILDifferentiabilityWitness *witness, Optional FunctionType); DifferentiabilityWitnessFunctionKind getWitnessKind() const { return witnessKind; diff --git a/include/swift/SIL/SILModule.h b/include/swift/SIL/SILModule.h index 2d5ceaffd26a0..dd8318e633803 100644 --- a/include/swift/SIL/SILModule.h +++ b/include/swift/SIL/SILModule.h @@ -207,6 +207,11 @@ class SILModule { /// Lookup table for SIL differentiability witnesses, keyed by mangled name. llvm::StringMap DifferentiabilityWitnessMap; + /// Lookup table for SILDifferentiabilityWitnesses, keyed by original + /// function name. + llvm::StringMap> + DifferentiabilityWitnessesByFunction; + /// The list of SILDifferentiabilityWitnesses in the module. DifferentiabilityWitnessListType differentiabilityWitnesses; // SWIFT_ENABLE_TENSORFLOW END @@ -613,6 +618,11 @@ class SILModule { /// Look up the differentiability witness corresponding to the given key. SILDifferentiabilityWitness * lookUpDifferentiabilityWitness(SILDifferentiabilityWitnessKey key); + + /// Look up the differentiability witness corresponding to the given function. + llvm::ArrayRef + lookUpDifferentiabilityWitnessesForFunction(StringRef name); + // SWIFT_ENABLE_TENSORFLOW_END // Given a protocol, attempt to create a default witness table declaration diff --git a/lib/IRGen/GenDiffWitness.cpp b/lib/IRGen/GenDiffWitness.cpp index 9eec9a5eedd3b..6e20a2e97c480 100644 --- a/lib/IRGen/GenDiffWitness.cpp +++ b/lib/IRGen/GenDiffWitness.cpp @@ -37,21 +37,13 @@ void IRGenModule::emitSILDifferentiabilityWitness( ConstantInitBuilder builder(*this); auto diffWitnessContents = builder.beginStruct(); - // TODO(TF-894): When the differentiation transform canonicalizes all - // differentiability witnesses to have JVP/VJP functions, remove the nullptr - // cases and assert that JVP/VJP functions exist. - if (dw->getJVP()) { - diffWitnessContents.addBitCast( - getAddrOfSILFunction(dw->getJVP(), NotForDefinition), Int8PtrTy); - } else { - diffWitnessContents.addNullPointer(Int8PtrTy); - } - if (dw->getVJP()) { - diffWitnessContents.addBitCast( - getAddrOfSILFunction(dw->getVJP(), NotForDefinition), Int8PtrTy); - } else { - diffWitnessContents.addNullPointer(Int8PtrTy); - } + assert(dw->getJVP() && "diff witness should be canonicalized"); + assert(dw->getVJP() && "diff witness should be canonicalized"); + + diffWitnessContents.addBitCast( + getAddrOfSILFunction(dw->getJVP(), NotForDefinition), Int8PtrTy); + diffWitnessContents.addBitCast( + getAddrOfSILFunction(dw->getVJP(), NotForDefinition), Int8PtrTy); getAddrOfDifferentiabilityWitness( dw, diffWitnessContents.finishAndCreateFuture()); diff --git a/lib/IRGen/LoadableByAddress.cpp b/lib/IRGen/LoadableByAddress.cpp index 42d2fe24283a7..53adb45949616 100644 --- a/lib/IRGen/LoadableByAddress.cpp +++ b/lib/IRGen/LoadableByAddress.cpp @@ -1675,6 +1675,10 @@ class LoadableByAddress : public SILModuleTransform { bool fixStoreToBlockStorageInstr(SILInstruction &I, SmallVectorImpl &Delete); + // SWIFT_ENABLE_TENSORFLOW + bool recreateDifferentiabilityWitnessFunction( + SILInstruction &I, SmallVectorImpl &Delete); + private: llvm::SetVector modFuncs; llvm::SetVector conversionInstrs; @@ -2672,6 +2676,36 @@ bool LoadableByAddress::fixStoreToBlockStorageInstr( return true; } +// SWIFT_ENABLE_TENSORFLOW +bool LoadableByAddress::recreateDifferentiabilityWitnessFunction( + SILInstruction &I, SmallVectorImpl &Delete) { + auto *instr = dyn_cast(&I); + if (!instr) + return false; + + // If the witness is a declaration, then LoadableByAddress cannot have changed + // the type because the function is in a different module. + if (instr->getWitness()->isDeclaration()) + return true; + + // Otherwise, update the instruction if the function type changed. + auto resultTy = instr->getType(); + auto *referencedFn = instr->getWitness()->getDerivative( + *instr->getWitnessKind().getAsDerivativeFunctionKind()); + assert(referencedFn && "diff witness should be canonicalized"); + auto newResultTy = referencedFn->getLoweredType(); + if (resultTy == newResultTy) + return true; + + SILBuilderWithScope builder(instr); + auto *newInstr = builder.createDifferentiabilityWitnessFunction( + instr->getLoc(), instr->getWitnessKind(), instr->getWitness(), + newResultTy); + instr->replaceAllUsesWith(newInstr); + Delete.push_back(instr); + return true; +} + bool LoadableByAddress::recreateTupleInstr( SILInstruction &I, SmallVectorImpl &Delete) { auto *tupleInstr = dyn_cast(&I); @@ -2994,6 +3028,9 @@ void LoadableByAddress::run() { continue; else if (recreateApply(I, Delete)) continue; + // SWIFT_ENABLE_TENSORFLOW + else if (recreateDifferentiabilityWitnessFunction(I, Delete)) + continue; else fixStoreToBlockStorageInstr(I, Delete); } diff --git a/lib/SIL/SILDifferentiabilityWitness.cpp b/lib/SIL/SILDifferentiabilityWitness.cpp index 407902ad41509..54d619a5e5128 100644 --- a/lib/SIL/SILDifferentiabilityWitness.cpp +++ b/lib/SIL/SILDifferentiabilityWitness.cpp @@ -35,6 +35,8 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::createDeclaration( assert(!module.DifferentiabilityWitnessMap.count(mangledKey) && "Cannot create duplicate differentiability witness in a module"); module.DifferentiabilityWitnessMap[mangledKey] = diffWitness; + module.DifferentiabilityWitnessesByFunction[originalFunction->getName()] + .push_back(diffWitness); module.getDifferentiabilityWitnessList().push_back(diffWitness); return diffWitness; } @@ -56,6 +58,8 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::createDefinition( assert(!module.DifferentiabilityWitnessMap.count(mangledKey) && "Cannot create duplicate differentiability witness in a module"); module.DifferentiabilityWitnessMap[mangledKey] = diffWitness; + module.DifferentiabilityWitnessesByFunction[originalFunction->getName()] + .push_back(diffWitness); module.getDifferentiabilityWitnessList().push_back(diffWitness); return diffWitness; } diff --git a/lib/SIL/SILInstructions.cpp b/lib/SIL/SILInstructions.cpp index 2bed7f979342c..f6207dfe47b48 100644 --- a/lib/SIL/SILInstructions.cpp +++ b/lib/SIL/SILInstructions.cpp @@ -781,9 +781,11 @@ SILType DifferentiabilityWitnessFunctionInst::getDifferentiabilityWitnessType( DifferentiabilityWitnessFunctionInst::DifferentiabilityWitnessFunctionInst( SILModule &module, SILDebugLocation debugLoc, DifferentiabilityWitnessFunctionKind witnessKind, - SILDifferentiabilityWitness *witness) - : InstructionBase(debugLoc, getDifferentiabilityWitnessType( - module, witnessKind, witness)), + SILDifferentiabilityWitness *witness, Optional FunctionType) + : InstructionBase(debugLoc, FunctionType + ? *FunctionType + : getDifferentiabilityWitnessType( + module, witnessKind, witness)), witnessKind(witnessKind), witness(witness) {} // SWIFT_ENABLE_TENSORFLOW END diff --git a/lib/SIL/SILModule.cpp b/lib/SIL/SILModule.cpp index 29b609c1702d2..d0019ca9051aa 100644 --- a/lib/SIL/SILModule.cpp +++ b/lib/SIL/SILModule.cpp @@ -593,6 +593,12 @@ SILModule::lookUpDifferentiabilityWitness(SILDifferentiabilityWitnessKey key) { mangler.mangleSILDifferentiabilityWitnessKey(key)); } +/// Look up the differentiability witness corresponding to the given indices. +llvm::ArrayRef +SILModule::lookUpDifferentiabilityWitnessesForFunction(StringRef name) { + return DifferentiabilityWitnessesByFunction[name]; +} + void SILModule::registerDeserializationNotificationHandler( std::unique_ptr &&handler) { deserializationNotificationHandlers.add(std::move(handler)); diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp index b7ac4f9beb195..fd862980a18c2 100644 --- a/lib/SILGen/SILGen.cpp +++ b/lib/SILGen/SILGen.cpp @@ -775,9 +775,8 @@ void SILGenModule::postEmitFunction(SILDeclRef constant, if (auto *vjpDecl = diffAttr->getVJPFunction()) vjp = getFunction(SILDeclRef(vjpDecl), NotForDefinition); auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0}); - AutoDiffConfig config( - diffAttr->getParameterIndices(), resultIndices, - diffAttr->getDerivativeGenericSignature().getPointer()); + AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices, + diffAttr->getDerivativeGenericSignature()); emitDifferentiabilityWitness(AFD, F, config, jvp, vjp); } }; @@ -819,17 +818,14 @@ void SILGenModule::emitDifferentiabilityWitness( }; bool reorderSelf = shouldReorderSelf(); - CanGenericSignature derivativeCanGenSig; - if (auto derivativeGenSig = config.derivativeGenericSignature) - derivativeCanGenSig = derivativeGenSig->getCanonicalSignature(); // Create new SIL differentiability witness. // Witness JVP and VJP are set below. // TODO(TF-919): Explore creating serialized differentiability witnesses. // Currently, differentiability witnesses are never serialized to avoid // deserialization issues where JVP/VJP functions cannot be found. auto *diffWitness = SILDifferentiabilityWitness::createDefinition( - M, originalFunction->getLinkage(), originalFunction, - loweredParamIndices, config.resultIndices, derivativeCanGenSig, + M, originalFunction->getLinkage(), originalFunction, loweredParamIndices, + config.resultIndices, config.derivativeGenericSignature, /*jvp*/ nullptr, /*vjp*/ nullptr, /*isSerialized*/ false); // Set derivative function in differentiability witness. diff --git a/lib/SILOptimizer/Mandatory/CMakeLists.txt b/lib/SILOptimizer/Mandatory/CMakeLists.txt index deffbec8caf1f..31b1375426835 100644 --- a/lib/SILOptimizer/Mandatory/CMakeLists.txt +++ b/lib/SILOptimizer/Mandatory/CMakeLists.txt @@ -7,6 +7,7 @@ silopt_register_sources( DefiniteInitialization.cpp # SWIFT_ENABLE_TENSORFLOW Differentiation.cpp + Differentiation/DerivativeLookup.cpp DIMemoryUseCollector.cpp DataflowDiagnostics.cpp DiagnoseInfiniteRecursion.cpp diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 359a72b047d13..1944562e9874d 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -22,6 +22,8 @@ #define DEBUG_TYPE "differentiation" #include "Differentiation.h" +#include "Differentiation/DerivativeLookup.h" + #include "swift/AST/ASTMangler.h" #include "swift/AST/ASTPrinter.h" #include "swift/AST/AnyFunctionRef.h" @@ -345,8 +347,10 @@ static Inst *peerThroughFunctionConversions(SILValue value) { /// Finds the differentiability witness corresponding to `attr` in `module`. static SILDifferentiabilityWitness * findDifferentiabilityWitness(SILModule &module, SILDifferentiableAttr *attr) { - auto *resultIndices = - IndexSubset::get(module.getASTContext(), 1, {attr->getIndices().source}); + auto *resultIndices = IndexSubset::get( + module.getASTContext(), + attr->getOriginal()->getLoweredFunctionType()->getNumResults(), + {attr->getIndices().source}); AutoDiffConfig config(attr->getIndices().parameters, resultIndices, attr->getDerivativeGenericSignature()); return module.lookUpDifferentiabilityWitness( @@ -383,8 +387,10 @@ canonicalizeDifferentiabilityWitness(SILModule &module, static SILDifferentiabilityWitness * createDifferentiabilityWitness(SILModule &module, SILLinkage linkage, SILDifferentiableAttr *attr) { - auto *resultIndices = - IndexSubset::get(module.getASTContext(), 1, {attr->getIndices().source}); + auto *resultIndices = IndexSubset::get( + module.getASTContext(), + attr->getOriginal()->getLoweredFunctionType()->getNumResults(), + {attr->getIndices().source}); auto *witness = SILDifferentiabilityWitness::createDefinition( module, linkage, attr->getOriginal(), attr->getIndices().parameters, resultIndices, attr->getDerivativeGenericSignature(), /*jvp*/ nullptr, @@ -1083,78 +1089,6 @@ class ADContext { return nullptr; } - /// Finds the `[differentiable]` attribute on the specified original function - /// whose parameter indices are a minimal superset of the specified parameter - /// indices. Returns nullptr if no such attribute exists. - SILDifferentiableAttr *lookUpMinimalDifferentiableAttr( - SILFunction *original, const SILAutoDiffIndices &indices) const { - auto *minimalIndexSet = IndexSubset::getDefault( - getASTContext(), - original->getLoweredFunctionType()->getNumParameters(), false); - auto *indexSet = indices.parameters; - if (auto *exactAttr = lookUpDifferentiableAttr(original, indices)) - return exactAttr; - SILDifferentiableAttr *minimalAttr = nullptr; - for (auto *da : original->getDifferentiableAttrs()) { - if (da->getIndices().source != indices.source) - continue; - auto *daIndexSet = da->getIndices().parameters; - // If all indices in `indexSet` are in `daIndexSet`, and it has fewer - // indices than our current candidate and a primitive VJP, then `da` is - // our new candidate. - // - // NOTE(TF-642): `da` may come from a un-partial-applied function and - // have larger capacity than the desired indices. We expect this logic to - // go away when `partial_apply` supports `@differentiable` callees. - if (daIndexSet->isSupersetOf(indexSet->extendingCapacity( - getASTContext(), daIndexSet->getCapacity())) && - // fewer parameters than before - (minimalIndexSet->isEmpty() || - daIndexSet->getNumIndices() < minimalIndexSet->getNumIndices())) { - minimalAttr = da; - minimalIndexSet = daIndexSet; - } - } - return minimalAttr; - } - - /// Finds the `@differentiable` attribute (and its parameter indices) on the - /// specified original function whose parameter indices are a minimal - /// superset of the specified parameter indices. Returns nullptr if no such - /// attribute exists. - std::pair - lookUpMinimalASTDifferentiableAttrAndIndexSubset( - SILDeclRef originalDeclRef, CanSILFunctionType originalFnType, - const SILAutoDiffIndices &indices) { - auto *original = originalDeclRef.getDecl(); - const DifferentiableAttr *minimalAttr = nullptr; - auto *minimalIndexSet = IndexSubset::getDefault( - getASTContext(), originalFnType->getNumParameters(), false); - auto *indexSet = indices.parameters; - for (auto *da : original->getAttrs().getAttributes()) { - auto *daParamIndices = da->getParameterIndices(); - auto *daIndexSet = autodiff::getLoweredParameterIndices( - daParamIndices, - original->getInterfaceType()->castTo()); - // If all indices in `indexSet` are in `daIndexSet`, and it has fewer - // indices than our current candidate and a primitive VJP, then `da` is - // our new candidate. - // - // NOTE(TF-642): `da` may come from a un-partial-applied function and - // have larger capacity than the desired indices. We expect this logic to - // go away when `partial_apply` supports `@differentiable` callees. - if (daIndexSet->isSupersetOf(indexSet->extendingCapacity(getASTContext(), - daIndexSet->getCapacity())) && - // fewer parameters than before - (minimalIndexSet->isEmpty() || - daIndexSet->getNumIndices() < minimalIndexSet->getNumIndices())) { - minimalAttr = da; - minimalIndexSet = daIndexSet; - } - } - return std::make_pair(minimalAttr, minimalIndexSet); - } - /// Creates a `[differentiable]` attribute on the specified original function /// with the specified parameter indices. SILDifferentiableAttr *createDifferentiableAttr( @@ -2701,14 +2635,28 @@ emitDerivativeFunctionReference( peerThroughFunctionConversions(original)) { auto loc = originalFRI->getLoc(); auto *originalFn = originalFRI->getReferencedFunctionOrNull(); - // Attempt to look up a `[differentiable]` attribute that minimally - // satisfies the specified indices. - // TODO(TF-482): Change `lookUpMinimalDifferentiableAttr` to additionally - // check whether `[differentiable]` attribute generic requirements are - // satisfied. - auto *minimalAttr = - context.lookUpMinimalDifferentiableAttr(originalFn, desiredIndices); - if (!minimalAttr) { + auto originalFnTy = originalFn->getLoweredFunctionType(); + auto *desiredResultIndices = + IndexSubset::get(context.getASTContext(), originalFnTy->getNumResults(), + {desiredIndices.source}); + auto *desiredParameterIndices = desiredIndices.parameters; + // NOTE(TF-893): Extending capacity is necessary when `originalFnTy` has + // parameters corresponding to captured variables. + // TODO: If posssible, change `autodiff::getLoweredParameterIndices` to + // take `CaptureInfo` into account. + if (originalFnTy->getNumParameters() > + desiredParameterIndices->getCapacity()) { + desiredParameterIndices = desiredParameterIndices->extendingCapacity( + context.getASTContext(), originalFnTy->getNumParameters()); + } + auto *minimalWitness = getExactDifferentiabilityWitness( + context.getModule(), originalFn, desiredParameterIndices, + desiredResultIndices); + if (!minimalWitness) + minimalWitness = getOrCreateMinimalASTDifferentiabilityWitness( + context.getModule(), originalFn, desiredParameterIndices, + desiredResultIndices); + if (!minimalWitness) { // If the function is intentionally marked as being opaque to // differentiation, then we should not create a task for it. if (originalFn->hasSemanticsAttr("autodiff.opaque")) { @@ -2751,16 +2699,18 @@ emitDerivativeFunctionReference( contextualDerivativeGenSig = invoker.getIndirectDifferentiation().second ->getDerivativeGenericSignature(); auto *newAttr = context.getOrCreateDifferentiableAttr( - originalFn, desiredIndices, contextualDerivativeGenSig); + originalFn, + SILAutoDiffIndices(desiredIndices.source, desiredParameterIndices), + contextualDerivativeGenSig); if (context.processDifferentiableAttribute(originalFn, newAttr, invoker)) return None; - createDifferentiabilityWitness(context.getModule(), SILLinkage::Hidden, - newAttr); - minimalAttr = newAttr; + minimalWitness = createDifferentiabilityWitness( + context.getModule(), SILLinkage::Hidden, newAttr); } - assert(minimalAttr); + assert(minimalWitness); // TODO(TF-482): Move generic requirement checking logic to - // `lookUpMinimalDifferentiableAttr`. + // `getExactDifferentiabilityWitness` & + // `getOrCreateMinimalASTDifferentiabilityWitness`. // Get the substitution map for checking unmet generic requirements. // By default, use the forwarding substitution map of the original function. // If the original callee is a `partial_apply` or `apply` instruction, use @@ -2772,35 +2722,37 @@ emitDerivativeFunctionReference( substMap = ai->getSubstitutionMap(); } if (diagnoseUnsatisfiedRequirements( - context, minimalAttr->getDerivativeGenericSignature(), originalFn, - substMap, invoker, original.getLoc().getSourceLoc())) + context, minimalWitness->getDerivativeGenericSignature(), + originalFn, substMap, invoker, original.getLoc().getSourceLoc())) return None; - if (context.processDifferentiableAttribute( - originalFn, minimalAttr, invoker)) - return None; - SILFunction *derivativeFn = nullptr; + DifferentiabilityWitnessFunctionKind witnessKind; switch (kind) { case AutoDiffDerivativeFunctionKind::JVP: - assert(!minimalAttr->getJVPName().empty() && "Expected JVP name"); - derivativeFn = context.getModule().lookUpFunction(minimalAttr->getJVPName()); + witnessKind = DifferentiabilityWitnessFunctionKind::JVP; break; case AutoDiffDerivativeFunctionKind::VJP: - assert(!minimalAttr->getVJPName().empty() && "Expected VJP name"); - derivativeFn = context.getModule().lookUpFunction(minimalAttr->getVJPName()); + witnessKind = DifferentiabilityWitnessFunctionKind::VJP; break; } - auto *derivativeFnRef = builder.createFunctionRef(loc, derivativeFn); + auto *derivativeFnRef = builder.createDifferentiabilityWitnessFunction( + loc, witnessKind, minimalWitness); // FIXME(TF-201): Handle direct differentiation of reabstraction thunks. // Tentative solution: clone a new reabstraction thunk where function // argument has a `@differentiable` function type. if (originalFn->isThunk() == IsReabstractionThunk) { // Handle here. } - auto convertedRef = reapplyFunctionConversion( - derivativeFnRef, originalFRI, original, builder, loc, - newBuffersToDealloc, - derivativeFn->getLoweredFunctionType()->getSubstGenericSignature()); - return std::make_pair(convertedRef, minimalAttr->getIndices()); + auto convertedRef = + reapplyFunctionConversion(derivativeFnRef, originalFRI, original, + builder, loc, newBuffersToDealloc, + derivativeFnRef->getType() + .getASTType() + ->castTo() + ->getSubstGenericSignature()); + return std::make_pair( + convertedRef, + SILAutoDiffIndices(desiredIndices.source, + minimalWitness->getParameterIndices())); } // Find witness method retrieval. @@ -2808,8 +2760,7 @@ emitDerivativeFunctionReference( peerThroughFunctionConversions(original)) { auto loc = witnessMethod->getLoc(); auto requirementDeclRef = witnessMethod->getMember(); - auto *requirementDecl = requirementDeclRef.getDecl(); - auto witnessMethodType = witnessMethod->getType().castTo(); + auto *requirementDecl = requirementDeclRef.getAbstractFunctionDecl(); // If requirement declaration does not have any `@differentiable` // attributes, produce an error. if (!requirementDecl->getAttrs().hasAttribute()) { @@ -2818,11 +2769,9 @@ emitDerivativeFunctionReference( return None; } // Get the minimal `@differentiable` attribute and parameter index subset. - const DifferentiableAttr *minimalAttr; - IndexSubset *minimalParamIndexSet; - std::tie(minimalAttr, minimalParamIndexSet) = - context.lookUpMinimalASTDifferentiableAttrAndIndexSubset( - requirementDeclRef, witnessMethodType, desiredIndices); + IndexSubset *minimalParamIndexSet = nullptr; + const auto *minimalAttr = getMinimalASTDifferentiableAttr( + requirementDecl, desiredIndices.parameters, minimalParamIndexSet); SILAutoDiffIndices minimalIndices(/*source*/ 0, minimalParamIndexSet); // If minimal `@differentiable` attribute does not exist, then no attribute // exists with a superset of the desired indices. Produce an error. @@ -2855,8 +2804,7 @@ emitDerivativeFunctionReference( peerThroughFunctionConversions(original)) { auto loc = classMethodInst->getLoc(); auto methodDeclRef = classMethodInst->getMember(); - auto *methodDecl = methodDeclRef.getDecl(); - auto classMethodType = classMethodInst->getType().castTo(); + auto *methodDecl = methodDeclRef.getAbstractFunctionDecl(); // If method declaration does not have any `@differentiable` attributes, // produce an error. if (!methodDecl->getAttrs().hasAttribute()) { @@ -2865,11 +2813,9 @@ emitDerivativeFunctionReference( return None; } // Get the minimal `@differentiable` attribute and parameter index subset. - const DifferentiableAttr *minimalAttr; - IndexSubset *minimalParamIndexSet; - std::tie(minimalAttr, minimalParamIndexSet) = - context.lookUpMinimalASTDifferentiableAttrAndIndexSubset( - methodDeclRef, classMethodType, desiredIndices); + IndexSubset *minimalParamIndexSet = nullptr; + const auto *minimalAttr = getMinimalASTDifferentiableAttr( + methodDecl, desiredIndices.parameters, minimalParamIndexSet); SILAutoDiffIndices minimalIndices(/*source*/ 0, minimalParamIndexSet); // If minimal `@differentiable` attribute does not exist, then no attribute // exists with a superset of the desired indices. Produce an error. @@ -8618,6 +8564,10 @@ ADContext::getOrCreateSubsetParametersThunkForDerivativeFunction( assocRef = builder.createClassMethod( loc, classOperand, assocMethodInst->getMember(), thunk->mapTypeIntoContext(assocMethodInst->getType())); + } else if (auto *diffWitFn = peerThroughFunctionConversions< + DifferentiabilityWitnessFunctionInst>(derivativeFn)) { + assocRef = builder.createDifferentiabilityWitnessFunction( + loc, diffWitFn->getWitnessKind(), diffWitFn->getWitness()); } assert(assocRef && "Expected derivative function to be resolved"); diff --git a/lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.cpp b/lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.cpp new file mode 100644 index 0000000000000..221be8e1c388a --- /dev/null +++ b/lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.cpp @@ -0,0 +1,122 @@ +//===--- DerivativeLookup.cpp ---------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2019 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// SWIFT_ENABLE_TENSORFLOW +// +// Utilities for looking up derivatives of functions. +// +//===----------------------------------------------------------------------===// + +#include "DerivativeLookup.h" + +namespace swift { + +/// Returns the AbstractFunctionDecl corresponding to `F`. If there isn't one, +/// returns `nullptr`. +static AbstractFunctionDecl *getAFDOrNull(SILFunction *F) { + auto *DC = F->getDeclContext(); + if (!DC) + return nullptr; + + auto *D = DC->getAsDecl(); + if (!D) + return nullptr; + + return dyn_cast(D); +} + +SILDifferentiabilityWitness * +getExactDifferentiabilityWitness(SILModule &module, SILFunction *original, + IndexSubset *parameterIndices, + IndexSubset *resultIndices) { + for (auto *w : module.lookUpDifferentiabilityWitnessesForFunction( + original->getName())) { + if (w->getParameterIndices() == parameterIndices && + w->getResultIndices() == resultIndices) + return w; + } + return nullptr; +} + +const DifferentiableAttr * +getMinimalASTDifferentiableAttr(AbstractFunctionDecl *original, + IndexSubset *parameterIndices, + IndexSubset *&minimalParameterIndices) { + const DifferentiableAttr *minimalAttr = nullptr; + minimalParameterIndices = nullptr; + for (auto *attr : original->getAttrs().getAttributes()) { + auto *attrParameterIndices = autodiff::getLoweredParameterIndices( + attr->getParameterIndices(), + original->getInterfaceType()->castTo()); + // If all indices in `parameterIndices` are in `daParameterIndices`, and it + // has fewer indices than our current candidate and a primitive VJP, then + // `attr` is our new candidate. + // + // NOTE(TF-642): `attr` may come from a un-partial-applied function and + // have larger capacity than the desired indices. We expect this logic to + // go away when `partial_apply` supports `@differentiable` callees. + if (attrParameterIndices->isSupersetOf(parameterIndices->extendingCapacity( + original->getASTContext(), attrParameterIndices->getCapacity())) && + // fewer parameters than before + (!minimalParameterIndices || + attrParameterIndices->getNumIndices() < + minimalParameterIndices->getNumIndices())) { + minimalAttr = attr; + minimalParameterIndices = attrParameterIndices; + } + } + return minimalAttr; +} + +SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness( + SILModule &module, SILFunction *original, IndexSubset *parameterIndices, + IndexSubset *resultIndices) { + // AST differentiability witnesses always have a single result. + if (resultIndices->getCapacity() != 1 || !resultIndices->contains(0)) + return nullptr; + + // Explicit differentiability witnesses only exist on SILFunctions that come + // from AST functions. + auto *originalAFD = getAFDOrNull(original); + if (!originalAFD) + return nullptr; + + IndexSubset *minimalParameterIndices = nullptr; + const auto *minimalAttr = getMinimalASTDifferentiableAttr( + originalAFD, parameterIndices, minimalParameterIndices); + + // TODO(TF-835): This will also need to search all `@differentiating` + // attributes after we stop synthesizing `@differentiable` attributes for + // `@differentiating` attributes. + + if (!minimalAttr) + return nullptr; + + AutoDiffConfig minimalConfig(minimalParameterIndices, resultIndices, + minimalAttr->getDerivativeGenericSignature()); + + auto *existingWitness = module.lookUpDifferentiabilityWitness( + {original->getName(), minimalConfig}); + if (existingWitness) + return existingWitness; + + assert(original->isExternalDeclaration() && + "SILGen should create differentiability witnesses for all function " + "definitions with explicit differentiable attributes"); + + return SILDifferentiabilityWitness::createDeclaration( + module, SILLinkage::PublicExternal, original, + minimalConfig.parameterIndices, minimalConfig.resultIndices, + minimalConfig.derivativeGenericSignature); +} + +} // end namespace swift diff --git a/lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.h b/lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.h new file mode 100644 index 0000000000000..53fa804d21c08 --- /dev/null +++ b/lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.h @@ -0,0 +1,72 @@ +//===--- DerivativeLookup.h -----------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2019 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// SWIFT_ENABLE_TENSORFLOW +// +// Utilities for looking up derivatives of functions. +// +//===----------------------------------------------------------------------===// + +#ifndef SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_DERIVATIVELOOKUP_H +#define SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_DERIVATIVELOOKUP_H + +#include "swift/AST/AutoDiff.h" +#include "swift/SIL/SILDeclRef.h" +#include "swift/SIL/SILModule.h" + +namespace swift { + +/// Returns a differentiability witness (definition or declaration) exactly +/// matching the specified indices. If none are found in the given `module`, +/// returns `nullptr`. +/// +/// \param parameterIndices must be lowered to SIL. +/// \param resultIndices must be lowered to SIL. +SILDifferentiabilityWitness * +getExactDifferentiabilityWitness(SILModule &module, SILFunction *original, + IndexSubset *parameterIndices, + IndexSubset *resultIndices); + +/// Finds the "@differentiable" attribute on `original` whose parameter indices +/// are a minimal superset of the specified parameter indices. Returns `nullptr` +/// if no such attribute exists. +/// +/// \param parameterIndices must be lowered to SIL. +/// \param minimalParameterIndices is an output parameter that is set to the SIL +/// indices +/// of the minimal attribute, or to `nullptr` if no attribute exists. +const DifferentiableAttr * +getMinimalASTDifferentiableAttr(AbstractFunctionDecl *original, + IndexSubset *parameterIndices, + IndexSubset *&minimalParameterIndices); + +/// Returns a differentiability witness for `original` whose parameter indices +/// are a minimal superset of the specified parameter indices and whose result +/// indices match the given result indices, out of all +/// differentiability witnesses that come from AST "@differentiable" or +/// "@differentiating" attributes. +/// +/// This function never creates new differentiability witness definitions. +/// However, this function may create new differentiability witness declarations +/// referring to definitions in other modules when these witnesses have not yet +/// been declared in the current module. +/// +/// \param module is the SILModule in which to get or create the witnesses. +/// \param parameterIndices must be lowered to SIL. +/// \param resultIndices must be lowered to SIL. +SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness( + SILModule &module, SILFunction *original, IndexSubset *parameterIndices, + IndexSubset *resultIndices); + +} // end namespace swift + +#endif // SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_DERIVATIVELOOKUP_H diff --git a/test/AutoDiff/closures.swift b/test/AutoDiff/closures.swift index ca8478707e06e..6650a0af305d7 100644 --- a/test/AutoDiff/closures.swift +++ b/test/AutoDiff/closures.swift @@ -26,11 +26,12 @@ struct InoutAliasableCapture { // CHECK-LABEL: @{{.*}}InoutAliasableCapture{{.*}}foo{{.*}} : $@convention(method) (@inout InoutAliasableCapture) -> () { // CHECK: bb0([[SELF:%.*]] : $*InoutAliasableCapture): -// CHECK: [[JVP:%.*]] = function_ref @{{.*}}capturesMutableSelf{{.*}}__jvp_src_0_wrt_0 : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK: [[JVP:%.*]] = differentiability_witness_function [jvp] [parameters 0] [results 0] @{{.*}}capturesMutableSelf{{.*}} : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> Float // CHECK-NOT: retain_value_addr [[SELF]] // CHECK-NOT: copy_addr [[SELF]] // CHECK: [[JVP_CAPTURED:%.*]] = partial_apply [callee_guaranteed] [[JVP]]([[SELF]]) : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float) -// CHECK: [[VJP:%.*]] = function_ref @{{.*}}capturesMutableSelf{{.*}}__vjp_src_0_wrt_0 : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK: [[VJP:%.*]] = differentiability_witness_function [vjp] [parameters 0] [results 0] @{{.*}}capturesMutableSelf{{.*}} : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> Float +// CHECK-NOT: retain_value_addr [[SELF]] // CHECK-NOT: retain_value_addr [[SELF]] // CHECK-NOT: copy_addr [[SELF]] // CHECK: [[VJP_CAPTURED:%.*]] = partial_apply [callee_guaranteed] [[VJP]]([[SELF]]) : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float) diff --git a/test/AutoDiff/compiler_crashers/tf891-protocol-req-capture-propagation.swift b/test/AutoDiff/compiler_crashers_fixed/tf891-protocol-req-capture-propagation.swift similarity index 99% rename from test/AutoDiff/compiler_crashers/tf891-protocol-req-capture-propagation.swift rename to test/AutoDiff/compiler_crashers_fixed/tf891-protocol-req-capture-propagation.swift index f146c02c5f7e9..d5ed545caaff0 100644 --- a/test/AutoDiff/compiler_crashers/tf891-protocol-req-capture-propagation.swift +++ b/test/AutoDiff/compiler_crashers_fixed/tf891-protocol-req-capture-propagation.swift @@ -1,4 +1,4 @@ -// RUN: not --crash %target-swift-frontend -O -emit-ir %s +// RUN: %target-swift-frontend -O -emit-ir %s // REQUIRES: asserts public protocol Protocol: Differentiable { diff --git a/test/AutoDiff/differentiable_sil_attr_roundtrip.swift b/test/AutoDiff/differentiable_sil_attr_roundtrip.swift index 38a8986fd315f..106d9113fc0dd 100644 --- a/test/AutoDiff/differentiable_sil_attr_roundtrip.swift +++ b/test/AutoDiff/differentiable_sil_attr_roundtrip.swift @@ -9,8 +9,16 @@ // Assertion failed: (newCapacity >= capacity), function extendingCapacity // ... ADContext::promoteToDifferentiableFunction +// NOTE: We cannot differentiate external functions in roundtrip SIL tests. +// Reason: When we print then parse the SIL we lose the information that the +// external function is associated with an AST decl. So the differentiation +// pass can't see the AST differentiable attrs, and the differentiation pass +// thinks that we're trying to differentiate an external function without +// explicit AST differentiable attrs. +// TODO(TF-988): This can probably be fixed. + @differentiable(wrt: x) func TF_656(_ x: Float, _ y: Float) -> Float { - return x + y + return 0 } _ = gradient(at: 1, in: { x in TF_656(x, 2) }) diff --git a/test/AutoDiff/forward_mode_sil.swift b/test/AutoDiff/forward_mode_sil.swift index 8a490a0c4ee90..e84d36b1ae80a 100644 --- a/test/AutoDiff/forward_mode_sil.swift +++ b/test/AutoDiff/forward_mode_sil.swift @@ -22,15 +22,15 @@ func unary(_ x: Float) -> Float { // CHECK-SIL-LABEL: sil hidden [ossa] @AD__unary__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { // CHECK-SIL: bb0([[X_ARG:%.*]] : $Float): // CHECK-SIL: [[MULT_FUNC_1:%.*]] = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float -// CHECK-SIL: [[MULT_FUNC_JVP_1:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) -// CHECK-SIL: [[MULT_FUNC_VJP_1:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) +// CHECK-SIL: [[MULT_FUNC_JVP_1:%.*]] = differentiability_witness_function [jvp] [parameters 0 1] [results 0] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// CHECK-SIL: [[MULT_FUNC_VJP_1:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // CHECK-SIL: [[AUTODIFF_INST_1:%.*]] = differentiable_function [parameters 0 1] [[MULT_FUNC_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with_derivative {[[MULT_FUNC_JVP_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} // CHECK-SIL: [[AUTODIFF_EXTRACT_INST_1:%.*]] = differentiable_function_extract [jvp] [[AUTODIFF_INST_1]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float // CHECK-SIL: [[MULT_JVP_APPLY_TUPLE_1:%.*]] = apply [[AUTODIFF_EXTRACT_INST_1]]([[X_ARG]], [[X_ARG]], %3) : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: ([[ORIG_RESULT_1:%.*]], [[MULT_DIFF_1:%.*]]) = destructure_tuple [[MULT_JVP_APPLY_TUPLE_1]] : $(Float, @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: [[MULT_FUNC_2:%.*]] = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float -// CHECK-SIL: [[MULT_FUNC_JVP_2:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) -// CHECK-SIL: [[MULT_FUNC_VJP_2:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) +// CHECK-SIL: [[MULT_FUNC_JVP_2:%.*]] = differentiability_witness_function [jvp] [parameters 0 1] [results 0] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// CHECK-SIL: [[MULT_FUNC_VJP_2:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // CHECK-SIL: [[AUTODIFF_INST_2:%.*]] = differentiable_function [parameters 0 1] [[MULT_FUNC_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with_derivative {[[MULT_FUNC_JVP_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} // CHECK-SIL: [[AUTODIFF_EXTRACT_INST_1:%.*]] = differentiable_function_extract [jvp] [[AUTODIFF_INST_2]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float // CHECK-SIL: [[MULT_JVP_APPLY_TUPLE_2:%.*]] = apply [[AUTODIFF_EXTRACT_INST_1]]([[ORIG_RESULT_1]], [[X_ARG]], %2) : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) @@ -67,8 +67,8 @@ func binary(x: Float, y: Float) -> Float { // CHECK-SIL-LABEL: sil hidden [ossa] @AD__binary__jvp_src_0_wrt_0_1 : $@convention(thin) (Float, Float) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) { // CHECK-SIL: bb0([[X_ARG:%.*]] : $Float, [[Y_ARG:%.*]] : $Float): // CHECK-SIL: [[MULT_FUNC:%.*]] = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float -// CHECK-SIL: [[MULT_FUNC_JVP:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) -// CHECK-SIL: [[MULT_FUNC_VJP:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) +// CHECK-SIL: [[MULT_FUNC_JVP:%.*]] = differentiability_witness_function [jvp] [parameters 0 1] [results 0] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// CHECK-SIL: [[MULT_FUNC_VJP:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // CHECK-SIL: [[AUTODIFF_INST:%.*]] = differentiable_function [parameters 0 1] [[MULT_FUNC]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with_derivative {[[MULT_FUNC_JVP]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} // CHECK-SIL: [[AUTODIFF_EXTRACT_INST:%.*]] = differentiable_function_extract [jvp] [[AUTODIFF_INST]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float // CHECK-SIL: [[MULT_JVP_APPLY_TUPLE:%.*]] = apply [[AUTODIFF_EXTRACT_INST]]([[X_ARG]], [[Y_ARG]], %4) : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) diff --git a/test/AutoDiff/refcounting.swift b/test/AutoDiff/refcounting.swift index 8fdb2c704c3fb..73bd597b51bee 100644 --- a/test/AutoDiff/refcounting.swift +++ b/test/AutoDiff/refcounting.swift @@ -85,8 +85,8 @@ _ = pullback(at: Vector.zero, in: testOwnedVector) // // CHECK-LABEL: sil hidden @{{.*}}testOwnedVector{{.*}}__vjp_src_0_wrt_0 : $@convention(thin) (@guaranteed Vector) -> (@owned Vector, @owned @callee_guaranteed (@guaranteed Vector) -> @owned Vector) // CHECK: [[ADD:%.*]] = function_ref @Vector_plus -// CHECK: [[ADD_JVP:%.*]] = function_ref @{{.*}}Vector_plus__jvp_src_0_wrt_0_1{{.*}} -// CHECK: [[ADD_VJP:%.*]] = function_ref @{{.*}}Vector_plus__vjp_src_0_wrt_0_1{{.*}} +// CHECK: [[ADD_JVP:%.*]] = differentiability_witness_function [jvp] [parameters 0 1] [results 0] @Vector_plus +// CHECK: [[ADD_VJP:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @Vector_plus // CHECK: [[ADD_AD_FUNC:%.*]] = differentiable_function [parameters 0 1] [[ADD]] {{.*}} with_derivative {[[ADD_JVP]] {{.*}}, [[ADD_VJP]] {{.*}}} // CHECK: [[ADD_AD_FUNC_EXTRACT:%.*]] = differentiable_function_extract [vjp] [[ADD_AD_FUNC]] // CHECK: [[ADD_VJP_RESULT:%.*]] = apply [[ADD_AD_FUNC_EXTRACT]]({{.*}}, {{.*}}, {{.*}}) : $@convention(method) (@guaranteed Vector, @guaranteed Vector, @thin Vector.Type) -> (@owned Vector, @owned @callee_guaranteed (@guaranteed Vector) -> (@owned Vector, @owned Vector)) diff --git a/test/AutoDiff/subset_parameters_thunk.swift b/test/AutoDiff/subset_parameters_thunk.swift index 649ed7476454e..d17dbdb1dc231 100644 --- a/test/AutoDiff/subset_parameters_thunk.swift +++ b/test/AutoDiff/subset_parameters_thunk.swift @@ -19,11 +19,11 @@ func differentiate_foo_wrt_0(_ x: Float) -> Float { // CHECK: bb0 // CHECK: [[FOO_ORIG:%.*]] = function_ref @{{.*}}foo{{.*}} : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0 // CHECK: [[FOO_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_ORIG]]() : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0 -// CHECK: [[FOO_JVP:%.*]] = function_ref @AD__{{.*}}foo{{.*}}__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, @in_guaranteed τ_0_0.TangentVector) -> @out τ_0_0.TangentVector) +// CHECK: [[FOO_JVP:%.*]] = differentiability_witness_function [jvp] [parameters 0 1] [results 0] @{{.*}}foo{{.*}} : $@convention(thin) (@in_guaranteed T, @in_guaranteed T) -> @out T // CHECK: [[FOO_JVP_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_JVP]]() : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, @in_guaranteed τ_0_0.TangentVector) -> @out τ_0_0.TangentVector) // CHECK: [[FOO_JVP_SUBSET_THUNK_THIN:%.*]] = function_ref @AD__orig_{{.*}}foo{{.*}}_src_0_wrt_0_jvp_subset_parameters_thunk : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) // CHECK: [[FOO_JVP_SUBSET_THUNK:%.*]] = thin_to_thick_function [[FOO_JVP_SUBSET_THUNK_THIN]] : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) to $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) -// CHECK: [[FOO_VJP:%.*]] = function_ref @AD__{{.*}}foo{{.*}}__vjp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_0.TangentVector)) +// CHECK: [[FOO_VJP:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @{{.*}}foo{{.*}} : $@convention(thin) (@in_guaranteed T, @in_guaranteed T) -> @out T // CHECK: [[FOO_VJP_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_VJP]]() : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_0.TangentVector)) // CHECK: [[FOO_VJP_SUBSET_THUNK_THIN:%.*]] = function_ref @AD__orig_{{.*}}foo{{.*}}_src_0_wrt_0_vjp_subset_parameters_thunk : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) // CHECK: [[FOO_VJP_SUBSET_THUNK:%.*]] = thin_to_thick_function [[FOO_VJP_SUBSET_THUNK_THIN]] : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) to $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) diff --git a/test/AutoDiff/witness_table_sil.swift b/test/AutoDiff/witness_table_sil.swift index 3caec4edfb47b..cefe097cc0e87 100644 --- a/test/AutoDiff/witness_table_sil.swift +++ b/test/AutoDiff/witness_table_sil.swift @@ -1,3 +1,4 @@ +// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -differentiation-skip-folding-differentiable-function-extraction %s // RUN: %target-swift-frontend -emit-sil -verify -Xllvm -differentiation-skip-folding-differentiable-function-extraction %s | %FileCheck %s protocol Proto : Differentiable { @@ -24,7 +25,7 @@ struct S : Proto, AdditiveArithmetic { // CHECK-LABEL: sil {{.*}} @AD__{{.*}}function1{{.*}}_jvp_SSU : $@convention(witness_method: Proto) (Float, Double, @in_guaranteed S) -> (Float, @owned @callee_guaranteed (Float, Double) -> Float) { // CHECK: [[JVP1_ORIG_FNREF:%.*]] = function_ref {{.*}}function1{{.*}} : $@convention(method) (Float, Double, S) -> Float - // CHECK: [[JVP1_VJP_FNREF:%.*]] = function_ref @AD__{{.*}}function1{{.*}}__vjp_src_0_wrt_0_1 + // CHECK: [[JVP1_VJP_FNREF:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @{{.*}}function1{{.*}} // CHECK: [[JVP1_ADFUNC:%.*]] = differentiable_function [parameters 0 1] [[JVP1_ORIG_FNREF]] : {{.*}} with_derivative {{{%.*}} : {{.*}}, [[JVP1_VJP_FNREF]] : {{.*}}} // CHECK: [[JVP1:%.*]] = differentiable_function_extract [jvp] [[JVP1_ADFUNC]] : $@differentiable @convention(method) (Float, Double, @nondiff S) -> Float // CHECK: apply [[JVP1]] @@ -32,7 +33,7 @@ struct S : Proto, AdditiveArithmetic { // CHECK-LABEL: sil {{.*}} @AD__{{.*}}function1{{.*}}_vjp_SSU : $@convention(witness_method: Proto) (Float, Double, @in_guaranteed S) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Double)) { // CHECK: [[VJP1_ORIG_FNREF:%.*]] = function_ref {{.*}}function1{{.*}} : $@convention(method) (Float, Double, S) -> Float - // CHECK: [[VJP1_VJP_FNREF:%.*]] = function_ref @AD__{{.*}}function1{{.*}}__vjp_src_0_wrt_0_1 + // CHECK: [[VJP1_VJP_FNREF:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @{{.*}}function1{{.*}} // CHECK: [[VJP1_ADFUNC:%.*]] = differentiable_function [parameters 0 1] [[VJP1_ORIG_FNREF]] : {{.*}} with_derivative {{{%.*}} : {{.*}}, [[VJP1_VJP_FNREF]] : {{.*}}} // CHECK: [[VJP1:%.*]] = differentiable_function_extract [vjp] [[VJP1_ADFUNC]] : $@differentiable @convention(method) (Float, Double, @nondiff S) -> Float // CHECK: apply [[VJP1]] @@ -45,7 +46,7 @@ struct S : Proto, AdditiveArithmetic { // CHECK-LABEL: sil {{.*}} @AD__{{.*}}function2{{.*}}_jvp_SSS : $@convention(witness_method: Proto) (Float, Double, @in_guaranteed S) -> (Float, @owned @callee_guaranteed (Float, Double, @in_guaranteed S) -> Float) { // CHECK: [[JVP2_ORIG_FNREF:%.*]] = function_ref {{.*}}function2{{.*}} : $@convention(method) (Float, Double, S) -> Float - // CHECK: [[JVP2_VJP_FNREF:%.*]] = function_ref @AD__{{.*}}function2{{.*}}__vjp_src_0_wrt_0_1_2 + // CHECK: [[JVP2_VJP_FNREF:%.*]] = differentiability_witness_function [vjp] [parameters 0 1 2] [results 0] @{{.*}}function2{{.*}} // CHECK: [[JVP2_ADFUNC:%.*]] = differentiable_function [parameters 0 1 2] [[JVP2_ORIG_FNREF]] : {{.*}} with_derivative {{{%.*}} : {{.*}}, [[JVP2_VJP_FNREF]] : {{.*}}} // CHECK: [[JVP2:%.*]] = differentiable_function_extract [jvp] [[JVP2_ADFUNC]] : $@differentiable @convention(method) (Float, Double, S) -> Float // CHECK: apply [[JVP2]] @@ -53,7 +54,7 @@ struct S : Proto, AdditiveArithmetic { // CHECK-LABEL: sil {{.*}} @AD__{{.*}}function2{{.*}}_vjp_SSS : $@convention(witness_method: Proto) (Float, Double, @in_guaranteed S) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Double, @out S)) { // CHECK: [[VJP2_ORIG_FNREF:%.*]] = function_ref {{.*}}function2{{.*}} : $@convention(method) (Float, Double, S) -> Float - // CHECK: [[VJP2_VJP_FNREF:%.*]] = function_ref @AD__{{.*}}function2{{.*}}__vjp_src_0_wrt_0_1_2 + // CHECK: [[VJP2_VJP_FNREF:%.*]] = differentiability_witness_function [vjp] [parameters 0 1 2] [results 0] @{{.*}}function2{{.*}} // CHECK: [[VJP2_ADFUNC:%.*]] = differentiable_function [parameters 0 1 2] [[VJP2_ORIG_FNREF]] : {{.*}} with_derivative {{{%.*}} : {{.*}}, [[VJP2_VJP_FNREF]] : {{.*}}} // CHECK: [[VJP2:%.*]] = differentiable_function_extract [vjp] [[VJP2_ADFUNC]] : $@differentiable @convention(method) (Float, Double, S) -> Float // CHECK: apply [[VJP2]] @@ -66,7 +67,7 @@ struct S : Proto, AdditiveArithmetic { // CHECK-LABEL: sil {{.*}} @AD__{{.*}}function3{{.*}}_jvp_USU : $@convention(witness_method: Proto) (Float, Double, @in_guaranteed S) -> (Double, @owned @callee_guaranteed (Double) -> Double) { // CHECK: [[JVP3_ORIG_FNREF:%.*]] = function_ref {{.*}}function3{{.*}} : $@convention(method) (Float, Double, S) -> Double - // CHECK: [[JVP3_VJP_FNREF:%.*]] = function_ref @AD__{{.*}}function3{{.*}}__vjp_src_0_wrt_1 + // CHECK: [[JVP3_VJP_FNREF:%.*]] = differentiability_witness_function [vjp] [parameters 1] [results 0] @{{.*}}function3{{.*}} // CHECK: [[JVP3_ADFUNC:%.*]] = differentiable_function [parameters 1] [[JVP3_ORIG_FNREF]] : {{.*}} with_derivative {{{%.*}} : {{.*}}, [[JVP3_VJP_FNREF]] : {{.*}}} // CHECK: [[JVP3:%.*]] = differentiable_function_extract [jvp] [[JVP3_ADFUNC]] : $@differentiable @convention(method) (@nondiff Float, Double, @nondiff S) -> Double // CHECK: apply [[JVP3]] @@ -74,7 +75,7 @@ struct S : Proto, AdditiveArithmetic { // CHECK-LABEL: sil {{.*}} @AD__{{.*}}function3{{.*}}_vjp_USU : $@convention(witness_method: Proto) (Float, Double, @in_guaranteed S) -> (Double, @owned @callee_guaranteed (Double) -> Double) { // CHECK: [[VJP3_ORIG_FNREF:%.*]] = function_ref {{.*}}function3{{.*}} : $@convention(method) (Float, Double, S) -> Double - // CHECK: [[VJP3_VJP_FNREF:%.*]] = function_ref @AD__{{.*}}function3{{.*}}__vjp_src_0_wrt_1 + // CHECK: [[VJP3_VJP_FNREF:%.*]] = differentiability_witness_function [vjp] [parameters 1] [results 0] @{{.*}}function3{{.*}} // CHECK: [[VJP3_ADFUNC:%.*]] = differentiable_function [parameters 1] [[VJP3_ORIG_FNREF]] : {{.*}} with_derivative {{{%.*}} : {{.*}}, [[VJP3_VJP_FNREF]] : {{.*}}} // CHECK: [[VJP3:%.*]] = differentiable_function_extract [vjp] [[VJP3_ADFUNC]] : $@differentiable @convention(method) (@nondiff Float, Double, @nondiff S) -> Double // CHECK: apply [[VJP3]] From 2745db9fdb7d78dab21a7ded0686a2cd3757922c Mon Sep 17 00:00:00 2001 From: Marc Rasi Date: Fri, 22 Nov 2019 21:35:30 -0800 Subject: [PATCH 02/10] bail out when the list does not exist so that it does not segfault --- lib/Serialization/DeserializeSIL.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index 3b616360ee896..8857ee19d5b7d 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -356,16 +356,14 @@ SILDeserializer::getSILDifferentiabilityWitnessForReference( StringRef mangledKey) { // Check to see if we have a witness under this key already. auto *witness = SILMod.lookUpDifferentiabilityWitness(mangledKey); - if (witness) { + if (witness) return witness; - } - // Otherwise, look for a witness under this key in the module. + if (!DifferentiabilityWitnessList) + return nullptr; auto iter = DifferentiabilityWitnessList->find(mangledKey); - if (iter == DifferentiabilityWitnessList->end()) { + if (iter == DifferentiabilityWitnessList->end()) return nullptr; - } - return readDifferentiabilityWitness(*iter); } From 2d984fb98f9b166653fb78b9252059bf9159a0d0 Mon Sep 17 00:00:00 2001 From: Marc Rasi Date: Sat, 23 Nov 2019 14:40:50 -0800 Subject: [PATCH 03/10] simple testcase for the failure --- ...ability_witness_reference_serialization.sil | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 test/AutoDiff/sil_differentiability_witness_reference_serialization.sil diff --git a/test/AutoDiff/sil_differentiability_witness_reference_serialization.sil b/test/AutoDiff/sil_differentiability_witness_reference_serialization.sil new file mode 100644 index 0000000000000..c9fae96d3188e --- /dev/null +++ b/test/AutoDiff/sil_differentiability_witness_reference_serialization.sil @@ -0,0 +1,18 @@ +// RUN: mkdir -p %t +// RUN: %target-swift-frontend -emit-module -emit-module-path %t/test.swiftmodule -module-name test %s +// RUN: %target-sil-opt %t/test.swiftmodule + +sil_stage raw + +import Swift +import Builtin + +sil_differentiability_witness [parameters 0] [results 0] @referenced_from_serialized : $@convention(thin) (Float, Float, Float) -> Float + +sil @referenced_from_serialized : $@convention(thin) (Float, Float, Float) -> Float + +sil [serialized] @test_serialized : $@convention(thin) () -> () { +bb0: + %referenced_from_serialized_jvp_wrt_0 = differentiability_witness_function [jvp] [parameters 0] [results 0] @referenced_from_serialized : $@convention(thin) (Float, Float, Float) -> Float + return undef : $() +} From 12f9ff205a6da5e45123cfdcbb1e5b54db462870 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 24 Nov 2019 12:42:24 -0800 Subject: [PATCH 04/10] Minor edit. --- lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.h b/lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.h index 53fa804d21c08..0b8ee7a07ce7f 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.h +++ b/lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.h @@ -42,8 +42,7 @@ getExactDifferentiabilityWitness(SILModule &module, SILFunction *original, /// /// \param parameterIndices must be lowered to SIL. /// \param minimalParameterIndices is an output parameter that is set to the SIL -/// indices -/// of the minimal attribute, or to `nullptr` if no attribute exists. +/// indices of the minimal attribute, or to `nullptr` if no attribute exists. const DifferentiableAttr * getMinimalASTDifferentiableAttr(AbstractFunctionDecl *original, IndexSubset *parameterIndices, From bb932588856e23b74417368148bc84b0891b3252 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 24 Nov 2019 12:43:13 -0800 Subject: [PATCH 05/10] Minor edits. --- .../Mandatory/Differentiation/DerivativeLookup.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.cpp b/lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.cpp index 221be8e1c388a..0be76eb6322e9 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.cpp @@ -22,15 +22,13 @@ namespace swift { /// Returns the AbstractFunctionDecl corresponding to `F`. If there isn't one, /// returns `nullptr`. -static AbstractFunctionDecl *getAFDOrNull(SILFunction *F) { +static AbstractFunctionDecl *findAbstractFunctionDecl(SILFunction *F) { auto *DC = F->getDeclContext(); if (!DC) return nullptr; - auto *D = DC->getAsDecl(); if (!D) return nullptr; - return dyn_cast(D); } @@ -86,7 +84,7 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness( // Explicit differentiability witnesses only exist on SILFunctions that come // from AST functions. - auto *originalAFD = getAFDOrNull(original); + auto *originalAFD = findAbstractFunctionDecl(original); if (!originalAFD) return nullptr; From 655b1919ac0e9ed250e6568481f7cb926c43e438 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 24 Nov 2019 13:09:28 -0800 Subject: [PATCH 06/10] Move DerivativeLookup.h to include/swift/SILOptimizer/Utils. TF-993 tracks further organization of differentiation code. --- .../swift/SILOptimizer/Utils}/DerivativeLookup.h | 6 +++--- lib/SILOptimizer/Mandatory/CMakeLists.txt | 2 +- lib/SILOptimizer/Mandatory/Differentiation.cpp | 4 +++- lib/SILOptimizer/Utils/CMakeLists.txt | 3 +++ .../Differentiation => Utils}/DerivativeLookup.cpp | 2 +- 5 files changed, 11 insertions(+), 6 deletions(-) rename {lib/SILOptimizer/Mandatory/Differentiation => include/swift/SILOptimizer/Utils}/DerivativeLookup.h (93%) rename lib/SILOptimizer/{Mandatory/Differentiation => Utils}/DerivativeLookup.cpp (98%) diff --git a/lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.h b/include/swift/SILOptimizer/Utils/DerivativeLookup.h similarity index 93% rename from lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.h rename to include/swift/SILOptimizer/Utils/DerivativeLookup.h index 0b8ee7a07ce7f..3ee352d0edd0b 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.h +++ b/include/swift/SILOptimizer/Utils/DerivativeLookup.h @@ -16,8 +16,8 @@ // //===----------------------------------------------------------------------===// -#ifndef SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_DERIVATIVELOOKUP_H -#define SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_DERIVATIVELOOKUP_H +#ifndef SWIFT_SILOPTIMIZER_UTILS_DERIVATIVELOOKUP_H +#define SWIFT_SILOPTIMIZER_UTILS_DERIVATIVELOOKUP_H #include "swift/AST/AutoDiff.h" #include "swift/SIL/SILDeclRef.h" @@ -68,4 +68,4 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness( } // end namespace swift -#endif // SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_DERIVATIVELOOKUP_H +#endif // SWIFT_SILOPTIMIZER_UTILS_DERIVATIVELOOKUP_H diff --git a/lib/SILOptimizer/Mandatory/CMakeLists.txt b/lib/SILOptimizer/Mandatory/CMakeLists.txt index 31b1375426835..21efad85d5fb5 100644 --- a/lib/SILOptimizer/Mandatory/CMakeLists.txt +++ b/lib/SILOptimizer/Mandatory/CMakeLists.txt @@ -7,7 +7,7 @@ silopt_register_sources( DefiniteInitialization.cpp # SWIFT_ENABLE_TENSORFLOW Differentiation.cpp - Differentiation/DerivativeLookup.cpp + # SWIFT_ENABLE_TENSORFLOW END DIMemoryUseCollector.cpp DataflowDiagnostics.cpp DiagnoseInfiniteRecursion.cpp diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index db011ebf39e7a..62ceb88127ec8 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -17,12 +17,13 @@ // NOTE: Although the AD feature is developed as part of the Swift for // TensorFlow project, it is completely independent from TensorFlow support. // +// TODO(TF-993): Organize Differentiation.cpp into smaller files. +// //===----------------------------------------------------------------------===// #define DEBUG_TYPE "differentiation" #include "Differentiation.h" -#include "Differentiation/DerivativeLookup.h" #include "swift/AST/ASTMangler.h" #include "swift/AST/ASTPrinter.h" @@ -49,6 +50,7 @@ #include "swift/SILOptimizer/Analysis/LoopAnalysis.h" #include "swift/SILOptimizer/PassManager/Passes.h" #include "swift/SILOptimizer/PassManager/Transforms.h" +#include "swift/SILOptimizer/Utils/DerivativeLookup.h" #include "swift/SILOptimizer/Utils/LoopUtils.h" #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" #include "llvm/ADT/APSInt.h" diff --git a/lib/SILOptimizer/Utils/CMakeLists.txt b/lib/SILOptimizer/Utils/CMakeLists.txt index 36f43361f0aa0..d0473d57f609b 100644 --- a/lib/SILOptimizer/Utils/CMakeLists.txt +++ b/lib/SILOptimizer/Utils/CMakeLists.txt @@ -6,6 +6,9 @@ silopt_register_sources( CheckedCastBrJumpThreading.cpp ConstantFolding.cpp ConstExpr.cpp + # SWIFT_ENABLE_TENSORFLOW + DerivativeLookup.cpp + # SWIFT_ENABLE_TENSORFLOW END Devirtualize.cpp Existential.cpp GenericCloner.cpp diff --git a/lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.cpp b/lib/SILOptimizer/Utils/DerivativeLookup.cpp similarity index 98% rename from lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.cpp rename to lib/SILOptimizer/Utils/DerivativeLookup.cpp index 0be76eb6322e9..4cd2d0ba5173f 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation/DerivativeLookup.cpp +++ b/lib/SILOptimizer/Utils/DerivativeLookup.cpp @@ -16,7 +16,7 @@ // //===----------------------------------------------------------------------===// -#include "DerivativeLookup.h" +#include "swift/SILOptimizer/Utils/DerivativeLookup.h" namespace swift { From cb8a51841590e4c46c8f4a6716f386cc9fb02b91 Mon Sep 17 00:00:00 2001 From: Marc Rasi Date: Tue, 26 Nov 2019 12:52:06 -0800 Subject: [PATCH 07/10] the optimization made the crasher start crashing again --- .../tf891-protocol-req-capture-propagation.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename test/AutoDiff/{compiler_crashers_fixed => compiler_crashers}/tf891-protocol-req-capture-propagation.swift (99%) diff --git a/test/AutoDiff/compiler_crashers_fixed/tf891-protocol-req-capture-propagation.swift b/test/AutoDiff/compiler_crashers/tf891-protocol-req-capture-propagation.swift similarity index 99% rename from test/AutoDiff/compiler_crashers_fixed/tf891-protocol-req-capture-propagation.swift rename to test/AutoDiff/compiler_crashers/tf891-protocol-req-capture-propagation.swift index d5ed545caaff0..f146c02c5f7e9 100644 --- a/test/AutoDiff/compiler_crashers_fixed/tf891-protocol-req-capture-propagation.swift +++ b/test/AutoDiff/compiler_crashers/tf891-protocol-req-capture-propagation.swift @@ -1,4 +1,4 @@ -// RUN: %target-swift-frontend -O -emit-ir %s +// RUN: not --crash %target-swift-frontend -O -emit-ir %s // REQUIRES: asserts public protocol Protocol: Differentiable { From b5beecd95e7b8eee25a896c163949bb576615a73 Mon Sep 17 00:00:00 2001 From: Marc Rasi Date: Tue, 26 Nov 2019 18:57:29 -0800 Subject: [PATCH 08/10] fix duplicate linker symbol --- lib/SILOptimizer/Mandatory/Differentiation.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index a3b8a55ed53b0..c3ffcb8b601e9 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -393,10 +393,12 @@ createDifferentiabilityWitness(SILModule &module, SILLinkage linkage, module.getASTContext(), attr->getOriginal()->getLoweredFunctionType()->getNumResults(), {attr->getIndices().source}); + bool isSerialized = attr->getOriginal()->isSerialized(); auto *witness = SILDifferentiabilityWitness::createDefinition( - module, linkage, attr->getOriginal(), attr->getIndices().parameters, - resultIndices, attr->getDerivativeGenericSignature(), /*jvp*/ nullptr, - /*vjp*/ nullptr, /*isSerialized*/ false); + module, isSerialized ? SILLinkage::Shared : SILLinkage::Hidden, + attr->getOriginal(), attr->getIndices().parameters, resultIndices, + attr->getDerivativeGenericSignature(), /*jvp*/ nullptr, + /*vjp*/ nullptr, isSerialized); canonicalizeDifferentiabilityWitness(module, attr, witness); return witness; } From 5f93d6075378224ca45c2ab123a1aa6ab8baf270 Mon Sep 17 00:00:00 2001 From: Marc Rasi Date: Wed, 27 Nov 2019 16:55:56 -0800 Subject: [PATCH 09/10] fix LoadableByAddress --- include/swift/SIL/SILInstruction.h | 3 ++ lib/IRGen/LoadableByAddress.cpp | 24 +++++------ lib/ParseSIL/ParseSIL.cpp | 11 ++++- lib/SIL/SILInstructions.cpp | 3 +- lib/SIL/SILPrinter.cpp | 4 ++ lib/Serialization/DeserializeSIL.cpp | 5 ++- lib/Serialization/ModuleFormat.h | 2 +- lib/Serialization/SerializeSIL.cpp | 6 ++- .../loadable_by_address_cross_module.swift | 16 +++++++ ...ifferentiability_witness_function_inst.sil | 9 ++++ .../loadable_by_address_cross_module.swift | 43 +++++++++++++++++++ 11 files changed, 106 insertions(+), 20 deletions(-) create mode 100644 test/AutoDiff/Inputs/loadable_by_address_cross_module.swift create mode 100644 test/AutoDiff/loadable_by_address_cross_module.swift diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h index fed8da5f7c786..3b9fc77dacab0 100644 --- a/include/swift/SIL/SILInstruction.h +++ b/include/swift/SIL/SILInstruction.h @@ -8111,6 +8111,8 @@ class DifferentiabilityWitnessFunctionInst DifferentiabilityWitnessFunctionKind witnessKind; /// The referenced SIL differentiability witness. SILDifferentiabilityWitness *witness; + /// Whether the instruction has an explicit function type. + bool hasExplicitFunctionType; static SILType getDifferentiabilityWitnessType( SILModule &module, @@ -8128,6 +8130,7 @@ class DifferentiabilityWitnessFunctionInst return witnessKind; } SILDifferentiabilityWitness *getWitness() const { return witness; } + bool getHasExplicitFunctionType() const { return hasExplicitFunctionType; } ArrayRef getAllOperands() const { return {}; } MutableArrayRef getAllOperands() { return {}; } diff --git a/lib/IRGen/LoadableByAddress.cpp b/lib/IRGen/LoadableByAddress.cpp index 53adb45949616..23ecfcd9f2ff2 100644 --- a/lib/IRGen/LoadableByAddress.cpp +++ b/lib/IRGen/LoadableByAddress.cpp @@ -2683,24 +2683,22 @@ bool LoadableByAddress::recreateDifferentiabilityWitnessFunction( if (!instr) return false; - // If the witness is a declaration, then LoadableByAddress cannot have changed - // the type because the function is in a different module. - if (instr->getWitness()->isDeclaration()) - return true; - - // Otherwise, update the instruction if the function type changed. - auto resultTy = instr->getType(); - auto *referencedFn = instr->getWitness()->getDerivative( - *instr->getWitnessKind().getAsDerivativeFunctionKind()); - assert(referencedFn && "diff witness should be canonicalized"); - auto newResultTy = referencedFn->getLoweredType(); - if (resultTy == newResultTy) + // Check if we need to recreate the instruction. + auto *currIRMod = getIRGenModule()->IRGen.getGenModule(instr->getFunction()); + auto resultFnTy = instr->getType().castTo(); + auto genSig = resultFnTy->getSubstGenericSignature(); + GenericEnvironment *genEnv = nullptr; + if (genSig) + genEnv = genSig->getGenericEnvironment(); + auto newResultFnTy = + MapperCache.getNewSILFunctionType(genEnv, resultFnTy, *currIRMod); + if (resultFnTy == newResultFnTy) return true; SILBuilderWithScope builder(instr); auto *newInstr = builder.createDifferentiabilityWitnessFunction( instr->getLoc(), instr->getWitnessKind(), instr->getWitness(), - newResultTy); + SILType::getPrimitiveObjectType(newResultFnTy)); instr->replaceAllUsesWith(newInstr); Delete.push_back(instr); return true; diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index 5a7759d3b6f2b..64c6fcb31e605 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -3217,8 +3217,15 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { P.diagnose(keyStartLoc, diag::sil_diff_witness_undefined); return true; } - ResultVal = B.createDifferentiabilityWitnessFunction( - InstLoc, witnessKind, witness); + // Parse an optional explicit function type. + Optional functionType = None; + if (P.consumeIf(tok::kw_as)) { + functionType = SILType(); + if (parseSILType(*functionType)) + return true; + } + ResultVal = B.createDifferentiabilityWitnessFunction(InstLoc, witnessKind, + witness, functionType); break; } // SWIFT_ENABLE_TENSORFLOW END diff --git a/lib/SIL/SILInstructions.cpp b/lib/SIL/SILInstructions.cpp index f6207dfe47b48..2fa1237462f4e 100644 --- a/lib/SIL/SILInstructions.cpp +++ b/lib/SIL/SILInstructions.cpp @@ -786,7 +786,8 @@ DifferentiabilityWitnessFunctionInst::DifferentiabilityWitnessFunctionInst( ? *FunctionType : getDifferentiabilityWitnessType( module, witnessKind, witness)), - witnessKind(witnessKind), witness(witness) {} + witnessKind(witnessKind), witness(witness), + hasExplicitFunctionType(FunctionType) {} // SWIFT_ENABLE_TENSORFLOW END FunctionRefBaseInst::FunctionRefBaseInst(SILInstructionKind Kind, diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index c7d6af4c7cfb6..fe33c0046c58a 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -1299,6 +1299,10 @@ class SILPrinter : public SILInstructionVisitor { *this << " "; } printSILFunctionNameAndType(PrintState.OS, witness->getOriginalFunction()); + if (dwfi->getHasExplicitFunctionType()) { + *this << " as "; + *this << dwfi->getType(); + } } // SWIFT_ENABLE_TENSORFLOW END diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index 4e8f35ceb5f6d..31cdadccd1414 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -1644,8 +1644,11 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB, auto *witness = getSILDifferentiabilityWitnessForReference(mangledKey); assert(witness && "SILDifferentiabilityWitness not found"); DifferentiabilityWitnessFunctionKind witnessKind(Attr); + Optional explicitFnTy = None; + if (TyID) + explicitFnTy = getSILType(MF->getType(TyID), SILValueCategory::Object); ResultVal = Builder.createDifferentiabilityWitnessFunction( - Loc, witnessKind, witness); + Loc, witnessKind, witness, explicitFnTy); break; } // SWIFT_ENABLE_TENSORFLOW END diff --git a/lib/Serialization/ModuleFormat.h b/lib/Serialization/ModuleFormat.h index 87821d962d789..99e972db359f2 100644 --- a/lib/Serialization/ModuleFormat.h +++ b/lib/Serialization/ModuleFormat.h @@ -52,7 +52,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0; /// describe what change you made. The content of this comment isn't important; /// it just ensures a conflict if two people change the module format. /// Don't worry about adhering to the 80-column limit for this line. -const uint16_t SWIFTMODULE_VERSION_MINOR = 524; // differentiable_function_extract explicit extractee type +const uint16_t SWIFTMODULE_VERSION_MINOR = 525; // differentiability_witness_function explicit extractee type /// A standard hash seed used for all string hashes in a serialized module. /// diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp index 4d055cc7d6034..03e75da34c5c0 100644 --- a/lib/Serialization/SerializeSIL.cpp +++ b/lib/Serialization/SerializeSIL.cpp @@ -1072,11 +1072,13 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) { auto mangledKey = mangler.mangleSILDifferentiabilityWitnessKey( witness->getKey()); auto rawWitnessKind = (unsigned)dwfi->getWitnessKind(); + // We only store the type when the instruciton has an explicit type. + bool hasExplicitFnTy = dwfi->getHasExplicitFunctionType(); SILOneOperandLayout::emitRecord( Out, ScratchRecord, SILAbbrCodes[SILOneOperandLayout::Code], (unsigned)dwfi->getKind(), rawWitnessKind, - S.addTypeRef(dwfi->getType().getASTType()), - (unsigned)dwfi->getType().getCategory(), + hasExplicitFnTy ? S.addTypeRef(dwfi->getType().getASTType()) : TypeID(), + hasExplicitFnTy ? (unsigned)dwfi->getType().getCategory() : 0, S.addUniquedStringRef(mangledKey)); break; } diff --git a/test/AutoDiff/Inputs/loadable_by_address_cross_module.swift b/test/AutoDiff/Inputs/loadable_by_address_cross_module.swift new file mode 100644 index 0000000000000..2618acb00da01 --- /dev/null +++ b/test/AutoDiff/Inputs/loadable_by_address_cross_module.swift @@ -0,0 +1,16 @@ +public struct LargeLoadableType: AdditiveArithmetic, Differentiable { + public var a, b, c, d, e: Float + + public init(a: Float) { + self.a = a + self.b = 0 + self.c = 0 + self.d = 0 + self.e = 0 + } + + @differentiable + public func externalLBAModifiedFunction(_ x: Float) -> Float { + return a * x + } +} diff --git a/test/AutoDiff/differentiability_witness_function_inst.sil b/test/AutoDiff/differentiability_witness_function_inst.sil index 2d28b4c2f2788..6668f1399ea01 100644 --- a/test/AutoDiff/differentiability_witness_function_inst.sil +++ b/test/AutoDiff/differentiability_witness_function_inst.sil @@ -57,6 +57,10 @@ bb0: // Test "dependent" generic requirements: `T == T.TangentVector` depends on `T: Differentiable`. %generic_vjp_wrt_0_1_dependent_req = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T + + // Test explicit function types. + %explicit_fnty = differentiability_witness_function [jvp] [parameters 0] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float as $@convention(thin) (Float, Float, Float) -> (Float, (Float) -> Float) + return undef : $() } @@ -68,6 +72,8 @@ bb0: // CHECK: {{%.*}} = differentiability_witness_function [vjp] [parameters 0 1] [results 0 1] @bar : $@convention(thin) (Float, Float, Float) -> (Float, Float) // CHECK: {{%.*}} = differentiability_witness_function [jvp] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 // CHECK: {{%.*}} = differentiability_witness_function [vjp] [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : Differentiable> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 +// CHECK: {{%.*}} = differentiability_witness_function [vjp] [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 == τ_0_0.TangentVector> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 +// CHECK: {{%.*}} = differentiability_witness_function [jvp] [parameters 0] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float as $@convention(thin) (Float, Float, Float) -> (Float, (Float) -> Float) // CHECK: } // IRGEN: @AD__foo_PSUURS = external global %swift.differentiability_witness, align 8 @@ -100,3 +106,6 @@ bb0: // IRGEN: [[PTR7:%.*]] = load i8*, i8** getelementptr inbounds (%swift.differentiability_witness, %swift.differentiability_witness* @AD__generic_PSSRSs14DifferentiableRz13TangentVectorsAAPQzRszl, i32 0, i32 1), align 8 // IRGEN: [[FNPTR7:%.*]] = bitcast i8* [[PTR7]] to { i8*, %swift.refcounted* } (%swift.opaque*, %swift.opaque*, float, %swift.type*, i8**)* + +// IRGEN: [[PTR8:%.*]] = load i8*, i8** getelementptr inbounds (%swift.differentiability_witness, %swift.differentiability_witness* @AD__foo_PSUURS, i32 0, i32 0), align 8 +// IRGEN: [[FNPTR8:%.*]] = bitcast i8* [[PTR8]] to { float, i8*, %swift.refcounted* } (float, float, float)* diff --git a/test/AutoDiff/loadable_by_address_cross_module.swift b/test/AutoDiff/loadable_by_address_cross_module.swift new file mode 100644 index 0000000000000..7941309b598bb --- /dev/null +++ b/test/AutoDiff/loadable_by_address_cross_module.swift @@ -0,0 +1,43 @@ +// First, check that LBA actually modifies the function, so that this test is useful. + +// RUN: %target-swift-frontend -emit-sil %S/Inputs/loadable_by_address_cross_module.swift | %FileCheck %s -check-prefix=CHECK-MODULE-PRE-LBA +// RUN: %target-swift-frontend -c -Xllvm -sil-print-after=loadable-address %S/Inputs/loadable_by_address_cross_module.swift 2>&1 | %FileCheck %s -check-prefix=CHECK-MODULE-POST-LBA + +// CHECK-MODULE-PRE-LBA: sil {{.*}}LBAModifiedFunction{{.*}} $@convention(method) (Float, LargeLoadableType) -> Float +// CHECK-MODULE-POST-LBA: sil {{.*}}LBAModifiedFunction{{.*}} $@convention(method) (Float, @in_constant LargeLoadableType) -> Float + +// Compile the module. + +// RUN: %empty-directory(%t) +// RUN: %target-swift-frontend -c -parse-as-library -emit-module -module-name external -emit-module-path %t/external.swiftmodule -o %t/external.o %S/Inputs/loadable_by_address_cross_module.swift + +// Next, check that differentiability_witness_functions in the client get +// correctly modified by LBA. + +// RUN: %target-swift-frontend -emit-sil -I%t %s | %FileCheck %s -check-prefix=CHECK-CLIENT-PRE-LBA +// RUN: %target-swift-frontend -c -I%t %s -Xllvm -sil-print-after=loadable-address 2>&1 | %FileCheck %s -check-prefix=CHECK-CLIENT-POST-LBA + +// CHECK-CLIENT-PRE-LBA: differentiability_witness_function [jvp] [parameters 0 1] [results 0] @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, LargeLoadableType<τ_0_0>) -> Float +// CHECK-CLIENT-PRE-LBA: differentiability_witness_function [vjp] [parameters 0 1] [results 0] @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, LargeLoadableType<τ_0_0>) -> Float + +// CHECK-CLIENT-POST-LBA: differentiability_witness_function [jvp] [parameters 0 1] [results 0] @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> Float as $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> (Float, @owned @callee_guaranteed (Float, @in_constant LargeLoadableType<τ_0_0>) -> Float) +// CHECK-CLIENT-POST-LBA: differentiability_witness_function [vjp] [parameters 0 1] [results 0] @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> Float as $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> (Float, @owned @callee_guaranteed (Float) -> (Float, LargeLoadableType<τ_0_0>)) + +// Finally, execute the test. + +// RUN: %target-build-swift -I%t %s %t/external.o -o %t/a.out -lm +// RUN: %target-run %t/a.out + +// REQUIRES: executable_test + +import external +import StdlibUnittest + +var Tests = TestSuite("LoadableByAddressCrossModule") + +Tests.test("Correctness") { + let g = gradient(at: LargeLoadableType(a: 5), 10) { $0.externalLBAModifiedFunction($1) } + expectEqual((LargeLoadableType(a: 10), 5), g) +} + +runAllTests() From 44f2ac0561f1febd66d3804d560e3c72bb43d6ec Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 28 Nov 2019 05:50:42 -0500 Subject: [PATCH 10/10] Fix minor typo. --- lib/Serialization/SerializeSIL.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp index 03e75da34c5c0..c5fa470dbc3af 100644 --- a/lib/Serialization/SerializeSIL.cpp +++ b/lib/Serialization/SerializeSIL.cpp @@ -1072,7 +1072,7 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) { auto mangledKey = mangler.mangleSILDifferentiabilityWitnessKey( witness->getKey()); auto rawWitnessKind = (unsigned)dwfi->getWitnessKind(); - // We only store the type when the instruciton has an explicit type. + // We only store the type when the instruction has an explicit type. bool hasExplicitFnTy = dwfi->getHasExplicitFunctionType(); SILOneOperandLayout::emitRecord( Out, ScratchRecord, SILAbbrCodes[SILOneOperandLayout::Code],