diff --git a/include/swift/AST/Attr.h b/include/swift/AST/Attr.h index b0e454a9b4a00..3697b101028a9 100644 --- a/include/swift/AST/Attr.h +++ b/include/swift/AST/Attr.h @@ -1537,8 +1537,8 @@ class DifferentiableAttr final ParsedAutoDiffParameter> { friend TrailingObjects; - /// Whether this function is linear (optional). - bool linear; + /// Whether this function is linear. + bool Linear; /// The number of parsed parameters specified in 'wrt:'. unsigned NumParsedParameters = 0; /// The JVP function. @@ -1621,7 +1621,7 @@ class DifferentiableAttr final return NumParsedParameters; } - bool isLinear() const { return linear; } + bool isLinear() const { return Linear; } TrailingWhereClause *getWhereClause() const { return WhereClause; } @@ -1676,8 +1676,8 @@ class DifferentiatingAttr final DeclNameWithLoc Original; /// The original function, resolved by the type checker. FuncDecl *OriginalFunction = nullptr; - /// Whether this function is linear (optional). - bool linear; + /// Whether this function is linear. + bool Linear; /// The number of parsed parameters specified in 'wrt:'. unsigned NumParsedParameters = 0; /// The differentiation parameters' indices, resolved by the type checker. @@ -1706,7 +1706,7 @@ class DifferentiatingAttr final DeclNameWithLoc getOriginal() const { return Original; } - bool isLinear() const { return linear; } + bool isLinear() const { return Linear; } FuncDecl *getOriginalFunction() const { return OriginalFunction; } void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; } diff --git a/lib/AST/Attr.cpp b/lib/AST/Attr.cpp index bdb12fa3c3df5..7dd347fce023d 100644 --- a/lib/AST/Attr.cpp +++ b/lib/AST/Attr.cpp @@ -1447,7 +1447,7 @@ DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit, Optional vjp, TrailingWhereClause *clause) : DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit), - linear(linear), NumParsedParameters(params.size()), JVP(std::move(jvp)), + Linear(linear), NumParsedParameters(params.size()), JVP(std::move(jvp)), VJP(std::move(vjp)), WhereClause(clause) { std::copy(params.begin(), params.end(), getTrailingObjects()); @@ -1461,7 +1461,7 @@ DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit, Optional vjp, GenericSignature *derivativeGenSig) : DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit), - linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)), + Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)), ParameterIndices(indices) { setDerivativeGenericSignature(context, derivativeGenSig); } @@ -1530,7 +1530,7 @@ DifferentiatingAttr::DifferentiatingAttr( DeclNameWithLoc original, bool linear, ArrayRef params) : DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit), - Original(std::move(original)), linear(linear), + Original(std::move(original)), Linear(linear), NumParsedParameters(params.size()) { std::copy(params.begin(), params.end(), getTrailingObjects()); @@ -1540,7 +1540,7 @@ DifferentiatingAttr::DifferentiatingAttr( ASTContext &context, bool implicit, SourceLoc atLoc, SourceRange baseRange, DeclNameWithLoc original, bool linear, IndexSubset *indices) : DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit), - Original(std::move(original)), linear(linear), ParameterIndices(indices) { + Original(std::move(original)), Linear(linear), ParameterIndices(indices) { } DifferentiatingAttr * diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp index 0076ade2d4e00..f5b9fdf4411d5 100644 --- a/lib/SILGen/SILGen.cpp +++ b/lib/SILGen/SILGen.cpp @@ -770,7 +770,7 @@ void SILGenModule::postEmitFunction(SILDeclRef constant, paramIndices, origFnType); SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices); assert(silDiffAttr->getIndices() == indices && - "Expected matching @differentiable and [differentiable]"); + "Expected matching @differentiable and [differentiable] indices"); auto lookUpConformance = LookUpConformanceInModule(M.getSwiftModule()); auto expectedJVPType = origSilFnType->getAutoDiffDerivativeFunctionType( @@ -875,10 +875,6 @@ void SILGenModule::emitAbstractFuncDecl(AbstractFunctionDecl *AFD) { if (!hasFunction(thunk)) emitNativeToForeignThunk(thunk); } - - // TODO: Handle SILGen for `@differentiating` attribute. - // Tentative solution: SILGen derivative function normally but also emit - // mangled redirection thunk for retroactive differentiation. } void SILGenModule::emitFunction(FuncDecl *fd) { diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 28424510635c6..7c5646392535c 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -353,8 +353,8 @@ struct DifferentiationInvoker { Kind kind; union Value { /// The instruction associated with the `DifferentiableFunctionInst` case. - DifferentiableFunctionInst *adFuncInst; - Value(DifferentiableFunctionInst *inst) : adFuncInst(inst) {} + DifferentiableFunctionInst *diffFuncInst; + Value(DifferentiableFunctionInst *inst) : diffFuncInst(inst) {} /// The parent `apply` instruction and `[differentiable]` attribute /// associated with the `IndirectDifferentiation` case. @@ -385,7 +385,7 @@ struct DifferentiationInvoker { DifferentiableFunctionInst *getDifferentiableFunctionInst() const { assert(kind == Kind::DifferentiableFunctionInst); - return value.adFuncInst; + return value.diffFuncInst; } std::pair diff --git a/stdlib/public/core/DifferentiationSupport.swift b/stdlib/public/core/DifferentiationSupport.swift index 92c4e11f4d73e..510d20fc94b19 100644 --- a/stdlib/public/core/DifferentiationSupport.swift +++ b/stdlib/public/core/DifferentiationSupport.swift @@ -919,7 +919,7 @@ public struct AnyDerivative : EuclideanDifferentiable & AdditiveArithmetic { @differentiating(+) @usableFromInline internal static func _jvpAdd( lhs: AnyDerivative, rhs: AnyDerivative - ) -> (value: AnyDerivative, + ) -> (value: AnyDerivative, differential: (AnyDerivative, AnyDerivative) -> (AnyDerivative)) { return (lhs + rhs, { (dlhs, drhs) in dlhs + drhs }) }