@@ -412,15 +412,23 @@ struct DifferentiationInvoker {
412412
413413class DifferentiableActivityInfo ;
414414
415- // / Information about the JVP/VJP function produced during JVP/VJP generation,
416- // / e.g. mappings from original values to corresponding values in the
417- // / pullback/differential struct.
415+ // / Linear map struct and branching trace enum information for an original
416+ // / function and and derivative function (JVP or VJP).
418417// /
419- // / A linear map struct is an aggregate value containing linear maps checkpointed
420- // / during the JVP/VJP computation. Linear map structs are generated for every
421- // / original function during JVP/VJP generation. Linear map struct values are
422- // / constructed by JVP/VJP functions and consumed by pullback/differential
423- // / functions.
418+ // / Linear map structs contain all callee linear maps produced in a JVP/VJP
419+ // / basic block. A linear map struct is created for each basic block in the
420+ // / original function, and a linear map struct field is created for every active
421+ // / `apply` in the original basic block.
422+ // /
423+ // / Branching trace enums model the control flow graph of the original function.
424+ // / A branching trace enum is created for each basic block in the original
425+ // / function, and a branching trace enum case is created for every basic block
426+ // / predecessor/successor. This supports control flow differentiation: JVP/VJP
427+ // / functions build branching trace enums to record an execution trace. Indirect
428+ // / branching trace enums are created for basic blocks that are in loops.
429+ // /
430+ // / Linear map struct values and branching trace enum values are constructed in
431+ // / JVP/VJP functions and consumed in pullback/differential functions.
424432class LinearMapInfo {
425433private:
426434 // / The linear map kind.
@@ -446,13 +454,12 @@ class LinearMapInfo {
446454 // / For differentials: these are successor enums.
447455 DenseMap<SILBasicBlock *, EnumDecl *> branchingTraceDecls;
448456
449- // / Mapping from `apply` and `struct_extract` instructions in the original
450- // / function to the corresponding linear map declaration in the linear map
451- // / struct.
452- DenseMap<SILInstruction *, VarDecl *> linearMapValueMap;
457+ // / Mapping from `apply` instructions in the original function to the
458+ // / corresponding linear map field declaration in the linear map struct.
459+ DenseMap<ApplyInst *, VarDecl *> linearMapFieldMap;
453460
454- // / Mapping from predecessor+ succcessor basic block pairs in original function
455- // / to the corresponding branching trace enum case.
461+ // / Mapping from predecessor- succcessor basic block pairs in the original
462+ // / function to the corresponding branching trace enum case.
456463 DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, EnumElementDecl *>
457464 branchingTraceEnumCases;
458465
@@ -662,7 +669,7 @@ class LinearMapInfo {
662669 }
663670
664671 // / Add a linear map to the linear map struct.
665- VarDecl *addLinearMapDecl (SILInstruction *inst , SILType linearMapType) {
672+ VarDecl *addLinearMapDecl (ApplyInst *ai , SILType linearMapType) {
666673 // IRGen requires decls to have AST types (not `SILFunctionType`), so we
667674 // convert the `SILFunctionType` of the linear map to a `FunctionType` with
668675 // the same parameters and results.
@@ -678,24 +685,24 @@ class LinearMapInfo {
678685 astFnTy = FunctionType::get (
679686 params, silFnTy->getAllResultsInterfaceType ().getASTType ());
680687
681- auto *origBB = inst ->getParent ();
688+ auto *origBB = ai ->getParent ();
682689 auto *linMapStruct = getLinearMapStruct (origBB);
683690 std::string linearMapName;
684691 switch (kind) {
685692 case AutoDiffLinearMapKind::Differential:
686- linearMapName = " differential_" + llvm::itostr (linearMapValueMap .size ());
693+ linearMapName = " differential_" + llvm::itostr (linearMapFieldMap .size ());
687694 break ;
688695 case AutoDiffLinearMapKind::Pullback:
689- linearMapName = " pullback_" + llvm::itostr (linearMapValueMap .size ());
696+ linearMapName = " pullback_" + llvm::itostr (linearMapFieldMap .size ());
690697 break ;
691698 }
692699 auto *linearMapDecl = addVarDecl (linMapStruct, linearMapName, astFnTy);
693- linearMapValueMap .insert ({inst , linearMapDecl});
700+ linearMapFieldMap .insert ({ai , linearMapDecl});
694701 return linearMapDecl;
695702 }
696703
697- // / Given an `apply` instruction, conditionally adds its linear map function
698- // / to the linear map struct if it is active.
704+ // / Given an `apply` instruction, conditionally add a linear map struct field
705+ // / for its linear map function if it is active.
699706 void addLinearMapToStruct (ADContext &context, ApplyInst *ai,
700707 const SILAutoDiffIndices &indices);
701708
@@ -771,12 +778,13 @@ class LinearMapInfo {
771778 return linearMapStructEnumFields.lookup (linearMapStruct);
772779 }
773780
774- // / Finds the linear map declaration in the pullback struct for an `apply` or
775- // / `struct_extract` in the original function.
776- VarDecl *lookUpLinearMapDecl (SILInstruction *inst) {
777- auto lookup = linearMapValueMap.find (inst);
778- assert (lookup != linearMapValueMap.end () &&
779- " No linear map declaration corresponding to the given instruction" );
781+ // / Finds the linear map declaration in the pullback struct for the given
782+ // / `apply` instruction in the original function.
783+ VarDecl *lookUpLinearMapDecl (ApplyInst *ai) {
784+ assert (ai->getFunction () == original);
785+ auto lookup = linearMapFieldMap.find (ai);
786+ assert (lookup != linearMapFieldMap.end () &&
787+ " No linear map field corresponding to the given `apply`" );
780788 return lookup->getSecond ();
781789 }
782790};
@@ -1047,7 +1055,8 @@ class ADContext {
10471055 for (auto *da : original->getAttrs ().getAttributes <DifferentiableAttr>()) {
10481056 auto *daParamIndices = da->getParameterIndices ();
10491057 auto *daIndexSet = autodiff::getLoweredParameterIndices (
1050- daParamIndices, original->getInterfaceType ()->castTo <AnyFunctionType>());
1058+ daParamIndices,
1059+ original->getInterfaceType ()->castTo <AnyFunctionType>());
10511060 // If all indices in `indexSet` are in `daIndexSet`, and it has fewer
10521061 // indices than our current candidate and a primitive VJP, then `da` is
10531062 // our new candidate.
@@ -1646,8 +1655,6 @@ bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) {
16461655 return false ;
16471656}
16481657
1649- // / Takes an `apply` instruction and adds its linear map function to the
1650- // / linear map struct if it is active.
16511658void LinearMapInfo::addLinearMapToStruct (ADContext &context, ApplyInst *ai,
16521659 const SILAutoDiffIndices &indices) {
16531660 SmallVector<SILValue, 4 > allResults;
@@ -1696,15 +1703,15 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
16961703 [&](CanSILFunctionType origFnTy) {
16971704 // Check non-differentiable arguments.
16981705 for (unsigned paramIndex : range (origFnTy->getNumParameters ())) {
1699- auto remappedParamType =
1700- origFnTy-> getParameters ()[paramIndex] .getSILStorageInterfaceType ();
1706+ auto remappedParamType = origFnTy-> getParameters ()[paramIndex]
1707+ .getSILStorageInterfaceType ();
17011708 if (applyIndices.isWrtParameter (paramIndex) &&
17021709 !remappedParamType.isDifferentiable (derivative->getModule ()))
17031710 return true ;
17041711 }
17051712 // Check non-differentiable results.
1706- auto remappedResultType =
1707- origFnTy-> getResults ()[applyIndices. source ] .getSILStorageInterfaceType ();
1713+ auto remappedResultType = origFnTy-> getResults ()[applyIndices. source ]
1714+ .getSILStorageInterfaceType ();
17081715 if (!remappedResultType.isDifferentiable (derivative->getModule ()))
17091716 return true ;
17101717 return false ;
@@ -1713,13 +1720,13 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
17131720 return ;
17141721
17151722 AutoDiffDerivativeFunctionKind derivativeFnKind (kind);
1716- auto derivativeFnType = remappedOrigFnSubstTy->getAutoDiffDerivativeFunctionType (
1717- parameters, source, derivativeFnKind, context.getTypeConverter (),
1718- LookUpConformanceInModule (derivative->getModule ().getSwiftModule ()));
1723+ auto derivativeFnType =
1724+ remappedOrigFnSubstTy->getAutoDiffDerivativeFunctionType (
1725+ parameters, source, derivativeFnKind, context.getTypeConverter (),
1726+ LookUpConformanceInModule (derivative->getModule ().getSwiftModule ()));
17191727
17201728 auto derivativeFnResultTypes =
17211729 derivativeFnType->getAllResultsInterfaceType ().castTo <TupleType>();
1722- derivativeFnResultTypes->getElement (derivativeFnResultTypes->getElements ().size () - 1 );
17231730 auto linearMapSILType = SILType::getPrimitiveObjectType (
17241731 derivativeFnResultTypes
17251732 ->getElement (derivativeFnResultTypes->getElements ().size () - 1 )
@@ -1760,8 +1767,8 @@ void LinearMapInfo::generateDifferentiationDataStructures(
17601767 break ;
17611768 }
17621769 for (auto &origBB : *original) {
1763- auto *traceEnum =
1764- createBranchingTraceDecl ( &origBB, indices, derivativeFnGenSig, loopInfo);
1770+ auto *traceEnum = createBranchingTraceDecl (
1771+ &origBB, indices, derivativeFnGenSig, loopInfo);
17651772 branchingTraceDecls.insert ({&origBB, traceEnum});
17661773 if (origBB.isEntry ())
17671774 continue ;
@@ -5847,6 +5854,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
58475854 // / Record a temporary value for cleanup before its block's terminator.
58485855 SILValue recordTemporary (SILValue value) {
58495856 assert (value->getType ().isObject ());
5857+ assert (value->getFunction () == &getPullback ());
58505858 blockTemporaries[value->getParentBlock ()].push_back (value);
58515859 LLVM_DEBUG (getADDebugStream () << " Recorded temporary " << value);
58525860 auto insertion = blockTemporarySet.insert (value); (void )insertion;
@@ -5971,7 +5979,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
59715979 auto insertion = valueMap.try_emplace ({origBB, originalValue},
59725980 adjointValue);
59735981 LLVM_DEBUG (getADDebugStream ()
5974- << " The existing adjoint value will be replaced : "
5982+ << " The new adjoint value, replacing the existing one, is : "
59755983 << insertion.first ->getSecond ());
59765984 if (!insertion.second )
59775985 insertion.first ->getSecond () = adjointValue;
@@ -6000,7 +6008,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
60006008 assert (originalValue->getType ().isObject ());
60016009 assert (newAdjointValue.getType ().isObject ());
60026010 assert (originalValue->getFunction () == &getOriginal ());
6003- LLVM_DEBUG (getADDebugStream () << " Adding adjoint for " << originalValue);
6011+ LLVM_DEBUG (getADDebugStream () << " Adding adjoint value for "
6012+ << originalValue);
60046013 // The adjoint value must be in the tangent space.
60056014 assert (newAdjointValue.getType () ==
60066015 getRemappedTangentType (originalValue->getType ()));
@@ -6494,8 +6503,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
64946503 SmallVector<SILValue, 4 > directResults;
64956504 auto indirectResultIt = pullback.getIndirectResults ().begin ();
64966505 for (auto resultInfo : pullback.getLoweredFunctionType ()->getResults ()) {
6497- auto resultType =
6498- pullback. mapTypeIntoContext ( resultInfo.getInterfaceType ())->getCanonicalType ();
6506+ auto resultType = pullback. mapTypeIntoContext (
6507+ resultInfo.getInterfaceType ())->getCanonicalType ();
64996508 if (resultInfo.isFormalDirect ())
65006509 directResults.push_back (emitZeroDirect (resultType, pbLoc));
65016510 else
@@ -6539,7 +6548,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
65396548 auto &predBBActiveValues = activeValues[origPredBB];
65406549 for (auto activeValue : predBBActiveValues) {
65416550 LLVM_DEBUG (getADDebugStream ()
6542- << " Propagating active adjoint " << activeValue
6551+ << " Propagating adjoint of active value " << activeValue
65436552 << " to predecessors' pullback blocks\n " );
65446553 if (activeValue->getType ().isObject ()) {
65456554 auto activeValueAdj = getAdjointValue (origBB, activeValue);
@@ -7641,7 +7650,7 @@ SILValue PullbackEmitter::accumulateDirect(SILValue lhs, SILValue rhs,
76417650 // TODO: Optimize for the case when lhs == rhs.
76427651 LLVM_DEBUG (getADDebugStream () <<
76437652 " Emitting adjoint accumulation for lhs: " << lhs <<
7644- " and rhs: " << rhs << " \n " );
7653+ " and rhs: " << rhs);
76457654 assert (lhs->getType () == rhs->getType () && " Adjoints must have equal types!" );
76467655 assert (lhs->getType ().isObject () && rhs->getType ().isObject () &&
76477656 " Adjoint types must be both object types!" );
0 commit comments