From bb912bfd54de3482e1c98238a1113ae55d3fe732 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Mon, 28 Jan 2019 06:58:55 -0500 Subject: [PATCH] [AutoDiff] Remove original parameters and results from adjoint type signature. Original arguments and results are not needed or used by the adjoint because the adjoint is just a pullback with an explicit closure context. Saves more memory. This has been on my list for quite some time. :) --- .../Mandatory/Differentiation.cpp | 144 +++--------------- test/AutoDiff/closures.swift | 2 +- test/AutoDiff/refcounting.swift | 4 +- 3 files changed, 25 insertions(+), 125 deletions(-) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 501d3e46ebe14..fec540142f652 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -2937,29 +2937,10 @@ class AdjointEmitter final : public SILInstructionVisitor { auto *origRetBB = &*original.findReturnBB(); adjointBBMap.insert({origRetBB, adjointEntry}); SILFunctionConventions origConv(origTy, getModule()); - // Initialize `originalParameters`, `primalValueAggregate`, - // `originalResults` and `seed`. + // The adjoint function has type (seed, pv) -> ([arg0], ..., [argn]). auto adjParamArgs = getAdjoint().getArgumentsWithoutIndirectResults(); - auto origNumParams = origConv.getNumParameters(); - auto origNumResults = origTy->getNumResults(); - // The adjoint function has type - // (seed, pv0, ..., pvn, origres, arg0, ..., argn, [self]) - // -> ([self], [arg0], ..., [argn]). - // Square brackets denote [] elements that are not always in the signature: - // * "self" is present in the argument list when it's present in the - // original's argument list. - // * Results are present when the adjoint is with respect to the - // corresponding original argument. - // - // We get each range of arguments by shifting the `paramArgsData` pointer. - auto *paramArgsData = adjParamArgs.data(); - seed = *paramArgsData++; - primalValueAggregateInAdj = *paramArgsData++; - originalResultsInAdj = {paramArgsData, origNumResults}; - paramArgsData += origNumResults; - originalParametersInAdj = {paramArgsData, origNumParams}; - paramArgsData += origNumParams; - assert(paramArgsData == adjParamArgs.data() + adjParamArgs.size()); + seed = adjParamArgs[0]; + primalValueAggregateInAdj = adjParamArgs[1]; // Assign adjoint to the return value. // y = tuple (y0, ..., yn) @@ -3023,12 +3004,11 @@ class AdjointEmitter final : public SILInstructionVisitor { // `parameterIndex` into the `retElts` vector. auto addRetElt = [&](unsigned parameterIndex) -> void { auto origParam = origParams[parameterIndex]; - auto adjParam = originalParametersInAdj[parameterIndex]; auto adjVal = getAdjointValue(origParam); - if (adjParam->getType().isObject()) + if (origParam->getType().isObject()) retElts.push_back(materializeAdjointDirect(adjVal, adjLoc)); else - materializeAdjointIndirect(adjVal, adjParam); + llvm_unreachable("Unimplemented: Handle indirect pullback results"); }; // The original's self parameter, if present, is the last parameter. But we @@ -4079,22 +4059,11 @@ void DifferentiationTask::createEmptyAdjoint() { AutoDiffAssociatedVectorSpaceKind::Cotangent, lookupConformance) ->getCanonicalType())); - // If there's a generated primal, accept a primal value struct in the adjoint - // parameter list. - if (auto *pi = getPrimalInfo()) { - auto pvType = pi->getPrimalValueStruct() - ->getDeclaredInterfaceType() - ->getCanonicalType(); - adjParams.push_back(getFormalParameterInfo(pvType)); - } - - // Add adjoint parameters for the original results. - for (auto &origRes : origTy->getResults()) - adjParams.push_back(getFormalParameterInfo(origRes.getType())); - - // Add adjoint parameters for the original parameters. - for (auto ¶m : origParams) - adjParams.push_back(param); + // Accept a primal value struct in the adjoint parameter list. This is the + // pullback's closure context. + auto pvType = getPrimalInfo()->getPrimalValueStruct() + ->getDeclaredInterfaceType()->getCanonicalType(); + adjParams.push_back(getFormalParameterInfo(pvType)); // Add adjoint result for the wrt self parameter, if it was requested. auto selfParamIndex = origParams.size() - 1; @@ -4279,7 +4248,6 @@ void DifferentiationTask::createVJP() { unsigned numOriginalResults = originalConv.getResults().size(); unsigned numCheckpoints = primalConv.getResults().size() - numOriginalResults; unsigned numSeeds = 1; - unsigned numWrt = getIndices().parameters.count(); LLVM_DEBUG(llvm::dbgs() << " numOriginalParameters: " << numOriginalParameters << "\n" @@ -4296,11 +4264,7 @@ void DifferentiationTask::createVJP() { "unexpected number of vjp parameters"); assert(vjpConv.getResults().size() == numOriginalResults + 1 && "unexpected number of vjp results"); - assert(adjointConv.getNumParameters() == - numSeeds + numCheckpoints + numOriginalResults + - numOriginalParameters && - "unexpected number of adjoint parameters"); - assert(adjointConv.getResults().size() == numWrt && + assert(adjointConv.getResults().size() == getIndices().parameters.count() && "unexpected number of adjoint results"); // We assume that primal result conventions (for all results but the optional @@ -4313,21 +4277,6 @@ void DifferentiationTask::createVJP() { assert(primalResultInfo == vjpResultInfo && "primal result info does not match vjp result info"); } - - // We assume that the primal result conventions for checkpoints and original - // results match the corresponding adjoint parameter conventions for - // checkpoints and original results, so check that assumption. - for (unsigned resultIndex : indices(primalConv.getResults())) { - auto &primalResultInfo = primalConv.getResults()[resultIndex]; - auto &adjointParameterInfo = - adjointConv.getParameters()[numSeeds + resultIndex]; - assert(primalResultInfo.isFormalIndirect() == - adjointParameterInfo.isFormalIndirect() && - "primal result directness does not match adjoint parameter " - "directness"); - assert(primalResultInfo.getType() == adjointParameterInfo.getType() && - "primal result type does not match adjoint parameter type"); - } #endif // Create VJP entry BB and arguments. @@ -4337,9 +4286,8 @@ void DifferentiationTask::createVJP() { SILBuilder builder(entry); auto loc = vjp->getLocation(); - // === Call primal with original arguments. === + // Call primal with original arguments. SmallVector primalArgs; - // Allocate space for indirect checkpoint results, and pass the addresses to // the primal. unsigned remainingIndirectCheckpointResults = @@ -4354,89 +4302,41 @@ void DifferentiationTask::createVJP() { stackAllocsToCleanUp.push_back(resultBuf); --remainingIndirectCheckpointResults; } - // Tell the primal to put its indirect results in the vjp indirect result // buffers. This assumes that the primal indirect results are exactly the vjp // indirect results, an assumption that we check in assertions above. for (auto indRes : vjp->getIndirectResults()) primalArgs.push_back(indRes); - // Add original parameters. for (auto arg : vjp->getArgumentsWithoutIndirectResults()) primalArgs.push_back(arg); - // Get and call the primal. auto *primalRef = builder.createFunctionRef(loc, primal); - auto vjpSubstMap = vjpGenericEnv ? vjpGenericEnv->getForwardingSubstitutionMap() : vjp->getForwardingSubstitutionMap(); - auto *primalApply = builder.createApply( loc, primalRef, vjpSubstMap, primalArgs, /*isNonThrowing*/ false); - // Collect the primal's direct results. - SmallVector primalDirectResults; - if (primalConv.getNumDirectSILResults() == 1) - primalDirectResults.push_back(primalApply); - else { - for (unsigned i : range(primalConv.getNumDirectSILResults())) { - auto val = builder.createTupleExtract(loc, primalApply, i); - primalDirectResults.push_back(val); - } - } - - // === Partially apply the adjoint. === - SmallVector partialAdjointArgs; - - // Add primal values and original results. - unsigned indResIdx = 0, dirResIdx = 0; - for (auto &resInfo : primalConv.getResults()) { - if (resInfo.isFormalDirect()) { - // This assumes that the primal direct results correspond exactly to the - // adjoint's direct parameters, an assumption that we check in assertions - // above. - // - // If this is not the primal value struct (index 0), it must be an - // original result. Retain it because we'll return original results - // alongside the pullback. - if (dirResIdx != 0) - builder.createRetainValue(loc, primalDirectResults[dirResIdx], - builder.getDefaultAtomicity()); - partialAdjointArgs.push_back(primalDirectResults[dirResIdx++]); - } else { - // This assumes that the primal indirect results correspond exactly to the - // adjoint's indirect parameters, an assumption that we check in - // assertions above. - partialAdjointArgs.push_back(primalArgs[indResIdx++]); - } - } - - // Add original parameters. - for (auto arg : vjp->getArgumentsWithoutIndirectResults()) { - if (arg->getType().isObject()) - builder.createRetainValue(loc, arg, builder.getDefaultAtomicity()); - // TODO: We need to copy address arguments into a Box, and give the Box to - // the adjoint. We'll need to wrap the adjoint in a function that projects - // Boxes to acheive this. - partialAdjointArgs.push_back(arg); - } + // Clean up the stack allocations for primal application. + for (auto alloc : reversed(stackAllocsToCleanUp)) + builder.createDeallocStack(loc, alloc); + // Collect the primal's direct results to prepare for creating a pullback + // and return original values and the pullback. + SmallVector primalDirectResults; + extractAllElements(primalApply, builder, primalDirectResults); + auto originalDirectResults = ArrayRef(primalDirectResults) + .take_back(originalConv.getNumDirectSILResults()); // Get and partially apply the adjoint. auto *adjointRef = builder.createFunctionRef(loc, adjoint); auto *adjointPartialApply = builder.createPartialApply( - loc, adjointRef, vjpSubstMap, partialAdjointArgs, + loc, adjointRef, vjpSubstMap, { primalDirectResults[0] }, ParameterConvention::Direct_Guaranteed); - // Clean up the stack allocations. - for (auto alloc : reversed(stackAllocsToCleanUp)) - builder.createDeallocStack(loc, alloc); - // Return the direct results. Note that indirect results have already been // filled in by the application of the primal. SmallVector directResults; - auto originalDirectResults = ArrayRef(primalDirectResults) - .take_back(originalConv.getNumDirectSILResults()); directResults.append(originalDirectResults.begin(), originalDirectResults.end()); directResults.push_back(adjointPartialApply); diff --git a/test/AutoDiff/closures.swift b/test/AutoDiff/closures.swift index 3df4412e73104..8fbcd482080c4 100644 --- a/test/AutoDiff/closures.swift +++ b/test/AutoDiff/closures.swift @@ -28,5 +28,5 @@ public func closureCaptureMutable() { // CHECK: [[PRIMAL:%.*]] = function_ref @AD__{{.*}}closureCaptureMutable{{.*}}___primal_src_0_wrt_0 // CHECK: {{.*}} = apply [[PRIMAL]]({{.*}}, [[BOXED_ARG]]) // CHECK: [[ADJOINT:%.*]] = function_ref @AD__{{.*}}closureCaptureMutabley{{.*}}___adjoint_src_0_wrt_0 -// CHECK: {{.*}} = partial_apply [callee_guaranteed] [[ADJOINT]]({{.*}}, {{.*}}, {{.*}}, [[BOXED_ARG]]) +// CHECK: {{.*}} = partial_apply [callee_guaranteed] [[ADJOINT]]({{.*}}) diff --git a/test/AutoDiff/refcounting.swift b/test/AutoDiff/refcounting.swift index 4365d6980fa61..c8c86c7388a96 100644 --- a/test/AutoDiff/refcounting.swift +++ b/test/AutoDiff/refcounting.swift @@ -48,7 +48,7 @@ _ = pullback(at: Vector.zero, in: testOwnedVector) // The adjoint should not release primal values because they are passed in as @guaranteed. // // CHECK-VJP-LABEL: @{{.*}}testOwnedVector{{.*}}__adjoint_src_0_wrt_0 -// CHECK-VJP: bb0({{%.*}} : $Vector, [[PRIMAL_VALUES:%.*]] : ${{.*}}testOwnedVector{{.*}}__Type__src_0_wrt_0, {{%.*}} : $Vector, {{%.*}} : $Vector): +// CHECK-VJP: bb0({{%.*}} : $Vector, [[PRIMAL_VALUES:%.*]] : ${{.*}}testOwnedVector{{.*}}__Type__src_0_wrt_0): // CHECK-VJP: [[PULLBACK0:%.*]] = struct_extract [[PRIMAL_VALUES]] : ${{.*}}testOwnedVector{{.*}}__Type__src_0_wrt_0, #{{.*}}testOwnedVector{{.*}}__Type__src_0_wrt_0.pullback_0 // CHECK-VJP-NOT: release_value [[PULLBACK0]] // CHECK-VJP-NOT: release_value [[PRIMAL_VALUES]] @@ -62,7 +62,7 @@ _ = pullback(at: Vector.zero, in: testOwnedVector) // The adjoint should not release primal values because they are passed in as @guaranteed. // // CHECK-NOVJP-LABEL: @{{.*}}testOwnedVector{{.*}}__adjoint_src_0_wrt_0 -// CHECK-NOVJP: bb0({{%.*}} : $Vector, [[PRIMAL_VALUES:%.*]] : ${{.*}}testOwnedVector{{.*}}__Type__src_0_wrt_0, {{%.*}} : $Vector, {{%.*}} : $Vector): +// CHECK-NOVJP: bb0({{%.*}} : $Vector, [[PRIMAL_VALUES:%.*]] : ${{.*}}testOwnedVector{{.*}}__Type__src_0_wrt_0): // CHECK-NOVJP: [[PV0:%.*]] = struct_extract [[PRIMAL_VALUES]] : ${{.*}}testOwnedVector{{.*}}__Type__src_0_wrt_0, #{{.*}}testOwnedVector{{.*}}__Type__src_0_wrt_0.v_0 // CHECK-NOVJP-NOT: release_value [[PV0]] // CHECK-NOVJP-NOT: release_value [[PRIMAL_VALUES]]