Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 72 additions & 2 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -4303,12 +4303,82 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,

CanSILFunctionType getWithoutDifferentiability();

/// Returns the type of the derivative function.
/// Returns the type of the derivative function for the given parameter
/// indices, result index, derivative function kind, derivative function
/// generic signature (optional), and other auxiliary parameters.
///
/// Preconditions:
/// - Parameters corresponding to parameter indices must conform to
/// `Differentiable`.
/// - The result corresponding to the result index must conform to
/// `Differentiable`.
///
/// Typing rules, given:
/// - Original function type: $(T0, T1, ...) -> (R0, R1, ...)
///
/// Terminology:
/// - The derivative of a `Differentiable`-conforming type has the
/// `TangentVector` associated type. `TangentVector` is abbreviated as `Tan`
/// below.
/// - "wrt" parameters refers to parameters indicated by the parameter
/// indices.
/// - "wrt" result refers to the result indicated by the result index.
///
/// JVP derivative type:
/// - Takes original parameters.
/// - Returns original results, followed by a differential function, which
/// takes "wrt" parameter derivatives and returns a "wrt" result derivative.
///
/// $(T0, ...) -> (R0, ..., (T0.Tan, T1.Tan, ...) -> R0.Tan)
/// ^~~~~~~ ^~~~~~~~~~~~~~~~~~~ ^~~~~~
/// original results | derivatives wrt params | derivative wrt result
///
/// VJP derivative type:
/// - Takes original parameters.
/// - Returns original results, followed by a pullback function, which
/// takes a "wrt" result derivative and returns "wrt" parameter derivatives.
///
/// $(T0, ...) -> (R0, ..., (R0.Tan) -> (T0.Tan, T1.Tan, ...))
/// ^~~~~~~ ^~~~~~ ^~~~~~~~~~~~~~~~~~~
/// original results | derivative wrt result | derivatives wrt params
///
/// The JVP/VJP generic signature is a "constrained" version of the given
/// `derivativeFunctionGenericSignature` if specified. Otherwise, it is a
/// "constrained" version of the original generic signature. A "constrained"
/// generic signature requires all "wrt" parameters to conform to
/// `Differentiable`; this is important for correctness.
///
/// Other properties of the original function type are copied exactly:
/// `ExtInfo`, coroutine kind, callee convention, yields, optional error
/// result, witness method conformance, etc.
///
/// Special cases:
/// - Reabstraction thunks have special derivative type calculation. The
/// original function-typed last parameter is transformed into a
/// `@differentiable` function-typed parameter in the derivative type. This
/// is necessary for the differentiation transform to support reabstraction
/// thunk differentiation because the function argument is opaque and cannot
/// be differentiated. Instead, the argument is made `@differentiable` and
/// reabstraction thunk JVP/VJP callers are reponsible for passing a
/// `@differentiable` function.
/// - TODO(TF-1036): Investigate more efficient reabstraction thunk
/// derivative approaches. The last argument can simply be a
/// corresponding derivative function, instead of a `@differentiable`
/// function - this is more direct. It may be possible to implement
/// reabstraction thunk derivatives using "reabstraction thunks for
/// the original function's derivative", avoiding extra code generation.
///
/// Caveats:
/// - We may support multiple result indices instead of a single result index
/// eventually. At the SIL level, this enables differentiating wrt multiple
/// function results. At the Swift level, this enables differentiating wrt
/// multiple tuple elements for tuple-returning functions.
CanSILFunctionType getAutoDiffDerivativeFunctionType(
IndexSubset *parameterIndices, unsigned resultIndex,
AutoDiffDerivativeFunctionKind kind, Lowering::TypeConverter &TC,
LookupConformanceFn lookupConformance,
CanGenericSignature derivativeFunctionGenericSignature = nullptr);
CanGenericSignature derivativeFunctionGenericSignature = nullptr,
bool isReabstractionThunk = false);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to have a documentation comment explaining what the derivative function of a reabstraction thunk is.

If I understand correctly, it is "a function that, when partially applied to the original function, becomes the derivative function of the reabstracted function." Or maybe instead it "becomes a reabstraction of the derivative function". I haven't thought carefully enough to figure out which one it is, or if there is a difference, but maybe there are situations where the difference is important.

Another question I have is: Why does the derivative function of the reabstraction function take the original function as argument rather than than the derivative function? That might be a possible alternative with advantages or disadvantages. I haven't thought enough about it to determine whether it's actually possible or whether it's a better alternative, I'm just wondering if you have. A comment somewhere about why we decided one way or the other would be useful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to have a documentation comment explaining what the derivative function of a reabstraction thunk is.

Added doc comment for SILFunctionType::getAutoDiffDerivativeFunctionType in 6eb7a07:

  /// Special cases:
  /// - Reabstraction thunks have special derivative type calculation. The
  ///   original function-typed last parameter is transformed into a
  ///   `@differentiable` function-typed parameter in the derivative type. This
  ///   is necessary for the differentiation transform to support reabstraction
  ///   thunk differentiation because the function argument is opaque and cannot
  ///   be differentiated. Instead, the argument is made `@differentiable` and
  ///   reabstraction thunk JVP/VJP callers are reponsible for passing a
  ///   `@differentiable` function.

Also added comment in reapplyFunctionConversion in Differentiation.cpp explaining partial_apply reapplication for reabstraction thunk derivatives.

If I understand correctly, it is "a function that, when partially applied to the original function, becomes the derivative function of the reabstracted function." Or maybe instead it "becomes a reabstraction of the derivative function". I haven't thought carefully enough to figure out which one it is, or if there is a difference, but maybe there are situations where the difference is important.

I believe your understanding is correct. I made a short write-up about the reabstraction thunk differentiation approach.

Another question I have is: Why does the derivative function of the reabstraction thunk take the original function as argument rather than than the derivative function?

That's a good point! A reabstraction thunk derivative only needs the original function's corresponding derivative, not any other @differentiable function components. I think that approach is more efficient and semantically clear, so I'll pursue it now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, to focus on retrodiff as a priority, how about deferring optimizations and using the current approach for now? TF-1036 tracks reabstraction thunk derivative generation optimizations.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

nice doc comment!


/// Returns the type of the transpose function.
CanSILFunctionType getAutoDiffTransposeFunctionType(
Expand Down
3 changes: 2 additions & 1 deletion include/swift/SILOptimizer/Utils/Differentiation/ADContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ class ADContext {
}

/// Adds the given `differentiable_function` instruction to the worklist.
void addDifferentiableFunctionInst(DifferentiableFunctionInst* dfi) {
void
addDifferentiableFunctionInstToWorklist(DifferentiableFunctionInst *dfi) {
differentiableFunctionInsts.push_back(dfi);
}

Expand Down
16 changes: 10 additions & 6 deletions include/swift/SILOptimizer/Utils/Differentiation/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,16 @@ template<class Inst>
Inst *peerThroughFunctionConversions(SILValue value) {
if (auto *inst = dyn_cast<Inst>(value))
return inst;
if (auto *thinToThick = dyn_cast<ThinToThickFunctionInst>(value))
return peerThroughFunctionConversions<Inst>(thinToThick->getOperand());
if (auto *convertFn = dyn_cast<ConvertFunctionInst>(value))
return peerThroughFunctionConversions<Inst>(convertFn->getOperand());
if (auto *partialApply = dyn_cast<PartialApplyInst>(value))
return peerThroughFunctionConversions<Inst>(partialApply->getCallee());
if (auto *cvi = dyn_cast<CopyValueInst>(value))
return peerThroughFunctionConversions<Inst>(cvi->getOperand());
if (auto *bbi = dyn_cast<BeginBorrowInst>(value))
return peerThroughFunctionConversions<Inst>(bbi->getOperand());
if (auto *tttfi = dyn_cast<ThinToThickFunctionInst>(value))
return peerThroughFunctionConversions<Inst>(tttfi->getOperand());
if (auto *cfi = dyn_cast<ConvertFunctionInst>(value))
return peerThroughFunctionConversions<Inst>(cfi->getOperand());
if (auto *pai = dyn_cast<PartialApplyInst>(value))
return peerThroughFunctionConversions<Inst>(pai->getCallee());
return nullptr;
}

Expand Down
28 changes: 22 additions & 6 deletions lib/SIL/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
IndexSubset *parameterIndices, unsigned resultIndex,
AutoDiffDerivativeFunctionKind kind, TypeConverter &TC,
LookupConformanceFn lookupConformance,
CanGenericSignature derivativeFnGenSig) {
CanGenericSignature derivativeFnGenSig, bool isReabstractionThunk) {
// JVP: (T...) -> ((R...),
// (T.TangentVector...) -> (R.TangentVector...))
// VJP: (T...) -> ((R...),
Expand Down Expand Up @@ -341,6 +341,22 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
}
}

SmallVector<SILParameterInfo, 4> newParameters;
newParameters.reserve(getNumParameters());
newParameters.append(getParameters().begin(), getParameters().end());
// Reabstraction thunks have a function-typed parameter (the function to
// reabstract) as their last parameter. Reabstraction thunk JVPs/VJPs have a
// `@differentiable` function-typed last parameter instead.
if (isReabstractionThunk) {
assert(!parameterIndices->contains(getNumParameters() - 1) &&
"Function-typed parameter should not be wrt");
auto fnParam = newParameters.back();
auto fnParamType = dyn_cast<SILFunctionType>(fnParam.getInterfaceType());
assert(fnParamType);
auto diffFnType = fnParamType->getWithDifferentiability(
DifferentiabilityKind::Normal, parameterIndices);
newParameters.back() = fnParam.getWithInterfaceType(diffFnType);
}
SmallVector<SILResultInfo, 4> newResults;
newResults.reserve(getNumResults() + 1);
for (auto &result : getResults()) {
Expand All @@ -350,11 +366,11 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
}
newResults.push_back({closureType->getCanonicalType(derivativeFnGenSig),
ResultConvention::Owned});
return SILFunctionType::get(derivativeFnGenSig, getExtInfo(),
getCoroutineKind(), getCalleeConvention(),
getParameters(), getYields(), newResults,
getOptionalErrorResult(), getSubstitutions(), isGenericSignatureImplied(), ctx,
getWitnessMethodConformanceOrInvalid());
return SILFunctionType::get(
derivativeFnGenSig, getExtInfo(), getCoroutineKind(),
getCalleeConvention(), newParameters, getYields(), newResults,
getOptionalErrorResult(), getSubstitutions(), isGenericSignatureImplied(),
ctx, getWitnessMethodConformanceOrInvalid());
}

CanSILFunctionType SILFunctionType::getAutoDiffTransposeFunctionType(
Expand Down
7 changes: 5 additions & 2 deletions lib/SIL/SILInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -768,9 +768,12 @@ SILType DifferentiabilityWitnessFunctionInst::getDifferentiabilityWitnessType(
auto *parameterIndices = witness->getParameterIndices();
auto *resultIndices = witness->getResultIndices();
if (auto derivativeKind = witnessKind.getAsDerivativeFunctionKind()) {
bool isReabstractionThunk =
witness->getOriginalFunction()->isThunk() == IsReabstractionThunk;
auto diffFnTy = fnTy->getAutoDiffDerivativeFunctionType(
parameterIndices, *resultIndices->begin(), *derivativeKind, module.Types,
LookUpConformanceInModule(module.getSwiftModule()), witnessCanGenSig);
parameterIndices, *resultIndices->begin(), *derivativeKind,
module.Types, LookUpConformanceInModule(module.getSwiftModule()),
witnessCanGenSig, isReabstractionThunk);
return SILType::getPrimitiveObjectType(diffFnTy);
}
assert(witnessKind ==
Expand Down
10 changes: 7 additions & 3 deletions lib/SIL/SILVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5412,7 +5412,9 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const {
// parameter/result conventions in lowered SIL.
if (M.getStage() == SILStage::Lowered)
return;
auto origFnType = getOriginalFunction()->getLoweredFunctionType();
auto *origFn = getOriginalFunction();
auto origFnType = origFn->getLoweredFunctionType();
bool origIsReabstractionThunk = origFn->isThunk() == IsReabstractionThunk;
CanGenericSignature derivativeCanGenSig;
if (auto derivativeGenSig = getDerivativeGenericSignature())
derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
Expand All @@ -5438,7 +5440,8 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const {
auto expectedJVPType = origFnType->getAutoDiffDerivativeFunctionType(
getParameterIndices(), /*resultIndex*/ *getResultIndices()->begin(),
AutoDiffDerivativeFunctionKind::JVP, M.Types,
LookUpConformanceInModule(M.getSwiftModule()), derivativeCanGenSig);
LookUpConformanceInModule(M.getSwiftModule()), derivativeCanGenSig,
origIsReabstractionThunk);
requireSameType(jvp->getLoweredFunctionType(), expectedJVPType,
"JVP type does not match expected JVP type");
}
Expand All @@ -5448,7 +5451,8 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const {
auto expectedVJPType = origFnType->getAutoDiffDerivativeFunctionType(
getParameterIndices(), /*resultIndex*/ *getResultIndices()->begin(),
AutoDiffDerivativeFunctionKind::VJP, M.Types,
LookUpConformanceInModule(M.getSwiftModule()), derivativeCanGenSig);
LookUpConformanceInModule(M.getSwiftModule()), derivativeCanGenSig,
origIsReabstractionThunk);
requireSameType(vjp->getLoweredFunctionType(), expectedVJPType,
"VJP type does not match expected VJP type");
}
Expand Down
Loading