3333#include " swift/AST/Module.h"
3434#include " swift/AST/ParameterList.h"
3535#include " swift/AST/SubstitutionMap.h"
36- #include " swift/Serialization/SerializedSILLoader.h"
3736#include " swift/SIL/FormalLinkage.h"
3837#include " swift/SIL/LoopInfo.h"
3938#include " swift/SIL/SILBuilder.h"
@@ -856,10 +855,14 @@ class ADContext {
856855 void clearTask (DifferentiationTask *task) {
857856 LLVM_DEBUG (getADDebugStream () << " Clearing differentiation task for "
858857 << task->original ->getName () << ' \n ' );
859- transform.notifyWillDeleteFunction (task->primal );
860- module .eraseFunction (task->primal );
861- transform.notifyWillDeleteFunction (task->adjoint );
862- module .eraseFunction (task->adjoint );
858+ if (task->primal ) {
859+ transform.notifyWillDeleteFunction (task->primal );
860+ module .eraseFunction (task->primal );
861+ }
862+ if (task->adjoint ) {
863+ transform.notifyWillDeleteFunction (task->adjoint );
864+ module .eraseFunction (task->adjoint );
865+ }
863866 transform.notifyWillDeleteFunction (task->jvp );
864867 module .eraseFunction (task->jvp );
865868 transform.notifyWillDeleteFunction (task->vjp );
@@ -980,6 +983,14 @@ class ADContext {
980983 return differentiationTasks.back ().get ();
981984 }
982985
986+ // / Declare an external reference to an associated function of `original`,
987+ // / given a `[differentiable]` attribute of `original` and the associated
988+ // / function kind.
989+ SILFunction *
990+ declareExternalAssociatedFunction (SILFunction *original,
991+ SILDifferentiableAttr *attr,
992+ AutoDiffAssociatedFunctionKind kind);
993+
983994 template <typename ... T, typename ... U>
984995 InFlightDiagnostic diagnose (SourceLoc loc, Diag<T...> diag,
985996 U &&... args) const {
@@ -1006,11 +1017,7 @@ class ADContext {
10061017
10071018ADContext::ADContext (SILModuleTransform &transform)
10081019 : transform(transform), module(*transform.getModule()),
1009- passManager(*transform.getPassManager()) {
1010- // Note: `getSILLoader` performs important initialization and is necessary to
1011- // prevent test failures related to `lookUpFunctionInWitnessTable`.
1012- (void )module .getSILLoader ();
1013- }
1020+ passManager(*transform.getPassManager()) {}
10141021
10151022void ADContext::emitNondifferentiabilityError (SILValue value,
10161023 const DifferentiationTask *task,
@@ -2261,7 +2268,7 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
22612268} // end anonymous namespace
22622269
22632270bool PrimalGen::performSynthesis (FunctionSynthesisItem item) {
2264- LLVM_DEBUG (getADDebugStream () << " Performing primal synthesis for original"
2271+ LLVM_DEBUG (getADDebugStream () << " Performing primal synthesis for original "
22652272 << item.original ->getName () << " and its corresponding primal "
22662273 << item.target ->getName () << ' \n ' );
22672274 // FIXME: If the original function has multiple basic blocks, bail out since
@@ -2314,8 +2321,8 @@ bool PrimalGen::run() {
23142321 auto synthesis = worklist.back ();
23152322 worklist.pop_back ();
23162323 if (performSynthesis (synthesis)) {
2317- context.clearTask (synthesis.task );
23182324 errorOccurred = true ;
2325+ continue ;
23192326 }
23202327 synthesis.task ->getPrimalInfo ()->computePrimalValueStructType ();
23212328 synthesis.task ->setPrimalSynthesisState (FunctionSynthesisState::Done);
@@ -2373,8 +2380,8 @@ bool AdjointGen::run() {
23732380 auto synthesis = worklist.back ();
23742381 worklist.pop_back ();
23752382 if (performSynthesis (synthesis)) {
2376- context.clearTask (synthesis.task );
23772383 errorOccurred = true ;
2384+ continue ;
23782385 }
23792386 synthesis.task ->setAdjointSynthesisState (FunctionSynthesisState::Done);
23802387 }
@@ -3301,8 +3308,6 @@ void AdjointEmitter::materializeZeroIndirect(CanType type,
33013308 // %wm = witness_method ...
33023309 auto *getter = builder.createWitnessMethod (loc, type, confRef,
33033310 accessorDeclRef, methodType);
3304- // Ensure that the witness table is linked.
3305- (void )getModule ().lookUpFunctionInWitnessTable (confRef, accessorDeclRef);
33063311 // %metatype = metatype $T
33073312 auto metatypeType = CanMetatypeType::get (type, MetatypeRepresentation::Thick);
33083313 auto metatype = builder.createMetatype (
@@ -3594,8 +3599,6 @@ void AdjointEmitter::accumulateMaterializedAdjointsIndirect(
35943599 // %0 = witness_method @+
35953600 auto witnessMethod = builder.createWitnessMethod (loc, adjointASTTy,
35963601 confRef, declRef, silFnTy);
3597- // Ensure the witness method is linked.
3598- getModule ().lookUpFunctionInWitnessTable (confRef, declRef);
35993602 auto subMap =
36003603 SubstitutionMap::getProtocolSubstitutions (proto, adjointASTTy, confRef);
36013604 // %1 = metatype $T.Type
@@ -3623,7 +3626,7 @@ void AdjointEmitter::accumulateMaterializedAdjointsIndirect(
36233626}
36243627
36253628bool AdjointGen::performSynthesis (FunctionSynthesisItem item) {
3626- LLVM_DEBUG (getADDebugStream () << " Performing adjoint synthesis for original"
3629+ LLVM_DEBUG (getADDebugStream () << " Performing adjoint synthesis for original "
36273630 << item.original ->getName () << " and its corresponding adjoint "
36283631 << item.target ->getName () << ' \n ' );
36293632 auto &passManager = context.getPassManager ();
@@ -3639,25 +3642,90 @@ bool AdjointGen::performSynthesis(FunctionSynthesisItem item) {
36393642// DifferentiationTask
36403643// ===----------------------------------------------------------------------===//
36413644
3645+ // Return the expected generic signature for autodiff associated functions given
3646+ // a SILDifferentiableAttr. The expected generic signature is built from the
3647+ // original generic signature and the attribute's requirements.
3648+ static CanGenericSignature
3649+ getAutoDiffAssociatedFunctionGenericSignature (SILDifferentiableAttr *attr,
3650+ SILFunction *original) {
3651+ auto originalGenSig =
3652+ original->getLoweredFunctionType ()->getGenericSignature ();
3653+ if (!originalGenSig)
3654+ return nullptr ;
3655+ GenericSignatureBuilder builder (original->getASTContext ());
3656+ // Add original generic signature.
3657+ builder.addGenericSignature (originalGenSig);
3658+ // Add where clause requirements.
3659+ auto source =
3660+ GenericSignatureBuilder::FloatingRequirementSource::forAbstract ();
3661+ for (auto &req : attr->getRequirements ())
3662+ builder.addRequirement (req, source, original->getModule ().getSwiftModule ());
3663+ return std::move (builder)
3664+ .computeGenericSignature (SourceLoc (), /* allowConcreteGenericParams=*/ true )
3665+ ->getCanonicalSignature ();
3666+ }
3667+
3668+ SILFunction *
3669+ ADContext::declareExternalAssociatedFunction (
3670+ SILFunction *original, SILDifferentiableAttr *attr,
3671+ AutoDiffAssociatedFunctionKind kind) {
3672+ auto &module = getModule ();
3673+ auto &indices = attr->getIndices ();
3674+ auto originalTy = original->getLoweredFunctionType ();
3675+ auto originalLoc = original->getLocation ();
3676+ StringRef name;
3677+ switch (kind) {
3678+ case AutoDiffAssociatedFunctionKind::JVP:
3679+ name = attr->getJVPName ();
3680+ break ;
3681+ case AutoDiffAssociatedFunctionKind::VJP:
3682+ name = attr->getVJPName ();
3683+ break ;
3684+ }
3685+ auto assocGenSig =
3686+ getAutoDiffAssociatedFunctionGenericSignature (attr, original);
3687+ auto assocFnTy = originalTy->getAutoDiffAssociatedFunctionType (
3688+ indices.parameters , indices.source , /* differentiationOrder*/ 1 , kind,
3689+ module , LookUpConformanceInModule (module .getSwiftModule ()), assocGenSig);
3690+ SILOptFunctionBuilder fb (getTransform ());
3691+ // Create external function declaration.
3692+ auto *assocFn =
3693+ fb.createFunction (SILLinkage::PublicExternal, name, assocFnTy,
3694+ /* GenericEnv*/ nullptr , originalLoc, original->isBare (),
3695+ IsNotTransparent, original->isSerialized ());
3696+ // NOTE: Setting debug scope is necessary to prevent crash in TFPartition.
3697+ assocFn->setDebugScope (new (module ) SILDebugScope (originalLoc, assocFn));
3698+ return assocFn;
3699+ }
3700+
36423701DifferentiationTask::DifferentiationTask (ADContext &context,
36433702 SILFunction *original,
36443703 SILDifferentiableAttr *&&attr,
36453704 DifferentiationInvoker invoker)
36463705 : context(context), original(original), attr(attr), invoker(invoker) {
3706+ auto &module = context.getModule ();
36473707 if (attr->hasJVP ()) {
3648- jvp = lookUpOrLinkFunction (attr->getJVPName (), context.getModule ());
3649- assert (jvp);
3708+ // If attribute specifies JVP name, try to look up JVP in current module.
3709+ // Otherwise, create an external reference.
3710+ jvp = module .lookUpFunction (attr->getJVPName ());
3711+ if (!jvp)
3712+ jvp = context.declareExternalAssociatedFunction (
3713+ original, attr, AutoDiffAssociatedFunctionKind::JVP);
36503714 }
36513715 if (attr->hasVJP ()) {
3652- vjp = lookUpOrLinkFunction (attr->getVJPName (), context.getModule ());
3653- assert (vjp);
3716+ // If attribute specifies VJP name, try to look up VJP in current module.
3717+ // Otherwise, create an external reference.
3718+ vjp = module .lookUpFunction (attr->getVJPName ());
3719+ if (!vjp)
3720+ vjp = context.declareExternalAssociatedFunction (
3721+ original, attr, AutoDiffAssociatedFunctionKind::VJP);
36543722 }
36553723
36563724 if (!jvp)
36573725 createJVP ();
36583726
36593727 if (vjp) {
3660- // If we already have the vjp , then we don't need to synthesize anything .
3728+ // If the VJP exists , then no synthesis is needed .
36613729 primalSynthesisState = FunctionSynthesisState::NotNeeded;
36623730 adjointSynthesisState = FunctionSynthesisState::NotNeeded;
36633731 return ;
@@ -3670,31 +3738,6 @@ DifferentiationTask::DifferentiationTask(ADContext &context,
36703738 createVJP ();
36713739}
36723740
3673- // Return the expected generic signature for autodiff associated functions given
3674- // a SILDifferentiableAttr. The expected generic signature is built from the
3675- // original generic signature and the attribute's requirements.
3676- static GenericSignature *
3677- getAutoDiffAssociatedFunctionGenericSignature (SILDifferentiableAttr *attr,
3678- SILFunction *original) {
3679- auto originalGenSig =
3680- original->getLoweredFunctionType ()->getGenericSignature ();
3681- if (!originalGenSig)
3682- return nullptr ;
3683- GenericSignatureBuilder builder (original->getASTContext ());
3684- // Add original generic signature.
3685- builder.addGenericSignature (originalGenSig);
3686- // Add where clause requirements.
3687- auto source =
3688- GenericSignatureBuilder::FloatingRequirementSource::forAbstract ();
3689- for (auto &req : attr->getRequirements ())
3690- builder.addRequirement (req, source, original->getModule ().getSwiftModule ());
3691- auto canGenericSig = std::move (builder)
3692- .computeGenericSignature (
3693- SourceLoc (), /* allowConcreteGenericParams=*/ true )
3694- ->getCanonicalSignature ();
3695- return canGenericSig;
3696- }
3697-
36983741void DifferentiationTask::createEmptyPrimal () {
36993742 assert (primalSynthesisState == FunctionSynthesisState::Needed);
37003743 assert (!primalInfo);
@@ -3707,7 +3750,7 @@ void DifferentiationTask::createEmptyPrimal() {
37073750 .getIdentifier (" AD__" + original->getName ().str () +
37083751 " __primal_" + indices.mangle ())
37093752 .str ();
3710- auto * primalGenericSig =
3753+ auto primalGenericSig =
37113754 getAutoDiffAssociatedFunctionGenericSignature (attr, original);
37123755 StructDecl *primalValueStructDecl = context.createPrimalValueStruct (this );
37133756 primalInfo = std::unique_ptr<PrimalInfo>(
@@ -3846,7 +3889,7 @@ void DifferentiationTask::createEmptyAdjoint() {
38463889 .getIdentifier (" AD__" + original->getName ().str () +
38473890 " __adjoint_" + getIndices ().mangle ())
38483891 .str ();
3849- auto * adjGenericSig =
3892+ auto adjGenericSig =
38503893 getAutoDiffAssociatedFunctionGenericSignature (attr, original);
38513894 auto *adjGenericEnv = adjGenericSig
38523895 ? adjGenericSig->createGenericEnvironment ()
@@ -3887,7 +3930,7 @@ void DifferentiationTask::createJVP() {
38873930 .getIdentifier (" AD__" + original->getName ().str () +
38883931 " __jvp_" + getIndices ().mangle ())
38893932 .str ();
3890- auto * jvpGenericSig =
3933+ auto jvpGenericSig =
38913934 getAutoDiffAssociatedFunctionGenericSignature (attr, original);
38923935 auto *jvpGenericEnv = jvpGenericSig
38933936 ? jvpGenericSig->createGenericEnvironment ()
@@ -3946,7 +3989,7 @@ void DifferentiationTask::createVJP() {
39463989 .getIdentifier (" AD__" + original->getName ().str () +
39473990 " __vjp_" + getIndices ().mangle ())
39483991 .str ();
3949- auto * vjpGenericSig =
3992+ auto vjpGenericSig =
39503993 getAutoDiffAssociatedFunctionGenericSignature (attr, original);
39513994 auto *vjpGenericEnv = vjpGenericSig
39523995 ? vjpGenericSig->createGenericEnvironment ()
@@ -4299,22 +4342,33 @@ void Differentiation::run() {
42994342 for (auto *adfi : autodiffInsts)
43004343 errorProcessingAutoDiffInsts |= processAutoDiffFunctionInst (adfi, context);
43014344
4345+ auto cleanUp = [&]() {
4346+ for (auto &task : context.getDifferentiationTasks ())
4347+ context.clearTask (task.get ());
4348+ };
4349+
43024350 // Run primal generation for newly created differentiation tasks. If any error
43034351 // occurs, back out.
43044352 PrimalGen primalGen (context);
4305- if (primalGen.run ())
4353+ if (primalGen.run ()) {
4354+ cleanUp ();
43064355 return ;
4356+ }
43074357
43084358 // Run adjoint generation for differentiation tasks. If any error occurs, back
43094359 // out.
43104360 AdjointGen adjointGen (context);
4311- if (adjointGen.run ())
4361+ if (adjointGen.run ()) {
4362+ cleanUp ();
43124363 return ;
4364+ }
43134365
43144366 // If there was any error that occurred during `autodiff_function` instruction
43154367 // processing, back out.
4316- if (errorProcessingAutoDiffInsts)
4368+ if (errorProcessingAutoDiffInsts) {
4369+ cleanUp ();
43174370 return ;
4371+ }
43184372
43194373 LLVM_DEBUG (getADDebugStream () << " All differentiation finished\n " );
43204374}
0 commit comments