Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 22 additions & 122 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2937,29 +2937,10 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
auto *origRetBB = &*original.findReturnBB();
adjointBBMap.insert({origRetBB, adjointEntry});
SILFunctionConventions origConv(origTy, getModule());
// Initialize `originalParameters`, `primalValueAggregate`,
// `originalResults` and `seed`.
// The adjoint function has type (seed, pv) -> ([arg0], ..., [argn]).
auto adjParamArgs = getAdjoint().getArgumentsWithoutIndirectResults();
auto origNumParams = origConv.getNumParameters();
auto origNumResults = origTy->getNumResults();
// The adjoint function has type
// (seed, pv0, ..., pvn, origres, arg0, ..., argn, [self])
// -> ([self], [arg0], ..., [argn]).
// Square brackets denote [] elements that are not always in the signature:
// * "self" is present in the argument list when it's present in the
// original's argument list.
// * Results are present when the adjoint is with respect to the
// corresponding original argument.
//
// We get each range of arguments by shifting the `paramArgsData` pointer.
auto *paramArgsData = adjParamArgs.data();
seed = *paramArgsData++;
primalValueAggregateInAdj = *paramArgsData++;
originalResultsInAdj = {paramArgsData, origNumResults};
paramArgsData += origNumResults;
originalParametersInAdj = {paramArgsData, origNumParams};
paramArgsData += origNumParams;
assert(paramArgsData == adjParamArgs.data() + adjParamArgs.size());
seed = adjParamArgs[0];
primalValueAggregateInAdj = adjParamArgs[1];

// Assign adjoint to the return value.
// y = tuple (y0, ..., yn)
Expand Down Expand Up @@ -3023,12 +3004,11 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
// `parameterIndex` into the `retElts` vector.
auto addRetElt = [&](unsigned parameterIndex) -> void {
auto origParam = origParams[parameterIndex];
auto adjParam = originalParametersInAdj[parameterIndex];
auto adjVal = getAdjointValue(origParam);
if (adjParam->getType().isObject())
if (origParam->getType().isObject())
retElts.push_back(materializeAdjointDirect(adjVal, adjLoc));
else
materializeAdjointIndirect(adjVal, adjParam);
llvm_unreachable("Unimplemented: Handle indirect pullback results");
};

// The original's self parameter, if present, is the last parameter. But we
Expand Down Expand Up @@ -4079,22 +4059,11 @@ void DifferentiationTask::createEmptyAdjoint() {
AutoDiffAssociatedVectorSpaceKind::Cotangent, lookupConformance)
->getCanonicalType()));

// If there's a generated primal, accept a primal value struct in the adjoint
// parameter list.
if (auto *pi = getPrimalInfo()) {
auto pvType = pi->getPrimalValueStruct()
->getDeclaredInterfaceType()
->getCanonicalType();
adjParams.push_back(getFormalParameterInfo(pvType));
}

// Add adjoint parameters for the original results.
for (auto &origRes : origTy->getResults())
adjParams.push_back(getFormalParameterInfo(origRes.getType()));

// Add adjoint parameters for the original parameters.
for (auto &param : origParams)
adjParams.push_back(param);
// Accept a primal value struct in the adjoint parameter list. This is the
// pullback's closure context.
auto pvType = getPrimalInfo()->getPrimalValueStruct()
->getDeclaredInterfaceType()->getCanonicalType();
adjParams.push_back(getFormalParameterInfo(pvType));

// Add adjoint result for the wrt self parameter, if it was requested.
auto selfParamIndex = origParams.size() - 1;
Expand Down Expand Up @@ -4279,7 +4248,6 @@ void DifferentiationTask::createVJP() {
unsigned numOriginalResults = originalConv.getResults().size();
unsigned numCheckpoints = primalConv.getResults().size() - numOriginalResults;
unsigned numSeeds = 1;
unsigned numWrt = getIndices().parameters.count();

LLVM_DEBUG(llvm::dbgs() << " numOriginalParameters: "
<< numOriginalParameters << "\n"
Expand All @@ -4296,11 +4264,7 @@ void DifferentiationTask::createVJP() {
"unexpected number of vjp parameters");
assert(vjpConv.getResults().size() == numOriginalResults + 1 &&
"unexpected number of vjp results");
assert(adjointConv.getNumParameters() ==
numSeeds + numCheckpoints + numOriginalResults +
numOriginalParameters &&
"unexpected number of adjoint parameters");
assert(adjointConv.getResults().size() == numWrt &&
assert(adjointConv.getResults().size() == getIndices().parameters.count() &&
"unexpected number of adjoint results");

// We assume that primal result conventions (for all results but the optional
Expand All @@ -4313,21 +4277,6 @@ void DifferentiationTask::createVJP() {
assert(primalResultInfo == vjpResultInfo &&
"primal result info does not match vjp result info");
}

// We assume that the primal result conventions for checkpoints and original
// results match the corresponding adjoint parameter conventions for
// checkpoints and original results, so check that assumption.
for (unsigned resultIndex : indices(primalConv.getResults())) {
auto &primalResultInfo = primalConv.getResults()[resultIndex];
auto &adjointParameterInfo =
adjointConv.getParameters()[numSeeds + resultIndex];
assert(primalResultInfo.isFormalIndirect() ==
adjointParameterInfo.isFormalIndirect() &&
"primal result directness does not match adjoint parameter "
"directness");
assert(primalResultInfo.getType() == adjointParameterInfo.getType() &&
"primal result type does not match adjoint parameter type");
}
#endif

// Create VJP entry BB and arguments.
Expand All @@ -4337,9 +4286,8 @@ void DifferentiationTask::createVJP() {
SILBuilder builder(entry);
auto loc = vjp->getLocation();

// === Call primal with original arguments. ===
// Call primal with original arguments.
SmallVector<SILValue, 8> primalArgs;

// Allocate space for indirect checkpoint results, and pass the addresses to
// the primal.
unsigned remainingIndirectCheckpointResults =
Expand All @@ -4354,89 +4302,41 @@ void DifferentiationTask::createVJP() {
stackAllocsToCleanUp.push_back(resultBuf);
--remainingIndirectCheckpointResults;
}

// Tell the primal to put its indirect results in the vjp indirect result
// buffers. This assumes that the primal indirect results are exactly the vjp
// indirect results, an assumption that we check in assertions above.
for (auto indRes : vjp->getIndirectResults())
primalArgs.push_back(indRes);

// Add original parameters.
for (auto arg : vjp->getArgumentsWithoutIndirectResults())
primalArgs.push_back(arg);

// Get and call the primal.
auto *primalRef = builder.createFunctionRef(loc, primal);

auto vjpSubstMap = vjpGenericEnv
? vjpGenericEnv->getForwardingSubstitutionMap()
: vjp->getForwardingSubstitutionMap();

auto *primalApply = builder.createApply(
loc, primalRef, vjpSubstMap, primalArgs, /*isNonThrowing*/ false);

// Collect the primal's direct results.
SmallVector<SILValue, 8> primalDirectResults;
if (primalConv.getNumDirectSILResults() == 1)
primalDirectResults.push_back(primalApply);
else {
for (unsigned i : range(primalConv.getNumDirectSILResults())) {
auto val = builder.createTupleExtract(loc, primalApply, i);
primalDirectResults.push_back(val);
}
}

// === Partially apply the adjoint. ===
SmallVector<SILValue, 8> partialAdjointArgs;

// Add primal values and original results.
unsigned indResIdx = 0, dirResIdx = 0;
for (auto &resInfo : primalConv.getResults()) {
if (resInfo.isFormalDirect()) {
// This assumes that the primal direct results correspond exactly to the
// adjoint's direct parameters, an assumption that we check in assertions
// above.
//
// If this is not the primal value struct (index 0), it must be an
// original result. Retain it because we'll return original results
// alongside the pullback.
if (dirResIdx != 0)
builder.createRetainValue(loc, primalDirectResults[dirResIdx],
builder.getDefaultAtomicity());
partialAdjointArgs.push_back(primalDirectResults[dirResIdx++]);
} else {
// This assumes that the primal indirect results correspond exactly to the
// adjoint's indirect parameters, an assumption that we check in
// assertions above.
partialAdjointArgs.push_back(primalArgs[indResIdx++]);
}
}

// Add original parameters.
for (auto arg : vjp->getArgumentsWithoutIndirectResults()) {
if (arg->getType().isObject())
builder.createRetainValue(loc, arg, builder.getDefaultAtomicity());
// TODO: We need to copy address arguments into a Box, and give the Box to
// the adjoint. We'll need to wrap the adjoint in a function that projects
// Boxes to acheive this.
partialAdjointArgs.push_back(arg);
}
// Clean up the stack allocations for primal application.
for (auto alloc : reversed(stackAllocsToCleanUp))
builder.createDeallocStack(loc, alloc);

// Collect the primal's direct results to prepare for creating a pullback
// and return original values and the pullback.
SmallVector<SILValue, 8> primalDirectResults;
extractAllElements(primalApply, builder, primalDirectResults);
auto originalDirectResults = ArrayRef<SILValue>(primalDirectResults)
.take_back(originalConv.getNumDirectSILResults());
// Get and partially apply the adjoint.
auto *adjointRef = builder.createFunctionRef(loc, adjoint);
auto *adjointPartialApply = builder.createPartialApply(
loc, adjointRef, vjpSubstMap, partialAdjointArgs,
loc, adjointRef, vjpSubstMap, { primalDirectResults[0] },
ParameterConvention::Direct_Guaranteed);

// Clean up the stack allocations.
for (auto alloc : reversed(stackAllocsToCleanUp))
builder.createDeallocStack(loc, alloc);

// Return the direct results. Note that indirect results have already been
// filled in by the application of the primal.
SmallVector<SILValue, 8> directResults;
auto originalDirectResults = ArrayRef<SILValue>(primalDirectResults)
.take_back(originalConv.getNumDirectSILResults());
directResults.append(originalDirectResults.begin(),
originalDirectResults.end());
directResults.push_back(adjointPartialApply);
Expand Down
2 changes: 1 addition & 1 deletion test/AutoDiff/closures.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ public func closureCaptureMutable() {
// CHECK: [[PRIMAL:%.*]] = function_ref @AD__{{.*}}closureCaptureMutable{{.*}}___primal_src_0_wrt_0
// CHECK: {{.*}} = apply [[PRIMAL]]({{.*}}, [[BOXED_ARG]])
// CHECK: [[ADJOINT:%.*]] = function_ref @AD__{{.*}}closureCaptureMutabley{{.*}}___adjoint_src_0_wrt_0
// CHECK: {{.*}} = partial_apply [callee_guaranteed] [[ADJOINT]]({{.*}}, {{.*}}, {{.*}}, [[BOXED_ARG]])
// CHECK: {{.*}} = partial_apply [callee_guaranteed] [[ADJOINT]]({{.*}})

4 changes: 2 additions & 2 deletions test/AutoDiff/refcounting.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ _ = pullback(at: Vector.zero, in: testOwnedVector)
// The adjoint should not release primal values because they are passed in as @guaranteed.
//
// CHECK-VJP-LABEL: @{{.*}}testOwnedVector{{.*}}__adjoint_src_0_wrt_0
// CHECK-VJP: bb0({{%.*}} : $Vector, [[PRIMAL_VALUES:%.*]] : ${{.*}}testOwnedVector{{.*}}__Type__src_0_wrt_0, {{%.*}} : $Vector, {{%.*}} : $Vector):
// CHECK-VJP: bb0({{%.*}} : $Vector, [[PRIMAL_VALUES:%.*]] : ${{.*}}testOwnedVector{{.*}}__Type__src_0_wrt_0):
// CHECK-VJP: [[PULLBACK0:%.*]] = struct_extract [[PRIMAL_VALUES]] : ${{.*}}testOwnedVector{{.*}}__Type__src_0_wrt_0, #{{.*}}testOwnedVector{{.*}}__Type__src_0_wrt_0.pullback_0
// CHECK-VJP-NOT: release_value [[PULLBACK0]]
// CHECK-VJP-NOT: release_value [[PRIMAL_VALUES]]
Expand All @@ -62,7 +62,7 @@ _ = pullback(at: Vector.zero, in: testOwnedVector)
// The adjoint should not release primal values because they are passed in as @guaranteed.
//
// CHECK-NOVJP-LABEL: @{{.*}}testOwnedVector{{.*}}__adjoint_src_0_wrt_0
// CHECK-NOVJP: bb0({{%.*}} : $Vector, [[PRIMAL_VALUES:%.*]] : ${{.*}}testOwnedVector{{.*}}__Type__src_0_wrt_0, {{%.*}} : $Vector, {{%.*}} : $Vector):
// CHECK-NOVJP: bb0({{%.*}} : $Vector, [[PRIMAL_VALUES:%.*]] : ${{.*}}testOwnedVector{{.*}}__Type__src_0_wrt_0):
// CHECK-NOVJP: [[PV0:%.*]] = struct_extract [[PRIMAL_VALUES]] : ${{.*}}testOwnedVector{{.*}}__Type__src_0_wrt_0, #{{.*}}testOwnedVector{{.*}}__Type__src_0_wrt_0.v_0
// CHECK-NOVJP-NOT: release_value [[PV0]]
// CHECK-NOVJP-NOT: release_value [[PRIMAL_VALUES]]