@@ -669,6 +669,24 @@ struct NestedApplyActivity {
669669 SILAutoDiffIndices indices;
670670};
671671
672+ // / Specifies how we should differentiate a `struct_extract` instruction.
673+ enum class StructExtractDifferentiationStrategy {
674+ // The `struct_extract` is not active, so do not differentiate it.
675+ Inactive,
676+
677+ // The `struct_extract` is extracting a field from a Differentiable struct
678+ // with @_fieldwiseProductSpace cotangent space. Therefore, differentiate the
679+ // `struct_extract` by setting the adjoint to a vector in the cotangent space
680+ // that is zero except along the direction of the corresponding field.
681+ //
682+ // Fields correspond by matching name.
683+ FieldwiseProductSpace,
684+
685+ // Differentiate the `struct_extract` by looking up the corresponding getter
686+ // and using its VJP.
687+ Getter
688+ };
689+
672690// / A differentiation task, specifying the original function and the
673691// / `[differentiable]` attribute on the function. PrimalGen and AdjointGen
674692// / will synthesize the primal and the adjoint for this task, filling the primal
@@ -714,6 +732,10 @@ class DifferentiationTask {
714732 // / Note: This is only used when `DifferentiationUseVJP`.
715733 DenseMap<ApplyInst *, NestedApplyActivity> nestedApplyActivities;
716734
735+ // / Mapping from original `struct_extract` instructions to their strategies.
736+ DenseMap<StructExtractInst *, StructExtractDifferentiationStrategy>
737+ structExtractDifferentiationStrategies;
738+
717739 // / Cache for associated functions.
718740 SILFunction *primal = nullptr ;
719741 SILFunction *adjoint = nullptr ;
@@ -810,6 +832,11 @@ class DifferentiationTask {
810832 return nestedApplyActivities;
811833 }
812834
835+ DenseMap<StructExtractInst *, StructExtractDifferentiationStrategy> &
836+ getStructExtractDifferentiationStrategies () {
837+ return structExtractDifferentiationStrategies;
838+ }
839+
813840 bool isEqual (const DifferentiationTask &other) const {
814841 return original == other.original && attr == other.attr ;
815842 }
@@ -2228,16 +2255,42 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
22282255 }
22292256
22302257 void visitStructExtractInst (StructExtractInst *sei) {
2258+ auto &astCtx = getContext ().getASTContext ();
2259+ auto &structExtractDifferentiationStrategies =
2260+ getDifferentiationTask ()->getStructExtractDifferentiationStrategies ();
2261+
22312262 // Special handling logic only applies when the `struct_extract` is active.
22322263 // If not, just do standard cloning.
22332264 if (!activityInfo.isActive (sei, synthesis.indices )) {
22342265 LLVM_DEBUG (getADDebugStream () << " Not active:\n " << *sei << ' \n ' );
2266+ structExtractDifferentiationStrategies.insert (
2267+ {sei, StructExtractDifferentiationStrategy::Inactive});
22352268 SILClonerWithScopes::visitStructExtractInst (sei);
22362269 return ;
22372270 }
22382271
2239- // This instruction is active. Replace it with a call to the corresponding
2240- // getter's VJP.
2272+ // This instruction is active. Determine the appropriate differentiation
2273+ // strategy, and use it.
2274+
2275+ // Use the FieldwiseProductSpace strategy, if appropriate.
2276+ auto *structDecl = sei->getStructDecl ();
2277+ auto aliasLookup = structDecl->lookupDirect (astCtx.Id_CotangentVector );
2278+ if (aliasLookup.size () >= 1 ) {
2279+ assert (aliasLookup.size () == 1 );
2280+ assert (isa<TypeAliasDecl>(aliasLookup[0 ]));
2281+ auto *aliasDecl = cast<TypeAliasDecl>(aliasLookup[0 ]);
2282+ if (aliasDecl->getAttrs ().hasAttribute <FieldwiseProductSpaceAttr>()) {
2283+ structExtractDifferentiationStrategies.insert (
2284+ {sei, StructExtractDifferentiationStrategy::FieldwiseProductSpace});
2285+ SILClonerWithScopes::visitStructExtractInst (sei);
2286+ return ;
2287+ }
2288+ }
2289+
2290+ // The FieldwiseProductSpace strategy is not appropriate, so use the Getter
2291+ // strategy.
2292+ structExtractDifferentiationStrategies.insert (
2293+ {sei, StructExtractDifferentiationStrategy::Getter});
22412294
22422295 // Find the corresponding getter and its VJP.
22432296 auto *getterDecl = sei->getField ()->getGetter ();
@@ -3596,17 +3649,103 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
35963649 }
35973650
35983651 void visitStructExtractInst (StructExtractInst *sei) {
3599- // Replace a `struct_extract` with a call to its pullback.
36003652 auto loc = remapLocation (sei->getLoc ());
3653+ auto &astCtx = getContext ().getASTContext ();
36013654
3602- // Get the pullback.
3603- auto *pullbackField = getPrimalInfo ().lookUpPullbackDecl (sei);
3604- if (!pullbackField) {
3605- // Inactive `struct_extract` instructions don't need to be cloned into the
3606- // adjoint.
3655+ auto &differentiationStrategies =
3656+ getDifferentiationTask ()->getStructExtractDifferentiationStrategies ();
3657+ auto differentiationStrategyLookUp = differentiationStrategies.find (sei);
3658+ assert (differentiationStrategyLookUp != differentiationStrategies.end ());
3659+ auto differentiationStrategy = differentiationStrategyLookUp->second ;
3660+
3661+ if (differentiationStrategy ==
3662+ StructExtractDifferentiationStrategy::Inactive) {
36073663 assert (!activityInfo.isActive (sei, synthesis.indices ));
36083664 return ;
36093665 }
3666+
3667+ if (differentiationStrategy ==
3668+ StructExtractDifferentiationStrategy::FieldwiseProductSpace) {
3669+ // Compute adjoint as follows:
3670+ // y = struct_extract <key>, x
3671+ // adj[x] = struct (0, ..., key': adj[y], ..., 0)
3672+ // where `key'` is the field in the cotangent space corresponding to
3673+ // `key`.
3674+
3675+ // Find the decl of the cotangent space type.
3676+ auto *structDecl = sei->getStructDecl ();
3677+ auto aliasLookup = structDecl->lookupDirect (astCtx.Id_CotangentVector );
3678+ assert (aliasLookup.size () == 1 );
3679+ assert (isa<TypeAliasDecl>(aliasLookup[0 ]));
3680+ auto *aliasDecl = cast<TypeAliasDecl>(aliasLookup[0 ]);
3681+ assert (aliasDecl->getAttrs ().hasAttribute <FieldwiseProductSpaceAttr>());
3682+ auto cotangentVectorTy =
3683+ aliasDecl->getUnderlyingTypeLoc ().getType ()->getCanonicalType ();
3684+ assert (!getModule ()
3685+ .Types .getTypeLowering (cotangentVectorTy)
3686+ .isAddressOnly ());
3687+ auto cotangentVectorSILTy =
3688+ SILType::getPrimitiveObjectType (cotangentVectorTy);
3689+ auto *cotangentVectorDecl =
3690+ cotangentVectorTy->getStructOrBoundGenericStruct ();
3691+ assert (cotangentVectorDecl);
3692+
3693+ // Find the corresponding field in the cotangent space.
3694+ VarDecl *correspondingField = nullptr ;
3695+ if (cotangentVectorDecl == structDecl)
3696+ correspondingField = sei->getField ();
3697+ else {
3698+ auto correspondingFieldLookup =
3699+ cotangentVectorDecl->lookupDirect (sei->getField ()->getName ());
3700+ assert (correspondingFieldLookup.size () == 1 );
3701+ assert (isa<VarDecl>(correspondingFieldLookup[0 ]));
3702+ correspondingField = cast<VarDecl>(correspondingFieldLookup[0 ]);
3703+ }
3704+ assert (correspondingField);
3705+
3706+ #ifndef NDEBUG
3707+ unsigned numMatchingStoredProperties = 0 ;
3708+ for (auto *storedProperty : cotangentVectorDecl->getStoredProperties ())
3709+ if (storedProperty == correspondingField)
3710+ numMatchingStoredProperties += 1 ;
3711+ assert (numMatchingStoredProperties == 1 );
3712+ #endif
3713+
3714+ // Compute adjoint.
3715+ auto av = getAdjointValue (sei);
3716+ switch (av.getKind ()) {
3717+ case AdjointValue::Kind::Zero:
3718+ addAdjointValue (sei->getOperand (),
3719+ AdjointValue::getZero (cotangentVectorSILTy));
3720+ break ;
3721+ case AdjointValue::Kind::Materialized:
3722+ case AdjointValue::Kind::Aggregate: {
3723+ SmallVector<AdjointValue, 8 > eltVals;
3724+ for (auto *field : cotangentVectorDecl->getStoredProperties ()) {
3725+ if (field == correspondingField)
3726+ eltVals.push_back (av);
3727+ else
3728+ eltVals.push_back (
3729+ AdjointValue::getZero (SILType::getPrimitiveObjectType (
3730+ field->getType ()->getCanonicalType ())));
3731+ }
3732+ addAdjointValue (sei->getOperand (),
3733+ AdjointValue::getAggregate (cotangentVectorSILTy,
3734+ eltVals, allocator));
3735+ }
3736+ }
3737+
3738+ return ;
3739+ }
3740+
3741+ // The only remaining strategy is the getter strategy.
3742+ // Replace the `struct_extract` with a call to its pullback.
3743+ assert (differentiationStrategy ==
3744+ StructExtractDifferentiationStrategy::Getter);
3745+
3746+ // Get the pullback.
3747+ auto *pullbackField = getPrimalInfo ().lookUpPullbackDecl (sei);
3748+ assert (pullbackField);
36103749 SILValue pullback = builder.createStructExtract (loc,
36113750 primalValueAggregateInAdj,
36123751 pullbackField);
0 commit comments