diff --git a/docs/SIL.rst b/docs/SIL.rst index 98839126da255..d9a068a6ab13f 100644 --- a/docs/SIL.rst +++ b/docs/SIL.rst @@ -5609,24 +5609,23 @@ differentiable_function sil-differentiable-function-parameter-indices? sil-differentiable-function-order? sil-value ':' sil-type - sil-differentiable-function-associated-functions-clause? + sil-differentiable-function-derivative-functions-clause? sil-differentiable-function-parameter-indices ::= '[' 'wrt' [0-9]+ (',', [0-9]+)* ']' sil-differentiable-function-order ::= '[' 'order' [0-9]+ ']' - sil-differentiable-associated-functions-clause ::= - 'with' sil-differentiable-associated-function-list - (',' sil-differentiable-associated-function-list)* - sil-differentiable-function-associated-function-list ::= + sil-differentiable-derivative-functions-clause ::= + 'with' sil-differentiable-derivative-function-list + (',' sil-differentiable-derivative-function-list)* + sil-differentiable-function-derivative-function-list ::= '{' sil-value ',' sil-value '}' differentiable_function [wrt 0] %0 : $(T) -> T \ with {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)} -Bundles a function with its associated differentiation functions into a -``@differentiable`` function. There are two associated functions: -a Jacobian-vector products (JVP) function and a vector-Jacobian products (VJP) -function. +Bundles a function with its derivative functions into a ``@differentiable`` +function. There are two derivative functions: a Jacobian-vector products (JVP) +function and a vector-Jacobian products (VJP) function. ``[wrt ...]`` specifies parameter indices that the original function is differentiable with respect to. When not specified, it defaults to all @@ -5634,10 +5633,10 @@ parameters. A ``with`` clause specifies the differentiation functions associated with the original function. When a ``with`` clause is not specified, the first -operand will be differentiated to produce associated functions, and a ``with`` +operand will be differentiated to produce derivative functions, and a ``with`` clause will be added to the instruction. -In raw SIL, it is optional to provide an associated function ``with`` clause. +In raw SIL, it is optional to provide a derivative function ``with`` clause. In canonical SIL, a ``with`` clause is mandatory. @@ -5660,7 +5659,7 @@ differentiable_function_extract differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T differentiable_function_extract [vjp] %0 : $@differentiable (T) -> T -Extracts the original function or an associated function from the given +Extracts the original function or a derivative function from the given ``@differentiable`` function. It must be provided with an extractee: ``[original]``, ``[jvp]`` or ``[vjp]``. diff --git a/include/swift/AST/ASTMangler.h b/include/swift/AST/ASTMangler.h index 171689f5d7368..6ea3807732d3d 100644 --- a/include/swift/AST/ASTMangler.h +++ b/include/swift/AST/ASTMangler.h @@ -155,12 +155,12 @@ class ASTMangler : public Mangler { ModuleDecl *Module); // SWIFT_ENABLE_TENSORFLOW - // Mangle the autodiff associated function (JVP/VJP) with the given: + // Mangle the derivative function (JVP/VJP) with the given: // - Mangled original function name. - // - Associated function kind. + // - Derivative function kind. // - Parameter/result indices. - std::string mangleAutoDiffAssociatedFunctionHelper( - StringRef name, AutoDiffAssociatedFunctionKind kind, + std::string mangleAutoDiffDerivativeFunctionHelper( + StringRef name, AutoDiffDerivativeFunctionKind kind, const SILAutoDiffIndices &indices); // SWIFT_ENABLE_TENSORFLOW diff --git a/include/swift/AST/Attr.h b/include/swift/AST/Attr.h index 4a34d6e5db257..6bf3b282041a8 100644 --- a/include/swift/AST/Attr.h +++ b/include/swift/AST/Attr.h @@ -1555,7 +1555,7 @@ class DifferentiableAttr final AutoDiffIndexSubset *ParameterIndices = nullptr; /// The trailing where clause (optional). TrailingWhereClause *WhereClause = nullptr; - /// The generic signature for autodiff associated functions. Resolved by the + /// The generic signature for autodiff derivative functions. Resolved by the /// type checker based on the original function's generic signature and the /// attribute's where clause requirements. This is set only if the attribute /// has a where clause. @@ -1650,10 +1650,10 @@ class DifferentiableAttr final // Print the attribute to the given stream. // If `omitWrtClause` is true, omit printing the `wrt:` clause. - // If `omitAssociatedFunctions` is true, omit printing associated functions. + // If `omitDerivativeFunctions` is true, omit printing derivative functions. void print(llvm::raw_ostream &OS, const Decl *D, bool omitWrtClause = false, - bool omitAssociatedFunctions = false) const; + bool omitDerivativeFunctions = false) const; static bool classof(const DeclAttribute *DA) { return DA->getKind() == DAK_Differentiable; diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 3534b87e35697..6334a6c5d05fc 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -431,8 +431,8 @@ struct AutoDiffLinearMapKind { operator innerty() const { return rawValue; } }; -/// The kind of an associated function. -struct AutoDiffAssociatedFunctionKind { +/// The kind of a derivative function. +struct AutoDiffDerivativeFunctionKind { enum innerty : uint8_t { // The Jacobian-vector products function. JVP = 0, @@ -440,11 +440,11 @@ struct AutoDiffAssociatedFunctionKind { VJP = 1 } rawValue; - AutoDiffAssociatedFunctionKind() = default; - AutoDiffAssociatedFunctionKind(innerty rawValue) : rawValue(rawValue) {} - AutoDiffAssociatedFunctionKind(AutoDiffLinearMapKind linMapKind) + AutoDiffDerivativeFunctionKind() = default; + AutoDiffDerivativeFunctionKind(innerty rawValue) : rawValue(rawValue) {} + AutoDiffDerivativeFunctionKind(AutoDiffLinearMapKind linMapKind) : rawValue(static_cast(linMapKind.rawValue)) {} - explicit AutoDiffAssociatedFunctionKind(StringRef string); + explicit AutoDiffDerivativeFunctionKind(StringRef string); operator innerty() const { return rawValue; } AutoDiffLinearMapKind getLinearMapKind() { return (AutoDiffLinearMapKind::innerty)rawValue; @@ -452,27 +452,27 @@ struct AutoDiffAssociatedFunctionKind { }; /// In conjunction with the original function declaration, identifies an -/// autodiff associated function. +/// autodiff derivative function. /// /// Is uniquely allocated within an ASTContext so that it can be hashed and /// compared by opaque pointer value. -class AutoDiffAssociatedFunctionIdentifier : public llvm::FoldingSetNode { - const AutoDiffAssociatedFunctionKind kind; +class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode { + const AutoDiffDerivativeFunctionKind kind; AutoDiffIndexSubset *const parameterIndices; - AutoDiffAssociatedFunctionIdentifier( - AutoDiffAssociatedFunctionKind kind, + AutoDiffDerivativeFunctionIdentifier( + AutoDiffDerivativeFunctionKind kind, AutoDiffIndexSubset *parameterIndices) : kind(kind), parameterIndices(parameterIndices) {} public: - AutoDiffAssociatedFunctionKind getKind() const { return kind; } + AutoDiffDerivativeFunctionKind getKind() const { return kind; } AutoDiffIndexSubset *getParameterIndices() const { return parameterIndices; } - static AutoDiffAssociatedFunctionIdentifier *get( - AutoDiffAssociatedFunctionKind kind, + static AutoDiffDerivativeFunctionIdentifier *get( + AutoDiffDerivativeFunctionKind kind, AutoDiffIndexSubset *parameterIndices, ASTContext &C); void Profile(llvm::FoldingSetNodeID &ID) { @@ -520,15 +520,15 @@ AutoDiffIndexSubset *getLoweredParameterIndices(AutoDiffIndexSubset *indices, /// `Builtin.autodiffApply`, e.g. `Builtin.autodiffApply_jvp_arity2_order1`. /// Returns true if the function name is parsed successfully. bool getBuiltinAutoDiffApplyConfig(StringRef operationName, - AutoDiffAssociatedFunctionKind &kind, + AutoDiffDerivativeFunctionKind &kind, unsigned &arity, bool &rethrows); -/// Computes the correct linkage for an associated function given the linkage of +/// Computes the correct linkage for a derivative function given the linkage of /// the original function. If the original linkage is not external and -/// `isAssocFnExported` is true, use the original function's linkage. Otherwise, -/// return hidden linkage. -SILLinkage getAutoDiffAssociatedFunctionLinkage(SILLinkage originalLinkage, - bool isAssocFnExported); +/// `isDerivativeFnExported` is true, use the original function's linkage. +/// Otherwise, return hidden linkage. +SILLinkage getAutoDiffDerivativeFunctionLinkage(SILLinkage originalLinkage, + bool isDerivativeFnExported); } // end namespace autodiff diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index 68c88f1907dd0..1df71ebb700a6 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -1597,15 +1597,15 @@ ERROR(sil_inst_autodiff_attr_expected_rsquare,PointsToFirstBadToken, ERROR(sil_inst_autodiff_expected_parameter_index,PointsToFirstBadToken, "expected the index of a parameter to differentiate with respect to", ()) ERROR(sil_inst_autodiff_operand_list_expected_lbrace,PointsToFirstBadToken, - "expected '{' to start an associated function list", ()) + "expected '{' to start a derivative function list", ()) ERROR(sil_inst_autodiff_operand_list_expected_comma,PointsToFirstBadToken, - "expected ',' between operands in an associated function list", ()) + "expected ',' between operands in a derivative function list", ()) ERROR(sil_inst_autodiff_operand_list_expected_rbrace,PointsToFirstBadToken, - "expected '}' to start an associated function list", ()) + "expected '}' to start a derivative function list", ()) ERROR(sil_inst_autodiff_num_operand_list_order_mismatch,PointsToFirstBadToken, "the number of operand lists does not match the order", ()) ERROR(sil_inst_autodiff_expected_associated_function_kind_attr,PointsToFirstBadToken, - "expected an associated function kind attribute, e.g. '[jvp]'", ()) + "expected a derivative function kind attribute, e.g. '[jvp]'", ()) ERROR(sil_inst_autodiff_expected_function_type_operand,PointsToFirstBadToken, "expected an operand of a function type", ()) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 7a7c5ccf037d1..621329ef94db5 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2859,8 +2859,7 @@ ERROR(implements_attr_protocol_not_conformed_to,none, ERROR(differentiable_attr_void_result,none, "cannot differentiate void function %0", (DeclName)) ERROR(differentiable_attr_associated_function_protocol,none, - "cannot specify associated differentiation function on protocol " - "requirement", ()) + "cannot specify derivative functions on protocol requirements", ()) ERROR(differentiable_attr_overload_not_found,none, "%0 does not have expected type %1", (DeclName, Type)) ERROR(differentiable_attr_no_currying,none, @@ -2874,7 +2873,7 @@ NOTE(differentiable_attr_duplicate_note,none, ERROR(differentiable_attr_function_not_same_type_context,none, "%0 is not defined in the current type context", (DeclName)) ERROR(differentiable_attr_specified_not_function,none, - "%0 is not a function to be used as associated differentiation function", + "%0 is not a function to be used as derivative function", (DeclName)) ERROR(differentiable_attr_class_derivative_not_final,none, "class member derivative must be final", ()) @@ -2882,9 +2881,9 @@ ERROR(differentiable_attr_ambiguous_function_identifier,none, "ambiguous or overloaded identifier %0 cannot be used in '@differentiable' " "attribute", (DeclName)) ERROR(differentiable_attr_invalid_access,none, - "associated differentiation function %0 is required to either be public " - "or @usableFromInline because the original function %1 is public or " - "@usableFromInline", (DeclName, DeclName)) + "derivative function %0 is required to either be public or " + "'@usableFromInline' because the original function %1 is public or " + "'@usableFromInline'", (DeclName, DeclName)) ERROR(differentiable_attr_result_not_differentiable,none, "can only differentiate functions with results that conform to " "'Differentiable', but %0 does not conform to 'Differentiable'", (Type)) diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index b7084d437d654..3f40e111c2148 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -3100,10 +3100,10 @@ class AnyFunctionType : public TypeBase { // SWIFT_ENABLE_TENSORFLOW /// Given `indices` and `kind`, calculates the type of the corresponding - /// autodiff associated function. + /// autodiff derivative function. /// /// By default, if the original type has a self parameter list and parameter - /// indices include self, the computed associated function type will return a + /// indices include self, the computed derivative function type will return a /// linear map taking/returning self's tangent *last* instead of first, for /// consistency with SIL. /// @@ -3114,18 +3114,18 @@ class AnyFunctionType : public TypeBase { /// \note The original function type (`self`) need not be `@differentiable`. /// The resulting function will preserve all `ExtInfo` of the original /// function, including `@differentiable`. - AnyFunctionType *getAutoDiffAssociatedFunctionType( + AnyFunctionType *getAutoDiffDerivativeFunctionType( AutoDiffIndexSubset *indices, unsigned resultIndex, - AutoDiffAssociatedFunctionKind kind, + AutoDiffDerivativeFunctionKind kind, LookupConformanceFn lookupConformance, GenericSignature *whereClauseGenericSignature = nullptr, bool makeSelfParamFirst = false); - /// Given the type of an autodiff associated function, returns the + /// Given the type of an autodiff derivative function, returns the /// corresponding original function type. AnyFunctionType *getAutoDiffOriginalFunctionType(); - /// Given the type of a transposing associated function, returns the + /// Given the type of a transposing derivative function, returns the /// corresponding original function type. AnyFunctionType * getTransposeOriginalFunctionType(TransposingAttr *attr, @@ -4222,11 +4222,11 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode, /// Returns the type of a differentiation function that is associated with /// a function of this type. - CanSILFunctionType getAutoDiffAssociatedFunctionType( + CanSILFunctionType getAutoDiffDerivativeFunctionType( AutoDiffIndexSubset *parameterIndices, unsigned resultIndex, - AutoDiffAssociatedFunctionKind kind, Lowering::TypeConverter &TC, + AutoDiffDerivativeFunctionKind kind, Lowering::TypeConverter &TC, LookupConformanceFn lookupConformance, - CanGenericSignature associatedFunctionGenericSignature = nullptr); + CanGenericSignature derivativeFunctionGenericSignature = nullptr); /// Returns a bit vector that specifices which parameters you can /// differentiate with respect to for this differentiable function type. (e.g. diff --git a/include/swift/SIL/SILCloner.h b/include/swift/SIL/SILCloner.h index 084aebf6eeee5..526f50ada16ad 100644 --- a/include/swift/SIL/SILCloner.h +++ b/include/swift/SIL/SILCloner.h @@ -973,7 +973,7 @@ void SILCloner::visitDifferentiableFunctionInst( Optional> derivativeFns = None; if (Inst->hasDerivativeFunctions()) derivativeFns = std::make_pair(getOpValue(Inst->getJVPFunction()), - getOpValue(Inst->getVJPFunction())); + getOpValue(Inst->getVJPFunction())); recordClonedInstruction( Inst, getBuilder().createDifferentiableFunction( getOpLocation(Inst->getLoc()), Inst->getParameterIndices(), diff --git a/include/swift/SIL/SILDeclRef.h b/include/swift/SIL/SILDeclRef.h index 039aa42ff78e6..c009f9c8205ff 100644 --- a/include/swift/SIL/SILDeclRef.h +++ b/include/swift/SIL/SILDeclRef.h @@ -35,7 +35,7 @@ namespace swift { class AbstractFunctionDecl; class AbstractClosureExpr; // SWIFT_ENABLE_TENSORFLOW - class AutoDiffAssociatedFunctionIdentifier; + class AutoDiffDerivativeFunctionIdentifier; class ValueDecl; class FuncDecl; class ClosureExpr; @@ -157,22 +157,22 @@ struct SILDeclRef { // SWIFT_ENABLE_TENSORFLOW /// When this is non-null, it modifies the SILDeclRef to refer to the - /// corresponding autodiff associated function. - AutoDiffAssociatedFunctionIdentifier *autoDiffAssociatedFunctionIdentifier; + /// corresponding autodiff derivative function. + AutoDiffDerivativeFunctionIdentifier *autoDiffDerivativeFunctionIdentifier; /// Produces a null SILDeclRef. SILDeclRef() : loc(), kind(Kind::Func), isCurried(0), isForeign(0), isDirectReference(0), // SWIFT_ENABLE_TENSORFLOW defaultArgIndex(0), - autoDiffAssociatedFunctionIdentifier(nullptr) {} + autoDiffDerivativeFunctionIdentifier(nullptr) {} /// Produces a SILDeclRef of the given kind for the given decl. explicit SILDeclRef(ValueDecl *decl, Kind kind, bool isCurried = false, // SWIFT_ENABLE_TENSORFLOW bool isForeign = false, - AutoDiffAssociatedFunctionIdentifier *autoDiffFuncId = + AutoDiffDerivativeFunctionIdentifier *autoDiffFuncId = nullptr); /// Produces a SILDeclRef for the given ValueDecl or @@ -308,8 +308,8 @@ struct SILDeclRef { && isDirectReference == rhs.isDirectReference // SWIFT_ENABLE_TENSORFLOW && defaultArgIndex == rhs.defaultArgIndex - && autoDiffAssociatedFunctionIdentifier == - rhs.autoDiffAssociatedFunctionIdentifier; + && autoDiffDerivativeFunctionIdentifier == + rhs.autoDiffDerivativeFunctionIdentifier; } bool operator!=(SILDeclRef rhs) const { return !(*this == rhs); @@ -330,7 +330,7 @@ struct SILDeclRef { curried, willBeDirect, willBeForeign, // SWIFT_ENABLE_TENSORFLOW defaultArgIndex, - autoDiffAssociatedFunctionIdentifier); + autoDiffDerivativeFunctionIdentifier); } /// Returns the foreign (or native) entry point corresponding to the same @@ -340,7 +340,7 @@ struct SILDeclRef { return SILDeclRef(loc.getOpaqueValue(), kind, // SWIFT_ENABLE_TENSORFLOW isCurried, isDirectReference, foreign, defaultArgIndex, - autoDiffAssociatedFunctionIdentifier); + autoDiffDerivativeFunctionIdentifier); } SILDeclRef asDirectReference(bool direct = true) const { @@ -354,20 +354,20 @@ struct SILDeclRef { // SWIFT_ENABLE_TENSORFLOW /// Returns the entry point for the corresponding autodiff associated /// function. - SILDeclRef asAutoDiffAssociatedFunction( - AutoDiffAssociatedFunctionIdentifier *id) const { - assert(!autoDiffAssociatedFunctionIdentifier); + SILDeclRef asAutoDiffDerivativeFunction( + AutoDiffDerivativeFunctionIdentifier *id) const { + assert(!autoDiffDerivativeFunctionIdentifier); SILDeclRef r = *this; - r.autoDiffAssociatedFunctionIdentifier = id; + r.autoDiffDerivativeFunctionIdentifier = id; return r; } /// Returns the entry point for the original function corresponding to an - /// autodiff associated function. + /// autodiff derivative function. SILDeclRef asAutoDiffOriginalFunction() const { - assert(autoDiffAssociatedFunctionIdentifier); + assert(autoDiffDerivativeFunctionIdentifier); SILDeclRef r = *this; - r.autoDiffAssociatedFunctionIdentifier = nullptr; + r.autoDiffDerivativeFunctionIdentifier = nullptr; return r; } @@ -454,14 +454,14 @@ struct SILDeclRef { bool isForeign, // SWIFT_ENABLE_TENSORFLOW unsigned defaultArgIndex, - AutoDiffAssociatedFunctionIdentifier *autoDiffFuncId) + AutoDiffDerivativeFunctionIdentifier *autoDiffFuncId) : loc(Loc::getFromOpaqueValue(opaqueLoc)), kind(kind), isCurried(isCurried), isForeign(isForeign), isDirectReference(isDirectReference), // SWIFT_ENABLE_TENSORFLOW defaultArgIndex(defaultArgIndex), - autoDiffAssociatedFunctionIdentifier(autoDiffFuncId) + autoDiffDerivativeFunctionIdentifier(autoDiffFuncId) {} }; @@ -503,7 +503,7 @@ template<> struct DenseMapInfo { unsigned h5 = UnsignedInfo::getHashValue(Val.isDirectReference); // SWIFT_ENABLE_TENSORFLOW unsigned h6 = - PointerInfo::getHashValue(Val.autoDiffAssociatedFunctionIdentifier); + PointerInfo::getHashValue(Val.autoDiffDerivativeFunctionIdentifier); return h1 ^ (h2 << 4) ^ (h3 << 9) ^ (h4 << 7) ^ (h5 << 11) ^ (h6 << 13); } static bool isEqual(swift::SILDeclRef const &LHS, diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h index 5e9dd651b5b45..6969bfa65b421 100644 --- a/include/swift/SIL/SILInstruction.h +++ b/include/swift/SIL/SILInstruction.h @@ -7846,8 +7846,8 @@ class TryApplyInst final // SWIFT_ENABLE_TENSORFLOW /// `differentiable_function` - given a function and differentiation indices and -/// its associated differentiation functions, create an `@differentiable` -/// function that represents a bundle of these functions and configurations. +/// its derivative functions, create an `@differentiable` function that +/// represents a bundle of these functions and configurations. class DifferentiableFunctionInst final : public InstructionBaseWithTrailingOperands< SILInstructionKind::DifferentiableFunctionInst, @@ -7912,16 +7912,16 @@ class DifferentiableFunctionInst final : } /// Returns the derivative function (JVP or VJP) that matches the given kind. - SILValue getDerivativeFunction(AutoDiffAssociatedFunctionKind kind) const { + SILValue getDerivativeFunction(AutoDiffDerivativeFunctionKind kind) const { switch (kind) { - case AutoDiffAssociatedFunctionKind::JVP: return getJVPFunction(); - case AutoDiffAssociatedFunctionKind::VJP: return getVJPFunction(); + case AutoDiffDerivativeFunctionKind::JVP: return getJVPFunction(); + case AutoDiffDerivativeFunctionKind::VJP: return getVJPFunction(); } } }; /// `differentiable_function_extract` - given an `@differentiable` function -/// representing a bundle of the original function and associated functions, +/// representing a bundle of the original function and derivative functions, /// extract the specified function. class DifferentiableFunctionExtractInst : public InstructionBase< @@ -7937,12 +7937,12 @@ class DifferentiableFunctionExtractInst Extractee() = default; Extractee(innerty rawValue) : rawValue(rawValue) {} explicit Extractee(unsigned rawValue) : Extractee((innerty)rawValue) {} - Extractee(AutoDiffAssociatedFunctionKind kind); + Extractee(AutoDiffDerivativeFunctionKind kind); explicit Extractee(StringRef name); operator innerty() const { return rawValue; } - Optional - getExtracteeAsAssociatedFunction() const; + Optional + getExtracteeAsDerivativeFunction() const; }; private: @@ -7961,8 +7961,8 @@ class DifferentiableFunctionExtractInst Extractee getExtractee() const { return extractee; } - AutoDiffAssociatedFunctionKind getAssociatedFunctionKind() const { - auto kind = extractee.getExtracteeAsAssociatedFunction(); + AutoDiffDerivativeFunctionKind getDerivativeFunctionKind() const { + auto kind = extractee.getExtracteeAsDerivativeFunction(); assert(kind); return *kind; } diff --git a/include/swift/SIL/SILVTableVisitor.h b/include/swift/SIL/SILVTableVisitor.h index e44dc4a4ca656..73b21c8a6eb87 100644 --- a/include/swift/SIL/SILVTableVisitor.h +++ b/include/swift/SIL/SILVTableVisitor.h @@ -91,15 +91,15 @@ template class SILVTableVisitor { // SWIFT_ENABLE_TENSORFLOW for (auto *DA : fd->getAttrs().getAttributes()) { auto constant = SILDeclRef(fd, SILDeclRef::Kind::Func); - auto jvpConstant = constant.asAutoDiffAssociatedFunction( - AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::JVP, + auto jvpConstant = constant.asAutoDiffDerivativeFunction( + AutoDiffDerivativeFunctionIdentifier::get( + AutoDiffDerivativeFunctionKind::JVP, DA->getParameterIndices(), fd->getASTContext())); maybeAddEntry(jvpConstant); - auto vjpConstant = constant.asAutoDiffAssociatedFunction( - AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::VJP, + auto vjpConstant = constant.asAutoDiffDerivativeFunction( + AutoDiffDerivativeFunctionIdentifier::get( + AutoDiffDerivativeFunctionKind::VJP, DA->getParameterIndices(), fd->getASTContext())); maybeAddEntry(vjpConstant); } @@ -118,15 +118,15 @@ template class SILVTableVisitor { // SWIFT_ENABLE_TENSORFLOW for (auto *DA : cd->getAttrs().getAttributes()) { auto constant = SILDeclRef(cd, SILDeclRef::Kind::Allocator); - auto jvpConstant = constant.asAutoDiffAssociatedFunction( - AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::JVP, + auto jvpConstant = constant.asAutoDiffDerivativeFunction( + AutoDiffDerivativeFunctionIdentifier::get( + AutoDiffDerivativeFunctionKind::JVP, DA->getParameterIndices(), cd->getASTContext())); maybeAddEntry(jvpConstant); - auto vjpConstant = constant.asAutoDiffAssociatedFunction( - AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::VJP, + auto vjpConstant = constant.asAutoDiffDerivativeFunction( + AutoDiffDerivativeFunctionIdentifier::get( + AutoDiffDerivativeFunctionKind::VJP, DA->getParameterIndices(), cd->getASTContext())); maybeAddEntry(vjpConstant); } diff --git a/include/swift/SIL/SILWitnessVisitor.h b/include/swift/SIL/SILWitnessVisitor.h index 6bac8e8c1eb1f..c9f9d9a9c6c63 100644 --- a/include/swift/SIL/SILWitnessVisitor.h +++ b/include/swift/SIL/SILWitnessVisitor.h @@ -181,13 +181,13 @@ template class SILWitnessVisitor : public ASTVisitor { asDerived().addMethod(funcDeclRef); for (auto *DA : func->getAttrs().getAttributes()) { - asDerived().addMethod(funcDeclRef.asAutoDiffAssociatedFunction( - AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::JVP, + asDerived().addMethod(funcDeclRef.asAutoDiffDerivativeFunction( + AutoDiffDerivativeFunctionIdentifier::get( + AutoDiffDerivativeFunctionKind::JVP, DA->getParameterIndices(), func->getASTContext()))); - asDerived().addMethod(funcDeclRef.asAutoDiffAssociatedFunction( - AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::VJP, + asDerived().addMethod(funcDeclRef.asAutoDiffDerivativeFunction( + AutoDiffDerivativeFunctionIdentifier::get( + AutoDiffDerivativeFunctionKind::VJP, DA->getParameterIndices(), func->getASTContext()))); } } diff --git a/include/swift/SILOptimizer/Analysis/BottomUpIPAnalysis.h b/include/swift/SILOptimizer/Analysis/BottomUpIPAnalysis.h index 2c2763a9b4cc2..26a758fee1fcc 100644 --- a/include/swift/SILOptimizer/Analysis/BottomUpIPAnalysis.h +++ b/include/swift/SILOptimizer/Analysis/BottomUpIPAnalysis.h @@ -43,7 +43,7 @@ class BottomUpIPAnalysis : public SILAnalysis { /// analysis information for a function. /// This base class stores the administrative information needed for /// invalidation and updating the analysis. - /// In the following "this function" refers to the associated function. + /// In the following "this function" refers to the derivative function. template class FunctionInfoBase { public: diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index 8f2be19379006..cc85ed01e907f 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -448,9 +448,9 @@ FOR_KNOWN_FOUNDATION_TYPES(CACHE_FOUNDATION_DECL) /// For uniquifying `AutoDiffIndexSubset` allocations. llvm::FoldingSet AutoDiffIndexSubsets; - /// For uniquifying `AutoDiffAssociatedFunctionIdentifier` allocations. - llvm::FoldingSet - AutoDiffAssociatedFunctionIdentifiers; + /// For uniquifying `AutoDiffDerivativeFunctionIdentifier` allocations. + llvm::FoldingSet + AutoDiffDerivativeFunctionIdentifiers; /// A cache of information about whether particular nominal types /// are representable in a foreign language. @@ -4827,12 +4827,12 @@ AutoDiffIndexSubset::get(ASTContext &ctx, const SmallBitVector &indices) { return newNode; } -AutoDiffAssociatedFunctionIdentifier * -AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind kind, AutoDiffIndexSubset *parameterIndices, +AutoDiffDerivativeFunctionIdentifier * +AutoDiffDerivativeFunctionIdentifier::get( + AutoDiffDerivativeFunctionKind kind, AutoDiffIndexSubset *parameterIndices, ASTContext &C) { assert(parameterIndices); - auto &foldingSet = C.getImpl().AutoDiffAssociatedFunctionIdentifiers; + auto &foldingSet = C.getImpl().AutoDiffDerivativeFunctionIdentifiers; llvm::FoldingSetNodeID id; id.AddInteger((unsigned)kind); id.AddPointer(parameterIndices); @@ -4842,9 +4842,9 @@ AutoDiffAssociatedFunctionIdentifier::get( if (existing) return existing; - void *mem = C.Allocate(sizeof(AutoDiffAssociatedFunctionIdentifier), - alignof(AutoDiffAssociatedFunctionIdentifier)); - auto *newNode = ::new (mem) AutoDiffAssociatedFunctionIdentifier( + void *mem = C.Allocate(sizeof(AutoDiffDerivativeFunctionIdentifier), + alignof(AutoDiffDerivativeFunctionIdentifier)); + auto *newNode = ::new (mem) AutoDiffDerivativeFunctionIdentifier( kind, parameterIndices); foldingSet.InsertNode(newNode, insertPos); diff --git a/lib/AST/ASTMangler.cpp b/lib/AST/ASTMangler.cpp index f69b889aab113..efd460f158243 100644 --- a/lib/AST/ASTMangler.cpp +++ b/lib/AST/ASTMangler.cpp @@ -379,8 +379,8 @@ std::string ASTMangler::mangleReabstractionThunkHelper( return finalize(); } -std::string ASTMangler::mangleAutoDiffAssociatedFunctionHelper( - StringRef name, AutoDiffAssociatedFunctionKind kind, +std::string ASTMangler::mangleAutoDiffDerivativeFunctionHelper( + StringRef name, AutoDiffDerivativeFunctionKind kind, const SILAutoDiffIndices &indices) { // TODO(TF-20): Make the mangling scheme robust. // TODO(TF-680): Mangle `@differentiable` atttribute requirements as well. @@ -388,10 +388,10 @@ std::string ASTMangler::mangleAutoDiffAssociatedFunctionHelper( Buffer << "AD__" << name << '_'; switch (kind) { - case AutoDiffAssociatedFunctionKind::JVP: + case AutoDiffDerivativeFunctionKind::JVP: Buffer << "_jvp_"; break; - case AutoDiffAssociatedFunctionKind::VJP: + case AutoDiffDerivativeFunctionKind::VJP: Buffer << "_vjp_"; break; } diff --git a/lib/AST/Attr.cpp b/lib/AST/Attr.cpp index 5c291c11caf9c..93caf6775a4a4 100644 --- a/lib/AST/Attr.cpp +++ b/lib/AST/Attr.cpp @@ -480,7 +480,7 @@ static std::string getTransposingParametersClauseString( static void printDifferentiableAttrArguments( const DifferentiableAttr *attr, ASTPrinter &printer, PrintOptions Options, const Decl *D, bool omitWrtClause = false, - bool omitAssociatedFunctions = false) { + bool omitDerivativeFunctions = false) { // Create a temporary string for the attribute argument text. std::string attrArgText; llvm::raw_string_ostream stream(attrArgText); @@ -520,8 +520,8 @@ static void printDifferentiableAttrArguments( stream << diffParamsString; } } - // Print associated function names, unless they are to be omitted. - if (!omitAssociatedFunctions) { + // Print derivative function names, unless they are to be omitted. + if (!omitDerivativeFunctions) { // Print jvp function name, if specified. if (auto jvp = attr->getJVP()) { printCommaIfNecessary(); @@ -1517,11 +1517,11 @@ GenericEnvironment *DifferentiableAttr::getDerivativeGenericEnvironment( void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D, bool omitWrtClause, - bool omitAssociatedFunctions) const { + bool omitDerivativeFunctions) const { StreamPrinter P(OS); P << "@" << getAttrName(); printDifferentiableAttrArguments(this, P, PrintOptions(), D, omitWrtClause, - omitAssociatedFunctions); + omitDerivativeFunctions); } // SWIFT_ENABLE_TENSORFLOW diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index 5fa24f94b9e43..7656b142ccca8 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -25,8 +25,8 @@ bool SILAutoDiffIndices::operator==(const SILAutoDiffIndices &other) const { return source == other.source && parameters == other.parameters; } -AutoDiffAssociatedFunctionKind:: -AutoDiffAssociatedFunctionKind(StringRef string) { +AutoDiffDerivativeFunctionKind:: +AutoDiffDerivativeFunctionKind(StringRef string) { Optional result = llvm::StringSwitch>(string) .Case("jvp", JVP).Case("vjp", VJP); @@ -124,16 +124,16 @@ void autodiff::getSubsetParameterTypes(AutoDiffIndexSubset *subset, } bool autodiff::getBuiltinAutoDiffApplyConfig( - StringRef operationName, AutoDiffAssociatedFunctionKind &kind, + StringRef operationName, AutoDiffDerivativeFunctionKind &kind, unsigned &arity, bool &rethrows) { if (!operationName.startswith("autodiffApply_")) return false; operationName = operationName.drop_front(strlen("autodiffApply_")); // Parse 'jvp' or 'vjp'. if (operationName.startswith("jvp")) - kind = AutoDiffAssociatedFunctionKind::JVP; + kind = AutoDiffDerivativeFunctionKind::JVP; else if (operationName.startswith("vjp")) - kind = AutoDiffAssociatedFunctionKind::VJP; + kind = AutoDiffDerivativeFunctionKind::VJP; operationName = operationName.drop_front(3); // Parse '_arity'. if (operationName.startswith("_arity")) { @@ -156,27 +156,27 @@ bool autodiff::getBuiltinAutoDiffApplyConfig( return operationName.empty(); } -SILLinkage autodiff::getAutoDiffAssociatedFunctionLinkage( - SILLinkage originalLinkage, bool isAssocFnExported) { +SILLinkage autodiff::getAutoDiffDerivativeFunctionLinkage( + SILLinkage originalLinkage, bool isDerivativeFnExported) { // If the original is defined externally, then the AD pass is just generating - // associated functions for use in the current module and therefore these - // associated functions should not be visible outside the module. + // derivative functions for use in the current module and therefore these + // derivative functions should not be visible outside the module. if (isAvailableExternally(originalLinkage)) return SILLinkage::Hidden; // If the original is public, then external modules may need to link the - // associated function. Return the linkage of the original function, unless - // the associated function is not exported (i.e. differentiation is not + // derivative function. Return the linkage of the original function, unless + // the derivative function is not exported (i.e. differentiation is not // explicitly requested via a `[differentiable]` attribute on the original // function). if (originalLinkage == SILLinkage::Public || originalLinkage == SILLinkage::PublicNonABI || originalLinkage == SILLinkage::Shared) - return isAssocFnExported ? originalLinkage : SILLinkage::Hidden; + return isDerivativeFnExported ? originalLinkage : SILLinkage::Hidden; // Otherwise, the original function is defined and used only in the current // module, so external modules will never try to access the associated - // function. Make the associated function hidden. + // function. Make the derivative function hidden. return SILLinkage::Hidden; } diff --git a/lib/AST/Builtins.cpp b/lib/AST/Builtins.cpp index f11b2bb3621da..9d1578fb28ae9 100644 --- a/lib/AST/Builtins.cpp +++ b/lib/AST/Builtins.cpp @@ -994,8 +994,8 @@ static ValueDecl *getGetObjCTypeEncodingOperation(ASTContext &Context, } // SWIFT_ENABLE_TENSORFLOW -static ValueDecl *getAutoDiffApplyAssociatedFunction( - ASTContext &Context, Identifier Id, AutoDiffAssociatedFunctionKind kind, +static ValueDecl *getAutoDiffApplyDerivativeFunction( + ASTContext &Context, Identifier Id, AutoDiffDerivativeFunctionKind kind, unsigned arity, bool rethrows) { assert(arity >= 1); // JVP: @@ -1035,17 +1035,17 @@ static ValueDecl *getAutoDiffApplyAssociatedFunction( } }; // Eagerly build the type of the first arg, then use that to compute the type - // of the associated function type. + // of the derivative function type. auto *origFnTy = firstArgGen.build(builder)->castTo(); origFnTy = origFnTy->getWithoutDifferentiability()->withExtInfo( origFnTy->getExtInfo().withNoEscape(false)); auto *paramIndices = AutoDiffIndexSubset::get( Context, SmallBitVector(origFnTy->getNumParams(), true)); - // Generator for the resultant function type, i.e. the AD associated function. + // Generator for the resultant function type, i.e. the AD derivative function. BuiltinGenericSignatureBuilder::LambdaGenerator resultGen{ [=, &Context](BuiltinGenericSignatureBuilder &builder) -> Type { - auto derivativeFnTy = origFnTy->getAutoDiffAssociatedFunctionType( + auto derivativeFnTy = origFnTy->getAutoDiffDerivativeFunctionType( paramIndices, /*resultIndex*/ 0, kind, LookUpConformanceInModule(Context.TheBuiltinModule)); return derivativeFnTy->getResult(); @@ -1840,13 +1840,13 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) { } // SWIFT_ENABLE_TENSORFLOW if (OperationName.startswith("autodiffApply_")) { - AutoDiffAssociatedFunctionKind kind; + AutoDiffDerivativeFunctionKind kind; unsigned arity; bool rethrows; if (!autodiff::getBuiltinAutoDiffApplyConfig(OperationName, kind, arity, rethrows)) return nullptr; - return getAutoDiffApplyAssociatedFunction(Context, Id, kind, arity, + return getAutoDiffApplyDerivativeFunction(Context, Id, kind, arity, rethrows); } auto BV = llvm::StringSwitch(OperationName) diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index 72766ae027d3b..a9f03bb0e5fe7 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -4567,9 +4567,9 @@ Optional TypeBase::getAutoDiffAssociatedTangentSpace( return cache(None); } -AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType( +AnyFunctionType *AnyFunctionType::getAutoDiffDerivativeFunctionType( AutoDiffIndexSubset *indices, unsigned resultIndex, - AutoDiffAssociatedFunctionKind kind, LookupConformanceFn lookupConformance, + AutoDiffDerivativeFunctionKind kind, LookupConformanceFn lookupConformance, GenericSignature *whereClauseGenSig, bool makeSelfParamFirst) { // JVP: (T...) -> ((R...), // (T.TangentVector...) -> (R.TangentVector...)) @@ -4606,7 +4606,7 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType( // JVP or VJP. Type closure; switch (kind) { - case AutoDiffAssociatedFunctionKind::JVP: { + case AutoDiffDerivativeFunctionKind::JVP: { // closure is the JVP "differential": // (T.TangentVector...) -> (R.TangentVector...) SmallVector differentialParams; @@ -4635,7 +4635,7 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType( closure = FunctionType::get(differentialParams, differentialResult); break; } - case AutoDiffAssociatedFunctionKind::VJP: { + case AutoDiffDerivativeFunctionKind::VJP: { // closure is the VJP "pullback": // (R.TangentVector...) -> (T.TangentVector...) SmallVector pullbackParams; @@ -4673,22 +4673,22 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType( retElts.push_back(originalResult); retElts.push_back(closure); auto retTy = TupleType::get(retElts, ctx); - auto *associatedFunction = makeFunctionType( + auto *derivativeFunction = makeFunctionType( curryLevels.back(), curryLevels.back()->getParams(), retTy, curryLevels.size() == 1 ? whereClauseGenSig : nullptr); - // Wrap the associated function type in additional curry levels. + // Wrap the derivative function type in additional curry levels. auto curryLevelsWithoutLast = ArrayRef(curryLevels).drop_back(1); for (auto pair : enumerate(reversed(curryLevelsWithoutLast))) { unsigned i = pair.index(); AnyFunctionType *curryLevel = pair.value(); - associatedFunction = makeFunctionType( - curryLevel, curryLevel->getParams(), associatedFunction, + derivativeFunction = makeFunctionType( + curryLevel, curryLevel->getParams(), derivativeFunction, i == curryLevelsWithoutLast.size() - 1 ? whereClauseGenSig : nullptr); } - return associatedFunction; + return derivativeFunction; } // SWIFT_ENABLE_TENSORFLOW @@ -4716,7 +4716,7 @@ AnyFunctionType::getAutoDiffOriginalFunctionType() { curryLevels.back(), curryLevels.back()->getParams(), originalResult, curryLevels.size() == 1 ? getOptGenericSignature() : nullptr); - // Wrap the associated function type in additional curry levels. + // Wrap the derivative function type in additional curry levels. auto curryLevelsWithoutLast = ArrayRef(curryLevels).drop_back(1); for (auto pair : enumerate(reversed(curryLevelsWithoutLast))) { diff --git a/lib/IRGen/GenDiffFunc.cpp b/lib/IRGen/GenDiffFunc.cpp index 9f86b7c2b61ea..7629ae9b74d02 100644 --- a/lib/IRGen/GenDiffFunc.cpp +++ b/lib/IRGen/GenDiffFunc.cpp @@ -63,8 +63,8 @@ class DiffFuncFieldInfo final : public RecordField { auto origFnTy = fnTy->getWithoutDifferentiability(); if (Index == DifferentiableFunctionExtractee::Original) return SILType::getPrimitiveObjectType(origFnTy); - auto kind = *Index.getExtracteeAsAssociatedFunction(); - auto assocTy = origFnTy->getAutoDiffAssociatedFunctionType( + auto kind = *Index.getExtracteeAsDerivativeFunction(); + auto assocTy = origFnTy->getAutoDiffDerivativeFunctionType( ParameterIndices, /*resultIndex*/ 0, kind, IGM.getSILTypes(), LookUpConformanceInModule(IGM.getSwiftModule())); return SILType::getPrimitiveObjectType(assocTy); @@ -152,8 +152,8 @@ class DiffFuncTypeBuilder SILType getType(DiffFuncIndex field) { if (field == DifferentiableFunctionExtractee::Original) return SILType::getPrimitiveObjectType(origFnTy->getCanonicalType()); - auto kind = *field.getExtracteeAsAssociatedFunction(); - auto assocTy = origFnTy->getAutoDiffAssociatedFunctionType( + auto kind = *field.getExtracteeAsDerivativeFunction(); + auto assocTy = origFnTy->getAutoDiffDerivativeFunctionType( parameterIndices, /*resultIndex*/ 0, kind, IGM.getSILTypes(), LookUpConformanceInModule(IGM.getSwiftModule())); return SILType::getPrimitiveObjectType(assocTy); diff --git a/lib/IRGen/IRGenSIL.cpp b/lib/IRGen/IRGenSIL.cpp index f875e1e0d6e54..082a299a25629 100644 --- a/lib/IRGen/IRGenSIL.cpp +++ b/lib/IRGen/IRGenSIL.cpp @@ -1872,7 +1872,7 @@ void IRGenSILFunction::visitSILBasicBlock(SILBasicBlock *BB) { // SWIFT_ENABLE_TENSORFLOW void IRGenSILFunction:: visitDifferentiableFunctionInst(DifferentiableFunctionInst *i) { - // The original function and associated functions can be thin or thick. + // The original function and derivative functions can be thin or thick. auto origExp = getLoweredExplosion(i->getOriginalFunction()); Explosion e; e.add(origExp.claimAll()); diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index daf296a3f345d..3a1dc3355dadd 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -1545,7 +1545,7 @@ bool SILParser::parseSILDeclRef(SILDeclRef &Result, unsigned uncurryLevel = 0; bool IsObjC = false; // SWIFT_ENABLE_TENSORFLOW - AutoDiffAssociatedFunctionIdentifier *autoDiffFuncId = nullptr; + AutoDiffDerivativeFunctionIdentifier *autoDiffFuncId = nullptr; if (!P.consumeIf(tok::sil_exclamation)) { // Construct SILDeclRef. @@ -1635,13 +1635,13 @@ bool SILParser::parseSILDeclRef(SILDeclRef &Result, // SWIFT_ENABLE_TENSORFLOW ParseState = 3; } else if (Id.str() == "jvp" || Id.str() == "vjp") { - AutoDiffAssociatedFunctionKind kind; + AutoDiffDerivativeFunctionKind kind; AutoDiffIndexSubset *parameterIndices = nullptr; if (Id.str() == "jvp") - kind = AutoDiffAssociatedFunctionKind::JVP; + kind = AutoDiffDerivativeFunctionKind::JVP; else if (Id.str() == "vjp") - kind = AutoDiffAssociatedFunctionKind::VJP; + kind = AutoDiffDerivativeFunctionKind::VJP; else llvm_unreachable("Should only have JVP and VJP here"); @@ -1658,7 +1658,7 @@ bool SILParser::parseSILDeclRef(SILDeclRef &Result, } P.consumeToken(); - autoDiffFuncId = AutoDiffAssociatedFunctionIdentifier::get( + autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get( kind, parameterIndices, SILMod.getASTContext()); break; @@ -2960,9 +2960,9 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { // Parse an optional operand list `with { , }`. if (P.Tok.is(tok::identifier) && P.Tok.getText() == "with") { P.consumeToken(tok::identifier); - // Parse associated function values as an operand list. + // Parse derivative function values as an operand list. // FIXME(rxwei): Change this to *not* require a type signature once - // we can infer AD associated function types. + // we can infer derivative function types. SILValue derivFn1, derivFn2; if (P.parseToken(tok::l_brace, diag::sil_inst_autodiff_operand_list_expected_lbrace) || @@ -3003,7 +3003,7 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { diag::sil_inst_autodiff_expected_associated_function_kind_attr) || P.parseToken(tok::r_square, diag::sil_inst_autodiff_attr_expected_rsquare, - "associated function kind")) + "derivative function kind")) return true; if (parseTypedValueRef(functionOperand, B) || parseSILDebugLocation(InstLoc, B)) diff --git a/lib/SIL/SILDeclRef.cpp b/lib/SIL/SILDeclRef.cpp index f4703f10d6f73..a37b3ba8ce627 100644 --- a/lib/SIL/SILDeclRef.cpp +++ b/lib/SIL/SILDeclRef.cpp @@ -116,19 +116,19 @@ bool swift::requiresForeignEntryPoint(ValueDecl *vd) { SILDeclRef::SILDeclRef(ValueDecl *vd, SILDeclRef::Kind kind, // SWIFT_ENABLE_TENSORFLOW bool isCurried, bool isForeign, - AutoDiffAssociatedFunctionIdentifier *autoDiffFuncId) + AutoDiffDerivativeFunctionIdentifier *autoDiffFuncId) : loc(vd), kind(kind), isCurried(isCurried), isForeign(isForeign), // SWIFT_ENABLE_TENSORFLOW isDirectReference(0), defaultArgIndex(0), - autoDiffAssociatedFunctionIdentifier(autoDiffFuncId) + autoDiffDerivativeFunctionIdentifier(autoDiffFuncId) {} SILDeclRef::SILDeclRef(SILDeclRef::Loc baseLoc, bool isCurried, bool asForeign) // SWIFT_ENABLE_TENSORFLOW : isCurried(isCurried), isDirectReference(0), defaultArgIndex(0), - autoDiffAssociatedFunctionIdentifier(nullptr) + autoDiffDerivativeFunctionIdentifier(nullptr) { if (auto *vd = baseLoc.dyn_cast()) { if (auto *fd = dyn_cast(vd)) { @@ -688,14 +688,14 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const { ASTMangler mangler; // SWIFT_ENABLE_TENSORFLOW - if (autoDiffAssociatedFunctionIdentifier) { + if (autoDiffDerivativeFunctionIdentifier) { std::string originalMangled = asAutoDiffOriginalFunction().mangle(MKind); auto *silParameterIndices = autodiff::getLoweredParameterIndices( - autoDiffAssociatedFunctionIdentifier->getParameterIndices(), + autoDiffDerivativeFunctionIdentifier->getParameterIndices(), getDecl()->getInterfaceType()->castTo()); SILAutoDiffIndices indices(/*source*/ 0, silParameterIndices); - auto derivativeFnKind = autoDiffAssociatedFunctionIdentifier->getKind(); - return mangler.mangleAutoDiffAssociatedFunctionHelper( + auto derivativeFnKind = autoDiffDerivativeFunctionIdentifier->getKind(); + return mangler.mangleAutoDiffDerivativeFunctionHelper( originalMangled, derivativeFnKind, indices); } @@ -825,8 +825,8 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const { // SWIFT_ENABLE_TENSORFLOW // Returns true if the given JVP/VJP SILDeclRef requires a new vtable entry. -static bool autoDiffAssociatedFunctionRequiresNewVTableEntry(SILDeclRef ref) { - assert(ref.autoDiffAssociatedFunctionIdentifier); +static bool autoDiffDerivativeFunctionRequiresNewVTableEntry(SILDeclRef ref) { + assert(ref.autoDiffDerivativeFunctionIdentifier); auto overridden = ref.getOverridden(); if (!overridden) return false; @@ -835,16 +835,16 @@ static bool autoDiffAssociatedFunctionRequiresNewVTableEntry(SILDeclRef ref) { ref.getDecl()->getAttrs().getAttributes(), [&](const DifferentiableAttr *derivedAttr) { return derivedAttr->getParameterIndices() == - ref.autoDiffAssociatedFunctionIdentifier->getParameterIndices(); + ref.autoDiffDerivativeFunctionIdentifier->getParameterIndices(); }); assert(derivedDA && "Expected `@differentiable` attribute"); // If the derived `@differentiable` attribute specifies a JVP/VJP, - switch (ref.autoDiffAssociatedFunctionIdentifier->getKind()) { - case AutoDiffAssociatedFunctionKind::JVP: + switch (ref.autoDiffDerivativeFunctionIdentifier->getKind()) { + case AutoDiffDerivativeFunctionKind::JVP: if (!overridden.requiresNewVTableEntry() && derivedDA->getJVP()) return true; break; - case AutoDiffAssociatedFunctionKind::VJP: + case AutoDiffDerivativeFunctionKind::VJP: if (!overridden.requiresNewVTableEntry() && derivedDA->getVJP()) return true; break; @@ -853,7 +853,7 @@ static bool autoDiffAssociatedFunctionRequiresNewVTableEntry(SILDeclRef ref) { overridden.getDecl()->getAttrs().getAttributes(); for (auto *baseDA : baseDAs) { if (baseDA->getParameterIndices() == - ref.autoDiffAssociatedFunctionIdentifier->getParameterIndices()) + ref.autoDiffDerivativeFunctionIdentifier->getParameterIndices()) return false; } return true; @@ -861,8 +861,8 @@ static bool autoDiffAssociatedFunctionRequiresNewVTableEntry(SILDeclRef ref) { bool SILDeclRef::requiresNewVTableEntry() const { // SWIFT_ENABLE_TENSORFLOW - if (autoDiffAssociatedFunctionIdentifier) - if (autoDiffAssociatedFunctionRequiresNewVTableEntry(*this)) + if (autoDiffDerivativeFunctionIdentifier) + if (autoDiffDerivativeFunctionRequiresNewVTableEntry(*this)) return true; // SWIFT_ENABLE_TENSORFLOW END if (cast(getDecl())->needsNewVTableEntry()) @@ -887,7 +887,7 @@ SILDeclRef SILDeclRef::getOverridden() const { // SWIFT_ENABLE_TENSORFLOW return SILDeclRef(overridden, kind, isCurried, isForeign, - autoDiffAssociatedFunctionIdentifier); + autoDiffDerivativeFunctionIdentifier); } SILDeclRef SILDeclRef::getNextOverriddenVTableEntry() const { @@ -942,12 +942,12 @@ SILDeclRef SILDeclRef::getNextOverriddenVTableEntry() const { // SWIFT_ENABLE_TENSORFLOW // JVPs/VJPs are overridden only if the base declaration has a // `@differentiable` with the same parameter indices. - if (autoDiffAssociatedFunctionIdentifier) { + if (autoDiffDerivativeFunctionIdentifier) { auto overriddenAttrs = overridden.getDecl()->getAttrs().getAttributes(); if (llvm::none_of(overriddenAttrs, [&](const DifferentiableAttr *attr) { return attr->getParameterIndices() == - autoDiffAssociatedFunctionIdentifier->getParameterIndices(); + autoDiffDerivativeFunctionIdentifier->getParameterIndices(); })) { return SILDeclRef(); } @@ -964,7 +964,7 @@ SILDeclRef SILDeclRef::getOverriddenWitnessTableEntry() const { getOverriddenWitnessTableEntry(cast(getDecl())); // SWIFT_ENABLE_TENSORFLOW return SILDeclRef(bestOverridden, kind, isCurried, isForeign, - autoDiffAssociatedFunctionIdentifier); + autoDiffDerivativeFunctionIdentifier); } AbstractFunctionDecl *SILDeclRef::getOverriddenWitnessTableEntry( diff --git a/lib/SIL/SILFunctionBuilder.cpp b/lib/SIL/SILFunctionBuilder.cpp index 5ea0fdb488e4f..94bd9be02a81d 100644 --- a/lib/SIL/SILFunctionBuilder.cpp +++ b/lib/SIL/SILFunctionBuilder.cpp @@ -80,7 +80,7 @@ void SILFunctionBuilder::addFunctionAttributes(SILFunction *F, // - Thunks. Those are currently handled in SILGenThunk.cpp. if ((!isa(decl) || cast(decl)->isGetter()) && constant.kind != SILDeclRef::Kind::DefaultArgGenerator && - !constant.autoDiffAssociatedFunctionIdentifier && + !constant.autoDiffDerivativeFunctionIdentifier && !constant.isStoredPropertyInitializer() && !constant.isThunk()) { for (auto *A : Attrs.getAttributes()) { @@ -99,15 +99,15 @@ void SILFunctionBuilder::addFunctionAttributes(SILFunction *F, if (auto *jvpFn = A->getJVPFunction()) { Mangle::ASTMangler mangler; jvpName = ctx.getIdentifier( - mangler.mangleAutoDiffAssociatedFunctionHelper( - constant.mangle(), AutoDiffAssociatedFunctionKind::JVP, + mangler.mangleAutoDiffDerivativeFunctionHelper( + constant.mangle(), AutoDiffDerivativeFunctionKind::JVP, indices)).str(); } if (auto *vjpFn = A->getVJPFunction()) { Mangle::ASTMangler mangler; vjpName = ctx.getIdentifier( - mangler.mangleAutoDiffAssociatedFunctionHelper( - constant.mangle(), AutoDiffAssociatedFunctionKind::VJP, + mangler.mangleAutoDiffDerivativeFunctionHelper( + constant.mangle(), AutoDiffDerivativeFunctionKind::VJP, indices)).str(); } auto *silDiffAttr = SILDifferentiableAttr::create( diff --git a/lib/SIL/SILFunctionType.cpp b/lib/SIL/SILFunctionType.cpp index c04e71414a358..08118cef163f1 100644 --- a/lib/SIL/SILFunctionType.cpp +++ b/lib/SIL/SILFunctionType.cpp @@ -150,10 +150,10 @@ CanSILFunctionType SILFunctionType::getWithoutDifferentiability() { getOptionalErrorResult(), getASTContext()); } -// Returns the canonical generic signature for an autodiff associated function -// given an existing associated function generic signature. All differentiation +// Returns the canonical generic signature for an autodiff derivative function +// given an existing derivative function generic signature. All differentiation // parameters are constrained to conform to `Differentiable`. -static CanGenericSignature getAutoDiffAssociatedFunctionGenericSignature( +static CanGenericSignature getAutoDiffDerivativeFunctionGenericSignature( CanGenericSignature derivativeFnGenSig, ArrayRef originalParameters, AutoDiffIndexSubset *parameterIndices, ModuleDecl *module) { @@ -162,7 +162,7 @@ static CanGenericSignature getAutoDiffAssociatedFunctionGenericSignature( auto &ctx = module->getASTContext(); GenericSignatureBuilder builder(ctx); - // Add associated function generic signature. + // Add derivative function generic signature. builder.addGenericSignature(derivativeFnGenSig); // Constrain all wrt parameters to conform to `Differentiable`. auto source = @@ -179,9 +179,9 @@ static CanGenericSignature getAutoDiffAssociatedFunctionGenericSignature( ->getCanonicalSignature(); } -CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType( +CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType( AutoDiffIndexSubset *parameterIndices, unsigned resultIndex, - AutoDiffAssociatedFunctionKind kind, TypeConverter &TC, + AutoDiffDerivativeFunctionKind kind, TypeConverter &TC, LookupConformanceFn lookupConformance, CanGenericSignature derivativeFnGenSig) { // JVP: (T...) -> ((R...), @@ -203,10 +203,10 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType( if (isWrtIndex(valueAndIndex.index())) wrtParams.push_back(valueAndIndex.value()); - // Get the canonical associated function generic signature. + // Get the canonical derivative function generic signature. if (!derivativeFnGenSig) derivativeFnGenSig = getGenericSignature(); - derivativeFnGenSig = getAutoDiffAssociatedFunctionGenericSignature( + derivativeFnGenSig = getAutoDiffDerivativeFunctionGenericSignature( derivativeFnGenSig, getParameters(), parameterIndices, &TC.M); Lowering::GenericContextScope genericContextScope(TC, derivativeFnGenSig); @@ -259,7 +259,7 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType( CanSILFunctionType closureType; switch (kind) { - case AutoDiffAssociatedFunctionKind::JVP: { + case AutoDiffDerivativeFunctionKind::JVP: { SmallVector differentialParams; for (auto ¶m : wrtParams) { auto paramTan = @@ -281,7 +281,7 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType( differentialResults, None, ctx); break; } - case AutoDiffAssociatedFunctionKind::VJP: { + case AutoDiffDerivativeFunctionKind::VJP: { SmallVector pullbackParams; auto &origRes = getResults()[resultIndex]; auto resultTan = @@ -2334,23 +2334,23 @@ const SILConstantInfo &TypeConverter::getConstantInfo(SILDeclRef constant) { loweredInterfaceType); // SWIFT_ENABLE_TENSORFLOW - // In the case of autodiff associated functions, the above computations - // determine `silFnType` by first computing the associated function type at + // In the case of autodiff derivative functions, the above computations + // determine `silFnType` by first computing the derivative function type at // the AST level and then lowering that. Unfortunately, the actual // SILFunctionType for the function is determined by first lowering the - // function's AST type, and then computing the associated function type at the + // function's AST type, and then computing the derivative function type at the // SIL level. "Lowering" does not commute with "getting the autodiff // associated type", so these two computations produce different results. // Therefore `silFnType` is not the actual type of the function that // `constant` refers to. // // We hackily fix this problem by redoing the computation in the right order. - if (auto *autoDiffFuncId = constant.autoDiffAssociatedFunctionIdentifier) { + if (auto *autoDiffFuncId = constant.autoDiffDerivativeFunctionIdentifier) { auto origFnConstantInfo = getConstantInfo(constant.asAutoDiffOriginalFunction()); auto loweredIndices = autodiff::getLoweredParameterIndices( autoDiffFuncId->getParameterIndices(), formalInterfaceType); - silFnType = origFnConstantInfo.SILFnType->getAutoDiffAssociatedFunctionType( + silFnType = origFnConstantInfo.SILFnType->getAutoDiffDerivativeFunctionType( loweredIndices, /*resultIndex*/ 0, autoDiffFuncId->getKind(), *this, LookUpConformanceInModule(&M)); } diff --git a/lib/SIL/SILInstructions.cpp b/lib/SIL/SILInstructions.cpp index f7d75e08fcfc1..f90a1dded6597 100644 --- a/lib/SIL/SILInstructions.cpp +++ b/lib/SIL/SILInstructions.cpp @@ -625,12 +625,12 @@ DifferentiableFunctionInst *DifferentiableFunctionInst::create( } DifferentiableFunctionExtractInst::Extractee::Extractee( - AutoDiffAssociatedFunctionKind kind) { + AutoDiffDerivativeFunctionKind kind) { switch (kind) { - case AutoDiffAssociatedFunctionKind::JVP: + case AutoDiffDerivativeFunctionKind::JVP: rawValue = JVP; return; - case AutoDiffAssociatedFunctionKind::VJP: + case AutoDiffDerivativeFunctionKind::VJP: rawValue = VJP; return; } @@ -646,16 +646,16 @@ DifferentiableFunctionExtractInst::Extractee::Extractee(StringRef string) { rawValue = *result; } -Optional -DifferentiableFunctionExtractInst::Extractee::getExtracteeAsAssociatedFunction() +Optional +DifferentiableFunctionExtractInst::Extractee::getExtracteeAsDerivativeFunction() const { switch (rawValue) { case Original: return None; case JVP: - return {AutoDiffAssociatedFunctionKind::JVP}; + return {AutoDiffDerivativeFunctionKind::JVP}; case VJP: - return {AutoDiffAssociatedFunctionKind::VJP}; + return {AutoDiffDerivativeFunctionKind::VJP}; } } @@ -664,12 +664,12 @@ getExtracteeType(SILValue function, Extractee extractee, SILModule &module) { auto fnTy = function->getType().castTo(); assert(fnTy->getExtInfo().isDifferentiable()); auto originalFnTy = fnTy->getWithoutDifferentiability(); - auto kindOpt = extractee.getExtracteeAsAssociatedFunction(); + auto kindOpt = extractee.getExtracteeAsDerivativeFunction(); if (!kindOpt) { assert(extractee == Extractee::Original); return SILType::getPrimitiveObjectType(originalFnTy); } - auto resultFnTy = originalFnTy->getAutoDiffAssociatedFunctionType( + auto resultFnTy = originalFnTy->getAutoDiffDerivativeFunctionType( fnTy->getDifferentiationParameterIndices(), /*resultIndex*/ 0, *kindOpt, module.Types, LookUpConformanceInModule(module.getSwiftModule())); diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index 0aa45bf1300d7..dd830e41cfbab 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -349,15 +349,15 @@ void SILDeclRef::print(raw_ostream &OS) const { OS << ((isDot || uncurryLevel != 0) ? '.' : '!') << "direct"; // SWIFT_ENABLE_TENSORFLOW - if (autoDiffAssociatedFunctionIdentifier) { - auto *autoDiffFuncId = autoDiffAssociatedFunctionIdentifier; + if (autoDiffDerivativeFunctionIdentifier) { + auto *autoDiffFuncId = autoDiffDerivativeFunctionIdentifier; OS << ((isDot || uncurryLevel != 0 || isForeign || isDirectReference) ? '.' : '!'); switch (autoDiffFuncId->getKind()) { - case AutoDiffAssociatedFunctionKind::JVP: + case AutoDiffDerivativeFunctionKind::JVP: OS << "jvp."; break; - case AutoDiffAssociatedFunctionKind::VJP: + case AutoDiffDerivativeFunctionKind::VJP: OS << "vjp."; break; } diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp index b5e7d8b982c9d..42ae01866d9bc 100644 --- a/lib/SIL/SILVerifier.cpp +++ b/lib/SIL/SILVerifier.cpp @@ -1505,9 +1505,9 @@ class SILVerifier : public SILVerifierBase { require(jvpType, "The JVP function must have a function type"); require(!jvpType->isDifferentiable(), "The JVP function must not be @differentiable"); - auto expectedJVPType = origTy->getAutoDiffAssociatedFunctionType( + auto expectedJVPType = origTy->getAutoDiffDerivativeFunctionType( dfi->getParameterIndices(), /*resultIndex*/ 0, - AutoDiffAssociatedFunctionKind::JVP, TC, + AutoDiffDerivativeFunctionKind::JVP, TC, LookUpConformanceInModule(M)); requireSameType(SILType::getPrimitiveObjectType(jvpType), SILType::getPrimitiveObjectType(expectedJVPType), @@ -1517,9 +1517,9 @@ class SILVerifier : public SILVerifierBase { require(vjpType, "The VJP function must have a function type"); require(!vjpType->isDifferentiable(), "The VJP function must not be @differentiable"); - auto expectedVJPType = origTy->getAutoDiffAssociatedFunctionType( + auto expectedVJPType = origTy->getAutoDiffDerivativeFunctionType( dfi->getParameterIndices(), /*resultIndex*/ 0, - AutoDiffAssociatedFunctionKind::VJP, TC, + AutoDiffDerivativeFunctionKind::VJP, TC, LookUpConformanceInModule(M)); requireSameType(SILType::getPrimitiveObjectType(vjpType), SILType::getPrimitiveObjectType(expectedVJPType), diff --git a/lib/SIL/TypeLowering.cpp b/lib/SIL/TypeLowering.cpp index 55a9f94ac411d..118a157b72bd8 100644 --- a/lib/SIL/TypeLowering.cpp +++ b/lib/SIL/TypeLowering.cpp @@ -152,13 +152,13 @@ namespace { assert(type->isDifferentiable()); auto &M = TC.M; auto origTy = type->getWithoutDifferentiability(); - auto jvpTy = origTy->getAutoDiffAssociatedFunctionType( + auto jvpTy = origTy->getAutoDiffDerivativeFunctionType( type->getDifferentiationParameterIndices(), /*resultIndex*/ 0, - AutoDiffAssociatedFunctionKind::JVP, TC, + AutoDiffDerivativeFunctionKind::JVP, TC, LookUpConformanceInModule(&M)); - auto vjpTy = origTy->getAutoDiffAssociatedFunctionType( + auto vjpTy = origTy->getAutoDiffDerivativeFunctionType( type->getDifferentiationParameterIndices(), /*resultIndex*/ 0, - AutoDiffAssociatedFunctionKind::VJP, TC, + AutoDiffDerivativeFunctionKind::VJP, TC, LookUpConformanceInModule(&M)); RecursiveProperties props; props.addSubobject(classifyType(origTy, TC, Sig, Expansion)); @@ -888,18 +888,18 @@ namespace { void lowerChildren(TypeConverter &TC, SmallVectorImpl &children) const override { auto fnTy = getLoweredType().castTo(); - auto numAssocFns = 2; - children.reserve(numAssocFns + 1); + auto numDerivativeFns = 2; + children.reserve(numDerivativeFns + 1); auto origFnTy = fnTy->getWithoutDifferentiability(); auto paramIndices = fnTy->getDifferentiationParameterIndices(); children.push_back(Child{ DifferentiableFunctionExtractee::Original, TC.getTypeLowering(origFnTy, getResilienceExpansion()) }); - for (AutoDiffAssociatedFunctionKind kind : - {AutoDiffAssociatedFunctionKind::JVP, - AutoDiffAssociatedFunctionKind::VJP}) { - auto derivativeFnTy = origFnTy->getAutoDiffAssociatedFunctionType( + for (AutoDiffDerivativeFunctionKind kind : + {AutoDiffDerivativeFunctionKind::JVP, + AutoDiffDerivativeFunctionKind::VJP}) { + auto derivativeFnTy = origFnTy->getAutoDiffDerivativeFunctionType( paramIndices, 0, kind, TC, LookUpConformanceInModule(&TC.M)); auto silTy = SILType::getPrimitiveObjectType(derivativeFnTy); @@ -908,7 +908,7 @@ namespace { // was caused by implicit conversions from `unsigned` to // `DifferentiableFunctionExtractee` which resulted into a wrong // extractee. - assert(extractee.getExtracteeAsAssociatedFunction() == kind); + assert(extractee.getExtracteeAsDerivativeFunction() == kind); children.push_back(Child{ extractee, TC.getTypeLowering(silTy, getResilienceExpansion())}); } @@ -2007,10 +2007,10 @@ getFunctionInterfaceTypeWithCaptures(TypeConverter &TC, CanAnyFunctionType TypeConverter::makeConstantInterfaceType(SILDeclRef c) { // SWIFT_ENABLE_TENSORFLOW - if (auto *autoDiffFuncId = c.autoDiffAssociatedFunctionIdentifier) { + if (auto *autoDiffFuncId = c.autoDiffDerivativeFunctionIdentifier) { auto originalFnTy = makeConstantInterfaceType(c.asAutoDiffOriginalFunction()); - auto *fnTy = originalFnTy->getAutoDiffAssociatedFunctionType( + auto *fnTy = originalFnTy->getAutoDiffDerivativeFunctionType( autoDiffFuncId->getParameterIndices(), /*resultIndex*/ 0, autoDiffFuncId->getKind(), LookUpConformanceInModule(&M)); return cast(fnTy->getCanonicalType()); diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp index 36655ae74d840..0076ade2d4e00 100644 --- a/lib/SILGen/SILGen.cpp +++ b/lib/SILGen/SILGen.cpp @@ -540,7 +540,7 @@ static SILFunction *getFunctionToInsertAfter(SILGenModule &SGM, return nullptr; } -static bool haveProfiledAssociatedFunction(SILDeclRef constant) { +static bool haveProfiledDerivativeFunction(SILDeclRef constant) { return constant.isDefaultArgGenerator() || constant.isForeign || constant.isCurried; } @@ -552,7 +552,7 @@ static void setUpForProfiling(SILDeclRef constant, SILFunction *F, return; ASTNode profiledNode; - if (!haveProfiledAssociatedFunction(constant)) { + if (!haveProfiledDerivativeFunction(constant)) { if (constant.hasDecl()) { if (auto *fd = constant.getFuncDecl()) { if (fd->hasBody()) { @@ -773,12 +773,12 @@ void SILGenModule::postEmitFunction(SILDeclRef constant, "Expected matching @differentiable and [differentiable]"); auto lookUpConformance = LookUpConformanceInModule(M.getSwiftModule()); - auto expectedJVPType = origSilFnType->getAutoDiffAssociatedFunctionType( + auto expectedJVPType = origSilFnType->getAutoDiffDerivativeFunctionType( indices.parameters, indices.source, - AutoDiffAssociatedFunctionKind::JVP, Types, lookUpConformance); - auto expectedVJPType = origSilFnType->getAutoDiffAssociatedFunctionType( + AutoDiffDerivativeFunctionKind::JVP, Types, lookUpConformance); + auto expectedVJPType = origSilFnType->getAutoDiffDerivativeFunctionType( indices.parameters, indices.source, - AutoDiffAssociatedFunctionKind::VJP, Types, lookUpConformance); + AutoDiffDerivativeFunctionKind::VJP, Types, lookUpConformance); // Self reordering is necessary if wrt at least two parameters, including // self. @@ -797,15 +797,15 @@ void SILGenModule::postEmitFunction(SILDeclRef constant, SILFunction *jvpThunk; auto *jvpFn = getFunction(SILDeclRef(jvpDecl), NotForDefinition); if (reorderSelf || jvpFn->getLoweredFunctionType() != expectedJVPType) { - jvpThunk = getOrCreateAutoDiffAssociatedFunctionThunk( - F, indices, jvpFn, AutoDiffAssociatedFunctionKind::JVP, + jvpThunk = getOrCreateAutoDiffDerivativeFunctionThunk( + F, indices, jvpFn, AutoDiffDerivativeFunctionKind::JVP, reorderSelf); } else { - auto *id = AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::JVP, + auto *id = AutoDiffDerivativeFunctionIdentifier::get( + AutoDiffDerivativeFunctionKind::JVP, diffAttr->getParameterIndices(), AFD->getASTContext()); jvpThunk = getOrCreateAutoDiffThunk( - constant.asAutoDiffAssociatedFunction(id), jvpFn, + constant.asAutoDiffDerivativeFunction(id), jvpFn, expectedJVPType); } silDiffAttr->setJVPName(jvpThunk->getName()); @@ -815,15 +815,15 @@ void SILGenModule::postEmitFunction(SILDeclRef constant, SILFunction *vjpThunk; auto *vjpFn = getFunction(SILDeclRef(vjpDecl), NotForDefinition); if (reorderSelf || vjpFn->getLoweredFunctionType() != expectedVJPType) { - vjpThunk = getOrCreateAutoDiffAssociatedFunctionThunk( - F, indices, vjpFn, AutoDiffAssociatedFunctionKind::VJP, + vjpThunk = getOrCreateAutoDiffDerivativeFunctionThunk( + F, indices, vjpFn, AutoDiffDerivativeFunctionKind::VJP, reorderSelf); } else { - auto *id = AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::VJP, + auto *id = AutoDiffDerivativeFunctionIdentifier::get( + AutoDiffDerivativeFunctionKind::VJP, diffAttr->getParameterIndices(), AFD->getASTContext()); vjpThunk = getOrCreateAutoDiffThunk( - constant.asAutoDiffAssociatedFunction(id), vjpFn, + constant.asAutoDiffDerivativeFunction(id), vjpFn, expectedVJPType); } silDiffAttr->setVJPName(vjpThunk->getName()); diff --git a/lib/SILGen/SILGen.h b/lib/SILGen/SILGen.h index 99c27d72b8257..c69aabcdba131 100644 --- a/lib/SILGen/SILGen.h +++ b/lib/SILGen/SILGen.h @@ -148,15 +148,15 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor { CanSILFunctionType constantTy); // SWIFT_ENABLE_TENSORFLOW - /// Get or create an autodiff associated function thunk for the given - /// SILDeclRef, SILFunction, and associated function type. + /// Get or create an autodiff derivative function thunk for the given + /// SILDeclRef, SILFunction, and derivative function type. SILFunction *getOrCreateAutoDiffThunk(SILDeclRef derivativeFnRef, SILFunction *derivativeFn, CanSILFunctionType derivativeFnTy); // SWIFT_ENABLE_TENSORFLOW - /// Get or create an autodiff associated function vtable entry thunk for the - /// given SILDeclRef and associated function type. + /// Get or create an autodiff derivative function vtable entry thunk for the + /// given SILDeclRef and derivative function type. SILFunction * getOrCreateAutoDiffClassMethodThunk(SILDeclRef derivativeFnRef, CanSILFunctionType derivativeFnTy); @@ -185,10 +185,10 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor { /// If `reorderSelf` is true, reorder self so that it appears as: /// - The last parameter in the returned differential. /// - The last result in the returned pullback. - SILFunction *getOrCreateAutoDiffAssociatedFunctionThunk( + SILFunction *getOrCreateAutoDiffDerivativeFunctionThunk( SILFunction *original, SILAutoDiffIndices &indices, SILFunction *derivativeFn, - AutoDiffAssociatedFunctionKind derivativeFnKind, bool reorderSelf); + AutoDiffDerivativeFunctionKind derivativeFnKind, bool reorderSelf); /// Determine whether the given class has any instance variables that /// need to be destroyed. diff --git a/lib/SILGen/SILGenBuiltin.cpp b/lib/SILGen/SILGenBuiltin.cpp index 79a9f0d7b6e64..4f477e45773f3 100644 --- a/lib/SILGen/SILGenBuiltin.cpp +++ b/lib/SILGen/SILGenBuiltin.cpp @@ -1031,8 +1031,8 @@ static ManagedValue emitBuiltinTypeTrait(SILGenFunction &SGF, } // SWIFT_ENABLE_TENSORFLOW -static ManagedValue emitBuiltinAutoDiffApplyAssociatedFunction( - AutoDiffAssociatedFunctionKind kind, unsigned arity, +static ManagedValue emitBuiltinAutoDiffApplyDerivativeFunction( + AutoDiffDerivativeFunctionKind kind, unsigned arity, bool rethrows, SILGenFunction &SGF, SILLocation loc, SubstitutionMap substitutions, ArrayRef args, SGFContext C) { auto origFnVal = args[0].getValue(); @@ -1040,7 +1040,7 @@ static ManagedValue emitBuiltinAutoDiffApplyAssociatedFunction( for (auto& arg : args.drop_front(1)) origFnArgVals.push_back(arg.getValue()); - // Get the associated function. + // Get the derivative function. SILValue derivativeFn = SGF.B.createDifferentiableFunctionExtract( loc, kind, origFnVal); auto derivativeFnType = derivativeFn->getType().castTo(); @@ -1143,13 +1143,13 @@ static ManagedValue emitBuiltinAutoDiffApply(SILGenFunction &SGF, cast(callExpr->getDirectCallee())->getRHS()) ->getDecl()); auto builtinName = builtinDecl->getName().str(); - AutoDiffAssociatedFunctionKind kind; + AutoDiffDerivativeFunctionKind kind; unsigned arity; bool rethrows; auto successfullyParsed = autodiff::getBuiltinAutoDiffApplyConfig( builtinName, kind, arity, rethrows); assert(successfullyParsed); - return emitBuiltinAutoDiffApplyAssociatedFunction(kind, arity, + return emitBuiltinAutoDiffApplyDerivativeFunction(kind, arity, rethrows, SGF, loc, substitutions, args, C); } diff --git a/lib/SILGen/SILGenPoly.cpp b/lib/SILGen/SILGenPoly.cpp index 37a92badba502..d68fc63516144 100644 --- a/lib/SILGen/SILGenPoly.cpp +++ b/lib/SILGen/SILGenPoly.cpp @@ -3292,16 +3292,16 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF, auto *parameterIndices = AutoDiffIndexSubset::get( SGF.getASTContext(), parameterBits); - auto getAssocFnTy = - [&](CanAnyFunctionType fnTy, AutoDiffAssociatedFunctionKind kind) + auto getDerivativeFnTy = + [&](CanAnyFunctionType fnTy, AutoDiffDerivativeFunctionKind kind) -> CanAnyFunctionType { - auto assocTy = fnTy->getAutoDiffAssociatedFunctionType( + auto assocTy = fnTy->getAutoDiffDerivativeFunctionType( parameterIndices, /*resultIndex*/ 0, kind, LookUpConformanceInModule(SGF.SGM.M.getSwiftModule())); return cast(assocTy->getCanonicalType()); }; - auto getAssocFnPattern = - [&](AbstractionPattern pattern, AutoDiffAssociatedFunctionKind kind) + auto getDerivativeFnPattern = + [&](AbstractionPattern pattern, AutoDiffDerivativeFunctionKind kind) -> AbstractionPattern { // If pattern does not store an `AnyFunctionType`, return original // pattern. This logic handles opaque abstraction patterns. @@ -3309,30 +3309,30 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF, if (!patternType) return pattern; return AbstractionPattern( - pattern.getGenericSignature(), getAssocFnTy(patternType, kind)); + pattern.getGenericSignature(), getDerivativeFnTy(patternType, kind)); }; - auto createAssocFnThunk = - [&](AutoDiffAssociatedFunctionKind kind) -> ManagedValue { + auto createDerivativeFnThunk = + [&](AutoDiffDerivativeFunctionKind kind) -> ManagedValue { auto derivativeFnInputOrigType = - getAssocFnPattern(inputOrigTypeNotDiff, kind); - auto derivativeFnInputSubstType = getAssocFnTy(inputSubstTypeNotDiff, kind); - auto derivativeFnOutputOrigType = getAssocFnPattern(outputOrigTypeNotDiff, + getDerivativeFnPattern(inputOrigTypeNotDiff, kind); + auto derivativeFnInputSubstType = getDerivativeFnTy(inputSubstTypeNotDiff, kind); + auto derivativeFnOutputOrigType = getDerivativeFnPattern(outputOrigTypeNotDiff, kind); auto derivativeFnOutputSubstType = - getAssocFnTy(outputSubstTypeNotDiff, kind); + getDerivativeFnTy(outputSubstTypeNotDiff, kind); auto &derivativeFnExpectedTL = SGF.getTypeLowering( derivativeFnOutputOrigType, derivativeFnOutputSubstType); SILValue derivativeFn = SGF.B.createDifferentiableFunctionExtract( loc, kind, borrowedFnValue.getValue()); derivativeFn = SGF.B.emitCopyValueOperation(loc, derivativeFn); - auto managedAssocFn = SGF.emitManagedRValueWithCleanup(derivativeFn); - return createThunk(SGF, loc, managedAssocFn, derivativeFnInputOrigType, + auto managedDerivativeFn = SGF.emitManagedRValueWithCleanup(derivativeFn); + return createThunk(SGF, loc, managedDerivativeFn, derivativeFnInputOrigType, derivativeFnInputSubstType, derivativeFnOutputOrigType, derivativeFnOutputSubstType, derivativeFnExpectedTL); }; - auto jvpThunk = createAssocFnThunk(AutoDiffAssociatedFunctionKind::JVP); - auto vjpThunk = createAssocFnThunk(AutoDiffAssociatedFunctionKind::VJP); + auto jvpThunk = createDerivativeFnThunk(AutoDiffDerivativeFunctionKind::JVP); + auto vjpThunk = createDerivativeFnThunk(AutoDiffDerivativeFunctionKind::VJP); SILValue convertedBundle = SGF.B.createDifferentiableFunction( loc, sourceType->getDifferentiationParameterIndices(), @@ -3666,9 +3666,9 @@ SILGenFunction::getThunkedAutoDiffLinearMap( // SWIFT_ENABLE_TENSORFLOW SILFunction * -SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk( +SILGenModule::getOrCreateAutoDiffDerivativeFunctionThunk( SILFunction *original, SILAutoDiffIndices &indices, - SILFunction *derivativeFn, AutoDiffAssociatedFunctionKind derivativeFnKind, + SILFunction *derivativeFn, AutoDiffDerivativeFunctionKind derivativeFnKind, bool reorderSelf) { auto derivativeFnType = derivativeFn->getLoweredFunctionType(); @@ -3676,7 +3676,7 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk( // Do not simply reuse reabstraction thunk mangling. Mangle::ASTMangler mangler; auto name = getASTContext().getIdentifier( - mangler.mangleAutoDiffAssociatedFunctionHelper( + mangler.mangleAutoDiffDerivativeFunctionHelper( original->getName(), derivativeFnKind, indices)).str(); Lowering::GenericContextScope genericContextScope( @@ -3686,18 +3686,18 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk( : nullptr; auto origFnType = original->getLoweredFunctionType(); - auto origAssocFnType = origFnType->getAutoDiffAssociatedFunctionType( + auto origDerivativeFnType = origFnType->getAutoDiffDerivativeFunctionType( indices.parameters, indices.source, derivativeFnKind, Types, LookUpConformanceInModule(M.getSwiftModule()), derivativeFnType->getGenericSignature()); - assert(!origAssocFnType->getExtInfo().hasContext()); + assert(!origDerivativeFnType->getExtInfo().hasContext()); auto loc = derivativeFn->getLocation(); SILGenFunctionBuilder fb(*this); - auto linkage = autodiff::getAutoDiffAssociatedFunctionLinkage( - original->getLinkage(), /*isAssocFnExported*/ true); + auto linkage = autodiff::getAutoDiffDerivativeFunctionLinkage( + original->getLinkage(), /*isDerivativeFnExported*/ true); auto *thunk = fb.getOrCreateFunction( - loc, name, linkage, origAssocFnType, IsBare, IsNotTransparent, + loc, name, linkage, origDerivativeFnType, IsBare, IsNotTransparent, derivativeFn->isSerialized(), derivativeFn->isDynamicallyReplaceable(), derivativeFn->getEntryCount(), derivativeFn->isThunk(), derivativeFn->getClassSubclassScope()); @@ -3743,7 +3743,7 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk( derivativeFnRefType->getResults().back().getType()) ->getCanonicalType()); auto targetLinearMapFnType = thunk->mapTypeIntoContext( - origAssocFnType->getResults().back().getSILStorageType()) + origDerivativeFnType->getResults().back().getSILStorageType()) .castTo(); if (!reorderSelf && linearMapFnType == targetLinearMapFnType) { createReturn(apply); @@ -4304,7 +4304,7 @@ getWitnessFunctionRef(SILGenFunction &SGF, switch (witnessKind) { case WitnessDispatchKind::Static: // SWIFT_ENABLE_TENSORFLOW - if (auto *autoDiffFuncId = witness.autoDiffAssociatedFunctionIdentifier) { + if (auto *autoDiffFuncId = witness.autoDiffDerivativeFunctionIdentifier) { auto originalFn = SGF.emitGlobalFunctionRef( loc, witness.asAutoDiffOriginalFunction()); auto loweredIndices = autodiff::getLoweredParameterIndices( @@ -4320,7 +4320,7 @@ getWitnessFunctionRef(SILGenFunction &SGF, return SGF.emitGlobalFunctionRef(loc, witness); case WitnessDispatchKind::Dynamic: // SWIFT_ENABLE_TENSORFLOW - assert(!witness.autoDiffAssociatedFunctionIdentifier); + assert(!witness.autoDiffDerivativeFunctionIdentifier); return SGF.emitDynamicMethodRef(loc, witness, witnessFTy).getValue(); case WitnessDispatchKind::Witness: { auto typeAndConf = diff --git a/lib/SILGen/SILGenThunk.cpp b/lib/SILGen/SILGenThunk.cpp index 578e22ae008d6..acddba5fb228a 100644 --- a/lib/SILGen/SILGenThunk.cpp +++ b/lib/SILGen/SILGenThunk.cpp @@ -76,15 +76,15 @@ SILGenModule::getOrCreateAutoDiffThunk(SILDeclRef derivativeFnDeclRef, SILFunction *derivativeFn, CanSILFunctionType derivativeFnTy) { auto *autoDiffFuncId = - derivativeFnDeclRef.autoDiffAssociatedFunctionIdentifier; + derivativeFnDeclRef.autoDiffDerivativeFunctionIdentifier; assert(autoDiffFuncId); auto *derivativeFnDecl = derivativeFnDeclRef.getDecl(); SILGenFunctionBuilder builder(*this); auto originalFn = derivativeFnDeclRef.asAutoDiffOriginalFunction(); auto originalLinkage = originalFn.getLinkage(ForDefinition); - auto linkage = autodiff::getAutoDiffAssociatedFunctionLinkage( - originalLinkage, /*isAssocFnExported*/ true); + auto linkage = autodiff::getAutoDiffDerivativeFunctionLinkage( + originalLinkage, /*isDerivativeFnExported*/ true); auto name = derivativeFnDeclRef.mangle(); auto *thunk = builder.getOrCreateFunction( derivativeFnDecl, name, linkage, derivativeFnTy, IsBare, IsTransparent, @@ -100,11 +100,11 @@ SILGenModule::getOrCreateAutoDiffThunk(SILDeclRef derivativeFnDeclRef, auto loc = derivativeFnDeclRef.getAsRegularLocation(); SGF.collectThunkParams(loc, params); auto derivativeFnRef = SGF.B.createFunctionRef(loc, derivativeFn); - auto autoDiffAssocFnSILTy = SILType::getPrimitiveObjectType(derivativeFnTy); + auto autoDiffDerivativeFnSILTy = SILType::getPrimitiveObjectType(derivativeFnTy); SmallVector args(thunk->getArguments().begin(), thunk->getArguments().end()); auto apply = SGF.emitApplyWithRethrow( - loc, derivativeFnRef, autoDiffAssocFnSILTy, + loc, derivativeFnRef, autoDiffDerivativeFnSILTy, SGF.getForwardingSubstitutionMap(), args); SGF.B.createReturn(loc, apply); return thunk; @@ -114,15 +114,15 @@ SILGenModule::getOrCreateAutoDiffThunk(SILDeclRef derivativeFnDeclRef, SILFunction *SILGenModule::getOrCreateAutoDiffClassMethodThunk( SILDeclRef derivativeFnDeclRef, CanSILFunctionType constantTy) { auto *autoDiffFuncId = - derivativeFnDeclRef.autoDiffAssociatedFunctionIdentifier; + derivativeFnDeclRef.autoDiffDerivativeFunctionIdentifier; assert(autoDiffFuncId); auto *derivativeFnDecl = derivativeFnDeclRef.getDecl(); SILGenFunctionBuilder builder(*this); auto originalFn = derivativeFnDeclRef.asAutoDiffOriginalFunction(); auto originalLinkage = originalFn.getLinkage(ForDefinition); - auto linkage = autodiff::getAutoDiffAssociatedFunctionLinkage( - originalLinkage, /*isAssocFnExported*/ true); + auto linkage = autodiff::getAutoDiffDerivativeFunctionLinkage( + originalLinkage, /*isDerivativeFnExported*/ true); // TODO(TF-685): Use principled thunk mangling. // Do not simply reuse reabstraction thunk mangling. auto name = derivativeFnDeclRef.mangle() + "_vtable_entry_thunk"; @@ -145,13 +145,13 @@ SILFunction *SILGenModule::getOrCreateAutoDiffClassMethodThunk( derivativeFnDecl->getInterfaceType()->castTo()); auto diffFn = SGF.B.createDifferentiableFunction( loc, loweredIndices, originalFnRef); - auto diffAssocFn = SGF.B.createDifferentiableFunctionExtract( + auto diffDerivativeFn = SGF.B.createDifferentiableFunctionExtract( loc, DifferentiableFunctionExtractee(autoDiffFuncId->getKind()), diffFn); - auto autoDiffAssocFnSILTy = SILType::getPrimitiveObjectType(constantTy); + auto autoDiffDerivativeFnSILTy = SILType::getPrimitiveObjectType(constantTy); SmallVector args(thunk->getArguments().begin(), thunk->getArguments().end()); auto apply = SGF.emitApplyWithRethrow( - loc, diffAssocFn, autoDiffAssocFnSILTy, + loc, diffDerivativeFn, autoDiffDerivativeFnSILTy, SGF.getForwardingSubstitutionMap(), args); SGF.B.createReturn(loc, apply); return thunk; @@ -184,7 +184,7 @@ getNextUncurryLevelRef(SILGenFunction &SGF, SILLocation loc, SILDeclRef thunk, // SWIFT_ENABLE_TENSORFLOW SILDeclRef next = SILDeclRef(vd, thunk.kind, /*isCurried*/ false, /*isForeign*/ false, - thunk.autoDiffAssociatedFunctionIdentifier); + thunk.autoDiffDerivativeFunctionIdentifier); assert(!next.isCurried); auto constantInfo = SGF.SGM.Types.getConstantInfo(next); diff --git a/lib/SILGen/SILGenType.cpp b/lib/SILGen/SILGenType.cpp index 65d32bbb75cfa..c015135902683 100644 --- a/lib/SILGen/SILGenType.cpp +++ b/lib/SILGen/SILGenType.cpp @@ -89,7 +89,7 @@ SILGenModule::emitVTableMethod(ClassDecl *theClass, if (usesObjCDynamicDispatch) { implFn = getDynamicThunk(derived, Types.getConstantInfo(derived).SILFnType); // SWIFT_ENABLE_TENSORFLOW - } else if (auto *adafi = derived.autoDiffAssociatedFunctionIdentifier) { + } else if (auto *adafi = derived.autoDiffDerivativeFunctionIdentifier) { // For JVP/VJP methods, create a vtable entry thunk. The thunk contains an // `differentiable_function` instruction, which is later filled during the // differentiation transform. @@ -153,12 +153,12 @@ SILGenModule::emitVTableMethod(ClassDecl *theClass, } // SWIFT_ENABLE_TENSORFLOW // TODO: Use proper mangling. - if (auto *adafi = derived.autoDiffAssociatedFunctionIdentifier) { + if (auto *adafi = derived.autoDiffDerivativeFunctionIdentifier) { switch (adafi->getKind()) { - case AutoDiffAssociatedFunctionKind::JVP: + case AutoDiffDerivativeFunctionKind::JVP: name += "_jvp"; break; - case AutoDiffAssociatedFunctionKind::VJP: + case AutoDiffDerivativeFunctionKind::VJP: name += "_vjp"; break; } @@ -684,13 +684,13 @@ SILFunction *SILGenModule::emitProtocolWitness( // SWIFT_ENABLE_TENSORFLOW // TODO: Proper mangling for autodiff witness thunks. if (auto *autoDiffFuncId = - requirement.autoDiffAssociatedFunctionIdentifier) { + requirement.autoDiffDerivativeFunctionIdentifier) { std::string kindString; switch (autoDiffFuncId->getKind()) { - case AutoDiffAssociatedFunctionKind::JVP: + case AutoDiffDerivativeFunctionKind::JVP: kindString = "jvp"; break; - case AutoDiffAssociatedFunctionKind::VJP: + case AutoDiffDerivativeFunctionKind::VJP: kindString = "vjp"; break; } diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 7f3290aa00d25..b9d5de842e2d1 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -878,10 +878,10 @@ class ADContext { /// Saved for deletion during cleanup. SmallVector generatedFunctions; - /// List of associated function references, generated via - /// `emitAssociatedFunctionReference`. + /// List of derivative function references, generated via + /// `emitDerivativeFunctionReference`. /// Saved for deletion during cleanup. - SmallVector generatedAssociatedFunctionReferences; + SmallVector generatedDerivativeFunctionReferences; /// The AdditiveArithmetic protocol in the standard library. ProtocolDecl *additiveArithmeticProtocol = @@ -933,8 +933,8 @@ class ADContext { return generatedFunctions; } - SmallVector &getGeneratedAssociatedFunctionReferences() { - return generatedAssociatedFunctionReferences; + SmallVector &getGeneratedDerivativeFunctionReferences() { + return generatedDerivativeFunctionReferences; } ProtocolDecl *getAdditiveArithmeticProtocol() const { @@ -969,9 +969,9 @@ class ADContext { original->removeDifferentiableAttr(attr); } // Delete all references to generated functions. - for (auto assocFn : generatedAssociatedFunctionReferences) { + for (auto derivativeFn : generatedDerivativeFunctionReferences) { if (auto *fnRef = - peerThroughFunctionConversions(assocFn)) { + peerThroughFunctionConversions(derivativeFn)) { fnRef->replaceAllUsesWithUndef(); fnRef->eraseFromParent(); } @@ -1109,9 +1109,9 @@ class ADContext { DifferentiableFunctionInst *createDifferentiableFunction( SILBuilder &builder, SILLocation loc, AutoDiffIndexSubset *parameterIndices, SILValue original, - Optional> associatedFunctions = None) { + Optional> derivativeFunctions = None) { auto *dfi = builder.createDifferentiableFunction( - loc, parameterIndices, original, associatedFunctions); + loc, parameterIndices, original, derivativeFunctions); processedDifferentiableFunctionInsts.erase(dfi); return dfi; } @@ -1131,7 +1131,7 @@ class ADContext { DifferentiationInvoker invoker); /// Process the given `differentiable_function` instruction, filling in - /// missing associated functions if necessary. + /// missing derivative functions if necessary. bool processDifferentiableFunctionInst(DifferentiableFunctionInst *dfi); /// Fold `differentiable_function_extract` users of the given @@ -1145,33 +1145,33 @@ class ADContext { /// purposes. void foldDifferentiableFunctionExtraction(DifferentiableFunctionInst *source); - /// Get or create an associated function index subset thunk from - /// `actualIndices` to `desiredIndices` for the given associated function + /// Get or create a derivative function index subset thunk from + /// `actualIndices` to `desiredIndices` for the given derivative function /// value and original function operand. /// Calls `getOrCreateSubsetParametersThunkForLinearMap` to thunk the linear - /// map returned by the associated function. + /// map returned by the derivative function. std::pair - getOrCreateSubsetParametersThunkForAssociatedFunction( - SILValue origFnOperand, SILValue assocFn, - AutoDiffAssociatedFunctionKind kind, SILAutoDiffIndices desiredIndices, + getOrCreateSubsetParametersThunkForDerivativeFunction( + SILValue origFnOperand, SILValue derivativeFn, + AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices); - /// Get or create an associated function index subset thunk from - /// `actualIndices` to `desiredIndices` for the given associated function + /// Get or create a derivative function index subset thunk from + /// `actualIndices` to `desiredIndices` for the given derivative function /// value and original function operand. SILFunction *getOrCreateSubsetParametersThunkForLinearMap( - SILFunction *assocFn, CanSILFunctionType linearMapType, - CanSILFunctionType targetType, AutoDiffAssociatedFunctionKind kind, + SILFunction *derivativeFn, CanSILFunctionType linearMapType, + CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices); public: - /// Declare an external reference to an associated function of `original`, + /// Declare an external reference to a derivative function of `original`, /// given a `[differentiable]` attribute of `original` and the associated /// function kind. SILFunction * - declareExternalAssociatedFunction(SILFunction *original, + declareExternalDerivativeFunction(SILFunction *original, SILDifferentiableAttr *attr, StringRef name, - AutoDiffAssociatedFunctionKind kind); + AutoDiffDerivativeFunctionKind kind); template InFlightDiagnostic diagnose(SourceLoc loc, Diag diag, @@ -1682,17 +1682,17 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai, if (checkNondifferentiableOriginalFunctionType(remappedOrigFnSubstTy)) return; - AutoDiffAssociatedFunctionKind assocFnKind(kind); - auto assocFnType = remappedOrigFnSubstTy->getAutoDiffAssociatedFunctionType( - parameters, source, assocFnKind, context.getTypeConverter(), + AutoDiffDerivativeFunctionKind derivativeFnKind(kind); + auto derivativeFnType = remappedOrigFnSubstTy->getAutoDiffDerivativeFunctionType( + parameters, source, derivativeFnKind, context.getTypeConverter(), LookUpConformanceInModule(derivative->getModule().getSwiftModule())); - auto assocFnResultTypes = - assocFnType->getAllResultsType().castTo(); - assocFnResultTypes->getElement(assocFnResultTypes->getElements().size() - 1); + auto derivativeFnResultTypes = + derivativeFnType->getAllResultsType().castTo(); + derivativeFnResultTypes->getElement(derivativeFnResultTypes->getElements().size() - 1); auto linearMapSILType = SILType::getPrimitiveObjectType( - assocFnResultTypes - ->getElement(assocFnResultTypes->getElements().size() - 1) + derivativeFnResultTypes + ->getElement(derivativeFnResultTypes->getElements().size() - 1) .getType() ->getCanonicalType()); addLinearMapDecl(ai, linearMapSILType); @@ -1700,21 +1700,21 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai, void LinearMapInfo::generateDifferentiationDataStructures( ADContext &context, const SILAutoDiffIndices &indices, - SILFunction *assocFn) { + SILFunction *derivativeFn) { auto &astCtx = original->getASTContext(); auto *loopAnalysis = context.getPassManager().getAnalysis(); auto *loopInfo = loopAnalysis->get(original); - // Get the associated function generic signature. - CanGenericSignature assocFnGenSig = nullptr; - if (auto *assocFnGenEnv = assocFn->getGenericEnvironment()) - assocFnGenSig = - assocFnGenEnv->getGenericSignature()->getCanonicalSignature(); + // Get the derivative function generic signature. + CanGenericSignature derivativeFnGenSig = nullptr; + if (auto *derivativeFnGenEnv = derivativeFn->getGenericEnvironment()) + derivativeFnGenSig = + derivativeFnGenEnv->getGenericSignature()->getCanonicalSignature(); // Create linear map struct for each original block. for (auto &origBB : *original) { auto *linearMapStruct = - createLinearMapStruct(&origBB, indices, assocFnGenSig); + createLinearMapStruct(&origBB, indices, derivativeFnGenSig); linearMapStructs.insert({&origBB, linearMapStruct}); } @@ -1731,7 +1731,7 @@ void LinearMapInfo::generateDifferentiationDataStructures( } for (auto &origBB : *original) { auto *traceEnum = - createBranchingTraceDecl(&origBB, indices, assocFnGenSig, loopInfo); + createBranchingTraceDecl(&origBB, indices, derivativeFnGenSig, loopInfo); branchingTraceDecls.insert({&origBB, traceEnum}); if (origBB.isEntry()) continue; @@ -1808,7 +1808,7 @@ class DifferentiableActivityCollection { PostDominanceInfo *postDomInfo; DifferentiableActivityInfo &getActivityInfo( - GenericSignature *assocGenSig, AutoDiffAssociatedFunctionKind kind) { + GenericSignature *assocGenSig, AutoDiffDerivativeFunctionKind kind) { auto activityInfoLookup = activityInfoMap.find(assocGenSig); if (activityInfoLookup != activityInfoMap.end()) return activityInfoLookup->getSecond(); @@ -2530,9 +2530,9 @@ static SubstitutionMap getSubstitutionMap( return substMap; } -/// Emits a reference to an associated function of `original`, differentiated +/// Emits a reference to a derivative function of `original`, differentiated /// with respect to a superset of `desiredIndices`. Returns the `SILValue` for -/// the associated function and the actual indices that the associated function +/// the derivative function and the actual indices that the derivative function /// is with respect to. /// /// Returns `None` on failure, signifying that a diagnostic has been emitted. @@ -2544,9 +2544,9 @@ static SubstitutionMap getSubstitutionMap( /// /// FIXME: This is too complicated and needs to be rewritten. static Optional> -emitAssociatedFunctionReference( +emitDerivativeFunctionReference( ADContext &context, SILBuilder &builder, SILAutoDiffIndices desiredIndices, - AutoDiffAssociatedFunctionKind kind, SILValue original, + AutoDiffDerivativeFunctionKind kind, SILValue original, DifferentiationInvoker invoker, SmallVectorImpl &newBuffersToDealloc) { @@ -2554,15 +2554,15 @@ emitAssociatedFunctionReference( // If `original` is itself an `DifferentiableFunctionExtractInst` whose kind matches // the given kind and desired differentiation parameter indices, simply - // extract the associated function of its function operand, retain the - // associated function, and return it. + // extract the derivative function of its function operand, retain the + // derivative function, and return it. if (auto *inst = original->getDefiningInstruction()) if (auto *dfei = dyn_cast(inst)) if (dfei->getExtractee() == DifferentiableFunctionExtractee::Original) functionSource = dfei->getFunctionOperand(); // If `functionSource` is a `@differentiable` function, just extract the - // associated function. + // derivative function. if (auto diffableFnType = functionSource->getType().castTo()) { if (diffableFnType->isDifferentiable()) { @@ -2576,13 +2576,13 @@ emitAssociatedFunctionReference( } auto borrowedDiffFunc = builder.emitBeginBorrowOperation( functionSource.getLoc(), functionSource); - SILValue assocFn = builder.createDifferentiableFunctionExtract( + SILValue derivativeFn = builder.createDifferentiableFunctionExtract( borrowedDiffFunc.getLoc(), kind, borrowedDiffFunc); - assocFn = - builder.emitCopyValueOperation(functionSource.getLoc(), assocFn); + derivativeFn = + builder.emitCopyValueOperation(functionSource.getLoc(), derivativeFn); builder.emitEndBorrowOperation(functionSource.getLoc(), borrowedDiffFunc); SILAutoDiffIndices indices(0, desiredIndices.parameters); - return std::make_pair(assocFn, indices); + return std::make_pair(derivativeFn, indices); } } @@ -2657,18 +2657,18 @@ emitAssociatedFunctionReference( if (context.processDifferentiableAttribute( originalFn, minimalAttr, invoker)) return None; - SILFunction *assocFn = nullptr; + SILFunction *derivativeFn = nullptr; switch (kind) { - case AutoDiffAssociatedFunctionKind::JVP: + case AutoDiffDerivativeFunctionKind::JVP: assert(!minimalAttr->getJVPName().empty() && "Expected JVP name"); - assocFn = context.getModule().lookUpFunction(minimalAttr->getJVPName()); + derivativeFn = context.getModule().lookUpFunction(minimalAttr->getJVPName()); break; - case AutoDiffAssociatedFunctionKind::VJP: + case AutoDiffDerivativeFunctionKind::VJP: assert(!minimalAttr->getVJPName().empty() && "Expected VJP name"); - assocFn = context.getModule().lookUpFunction(minimalAttr->getVJPName()); + derivativeFn = context.getModule().lookUpFunction(minimalAttr->getVJPName()); break; } - auto *assocFnRef = builder.createFunctionRef(loc, assocFn); + auto *derivativeFnRef = builder.createFunctionRef(loc, derivativeFn); // FIXME(TF-201): Handle direct differentiation of reabstraction thunks. // Tentative solution: clone a new reabstraction thunk where function // argument has a `@differentiable` function type. @@ -2676,9 +2676,9 @@ emitAssociatedFunctionReference( // Handle here. } auto convertedRef = reapplyFunctionConversion( - assocFnRef, originalFRI, original, builder, loc, + derivativeFnRef, originalFRI, original, builder, loc, newBuffersToDealloc, - assocFn->getLoweredFunctionType()->getGenericSignature()); + derivativeFn->getLoweredFunctionType()->getGenericSignature()); return std::make_pair(convertedRef, minimalAttr->getIndices()); } @@ -2711,17 +2711,17 @@ emitAssociatedFunctionReference( diag::autodiff_member_subset_indices_not_differentiable); return None; } - // Emit a `witness_method` instruction for the associated function. + // Emit a `witness_method` instruction for the derivative function. auto originalType = witnessMethod->getType().castTo(); - auto assocType = originalType->getAutoDiffAssociatedFunctionType( + auto assocType = originalType->getAutoDiffDerivativeFunctionType( minimalIndices.parameters, minimalIndices.source, kind, context.getTypeConverter(), LookUpConformanceInModule(builder.getModule().getSwiftModule())); - auto *autoDiffFuncId = AutoDiffAssociatedFunctionIdentifier::get( + auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get( kind, minimalAttr->getParameterIndices(), context.getASTContext()); auto *ref = builder.createWitnessMethod( loc, witnessMethod->getLookupType(), witnessMethod->getConformance(), - requirementDeclRef.asAutoDiffAssociatedFunction(autoDiffFuncId), + requirementDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId), SILType::getPrimitiveObjectType(assocType)); auto convertedRef = reapplyFunctionConversion(ref, witnessMethod, original, builder, loc, @@ -2758,18 +2758,18 @@ emitAssociatedFunctionReference( diag::autodiff_member_subset_indices_not_differentiable); return None; } - // Emit a `class_method` instruction for the associated function. + // Emit a `class_method` instruction for the derivative function. auto originalType = classMethodInst->getType().castTo(); - auto assocType = originalType->getAutoDiffAssociatedFunctionType( + auto assocType = originalType->getAutoDiffDerivativeFunctionType( minimalIndices.parameters, minimalIndices.source, kind, context.getTypeConverter(), LookUpConformanceInModule(builder.getModule().getSwiftModule())); - auto *autoDiffFuncId = AutoDiffAssociatedFunctionIdentifier::get( + auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get( kind, minimalAttr->getParameterIndices(), context.getASTContext()); auto *ref = builder.createClassMethod( loc, classMethodInst->getOperand(), - methodDeclRef.asAutoDiffAssociatedFunction(autoDiffFuncId), + methodDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId), SILType::getPrimitiveObjectType(assocType)); auto convertedRef = reapplyFunctionConversion(ref, classMethodInst, original, builder, loc, @@ -3316,7 +3316,7 @@ class VJPEmitter final auto &activityCollection = *activityAnalysis->get(original); auto &activityInfo = activityCollection.getActivityInfo( vjp->getLoweredFunctionType()->getGenericSignature(), - AutoDiffAssociatedFunctionKind::VJP); + AutoDiffDerivativeFunctionKind::VJP); LLVM_DEBUG( dumpActivityInfo(*original, indices, activityInfo, getADDebugStream())); return activityInfo; @@ -4238,7 +4238,7 @@ class JVPEmitter final auto &activityCollection = *activityAnalysis->get(original); auto &activityInfo = activityCollection.getActivityInfo( jvp->getLoweredFunctionType()->getGenericSignature(), - AutoDiffAssociatedFunctionKind::JVP); + AutoDiffDerivativeFunctionKind::JVP); LLVM_DEBUG( dumpActivityInfo(*original, indices, activityInfo, getADDebugStream())); return activityInfo; @@ -7831,26 +7831,26 @@ bool VJPEmitter::run() { //===----------------------------------------------------------------------===// SILFunction * -ADContext::declareExternalAssociatedFunction( +ADContext::declareExternalDerivativeFunction( SILFunction *original, SILDifferentiableAttr *attr, StringRef name, - AutoDiffAssociatedFunctionKind kind) { + AutoDiffDerivativeFunctionKind kind) { auto &module = getModule(); auto &indices = attr->getIndices(); auto originalTy = original->getLoweredFunctionType(); auto originalLoc = original->getLocation(); auto assocGenSig = getDerivativeGenericSignature(attr, original); - auto assocFnTy = originalTy->getAutoDiffAssociatedFunctionType( + auto derivativeFnTy = originalTy->getAutoDiffDerivativeFunctionType( indices.parameters, indices.source, kind, module.Types, LookUpConformanceInModule(module.getSwiftModule()), assocGenSig); SILOptFunctionBuilder fb(getTransform()); // Create external function declaration. - auto *assocFn = fb.createFunction( - SILLinkage::PublicExternal, name, assocFnTy, + auto *derivativeFn = fb.createFunction( + SILLinkage::PublicExternal, name, derivativeFnTy, /*genericEnv*/ nullptr, originalLoc, original->isBare(), IsNotTransparent, original->isSerialized(), original->isDynamicallyReplaceable()); // Note: Setting debug scope prevents crashes during later transforms. - assocFn->setDebugScope(new (module) SILDebugScope(originalLoc, assocFn)); - return assocFn; + derivativeFn->setDebugScope(new (module) SILDebugScope(originalLoc, derivativeFn)); + return derivativeFn; } static SILFunction *createEmptyVJP( @@ -7869,8 +7869,8 @@ static SILFunction *createEmptyVJP( // === Create an empty VJP. === Mangle::ASTMangler mangler; auto vjpName = original->getASTContext().getIdentifier( - mangler.mangleAutoDiffAssociatedFunctionHelper( - original->getName(), AutoDiffAssociatedFunctionKind::VJP, indices)) + mangler.mangleAutoDiffDerivativeFunctionHelper( + original->getName(), AutoDiffDerivativeFunctionKind::VJP, indices)) .str(); auto vjpGenericSig = getDerivativeGenericSignature(attr, original); @@ -7883,13 +7883,13 @@ static SILFunction *createEmptyVJP( auto *vjpGenericEnv = vjpGenericSig ? vjpGenericSig->getGenericEnvironment() : nullptr; - auto vjpType = originalTy->getAutoDiffAssociatedFunctionType( - indices.parameters, indices.source, AutoDiffAssociatedFunctionKind::VJP, + auto vjpType = originalTy->getAutoDiffDerivativeFunctionType( + indices.parameters, indices.source, AutoDiffDerivativeFunctionKind::VJP, module.Types, LookUpConformanceInModule(module.getSwiftModule()), vjpGenericSig); SILOptFunctionBuilder fb(context.getTransform()); - auto linkage = autodiff::getAutoDiffAssociatedFunctionLinkage( + auto linkage = autodiff::getAutoDiffDerivativeFunctionLinkage( original->getLinkage(), isExported); auto *vjp = fb.createFunction(linkage, vjpName, vjpType, vjpGenericEnv, original->getLocation(), original->isBare(), @@ -7919,8 +7919,8 @@ static SILFunction *createEmptyJVP( // === Create an empty JVP. === Mangle::ASTMangler mangler; auto jvpName = original->getASTContext().getIdentifier( - mangler.mangleAutoDiffAssociatedFunctionHelper( - original->getName(), AutoDiffAssociatedFunctionKind::JVP, indices)) + mangler.mangleAutoDiffDerivativeFunctionHelper( + original->getName(), AutoDiffDerivativeFunctionKind::JVP, indices)) .str(); auto jvpGenericSig = getDerivativeGenericSignature(attr, original); @@ -7933,13 +7933,13 @@ static SILFunction *createEmptyJVP( auto *jvpGenericEnv = jvpGenericSig ? jvpGenericSig->getGenericEnvironment() : nullptr; - auto jvpType = originalTy->getAutoDiffAssociatedFunctionType( + auto jvpType = originalTy->getAutoDiffDerivativeFunctionType( indices.parameters, indices.source, - AutoDiffAssociatedFunctionKind::JVP, module.Types, + AutoDiffDerivativeFunctionKind::JVP, module.Types, LookUpConformanceInModule(module.getSwiftModule()), jvpGenericSig); SILOptFunctionBuilder fb(context.getTransform()); - auto linkage = autodiff::getAutoDiffAssociatedFunctionLinkage( + auto linkage = autodiff::getAutoDiffDerivativeFunctionLinkage( original->getLinkage(), isExported); auto *jvp = fb.createFunction(linkage, jvpName, jvpType, jvpGenericEnv, original->getLocation(), original->isBare(), @@ -7968,21 +7968,21 @@ bool ADContext::processDifferentiableAttribute( } else if (original->isExternalDeclaration()) { Mangle::ASTMangler mangler; jvpName = original->getASTContext().getIdentifier( - mangler.mangleAutoDiffAssociatedFunctionHelper( - original->getName(), AutoDiffAssociatedFunctionKind::JVP, + mangler.mangleAutoDiffDerivativeFunctionHelper( + original->getName(), AutoDiffDerivativeFunctionKind::JVP, attr->getIndices())).str(); } if (!jvpName.empty()) { jvp = module.lookUpFunction(jvpName); if (!jvp) - jvp = declareExternalAssociatedFunction( - original, attr, jvpName, AutoDiffAssociatedFunctionKind::JVP); + jvp = declareExternalDerivativeFunction( + original, attr, jvpName, AutoDiffDerivativeFunctionKind::JVP); attr->setJVPName(jvpName); } - // If differentiation is triggered by `[differentiable]`, associated function + // If differentiation is triggered by `[differentiable]`, derivative function // should share linkage of original function. - auto isAssocFnExported = + auto isDerivativeFnExported = invoker.getKind() == DifferentiationInvoker::Kind::SILDifferentiableAttribute; @@ -7996,15 +7996,15 @@ bool ADContext::processDifferentiableAttribute( } else if (original->isExternalDeclaration()) { Mangle::ASTMangler mangler; vjpName = original->getASTContext().getIdentifier( - mangler.mangleAutoDiffAssociatedFunctionHelper( - original->getName(), AutoDiffAssociatedFunctionKind::VJP, + mangler.mangleAutoDiffDerivativeFunctionHelper( + original->getName(), AutoDiffDerivativeFunctionKind::VJP, attr->getIndices())).str(); } if (!vjpName.empty()) { vjp = module.lookUpFunction(vjpName); if (!vjp) - vjp = declareExternalAssociatedFunction( - original, attr, vjpName, AutoDiffAssociatedFunctionKind::VJP); + vjp = declareExternalDerivativeFunction( + original, attr, vjpName, AutoDiffDerivativeFunctionKind::VJP); attr->setVJPName(vjpName); } @@ -8017,7 +8017,7 @@ bool ADContext::processDifferentiableAttribute( diagnoseUnsupportedControlFlow(*this, original, invoker))) return true; - jvp = createEmptyJVP(*this, original, attr, isAssocFnExported); + jvp = createEmptyJVP(*this, original, attr, isDerivativeFnExported); getGeneratedFunctions().push_back(jvp); // For now, only do JVP generation if the flag is enabled and if custom VJP @@ -8076,7 +8076,7 @@ bool ADContext::processDifferentiableAttribute( diagnoseUnsupportedControlFlow(*this, original, invoker)) return true; - vjp = createEmptyVJP(*this, original, attr, isAssocFnExported); + vjp = createEmptyVJP(*this, original, attr, isDerivativeFnExported); getGeneratedFunctions().push_back(vjp); VJPEmitter emitter(*this, original, attr, vjp, invoker); return emitter.run(); @@ -8101,7 +8101,7 @@ class Differentiation : public SILModuleTransform { SILFunction * ADContext::getOrCreateSubsetParametersThunkForLinearMap( SILFunction *parentThunk, CanSILFunctionType linearMapType, - CanSILFunctionType targetType, AutoDiffAssociatedFunctionKind kind, + CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices) { LLVM_DEBUG(getADDebugStream() << "Getting a subset parameters thunk for " << linearMapType @@ -8117,10 +8117,10 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap( // TODO(TF-685): Use more principled mangling for thunks. std::string thunkName; switch (kind) { - case AutoDiffAssociatedFunctionKind::JVP: + case AutoDiffDerivativeFunctionKind::JVP: thunkName = "differential"; break; - case AutoDiffAssociatedFunctionKind::VJP: + case AutoDiffDerivativeFunctionKind::VJP: thunkName = "pullback"; } Mangle::ASTMangler mangler; @@ -8218,7 +8218,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap( // - Thunk arguments (when parameter index is in both desired and actual // indices). // - Zeros (when parameter is not in desired indices). - case AutoDiffAssociatedFunctionKind::JVP: { + case AutoDiffDerivativeFunctionKind::JVP: { // Forward all indirect results. arguments.append(thunk->getIndirectResults().begin(), thunk->getIndirectResults().end()); @@ -8248,7 +8248,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap( // actual indices). // - Zeros (when parameter is not in desired indices). // - All actual arguments. - case AutoDiffAssociatedFunctionKind::VJP: { + case AutoDiffDerivativeFunctionKind::VJP: { auto toIndirectResultsIter = thunk->getIndirectResults().begin(); auto useNextResult = [&]() { arguments.push_back(*toIndirectResultsIter++); @@ -8285,7 +8285,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap( // If differential thunk, deallocate local allocations and directly return // `apply` result. - if (kind == AutoDiffAssociatedFunctionKind::JVP) { + if (kind == AutoDiffDerivativeFunctionKind::JVP) { for (auto *alloc : reversed(localAllocations)) builder.createDeallocStack(loc, alloc); builder.createReturn(loc, ai); @@ -8329,13 +8329,13 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap( } std::pair -ADContext::getOrCreateSubsetParametersThunkForAssociatedFunction( - SILValue origFnOperand, SILValue assocFn, - AutoDiffAssociatedFunctionKind kind, SILAutoDiffIndices desiredIndices, +ADContext::getOrCreateSubsetParametersThunkForDerivativeFunction( + SILValue origFnOperand, SILValue derivativeFn, + AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices) { LLVM_DEBUG(getADDebugStream() - << "Getting a subset parameters thunk for associated function " - << assocFn << " of the original function " << origFnOperand + << "Getting a subset parameters thunk for derivative function " + << derivativeFn << " of the original function " << origFnOperand << " from " << actualIndices << " to " << desiredIndices << '\n'); auto origFnType = origFnOperand->getType().castTo(); @@ -8343,30 +8343,30 @@ ADContext::getOrCreateSubsetParametersThunkForAssociatedFunction( auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule()); // Compute target type for thunking. - auto assocFnType = assocFn->getType().castTo(); - auto targetType = origFnType->getAutoDiffAssociatedFunctionType( + auto derivativeFnType = derivativeFn->getType().castTo(); + auto targetType = origFnType->getAutoDiffDerivativeFunctionType( desiredIndices.parameters, desiredIndices.source, kind, module.Types, lookupConformance); - auto *caller = assocFn->getFunction(); + auto *caller = derivativeFn->getFunction(); if (targetType->hasArchetype()) { auto substTargetType = caller->mapTypeIntoContext( targetType->mapTypeOutOfContext())->getCanonicalType(); targetType = SILType::getPrimitiveObjectType(substTargetType) .castTo(); } - assert(assocFnType->getNumParameters() == targetType->getNumParameters()); - assert(assocFnType->getNumResults() == targetType->getNumResults()); + assert(derivativeFnType->getNumParameters() == targetType->getNumParameters()); + assert(derivativeFnType->getNumResults() == targetType->getNumResults()); // Build thunk type. SubstitutionMap interfaceSubs; GenericEnvironment *genericEnv = nullptr; auto thunkType = buildThunkType( - assocFn->getFunction(), assocFnType, targetType, genericEnv, + derivativeFn->getFunction(), derivativeFnType, targetType, genericEnv, interfaceSubs, /*withoutActuallyEscaping*/ false, DifferentiationThunkKind::IndexSubset); // FIXME: The logic for resolving `assocRef` does not reapply function - // conversions, which is problematic if `assocFn` is a `partial_apply` + // conversions, which is problematic if `derivativeFn` is a `partial_apply` // instruction. StringRef origName; if (auto *origFnRef = @@ -8381,15 +8381,15 @@ ADContext::getOrCreateSubsetParametersThunkForAssociatedFunction( // TODO(TF-685): Use more principled mangling for thunks. std::string thunkName; switch (kind) { - case AutoDiffAssociatedFunctionKind::JVP: + case AutoDiffDerivativeFunctionKind::JVP: thunkName = "jvp"; break; - case AutoDiffAssociatedFunctionKind::VJP: + case AutoDiffDerivativeFunctionKind::VJP: thunkName = "vjp"; } Mangle::ASTMangler mangler; auto fromInterfaceType = - assocFnType->mapTypeOutOfContext()->getCanonicalType(); + derivativeFnType->mapTypeOutOfContext()->getCanonicalType(); auto toInterfaceType = targetType->mapTypeOutOfContext()->getCanonicalType(); CanType dynamicSelfType; thunkName = "AD__orig_" + origName.str() + "_" + @@ -8415,25 +8415,25 @@ ADContext::getOrCreateSubsetParametersThunkForAssociatedFunction( createEntryArguments(thunk); SubstitutionMap assocSubstMap; - if (auto *partialApply = dyn_cast(assocFn)) + if (auto *partialApply = dyn_cast(derivativeFn)) assocSubstMap = partialApply->getSubstitutionMap(); // FIXME: The logic for resolving `assocRef` does not reapply function - // conversions, which is problematic if `assocFn` is a `partial_apply` + // conversions, which is problematic if `derivativeFn` is a `partial_apply` // instruction. SILValue assocRef; - if (auto *assocFnRef = - peerThroughFunctionConversions(assocFn)) { - auto *assoc = assocFnRef->getReferencedFunctionOrNull(); + if (auto *derivativeFnRef = + peerThroughFunctionConversions(derivativeFn)) { + auto *assoc = derivativeFnRef->getReferencedFunctionOrNull(); assocRef = builder.createFunctionRef(loc, assoc); } else if (auto *assocMethodInst = - peerThroughFunctionConversions(assocFn)) { + peerThroughFunctionConversions(derivativeFn)) { assocRef = builder.createWitnessMethod( loc, assocMethodInst->getLookupType(), assocMethodInst->getConformance(), assocMethodInst->getMember(), thunk->mapTypeIntoContext(assocMethodInst->getType())); } else if (auto *assocMethodInst = - peerThroughFunctionConversions(assocFn)) { + peerThroughFunctionConversions(derivativeFn)) { auto classOperand = thunk->getArgumentsWithoutIndirectResults().back(); auto classOperandType = assocMethodInst->getOperand()->getType(); assert(classOperand->getType() == classOperandType); @@ -8441,15 +8441,15 @@ ADContext::getOrCreateSubsetParametersThunkForAssociatedFunction( loc, classOperand, assocMethodInst->getMember(), thunk->mapTypeIntoContext(assocMethodInst->getType())); } - assert(assocRef && "Expected associated function to be resolved"); + assert(assocRef && "Expected derivative function to be resolved"); assocSubstMap = assocSubstMap.subst(thunk->getForwardingSubstitutionMap()); - assocFnType = assocRef->getType().castTo(); + derivativeFnType = assocRef->getType().castTo(); SmallVector arguments; arguments.append(thunk->getArguments().begin(), thunk->getArguments().end()); - assert(arguments.size() == assocFnType->getNumParameters() + - assocFnType->getNumIndirectFormalResults()); + assert(arguments.size() == derivativeFnType->getNumParameters() + + derivativeFnType->getNumIndirectFormalResults()); auto *apply = builder.createApply( loc, assocRef, assocSubstMap, arguments, /*isNonThrowing*/ false); @@ -8576,28 +8576,28 @@ SILValue ADContext::promoteToDifferentiableFunction( } SILAutoDiffIndices desiredIndices(resultIndex, parameterIndices); - SmallVector assocFns; + SmallVector derivativeFns; SmallVector newBuffersToDealloc; - for (auto assocFnKind : {AutoDiffAssociatedFunctionKind::JVP, - AutoDiffAssociatedFunctionKind::VJP}) { - auto assocFnAndIndices = emitAssociatedFunctionReference( - *this, builder, desiredIndices, assocFnKind, origFnOperand, invoker, + for (auto derivativeFnKind : {AutoDiffDerivativeFunctionKind::JVP, + AutoDiffDerivativeFunctionKind::VJP}) { + auto derivativeFnAndIndices = emitDerivativeFunctionReference( + *this, builder, desiredIndices, derivativeFnKind, origFnOperand, invoker, newBuffersToDealloc); // Show an error at the operator, highlight the argument, and show a note // at the definition site of the argument. - if (!assocFnAndIndices) + if (!derivativeFnAndIndices) return nullptr; - auto assocFn = assocFnAndIndices->first; - getGeneratedAssociatedFunctionReferences().push_back(assocFn); + auto derivativeFn = derivativeFnAndIndices->first; + getGeneratedDerivativeFunctionReferences().push_back(derivativeFn); // If desired indices are a subset of actual indices, create a "subset - // indices thunk" and destroy the emitted associated function reference. + // indices thunk" and destroy the emitted derivative function reference. // - For JVPs: the thunked JVP returns a differential taking fewer // parameters (using `.zero` for the dropped parameters). // - For VJPs: the thunked VJP returns a pullback that drops the unused // tangent values. - auto actualIndices = assocFnAndIndices->second; + auto actualIndices = derivativeFnAndIndices->second; // NOTE: `desiredIndices` may come from a partially-applied function and // have smaller capacity than `actualIndices`. We expect this logic to go // away when we support `@differentiable` partial apply. @@ -8606,9 +8606,9 @@ SILValue ADContext::promoteToDifferentiableFunction( getASTContext(), actualIndices.parameters->getCapacity()); if (actualIndices.source != desiredIndices.source || !actualIndices.parameters->equals(extendedDesiredIndices)) { - // Destroy the already emitted associated function reference because it + // Destroy the already emitted derivative function reference because it // is no longer used. - builder.emitDestroyValueOperation(loc, assocFn); + builder.emitDestroyValueOperation(loc, derivativeFn); // Check if underlying original function reference has been partially // applied with arguments. If so, produce an error: parameter subset // thunks do not yet support this case because partially applied arguments @@ -8633,42 +8633,42 @@ SILValue ADContext::promoteToDifferentiableFunction( SILFunction *thunk; SubstitutionMap interfaceSubs; std::tie(thunk, interfaceSubs) = - getOrCreateSubsetParametersThunkForAssociatedFunction( - origFnOperand, assocFn, assocFnKind, desiredIndices, + getOrCreateSubsetParametersThunkForDerivativeFunction( + origFnOperand, derivativeFn, derivativeFnKind, desiredIndices, actualIndices); auto *thunkFRI = builder.createFunctionRef(loc, thunk); if (auto genSig = thunk->getLoweredFunctionType()->getGenericSignature()) { - assocFn = builder.createPartialApply( + derivativeFn = builder.createPartialApply( loc, thunkFRI, interfaceSubs, {}, ParameterConvention::Direct_Guaranteed); } else { - assocFn = thunkFRI; + derivativeFn = thunkFRI; } } - auto expectedAssocFnTy = origFnTy->getAutoDiffAssociatedFunctionType( - parameterIndices, resultIndex, assocFnKind, getTypeConverter(), + auto expectedDerivativeFnTy = origFnTy->getAutoDiffDerivativeFunctionType( + parameterIndices, resultIndex, derivativeFnKind, getTypeConverter(), LookUpConformanceInModule(getModule().getSwiftModule())); - // If `assocFn` is `@convention(thin)` but is expected to be + // If `derivativeFn` is `@convention(thin)` but is expected to be // `@convention(thick)`, emit a `thin_to_thick` instruction. - if (expectedAssocFnTy->getRepresentation() + if (expectedDerivativeFnTy->getRepresentation() == SILFunctionTypeRepresentation::Thick && - assocFn->getType().castTo()->getRepresentation() + derivativeFn->getType().castTo()->getRepresentation() == SILFunctionTypeRepresentation::Thin) { - assocFn = builder.createThinToThickFunction( - loc, assocFn, SILType::getPrimitiveObjectType(expectedAssocFnTy)); + derivativeFn = builder.createThinToThickFunction( + loc, derivativeFn, SILType::getPrimitiveObjectType(expectedDerivativeFnTy)); } - assocFns.push_back(assocFn); + derivativeFns.push_back(derivativeFn); } - // Deallocate temporary buffers used for creating associated functions. + // Deallocate temporary buffers used for creating derivative functions. for (auto *buf : reversed(newBuffersToDealloc)) builder.createDeallocStack(loc, buf); auto origFnCopy = builder.emitCopyValueOperation(loc, origFnOperand); auto *newDFI = createDifferentiableFunction( builder, loc, parameterIndices, origFnCopy, - std::make_pair(assocFns[0], assocFns[1])); + std::make_pair(derivativeFns[0], derivativeFns[1])); resultIndices[dfi] = resultIndex; getDifferentiableFunctionInsts().push_back(dfi); @@ -8701,10 +8701,10 @@ void ADContext::foldDifferentiableFunctionExtraction( dfei->eraseFromParent(); continue; } - // Fold associated function extractors. - auto assocFnValue = - source->getDerivativeFunction(dfei->getAssociatedFunctionKind()); - dfei->replaceAllUsesWith(assocFnValue); + // Fold derivative function extractors. + auto derivativeFnValue = + source->getDerivativeFunction(dfei->getDerivativeFunctionKind()); + dfei->replaceAllUsesWith(derivativeFnValue); dfei->eraseFromParent(); } // If the `differentiable_function` instruction has no remaining uses, erase diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index e5fc655607e40..c5665f7a8f429 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -621,7 +621,7 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) { // Now that this member is in the `TangentVector` type, it should be marked // `@differentiable` so that the differentiation transform will synthesize - // associated functions for it. We only add this to public stored + // derivative functions for it. We only add this to public stored // properties, because their access outside the module will go through a // call to the getter. if (member->getEffectiveAccess() > AccessLevel::Internal && diff --git a/lib/Sema/MiscDiagnostics.cpp b/lib/Sema/MiscDiagnostics.cpp index 6aaa58ac80c81..e2e6b51f8ad72 100644 --- a/lib/Sema/MiscDiagnostics.cpp +++ b/lib/Sema/MiscDiagnostics.cpp @@ -2248,7 +2248,7 @@ class VarDeclUsageChecker : public ASTWalker { }; /// An AST walker that determines the underlying type of an opaque return decl -/// from its associated function body. +/// from its derivative function body. class OpaqueUnderlyingTypeChecker : public ASTWalker { TypeChecker &TC; AbstractFunctionDecl *Implementation; diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index d809a979f6592..c40645e51b099 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -2727,7 +2727,7 @@ TypeChecker::inferDifferentiableParameters( } // SWIFT_ENABLE_TENSORFLOW -static FuncDecl *resolveAutoDiffAssociatedFunction( +static FuncDecl *resolveAutoDiffDerivativeFunction( TypeChecker &TC, DeclNameWithLoc specifier, AbstractFunctionDecl *original, Type expectedTy, std::function isValid) { auto nameLoc = specifier.Loc.getBaseNameLoc(); @@ -2750,9 +2750,9 @@ static FuncDecl *resolveAutoDiffAssociatedFunction( specifier.Name); }; - // Returns true if the original function and associated function candidate are + // Returns true if the original function and derivative function candidate are // defined in compatible type contexts. If the original function and the - // associated function have different parents, or if they both have no type + // derivative function have different parents, or if they both have no type // context and are in different modules, return false. std::function hasValidTypeContext = [&](FuncDecl *func) { // Check if both functions are top-level. @@ -2775,7 +2775,7 @@ static FuncDecl *resolveAutoDiffAssociatedFunction( }; // If the original function is exported (i.e. it is public or - // @usableFromInline), then the associated functions must also be exported. + // @usableFromInline), then the derivative functions must also be exported. // Returns true on error. auto checkAccessControl = [&](FuncDecl *func) { if (!isABIPublic(original)) @@ -3330,7 +3330,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { // Handle 'where' clause, if it exists. // - Resolve attribute where clause requirements and store in the attribute // for serialization. - // - Compute generic signature for autodiff associated functions based on + // - Compute generic signature for autodiff derivative functions based on // the original function's generate signature and the attribute's where // clause requirements. GenericSignature *whereClauseGenSig = nullptr; @@ -3364,7 +3364,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { return; } - // Build a new generic signature for autodiff associated functions. + // Build a new generic signature for autodiff derivative functions. GenericSignatureBuilder builder(ctx); // Add the original function's generic signature. builder.addGenericSignature(originalGenSig); @@ -3476,9 +3476,9 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { // Resolve the JVP declaration, if it exists. if (attr->getJVP()) { AnyFunctionType *expectedJVPFnTy = - originalFnTy->getAutoDiffAssociatedFunctionType( + originalFnTy->getAutoDiffDerivativeFunctionType( checkedWrtParamIndices, /*resultIndex*/ 0, - AutoDiffAssociatedFunctionKind::JVP, lookupConformance, + AutoDiffDerivativeFunctionKind::JVP, lookupConformance, whereClauseGenSig, /*makeSelfParamFirst*/ true); auto isValidJVP = [&](FuncDecl *jvpCandidate) { @@ -3488,7 +3488,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { jvpCandidate->getInterfaceType()->getCanonicalType()); }; - FuncDecl *jvp = resolveAutoDiffAssociatedFunction( + FuncDecl *jvp = resolveAutoDiffDerivativeFunction( TC, attr->getJVP().getValue(), original, expectedJVPFnTy, isValidJVP); if (!jvp) { @@ -3502,9 +3502,9 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { // Resolve the VJP declaration, if it exists. if (attr->getVJP()) { AnyFunctionType *expectedVJPFnTy = - originalFnTy->getAutoDiffAssociatedFunctionType( + originalFnTy->getAutoDiffDerivativeFunctionType( checkedWrtParamIndices, /*resultIndex*/ 0, - AutoDiffAssociatedFunctionKind::VJP, lookupConformance, + AutoDiffDerivativeFunctionKind::VJP, lookupConformance, whereClauseGenSig, /*makeSelfParamFirst*/ true); auto isValidVJP = [&](FuncDecl *vjpCandidate) { @@ -3514,7 +3514,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { vjpCandidate->getInterfaceType()->getCanonicalType()); }; - FuncDecl *vjp = resolveAutoDiffAssociatedFunction( + FuncDecl *vjp = resolveAutoDiffDerivativeFunction( TC, attr->getVJP().getValue(), original, expectedVJPFnTy, isValidVJP); if (!vjp) { @@ -3593,8 +3593,8 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { } auto valueResultElt = derivativeResultTupleType->getElement(0); auto funcResultElt = derivativeResultTupleType->getElement(1); - // Get derivative kind and associated function identifier. - AutoDiffAssociatedFunctionKind kind; + // Get derivative kind and derivative function identifier. + AutoDiffDerivativeFunctionKind kind; if (valueResultElt.getName().str() != "value") { TC.diagnose(attr->getLocation(), diag::differentiating_attr_invalid_result_tuple_value_label); @@ -3602,9 +3602,9 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { return; } if (funcResultElt.getName().str() == "differential") { - kind = AutoDiffAssociatedFunctionKind::JVP; + kind = AutoDiffDerivativeFunctionKind::JVP; } else if (funcResultElt.getName().str() == "pullback") { - kind = AutoDiffAssociatedFunctionKind::VJP; + kind = AutoDiffDerivativeFunctionKind::VJP; } else { TC.diagnose(attr->getLocation(), diag::differentiating_attr_invalid_result_tuple_func_label); @@ -3773,7 +3773,7 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { // Compute expected differential/pullback type. auto funcEltType = funcResultElt.getType(); Type expectedFuncEltType; - if (kind == AutoDiffAssociatedFunctionKind::JVP) { + if (kind == AutoDiffDerivativeFunctionKind::JVP) { auto diffParams = map>( diffParamElts, [&](TupleTypeElt elt) { return AnyFunctionType::Param(elt.getType()); @@ -3834,10 +3834,10 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { /*vjp*/ None, derivative->getGenericSignature()); switch (kind) { - case AutoDiffAssociatedFunctionKind::JVP: + case AutoDiffDerivativeFunctionKind::JVP: da->setJVPFunction(derivative); break; - case AutoDiffAssociatedFunctionKind::VJP: + case AutoDiffDerivativeFunctionKind::VJP: da->setVJPFunction(derivative); break; } @@ -3861,7 +3861,7 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { // `@differentiating` attribute. Otherwise, register the derivative in the // `@differentiable` attribute. switch (kind) { - case AutoDiffAssociatedFunctionKind::JVP: + case AutoDiffDerivativeFunctionKind::JVP: // If there's a different registered derivative, emit an error. if ((da->getJVP() && da->getJVP()->Name.getBaseName() != derivative->getBaseName()) || @@ -3873,7 +3873,7 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { } da->setJVPFunction(derivative); break; - case AutoDiffAssociatedFunctionKind::VJP: + case AutoDiffDerivativeFunctionKind::VJP: // If there's a different registered derivative, emit an error. if ((da->getVJP() && da->getVJP()->Name.getBaseName() != derivative->getBaseName()) || diff --git a/lib/Sema/TypeCheckDeclOverride.cpp b/lib/Sema/TypeCheckDeclOverride.cpp index 9b7becf61f3b3..8cade41303c71 100644 --- a/lib/Sema/TypeCheckDeclOverride.cpp +++ b/lib/Sema/TypeCheckDeclOverride.cpp @@ -650,7 +650,7 @@ static bool overridesDifferentiableAttribute(ValueDecl *derivedDecl, std::string baseDAString; llvm::raw_string_ostream stream(baseDAString); baseDA->print(stream, derivedDecl, omitWrtClause, - /*omitAssociatedFunctions*/ true); + /*omitDerivativeFunctions*/ true); diags.diagnose( derivedDecl, diag::overriding_decl_missing_differentiable_attr, StringRef(stream.str()).trim()); diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index a0d7b2adc8f41..fb1bf4d75b880 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -2263,7 +2263,7 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance, std::string reqDiffAttrString; llvm::raw_string_ostream stream(reqDiffAttrString); reqAttr->print(stream, req, omitWrtClause, - /*omitAssociatedFunctions*/ true); + /*omitDerivativeFunctions*/ true); diags.diagnose(match.Witness, diag::protocol_witness_missing_differentiable_attr, StringRef(stream.str()).trim()); diff --git a/lib/TBDGen/TBDGen.cpp b/lib/TBDGen/TBDGen.cpp index 083f93f7b5c64..9dee6ec6765bc 100644 --- a/lib/TBDGen/TBDGen.cpp +++ b/lib/TBDGen/TBDGen.cpp @@ -237,14 +237,14 @@ void TBDGenVisitor::visitAbstractFunctionDecl(AbstractFunctionDecl *AFD) { // function with a `@differentiable` attribute. auto diffAttrs = AFD->getAttrs().getAttributes(); for (auto *DA : diffAttrs) { - auto *jvpId = AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::JVP, DA->getParameterIndices(), + auto *jvpId = AutoDiffDerivativeFunctionIdentifier::get( + AutoDiffDerivativeFunctionKind::JVP, DA->getParameterIndices(), AFD->getASTContext()); - addSymbol(SILDeclRef(AFD).asAutoDiffAssociatedFunction(jvpId)); - auto *vjpId = AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::VJP, DA->getParameterIndices(), + addSymbol(SILDeclRef(AFD).asAutoDiffDerivativeFunction(jvpId)); + auto *vjpId = AutoDiffDerivativeFunctionIdentifier::get( + AutoDiffDerivativeFunctionKind::VJP, DA->getParameterIndices(), AFD->getASTContext()); - addSymbol(SILDeclRef(AFD).asAutoDiffAssociatedFunction(vjpId)); + addSymbol(SILDeclRef(AFD).asAutoDiffDerivativeFunction(vjpId)); } visitDefaultArguments(AFD, AFD->getParameters()); @@ -300,16 +300,16 @@ void TBDGenVisitor::visitAbstractStorageDecl(AbstractStorageDecl *ASD) { // var/subscript with a `@differentiable` attribute. auto diffAttrs = ASD->getAttrs().getAttributes(); for (auto *DA : diffAttrs) { - auto *jvpId = AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::JVP, DA->getParameterIndices(), + auto *jvpId = AutoDiffDerivativeFunctionIdentifier::get( + AutoDiffDerivativeFunctionKind::JVP, DA->getParameterIndices(), ASD->getASTContext()); addSymbol(SILDeclRef(ASD->getAccessor(AccessorKind::Get)) - .asAutoDiffAssociatedFunction(jvpId)); - auto *vjpId = AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::VJP, DA->getParameterIndices(), + .asAutoDiffDerivativeFunction(jvpId)); + auto *vjpId = AutoDiffDerivativeFunctionIdentifier::get( + AutoDiffDerivativeFunctionKind::VJP, DA->getParameterIndices(), ASD->getASTContext()); addSymbol(SILDeclRef(ASD->getAccessor(AccessorKind::Get)) - .asAutoDiffAssociatedFunction(vjpId)); + .asAutoDiffDerivativeFunction(vjpId)); } // Explicitly look at each accessor here: see visitAccessorDecl. diff --git a/test/AutoDiff/differentiable_attr_access_control.swift b/test/AutoDiff/differentiable_attr_access_control.swift index c67866bda5888..1337edea8fb15 100644 --- a/test/AutoDiff/differentiable_attr_access_control.swift +++ b/test/AutoDiff/differentiable_attr_access_control.swift @@ -20,6 +20,6 @@ private func foo3(_ x: Float) -> Float { return 1 } private func dfoo3(_ x: Float) -> (Float, (Float) -> Float) { return (1, {$0}) } // Error: vjp not exported. -@differentiable(vjp: dbar1) // expected-error {{associated differentiation function 'dbar1' is required to either be public or @usableFromInline because the original function 'bar1' is public or @usableFromInline}} +@differentiable(vjp: dbar1) // expected-error {{derivative function 'dbar1' is required to either be public or '@usableFromInline' because the original function 'bar1' is public or '@usableFromInline'}} public func bar1(_ x: Float) -> Float { return 1 } private func dbar1(_ x: Float) -> (Float, (Float) -> Float) { return (1, {$0}) } diff --git a/test/AutoDiff/generics.swift b/test/AutoDiff/generics.swift index ff640ce45a71f..f982e17acf331 100644 --- a/test/AutoDiff/generics.swift +++ b/test/AutoDiff/generics.swift @@ -245,7 +245,7 @@ public func TF_688( reduction(x) } -// TF-697: Test generic requirements of generated AD associated function. +// TF-697: Test generic requirements of generated derivative function. protocol TF_697_Module: Differentiable { associatedtype Input associatedtype Output: Differentiable