From bb9c2fe6a2847e7193dc687639a8e19a61550167 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Mon, 7 Oct 2019 12:32:32 -0700 Subject: [PATCH 1/4] [AutoDiff] Remove differentiation order from AD-related instructions. The differentiation order field in `differentiable_function` and `differentiable_function_extract` instructions is unsupported and will not be used by the current design. Quite a lot of dead code exists to try to handle `order`, but it is mostly incomplete and untested. This PR removes the differentiation order from the code base to simplify what we upstream to the 'master' branch. Changes include: * Remove `differentiationOrder` from `DifferentiableFunctionInst` and `DifferentiableFunctionExtractInst`. * Make `DifferentiableFunctionInst::DifferentiableFunctionInst` take an optional pair of JVP and VJP instead of a variable-size array. * Rename "associated functions" to "derivative functions" in `DifferentiableFunctionInst` to align better with [the design](https://forums.swift.org/t/differentiable-programming-mega-proposal/28547). Filed task [TF-882](https://bugs.swift.org/browse/TF-882) to track the renaming of all other occurrences of "associated functions". Resolves [TF-880](https://bugs.swift.org/browse/TF-880). --- docs/SIL.rst | 21 ++-- include/swift/AST/AutoDiff.h | 29 +---- include/swift/AST/DiagnosticsParse.def | 4 - include/swift/AST/Types.h | 12 +- include/swift/SIL/SILBuilder.h | 25 ++-- include/swift/SIL/SILCloner.h | 12 +- include/swift/SIL/SILInstruction.h | 115 ++++++++---------- include/swift/SIL/SILVTableVisitor.h | 8 +- include/swift/SIL/SILWitnessVisitor.h | 4 +- lib/AST/ASTContext.cpp | 7 +- lib/AST/ASTPrinter.cpp | 2 - lib/AST/AutoDiff.cpp | 23 +--- lib/AST/Builtins.cpp | 11 +- lib/AST/Type.cpp | 4 +- lib/IRGen/GenDiffFunc.cpp | 38 +++--- lib/IRGen/IRGenSIL.cpp | 11 +- lib/IRGen/LoadableByAddress.cpp | 10 +- lib/ParseSIL/ParseSIL.cpp | 93 ++++---------- lib/SIL/SILFunctionType.cpp | 12 +- lib/SIL/SILInstructions.cpp | 98 +++++++-------- lib/SIL/SILPrinter.cpp | 17 +-- lib/SIL/SILVerifier.cpp | 63 ++++------ lib/SIL/TypeLowering.cpp | 53 ++++---- lib/SILGen/SILGen.cpp | 8 +- lib/SILGen/SILGenBuiltin.cpp | 10 +- lib/SILGen/SILGenExpr.cpp | 4 +- lib/SILGen/SILGenPoly.cpp | 16 ++- lib/SILGen/SILGenThunk.cpp | 5 +- .../Mandatory/Differentiation.cpp | 77 +++++------- lib/Sema/TypeCheckAttr.cpp | 8 +- lib/Serialization/DeserializeSIL.cpp | 27 ++-- lib/Serialization/ModuleFormat.h | 2 +- lib/Serialization/SILFormat.h | 6 +- lib/Serialization/SerializeSIL.cpp | 6 +- lib/TBDGen/TBDGen.cpp | 16 +-- test/AutoDiff/core_builtins.swift | 4 +- .../AutoDiff/differentiable_function_inst.sil | 12 +- .../differentiable_function_inst_irgen.sil | 4 +- .../differentiable_function_silgen.swift | 14 +-- ...differentiable_sil_function_type_parse.sil | 24 ++-- test/AutoDiff/forward_mode_sil.swift | 12 +- test/AutoDiff/refcounting.swift | 4 +- test/AutoDiff/sildeclref_parse.sil | 16 +-- test/AutoDiff/simple_real_vector.swift | 2 +- test/AutoDiff/subset_parameters_thunk.swift | 2 +- test/AutoDiff/vtable_sil.swift | 44 +++---- test/AutoDiff/witness_method_autodiff.sil | 16 +-- test/AutoDiff/witness_table_irgen.sil | 4 +- test/AutoDiff/witness_table_sil.swift | 36 +++--- 49 files changed, 424 insertions(+), 627 deletions(-) diff --git a/docs/SIL.rst b/docs/SIL.rst index c6d57f43c34bd..98839126da255 100644 --- a/docs/SIL.rst +++ b/docs/SIL.rst @@ -5620,21 +5620,18 @@ differentiable_function sil-differentiable-function-associated-function-list ::= '{' sil-value ',' sil-value '}' - differentiable_function [wrt 0] [order 1] %0 : $(T) -> T \ + differentiable_function [wrt 0] %0 : $(T) -> T \ with {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)} -Bundles a function with its associated differentiation functions up to a -specified differentiation order into an ``@differentiable`` function. There are -two associated functions per differentiation order: a Jacobian-vector products -(JVP) function and a vector-Jacobian products (VJP) function. +Bundles a function with its associated differentiation functions into a +``@differentiable`` function. There are two associated functions: +a Jacobian-vector products (JVP) function and a vector-Jacobian products (VJP) +function. ``[wrt ...]`` specifies parameter indices that the original function is differentiable with respect to. When not specified, it defaults to all parameters. -``[order ...]`` specifies the maximum differentiation order for the resulting -function. The number of lists of associated functions is equal to the order. - A ``with`` clause specifies the differentiation functions associated with the original function. When a ``with`` clause is not specified, the first operand will be differentiated to produce associated functions, and a ``with`` @@ -5660,12 +5657,12 @@ differentiable_function_extract sil-differentiable-function-differentiation-order ::= '[' 'order' [0-9]+ ']' differentiable_function_extract [original] %0 : $@differentiable (T) -> T - differentiable_function_extract [jvp] [order 1] %0 : $@differentiable (T) -> T - differentiable_function_extract [vjp] [order 1] %0 : $@differentiable (T) -> T + differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T + differentiable_function_extract [vjp] %0 : $@differentiable (T) -> T Extracts the original function or an associated function from the given -``@differentiable`` function at a specific differentiation order. It must be -provided with an extractee: ``[original]``, ``[jvp]`` or ``[vjp]``. +``@differentiable`` function. It must be provided with an extractee: +``[original]``, ``[jvp]`` or ``[vjp]``. Assertion configuration diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 745547a0e6ea8..3534b87e35697 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -458,29 +458,25 @@ struct AutoDiffAssociatedFunctionKind { /// compared by opaque pointer value. class AutoDiffAssociatedFunctionIdentifier : public llvm::FoldingSetNode { const AutoDiffAssociatedFunctionKind kind; - const unsigned differentiationOrder; AutoDiffIndexSubset *const parameterIndices; AutoDiffAssociatedFunctionIdentifier( - AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder, + AutoDiffAssociatedFunctionKind kind, AutoDiffIndexSubset *parameterIndices) : - kind(kind), differentiationOrder(differentiationOrder), - parameterIndices(parameterIndices) {} + kind(kind), parameterIndices(parameterIndices) {} public: AutoDiffAssociatedFunctionKind getKind() const { return kind; } - unsigned getDifferentiationOrder() const { return differentiationOrder; } AutoDiffIndexSubset *getParameterIndices() const { return parameterIndices; } static AutoDiffAssociatedFunctionIdentifier *get( - AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder, + AutoDiffAssociatedFunctionKind kind, AutoDiffIndexSubset *parameterIndices, ASTContext &C); void Profile(llvm::FoldingSetNodeID &ID) { ID.AddInteger(kind); - ID.AddInteger(differentiationOrder); ID.AddPointer(parameterIndices); } }; @@ -520,29 +516,12 @@ void getSubsetParameterTypes(AutoDiffIndexSubset *indices, AutoDiffIndexSubset *getLoweredParameterIndices(AutoDiffIndexSubset *indices, AnyFunctionType *type); -/// Returns the offset for an associated function at a specific differentiation -/// order. -/// This is used for both ordering in the `differentiable_function` instruction -/// and ABI layout. -/// -/// Order 1 Order 2 ... -/// |----------| |-----|-----| |-----|-----| ... -/// | Original | | JVP | VJP | | JVP | VJP | ... -/// |----------| |-----|-----| |-----|-----| ... -unsigned -getOffsetForAutoDiffAssociatedFunction(unsigned order, - AutoDiffAssociatedFunctionKind kind); - -unsigned -getNumAutoDiffAssociatedFunctions(unsigned differentiationOrder); - /// Retrieve config from the function name of a variant of /// `Builtin.autodiffApply`, e.g. `Builtin.autodiffApply_jvp_arity2_order1`. /// Returns true if the function name is parsed successfully. bool getBuiltinAutoDiffApplyConfig(StringRef operationName, AutoDiffAssociatedFunctionKind &kind, - unsigned &arity, unsigned &order, - bool &rethrows); + unsigned &arity, bool &rethrows); /// Computes the correct linkage for an associated function given the linkage of /// the original function. If the original linkage is not external and diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index b5e528821236d..68c88f1907dd0 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -1594,10 +1594,6 @@ ERROR(sil_attr_differentiable_expected_source_index,PointsToFirstBadToken, // SIL autodiff ERROR(sil_inst_autodiff_attr_expected_rsquare,PointsToFirstBadToken, "expected ']' to complete the %0", (StringRef)) -ERROR(sil_inst_autodiff_expected_order,PointsToFirstBadToken, - "expected an unsigned integer indicating the differentiation order", ()) -ERROR(sil_inst_autodiff_expected_nonzero_order,PointsToFirstBadToken, - "expected a non-zero differentiation order", ()) ERROR(sil_inst_autodiff_expected_parameter_index,PointsToFirstBadToken, "expected the index of a parameter to differentiate with respect to", ()) ERROR(sil_inst_autodiff_operand_list_expected_lbrace,PointsToFirstBadToken, diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index a3c1d5bcb53c6..b7084d437d654 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -3099,8 +3099,8 @@ class AnyFunctionType : public TypeBase { } // SWIFT_ENABLE_TENSORFLOW - /// Given `indices`, `differentiationOrder`, and `kind`, calculates the type - /// of the corresponding autodiff associated function. + /// Given `indices` and `kind`, calculates the type of the corresponding + /// autodiff associated function. /// /// By default, if the original type has a self parameter list and parameter /// indices include self, the computed associated function type will return a @@ -3116,7 +3116,7 @@ class AnyFunctionType : public TypeBase { /// function, including `@differentiable`. AnyFunctionType *getAutoDiffAssociatedFunctionType( AutoDiffIndexSubset *indices, unsigned resultIndex, - unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind, + AutoDiffAssociatedFunctionKind kind, LookupConformanceFn lookupConformance, GenericSignature *whereClauseGenericSignature = nullptr, bool makeSelfParamFirst = false); @@ -4216,7 +4216,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode, // SWIFT_ENABLE_TENSORFLOW CanSILFunctionType getWithDifferentiability( - unsigned differentiationOrder, AutoDiffIndexSubset *parameterIndices); + AutoDiffIndexSubset *parameterIndices); CanSILFunctionType getWithoutDifferentiability(); @@ -4224,8 +4224,8 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode, /// a function of this type. CanSILFunctionType getAutoDiffAssociatedFunctionType( AutoDiffIndexSubset *parameterIndices, unsigned resultIndex, - unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind, - Lowering::TypeConverter &TC, LookupConformanceFn lookupConformance, + AutoDiffAssociatedFunctionKind kind, Lowering::TypeConverter &TC, + LookupConformanceFn lookupConformance, CanGenericSignature associatedFunctionGenericSignature = nullptr); /// Returns a bit vector that specifices which parameters you can diff --git a/include/swift/SIL/SILBuilder.h b/include/swift/SIL/SILBuilder.h index cc061ea25055f..88aeb83dd9d18 100644 --- a/include/swift/SIL/SILBuilder.h +++ b/include/swift/SIL/SILBuilder.h @@ -511,28 +511,27 @@ class SILBuilder { /// SWIFT_ENABLE_TENSORFLOW DifferentiableFunctionInst *createDifferentiableFunction( - SILLocation loc, AutoDiffIndexSubset *parameterIndices, - unsigned differentiationOrder, SILValue original, - ArrayRef associatedFunctions = {}) { + SILLocation Loc, AutoDiffIndexSubset *ParameterIndices, + SILValue OriginalFunction, + Optional> JVPAndVJPFunctions = None) { return insert(DifferentiableFunctionInst::create( - getModule(), getSILDebugLocation(loc), parameterIndices, - differentiationOrder, original, associatedFunctions)); + getModule(), getSILDebugLocation(Loc), ParameterIndices, + OriginalFunction, JVPAndVJPFunctions, hasOwnership())); } DifferentiableFunctionExtractInst *createDifferentiableFunctionExtract( - SILLocation loc, DifferentiableFunctionExtractee extractee, - unsigned differentiationOrder, SILValue theFunction) { + SILLocation Loc, DifferentiableFunctionExtractee Extractee, + SILValue TheFunction) { return insert(new (getModule()) DifferentiableFunctionExtractInst( - getModule(), getSILDebugLocation(loc), extractee, differentiationOrder, - theFunction)); + getModule(), getSILDebugLocation(Loc), Extractee, TheFunction)); } DifferentiableFunctionExtractInst * - createDifferentiableFunctionExtractOriginal(SILLocation loc, - SILValue theFunction) { + createDifferentiableFunctionExtractOriginal(SILLocation Loc, + SILValue TheFunction) { return insert(new (getModule()) DifferentiableFunctionExtractInst( - getModule(), getSILDebugLocation(loc), - DifferentiableFunctionExtractee::Original, 0, theFunction)); + getModule(), getSILDebugLocation(Loc), + DifferentiableFunctionExtractee::Original, TheFunction)); } BuiltinInst *createBuiltin(SILLocation Loc, Identifier Name, SILType ResultTy, diff --git a/include/swift/SIL/SILCloner.h b/include/swift/SIL/SILCloner.h index 13e56b96c5f33..618b0eeeecaa2 100644 --- a/include/swift/SIL/SILCloner.h +++ b/include/swift/SIL/SILCloner.h @@ -970,15 +970,14 @@ template void SILCloner::visitDifferentiableFunctionInst( DifferentiableFunctionInst *Inst) { getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope())); - SmallVector mappedAssocFns; - mappedAssocFns.reserve(Inst->getNumAssociatedFunctions()); - for (auto &fn : Inst->getAssociatedFunctions()) - mappedAssocFns.push_back(getOpValue(fn.get())); + Optional> assocFns = None; + if (Inst->hasDerivativeFunctions()) + assocFns = std::make_pair(getOpValue(Inst->getJVPFunction()), + getOpValue(Inst->getVJPFunction())); recordClonedInstruction( Inst, getBuilder().createDifferentiableFunction( getOpLocation(Inst->getLoc()), Inst->getParameterIndices(), - Inst->getDifferentiationOrder(), - getOpValue(Inst->getOriginalFunction()), mappedAssocFns)); + getOpValue(Inst->getOriginalFunction()), assocFns)); } template @@ -988,7 +987,6 @@ visitDifferentiableFunctionExtractInst(DifferentiableFunctionExtractInst *Inst) recordClonedInstruction( Inst, getBuilder().createDifferentiableFunctionExtract( getOpLocation(Inst->getLoc()), Inst->getExtractee(), - Inst->getDifferentiationOrder(), getOpValue(Inst->getFunctionOperand()))); } // SWIFT_ENABLE_TENSORFLOW END diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h index 6c3d34d322846..02db90e4fd597 100644 --- a/include/swift/SIL/SILInstruction.h +++ b/include/swift/SIL/SILInstruction.h @@ -7855,60 +7855,69 @@ class DifferentiableFunctionInst final : private: friend SILBuilder; /// Differentiation parameter indices. - AutoDiffIndexSubset *parameterIndices; - /// The order of differentiation. - unsigned differentiationOrder; - /// The number of operands. The first operand is always the original function. - /// The rest of operands determined by the order of differentiation and whether - /// this is the new AD model or the legacy reverse-mode AD model. - unsigned numOperands; + AutoDiffIndexSubset *ParameterIndices; + /// Indicates whether derivative functions (JVP/VJP) exist. + bool HasDerivativeFunctions; - DifferentiableFunctionInst(SILModule &module, SILDebugLocation debugLoc, - AutoDiffIndexSubset *parameterIndices, - unsigned differentiationOrder, - SILValue originalFunction, - ArrayRef associatedFunctions); + DifferentiableFunctionInst( + SILDebugLocation DebugLoc, AutoDiffIndexSubset *ParameterIndices, + SILValue OriginalFunction, ArrayRef DerivativeFunctions, + bool HasOwnership); + + static SILType getDifferentiableFunctionType( + SILValue Original, AutoDiffIndexSubset *ParameterIndices); + + static ValueOwnershipKind getMergedOwnershipKind( + SILValue Original, ArrayRef DerivativeFunctions); public: static DifferentiableFunctionInst *create( - SILModule &module, SILDebugLocation debugLoc, - AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, - SILValue originalFunction, ArrayRef associatedFunctions); - - static SILType getAutoDiffType(SILValue original, - unsigned differentiationOrder, - AutoDiffIndexSubset *parameterIndices); + SILModule &Module, SILDebugLocation DebugLoc, + AutoDiffIndexSubset *ParameterIndices, SILValue OriginalFunction, + Optional> VJPAndJVPFunctions, + bool HasOwnership); /// Returns the original function. - SILValue getOriginalFunction() const { return getAllOperands()[0].get(); } + SILValue getOriginalFunction() const { return getOperand(0); } /// Returns differentiation indices. - AutoDiffIndexSubset *getParameterIndices() const { - return parameterIndices; - } + AutoDiffIndexSubset *getParameterIndices() const { return ParameterIndices; } - /// Returns the differentiation order. - unsigned getDifferentiationOrder() const { - return differentiationOrder; - } + /// Returns true if derivative functions (JVP/VJP) exist. + bool hasDerivativeFunctions() const { return HasDerivativeFunctions; } - unsigned getNumAssociatedFunctions() const { - return numOperands - 1; + /// Returns the derivative functions, namely the JVP and VJP functions, if + /// they exist. Otherwise, return None. + Optional> + getOptionalDerivativeFunctionPair() const { + if (!HasDerivativeFunctions) + return None; + return std::make_pair(getOperand(1), getOperand(2)); } - bool hasAssociatedFunctions() const { - return numOperands > 1; + ArrayRef getDerivativeFunctionArray() const { + return getAllOperands().drop_front(); } - ArrayRef getAssociatedFunctions() const { - return getAllOperands().drop_front(); + /// Returns the JVP function. + SILValue getJVPFunction() const { + assert(HasDerivativeFunctions); + return getOperand(1); } - std::pair - getAssociatedFunctionPair(unsigned differentiationOrder) const; + /// Returns the VJP function. + SILValue getVJPFunction() const { + assert(HasDerivativeFunctions); + return getOperand(2); + } - SILValue getAssociatedFunction(unsigned differentiationOrder, - AutoDiffAssociatedFunctionKind kind) const; + /// Returns the derivative function (JVP or VJP) that matches the given kind. + SILValue getDerivativeFunction(AutoDiffAssociatedFunctionKind kind) const { + switch (kind) { + case AutoDiffAssociatedFunctionKind::JVP: return getJVPFunction(); + case AutoDiffAssociatedFunctionKind::VJP: return getVJPFunction(); + } + } }; /// `differentiable_function_extract` - given an `@differentiable` function @@ -7917,7 +7926,7 @@ class DifferentiableFunctionInst final : class DifferentiableFunctionExtractInst : public InstructionBase< SILInstructionKind::DifferentiableFunctionExtractInst, - OwnershipForwardingSingleValueInst> { + SingleValueInstruction> { public: struct Extractee { enum innerty : unsigned { @@ -7939,24 +7948,18 @@ class DifferentiableFunctionExtractInst private: /// The extractee. Extractee extractee; - /// The differentiation order. A zero value is only legal when the extractee - /// is the original function, and it is a private representation only. - unsigned differentiationOrder; /// The list containing the `@differentiable` function operand. FixedOperandList<1> operands; static SILType - getExtracteeType(SILValue function, Extractee extractee, - unsigned differentiationOrder, SILModule &module); + getExtracteeType(SILValue function, Extractee extractee, SILModule &module); public: explicit DifferentiableFunctionExtractInst( SILModule &module, SILDebugLocation debugLoc, Extractee extractee, - unsigned differentiationOrder, SILValue theFunction); + SILValue theFunction); - Extractee getExtractee() const { - return extractee; - } + Extractee getExtractee() const { return extractee; } AutoDiffAssociatedFunctionKind getAssociatedFunctionKind() const { auto kind = extractee.getExtracteeAsAssociatedFunction(); @@ -7964,21 +7967,9 @@ class DifferentiableFunctionExtractInst return *kind; } - SILValue getFunctionOperand() const { - return operands[0].get(); - } - - unsigned getDifferentiationOrder() const { - return differentiationOrder; - } - - ArrayRef getAllOperands() const { - return operands.asArray(); - } - - MutableArrayRef getAllOperands() { - return operands.asArray(); - } + SILValue getFunctionOperand() const { return operands[0].get(); } + ArrayRef getAllOperands() const { return operands.asArray(); } + MutableArrayRef getAllOperands() { return operands.asArray(); } }; typedef DifferentiableFunctionExtractInst::Extractee diff --git a/include/swift/SIL/SILVTableVisitor.h b/include/swift/SIL/SILVTableVisitor.h index f9849e5a47006..e44dc4a4ca656 100644 --- a/include/swift/SIL/SILVTableVisitor.h +++ b/include/swift/SIL/SILVTableVisitor.h @@ -93,13 +93,13 @@ template class SILVTableVisitor { auto constant = SILDeclRef(fd, SILDeclRef::Kind::Func); auto jvpConstant = constant.asAutoDiffAssociatedFunction( AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::JVP, /*differentiationOrder*/ 1, + AutoDiffAssociatedFunctionKind::JVP, DA->getParameterIndices(), fd->getASTContext())); maybeAddEntry(jvpConstant); auto vjpConstant = constant.asAutoDiffAssociatedFunction( AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::VJP, /*differentiationOrder*/ 1, + AutoDiffAssociatedFunctionKind::VJP, DA->getParameterIndices(), fd->getASTContext())); maybeAddEntry(vjpConstant); } @@ -120,13 +120,13 @@ template class SILVTableVisitor { auto constant = SILDeclRef(cd, SILDeclRef::Kind::Allocator); auto jvpConstant = constant.asAutoDiffAssociatedFunction( AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::JVP, /*differentiationOrder*/ 1, + AutoDiffAssociatedFunctionKind::JVP, DA->getParameterIndices(), cd->getASTContext())); maybeAddEntry(jvpConstant); auto vjpConstant = constant.asAutoDiffAssociatedFunction( AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::VJP, /*differentiationOrder*/ 1, + AutoDiffAssociatedFunctionKind::VJP, DA->getParameterIndices(), cd->getASTContext())); maybeAddEntry(vjpConstant); } diff --git a/include/swift/SIL/SILWitnessVisitor.h b/include/swift/SIL/SILWitnessVisitor.h index 212276afcd74a..6bac8e8c1eb1f 100644 --- a/include/swift/SIL/SILWitnessVisitor.h +++ b/include/swift/SIL/SILWitnessVisitor.h @@ -183,11 +183,11 @@ template class SILWitnessVisitor : public ASTVisitor { for (auto *DA : func->getAttrs().getAttributes()) { asDerived().addMethod(funcDeclRef.asAutoDiffAssociatedFunction( AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::JVP, /*differentiationOrder*/ 1, + AutoDiffAssociatedFunctionKind::JVP, DA->getParameterIndices(), func->getASTContext()))); asDerived().addMethod(funcDeclRef.asAutoDiffAssociatedFunction( AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::VJP, /*differentiationOrder*/ 1, + AutoDiffAssociatedFunctionKind::VJP, DA->getParameterIndices(), func->getASTContext()))); } } diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index 67e029dd44b24..8f2be19379006 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -4829,13 +4829,12 @@ AutoDiffIndexSubset::get(ASTContext &ctx, const SmallBitVector &indices) { AutoDiffAssociatedFunctionIdentifier * AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder, - AutoDiffIndexSubset *parameterIndices, ASTContext &C) { + AutoDiffAssociatedFunctionKind kind, AutoDiffIndexSubset *parameterIndices, + ASTContext &C) { assert(parameterIndices); auto &foldingSet = C.getImpl().AutoDiffAssociatedFunctionIdentifiers; llvm::FoldingSetNodeID id; id.AddInteger((unsigned)kind); - id.AddInteger(differentiationOrder); id.AddPointer(parameterIndices); void *insertPos; @@ -4846,7 +4845,7 @@ AutoDiffAssociatedFunctionIdentifier::get( void *mem = C.Allocate(sizeof(AutoDiffAssociatedFunctionIdentifier), alignof(AutoDiffAssociatedFunctionIdentifier)); auto *newNode = ::new (mem) AutoDiffAssociatedFunctionIdentifier( - kind, differentiationOrder, parameterIndices); + kind, parameterIndices); foldingSet.InsertNode(newNode, insertPos); return newNode; diff --git a/lib/AST/ASTPrinter.cpp b/lib/AST/ASTPrinter.cpp index c6a12f27b6609..99e94eb020977 100644 --- a/lib/AST/ASTPrinter.cpp +++ b/lib/AST/ASTPrinter.cpp @@ -3791,7 +3791,6 @@ class TypePrinter : public TypeVisitor { // SWIFT_ENABLE_TENSORFLOW if (!Options.excludeAttrKind(TAK_differentiable) && info.isDifferentiable()) { - // FIXME(rxwei): Print differentiation order. if (info.getDifferentiabilityKind() == DifferentiabilityKind::Linear) { Printer << "@differentiable(linear) "; } else { @@ -3845,7 +3844,6 @@ class TypePrinter : public TypeVisitor { // SWIFT_ENABLE_TENSORFLOW if (!Options.excludeAttrKind(TAK_differentiable) && info.isDifferentiable()) { - // FIXME(rxwei): Print differentiation order. if (info.getDifferentiabilityKind() == DifferentiabilityKind::Linear) { Printer << "@differentiable(linear) "; } else { diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index bdeec68bdaae5..5fa24f94b9e43 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -123,19 +123,9 @@ void autodiff::getSubsetParameterTypes(AutoDiffIndexSubset *subset, } } -unsigned autodiff::getOffsetForAutoDiffAssociatedFunction( - unsigned order, AutoDiffAssociatedFunctionKind kind) { - return (order - 1) * getNumAutoDiffAssociatedFunctions(order) + kind.rawValue; -} - -unsigned -autodiff::getNumAutoDiffAssociatedFunctions(unsigned differentiationOrder) { - return differentiationOrder * 2; -} - bool autodiff::getBuiltinAutoDiffApplyConfig( StringRef operationName, AutoDiffAssociatedFunctionKind &kind, - unsigned &arity, unsigned &order, bool &rethrows) { + unsigned &arity, bool &rethrows) { if (!operationName.startswith("autodiffApply_")) return false; operationName = operationName.drop_front(strlen("autodiffApply_")); @@ -156,17 +146,6 @@ bool autodiff::getBuiltinAutoDiffApplyConfig( } else { arity = 1; } - // Parse '_order'. - if (operationName.startswith("_order")) { - operationName = operationName.drop_front(strlen("_order")); - auto orderStr = operationName.take_while(llvm::isDigit); - auto converted = llvm::to_integer(orderStr, order); - operationName = operationName.drop_front(orderStr.size()); - assert(converted); (void)converted; - assert(order > 0); - } else { - order = 1; - } // Parse '_rethrows'. if (operationName.startswith("_rethrows")) { operationName = operationName.drop_front(strlen("_rethrows")); diff --git a/lib/AST/Builtins.cpp b/lib/AST/Builtins.cpp index ec69ad75c820e..c899b78088213 100644 --- a/lib/AST/Builtins.cpp +++ b/lib/AST/Builtins.cpp @@ -996,9 +996,8 @@ static ValueDecl *getGetObjCTypeEncodingOperation(ASTContext &Context, // SWIFT_ENABLE_TENSORFLOW static ValueDecl *getAutoDiffApplyAssociatedFunction( ASTContext &Context, Identifier Id, AutoDiffAssociatedFunctionKind kind, - unsigned arity, unsigned order, bool rethrows) { + unsigned arity, bool rethrows) { assert(arity >= 1); - assert(order == 1 && "higher-order differentiation is not supported yet"); // JVP: // <...T...(arity), R> (@differentiable (...T) throws -> R, ...T) // rethrows -> (R, (...T.TangentVector) -> R.TangentVector) @@ -1047,7 +1046,7 @@ static ValueDecl *getAutoDiffApplyAssociatedFunction( BuiltinGenericSignatureBuilder::LambdaGenerator resultGen{ [=, &Context](BuiltinGenericSignatureBuilder &builder) -> Type { auto assocFnTy = origFnTy->getAutoDiffAssociatedFunctionType( - paramIndices, /*resultIndex*/ 0, /*differentiationOrder*/ 1, kind, + paramIndices, /*resultIndex*/ 0, kind, LookUpConformanceInModule(Context.TheBuiltinModule)); return assocFnTy->getResult(); }}; @@ -1842,13 +1841,13 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) { // SWIFT_ENABLE_TENSORFLOW if (OperationName.startswith("autodiffApply_")) { AutoDiffAssociatedFunctionKind kind; - unsigned arity, order; + unsigned arity; bool rethrows; if (!autodiff::getBuiltinAutoDiffApplyConfig(OperationName, kind, arity, - order, rethrows)) + rethrows)) return nullptr; return getAutoDiffApplyAssociatedFunction(Context, Id, kind, arity, - order, rethrows); + rethrows); } auto BV = llvm::StringSwitch(OperationName) #define BUILTIN(id, name, Attrs) .Case(name, BuiltinValueKind::id) diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index f6b6def7d8cde..72766ae027d3b 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -4569,8 +4569,7 @@ Optional TypeBase::getAutoDiffAssociatedTangentSpace( AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType( AutoDiffIndexSubset *indices, unsigned resultIndex, - unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind, - LookupConformanceFn lookupConformance, + AutoDiffAssociatedFunctionKind kind, LookupConformanceFn lookupConformance, GenericSignature *whereClauseGenSig, bool makeSelfParamFirst) { // JVP: (T...) -> ((R...), // (T.TangentVector...) -> (R.TangentVector...)) @@ -4581,7 +4580,6 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType( // "Closure" and then use common code to wrap "Closure" in the outer function // type. - assert(differentiationOrder == 1 && "only order 1 currently supported"); assert(!indices->isEmpty() && "there must be at least one wrt index"); auto &ctx = getASTContext(); diff --git a/lib/IRGen/GenDiffFunc.cpp b/lib/IRGen/GenDiffFunc.cpp index becb17a021784..9f86b7c2b61ea 100644 --- a/lib/IRGen/GenDiffFunc.cpp +++ b/lib/IRGen/GenDiffFunc.cpp @@ -32,8 +32,7 @@ using namespace swift; using namespace irgen; -/// A pair of `@differentiable` function extractee and differentiation order. -using DiffFuncIndex = std::pair; +using DiffFuncIndex = DifferentiableFunctionExtractee; namespace { class DiffFuncFieldInfo final : public RecordField { @@ -49,27 +48,24 @@ class DiffFuncFieldInfo final : public RecordField { AutoDiffIndexSubset *ParameterIndices; std::string getFieldName() const { - auto extractee = std::get<0>(Index); - auto differentiationOrder = std::get<1>(Index); - switch (extractee) { + switch (Index) { case DifferentiableFunctionExtractee::Original: return "original"; case DifferentiableFunctionExtractee::JVP: - return "jvp_" + llvm::itostr(differentiationOrder); + return "jvp"; case DifferentiableFunctionExtractee::VJP: - return "vjp_" + llvm::itostr(differentiationOrder); + return "vjp"; } } SILType getType(IRGenModule &IGM, SILType t) const { auto fnTy = t.castTo(); auto origFnTy = fnTy->getWithoutDifferentiability(); - if (std::get<0>(Index) == DifferentiableFunctionExtractee::Original) + if (Index == DifferentiableFunctionExtractee::Original) return SILType::getPrimitiveObjectType(origFnTy); - auto differentiationOrder = std::get<1>(Index); - auto kind = *std::get<0>(Index).getExtracteeAsAssociatedFunction(); + auto kind = *Index.getExtracteeAsAssociatedFunction(); auto assocTy = origFnTy->getAutoDiffAssociatedFunctionType( - ParameterIndices, /*resultIndex*/ 0, differentiationOrder, kind, + ParameterIndices, /*resultIndex*/ 0, kind, IGM.getSILTypes(), LookUpConformanceInModule(IGM.getSwiftModule())); return SILType::getPrimitiveObjectType(assocTy); } @@ -154,13 +150,12 @@ class DiffFuncTypeBuilder } SILType getType(DiffFuncIndex field) { - if (std::get<0>(field) == DifferentiableFunctionExtractee::Original) + if (field == DifferentiableFunctionExtractee::Original) return SILType::getPrimitiveObjectType(origFnTy->getCanonicalType()); - auto differentiationOrder = std::get<1>(field); - auto kind = *std::get<0>(field).getExtracteeAsAssociatedFunction(); + auto kind = *field.getExtracteeAsAssociatedFunction(); auto assocTy = origFnTy->getAutoDiffAssociatedFunctionType( - parameterIndices, /*resultIndex*/ 0, differentiationOrder, kind, - IGM.getSILTypes(), LookUpConformanceInModule(IGM.getSwiftModule())); + parameterIndices, /*resultIndex*/ 0, kind, IGM.getSILTypes(), + LookUpConformanceInModule(IGM.getSwiftModule())); return SILType::getPrimitiveObjectType(assocTy); } @@ -175,12 +170,7 @@ const TypeInfo * TypeConverter::convertDifferentiableFunctionType(SILFunctionType *type) { assert(type->isDifferentiable()); DiffFuncTypeBuilder builder(IGM, type); - SmallVector fields; - fields.push_back( - std::make_pair(DifferentiableFunctionExtractee::Original, 0)); - fields.push_back( - std::make_pair(DifferentiableFunctionExtractee::JVP, 1)); - fields.push_back( - std::make_pair(DifferentiableFunctionExtractee::VJP, 1)); - return builder.layout(fields); + return builder.layout({DifferentiableFunctionExtractee::Original, + DifferentiableFunctionExtractee::JVP, + DifferentiableFunctionExtractee::VJP}); } diff --git a/lib/IRGen/IRGenSIL.cpp b/lib/IRGen/IRGenSIL.cpp index c461efdaa65d8..f875e1e0d6e54 100644 --- a/lib/IRGen/IRGenSIL.cpp +++ b/lib/IRGen/IRGenSIL.cpp @@ -1876,17 +1876,15 @@ visitDifferentiableFunctionInst(DifferentiableFunctionInst *i) { auto origExp = getLoweredExplosion(i->getOriginalFunction()); Explosion e; e.add(origExp.claimAll()); - for (auto &assocFnOp : i->getAssociatedFunctions()) - e.add(getLoweredExplosion(assocFnOp.get()).claimAll()); + assert(i->hasDerivativeFunctions()); + for (auto &derivFnOperand : i->getDerivativeFunctionArray()) + e.add(getLoweredExplosion(derivFnOperand.get()).claimAll()); setLoweredExplosion(i, e); } void IRGenSILFunction:: visitDifferentiableFunctionExtractInst(DifferentiableFunctionExtractInst *i) { - unsigned structFieldOffset = 0; - if (i->getExtractee() != DifferentiableFunctionExtractee::Original) - structFieldOffset = 1 + autodiff::getOffsetForAutoDiffAssociatedFunction( - i->getDifferentiationOrder(), i->getAssociatedFunctionKind()); + unsigned structFieldOffset = i->getExtractee().rawValue; unsigned fieldSize = 1; auto fnRepr = i->getFunctionOperand()->getType().getFunctionRepresentation(); if (fnRepr == SILFunctionTypeRepresentation::Thick) { @@ -1894,6 +1892,7 @@ visitDifferentiableFunctionExtractInst(DifferentiableFunctionExtractInst *i) { fieldSize = 2; } auto diffFnExp = getLoweredExplosion(i->getFunctionOperand()); + assert(diffFnExp.size() == fieldSize * 3); Explosion e; e.add(diffFnExp.getRange(structFieldOffset, structFieldOffset + fieldSize)); (void)diffFnExp.claimAll(); diff --git a/lib/IRGen/LoadableByAddress.cpp b/lib/IRGen/LoadableByAddress.cpp index 63fe8331d7e10..a84e577891915 100644 --- a/lib/IRGen/LoadableByAddress.cpp +++ b/lib/IRGen/LoadableByAddress.cpp @@ -2736,20 +2736,16 @@ bool LoadableByAddress::recreateConvInstr(SILInstruction &I, // SWIFT_ENABLE_TENSORFLOW case SILInstructionKind::DifferentiableFunctionInst: { auto instr = cast(convInstr); - SmallVector associatedFunctions; - for (auto &assocFn : instr->getAssociatedFunctions()) - associatedFunctions.push_back(assocFn.get()); newInstr = convBuilder.createDifferentiableFunction( instr->getLoc(), instr->getParameterIndices(), - instr->getDifferentiationOrder(), instr->getOriginalFunction(), - associatedFunctions); + instr->getOriginalFunction(), + instr->getOptionalDerivativeFunctionPair()); break; } case SILInstructionKind::DifferentiableFunctionExtractInst: { auto instr = cast(convInstr); newInstr = convBuilder.createDifferentiableFunctionExtract( - instr->getLoc(), instr->getExtractee(), - instr->getDifferentiationOrder(), instr->getFunctionOperand()); + instr->getLoc(), instr->getExtractee(), instr->getFunctionOperand()); break; } // SWIFT_ENABLE_TENSORFLOW END diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index 85a6797796ceb..daf296a3f345d 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -1636,7 +1636,6 @@ bool SILParser::parseSILDeclRef(SILDeclRef &Result, ParseState = 3; } else if (Id.str() == "jvp" || Id.str() == "vjp") { AutoDiffAssociatedFunctionKind kind; - unsigned differentiationOrder; AutoDiffIndexSubset *parameterIndices = nullptr; if (Id.str() == "jvp") @@ -1651,15 +1650,6 @@ bool SILParser::parseSILDeclRef(SILDeclRef &Result, return true; } - if (parseInteger(differentiationOrder, - diag::sil_const_expected_int_value)) - return true; - - if (!P.consumeIf(tok::period)) { - P.diagnose(P.Tok, diag::expected_tok_in_sil_instr, "."); - return true; - } - parameterIndices = AutoDiffIndexSubset::getFromString( SILMod.getASTContext(), P.Tok.getText()); if (!parameterIndices) { @@ -1669,8 +1659,7 @@ bool SILParser::parseSILDeclRef(SILDeclRef &Result, P.consumeToken(); autoDiffFuncId = AutoDiffAssociatedFunctionIdentifier::get( - kind, differentiationOrder, parameterIndices, - SILMod.getASTContext()); + kind, parameterIndices, SILMod.getASTContext()); break; } else @@ -2937,7 +2926,6 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { // ^ jvp ^ vjp SourceLoc lastLoc; SmallVector parameterIndices; - unsigned order = 1; // Parse optional `[wrt ...]` if (P.Tok.is(tok::l_square) && P.peekToken().is(tok::identifier) && @@ -2957,24 +2945,6 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { "parameter index list")) return true; } - // Parse optional `[order ]`. - if (P.Tok.is(tok::l_square) && - P.peekToken().is(tok::identifier) && - P.peekToken().getText() == "order") { - P.consumeToken(tok::l_square); - P.consumeToken(tok::identifier); - // Parse an order. - if (P.parseUnsignedInteger(order, lastLoc, - diag::sil_inst_autodiff_expected_order) || - P.parseToken(tok::r_square, - diag::sil_inst_autodiff_attr_expected_rsquare, - "differentiation order")) - return true; - if (order == 0) { - P.diagnose(lastLoc, diag::sil_inst_autodiff_expected_nonzero_order); - return true; - } - } // Parse the original function value. SILValue original; SourceLoc originalOperandLoc; @@ -2986,29 +2956,24 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { diag::sil_inst_autodiff_expected_function_type_operand); return true; } - SmallVector associatedFunctions; - // Parse optional operand lists `with { , }, ...`. + Optional> derivativeFunctions = None; + // Parse an optional operand list `with { , }`. if (P.Tok.is(tok::identifier) && P.Tok.getText() == "with") { P.consumeToken(tok::identifier); - // Parse associated function values as operand lists. There are as many - // operand lists as the differentiation order. - associatedFunctions.reserve(2 * order); - for (unsigned listIdx = 0; listIdx < order; ++listIdx) { - // FIXME(rxwei): Change this to *not* require a type signature once - // we can infer AD associated function types. - SILValue newAssocFn1, newAssocFn2; - if (P.parseToken(tok::l_brace, - diag::sil_inst_autodiff_operand_list_expected_lbrace) || - parseTypedValueRef(newAssocFn1, B) || - P.parseToken(tok::comma, - diag::sil_inst_autodiff_operand_list_expected_comma) || - parseTypedValueRef(newAssocFn2, B) || - P.parseToken(tok::r_brace, - diag::sil_inst_autodiff_operand_list_expected_rbrace)) - return true; - associatedFunctions.push_back(newAssocFn1); - associatedFunctions.push_back(newAssocFn2); - } + // Parse associated function values as an operand list. + // FIXME(rxwei): Change this to *not* require a type signature once + // we can infer AD associated function types. + SILValue derivFn1, derivFn2; + if (P.parseToken(tok::l_brace, + diag::sil_inst_autodiff_operand_list_expected_lbrace) || + parseTypedValueRef(derivFn1, B) || + P.parseToken(tok::comma, + diag::sil_inst_autodiff_operand_list_expected_comma) || + parseTypedValueRef(derivFn2, B) || + P.parseToken(tok::r_brace, + diag::sil_inst_autodiff_operand_list_expected_rbrace)) + return true; + derivativeFunctions = std::make_pair(derivFn1, derivFn2); if (P.Tok.is(tok::l_brace)) { P.diagnose(P.Tok, diag::sil_inst_autodiff_num_operand_list_order_mismatch); @@ -3021,16 +2986,15 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { AutoDiffIndexSubset::get(P.Context, fnType->getNumParameters(), parameterIndices); ResultVal = B.createDifferentiableFunction( - InstLoc, parameterIndicesSubset, order, original, associatedFunctions); + InstLoc, parameterIndicesSubset, original, derivativeFunctions); break; } case SILInstructionKind::DifferentiableFunctionExtractInst: { - // Parse the rest of the instruction: an extractee, a differentiation order, - // a differentiable function operand, and a debug location. + // Parse the rest of the instruction: an extractee, a differentiable + // function operand, and a debug location. DifferentiableFunctionExtractee extractee; StringRef extracteeNames[3] = {"original", "jvp", "vjp"}; - unsigned order = 0; SILValue functionOperand; SourceLoc lastLoc; if (P.parseToken(tok::l_square, @@ -3041,26 +3005,11 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { diag::sil_inst_autodiff_attr_expected_rsquare, "associated function kind")) return true; - if (P.Tok.is(tok::l_square) && P.peekToken().is(tok::identifier) && - P.peekToken().getText() == "order") { - P.consumeToken(tok::l_square); - P.consumeToken(tok::identifier); - if (P.parseUnsignedInteger(order, lastLoc, - diag::sil_inst_autodiff_expected_order) || - P.parseToken(tok::r_square, - diag::sil_inst_autodiff_attr_expected_rsquare, - "differentiation order")) - return true; - if (order == 0) { - P.diagnose(lastLoc, diag::sil_inst_autodiff_expected_nonzero_order); - return true; - } - } if (parseTypedValueRef(functionOperand, B) || parseSILDebugLocation(InstLoc, B)) return true; ResultVal = B.createDifferentiableFunctionExtract( - InstLoc, extractee, order, functionOperand); + InstLoc, extractee, functionOperand); break; } // SWIFT_ENABLE_TENSORFLOW END diff --git a/lib/SIL/SILFunctionType.cpp b/lib/SIL/SILFunctionType.cpp index caf865c2610de..0a43f3cf20fae 100644 --- a/lib/SIL/SILFunctionType.cpp +++ b/lib/SIL/SILFunctionType.cpp @@ -114,9 +114,7 @@ SILFunctionType::getDifferentiationParameterIndices() { } CanSILFunctionType SILFunctionType::getWithDifferentiability( - unsigned differentiationOrder, AutoDiffIndexSubset *parameterIndices) { - // FIXME(rxwei): Handle differentiation order. - + AutoDiffIndexSubset *parameterIndices) { SmallVector newParameters; for (auto paramAndIndex : enumerate(getParameters())) { auto ¶m = paramAndIndex.value(); @@ -183,9 +181,8 @@ static CanGenericSignature getAutoDiffAssociatedFunctionGenericSignature( CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType( AutoDiffIndexSubset *parameterIndices, unsigned resultIndex, - unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind, - TypeConverter &TC, LookupConformanceFn lookupConformance, - CanGenericSignature assocFnGenSig) { + AutoDiffAssociatedFunctionKind kind, TypeConverter &TC, + LookupConformanceFn lookupConformance, CanGenericSignature assocFnGenSig) { // JVP: (T...) -> ((R...), // (T.TangentVector...) -> (R.TangentVector...)) // VJP: (T...) -> ((R...), @@ -2353,8 +2350,7 @@ const SILConstantInfo &TypeConverter::getConstantInfo(SILDeclRef constant) { auto loweredIndices = autodiff::getLoweredParameterIndices( autoDiffFuncId->getParameterIndices(), formalInterfaceType); silFnType = origFnConstantInfo.SILFnType->getAutoDiffAssociatedFunctionType( - loweredIndices, /*resultIndex*/ 0, - autoDiffFuncId->getDifferentiationOrder(), autoDiffFuncId->getKind(), + loweredIndices, /*resultIndex*/ 0, autoDiffFuncId->getKind(), *this, LookUpConformanceInModule(&M)); } diff --git a/lib/SIL/SILInstructions.cpp b/lib/SIL/SILInstructions.cpp index 0b836b4835261..f7d75e08fcfc1 100644 --- a/lib/SIL/SILInstructions.cpp +++ b/lib/SIL/SILInstructions.cpp @@ -578,61 +578,50 @@ TryApplyInst *TryApplyInst::create( } // SWIFT_ENABLE_TENSORFLOW -SILType DifferentiableFunctionInst::getAutoDiffType( - SILValue originalFunction, unsigned differentiationOrder, - AutoDiffIndexSubset *parameterIndices) { +SILType DifferentiableFunctionInst::getDifferentiableFunctionType( + SILValue originalFunction, AutoDiffIndexSubset *parameterIndices) { auto fnTy = originalFunction->getType().castTo(); - auto diffTy = - fnTy->getWithDifferentiability(differentiationOrder, parameterIndices); + auto diffTy = fnTy->getWithDifferentiability(parameterIndices); return SILType::getPrimitiveObjectType(diffTy); } +ValueOwnershipKind DifferentiableFunctionInst::getMergedOwnershipKind( + SILValue original, ArrayRef derivativeFunctions) { + if (derivativeFunctions.empty()) + return original.getOwnershipKind(); + return *mergeSILValueOwnership( + {original, derivativeFunctions[0], derivativeFunctions[1]}); +} + DifferentiableFunctionInst::DifferentiableFunctionInst( - SILModule &module, SILDebugLocation debugLoc, - AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, - SILValue originalFunction, ArrayRef associatedFunctions) + SILDebugLocation DebugLoc, AutoDiffIndexSubset *ParameterIndices, + SILValue OriginalFunction, ArrayRef DerivativeFunctions, + bool HasOwnership) : InstructionBaseWithTrailingOperands( - originalFunction, associatedFunctions, debugLoc, - getAutoDiffType(originalFunction, differentiationOrder, - parameterIndices), - originalFunction.getOwnershipKind()), - parameterIndices(parameterIndices), - differentiationOrder(differentiationOrder), - numOperands(1 + associatedFunctions.size()) {} + OriginalFunction, DerivativeFunctions, DebugLoc, + getDifferentiableFunctionType(OriginalFunction, ParameterIndices), + HasOwnership + ? getMergedOwnershipKind(OriginalFunction, DerivativeFunctions) + : ValueOwnershipKind(ValueOwnershipKind::Any)), + ParameterIndices(ParameterIndices), + HasDerivativeFunctions(!DerivativeFunctions.empty()) { + assert(DerivativeFunctions.empty() || DerivativeFunctions.size() == 2); +} DifferentiableFunctionInst *DifferentiableFunctionInst::create( - SILModule &module, SILDebugLocation debugLoc, - AutoDiffIndexSubset *parameterIndices, - unsigned differentiationOrder, SILValue originalFunction, - ArrayRef associatedFunctions) { - size_t size = totalSizeToAlloc(associatedFunctions.size() + 1); - void *buffer = module.allocateInst(size, alignof(DifferentiableFunctionInst)); - return ::new (buffer) DifferentiableFunctionInst(module, debugLoc, - parameterIndices, - differentiationOrder, - originalFunction, - associatedFunctions); -} - -std::pair DifferentiableFunctionInst:: -getAssociatedFunctionPair(unsigned differentiationOrder) const { - assert(differentiationOrder > 0 && - differentiationOrder <= this->differentiationOrder); - assert(!getAssociatedFunctions().empty() && "No associated functions. Maybe " - "the differentiation pass has not run?"); - auto offset = (differentiationOrder - 1) * 2; - auto assocFns = getAssociatedFunctions(); - return {assocFns[offset].get(), assocFns[offset+1].get()}; -} - -SILValue DifferentiableFunctionInst:: -getAssociatedFunction(unsigned differentiationOrder, - AutoDiffAssociatedFunctionKind kind) const { - assert(differentiationOrder > 0 && - differentiationOrder <= this->differentiationOrder); - auto offset = autodiff::getOffsetForAutoDiffAssociatedFunction( - differentiationOrder, kind); - return getAssociatedFunctions()[offset].get(); + SILModule &Module, SILDebugLocation DebugLoc, + AutoDiffIndexSubset *ParameterIndices, SILValue OriginalFunction, + Optional> VJPAndJVPFunctions, + bool HasOwnership) { + auto derivativeFunctions = VJPAndJVPFunctions.hasValue() + ? ArrayRef( + reinterpret_cast(VJPAndJVPFunctions.getPointer()), 2) + : ArrayRef(); + size_t size = totalSizeToAlloc(1 + derivativeFunctions.size()); + void *buffer = Module.allocateInst(size, alignof(DifferentiableFunctionInst)); + return ::new (buffer) DifferentiableFunctionInst( + DebugLoc, ParameterIndices, OriginalFunction, derivativeFunctions, + HasOwnership); } DifferentiableFunctionExtractInst::Extractee::Extractee( @@ -671,33 +660,28 @@ DifferentiableFunctionExtractInst::Extractee::getExtracteeAsAssociatedFunction() } SILType DifferentiableFunctionExtractInst:: -getExtracteeType(SILValue function, Extractee extractee, - unsigned differentiationOrder, SILModule &module) { +getExtracteeType(SILValue function, Extractee extractee, SILModule &module) { auto fnTy = function->getType().castTo(); assert(fnTy->getExtInfo().isDifferentiable()); auto originalFnTy = fnTy->getWithoutDifferentiability(); auto kindOpt = extractee.getExtracteeAsAssociatedFunction(); if (!kindOpt) { assert(extractee == Extractee::Original); - assert(differentiationOrder == 0); return SILType::getPrimitiveObjectType(originalFnTy); } auto resultFnTy = originalFnTy->getAutoDiffAssociatedFunctionType( fnTy->getDifferentiationParameterIndices(), /*resultIndex*/ 0, - differentiationOrder, *kindOpt, module.Types, + *kindOpt, module.Types, LookUpConformanceInModule(module.getSwiftModule())); return SILType::getPrimitiveObjectType(resultFnTy); } DifferentiableFunctionExtractInst::DifferentiableFunctionExtractInst( SILModule &module, SILDebugLocation debugLoc, Extractee extractee, - unsigned differentiationOrder, SILValue theFunction) + SILValue theFunction) : InstructionBase(debugLoc, - getExtracteeType(theFunction, extractee, - differentiationOrder, module), - theFunction.getOwnershipKind()), - extractee(extractee), differentiationOrder(differentiationOrder), - operands(this, theFunction) {} + getExtracteeType(theFunction, extractee, module)), + extractee(extractee), operands(this, theFunction) {} // SWIFT_ENABLE_TENSORFLOW END FunctionRefBaseInst::FunctionRefBaseInst(SILInstructionKind Kind, diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index 13cf14af620b9..0aa45bf1300d7 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -361,8 +361,7 @@ void SILDeclRef::print(raw_ostream &OS) const { OS << "vjp."; break; } - OS << autoDiffFuncId->getDifferentiationOrder() << "." - << autoDiffFuncId->getParameterIndices()->getString(); + OS << autoDiffFuncId->getParameterIndices()->getString(); } } @@ -1169,16 +1168,11 @@ class SILPrinter : public SILInstructionVisitor { *this << ' ' << i; *this << "] "; } - *this << "[order " << dfi->getDifferentiationOrder() << "] "; *this << getIDAndType(dfi->getOriginalFunction()); - if (!dfi->getAssociatedFunctions().empty()) { + if (dfi->hasDerivativeFunctions()) { *this << " with "; - interleave(range(1, dfi->getDifferentiationOrder() + 1), - [&](unsigned order) { - auto pair = dfi->getAssociatedFunctionPair(order); - *this << '{' << getIDAndType(pair.first) << ", " - << getIDAndType(pair.second) << '}'; - }, [this] { *this << ", "; }); + *this << '{' << getIDAndType(dfi->getJVPFunction()) << ", " + << getIDAndType(dfi->getVJPFunction()) << '}'; } } @@ -1197,9 +1191,6 @@ class SILPrinter : public SILInstructionVisitor { break; } *this << "] "; - auto order = dfei->getDifferentiationOrder(); - if (order > 0) - *this << "[order " << order << "] "; *this << getIDAndType(dfei->getFunctionOperand()); } diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp index cfdb0e408e098..b5e7d8b982c9d 100644 --- a/lib/SIL/SILVerifier.cpp +++ b/lib/SIL/SILVerifier.cpp @@ -1493,59 +1493,48 @@ class SILVerifier : public SILVerifierBase { // SWIFT_ENABLE_TENSORFLOW void checkDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) { - require(dfi->getDifferentiationOrder() > 0, - "The differentiation order must be non-zero"); auto origTy = dfi->getOriginalFunction()->getType().getAs(); require(origTy, "The original function must have a function type"); require(!origTy->isDifferentiable(), "The original function must not be @differentiable"); if (F.getModule().getStage() == SILStage::Canonical || - dfi->hasAssociatedFunctions()) { - for (auto order : range(1, dfi->getDifferentiationOrder() + 1)) { - auto pair = dfi->getAssociatedFunctionPair(order); - auto jvpType = pair.first->getType().getAs(); - require(jvpType, "The JVP function must have a function type"); - require(!jvpType->isDifferentiable(), - "The JVP function must not be @differentiable"); - auto expectedJVPType = origTy->getAutoDiffAssociatedFunctionType( - dfi->getParameterIndices(), /*resultIndex*/ 0, order, - AutoDiffAssociatedFunctionKind::JVP, TC, - LookUpConformanceInModule(M)); - requireSameType(SILType::getPrimitiveObjectType(jvpType), - SILType::getPrimitiveObjectType(expectedJVPType), - "JVP type does not match expected JVP type"); - auto vjpType = pair.second->getType().getAs(); - require(vjpType, "The VJP function must have a function type"); - require(!vjpType->isDifferentiable(), - "The VJP function must not be @differentiable"); - auto expectedVJPType = origTy->getAutoDiffAssociatedFunctionType( - dfi->getParameterIndices(), /*resultIndex*/ 0, order, - AutoDiffAssociatedFunctionKind::VJP, TC, - LookUpConformanceInModule(M)); - requireSameType(SILType::getPrimitiveObjectType(vjpType), - SILType::getPrimitiveObjectType(expectedVJPType), - "VJP type does not match expected VJP type"); - } + dfi->hasDerivativeFunctions()) { + auto jvp = dfi->getJVPFunction(); + auto jvpType = jvp->getType().getAs(); + require(jvpType, "The JVP function must have a function type"); + require(!jvpType->isDifferentiable(), + "The JVP function must not be @differentiable"); + auto expectedJVPType = origTy->getAutoDiffAssociatedFunctionType( + dfi->getParameterIndices(), /*resultIndex*/ 0, + AutoDiffAssociatedFunctionKind::JVP, TC, + LookUpConformanceInModule(M)); + requireSameType(SILType::getPrimitiveObjectType(jvpType), + SILType::getPrimitiveObjectType(expectedJVPType), + "JVP type does not match expected JVP type"); + auto vjp = dfi->getVJPFunction(); + auto vjpType = vjp->getType().getAs(); + require(vjpType, "The VJP function must have a function type"); + require(!vjpType->isDifferentiable(), + "The VJP function must not be @differentiable"); + auto expectedVJPType = origTy->getAutoDiffAssociatedFunctionType( + dfi->getParameterIndices(), /*resultIndex*/ 0, + AutoDiffAssociatedFunctionKind::VJP, TC, + LookUpConformanceInModule(M)); + requireSameType(SILType::getPrimitiveObjectType(vjpType), + SILType::getPrimitiveObjectType(expectedVJPType), + "VJP type does not match expected VJP type"); } } void checkDifferentiableFunctionExtractInst( DifferentiableFunctionExtractInst *dfei) { - if (dfei->getExtractee() == DifferentiableFunctionExtractee::Original) - require(dfei->getDifferentiationOrder() == 0, - "Differentiation order should not have been set when the " - "original function is being extracted"); - else - require(dfei->getDifferentiationOrder() > 0, - "Extraction of associated functions requires a differentiation " - "order"); auto fnTy = dfei->getFunctionOperand()->getType().getAs(); require(fnTy, "The function operand must have a function type"); require(fnTy->isDifferentiable(), "The function operand must be an '@differentiable' function"); } - // SWIFT_ENABLE_TENSORFLOW + // SWIFT_ENABLE_TENSORFLOW END void verifyLLVMIntrinsic(BuiltinInst *BI, llvm::Intrinsic::ID ID) { // Certain llvm intrinsic require constant values as their operands. diff --git a/lib/SIL/TypeLowering.cpp b/lib/SIL/TypeLowering.cpp index b5946c843eb72..54fb9ca468729 100644 --- a/lib/SIL/TypeLowering.cpp +++ b/lib/SIL/TypeLowering.cpp @@ -154,11 +154,11 @@ namespace { auto origTy = type->getWithoutDifferentiability(); auto jvpTy = origTy->getAutoDiffAssociatedFunctionType( type->getDifferentiationParameterIndices(), /*resultIndex*/ 0, - /*differentiationOrder*/ 1, AutoDiffAssociatedFunctionKind::JVP, TC, + AutoDiffAssociatedFunctionKind::JVP, TC, LookUpConformanceInModule(&M)); auto vjpTy = origTy->getAutoDiffAssociatedFunctionType( type->getDifferentiationParameterIndices(), /*resultIndex*/ 0, - /*differentiationOrder*/ 1, AutoDiffAssociatedFunctionKind::VJP, TC, + AutoDiffAssociatedFunctionKind::VJP, TC, LookUpConformanceInModule(&M)); RecursiveProperties props; props.addSubobject(classifyType(origTy, TC, Sig, Expansion)); @@ -862,62 +862,52 @@ namespace { }; // SWIFT_ENABLE_TENSORFLOW - using DifferentiableSILFunctionTypeIndex = - std::pair; class DifferentiableSILFunctionTypeLowering final : public LoadableAggTypeLowering { + DifferentiableFunctionExtractee> { public: using LoadableAggTypeLowering::LoadableAggTypeLowering; SILValue emitRValueProject(SILBuilder &B, SILLocation loc, SILValue tupleValue, - DifferentiableSILFunctionTypeIndex index, + DifferentiableFunctionExtractee extractee, const TypeLowering &eltLowering) const { return B.createDifferentiableFunctionExtract( - loc, index.first, index.second, tupleValue); + loc, extractee, tupleValue); } SILValue rebuildAggregate(SILBuilder &B, SILLocation loc, ArrayRef values) const override { + assert(values.size() == 3); auto fnTy = getLoweredType().castTo(); auto paramIndices = fnTy->getDifferentiationParameterIndices(); - // TODO: Retrieve the differentiation order when that is properly stored - // in the function type. - unsigned maxOrder = 1; return B.createDifferentiableFunction( - loc, paramIndices, maxOrder, values.front(), values.drop_front()); + loc, paramIndices, values[0], std::make_pair(values[1], values[2])); } void lowerChildren(TypeConverter &TC, SmallVectorImpl &children) const override { auto fnTy = getLoweredType().castTo(); - // TODO: Retrieve the differentiation order when that is properly stored - // in the function type. - auto maxOrder = 1; - auto numAssocFns = autodiff::getNumAutoDiffAssociatedFunctions(maxOrder); + auto numAssocFns = 2; children.reserve(numAssocFns + 1); auto origFnTy = fnTy->getWithoutDifferentiability(); auto paramIndices = fnTy->getDifferentiationParameterIndices(); children.push_back(Child{ - {DifferentiableFunctionExtractee::Original, 0}, + DifferentiableFunctionExtractee::Original, TC.getTypeLowering(origFnTy, getResilienceExpansion()) }); - for (auto order : range(1, maxOrder + 1)) { - for (AutoDiffAssociatedFunctionKind kind - : {AutoDiffAssociatedFunctionKind::JVP, - AutoDiffAssociatedFunctionKind::VJP}) { - auto assocFnTy = origFnTy->getAutoDiffAssociatedFunctionType( - paramIndices, 0, order, kind, TC, - LookUpConformanceInModule(&TC.M)); - auto silTy = SILType::getPrimitiveObjectType(assocFnTy); - children.push_back(Child{ - {DifferentiableFunctionExtractee(kind), order}, - TC.getTypeLowering(silTy, getResilienceExpansion()) - }); - } + for (auto kind : {AutoDiffAssociatedFunctionKind::JVP, + AutoDiffAssociatedFunctionKind::VJP}) { + auto assocFnTy = origFnTy->getAutoDiffAssociatedFunctionType( + paramIndices, 0, kind, TC, + LookUpConformanceInModule(&TC.M)); + auto silTy = SILType::getPrimitiveObjectType(assocFnTy); + children.push_back(Child{ + DifferentiableFunctionExtractee(kind), + TC.getTypeLowering(silTy, getResilienceExpansion()) + }); } - assert(children.size() == numAssocFns + 1); + assert(children.size() == 3); } }; @@ -2017,8 +2007,7 @@ CanAnyFunctionType TypeConverter::makeConstantInterfaceType(SILDeclRef c) { makeConstantInterfaceType(c.asAutoDiffOriginalFunction()); auto *fnTy = originalFnTy->getAutoDiffAssociatedFunctionType( autoDiffFuncId->getParameterIndices(), /*resultIndex*/ 0, - autoDiffFuncId->getDifferentiationOrder(), autoDiffFuncId->getKind(), - LookUpConformanceInModule(&M)); + autoDiffFuncId->getKind(), LookUpConformanceInModule(&M)); return cast(fnTy->getCanonicalType()); } diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp index 4fe13f361f150..36655ae74d840 100644 --- a/lib/SILGen/SILGen.cpp +++ b/lib/SILGen/SILGen.cpp @@ -774,10 +774,10 @@ void SILGenModule::postEmitFunction(SILDeclRef constant, auto lookUpConformance = LookUpConformanceInModule(M.getSwiftModule()); auto expectedJVPType = origSilFnType->getAutoDiffAssociatedFunctionType( - indices.parameters, indices.source, /*differentiationOrder*/ 1, + indices.parameters, indices.source, AutoDiffAssociatedFunctionKind::JVP, Types, lookUpConformance); auto expectedVJPType = origSilFnType->getAutoDiffAssociatedFunctionType( - indices.parameters, indices.source, /*differentiationOrder*/ 1, + indices.parameters, indices.source, AutoDiffAssociatedFunctionKind::VJP, Types, lookUpConformance); // Self reordering is necessary if wrt at least two parameters, including @@ -802,7 +802,7 @@ void SILGenModule::postEmitFunction(SILDeclRef constant, reorderSelf); } else { auto *id = AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::JVP, /*differentiationOrder*/ 1, + AutoDiffAssociatedFunctionKind::JVP, diffAttr->getParameterIndices(), AFD->getASTContext()); jvpThunk = getOrCreateAutoDiffThunk( constant.asAutoDiffAssociatedFunction(id), jvpFn, @@ -820,7 +820,7 @@ void SILGenModule::postEmitFunction(SILDeclRef constant, reorderSelf); } else { auto *id = AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::VJP, /*differentiationOrder*/ 1, + AutoDiffAssociatedFunctionKind::VJP, diffAttr->getParameterIndices(), AFD->getASTContext()); vjpThunk = getOrCreateAutoDiffThunk( constant.asAutoDiffAssociatedFunction(id), vjpFn, diff --git a/lib/SILGen/SILGenBuiltin.cpp b/lib/SILGen/SILGenBuiltin.cpp index 11b0ee7672db1..1781ce45f5999 100644 --- a/lib/SILGen/SILGenBuiltin.cpp +++ b/lib/SILGen/SILGenBuiltin.cpp @@ -1032,7 +1032,7 @@ static ManagedValue emitBuiltinTypeTrait(SILGenFunction &SGF, // SWIFT_ENABLE_TENSORFLOW static ManagedValue emitBuiltinAutoDiffApplyAssociatedFunction( - AutoDiffAssociatedFunctionKind kind, unsigned arity, unsigned order, + AutoDiffAssociatedFunctionKind kind, unsigned arity, bool rethrows, SILGenFunction &SGF, SILLocation loc, SubstitutionMap substitutions, ArrayRef args, SGFContext C) { auto origFnVal = args[0].getValue(); @@ -1042,7 +1042,7 @@ static ManagedValue emitBuiltinAutoDiffApplyAssociatedFunction( // Get the associated function. SILValue assocFn = SGF.B.createDifferentiableFunctionExtract( - loc, kind, /*differentiationOrder*/ 1, origFnVal); + loc, kind, origFnVal); auto assocFnType = assocFn->getType().castTo(); // We don't need to destroy the original function or retain the `assocFn`, @@ -1144,12 +1144,12 @@ static ManagedValue emitBuiltinAutoDiffApply(SILGenFunction &SGF, ->getDecl()); auto builtinName = builtinDecl->getName().str(); AutoDiffAssociatedFunctionKind kind; - unsigned arity, order; + unsigned arity; bool rethrows; auto successfullyParsed = autodiff::getBuiltinAutoDiffApplyConfig( - builtinName, kind, arity, order, rethrows); + builtinName, kind, arity, rethrows); assert(successfullyParsed); - return emitBuiltinAutoDiffApplyAssociatedFunction(kind, arity, order, + return emitBuiltinAutoDiffApplyAssociatedFunction(kind, arity, rethrows, SGF, loc, substitutions, args, C); } diff --git a/lib/SILGen/SILGenExpr.cpp b/lib/SILGen/SILGenExpr.cpp index 8597611e67854..8bc6613a4b19c 100644 --- a/lib/SILGen/SILGenExpr.cpp +++ b/lib/SILGen/SILGenExpr.cpp @@ -5431,10 +5431,8 @@ RValue RValueEmitter::visitDifferentiableFunctionExpr( DifferentiableFunctionExpr *E, SGFContext C) { auto origFunc = SGF.emitRValueAsSingleValue(E->getSubExpr()); auto destTy = SGF.getLoweredType(E->getType()).castTo(); - // TODO(rxwei): Use the order specified in E's function type. auto *diffFunc = SGF.B.createDifferentiableFunction( - E, destTy->getDifferentiationParameterIndices(), /*order*/ 1, - origFunc.forward(SGF)); + E, destTy->getDifferentiationParameterIndices(), origFunc.forward(SGF)); return RValue(SGF, E, SGF.emitManagedRValueWithCleanup(diffFunc)); } diff --git a/lib/SILGen/SILGenPoly.cpp b/lib/SILGen/SILGenPoly.cpp index b977af14cdd88..60dbaa5d781b4 100644 --- a/lib/SILGen/SILGenPoly.cpp +++ b/lib/SILGen/SILGenPoly.cpp @@ -3269,8 +3269,7 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF, outputSubstType->getWithoutDifferentiability()); auto &expectedTLNotDiff = SGF.getTypeLowering(outputOrigTypeNotDiff, outputSubstTypeNotDiff); - // `differentiable_function_extract` is consuming; copy `fn` before passing as - // operand. + // `differentiable_function_extract` takes `@guaranteed` values. auto borrowedFnValue = fn.borrow(SGF, loc); SILValue original = SGF.B.createDifferentiableFunctionExtractOriginal( loc, borrowedFnValue.getValue()); @@ -3297,7 +3296,7 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF, [&](CanAnyFunctionType fnTy, AutoDiffAssociatedFunctionKind kind) -> CanAnyFunctionType { auto assocTy = fnTy->getAutoDiffAssociatedFunctionType( - parameterIndices, /*resultIndex*/ 0, /*differentiationOrder*/ 1, + parameterIndices, /*resultIndex*/ 0, kind, LookUpConformanceInModule(SGF.SGM.M.getSwiftModule())); return cast(assocTy->getCanonicalType()); }; @@ -3322,7 +3321,7 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF, auto &assocFnExpectedTL = SGF.getTypeLowering(assocFnOutputOrigType, assocFnOutputSubstType); SILValue assocFn = SGF.B.createDifferentiableFunctionExtract( - loc, kind, /*differentiationOrder*/ 1, borrowedFnValue.getValue()); + loc, kind, borrowedFnValue.getValue()); assocFn = SGF.B.emitCopyValueOperation(loc, assocFn); auto managedAssocFn = SGF.emitManagedRValueWithCleanup(assocFn); return createThunk(SGF, loc, managedAssocFn, assocFnInputOrigType, @@ -3335,9 +3334,8 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF, SILValue convertedBundle = SGF.B.createDifferentiableFunction( loc, sourceType->getDifferentiationParameterIndices(), - /*differentiationOrder*/ 1, originalThunk.forward(SGF), - {jvpThunk.forward(SGF), vjpThunk.forward(SGF)}); + std::make_pair(jvpThunk.forward(SGF), vjpThunk.forward(SGF))); return SGF.emitManagedRValueWithCleanup(convertedBundle); } @@ -3687,7 +3685,7 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk( auto origFnType = original->getLoweredFunctionType(); auto origAssocFnType = origFnType->getAutoDiffAssociatedFunctionType( - indices.parameters, indices.source, /*differentiationOrder*/ 1, + indices.parameters, indices.source, assocFnKind, Types, LookUpConformanceInModule(M.getSwiftModule()), assocFnType->getGenericSignature()); assert(!origAssocFnType->getExtInfo().hasContext()); @@ -4307,10 +4305,10 @@ getWitnessFunctionRef(SILGenFunction &SGF, autoDiffFuncId->getParameterIndices(), witness.getDecl()->getInterfaceType()->castTo()); auto autoDiffFn = SGF.B.createDifferentiableFunction( - loc, loweredIndices, /*differentiationOrder*/ 1, originalFn); + loc, loweredIndices, originalFn); return SGF.B.createDifferentiableFunctionExtract( loc, DifferentiableFunctionExtractee(autoDiffFuncId->getKind()), - /*differentiationOrder*/ 1, autoDiffFn); + autoDiffFn); } return SGF.emitGlobalFunctionRef(loc, witness); diff --git a/lib/SILGen/SILGenThunk.cpp b/lib/SILGen/SILGenThunk.cpp index 387fc8912416e..84b4d4662f8a9 100644 --- a/lib/SILGen/SILGenThunk.cpp +++ b/lib/SILGen/SILGenThunk.cpp @@ -140,10 +140,9 @@ SILFunction *SILGenModule::getOrCreateAutoDiffClassMethodThunk( autoDiffFuncId->getParameterIndices(), assocFnDecl->getInterfaceType()->castTo()); auto diffFn = SGF.B.createDifferentiableFunction( - loc, loweredIndices, /*differentiationOrder*/ 1, originalFnRef); + loc, loweredIndices, originalFnRef); auto diffAssocFn = SGF.B.createDifferentiableFunctionExtract( - loc, DifferentiableFunctionExtractee(autoDiffFuncId->getKind()), - /*differentiationOrder*/ 1, diffFn); + loc, DifferentiableFunctionExtractee(autoDiffFuncId->getKind()), diffFn); auto autoDiffAssocFnSILTy = SILType::getPrimitiveObjectType(constantTy); SmallVector args(thunk->getArguments().begin(), thunk->getArguments().end()); diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 177bc7863d703..7f3290aa00d25 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -1108,11 +1108,10 @@ class ADContext { /// pointer value as a previously processed and deleted instruction. DifferentiableFunctionInst *createDifferentiableFunction( SILBuilder &builder, SILLocation loc, - AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, - SILValue original, ArrayRef associatedFunctions = {}) { + AutoDiffIndexSubset *parameterIndices, SILValue original, + Optional> associatedFunctions = None) { auto *dfi = builder.createDifferentiableFunction( - loc, parameterIndices, differentiationOrder, original, - associatedFunctions); + loc, parameterIndices, original, associatedFunctions); processedDifferentiableFunctionInsts.erase(dfi); return dfi; } @@ -1685,8 +1684,7 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai, AutoDiffAssociatedFunctionKind assocFnKind(kind); auto assocFnType = remappedOrigFnSubstTy->getAutoDiffAssociatedFunctionType( - parameters, source, /*differentiationOrder*/ 1, assocFnKind, - context.getTypeConverter(), + parameters, source, assocFnKind, context.getTypeConverter(), LookUpConformanceInModule(derivative->getModule().getSwiftModule())); auto assocFnResultTypes = @@ -2579,8 +2577,7 @@ emitAssociatedFunctionReference( auto borrowedDiffFunc = builder.emitBeginBorrowOperation( functionSource.getLoc(), functionSource); SILValue assocFn = builder.createDifferentiableFunctionExtract( - borrowedDiffFunc.getLoc(), kind, /*differentiationOrder*/ 1, - borrowedDiffFunc); + borrowedDiffFunc.getLoc(), kind, borrowedDiffFunc); assocFn = builder.emitCopyValueOperation(functionSource.getLoc(), assocFn); builder.emitEndBorrowOperation(functionSource.getLoc(), borrowedDiffFunc); @@ -2718,11 +2715,10 @@ emitAssociatedFunctionReference( auto originalType = witnessMethod->getType().castTo(); auto assocType = originalType->getAutoDiffAssociatedFunctionType( minimalIndices.parameters, minimalIndices.source, - /*differentiationOrder*/ 1, kind, context.getTypeConverter(), + kind, context.getTypeConverter(), LookUpConformanceInModule(builder.getModule().getSwiftModule())); auto *autoDiffFuncId = AutoDiffAssociatedFunctionIdentifier::get( - kind, /*differentiationOrder*/ 1, minimalAttr->getParameterIndices(), - context.getASTContext()); + kind, minimalAttr->getParameterIndices(), context.getASTContext()); auto *ref = builder.createWitnessMethod( loc, witnessMethod->getLookupType(), witnessMethod->getConformance(), requirementDeclRef.asAutoDiffAssociatedFunction(autoDiffFuncId), @@ -2766,10 +2762,10 @@ emitAssociatedFunctionReference( auto originalType = classMethodInst->getType().castTo(); auto assocType = originalType->getAutoDiffAssociatedFunctionType( minimalIndices.parameters, minimalIndices.source, - /*differentiationOrder*/ 1, kind, context.getTypeConverter(), + kind, context.getTypeConverter(), LookUpConformanceInModule(builder.getModule().getSwiftModule())); auto *autoDiffFuncId = AutoDiffAssociatedFunctionIdentifier::get( - kind, /*differentiationOrder*/ 1, minimalAttr->getParameterIndices(), + kind, minimalAttr->getParameterIndices(), context.getASTContext()); auto *ref = builder.createClassMethod( loc, classMethodInst->getOperand(), @@ -3787,7 +3783,7 @@ class VJPEmitter final auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, original); vjpValue = builder.createDifferentiableFunctionExtract( loc, DifferentiableFunctionExtractInst::Extractee::VJP, - /*differentiationOrder*/ 1, borrowedDiffFunc); + borrowedDiffFunc); vjpValue = builder.emitCopyValueOperation(loc, vjpValue); } @@ -3858,8 +3854,7 @@ class VJPEmitter final } auto *diffFuncInst = context.createDifferentiableFunction( - getBuilder(), loc, indices.parameters, /*differentiationOrder*/ 1, - original); + getBuilder(), loc, indices.parameters, original); // Record the `differentiable_function` instruction. context.getDifferentiableFunctionInsts().push_back(diffFuncInst); @@ -3871,7 +3866,7 @@ class VJPEmitter final builder.emitBeginBorrowOperation(loc, diffFuncInst); auto extractedVJP = getBuilder().createDifferentiableFunctionExtract( loc, DifferentiableFunctionExtractInst::Extractee::VJP, - /*differentiationOrder*/ 1, borrowedADFunc); + borrowedADFunc); vjpValue = builder.emitCopyValueOperation(loc, extractedVJP); builder.emitEndBorrowOperation(loc, borrowedADFunc); builder.emitDestroyValueOperation(loc, diffFuncInst); @@ -5466,7 +5461,7 @@ class JVPEmitter final auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, original); jvpValue = builder.createDifferentiableFunctionExtract( loc, DifferentiableFunctionExtractInst::Extractee::JVP, - /*differentiationOrder*/ 1, borrowedDiffFunc); + borrowedDiffFunc); jvpValue = builder.emitCopyValueOperation(loc, jvpValue); } @@ -5533,8 +5528,7 @@ class JVPEmitter final return; auto *diffFuncInst = context.createDifferentiableFunction( - builder, loc, indices.parameters, /*differentiationOrder*/ 1, - original); + builder, loc, indices.parameters, original); // Record the `differentiable_function` instruction. context.getDifferentiableFunctionInsts().push_back(diffFuncInst); @@ -5546,7 +5540,7 @@ class JVPEmitter final builder.emitBeginBorrowOperation(loc, diffFuncInst); auto extractedJVP = builder.createDifferentiableFunctionExtract( loc, DifferentiableFunctionExtractInst::Extractee::JVP, - /*differentiationOrder*/ 1, borrowedADFunc); + borrowedADFunc); jvpValue = builder.emitCopyValueOperation(loc, extractedJVP); builder.emitEndBorrowOperation(loc, borrowedADFunc); builder.emitDestroyValueOperation(loc, diffFuncInst); @@ -7846,9 +7840,8 @@ ADContext::declareExternalAssociatedFunction( auto originalLoc = original->getLocation(); auto assocGenSig = getDerivativeGenericSignature(attr, original); auto assocFnTy = originalTy->getAutoDiffAssociatedFunctionType( - indices.parameters, indices.source, /*differentiationOrder*/ 1, kind, - module.Types, LookUpConformanceInModule(module.getSwiftModule()), - assocGenSig); + indices.parameters, indices.source, kind, module.Types, + LookUpConformanceInModule(module.getSwiftModule()), assocGenSig); SILOptFunctionBuilder fb(getTransform()); // Create external function declaration. auto *assocFn = fb.createFunction( @@ -7891,9 +7884,9 @@ static SILFunction *createEmptyVJP( ? vjpGenericSig->getGenericEnvironment() : nullptr; auto vjpType = originalTy->getAutoDiffAssociatedFunctionType( - indices.parameters, indices.source, /*differentiationOrder*/ 1, - AutoDiffAssociatedFunctionKind::VJP, module.Types, - LookUpConformanceInModule(module.getSwiftModule()), vjpGenericSig); + indices.parameters, indices.source, AutoDiffAssociatedFunctionKind::VJP, + module.Types, LookUpConformanceInModule(module.getSwiftModule()), + vjpGenericSig); SILOptFunctionBuilder fb(context.getTransform()); auto linkage = autodiff::getAutoDiffAssociatedFunctionLinkage( @@ -7941,7 +7934,7 @@ static SILFunction *createEmptyJVP( ? jvpGenericSig->getGenericEnvironment() : nullptr; auto jvpType = originalTy->getAutoDiffAssociatedFunctionType( - indices.parameters, indices.source, /*differentiationOrder*/ 1, + indices.parameters, indices.source, AutoDiffAssociatedFunctionKind::JVP, module.Types, LookUpConformanceInModule(module.getSwiftModule()), jvpGenericSig); @@ -8352,8 +8345,8 @@ ADContext::getOrCreateSubsetParametersThunkForAssociatedFunction( // Compute target type for thunking. auto assocFnType = assocFn->getType().castTo(); auto targetType = origFnType->getAutoDiffAssociatedFunctionType( - desiredIndices.parameters, desiredIndices.source, - /*differentiationOrder*/ 1, kind, module.Types, lookupConformance); + desiredIndices.parameters, desiredIndices.source, kind, module.Types, + lookupConformance); auto *caller = assocFn->getFunction(); if (targetType->hasArchetype()) { auto substTargetType = caller->mapTypeIntoContext( @@ -8501,7 +8494,6 @@ SILValue ADContext::promoteToDifferentiableFunction( auto origFnTy = origFnOperand->getType().castTo(); auto parameterIndices = dfi->getParameterIndices(); unsigned resultIndex = resultIndices[dfi]; - unsigned differentiationOrder = dfi->getDifferentiationOrder(); // Handle curry thunk applications specially. if (auto *ai = dyn_cast(origFnOperand)) { @@ -8550,7 +8542,6 @@ SILValue ADContext::promoteToDifferentiableFunction( SILBuilder thunkBuilder(retInst); auto *dfi = createDifferentiableFunction(thunkBuilder, loc, parameterIndices, - differentiationOrder, retInst->getOperand()); resultIndices[dfi] = resultIndex; thunkBuilder.createReturn(loc, dfi); @@ -8656,8 +8647,7 @@ SILValue ADContext::promoteToDifferentiableFunction( } } auto expectedAssocFnTy = origFnTy->getAutoDiffAssociatedFunctionType( - parameterIndices, resultIndex, differentiationOrder, - assocFnKind, getTypeConverter(), + parameterIndices, resultIndex, assocFnKind, getTypeConverter(), LookUpConformanceInModule(getModule().getSwiftModule())); // If `assocFn` is `@convention(thin)` but is expected to be // `@convention(thick)`, emit a `thin_to_thick` instruction. @@ -8677,8 +8667,8 @@ SILValue ADContext::promoteToDifferentiableFunction( auto origFnCopy = builder.emitCopyValueOperation(loc, origFnOperand); auto *newDFI = createDifferentiableFunction( - builder, loc, parameterIndices, differentiationOrder, origFnCopy, - assocFns); + builder, loc, parameterIndices, origFnCopy, + std::make_pair(assocFns[0], assocFns[1])); resultIndices[dfi] = resultIndex; getDifferentiableFunctionInsts().push_back(dfi); @@ -8712,8 +8702,8 @@ void ADContext::foldDifferentiableFunctionExtraction( continue; } // Fold associated function extractors. - auto assocFnValue = source->getAssociatedFunction( - dfei->getDifferentiationOrder(), dfei->getAssociatedFunctionKind()); + auto assocFnValue = + source->getDerivativeFunction(dfei->getAssociatedFunctionKind()); dfei->replaceAllUsesWith(assocFnValue); dfei->eraseFromParent(); } @@ -8721,8 +8711,8 @@ void ADContext::foldDifferentiableFunctionExtraction( // it. if (isInstructionTriviallyDead(source)) { SILBuilder builder(source); - for (auto &assocFn : source->getAssociatedFunctions()) - builder.emitDestroyAddrAndFold(source->getLoc(), assocFn.get()); + builder.emitDestroyAddrAndFold(source->getLoc(), source->getJVPFunction()); + builder.emitDestroyAddrAndFold(source->getLoc(), source->getVJPFunction()); source->eraseFromParent(); } // Mark `source` as processed so that it won't be reprocessed after deletion. @@ -8735,13 +8725,8 @@ bool ADContext::processDifferentiableFunctionInst( auto &s = getADDebugStream() << "Processing DifferentiableFunctionInst:\n"; dfi->printInContext(s); }); - - if (dfi->getNumAssociatedFunctions() == - autodiff::getNumAutoDiffAssociatedFunctions( - dfi->getDifferentiationOrder())) + if (dfi->hasDerivativeFunctions()) return false; - assert(dfi->getNumAssociatedFunctions() == 0 && - "some functions are already filled in but not all of them"); SILFunction *parent = dfi->getFunction(); auto loc = dfi->getLoc(); diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 628f592607038..d809a979f6592 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -3478,8 +3478,8 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { AnyFunctionType *expectedJVPFnTy = originalFnTy->getAutoDiffAssociatedFunctionType( checkedWrtParamIndices, /*resultIndex*/ 0, - /*differentiationOrder*/ 1, AutoDiffAssociatedFunctionKind::JVP, - lookupConformance, whereClauseGenSig, /*makeSelfParamFirst*/ true); + AutoDiffAssociatedFunctionKind::JVP, lookupConformance, + whereClauseGenSig, /*makeSelfParamFirst*/ true); auto isValidJVP = [&](FuncDecl *jvpCandidate) { TC.validateDecl(jvpCandidate); @@ -3504,8 +3504,8 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { AnyFunctionType *expectedVJPFnTy = originalFnTy->getAutoDiffAssociatedFunctionType( checkedWrtParamIndices, /*resultIndex*/ 0, - /*differentiationOrder*/ 1, AutoDiffAssociatedFunctionKind::VJP, - lookupConformance, whereClauseGenSig, /*makeSelfParamFirst*/ true); + AutoDiffAssociatedFunctionKind::VJP, lookupConformance, + whereClauseGenSig, /*makeSelfParamFirst*/ true); auto isValidVJP = [&](FuncDecl *vjpCandidate) { TC.validateDecl(vjpCandidate); diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index 29448db134a3e..e6ff6ad6b3c96 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -1015,8 +1015,6 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB, // SWIFT_ENABLE_TENSORFLOW Attr = 0, Attr2 = 0, NumSubs = 0, NumConformances = 0, IsNonThrowingApply = 0; - // SWIFT_ENABLE_TENSORFLOW - unsigned NumArguments = 0; ValueID ValID, ValID2, ValID3; TypeID TyID, TyID2, TyID3; TypeID ConcreteTyID; @@ -1120,13 +1118,13 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB, // SWIFT_ENABLE_TENSORFLOW case SIL_INST_DIFFERENTIABLE_FUNCTION: SILInstDifferentiableFunctionLayout::readRecord( - scratch, /*order*/ Attr, /*numParams*/ Attr2, NumArguments, + scratch, /*numParams*/ Attr, /*hasDerivativeFunctions*/ Attr2, ListOfValues); RawOpCode = (unsigned)SILInstructionKind::DifferentiableFunctionInst; break; case SIL_INST_DIFFERENTIABLE_FUNCTION_EXTRACT: SILInstDifferentiableFunctionExtractLayout::readRecord( - scratch, TyID, TyCategory, ValID, /*extractee*/ Attr, /*order*/ Attr2); + scratch, TyID, TyCategory, ValID, /*extractee*/ Attr); RawOpCode = (unsigned)SILInstructionKind::DifferentiableFunctionExtractInst; break; case SIL_INST_NO_OPERAND: @@ -1521,22 +1519,28 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB, } // SWIFT_ENABLE_TENSORFLOW case SILInstructionKind::DifferentiableFunctionInst: { - auto numParamIndices = ListOfValues.size() - NumArguments * 3; + bool hasDerivativeFunctions = (bool)Attr2; + unsigned numOperands = hasDerivativeFunctions ? 3 : 1; + auto numParamIndices = ListOfValues.size() - numOperands * 3; + assert(ListOfValues.size() == numParamIndices + numOperands * 3); auto rawParamIndices = map>(ListOfValues.take_front(numParamIndices), [](uint64_t i) { return (unsigned)i; }); - auto numParams = Attr2; + auto numParams = Attr; auto *paramIndices = AutoDiffIndexSubset::get(MF->getContext(), numParams, rawParamIndices); - SmallVector operands; - for (auto i = numParamIndices; i < NumArguments * 3; i += 3) { + SmallVector operands; + for (auto i = numParamIndices; + i < numParamIndices + numOperands * 3; i += 3) { auto astTy = MF->getType(ListOfValues[i]); auto silTy = getSILType(astTy, (SILValueCategory)ListOfValues[i+1]); operands.push_back(getLocalValue(ListOfValues[i+2], silTy)); } + Optional> derivativeFunctions = None; + if (hasDerivativeFunctions) + derivativeFunctions = std::make_pair(operands[1], operands[2]); ResultVal = Builder.createDifferentiableFunction( - Loc, paramIndices, /*differentiationOrder*/ Attr, operands[0], - ArrayRef(operands).drop_front()); + Loc, paramIndices, operands[0], derivativeFunctions); break; } case SILInstructionKind::DifferentiableFunctionExtractInst: { @@ -1544,9 +1548,8 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB, auto silTy = getSILType(astTy, SILValueCategory::Object); auto val = getLocalValue(ValID, silTy); DifferentiableFunctionExtractee extractee(Attr); - auto order = Attr2; ResultVal = - Builder.createDifferentiableFunctionExtract(Loc, extractee, order, val); + Builder.createDifferentiableFunctionExtract(Loc, extractee, val); break; } // SWIFT_ENABLE_TENSORFLOW END diff --git a/lib/Serialization/ModuleFormat.h b/lib/Serialization/ModuleFormat.h index 9c54d54c2799f..a3af7336da7c0 100644 --- a/lib/Serialization/ModuleFormat.h +++ b/lib/Serialization/ModuleFormat.h @@ -52,7 +52,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0; /// describe what change you made. The content of this comment isn't important; /// it just ensures a conflict if two people change the module format. /// Don't worry about adhering to the 80-column limit for this line. -const uint16_t SWIFTMODULE_VERSION_MINOR = 520; // store generic signature in AST/SIL differentiable attributes +const uint16_t SWIFTMODULE_VERSION_MINOR = 521; // remove order from 'differentiation_function' layout /// A standard hash seed used for all string hashes in a serialized module. /// diff --git a/lib/Serialization/SILFormat.h b/lib/Serialization/SILFormat.h index a7857acc330da..9ee90b4e22e43 100644 --- a/lib/Serialization/SILFormat.h +++ b/lib/Serialization/SILFormat.h @@ -414,9 +414,8 @@ namespace sil_block { // SWIFT_ENABLE_TENSORFLOW using SILInstDifferentiableFunctionLayout = BCRecordLayout< SIL_INST_DIFFERENTIABLE_FUNCTION, - BCVBR<8>, // differentiation order BCVBR<8>, // number of function parameters - BCVBR<8>, // number of operands + BCFixed<1>, // has derivative functions? BCArray // parameter indices and operands >; @@ -425,8 +424,7 @@ namespace sil_block { TypeIDField, SILTypeCategoryField, ValueIDField, - BCFixed<2>, // extractee - BCVBR<8> // order + BCFixed<2> // extractee >; // SIL instructions with one type. (alloc_stack) diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp index e2d67ba9d3413..67f6e8aeaba48 100644 --- a/lib/Serialization/SerializeSIL.cpp +++ b/lib/Serialization/SerializeSIL.cpp @@ -1000,8 +1000,8 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) { } SILInstDifferentiableFunctionLayout::emitRecord(Out, ScratchRecord, SILAbbrCodes[SILInstDifferentiableFunctionLayout::Code], - dfi->getDifferentiationOrder(), paramIndices->getCapacity(), - dfi->getNumOperands(), trailingInfo); + paramIndices->getCapacity(), dfi->hasDerivativeFunctions(), + trailingInfo); break; } case SILInstructionKind::DifferentiableFunctionExtractInst: { @@ -1013,7 +1013,7 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) { SILInstDifferentiableFunctionExtractLayout::emitRecord(Out, ScratchRecord, SILAbbrCodes[SILInstDifferentiableFunctionExtractLayout::Code], operandTypeRef, (unsigned)operandType.getCategory(), operandRef, - rawExtractee, dfei->getDifferentiationOrder()); + rawExtractee); break; } case SILInstructionKind::ApplyInst: { diff --git a/lib/TBDGen/TBDGen.cpp b/lib/TBDGen/TBDGen.cpp index e7e8cbca7e7dd..083f93f7b5c64 100644 --- a/lib/TBDGen/TBDGen.cpp +++ b/lib/TBDGen/TBDGen.cpp @@ -238,12 +238,12 @@ void TBDGenVisitor::visitAbstractFunctionDecl(AbstractFunctionDecl *AFD) { auto diffAttrs = AFD->getAttrs().getAttributes(); for (auto *DA : diffAttrs) { auto *jvpId = AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::JVP, /*differentiationOrder*/ 1, - DA->getParameterIndices(), AFD->getASTContext()); + AutoDiffAssociatedFunctionKind::JVP, DA->getParameterIndices(), + AFD->getASTContext()); addSymbol(SILDeclRef(AFD).asAutoDiffAssociatedFunction(jvpId)); auto *vjpId = AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::VJP, /*differentiationOrder*/ 1, - DA->getParameterIndices(), AFD->getASTContext()); + AutoDiffAssociatedFunctionKind::VJP, DA->getParameterIndices(), + AFD->getASTContext()); addSymbol(SILDeclRef(AFD).asAutoDiffAssociatedFunction(vjpId)); } @@ -301,13 +301,13 @@ void TBDGenVisitor::visitAbstractStorageDecl(AbstractStorageDecl *ASD) { auto diffAttrs = ASD->getAttrs().getAttributes(); for (auto *DA : diffAttrs) { auto *jvpId = AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::JVP, /*differentiationOrder*/ 1, - DA->getParameterIndices(), ASD->getASTContext()); + AutoDiffAssociatedFunctionKind::JVP, DA->getParameterIndices(), + ASD->getASTContext()); addSymbol(SILDeclRef(ASD->getAccessor(AccessorKind::Get)) .asAutoDiffAssociatedFunction(jvpId)); auto *vjpId = AutoDiffAssociatedFunctionIdentifier::get( - AutoDiffAssociatedFunctionKind::VJP, /*differentiationOrder*/ 1, - DA->getParameterIndices(), ASD->getASTContext()); + AutoDiffAssociatedFunctionKind::VJP, DA->getParameterIndices(), + ASD->getASTContext()); addSymbol(SILDeclRef(ASD->getAccessor(AccessorKind::Get)) .asAutoDiffAssociatedFunction(vjpId)); } diff --git a/test/AutoDiff/core_builtins.swift b/test/AutoDiff/core_builtins.swift index 996f19b301d0e..83528c00261d4 100644 --- a/test/AutoDiff/core_builtins.swift +++ b/test/AutoDiff/core_builtins.swift @@ -10,7 +10,7 @@ func evaldiff(_ f: @differentiable (T) -> // CHECK-SIL-LABEL: @{{.*}}evaldiff{{.*}} // CHECK-SIL: bb0([[ORIG_RES_BUF:%.*]] : $*U, [[ORIG_FN:%.*]] : $@differentiable @noescape @callee_guaranteed (@in_guaranteed T) -> @out U, [[ORIG_FN_ARG:%.*]] : $*T): -// CHECK-SIL: [[JVP_FN:%.*]] = differentiable_function_extract [jvp] [order 1] [[ORIG_FN]] : $@differentiable @noescape @callee_guaranteed (@in_guaranteed T) -> @out U +// CHECK-SIL: [[JVP_FN:%.*]] = differentiable_function_extract [jvp] [[ORIG_FN]] : $@differentiable @noescape @callee_guaranteed (@in_guaranteed T) -> @out U // CHECK-SIL: [[JVP_RES_BUF:%.*]] = alloc_stack $(U, @callee_guaranteed (@in_guaranteed T) -> @out U.TangentVector) // CHECK-SIL: [[JVP_RES_BUF_0:%.*]] = tuple_element_addr [[JVP_RES_BUF]] : $*(U, @callee_guaranteed (@in_guaranteed T) -> @out U.TangentVector), 0 // CHECK-SIL: [[DIFFERENTIAL:%.*]] = apply [[JVP_FN]]([[JVP_RES_BUF_0]], [[ORIG_FN_ARG]]) : $@noescape @callee_guaranteed (@in_guaranteed T) -> (@out U, @owned @callee_guaranteed (@in_guaranteed T) -> @out U.TangentVector) @@ -30,4 +30,4 @@ func evaldiff2(_ f: @di // CHECK-LABEL: @{{.*}}evaldiff2{{.*}} // CHECK: bb0({{.*}} : $*V, [[DIFFED:%.*]] : $@differentiable @noescape @callee_guaranteed (@in_guaranteed T, @in_guaranteed U) -> @out V, {{.*}} : $*T, {{.*}} : $*U): -// CHECK: differentiable_function_extract [jvp] [order 1] [[DIFFED]] : $@differentiable @noescape @callee_guaranteed (@in_guaranteed T, @in_guaranteed U) -> @out V // user: %14 +// CHECK: differentiable_function_extract [jvp] [[DIFFED]] : $@differentiable @noescape @callee_guaranteed (@in_guaranteed T, @in_guaranteed U) -> @out V // user: %14 diff --git a/test/AutoDiff/differentiable_function_inst.sil b/test/AutoDiff/differentiable_function_inst.sil index f03c6f94910f3..f6b3a4488cad4 100644 --- a/test/AutoDiff/differentiable_function_inst.sil +++ b/test/AutoDiff/differentiable_function_inst.sil @@ -39,19 +39,19 @@ bb0(%0 : $Float): sil @make_diff_func : $@convention(thin) () -> @differentiable @convention(thin) (Float) -> Float { bb0: %orig = function_ref @foo : $@convention(thin) (Float) -> Float - %undiffedFunc = differentiable_function [wrt 0] [order 1] %orig : $@convention(thin) (Float) -> Float + %undiffedFunc = differentiable_function [wrt 0] %orig : $@convention(thin) (Float) -> Float %vjp = function_ref @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) - %diffFunc = differentiable_function [wrt 0] [order 1] %orig : $@convention(thin) (Float) -> Float with {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} - %extractedVJP = differentiable_function_extract [vjp] [order 1] %diffFunc : $@differentiable @convention(thin) (Float) -> Float + %diffFunc = differentiable_function [wrt 0] %orig : $@convention(thin) (Float) -> Float with {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} + %extractedVJP = differentiable_function_extract [vjp] %diffFunc : $@differentiable @convention(thin) (Float) -> Float %extractedOriginal = differentiable_function_extract [original] %diffFunc : $@differentiable @convention(thin) (Float) -> Float return %undiffedFunc : $@differentiable @convention(thin) (Float) -> Float } // CHECK-LABEL: @make_diff_func : $@convention(thin) () -> @differentiable @convention(thin) (Float) -> Float // CHECK: [[FOO:%.*]] = function_ref @foo : $@convention(thin) (Float) -> Float -// CHECK: [[UNDIFFED_FOO:%.*]] = differentiable_function [wrt 0] [order 1] [[FOO]] : $@convention(thin) (Float) -> Float +// CHECK: [[UNDIFFED_FOO:%.*]] = differentiable_function [wrt 0] [[FOO]] : $@convention(thin) (Float) -> Float // CHECK: [[FOO_VJP:%.*]] = function_ref @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) -// CHECK: [[DIFFED_FOO:%.*]] = differentiable_function [wrt 0] [order 1] [[FOO]] : $@convention(thin) (Float) -> Float with {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), [[FOO_VJP]] : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} -// CHECK: [[EXTRACTED_VJP:%.*]] = differentiable_function_extract [vjp] [order 1] [[DIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float +// CHECK: [[DIFFED_FOO:%.*]] = differentiable_function [wrt 0] [[FOO]] : $@convention(thin) (Float) -> Float with {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), [[FOO_VJP]] : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} +// CHECK: [[EXTRACTED_VJP:%.*]] = differentiable_function_extract [vjp] [[DIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float // CHECK: [[EXTRACTED_ORIG:%.*]] = differentiable_function_extract [original] [[DIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float // CHECK: return [[UNDIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float diff --git a/test/AutoDiff/differentiable_function_inst_irgen.sil b/test/AutoDiff/differentiable_function_inst_irgen.sil index 05559b8cf2015..79cd8a6a9b702 100644 --- a/test/AutoDiff/differentiable_function_inst_irgen.sil +++ b/test/AutoDiff/differentiable_function_inst_irgen.sil @@ -36,9 +36,9 @@ sil @make_diff_func : $@convention(thin) () -> (@convention(thin) (Float) -> Flo bb0: %orig = function_ref @foo : $@convention(thin) (Float) -> Float %vjp = function_ref @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) - %diffFunc = differentiable_function [wrt 0] [order 1] %orig : $@convention(thin) (Float) -> Float with {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} + %diffFunc = differentiable_function [wrt 0] %orig : $@convention(thin) (Float) -> Float with {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} %extractedOrig = differentiable_function_extract [original] %diffFunc : $@differentiable @convention(thin) (Float) -> Float - %extractedVJP = differentiable_function_extract [vjp] [order 1] %diffFunc : $@differentiable @convention(thin) (Float) -> Float + %extractedVJP = differentiable_function_extract [vjp] %diffFunc : $@differentiable @convention(thin) (Float) -> Float %tuple = tuple (%extractedOrig : $@convention(thin) (Float) -> Float, %extractedVJP : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) return %tuple : $(@convention(thin) (Float) -> Float, @convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) } diff --git a/test/AutoDiff/differentiable_function_silgen.swift b/test/AutoDiff/differentiable_function_silgen.swift index 16ad288847a36..843e2a70b4c48 100644 --- a/test/AutoDiff/differentiable_function_silgen.swift +++ b/test/AutoDiff/differentiable_function_silgen.swift @@ -49,9 +49,9 @@ func apply() { // CHECK-SILGEN-LABEL: @{{.*}}apply{{.*}} // CHECK-SILGEN: [[ORIG:%.*]] = function_ref @{{.*}}thin{{.*}} : $@convention(thin) (Float) -> Float // CHECK-SILGEN-NEXT: [[ORIG_THICK:%.*]] = thin_to_thick_function [[ORIG]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float -// CHECK-SILGEN-NEXT: [[DIFFED:%.*]] = differentiable_function [wrt 0] [order 1] [[ORIG_THICK]] : $@callee_guaranteed (Float) -> Float +// CHECK-SILGEN-NEXT: [[DIFFED:%.*]] = differentiable_function [wrt 0] [[ORIG_THICK]] : $@callee_guaranteed (Float) -> Float -// CHECK-SIL: [[DIFFED:%.*]] = differentiable_function [wrt 0] [order 1] {{%.*}} : $@callee_guaranteed (Float) -> Float +// CHECK-SIL: [[DIFFED:%.*]] = differentiable_function [wrt 0] {{%.*}} : $@callee_guaranteed (Float) -> Float //===----------------------------------------------------------------------===// // Reabstraction @@ -75,14 +75,14 @@ func appliesReabstraction(_ f: @escaping @differentiable (Float) -> Float) { // CHECK-SILGEN: [[ORIG_COPY:%.*]] = copy_value [[ORIG]] : $@callee_guaranteed (Float) -> Float // CHECK-SILGEN: [[REABS_ORIG:%.*]] = function_ref @$sS2fIegyd_S2fIegnr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> @out Float // CHECK-SILGEN: [[NEW_ORIG:%.*]] = partial_apply [callee_guaranteed] [[REABS_ORIG]]([[ORIG_COPY]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> @out Float -// CHECK-SILGEN: [[JVP:%.*]] = differentiable_function_extract [jvp] [order 1] [[DIFF_FUNC_BORROWED]] : $@differentiable @callee_guaranteed (Float) -> Float +// CHECK-SILGEN: [[JVP:%.*]] = differentiable_function_extract [jvp] [[DIFF_FUNC_BORROWED]] : $@differentiable @callee_guaranteed (Float) -> Float // CHECK-SILGEN: [[JVP_COPY:%.*]] = copy_value [[JVP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) // CHECK-SILGEN: [[REABS_JVP:%.*]] = function_ref @$sS4fIegyd_Iegydo_S4fIegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) // CHECK-SILGEN: [[NEW_JVP:%.*]] = partial_apply [callee_guaranteed] [[REABS_JVP]]([[JVP_COPY]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) -// CHECK-SILGEN: [[VJP:%.*]] = differentiable_function_extract [vjp] [order 1] [[DIFF_FUNC_BORROWED]] : $@differentiable @callee_guaranteed (Float) -> Float +// CHECK-SILGEN: [[VJP:%.*]] = differentiable_function_extract [vjp] [[DIFF_FUNC_BORROWED]] : $@differentiable @callee_guaranteed (Float) -> Float // CHECK-SILGEN: [[VJP_COPY:%.*]] = copy_value [[VJP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) // CHECK-SILGEN: [[REABS_VJP:%.*]] = function_ref @$sS4fIegyd_Iegydo_S4fIegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) // CHECK-SILGEN: [[NEW_VJP:%.*]] = partial_apply [callee_guaranteed] [[REABS_VJP]]([[VJP_COPY]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) -// CHECK-SILGEN: [[NEW_DIFF_FUNC:%.*]] = differentiable_function [wrt 0] [order 1] [[NEW_ORIG]] : $@callee_guaranteed (@in_guaranteed Float) -> @out Float with {[[NEW_JVP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float), [[NEW_VJP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)} -// CHECK-SILGEN: [[DIFF_API:%.*]] = function_ref @${{.*}}pullback{{.*}}at{{.*}} : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : _Differentiable, τ_0_1 : _Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector -// CHECK-SILGEN: apply [[DIFF_API]]({{.*}}, [[NEW_DIFF_FUNC]]) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : _Differentiable, τ_0_1 : _Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector +// CHECK-SILGEN: [[NEW_DIFF_FUNC:%.*]] = differentiable_function [wrt 0] [[NEW_ORIG]] : $@callee_guaranteed (@in_guaranteed Float) -> @out Float with {[[NEW_JVP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float), [[NEW_VJP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)} +// CHECK-SILGEN: [[DIFF_API:%.*]] = function_ref @${{.*}}pullback{{.*}}at{{.*}} : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector +// CHECK-SILGEN: apply [[DIFF_API]]({{.*}}, [[NEW_DIFF_FUNC]]) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector diff --git a/test/AutoDiff/differentiable_sil_function_type_parse.sil b/test/AutoDiff/differentiable_sil_function_type_parse.sil index 9c1e62ce18bb1..72103e6976e9b 100644 --- a/test/AutoDiff/differentiable_sil_function_type_parse.sil +++ b/test/AutoDiff/differentiable_sil_function_type_parse.sil @@ -13,23 +13,23 @@ sil @test : $@convention(thin) () -> () { bb0: %0 = function_ref @examplefunc : $@convention(thin) (Float, Float, Float) -> Float - %1 = differentiable_function [wrt 0 1 2] [order 1] %0 : $@convention(thin) (Float, Float, Float) -> Float - // CHECK: %2 = differentiable_function_extract [vjp] [order 1] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float - %2 = differentiable_function_extract [vjp] [order 1] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float + %1 = differentiable_function [wrt 0 1 2] %0 : $@convention(thin) (Float, Float, Float) -> Float + // CHECK: %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float + %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float - %3 = differentiable_function [wrt 0] [order 1] %0 : $@convention(thin) (Float, Float, Float) -> Float - // CHECK: %4 = differentiable_function_extract [vjp] [order 1] %3 : $@differentiable @convention(thin) (Float, @nondiff Float, @nondiff Float) -> Float - %4 = differentiable_function_extract [vjp] [order 1] %3 : $@differentiable @convention(thin) (Float, @nondiff Float, @nondiff Float) -> Float + %3 = differentiable_function [wrt 0] %0 : $@convention(thin) (Float, Float, Float) -> Float + // CHECK: %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (Float, @nondiff Float, @nondiff Float) -> Float + %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (Float, @nondiff Float, @nondiff Float) -> Float %5 = function_ref @examplemethod : $@convention(method) (Float, Float, Float) -> Float - %6 = differentiable_function [wrt 0 1 2] [order 1] %5 : $@convention(method) (Float, Float, Float) -> Float - // CHECK: %7 = differentiable_function_extract [vjp] [order 1] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float - %7 = differentiable_function_extract [vjp] [order 1] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float + %6 = differentiable_function [wrt 0 1 2] %5 : $@convention(method) (Float, Float, Float) -> Float + // CHECK: %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float + %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float - %8 = differentiable_function [wrt 0] [order 1] %5 : $@convention(method) (Float, Float, Float) -> Float - // CHECK: %9 = differentiable_function_extract [vjp] [order 1] %8 : $@differentiable @convention(method) (Float, @nondiff Float, @nondiff Float) -> Float - %9 = differentiable_function_extract [vjp] [order 1] %8 : $@differentiable @convention(method) (Float, @nondiff Float, @nondiff Float) -> Float + %8 = differentiable_function [wrt 0] %5 : $@convention(method) (Float, Float, Float) -> Float + // CHECK: %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (Float, @nondiff Float, @nondiff Float) -> Float + %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (Float, @nondiff Float, @nondiff Float) -> Float %ret = tuple () return %ret : $() diff --git a/test/AutoDiff/forward_mode_sil.swift b/test/AutoDiff/forward_mode_sil.swift index 70bae68d5ccd3..81dace664f739 100644 --- a/test/AutoDiff/forward_mode_sil.swift +++ b/test/AutoDiff/forward_mode_sil.swift @@ -23,15 +23,15 @@ func unary(_ x: Float) -> Float { // CHECK-SIL: [[MULT_FUNC_1:%.*]] = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // CHECK-SIL: [[MULT_FUNC_JVP_1:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: [[MULT_FUNC_VJP_1:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -// CHECK-SIL: [[AUTODIFF_INST_1:%.*]] = differentiable_function [wrt 0 1] [order 1] [[MULT_FUNC_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with {[[MULT_FUNC_JVP_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} -// CHECK-SIL: [[AUTODIFF_EXTRACT_INST_1:%.*]] = differentiable_function_extract [jvp] [order 1] [[AUTODIFF_INST_1]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float +// CHECK-SIL: [[AUTODIFF_INST_1:%.*]] = differentiable_function [wrt 0 1] [[MULT_FUNC_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with {[[MULT_FUNC_JVP_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} +// CHECK-SIL: [[AUTODIFF_EXTRACT_INST_1:%.*]] = differentiable_function_extract [jvp] [[AUTODIFF_INST_1]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float // CHECK-SIL: [[MULT_JVP_APPLY_TUPLE_1:%.*]] = apply [[AUTODIFF_EXTRACT_INST_1]]([[X_ARG]], [[X_ARG]], %3) : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: ([[ORIG_RESULT_1:%.*]], [[MULT_DIFF_1:%.*]]) = destructure_tuple [[MULT_JVP_APPLY_TUPLE_1]] : $(Float, @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: [[MULT_FUNC_2:%.*]] = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // CHECK-SIL: [[MULT_FUNC_JVP_2:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: [[MULT_FUNC_VJP_2:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -// CHECK-SIL: [[AUTODIFF_INST_2:%.*]] = differentiable_function [wrt 0 1] [order 1] [[MULT_FUNC_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with {[[MULT_FUNC_JVP_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} -// CHECK-SIL: [[AUTODIFF_EXTRACT_INST_1:%.*]] = differentiable_function_extract [jvp] [order 1] [[AUTODIFF_INST_2]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float +// CHECK-SIL: [[AUTODIFF_INST_2:%.*]] = differentiable_function [wrt 0 1] [[MULT_FUNC_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with {[[MULT_FUNC_JVP_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} +// CHECK-SIL: [[AUTODIFF_EXTRACT_INST_1:%.*]] = differentiable_function_extract [jvp] [[AUTODIFF_INST_2]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float // CHECK-SIL: [[MULT_JVP_APPLY_TUPLE_2:%.*]] = apply [[AUTODIFF_EXTRACT_INST_1]]([[ORIG_RESULT_1]], [[X_ARG]], %2) : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: ([[ORIG_RESULT_2:%.*]], [[MULT_DIFF_2:%.*]]) = destructure_tuple [[MULT_JVP_APPLY_TUPLE_2]] : $(Float, @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: [[DIFF_STRUCT:%.*]] = struct $_AD__unary_bb0__DF__src_0_wrt_0 ([[MULT_DIFF_1]] : $@callee_guaranteed (Float, Float) -> Float, [[MULT_DIFF_2]] : $@callee_guaranteed (Float, Float) -> Float) @@ -68,8 +68,8 @@ func binary(x: Float, y: Float) -> Float { // CHECK-SIL: [[MULT_FUNC:%.*]] = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // CHECK-SIL: [[MULT_FUNC_JVP:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: [[MULT_FUNC_VJP:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -// CHECK-SIL: [[AUTODIFF_INST:%.*]] = differentiable_function [wrt 0 1] [order 1] [[MULT_FUNC]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with {[[MULT_FUNC_JVP]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} -// CHECK-SIL: [[AUTODIFF_EXTRACT_INST:%.*]] = differentiable_function_extract [jvp] [order 1] [[AUTODIFF_INST]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float +// CHECK-SIL: [[AUTODIFF_INST:%.*]] = differentiable_function [wrt 0 1] [[MULT_FUNC]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with {[[MULT_FUNC_JVP]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} +// CHECK-SIL: [[AUTODIFF_EXTRACT_INST:%.*]] = differentiable_function_extract [jvp] [[AUTODIFF_INST]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float // CHECK-SIL: [[MULT_JVP_APPLY_TUPLE:%.*]] = apply [[AUTODIFF_EXTRACT_INST]]([[X_ARG]], [[Y_ARG]], %4) : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: ([[ORIG_RESULT:%.*]], [[MULT_DIFF:%.*]]) = destructure_tuple [[MULT_JVP_APPLY_TUPLE]] : $(Float, @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: [[DIFF_STRUCT:%.*]] = struct $_AD__binary_bb0__DF__src_0_wrt_0_1 ([[MULT_DIFF]] : $@callee_guaranteed (Float, Float) -> Float) diff --git a/test/AutoDiff/refcounting.swift b/test/AutoDiff/refcounting.swift index 82f86b3afc75b..3927212225f47 100644 --- a/test/AutoDiff/refcounting.swift +++ b/test/AutoDiff/refcounting.swift @@ -86,8 +86,8 @@ _ = pullback(at: Vector.zero, in: testOwnedVector) // CHECK: [[ADD:%.*]] = function_ref @Vector_plus // CHECK: [[ADD_JVP:%.*]] = function_ref @{{.*}}Vector_plus__jvp_src_0_wrt_0_1{{.*}} // CHECK: [[ADD_VJP:%.*]] = function_ref @{{.*}}Vector_plus__vjp_src_0_wrt_0_1{{.*}} -// CHECK: [[ADD_AD_FUNC:%.*]] = differentiable_function [wrt 0 1] [order 1] [[ADD]] {{.*}} with {[[ADD_JVP]] {{.*}}, [[ADD_VJP]] {{.*}}} -// CHECK: [[ADD_AD_FUNC_EXTRACT:%.*]] = differentiable_function_extract [vjp] [order 1] [[ADD_AD_FUNC]] +// CHECK: [[ADD_AD_FUNC:%.*]] = differentiable_function [wrt 0 1] [[ADD]] {{.*}} with {[[ADD_JVP]] {{.*}}, [[ADD_VJP]] {{.*}}} +// CHECK: [[ADD_AD_FUNC_EXTRACT:%.*]] = differentiable_function_extract [vjp] [[ADD_AD_FUNC]] // CHECK: [[ADD_VJP_RESULT:%.*]] = apply [[ADD_AD_FUNC_EXTRACT]]({{.*}}, {{.*}}, {{.*}}) : $@convention(method) (@guaranteed Vector, @guaranteed Vector, @thin Vector.Type) -> (@owned Vector, @owned @callee_guaranteed (@guaranteed Vector) -> (@owned Vector, @owned Vector)) // CHECK: [[ADD_PULLBACK:%.*]] = tuple_extract [[ADD_VJP_RESULT]] : $(Vector, @callee_guaranteed (@guaranteed Vector) -> (@owned Vector, @owned Vector)), 1 // CHECK-NOT: release_value [[ADD_VJP_RESULT]] diff --git a/test/AutoDiff/sildeclref_parse.sil b/test/AutoDiff/sildeclref_parse.sil index d20082a96048e..cfd509e0d79c6 100644 --- a/test/AutoDiff/sildeclref_parse.sil +++ b/test/AutoDiff/sildeclref_parse.sil @@ -13,17 +13,17 @@ bb0(%0 : $*T): // CHECK: witness_method $T, #Proto.f!1 %1 = witness_method $T, #Proto.f!1 : (Self) -> (Float, Float) -> Float : $@convention(witness_method: Proto) <τ_0_0 where τ_0_0 : Proto> (@in_guaranteed τ_0_0) -> (Float, Float) -> Float - // CHECK: witness_method $T, #Proto.f!1.jvp.1.SSS - %2 = witness_method $T, #Proto.f!1.jvp.1.SSS : (Self) -> (Float, Float) -> Float : $@convention(witness_method: Proto) <τ_0_0 where τ_0_0 : Proto> (@in_guaranteed τ_0_0) -> (Float, Float) -> Float + // CHECK: witness_method $T, #Proto.f!1.jvp.SSS + %2 = witness_method $T, #Proto.f!1.jvp.SSS : (Self) -> (Float, Float) -> Float : $@convention(witness_method: Proto) <τ_0_0 where τ_0_0 : Proto> (@in_guaranteed τ_0_0) -> (Float, Float) -> Float - // CHECK: witness_method $T, #Proto.f!1.jvp.1.UUS - %3 = witness_method $T, #Proto.f!1.jvp.1.UUS : (Self) -> (Float, Float) -> Float : $@convention(witness_method: Proto) <τ_0_0 where τ_0_0 : Proto> (@in_guaranteed τ_0_0) -> (Float, Float) -> Float + // CHECK: witness_method $T, #Proto.f!1.jvp.UUS + %3 = witness_method $T, #Proto.f!1.jvp.UUS : (Self) -> (Float, Float) -> Float : $@convention(witness_method: Proto) <τ_0_0 where τ_0_0 : Proto> (@in_guaranteed τ_0_0) -> (Float, Float) -> Float - // CHECK: witness_method $T, #Proto.f!1.vjp.1.SSS - %4 = witness_method $T, #Proto.f!1.vjp.1.SSS : (Self) -> (Float, Float) -> Float : $@convention(witness_method: Proto) <τ_0_0 where τ_0_0 : Proto> (@in_guaranteed τ_0_0) -> (Float, Float) -> Float + // CHECK: witness_method $T, #Proto.f!1.vjp.SSS + %4 = witness_method $T, #Proto.f!1.vjp.SSS : (Self) -> (Float, Float) -> Float : $@convention(witness_method: Proto) <τ_0_0 where τ_0_0 : Proto> (@in_guaranteed τ_0_0) -> (Float, Float) -> Float - // CHECK: witness_method $T, #Proto.f!1.vjp.1.UUS - %5 = witness_method $T, #Proto.f!1.vjp.1.UUS : (Self) -> (Float, Float) -> Float : $@convention(witness_method: Proto) <τ_0_0 where τ_0_0 : Proto> (@in_guaranteed τ_0_0) -> (Float, Float) -> Float + // CHECK: witness_method $T, #Proto.f!1.vjp.UUS + %5 = witness_method $T, #Proto.f!1.vjp.UUS : (Self) -> (Float, Float) -> Float : $@convention(witness_method: Proto) <τ_0_0 where τ_0_0 : Proto> (@in_guaranteed τ_0_0) -> (Float, Float) -> Float %6 = tuple () return %6 : $() diff --git a/test/AutoDiff/simple_real_vector.swift b/test/AutoDiff/simple_real_vector.swift index b83c6d8fd98fe..6dc92fdf90f5d 100644 --- a/test/AutoDiff/simple_real_vector.swift +++ b/test/AutoDiff/simple_real_vector.swift @@ -46,7 +46,7 @@ public func test1() -> Vector { // CHECK-LABEL: @{{.*}}test1{{.*}} // CHECK: [[CLOSURE:%.*]] = function_ref @{{.*}}test1{{.*}}foo{{.*}} : $@convention(thin) (Vector) -> Float // CHECK: [[CLOSURE_THICK:%.*]] = thin_to_thick_function [[CLOSURE]] : $@convention(thin) (Vector) -> Float to $@callee_guaranteed (Vector) -> Float -// CHECK: [[CLOSURE_DIFF:%.*]] = differentiable_function [wrt 0] [order 1] [[CLOSURE_THICK]] : $@callee_guaranteed (Vector) -> Float +// CHECK: [[CLOSURE_DIFF:%.*]] = differentiable_function [wrt 0] [[CLOSURE_THICK]] : $@callee_guaranteed (Vector) -> Float // CHECK: [[CLOSURE_DIFF_NOESC:%.*]] = convert_escape_to_noescape [not_guaranteed] [[CLOSURE_DIFF]] : $@differentiable @callee_guaranteed (Vector) -> Float to $@differentiable @noescape @callee_guaranteed (Vector) -> Float // TF-189: `TF189` is a non-trivial type but `TF189.AllDifferentiableVariables` is trivial. diff --git a/test/AutoDiff/subset_parameters_thunk.swift b/test/AutoDiff/subset_parameters_thunk.swift index 8521faa117f20..2dc0308285faf 100644 --- a/test/AutoDiff/subset_parameters_thunk.swift +++ b/test/AutoDiff/subset_parameters_thunk.swift @@ -26,5 +26,5 @@ func differentiate_foo_wrt_0(_ x: Float) -> Float { // CHECK: [[FOO_VJP_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_VJP]]() : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric, τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_0.TangentVector)) // CHECK: [[FOO_VJP_SUBSET_THUNK_THIN:%.*]] = function_ref @AD__orig_{{.*}}foo{{.*}}_src_0_wrt_0_vjp_subset_parameters_thunk : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) // CHECK: [[FOO_VJP_SUBSET_THUNK:%.*]] = thin_to_thick_function [[FOO_VJP_SUBSET_THUNK_THIN]] : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) to $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) -// CHECK: [[FOO_DIFF:%.*]] = differentiable_function [wrt 0] [order 1] [[FOO_FLOAT]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> @out Float with {[[FOO_JVP_SUBSET_THUNK]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float), [[FOO_VJP_SUBSET_THUNK]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)} +// CHECK: [[FOO_DIFF:%.*]] = differentiable_function [wrt 0] [[FOO_FLOAT]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> @out Float with {[[FOO_JVP_SUBSET_THUNK]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float), [[FOO_VJP_SUBSET_THUNK]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)} // CHECK: } diff --git a/test/AutoDiff/vtable_sil.swift b/test/AutoDiff/vtable_sil.swift index f1d95399abe1d..50bb8fae60b7f 100644 --- a/test/AutoDiff/vtable_sil.swift +++ b/test/AutoDiff/vtable_sil.swift @@ -91,14 +91,14 @@ class SubSub : Sub {} // CHECK-NEXT: #Super._nontrivial!modify.1: (Super) -> () -> () : @$s10vtable_sil5SuperC11_nontrivialSaySfGvM // CHECK-NEXT: #Super.init!allocator.1: (Super.Type) -> (Float) -> Super : @$s10vtable_sil5SuperC4baseACSf_tcfC // CHECK-NEXT: #Super.property!getter.1: (Super) -> () -> Float : @$s10vtable_sil5SuperC8propertySfvg -// CHECK-NEXT: #Super.property!getter.1.jvp.1.S: (Super) -> () -> Float : @AD__$s10vtable_sil5SuperC8propertySfvg__jvp_src_0_wrt_0_vtable_entry_thunk -// CHECK-NEXT: #Super.property!getter.1.vjp.1.S: (Super) -> () -> Float : @AD__$s10vtable_sil5SuperC8propertySfvg__vjp_src_0_wrt_0_vtable_entry_thunk +// CHECK-NEXT: #Super.property!getter.1.jvp.S: (Super) -> () -> Float : @AD__$s10vtable_sil5SuperC8propertySfvg__jvp_src_0_wrt_0_vtable_entry_thunk +// CHECK-NEXT: #Super.property!getter.1.vjp.S: (Super) -> () -> Float : @AD__$s10vtable_sil5SuperC8propertySfvg__vjp_src_0_wrt_0_vtable_entry_thunk // CHECK-NEXT: #Super.f!1: (Super) -> (Float, Float) -> Float : @$s10vtable_sil5SuperC1fyS2f_SftF -// CHECK-NEXT: #Super.f!1.jvp.1.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil5SuperC1fyS2f_SftF__jvp_src_0_wrt_0_vtable_entry_thunk -// CHECK-NEXT: #Super.f!1.vjp.1.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil5SuperC1fyS2f_SftF__vjp_src_0_wrt_0_vtable_entry_thunk +// CHECK-NEXT: #Super.f!1.jvp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil5SuperC1fyS2f_SftF__jvp_src_0_wrt_0_vtable_entry_thunk +// CHECK-NEXT: #Super.f!1.vjp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil5SuperC1fyS2f_SftF__vjp_src_0_wrt_0_vtable_entry_thunk // CHECK-NEXT: #Super.subscript!getter.1: (Super) -> (Float, Float) -> Float : @$s10vtable_sil5SuperCyS2f_Sftcig -// CHECK-NEXT: #Super.subscript!getter.1.jvp.1.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil5SuperCyS2f_Sftcig__jvp_src_0_wrt_0_vtable_entry_thunk -// CHECK-NEXT: #Super.subscript!getter.1.vjp.1.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil5SuperCyS2f_Sftcig__vjp_src_0_wrt_0_vtable_entry_thunk +// CHECK-NEXT: #Super.subscript!getter.1.jvp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil5SuperCyS2f_Sftcig__jvp_src_0_wrt_0_vtable_entry_thunk +// CHECK-NEXT: #Super.subscript!getter.1.vjp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil5SuperCyS2f_Sftcig__vjp_src_0_wrt_0_vtable_entry_thunk // CHECK-NEXT: #Super.move!1: (Super) -> (Super.TangentVector) -> () : @$s10vtable_sil5SuperC4move5alongyAC13TangentVectorV_tF // CHECK-NEXT: #Super.deinit!deallocator.1: @$s10vtable_sil5SuperCfD // CHECK-NEXT: } @@ -112,17 +112,17 @@ class SubSub : Sub {} // CHECK-NEXT: #Super._nontrivial!modify.1: (Super) -> () -> () : @$s10vtable_sil5SuperC11_nontrivialSaySfGvM [inherited] // CHECK-NEXT: #Super.init!allocator.1: (Super.Type) -> (Float) -> Super : @$s10vtable_sil3SubC4baseACSf_tcfC [override] // CHECK-NEXT: #Super.property!getter.1: (Super) -> () -> Float : @$s10vtable_sil3SubC8propertySfvg [override] -// CHECK-NEXT: #Super.property!getter.1.jvp.1.S: (Super) -> () -> Float : @AD__$s10vtable_sil3SubC8propertySfvg__jvp_src_0_wrt_0_vtable_entry_thunk [override] -// CHECK-NEXT: #Super.property!getter.1.vjp.1.S: (Super) -> () -> Float : @AD__$s10vtable_sil3SubC8propertySfvg__vjp_src_0_wrt_0_vtable_entry_thunk [override] +// CHECK-NEXT: #Super.property!getter.1.jvp.S: (Super) -> () -> Float : @AD__$s10vtable_sil3SubC8propertySfvg__jvp_src_0_wrt_0_vtable_entry_thunk [override] +// CHECK-NEXT: #Super.property!getter.1.vjp.S: (Super) -> () -> Float : @AD__$s10vtable_sil3SubC8propertySfvg__vjp_src_0_wrt_0_vtable_entry_thunk [override] // CHECK-NEXT: #Super.f!1: (Super) -> (Float, Float) -> Float : @$s10vtable_sil3SubC1fyS2f_SftF [override] -// CHECK-NEXT: #Super.f!1.jvp.1.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubC1fyS2f_SftF__jvp_src_0_wrt_0_vtable_entry_thunk [override] -// CHECK-NEXT: #Super.f!1.vjp.1.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubC1fyS2f_SftF__vjp_src_0_wrt_0_vtable_entry_thunk [override] +// CHECK-NEXT: #Super.f!1.jvp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubC1fyS2f_SftF__jvp_src_0_wrt_0_vtable_entry_thunk [override] +// CHECK-NEXT: #Super.f!1.vjp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubC1fyS2f_SftF__vjp_src_0_wrt_0_vtable_entry_thunk [override] // CHECK-NEXT: #Super.subscript!getter.1: (Super) -> (Float, Float) -> Float : @$s10vtable_sil3SubCyS2f_Sftcig [override] -// CHECK-NEXT: #Super.subscript!getter.1.jvp.1.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubCyS2f_Sftcig__jvp_src_0_wrt_0_vtable_entry_thunk [override] -// CHECK-NEXT: #Super.subscript!getter.1.vjp.1.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubCyS2f_Sftcig__vjp_src_0_wrt_0_vtable_entry_thunk [override] +// CHECK-NEXT: #Super.subscript!getter.1.jvp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubCyS2f_Sftcig__jvp_src_0_wrt_0_vtable_entry_thunk [override] +// CHECK-NEXT: #Super.subscript!getter.1.vjp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubCyS2f_Sftcig__vjp_src_0_wrt_0_vtable_entry_thunk [override] // CHECK-NEXT: #Super.move!1: (Super) -> (Super.TangentVector) -> () : @$s10vtable_sil5SuperC4move5alongyAC13TangentVectorV_tF [inherited] -// CHECK-NEXT: #Sub.f!1.jvp.1.SSU: (Sub) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubC1fyS2f_SftF__jvp_src_0_wrt_0_1_vtable_entry_thunk -// CHECK-NEXT: #Sub.f!1.vjp.1.SSU: (Sub) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubC1fyS2f_SftF__vjp_src_0_wrt_0_1_vtable_entry_thunk +// CHECK-NEXT: #Sub.f!1.jvp.SSU: (Sub) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubC1fyS2f_SftF__jvp_src_0_wrt_0_1_vtable_entry_thunk +// CHECK-NEXT: #Sub.f!1.vjp.SSU: (Sub) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubC1fyS2f_SftF__vjp_src_0_wrt_0_1_vtable_entry_thunk // CHECK-NEXT: #Sub.deinit!deallocator.1: @$s10vtable_sil3SubCfD // CHECK-NEXT: } @@ -135,16 +135,16 @@ class SubSub : Sub {} // CHECK-NEXT: #Super._nontrivial!modify.1: (Super) -> () -> () : @$s10vtable_sil5SuperC11_nontrivialSaySfGvM [inherited] // CHECK-NEXT: #Super.init!allocator.1: (Super.Type) -> (Float) -> Super : @$s10vtable_sil03SubC0C4baseACSf_tcfC [override] // CHECK-NEXT: #Super.property!getter.1: (Super) -> () -> Float : @$s10vtable_sil3SubC8propertySfvg [inherited] -// CHECK-NEXT: #Super.property!getter.1.jvp.1.S: (Super) -> () -> Float : @AD__$s10vtable_sil3SubC8propertySfvg__jvp_src_0_wrt_0_vtable_entry_thunk [inherited] -// CHECK-NEXT: #Super.property!getter.1.vjp.1.S: (Super) -> () -> Float : @AD__$s10vtable_sil3SubC8propertySfvg__vjp_src_0_wrt_0_vtable_entry_thunk [inherited] +// CHECK-NEXT: #Super.property!getter.1.jvp.S: (Super) -> () -> Float : @AD__$s10vtable_sil3SubC8propertySfvg__jvp_src_0_wrt_0_vtable_entry_thunk [inherited] +// CHECK-NEXT: #Super.property!getter.1.vjp.S: (Super) -> () -> Float : @AD__$s10vtable_sil3SubC8propertySfvg__vjp_src_0_wrt_0_vtable_entry_thunk [inherited] // CHECK-NEXT: #Super.f!1: (Super) -> (Float, Float) -> Float : @$s10vtable_sil3SubC1fyS2f_SftF [inherited] -// CHECK-NEXT: #Super.f!1.jvp.1.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubC1fyS2f_SftF__jvp_src_0_wrt_0_vtable_entry_thunk [inherited] -// CHECK-NEXT: #Super.f!1.vjp.1.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubC1fyS2f_SftF__vjp_src_0_wrt_0_vtable_entry_thunk [inherited] +// CHECK-NEXT: #Super.f!1.jvp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubC1fyS2f_SftF__jvp_src_0_wrt_0_vtable_entry_thunk [inherited] +// CHECK-NEXT: #Super.f!1.vjp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubC1fyS2f_SftF__vjp_src_0_wrt_0_vtable_entry_thunk [inherited] // CHECK-NEXT: #Super.subscript!getter.1: (Super) -> (Float, Float) -> Float : @$s10vtable_sil3SubCyS2f_Sftcig [inherited] -// CHECK-NEXT: #Super.subscript!getter.1.jvp.1.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubCyS2f_Sftcig__jvp_src_0_wrt_0_vtable_entry_thunk [inherited] -// CHECK-NEXT: #Super.subscript!getter.1.vjp.1.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubCyS2f_Sftcig__vjp_src_0_wrt_0_vtable_entry_thunk [inherited] +// CHECK-NEXT: #Super.subscript!getter.1.jvp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubCyS2f_Sftcig__jvp_src_0_wrt_0_vtable_entry_thunk [inherited] +// CHECK-NEXT: #Super.subscript!getter.1.vjp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubCyS2f_Sftcig__vjp_src_0_wrt_0_vtable_entry_thunk [inherited] // CHECK-NEXT: #Super.move!1: (Super) -> (Super.TangentVector) -> () : @$s10vtable_sil5SuperC4move5alongyAC13TangentVectorV_tF [inherited] -// CHECK-NEXT: #Sub.f!1.jvp.1.SSU: (Sub) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubC1fyS2f_SftF__jvp_src_0_wrt_0_1_vtable_entry_thunk [inherited] -// CHECK-NEXT: #Sub.f!1.vjp.1.SSU: (Sub) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubC1fyS2f_SftF__vjp_src_0_wrt_0_1_vtable_entry_thunk [inherited] +// CHECK-NEXT: #Sub.f!1.jvp.SSU: (Sub) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubC1fyS2f_SftF__jvp_src_0_wrt_0_1_vtable_entry_thunk [inherited] +// CHECK-NEXT: #Sub.f!1.vjp.SSU: (Sub) -> (Float, Float) -> Float : @AD__$s10vtable_sil3SubC1fyS2f_SftF__vjp_src_0_wrt_0_1_vtable_entry_thunk [inherited] // CHECK-NEXT: #SubSub.deinit!deallocator.1: @$s10vtable_sil03SubC0CfD // CHECK-NEXT: } diff --git a/test/AutoDiff/witness_method_autodiff.sil b/test/AutoDiff/witness_method_autodiff.sil index 4336605268d27..1c3603bb0c837 100644 --- a/test/AutoDiff/witness_method_autodiff.sil +++ b/test/AutoDiff/witness_method_autodiff.sil @@ -14,7 +14,7 @@ protocol DiffReq { sil @differentiateWitnessMethod : $@convention(thin) (@in_guaranteed T) -> () { bb0(%0 : $*T): %1 = witness_method $T, #DiffReq.f!1 : (Self) -> (Float) -> Float : $@convention(witness_method: DiffReq) <τ_0_0 where τ_0_0 : DiffReq> (Float, @in_guaranteed τ_0_0) -> Float - %2 = differentiable_function [wrt 0] [order 1] %1 : $@convention(witness_method: DiffReq) <τ_0_0 where τ_0_0 : DiffReq> (Float, @in_guaranteed τ_0_0) -> Float + %2 = differentiable_function [wrt 0] %1 : $@convention(witness_method: DiffReq) <τ_0_0 where τ_0_0 : DiffReq> (Float, @in_guaranteed τ_0_0) -> Float %ret = tuple () return %ret : $() @@ -22,16 +22,16 @@ bb0(%0 : $*T): // CHECK-LABEL: sil @differentiateWitnessMethod // CHECK: [[ORIG_REF:%.*]] = witness_method $T, #DiffReq.f!1 -// CHECK: [[JVP_REF:%.*]] = witness_method $T, #DiffReq.f!1.jvp.1.SU -// CHECK: [[VJP_REF:%.*]] = witness_method $T, #DiffReq.f!1.vjp.1.SU -// CHECK: differentiable_function [wrt 0] [order 1] [[ORIG_REF]] : {{.*}} with {[[JVP_REF]] : {{.*}}, [[VJP_REF]] : {{.*}}} +// CHECK: [[JVP_REF:%.*]] = witness_method $T, #DiffReq.f!1.jvp.SU +// CHECK: [[VJP_REF:%.*]] = witness_method $T, #DiffReq.f!1.vjp.SU +// CHECK: differentiable_function [wrt 0] [[ORIG_REF]] : {{.*}} with {[[JVP_REF]] : {{.*}}, [[VJP_REF]] : {{.*}}} // CHECK: } // end sil function 'differentiateWitnessMethod' sil @differentiatePartiallyAppliedWitnessMethod : $@convention(thin) (@in_guaranteed T) -> () { bb0(%0 : $*T): %1 = witness_method $T, #DiffReq.f!1 : (Self) -> (Float) -> Float : $@convention(witness_method: DiffReq) <τ_0_0 where τ_0_0 : DiffReq> (Float, @in_guaranteed τ_0_0) -> Float %2 = partial_apply [callee_guaranteed] %1(%0) : $@convention(witness_method: DiffReq) <τ_0_0 where τ_0_0 : DiffReq> (Float, @in_guaranteed τ_0_0) -> Float - %3 = differentiable_function [wrt 0] [order 1] %2 : $@callee_guaranteed (Float) -> Float + %3 = differentiable_function [wrt 0] %2 : $@callee_guaranteed (Float) -> Float %ret = tuple () return %ret : $() @@ -45,11 +45,11 @@ bb0(%0 : $*T): // CHECK: [[ARGCOPY2:%.*]] = alloc_stack $T // CHECK: copy_addr [[ARG]] to [initialization] [[ARGCOPY2]] : $*T // CHECK: [[ORIG_REF_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[ORIG_REF]](%0) -// CHECK: [[JVP_REF:%.*]] = witness_method $T, #DiffReq.f!1.jvp.1.SU +// CHECK: [[JVP_REF:%.*]] = witness_method $T, #DiffReq.f!1.jvp.SU // CHECK: [[JVP_REF_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[JVP_REF]]([[ARGCOPY1]]) -// CHECK: [[VJP_REF:%.*]] = witness_method $T, #DiffReq.f!1.vjp.1.SU +// CHECK: [[VJP_REF:%.*]] = witness_method $T, #DiffReq.f!1.vjp.SU // CHECK: [[VJP_REF_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[VJP_REF]]([[ARGCOPY2]]) // CHECK: dealloc_stack [[ARGCOPY2]] // CHECK: dealloc_stack [[ARGCOPY1]] -// CHECK: differentiable_function [wrt 0] [order 1] [[ORIG_REF_PARTIALLY_APPLIED]] : {{.*}} with {[[JVP_REF_PARTIALLY_APPLIED]] : {{.*}}, [[VJP_REF_PARTIALLY_APPLIED]] : {{.*}}} +// CHECK: differentiable_function [wrt 0] [[ORIG_REF_PARTIALLY_APPLIED]] : {{.*}} with {[[JVP_REF_PARTIALLY_APPLIED]] : {{.*}}, [[VJP_REF_PARTIALLY_APPLIED]] : {{.*}}} // CHECK: } // end sil function 'differentiatePartiallyAppliedWitnessMethod' diff --git a/test/AutoDiff/witness_table_irgen.sil b/test/AutoDiff/witness_table_irgen.sil index f62a8a54551f2..114e74507ead9 100644 --- a/test/AutoDiff/witness_table_irgen.sil +++ b/test/AutoDiff/witness_table_irgen.sil @@ -65,8 +65,8 @@ bb0(%0 : $Float, %1 : $AD__$s23witness_tables_autodiff25DifferentiableConformanc sil_witness_table hidden DifferentiableConformance: DifferentiableRequirement module witness_tables_autodiff { method #DifferentiableRequirement.f!1: (Self) -> (Float) -> Float : @$s23witness_tables_autodiff25DifferentiableConformanceVAA0D11RequirementA2aDP1fyS2fFTW // protocol witness for DifferentiableRequirement.f(_:) in conformance DifferentiableConformance - method #DifferentiableRequirement.f!1.jvp.1.SU: (Self) -> (Float) -> Float : @AD__$s23witness_tables_autodiff25DifferentiableConformanceVAA0D11RequirementA2aDP1fyS2fFTW_jvp_SU // AD__$s23witness_tables_autodiff25DifferentiableConformanceVAA0D11RequirementA2aDP1fyS2fFTW_jvp_SU - method #DifferentiableRequirement.f!1.vjp.1.SU: (Self) -> (Float) -> Float : @AD__$s23witness_tables_autodiff25DifferentiableConformanceVAA0D11RequirementA2aDP1fyS2fFTW_vjp_SU // AD__$s23witness_tables_autodiff25DifferentiableConformanceVAA0D11RequirementA2aDP1fyS2fFTW_vjp_SU + method #DifferentiableRequirement.f!1.jvp.SU: (Self) -> (Float) -> Float : @AD__$s23witness_tables_autodiff25DifferentiableConformanceVAA0D11RequirementA2aDP1fyS2fFTW_jvp_SU // AD__$s23witness_tables_autodiff25DifferentiableConformanceVAA0D11RequirementA2aDP1fyS2fFTW_jvp_SU + method #DifferentiableRequirement.f!1.vjp.SU: (Self) -> (Float) -> Float : @AD__$s23witness_tables_autodiff25DifferentiableConformanceVAA0D11RequirementA2aDP1fyS2fFTW_vjp_SU // AD__$s23witness_tables_autodiff25DifferentiableConformanceVAA0D11RequirementA2aDP1fyS2fFTW_vjp_SU } // CHECK: @"$s19witness_table_irgen25DifferentiableConformanceVAA0D11RequirementAAWP" = hidden constant [4 x i8*] [ diff --git a/test/AutoDiff/witness_table_sil.swift b/test/AutoDiff/witness_table_sil.swift index ff4b29088a061..8cd907ef86d95 100644 --- a/test/AutoDiff/witness_table_sil.swift +++ b/test/AutoDiff/witness_table_sil.swift @@ -25,16 +25,16 @@ struct S : Proto, VectorProtocol { // CHECK-LABEL: sil {{.*}} @AD__{{.*}}function1{{.*}}_jvp_SSU : $@convention(witness_method: Proto) (Float, Double, @in_guaranteed S) -> (Float, @owned @callee_guaranteed (Float, Double) -> Float) { // CHECK: [[JVP1_ORIG_FNREF:%.*]] = function_ref {{.*}}function1{{.*}} : $@convention(method) (Float, Double, S) -> Float // CHECK: [[JVP1_VJP_FNREF:%.*]] = function_ref @AD__{{.*}}function1{{.*}}__vjp_src_0_wrt_0_1 - // CHECK: [[JVP1_ADFUNC:%.*]] = differentiable_function [wrt 0 1] [order 1] [[JVP1_ORIG_FNREF]] : {{.*}} with {{{%.*}} : {{.*}}, [[JVP1_VJP_FNREF]] : {{.*}}} - // CHECK: [[JVP1:%.*]] = differentiable_function_extract [jvp] [order 1] [[JVP1_ADFUNC]] : $@differentiable @convention(method) (Float, Double, @nondiff S) -> Float + // CHECK: [[JVP1_ADFUNC:%.*]] = differentiable_function [wrt 0 1] [[JVP1_ORIG_FNREF]] : {{.*}} with {{{%.*}} : {{.*}}, [[JVP1_VJP_FNREF]] : {{.*}}} + // CHECK: [[JVP1:%.*]] = differentiable_function_extract [jvp] [[JVP1_ADFUNC]] : $@differentiable @convention(method) (Float, Double, @nondiff S) -> Float // CHECK: apply [[JVP1]] // CHECK: } // end sil function 'AD__{{.*}}function1{{.*}}_jvp_SSU' // CHECK-LABEL: sil {{.*}} @AD__{{.*}}function1{{.*}}_vjp_SSU : $@convention(witness_method: Proto) (Float, Double, @in_guaranteed S) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Double)) { // CHECK: [[VJP1_ORIG_FNREF:%.*]] = function_ref {{.*}}function1{{.*}} : $@convention(method) (Float, Double, S) -> Float // CHECK: [[VJP1_VJP_FNREF:%.*]] = function_ref @AD__{{.*}}function1{{.*}}__vjp_src_0_wrt_0_1 - // CHECK: [[VJP1_ADFUNC:%.*]] = differentiable_function [wrt 0 1] [order 1] [[VJP1_ORIG_FNREF]] : {{.*}} with {{{%.*}} : {{.*}}, [[VJP1_VJP_FNREF]] : {{.*}}} - // CHECK: [[VJP1:%.*]] = differentiable_function_extract [vjp] [order 1] [[VJP1_ADFUNC]] : $@differentiable @convention(method) (Float, Double, @nondiff S) -> Float + // CHECK: [[VJP1_ADFUNC:%.*]] = differentiable_function [wrt 0 1] [[VJP1_ORIG_FNREF]] : {{.*}} with {{{%.*}} : {{.*}}, [[VJP1_VJP_FNREF]] : {{.*}}} + // CHECK: [[VJP1:%.*]] = differentiable_function_extract [vjp] [[VJP1_ADFUNC]] : $@differentiable @convention(method) (Float, Double, @nondiff S) -> Float // CHECK: apply [[VJP1]] // CHECK: } // end sil function 'AD__{{.*}}function1{{.*}}_vjp_SSU' @@ -46,16 +46,16 @@ struct S : Proto, VectorProtocol { // CHECK-LABEL: sil {{.*}} @AD__{{.*}}function2{{.*}}_jvp_SSS : $@convention(witness_method: Proto) (Float, Double, @in_guaranteed S) -> (Float, @owned @callee_guaranteed (Float, Double, @in_guaranteed S) -> Float) { // CHECK: [[JVP2_ORIG_FNREF:%.*]] = function_ref {{.*}}function2{{.*}} : $@convention(method) (Float, Double, S) -> Float // CHECK: [[JVP2_VJP_FNREF:%.*]] = function_ref @AD__{{.*}}function2{{.*}}__vjp_src_0_wrt_0_1_2 - // CHECK: [[JVP2_ADFUNC:%.*]] = differentiable_function [wrt 0 1 2] [order 1] [[JVP2_ORIG_FNREF]] : {{.*}} with {{{%.*}} : {{.*}}, [[JVP2_VJP_FNREF]] : {{.*}}} - // CHECK: [[JVP2:%.*]] = differentiable_function_extract [jvp] [order 1] [[JVP2_ADFUNC]] : $@differentiable @convention(method) (Float, Double, S) -> Float + // CHECK: [[JVP2_ADFUNC:%.*]] = differentiable_function [wrt 0 1 2] [[JVP2_ORIG_FNREF]] : {{.*}} with {{{%.*}} : {{.*}}, [[JVP2_VJP_FNREF]] : {{.*}}} + // CHECK: [[JVP2:%.*]] = differentiable_function_extract [jvp] [[JVP2_ADFUNC]] : $@differentiable @convention(method) (Float, Double, S) -> Float // CHECK: apply [[JVP2]] // CHECK: } // end sil function 'AD__{{.*}}function2{{.*}}_jvp_SSS' // CHECK-LABEL: sil {{.*}} @AD__{{.*}}function2{{.*}}_vjp_SSS : $@convention(witness_method: Proto) (Float, Double, @in_guaranteed S) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Double, @out S)) { // CHECK: [[VJP2_ORIG_FNREF:%.*]] = function_ref {{.*}}function2{{.*}} : $@convention(method) (Float, Double, S) -> Float // CHECK: [[VJP2_VJP_FNREF:%.*]] = function_ref @AD__{{.*}}function2{{.*}}__vjp_src_0_wrt_0_1_2 - // CHECK: [[VJP2_ADFUNC:%.*]] = differentiable_function [wrt 0 1 2] [order 1] [[VJP2_ORIG_FNREF]] : {{.*}} with {{{%.*}} : {{.*}}, [[VJP2_VJP_FNREF]] : {{.*}}} - // CHECK: [[VJP2:%.*]] = differentiable_function_extract [vjp] [order 1] [[VJP2_ADFUNC]] : $@differentiable @convention(method) (Float, Double, S) -> Float + // CHECK: [[VJP2_ADFUNC:%.*]] = differentiable_function [wrt 0 1 2] [[VJP2_ORIG_FNREF]] : {{.*}} with {{{%.*}} : {{.*}}, [[VJP2_VJP_FNREF]] : {{.*}}} + // CHECK: [[VJP2:%.*]] = differentiable_function_extract [vjp] [[VJP2_ADFUNC]] : $@differentiable @convention(method) (Float, Double, S) -> Float // CHECK: apply [[VJP2]] // CHECK: } // end sil function 'AD__{{.*}}function2{{.*}}_vjp_SSS' @@ -67,16 +67,16 @@ struct S : Proto, VectorProtocol { // CHECK-LABEL: sil {{.*}} @AD__{{.*}}function3{{.*}}_jvp_USU : $@convention(witness_method: Proto) (Float, Double, @in_guaranteed S) -> (Double, @owned @callee_guaranteed (Double) -> Double) { // CHECK: [[JVP3_ORIG_FNREF:%.*]] = function_ref {{.*}}function3{{.*}} : $@convention(method) (Float, Double, S) -> Double // CHECK: [[JVP3_VJP_FNREF:%.*]] = function_ref @AD__{{.*}}function3{{.*}}__vjp_src_0_wrt_1 - // CHECK: [[JVP3_ADFUNC:%.*]] = differentiable_function [wrt 1] [order 1] [[JVP3_ORIG_FNREF]] : {{.*}} with {{{%.*}} : {{.*}}, [[JVP3_VJP_FNREF]] : {{.*}}} - // CHECK: [[JVP3:%.*]] = differentiable_function_extract [jvp] [order 1] [[JVP3_ADFUNC]] : $@differentiable @convention(method) (@nondiff Float, Double, @nondiff S) -> Double + // CHECK: [[JVP3_ADFUNC:%.*]] = differentiable_function [wrt 1] [[JVP3_ORIG_FNREF]] : {{.*}} with {{{%.*}} : {{.*}}, [[JVP3_VJP_FNREF]] : {{.*}}} + // CHECK: [[JVP3:%.*]] = differentiable_function_extract [jvp] [[JVP3_ADFUNC]] : $@differentiable @convention(method) (@nondiff Float, Double, @nondiff S) -> Double // CHECK: apply [[JVP3]] // CHECK: } // end sil function 'AD__{{.*}}function3{{.*}}_jvp_USU' // CHECK-LABEL: sil {{.*}} @AD__{{.*}}function3{{.*}}_vjp_USU : $@convention(witness_method: Proto) (Float, Double, @in_guaranteed S) -> (Double, @owned @callee_guaranteed (Double) -> Double) { // CHECK: [[VJP3_ORIG_FNREF:%.*]] = function_ref {{.*}}function3{{.*}} : $@convention(method) (Float, Double, S) -> Double // CHECK: [[VJP3_VJP_FNREF:%.*]] = function_ref @AD__{{.*}}function3{{.*}}__vjp_src_0_wrt_1 - // CHECK: [[VJP3_ADFUNC:%.*]] = differentiable_function [wrt 1] [order 1] [[VJP3_ORIG_FNREF]] : {{.*}} with {{{%.*}} : {{.*}}, [[VJP3_VJP_FNREF]] : {{.*}}} - // CHECK: [[VJP3:%.*]] = differentiable_function_extract [vjp] [order 1] [[VJP3_ADFUNC]] : $@differentiable @convention(method) (@nondiff Float, Double, @nondiff S) -> Double + // CHECK: [[VJP3_ADFUNC:%.*]] = differentiable_function [wrt 1] [[VJP3_ORIG_FNREF]] : {{.*}} with {{{%.*}} : {{.*}}, [[VJP3_VJP_FNREF]] : {{.*}}} + // CHECK: [[VJP3:%.*]] = differentiable_function_extract [vjp] [[VJP3_ADFUNC]] : $@differentiable @convention(method) (@nondiff Float, Double, @nondiff S) -> Double // CHECK: apply [[VJP3]] // CHECK: } // end sil function 'AD__{{.*}}function3{{.*}}_vjp_USU' } @@ -84,12 +84,12 @@ struct S : Proto, VectorProtocol { // CHECK-LABEL: sil_witness_table hidden S: Proto module witness_table_sil { // CHECK-NEXT: base_protocol _Differentiable: S: _Differentiable module witness_table_sil // CHECK-NEXT: method #Proto.function1!1: (Self) -> (Float, Double) -> Float : @{{.*}}function1 -// CHECK-NEXT: method #Proto.function1!1.jvp.1.SSU: (Self) -> (Float, Double) -> Float : @AD__{{.*}}function1{{.*}}_jvp_SSU -// CHECK-NEXT: method #Proto.function1!1.vjp.1.SSU: (Self) -> (Float, Double) -> Float : @AD__{{.*}}function1{{.*}}_vjp_SSU +// CHECK-NEXT: method #Proto.function1!1.jvp.SSU: (Self) -> (Float, Double) -> Float : @AD__{{.*}}function1{{.*}}_jvp_SSU +// CHECK-NEXT: method #Proto.function1!1.vjp.SSU: (Self) -> (Float, Double) -> Float : @AD__{{.*}}function1{{.*}}_vjp_SSU // CHECK-NEXT: method #Proto.function2!1: (Self) -> (Float, Double) -> Float : @{{.*}}function2 -// CHECK-NEXT: method #Proto.function2!1.jvp.1.SSS: (Self) -> (Float, Double) -> Float : @AD__{{.*}}function2{{.*}}_jvp_SSS -// CHECK-NEXT: method #Proto.function2!1.vjp.1.SSS: (Self) -> (Float, Double) -> Float : @AD__{{.*}}function2{{.*}}_vjp_SSS +// CHECK-NEXT: method #Proto.function2!1.jvp.SSS: (Self) -> (Float, Double) -> Float : @AD__{{.*}}function2{{.*}}_jvp_SSS +// CHECK-NEXT: method #Proto.function2!1.vjp.SSS: (Self) -> (Float, Double) -> Float : @AD__{{.*}}function2{{.*}}_vjp_SSS // CHECK-NEXT: method #Proto.function3!1: (Self) -> (Float, Double) -> Double : @{{.*}}function3 -// CHECK-NEXT: method #Proto.function3!1.jvp.1.USU: (Self) -> (Float, Double) -> Double : @AD__{{.*}}function3{{.*}}_jvp_USU -// CHECK-NEXT: method #Proto.function3!1.vjp.1.USU: (Self) -> (Float, Double) -> Double : @AD__{{.*}}function3{{.*}}_vjp_USU +// CHECK-NEXT: method #Proto.function3!1.jvp.USU: (Self) -> (Float, Double) -> Double : @AD__{{.*}}function3{{.*}}_jvp_USU +// CHECK-NEXT: method #Proto.function3!1.vjp.USU: (Self) -> (Float, Double) -> Double : @AD__{{.*}}function3{{.*}}_vjp_USU // CHECK-NEXT:} From 3faf47490f8da9fc95eba8b4e1839f49a919230c Mon Sep 17 00:00:00 2001 From: Marc Rasi Date: Wed, 9 Oct 2019 15:37:20 -0700 Subject: [PATCH 2/4] fix conversion problem --- include/swift/SIL/SILInstruction.h | 2 +- lib/SIL/TypeLowering.cpp | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h index 02db90e4fd597..5e9dd651b5b45 100644 --- a/include/swift/SIL/SILInstruction.h +++ b/include/swift/SIL/SILInstruction.h @@ -7936,7 +7936,7 @@ class DifferentiableFunctionExtractInst } rawValue; Extractee() = default; Extractee(innerty rawValue) : rawValue(rawValue) {} - Extractee(unsigned rawValue) : Extractee((innerty)rawValue) {} + explicit Extractee(unsigned rawValue) : Extractee((innerty)rawValue) {} Extractee(AutoDiffAssociatedFunctionKind kind); explicit Extractee(StringRef name); operator innerty() const { return rawValue; } diff --git a/lib/SIL/TypeLowering.cpp b/lib/SIL/TypeLowering.cpp index 54fb9ca468729..b5d01b7139a75 100644 --- a/lib/SIL/TypeLowering.cpp +++ b/lib/SIL/TypeLowering.cpp @@ -896,16 +896,23 @@ namespace { DifferentiableFunctionExtractee::Original, TC.getTypeLowering(origFnTy, getResilienceExpansion()) }); - for (auto kind : {AutoDiffAssociatedFunctionKind::JVP, - AutoDiffAssociatedFunctionKind::VJP}) { + for (AutoDiffAssociatedFunctionKind kind : + {AutoDiffAssociatedFunctionKind::JVP, + AutoDiffAssociatedFunctionKind::VJP}) { auto assocFnTy = origFnTy->getAutoDiffAssociatedFunctionType( paramIndices, 0, kind, TC, LookUpConformanceInModule(&TC.M)); auto silTy = SILType::getPrimitiveObjectType(assocFnTy); + auto extractee = DifferentiableFunctionExtractee(kind); + + // A bug caused by implicit conversions caused us to get the wrong + // extractee, so assert that we have the right extractee to prevent + // reoccurrence of the bug. + assert(extractee.getExtracteeAsAssociatedFunction() == + Optional(kind)); + children.push_back(Child{ - DifferentiableFunctionExtractee(kind), - TC.getTypeLowering(silTy, getResilienceExpansion()) - }); + extractee, TC.getTypeLowering(silTy, getResilienceExpansion())}); } assert(children.size() == 3); } From e70d704e2f8bce50ade32ec80fa9b8d552a267f6 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Wed, 9 Oct 2019 16:13:47 -0700 Subject: [PATCH 3/4] Minor style changes. --- lib/SIL/TypeLowering.cpp | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/lib/SIL/TypeLowering.cpp b/lib/SIL/TypeLowering.cpp index b5d01b7139a75..c7019cf15440c 100644 --- a/lib/SIL/TypeLowering.cpp +++ b/lib/SIL/TypeLowering.cpp @@ -897,20 +897,18 @@ namespace { TC.getTypeLowering(origFnTy, getResilienceExpansion()) }); for (AutoDiffAssociatedFunctionKind kind : - {AutoDiffAssociatedFunctionKind::JVP, - AutoDiffAssociatedFunctionKind::VJP}) { + {AutoDiffAssociatedFunctionKind::JVP, + AutoDiffAssociatedFunctionKind::VJP}) { auto assocFnTy = origFnTy->getAutoDiffAssociatedFunctionType( paramIndices, 0, kind, TC, LookUpConformanceInModule(&TC.M)); auto silTy = SILType::getPrimitiveObjectType(assocFnTy); - auto extractee = DifferentiableFunctionExtractee(kind); - - // A bug caused by implicit conversions caused us to get the wrong - // extractee, so assert that we have the right extractee to prevent - // reoccurrence of the bug. - assert(extractee.getExtracteeAsAssociatedFunction() == - Optional(kind)); - + DifferentiableFunctionExtractee extractee(kind); + // Assert that we have the right extractee. A terrible bug in the past + // was caused by implicit conversions from `unsigned` to + // `DifferentiableFunctionExtractee` which resulted into a wrong + // extractee. + assert(extractee.getExtracteeAsAssociatedFunction() == kind); children.push_back(Child{ extractee, TC.getTypeLowering(silTy, getResilienceExpansion())}); } From 96b656465a74025445eb7add6462bd3fd69658aa Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Wed, 9 Oct 2019 16:30:06 -0700 Subject: [PATCH 4/4] Fix test. --- test/AutoDiff/differentiable_function_silgen.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/AutoDiff/differentiable_function_silgen.swift b/test/AutoDiff/differentiable_function_silgen.swift index 843e2a70b4c48..ef7d06cbed532 100644 --- a/test/AutoDiff/differentiable_function_silgen.swift +++ b/test/AutoDiff/differentiable_function_silgen.swift @@ -84,5 +84,5 @@ func appliesReabstraction(_ f: @escaping @differentiable (Float) -> Float) { // CHECK-SILGEN: [[REABS_VJP:%.*]] = function_ref @$sS4fIegyd_Iegydo_S4fIegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) // CHECK-SILGEN: [[NEW_VJP:%.*]] = partial_apply [callee_guaranteed] [[REABS_VJP]]([[VJP_COPY]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) // CHECK-SILGEN: [[NEW_DIFF_FUNC:%.*]] = differentiable_function [wrt 0] [[NEW_ORIG]] : $@callee_guaranteed (@in_guaranteed Float) -> @out Float with {[[NEW_JVP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float), [[NEW_VJP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)} -// CHECK-SILGEN: [[DIFF_API:%.*]] = function_ref @${{.*}}pullback{{.*}}at{{.*}} : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector -// CHECK-SILGEN: apply [[DIFF_API]]({{.*}}, [[NEW_DIFF_FUNC]]) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector +// CHECK-SILGEN: [[DIFF_API:%.*]] = function_ref @${{.*}}pullback{{.*}}at{{.*}} : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : _Differentiable, τ_0_1 : _Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector +// HECK-SILGEN: apply [[DIFF_API]]({{.*}}, [[NEW_DIFF_FUNC]]) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : _Differentiable, τ_0_1 : _Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector