@@ -1513,20 +1513,6 @@ class DifferentiableActivityInfo {
15131513 const SILAutoDiffIndices &indices) const ;
15141514};
15151515
1516- // / Given a parameter argument (not indirect result) and some differentiation
1517- // / indices, figure out whether the parent function is being differentiated with
1518- // / respect to this parameter, according to the indices.
1519- static bool isDifferentiationParameter (SILArgument *argument,
1520- IndexSubset *indices) {
1521- if (!argument) return false ;
1522- auto *function = argument->getFunction ();
1523- auto paramArgs = function->getArgumentsWithoutIndirectResults ();
1524- for (unsigned i : indices->getIndices ())
1525- if (paramArgs[i] == argument)
1526- return true ;
1527- return false ;
1528- }
1529-
15301516// / For an `apply` instruction with active results, compute:
15311517// / - The results of the `apply` instruction, in type order.
15321518// / - The set of minimal parameter and result indices for differentiating the
@@ -1543,9 +1529,7 @@ static void collectMinimalIndicesForFunctionCall(
15431529 // Record all parameter indices in type order.
15441530 unsigned currentParamIdx = 0 ;
15451531 for (auto applyArg : ai->getArgumentsWithoutIndirectResults ()) {
1546- if (activityInfo.isVaried (applyArg, parentIndices.parameters ) ||
1547- isDifferentiationParameter (dyn_cast<SILArgument>(applyArg),
1548- parentIndices.parameters ))
1532+ if (activityInfo.isActive (applyArg, parentIndices))
15491533 paramIndices.push_back (currentParamIdx);
15501534 ++currentParamIdx;
15511535 }
@@ -1566,13 +1550,12 @@ static void collectMinimalIndicesForFunctionCall(
15661550 if (res.isFormalDirect ()) {
15671551 results.push_back (directResults[dirResIdx]);
15681552 if (auto dirRes = directResults[dirResIdx])
1569- if (dirRes && activityInfo.isUseful (dirRes, parentIndices. source ))
1553+ if (dirRes && activityInfo.isActive (dirRes, parentIndices))
15701554 resultIndices.push_back (idx);
15711555 ++dirResIdx;
15721556 } else {
15731557 results.push_back (indirectResults[indResIdx]);
1574- if (activityInfo.isUseful (indirectResults[indResIdx],
1575- parentIndices.source ))
1558+ if (activityInfo.isActive (indirectResults[indResIdx], parentIndices))
15761559 resultIndices.push_back (idx);
15771560 ++indResIdx;
15781561 }
@@ -3813,10 +3796,12 @@ class VJPEmitter final
38133796 activeResultIndices.begin (), activeResultIndices.end (),
38143797 [&s](unsigned i) { s << i; }, [&s] { s << " , " ; });
38153798 s << " }\n " ;);
3816- // FIXME: We don't support multiple active results yet.
3799+ // Diagnose multiple active results.
3800+ // TODO(TF-983): Support multiple active results.
38173801 if (activeResultIndices.size () > 1 ) {
38183802 context.emitNondifferentiabilityError (
3819- ai, invoker, diag::autodiff_expression_not_differentiable_note);
3803+ ai, invoker,
3804+ diag::autodiff_cannot_differentiate_through_multiple_results);
38203805 errorOccurred = true ;
38213806 return ;
38223807 }
@@ -5492,13 +5477,16 @@ class JVPEmitter final
54925477 activeResultIndices.begin (), activeResultIndices.end (),
54935478 [&s](unsigned i) { s << i; }, [&s] { s << " , " ; });
54945479 s << " }\n " ;);
5495- // FIXME: We don't support multiple active results yet.
5480+ // Diagnose multiple active results.
5481+ // TODO(TF-983): Support multiple active results.
54965482 if (activeResultIndices.size () > 1 ) {
54975483 context.emitNondifferentiabilityError (
5498- ai, invoker, diag::autodiff_expression_not_differentiable_note);
5484+ ai, invoker,
5485+ diag::autodiff_cannot_differentiate_through_multiple_results);
54995486 errorOccurred = true ;
55005487 return ;
55015488 }
5489+
55025490 // Form expected indices, assuming there's only one result.
55035491 SILAutoDiffIndices indices (
55045492 activeResultIndices.front (),
@@ -6257,6 +6245,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
62576245 // Adjoint values of dominated active values are passed as pullback block
62586246 // arguments.
62596247 DominanceOrder domOrder (original.getEntryBlock (), domInfo);
6248+ // Keep track of visited values.
6249+ SmallPtrSet<SILValue, 8 > visited;
62606250 while (auto *bb = domOrder.getNext ()) {
62616251 auto &bbActiveValues = activeValues[bb];
62626252 // If the current block has an immediate dominator, append the immediate
@@ -6266,13 +6256,12 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
62666256 bbActiveValues.append (domBBActiveValues.begin (),
62676257 domBBActiveValues.end ());
62686258 }
6269- SmallPtrSet<SILValue, 8 > visited (bbActiveValues.begin (),
6270- bbActiveValues.end ());
6271- // Register a value as active if it has not yet been visited.
62726259 bool diagnosedActiveEnumValue = false ;
6273- auto addActiveValue = [&](SILValue v) {
6260+ // Mark the activity of a value if it has not yet been visited.
6261+ auto markValueActivity = [&](SILValue v) {
62746262 if (visited.count (v))
62756263 return ;
6264+ visited.insert (v);
62766265 // Diagnose active enum values. Differentiation of enum values requires
62776266 // special adjoint value handling and is not yet supported. Diagnose
62786267 // only the first active enum value to prevent too many diagnostics.
@@ -6288,17 +6277,19 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
62886277 // become projections into their adjoint base buffer.
62896278 if (Projection::isAddressProjection (v))
62906279 return ;
6291- visited.insert (v);
62926280 bbActiveValues.push_back (v);
62936281 };
6294- // Register bb arguments and all instruction operands/results.
6282+ // Visit bb arguments and all instruction operands/results.
6283+ for (auto *arg : bb->getArguments ())
6284+ if (getActivityInfo ().isActive (arg, getIndices ()))
6285+ markValueActivity (arg);
62956286 for (auto &inst : *bb) {
62966287 for (auto op : inst.getOperandValues ())
62976288 if (getActivityInfo ().isActive (op, getIndices ()))
6298- addActiveValue (op);
6289+ markValueActivity (op);
62996290 for (auto result : inst.getResults ())
63006291 if (getActivityInfo ().isActive (result, getIndices ()))
6301- addActiveValue (result);
6292+ markValueActivity (result);
63026293 }
63036294 domOrder.pushChildren (bb);
63046295 if (errorOccurred)
0 commit comments