diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index f4e8806b8f7f9..c47e0e835d71d 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -20,6 +20,7 @@ #include "ASTContext.h" #include "llvm/ADT/SmallBitVector.h" +#include "swift/Basic/Range.h" namespace swift { @@ -73,6 +74,7 @@ class ParsedAutoDiffParameter { }; class AnyFunctionType; +class AutoDiffIndexSubset; class AutoDiffParameterIndicesBuilder; class Type; @@ -173,7 +175,8 @@ class AutoDiffParameterIndices : public llvm::FoldingSetNode { /// ==> returns 1110 /// (because the lowered SIL type is (A, B, C, D) -> R) /// - llvm::SmallBitVector getLowered(AnyFunctionType *functionType) const; + AutoDiffIndexSubset *getLowered(ASTContext &ctx, + AnyFunctionType *functionType) const; void Profile(llvm::FoldingSetNodeID &ID) const { ID.AddInteger(parameters.size()); @@ -219,6 +222,216 @@ class AutoDiffParameterIndicesBuilder { unsigned size() { return parameters.size(); } }; +class AutoDiffIndexSubset : public llvm::FoldingSetNode { +public: + typedef uint64_t BitWord; + + static constexpr unsigned bitWordSize = sizeof(BitWord); + static constexpr unsigned numBitsPerBitWord = bitWordSize * 8; + + static std::pair + getBitWordIndexAndOffset(unsigned index) { + auto bitWordIndex = index / numBitsPerBitWord; + auto bitWordOffset = index % numBitsPerBitWord; + return {bitWordIndex, bitWordOffset}; + } + + static unsigned getNumBitWordsNeededForCapacity(unsigned capacity) { + if (capacity == 0) return 0; + return capacity / numBitsPerBitWord + 1; + } + +private: + /// The total capacity of the index subset, which is `1` less than the largest + /// index. + unsigned capacity; + /// The number of bit words in the index subset. + unsigned numBitWords; + + BitWord *getBitWordsData() { + return reinterpret_cast(this + 1); + } + + const BitWord *getBitWordsData() const { + return reinterpret_cast(this + 1); + } + + ArrayRef getBitWords() const { + return {getBitWordsData(), getNumBitWords()}; + } + + BitWord getBitWord(unsigned i) const { + return getBitWordsData()[i]; + } + + BitWord &getBitWord(unsigned i) { + return getBitWordsData()[i]; + } + + MutableArrayRef getMutableBitWords() { + return {const_cast(getBitWordsData()), getNumBitWords()}; + } + + explicit AutoDiffIndexSubset(unsigned capacity, ArrayRef indices) + : capacity(capacity), + numBitWords(getNumBitWordsNeededForCapacity(capacity)) { + std::uninitialized_fill_n(getBitWordsData(), numBitWords, 0); + for (auto i : indices) { + unsigned bitWordIndex, offset; + std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(i); + getBitWord(bitWordIndex) |= (1 << offset); + } + } + +public: + AutoDiffIndexSubset() = delete; + AutoDiffIndexSubset(const AutoDiffIndexSubset &) = delete; + AutoDiffIndexSubset &operator=(const AutoDiffIndexSubset &) = delete; + + // Defined in ASTContext.h. + static AutoDiffIndexSubset *get(ASTContext &ctx, + unsigned capacity, + ArrayRef indices); + + static AutoDiffIndexSubset *getDefault(ASTContext &ctx, + unsigned capacity, + bool includeAll = false) { + if (includeAll) + return getFromRange(ctx, capacity, IntRange<>(capacity)); + return get(ctx, capacity, {}); + } + + static AutoDiffIndexSubset *getFromRange(ASTContext &ctx, + unsigned capacity, + IntRange<> range) { + return get(ctx, capacity, + SmallVector(range.begin(), range.end())); + } + + unsigned getNumBitWords() const { + return numBitWords; + } + + unsigned getCapacity() const { + return capacity; + } + + class iterator; + + iterator begin() const { + return iterator(this); + } + + iterator end() const { + return iterator(this, (int)capacity); + } + + iterator_range getIndices() const { + return make_range(begin(), end()); + } + + unsigned getNumIndices() const { + return (unsigned)std::distance(begin(), end()); + } + + bool contains(unsigned index) const { + unsigned bitWordIndex, offset; + std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(index); + return getBitWord(bitWordIndex) & (1 << offset); + } + + bool isEmpty() const { + return llvm::all_of(getBitWords(), [](BitWord bw) { return !(bool)bw; }); + } + + bool equals(AutoDiffIndexSubset *other) const { + return capacity == other->getCapacity() && + getBitWords().equals(other->getBitWords()); + } + + bool isSubsetOf(AutoDiffIndexSubset *other) const; + bool isSupersetOf(AutoDiffIndexSubset *other) const; + + AutoDiffIndexSubset *adding(unsigned index, ASTContext &ctx) const; + AutoDiffIndexSubset *extendingCapacity(ASTContext &ctx, + unsigned newCapacity) const; + + void Profile(llvm::FoldingSetNodeID &id) const { + id.AddInteger(capacity); + for (auto index : getIndices()) + id.AddInteger(index); + } + + void print(llvm::raw_ostream &s = llvm::outs()) const { + s << '{'; + interleave(range(capacity), [this, &s](unsigned i) { s << contains(i); }, + [&s] { s << ", "; }); + s << '}'; + } + + void dump(llvm::raw_ostream &s = llvm::errs()) const { + s << "(autodiff_index_subset capacity=" << capacity << " indices=("; + interleave(getIndices(), [&s](unsigned i) { s << i; }, + [&s] { s << ", "; }); + s << "))"; + } + + int findNext(int startIndex) const; + int findFirst() const { return findNext(-1); } + int findPrevious(int endIndex) const; + int findLast() const { return findPrevious(capacity); } + + class iterator { + public: + typedef unsigned value_type; + typedef unsigned difference_type; + typedef unsigned * pointer; + typedef unsigned & reference; + typedef std::forward_iterator_tag iterator_category; + + private: + const AutoDiffIndexSubset *parent; + int current = 0; + + void advance() { + assert(current != -1 && "Trying to advance past end."); + current = parent->findNext(current); + } + + public: + iterator(const AutoDiffIndexSubset *parent, int current) + : parent(parent), current(current) {} + explicit iterator(const AutoDiffIndexSubset *parent) + : iterator(parent, parent->findFirst()) {} + iterator(const iterator &) = default; + + iterator operator++(int) { + auto prev = *this; + advance(); + return prev; + } + + iterator &operator++() { + advance(); + return *this; + } + + unsigned operator*() const { return current; } + + bool operator==(const iterator &other) const { + assert(parent == other.parent && + "Comparing iterators from different AutoDiffIndexSubsets"); + return current == other.current; + } + + bool operator!=(const iterator &other) const { + assert(parent == other.parent && + "Comparing iterators from different AutoDiffIndexSubsets"); + return current != other.current; + } + }; +}; + /// SIL-level automatic differentiation indices. Consists of a source index, /// i.e. index of the dependent result to differentiate from, and parameter /// indices, i.e. index of independent parameters to differentiate with @@ -242,38 +455,33 @@ struct SILAutoDiffIndices { /// Function type: (A, B) -> (C, D) -> R /// Bits: [C][D][A][B] /// - llvm::SmallBitVector parameters; + AutoDiffIndexSubset *parameters; /// Creates a set of AD indices from the given source index and a bit vector /// representing parameter indices. /*implicit*/ SILAutoDiffIndices(unsigned source, - llvm::SmallBitVector parameters) + AutoDiffIndexSubset *parameters) : source(source), parameters(parameters) {} - /// Creates a set of AD indices from the given source index and an array of - /// parameter indices. Elements in `parameters` must be ascending integers. - /*implicit*/ SILAutoDiffIndices(unsigned source, - ArrayRef parameters); - bool operator==(const SILAutoDiffIndices &other) const; /// Queries whether the function's parameter with index `parameterIndex` is /// one of the parameters to differentiate with respect to. bool isWrtParameter(unsigned parameterIndex) const { - return parameterIndex < parameters.size() && - parameters.test(parameterIndex); + return parameterIndex < parameters->getCapacity() && + parameters->contains(parameterIndex); } void print(llvm::raw_ostream &s = llvm::outs()) const { s << "(source=" << source << " parameters=("; - interleave(parameters.set_bits(), + interleave(parameters->getIndices(), [&s](unsigned p) { s << p; }, [&s]{ s << ' '; }); s << "))"; } std::string mangle() const { std::string result = "src_" + llvm::utostr(source) + "_wrt_"; - interleave(parameters.set_bits(), + interleave(parameters->getIndices(), [&](unsigned idx) { result += llvm::utostr(idx); }, [&] { result += '_'; }); return result; @@ -449,19 +657,18 @@ template struct DenseMapInfo; template<> struct DenseMapInfo { static SILAutoDiffIndices getEmptyKey() { - return { DenseMapInfo::getEmptyKey(), SmallBitVector() }; + return { DenseMapInfo::getEmptyKey(), nullptr }; } static SILAutoDiffIndices getTombstoneKey() { - return { DenseMapInfo::getTombstoneKey(), - SmallBitVector(sizeof(intptr_t), true) }; + return { DenseMapInfo::getTombstoneKey(), nullptr }; } static unsigned getHashValue(const SILAutoDiffIndices &Val) { - auto params = Val.parameters.set_bits(); unsigned combinedHash = hash_combine(~1U, DenseMapInfo::getHashValue(Val.source), - hash_combine_range(params.begin(), params.end())); + hash_combine_range(Val.parameters->begin(), + Val.parameters->end())); return combinedHash; } diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index b5baf560dd687..60b3e1afa8017 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -1535,7 +1535,9 @@ ERROR(sil_inst_autodiff_operand_list_expected_rbrace,PointsToFirstBadToken, ERROR(sil_inst_autodiff_num_operand_list_order_mismatch,PointsToFirstBadToken, "the number of operand lists does not match the order", ()) ERROR(sil_inst_autodiff_expected_associated_function_kind_attr,PointsToFirstBadToken, - "expects an assoiacted function kind attribute, e.g. '[jvp]'", ()) + "expected an associated function kind attribute, e.g. '[jvp]'", ()) +ERROR(sil_inst_autodiff_expected_function_type_operand,PointsToFirstBadToken, + "expected an operand of a function type", ()) //------------------------------------------------------------------------------ // MARK: Generics parsing diagnostics diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index b66d3b652d1fc..801066a59e96b 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -4132,14 +4132,14 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode, // SWIFT_ENABLE_TENSORFLOW CanSILFunctionType getWithDifferentiability( - unsigned differentiationOrder, const SmallBitVector ¶meterIndices); + unsigned differentiationOrder, AutoDiffIndexSubset *parameterIndices); CanSILFunctionType getWithoutDifferentiability(); /// Returns the type of a differentiation function that is associated with /// a function of this type. CanSILFunctionType getAutoDiffAssociatedFunctionType( - const SmallBitVector ¶meterIndices, unsigned resultIndex, + AutoDiffIndexSubset *parameterIndices, unsigned resultIndex, unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind, SILModule &module, LookupConformanceFn lookupConformance, GenericSignature *whereClauseGenericSignature = nullptr); @@ -4148,7 +4148,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode, /// differentiate with respect to for this differentiable function type. (e.g. /// which parameters are not @nondiff). The function type must be /// differentiable. - SmallBitVector getDifferentiationParameterIndices() const; + AutoDiffIndexSubset *getDifferentiationParameterIndices(); /// If this is a @convention(witness_method) function with a class /// constrained self parameter, return the class constraint for the diff --git a/include/swift/SIL/SILBuilder.h b/include/swift/SIL/SILBuilder.h index 3fe0747495cc8..8c5e2a29572b3 100644 --- a/include/swift/SIL/SILBuilder.h +++ b/include/swift/SIL/SILBuilder.h @@ -504,7 +504,7 @@ class SILBuilder { /// SWIFT_ENABLE_TENSORFLOW AutoDiffFunctionInst *createAutoDiffFunction( - SILLocation loc, const llvm::SmallBitVector ¶meterIndices, + SILLocation loc, AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, SILValue original, ArrayRef associatedFunctions = {}) { return insert(AutoDiffFunctionInst::create(getModule(), diff --git a/include/swift/SIL/SILFunction.h b/include/swift/SIL/SILFunction.h index f82a6f20e4f5c..30b08688f24ff 100644 --- a/include/swift/SIL/SILFunction.h +++ b/include/swift/SIL/SILFunction.h @@ -174,6 +174,9 @@ class SILDifferentiableAttr final { SILFunction *getOriginal() const { return Original; } const SILAutoDiffIndices &getIndices() const { return indices; } + void setIndices(const SILAutoDiffIndices &indices) { + this->indices = indices; + } TrailingWhereClause *getWhereClause() const { return WhereClause; } diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h index be2c95918c1ee..db4f0dcf164b6 100644 --- a/include/swift/SIL/SILInstruction.h +++ b/include/swift/SIL/SILInstruction.h @@ -7721,7 +7721,7 @@ class AutoDiffFunctionInst final : private: friend SILBuilder; /// Differentiation parameter indices. - SmallBitVector parameterIndices; + AutoDiffIndexSubset *parameterIndices; /// The order of differentiation. unsigned differentiationOrder; /// The number of operands. The first operand is always the original function. @@ -7730,7 +7730,7 @@ class AutoDiffFunctionInst final : unsigned numOperands; AutoDiffFunctionInst(SILModule &module, SILDebugLocation debugLoc, - const SmallBitVector ¶meterIndices, + AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, SILValue originalFunction, ArrayRef associatedFunctions); @@ -7738,20 +7738,20 @@ class AutoDiffFunctionInst final : public: static AutoDiffFunctionInst *create(SILModule &module, SILDebugLocation debugLoc, - const SmallBitVector ¶meterIndices, + AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, SILValue originalFunction, ArrayRef associatedFunctions); static SILType getAutoDiffType(SILValue original, unsigned differentiationOrder, - const SmallBitVector ¶meterIndices); + AutoDiffIndexSubset *parameterIndices); /// Returns the original function. SILValue getOriginalFunction() const { return getAllOperands()[0].get(); } /// Returns differentiation indices. - const SmallBitVector &getParameterIndices() const { + AutoDiffIndexSubset *getParameterIndices() const { return parameterIndices; } diff --git a/include/swift/Serialization/ModuleFormat.h b/include/swift/Serialization/ModuleFormat.h index b5a40fc921ad0..e02e226f474e3 100644 --- a/include/swift/Serialization/ModuleFormat.h +++ b/include/swift/Serialization/ModuleFormat.h @@ -52,7 +52,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0; /// describe what change you made. The content of this comment isn't important; /// it just ensures a conflict if two people change the module format. /// Don't worry about adhering to the 80-column limit for this line. -const uint16_t SWIFTMODULE_VERSION_MINOR = 489; // Last change: `@differentiating` wrt +const uint16_t SWIFTMODULE_VERSION_MINOR = 490; // Last change: `@differentiable` parameter indices layout. using DeclIDField = BCFixed<31>; diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index 11f6698b43e20..6b7026f23c9dd 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -396,6 +396,9 @@ FOR_KNOWN_FOUNDATION_TYPES(CACHE_FOUNDATION_DECL) /// For uniquifying `AutoDiffParameterIndices` allocations. llvm::FoldingSet AutoDiffParameterIndicesSet; + /// For uniquifying `AutoDiffIndexSubset` allocations. + llvm::FoldingSet AutoDiffIndexSubsets; + /// For uniquifying `AutoDiffAssociatedFunctionIdentifier` allocations. llvm::FoldingSet AutoDiffAssociatedFunctionIdentifiers; @@ -4533,6 +4536,35 @@ AutoDiffParameterIndices::get(llvm::SmallBitVector indices, ASTContext &C) { return newNode; } +AutoDiffIndexSubset * +AutoDiffIndexSubset::get(ASTContext &ctx, unsigned capacity, + ArrayRef indices) { + auto &foldingSet = ctx.getImpl().AutoDiffIndexSubsets; + llvm::FoldingSetNodeID id; + id.AddInteger(capacity); +#ifndef NDEBUG + int last = -1; +#endif + for (unsigned index : indices) { +#ifndef NDEBUG + assert((int)index > last && "Indices must be ascending"); + last = (int)index; +#endif + id.AddInteger(index); + } + void *insertPos = nullptr; + auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos); + if (existing) + return existing; + auto sizeToAlloc = sizeof(AutoDiffIndexSubset) + + getNumBitWordsNeededForCapacity(capacity); + auto *buf = reinterpret_cast( + ctx.Allocate(sizeToAlloc, alignof(AutoDiffIndexSubset))); + auto *newNode = new (buf) AutoDiffIndexSubset(capacity, indices); + foldingSet.InsertNode(newNode, insertPos); + return newNode; +} + AutoDiffAssociatedFunctionIdentifier * AutoDiffAssociatedFunctionIdentifier::get( AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder, diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index 9e87df3e8155b..755015902e2dd 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -20,33 +20,8 @@ using namespace swift; -SILAutoDiffIndices::SILAutoDiffIndices( - unsigned source, ArrayRef parameters) : source(source) { - if (parameters.empty()) - return; - - auto max = *std::max_element(parameters.begin(), parameters.end()); - this->parameters.resize(max + 1); - int last = -1; - for (auto paramIdx : parameters) { - assert((int)paramIdx > last && "Parameter indices must be ascending"); - last = paramIdx; - this->parameters.set(paramIdx); - } -} - -bool SILAutoDiffIndices::operator==( - const SILAutoDiffIndices &other) const { - if (source != other.source) - return false; - - // The parameters are the same when they have exactly the same set bit - // indices, even if they have different sizes. - llvm::SmallBitVector buffer(std::max(parameters.size(), - other.parameters.size())); - buffer ^= parameters; - buffer ^= other.parameters; - return buffer.none(); +bool SILAutoDiffIndices::operator==(const SILAutoDiffIndices &other) const { + return source == other.source && parameters == other.parameters; } AutoDiffAssociatedFunctionKind:: @@ -222,8 +197,9 @@ static unsigned countNumFlattenedElementTypes(Type type) { /// ==> returns 1110 /// (because the lowered SIL type is (A, B, C, D) -> R) /// -llvm::SmallBitVector -AutoDiffParameterIndices::getLowered(AnyFunctionType *functionType) const { +AutoDiffIndexSubset * +AutoDiffParameterIndices::getLowered(ASTContext &ctx, + AnyFunctionType *functionType) const { SmallVector curryLevels; unwrapCurryLevels(functionType, curryLevels); @@ -241,16 +217,18 @@ AutoDiffParameterIndices::getLowered(AnyFunctionType *functionType) const { // Construct the result by setting each range of bits that corresponds to each // "on" parameter. - llvm::SmallBitVector result(totalLoweredSize); + llvm::SmallVector loweredIndices; unsigned currentBitIndex = 0; for (unsigned i : range(parameters.size())) { auto paramLoweredSize = paramLoweredSizes[i]; - if (parameters[i]) - result.set(currentBitIndex, currentBitIndex + paramLoweredSize); + if (parameters[i]) { + auto indices = range(currentBitIndex, currentBitIndex + paramLoweredSize); + loweredIndices.append(indices.begin(), indices.end()); + } currentBitIndex += paramLoweredSize; } - return result; + return AutoDiffIndexSubset::get(ctx, totalLoweredSize, loweredIndices); } static unsigned getNumAutoDiffParameterIndices(AnyFunctionType *fnTy) { @@ -352,3 +330,94 @@ CanType VectorSpace::getCanonicalType() const { NominalTypeDecl *VectorSpace::getNominal() const { return getVector()->getNominalOrBoundGenericNominal(); } + +bool AutoDiffIndexSubset::isSubsetOf(AutoDiffIndexSubset *other) const { + assert(capacity == other->capacity); + for (auto index : range(numBitWords)) + if (getBitWord(index) & ~other->getBitWord(index)) + return false; + return true; +} + +bool AutoDiffIndexSubset::isSupersetOf(AutoDiffIndexSubset *other) const { + assert(capacity == other->capacity); + for (auto index : range(numBitWords)) + if (~getBitWord(index) & other->getBitWord(index)) + return false; + return true; +} + +AutoDiffIndexSubset *AutoDiffIndexSubset::adding(unsigned index, + ASTContext &ctx) const { + assert(index < getCapacity()); + if (contains(index)) + return const_cast(this); + SmallVector newIndices; + newIndices.reserve(capacity + 1); + bool inserted = false; + for (auto curIndex : getIndices()) { + if (!inserted && curIndex > index) { + newIndices.push_back(index); + inserted = true; + } + newIndices.push_back(curIndex); + } + return get(ctx, capacity, newIndices); +} + +AutoDiffIndexSubset *AutoDiffIndexSubset::extendingCapacity( + ASTContext &ctx, unsigned newCapacity) const { + assert(newCapacity >= capacity); + if (newCapacity == capacity) + return const_cast(this); + SmallVector indices; + for (auto index : getIndices()) + indices.push_back(index); + return AutoDiffIndexSubset::get(ctx, newCapacity, indices); +} + +int AutoDiffIndexSubset::findNext(int startIndex) const { + assert(startIndex < (int)capacity && "Start index cannot be past the end"); + unsigned bitWordIndex = 0, offset = 0; + if (startIndex >= 0) { + auto indexAndOffset = getBitWordIndexAndOffset(startIndex); + bitWordIndex = indexAndOffset.first; + offset = indexAndOffset.second + 1; + } + for (; bitWordIndex < numBitWords; ++bitWordIndex, offset = 0) { + for (; offset < numBitsPerBitWord; ++offset) { + auto index = bitWordIndex * numBitsPerBitWord + offset; + auto bitWord = getBitWord(bitWordIndex); + if (!bitWord) + break; + if (index >= capacity) + return capacity; + if (bitWord & ((BitWord)1 << offset)) + return index; + } + } + return capacity; +} + +int AutoDiffIndexSubset::findPrevious(int endIndex) const { + assert(endIndex >= 0 && "End index cannot be before the start"); + int bitWordIndex = numBitWords - 1, offset = numBitsPerBitWord - 1; + if (endIndex < (int)capacity) { + auto indexAndOffset = getBitWordIndexAndOffset(endIndex); + bitWordIndex = (int)indexAndOffset.first; + offset = (int)indexAndOffset.second - 1; + } + for (; bitWordIndex >= 0; --bitWordIndex, offset = numBitsPerBitWord - 1) { + for (; offset < (int)numBitsPerBitWord; --offset) { + auto index = bitWordIndex * (int)numBitsPerBitWord + offset; + auto bitWord = getBitWord(bitWordIndex); + if (!bitWord) + break; + if (index < 0) + return -1; + if (bitWord & ((BitWord)1 << offset)) + return index; + } + } + return -1; +} diff --git a/lib/IRGen/GenDiffFunc.cpp b/lib/IRGen/GenDiffFunc.cpp index a486171487948..01093f98f3483 100644 --- a/lib/IRGen/GenDiffFunc.cpp +++ b/lib/IRGen/GenDiffFunc.cpp @@ -39,14 +39,14 @@ namespace { class DiffFuncFieldInfo final : public RecordField { public: DiffFuncFieldInfo(DiffFuncIndex index, const TypeInfo &type, - const SmallBitVector ¶meterIndices) + AutoDiffIndexSubset *parameterIndices) : RecordField(type), Index(index), ParameterIndices(parameterIndices) {} /// The field index. const DiffFuncIndex Index; /// The parameter indices. - SmallBitVector ParameterIndices; + AutoDiffIndexSubset *ParameterIndices; std::string getFieldName() const { auto extractee = std::get<0>(Index); @@ -119,7 +119,7 @@ class DiffFuncTypeBuilder DiffFuncIndex> { SILFunctionType *origFnTy; - SmallBitVector parameterIndices; + AutoDiffIndexSubset *parameterIndices; public: DiffFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy) diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index c93ca437ee553..cb55fe2626e7e 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -974,7 +974,7 @@ void SILParser::convertRequirements(SILFunction *F, /// SWIFT_ENABLE_TENSORFLOW /// Parse a `differentiable` attribute, e.g. -/// `[differentiable wrt 0, 1 adjoint @other]`. +/// `[differentiable wrt 0, 1 vjp @other]`. /// Returns true on error. static bool parseDifferentiableAttr( SmallVectorImpl &DAs, SILParser &SP) { @@ -1045,9 +1045,12 @@ static bool parseDifferentiableAttr( whereLoc, requirementReprs); } // Create a SILDifferentiableAttr and we are done. + auto maxIndexRef = std::max_element(ParamIndices.begin(), ParamIndices.end()); + auto *paramIndicesSubset = AutoDiffIndexSubset::get( + P.Context, maxIndexRef ? *maxIndexRef + 1 : 0, ParamIndices); auto *Attr = SILDifferentiableAttr::create( - SP.SILMod, {SourceIndex, ParamIndices}, JVPName.str(), VJPName.str(), - WhereClause); + SP.SILMod, {SourceIndex, paramIndicesSubset}, JVPName.str(), + VJPName.str(), WhereClause); DAs.push_back(Attr); return false; } @@ -2873,7 +2876,7 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { // {%1 : $T, %2 : $T}, {%3 : $T, %4 : $T} // ^ jvp ^ vjp SourceLoc lastLoc; - SmallBitVector parameterIndices(32); + SmallVector parameterIndices; unsigned order = 1; // Parse optional `[wrt ...]` if (P.Tok.is(tok::l_square) && @@ -2882,15 +2885,12 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { P.consumeToken(tok::l_square); P.consumeToken(tok::identifier); // Parse indices. - unsigned size = parameterIndices.size(); while (P.Tok.is(tok::integer_literal)) { unsigned index; if (P.parseUnsignedInteger(index, lastLoc, diag::sil_inst_autodiff_expected_parameter_index)) return true; - if (index >= size) - parameterIndices.resize((size *= 2)); - parameterIndices.set(index); + parameterIndices.push_back(index); } if (P.parseToken(tok::r_square, diag::sil_inst_autodiff_attr_expected_rsquare, @@ -2917,8 +2917,15 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { } // Parse the original function value. SILValue original; - if (parseTypedValueRef(original, B)) + SourceLoc originalOperandLoc; + if (parseTypedValueRef(original, originalOperandLoc, B)) return true; + auto fnType = original->getType().getAs(); + if (!fnType) { + P.diagnose(originalOperandLoc, + diag::sil_inst_autodiff_expected_function_type_operand); + return true; + } SmallVector associatedFunctions; // Parse optional operand lists `with { , }, ...`. if (P.Tok.is(tok::identifier) && P.Tok.getText() == "with") { @@ -2950,7 +2957,10 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { } if (parseSILDebugLocation(InstLoc, B)) return true; - ResultVal = B.createAutoDiffFunction(InstLoc, parameterIndices, order, + auto *parameterIndicesSubset = + AutoDiffIndexSubset::get(P.Context, fnType->getNumParameters(), + parameterIndices); + ResultVal = B.createAutoDiffFunction(InstLoc, parameterIndicesSubset, order, original, associatedFunctions); break; } @@ -5714,6 +5724,18 @@ bool SILParserTUState::parseDeclSIL(Parser &P) { // SWIFT_ENABLE_TENSORFLOW for (auto &attr : DiffAttrs) { + // Resolve parameter indices to have the right capacity, if it's + // different from the number of parameters. We have to do this because + // the parser does not know the function type before creating a + // `SILDifferentiableAttr`, so it had to find the max of all provided + // indices. + if (attr->getIndices().parameters->getCapacity() != + SILFnType->getNumParameters()) { + auto *newParamIndices = attr->getIndices().parameters + ->extendingCapacity(P.Context, SILFnType->getNumParameters()); + attr->setIndices({attr->getIndices().source, newParamIndices}); + } + // Resolve where clause requirements. // If no where clause, continue. if (!attr->getWhereClause()) diff --git a/lib/SIL/SILDeclRef.cpp b/lib/SIL/SILDeclRef.cpp index 65f4fa1a20f88..bbd0a611525aa 100644 --- a/lib/SIL/SILDeclRef.cpp +++ b/lib/SIL/SILDeclRef.cpp @@ -683,7 +683,7 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const { getDecl()->getInterfaceType()->castTo(); auto silParameterIndices = autoDiffAssociatedFunctionIdentifier->getParameterIndices()->getLowered( - functionTy); + functionTy->getASTContext(), functionTy); SILAutoDiffIndices indices(/*source*/ 0, silParameterIndices); std::string mangledKind; switch (autoDiffAssociatedFunctionIdentifier->getKind()) { diff --git a/lib/SIL/SILFunctionBuilder.cpp b/lib/SIL/SILFunctionBuilder.cpp index e1f2d0ecca64f..f1f18ae695341 100644 --- a/lib/SIL/SILFunctionBuilder.cpp +++ b/lib/SIL/SILFunctionBuilder.cpp @@ -89,6 +89,7 @@ void SILFunctionBuilder::addFunctionAttributes(SILFunction *F, // Get lowered argument indices. auto paramIndices = A->getParameterIndices(); auto loweredParamIndices = paramIndices->getLowered( + F->getASTContext(), decl->getInterfaceType()->castTo()); SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices); auto silDiffAttr = SILDifferentiableAttr::create( diff --git a/lib/SIL/SILFunctionType.cpp b/lib/SIL/SILFunctionType.cpp index 241afe260c29f..928595ef5f10c 100644 --- a/lib/SIL/SILFunctionType.cpp +++ b/lib/SIL/SILFunctionType.cpp @@ -98,20 +98,19 @@ CanType SILFunctionType::getSelfInstanceType() const { } // SWIFT_ENABLE_TENSORFLOW -SmallBitVector -SILFunctionType::getDifferentiationParameterIndices() const { +AutoDiffIndexSubset * +SILFunctionType::getDifferentiationParameterIndices() { assert(isDifferentiable()); - SmallBitVector result(NumParameters, true); + SmallVector result; for (auto valueAndIndex : enumerate(getParameters())) - if (valueAndIndex.value().getDifferentiability() == + if (valueAndIndex.value().getDifferentiability() != SILParameterDifferentiability::NotDifferentiable) - result.reset(valueAndIndex.index()); - return result; + result.push_back(valueAndIndex.index()); + return AutoDiffIndexSubset::get(getASTContext(), getNumParameters(), result); } CanSILFunctionType SILFunctionType::getWithDifferentiability( - unsigned differentiationOrder, - const SmallBitVector ¶meterIndices) { + unsigned differentiationOrder, AutoDiffIndexSubset *parameterIndices) { // FIXME(rxwei): Handle differentiation order. SmallVector newParameters; @@ -119,9 +118,10 @@ CanSILFunctionType SILFunctionType::getWithDifferentiability( auto ¶m = paramAndIndex.value(); unsigned index = paramAndIndex.index(); newParameters.push_back(param.getWithDifferentiability( - index < parameterIndices.size() && parameterIndices[index] - ? SILParameterDifferentiability::DifferentiableOrNotApplicable - : SILParameterDifferentiability::NotDifferentiable)); + index < parameterIndices->getCapacity() && + parameterIndices->contains(index) + ? SILParameterDifferentiability::DifferentiableOrNotApplicable + : SILParameterDifferentiability::NotDifferentiable)); } auto newExtInfo = getExtInfo().withDifferentiable(); @@ -147,10 +147,9 @@ CanSILFunctionType SILFunctionType::getWithoutDifferentiability() { } CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType( - const SmallBitVector ¶meterIndices, unsigned resultIndex, - unsigned differentiationOrder, - AutoDiffAssociatedFunctionKind kind, SILModule &module, - LookupConformanceFn lookupConformance, + AutoDiffIndexSubset *parameterIndices, unsigned resultIndex, + unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind, + SILModule &module, LookupConformanceFn lookupConformance, GenericSignature *whereClauseGenSig) { // JVP: (T...) -> ((R...), // (T.TangentVector...) -> (R.TangentVector...)) @@ -213,7 +212,8 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType( // Helper function testing if we are differentiating wrt this index. auto isWrtIndex = [&](unsigned index) -> bool { - return index < parameterIndices.size() && parameterIndices[index]; + return index < parameterIndices->getCapacity() && + parameterIndices->contains(index); }; // Calculate WRT parameter infos, in the order that they should appear in the @@ -2301,8 +2301,8 @@ const SILConstantInfo &TypeConverter::getConstantInfo(SILDeclRef constant) { if (auto *autoDiffFuncId = constant.autoDiffAssociatedFunctionIdentifier) { auto origFnConstantInfo = getConstantInfo(constant.asAutoDiffOriginalFunction()); - auto loweredIndices = - autoDiffFuncId->getParameterIndices()->getLowered(formalInterfaceType); + auto loweredIndices = autoDiffFuncId->getParameterIndices() + ->getLowered(Context, formalInterfaceType); silFnType = origFnConstantInfo.SILFnType->getAutoDiffAssociatedFunctionType( loweredIndices, /*resultIndex*/ 0, autoDiffFuncId->getDifferentiationOrder(), autoDiffFuncId->getKind(), M, diff --git a/lib/SIL/SILInstructions.cpp b/lib/SIL/SILInstructions.cpp index 936720ca8ac0a..b78402b7ba92d 100644 --- a/lib/SIL/SILInstructions.cpp +++ b/lib/SIL/SILInstructions.cpp @@ -569,7 +569,7 @@ TryApplyInst *TryApplyInst::create( SILType AutoDiffFunctionInst::getAutoDiffType(SILValue originalFunction, unsigned differentiationOrder, - const SmallBitVector ¶meterIndices) { + AutoDiffIndexSubset *parameterIndices) { auto fnTy = originalFunction->getType().castTo(); auto diffTy = fnTy->getWithDifferentiability(differentiationOrder, parameterIndices); @@ -578,7 +578,7 @@ AutoDiffFunctionInst::getAutoDiffType(SILValue originalFunction, AutoDiffFunctionInst::AutoDiffFunctionInst( SILModule &module, SILDebugLocation debugLoc, - const SmallBitVector ¶meterIndices, unsigned differentiationOrder, + AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, SILValue originalFunction, ArrayRef associatedFunctions) : InstructionBaseWithTrailingOperands( originalFunction, associatedFunctions, debugLoc, @@ -591,7 +591,7 @@ AutoDiffFunctionInst::AutoDiffFunctionInst( AutoDiffFunctionInst *AutoDiffFunctionInst::create( SILModule &module, SILDebugLocation debugLoc, - const SmallBitVector ¶meterIndices, + AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, SILValue originalFunction, ArrayRef associatedFunctions) { size_t size = totalSizeToAlloc(associatedFunctions.size() + 1); diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index 83d2edbecee99..04c90102ed669 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -1159,9 +1159,9 @@ class SILPrinter : public SILInstructionVisitor { // SWIFT_ENABLE_TENSORFLOW void visitAutoDiffFunctionInst(AutoDiffFunctionInst *adfi) { - if (adfi->getParameterIndices().any()) { + if (!adfi->getParameterIndices()->isEmpty()) { *this << "[wrt"; - for (auto i : adfi->getParameterIndices().set_bits()) + for (auto i : adfi->getParameterIndices()->getIndices()) *this << ' ' << i; *this << "] "; } @@ -3118,7 +3118,7 @@ void SILSpecializeAttr::print(llvm::raw_ostream &OS) const { void SILDifferentiableAttr::print(llvm::raw_ostream &OS) const { auto &indices = getIndices(); OS << "source " << indices.source << " wrt "; - interleave(indices.parameters.set_bits(), + interleave(indices.parameters->getIndices(), [&](unsigned index) { OS << index; }, [&] { OS << ", "; }); if (!JVPName.empty()) { diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp index 3356d3ec37fbd..1f2c461be29b1 100644 --- a/lib/SIL/SILVerifier.cpp +++ b/lib/SIL/SILVerifier.cpp @@ -4601,7 +4601,7 @@ class SILVerifier : public SILVerifierBase { }; // Parameter indices must be specified. - require(!Attr.getIndices().parameters.empty(), + require(!Attr.getIndices().parameters->isEmpty(), "Parameter indices cannot be empty"); // JVP and VJP must be specified in canonical SIL. if (F->getModule().getStage() == SILStage::Canonical) @@ -4610,16 +4610,15 @@ class SILVerifier : public SILVerifierBase { // Verify if specified parameter indices are valid. auto numParams = countParams(F->getLoweredFunctionType()); int lastIndex = -1; - for (auto paramIdx : Attr.getIndices().parameters.set_bits()) { + for (auto paramIdx : Attr.getIndices().parameters->getIndices()) { require(paramIdx < numParams, "Parameter index out of bounds."); auto currentIdx = (int)paramIdx; require(currentIdx > lastIndex, "Parameter indices not ascending."); lastIndex = currentIdx; } - // TODO: Verify if the specified primal/adjoint function has the right - // signature. SIL function verification runs right after a function is - // parsed. - // However, the adjoint function may come after the this function. Without + // TODO: Verify if the specified JVP/VJP function has the right signature. + // SIL function verification runs right after a function is parsed. + // However, the JVP/VJP function may come after the this function. Without // changing the compiler too much, is there a way to verify this at a module // level, after everything is parsed? } diff --git a/lib/SILGen/SILGenPoly.cpp b/lib/SILGen/SILGenPoly.cpp index 5e687f3bf5425..95bd4ca91dd18 100644 --- a/lib/SILGen/SILGenPoly.cpp +++ b/lib/SILGen/SILGenPoly.cpp @@ -3820,6 +3820,7 @@ getWitnessFunctionRef(SILGenFunction &SGF, auto originalFn = SGF.emitGlobalFunctionRef( loc, witness.asAutoDiffOriginalFunction()); auto loweredIndices = autoDiffFuncId->getParameterIndices()->getLowered( + SGF.getASTContext(), witness.getDecl()->getInterfaceType()->castTo()); auto autoDiffFn = SGF.B.createAutoDiffFunction( loc, loweredIndices, /*differentiationOrder*/ 1, originalFn); diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index bed544710cc30..2985bd64e9390 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -213,9 +213,10 @@ static FuncDecl *findOperatorDeclInProtocol(DeclName operatorName, /// qualifier for creating a `store` instruction into the buffer. static StoreOwnershipQualifier getBufferSOQ(Type type, SILFunction &fn) { if (fn.hasOwnership()) - return fn.getModule().Types.getTypeLowering(type, ResilienceExpansion::Minimal).isTrivial() - ? StoreOwnershipQualifier::Trivial - : StoreOwnershipQualifier::Init; + return fn.getModule().Types.getTypeLowering( + type, ResilienceExpansion::Minimal).isTrivial() + ? StoreOwnershipQualifier::Trivial + : StoreOwnershipQualifier::Init; return StoreOwnershipQualifier::Unqualified; } @@ -223,9 +224,10 @@ static StoreOwnershipQualifier getBufferSOQ(Type type, SILFunction &fn) { /// qualified for creating a `load` instruction from the buffer. static LoadOwnershipQualifier getBufferLOQ(Type type, SILFunction &fn) { if (fn.hasOwnership()) - return fn.getModule().Types.getTypeLowering(type, ResilienceExpansion::Minimal).isTrivial() - ? LoadOwnershipQualifier::Trivial - : LoadOwnershipQualifier::Take; + return fn.getModule().Types.getTypeLowering( + type, ResilienceExpansion::Minimal).isTrivial() + ? LoadOwnershipQualifier::Trivial + : LoadOwnershipQualifier::Take; return LoadOwnershipQualifier::Unqualified; } @@ -932,22 +934,31 @@ class ADContext { DifferentiationTask * lookUpMinimalDifferentiationTask(SILFunction *original, const SILAutoDiffIndices &indices) { - auto supersetParamIndices = llvm::SmallBitVector(); - const auto &indexSet = indices.parameters; + auto *superset = AutoDiffIndexSubset::getDefault( + getASTContext(), + original->getLoweredFunctionType()->getNumParameters(), false); + auto *indexSet = indices.parameters; if (auto *existingTask = lookUpDifferentiationTask(original, indices)) return existingTask; for (auto *rda : original->getDifferentiableAttrs()) { - const auto &rdaIndexSet = rda->getIndices().parameters; - // If all indices in indexSet are in rdaIndexSet, and it has fewer - // indices than our current candidate and a primitive adjoint, rda is our - // new candidate. - if (!indexSet.test(rdaIndexSet) && // all indexSet indices in rdaIndexSet - (supersetParamIndices.empty() || // fewer parameters than before - rdaIndexSet.count() < supersetParamIndices.count())) - supersetParamIndices = rda->getIndices().parameters; + auto *rdaIndexSet = rda->getIndices().parameters; + // If all indices in `indexSet` are in `rdaIndexSet`, and it has fewer + // indices than our current candidate and a primitive VJP, then `rda` is + // our new candidate. + // + // NOTE: `rda` may come from a un-partial-applied function, it may have + // more parameters than the desired indices. We expect this logic to go + // away when we support `@differentiable` partial apply. + if (rdaIndexSet->isSupersetOf( + indexSet->extendingCapacity(getASTContext(), + rdaIndexSet->getCapacity())) && + // fewer parameters than before + (superset->isEmpty() || + rdaIndexSet->getNumIndices() < superset->getNumIndices())) + superset = rda->getIndices().parameters; } - auto existing = enqueuedTaskIndices.find( - {original, {indices.source, supersetParamIndices}}); + auto existing = + enqueuedTaskIndices.find({original, {indices.source, superset}}); if (existing == enqueuedTaskIndices.end()) return nullptr; return differentiationTasks[existing->getSecond()].get(); @@ -957,7 +968,7 @@ class ADContext { /// function. SILValue promoteToDifferentiableFunction( SILBuilder &builder, SILLocation loc, SILValue origFnOperand, - const llvm::SmallBitVector ¶meterIndices, + AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, DifferentiationInvoker invoker); /// For an autodiff_function instruction that is missing associated functions, @@ -1260,8 +1271,7 @@ class DifferentiableActivityInfo { bool isVaried(SILValue value, unsigned independentVariableIndex) const; bool isUseful(SILValue value, unsigned dependentVariableIndex) const; - bool isVaried(SILValue value, - const llvm::SmallBitVector ¶meterIndices) const; + bool isVaried(SILValue value, AutoDiffIndexSubset *parameterIndices) const; bool isActive(SILValue value, const SILAutoDiffIndices &indices) const; Activity getActivity(SILValue value, @@ -1531,8 +1541,8 @@ bool DifferentiableActivityInfo::isVaried( } bool DifferentiableActivityInfo::isVaried( - SILValue value, const llvm::SmallBitVector ¶meterIndices) const { - for (auto paramIdx : parameterIndices.set_bits()) + SILValue value, AutoDiffIndexSubset *parameterIndices) const { + for (auto paramIdx : parameterIndices->getIndices()) if (isVaried(value, paramIdx)) return true; return false; @@ -1814,8 +1824,8 @@ emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder, if (auto diffableFnType = original->getType().castTo()) { if (diffableFnType->isDifferentiable()) { auto paramIndices = diffableFnType->getDifferentiationParameterIndices(); - for (auto i : desiredIndices.parameters.set_bits()) { - if (i >= paramIndices.size() || !paramIndices[i]) { + for (auto i : desiredIndices.parameters->getIndices()) { + if (!paramIndices->contains(i)) { context.emitNondifferentiabilityError(original, parentTask, diag::autodiff_function_nondiff_parameter_not_differentiable); return None; @@ -1966,12 +1976,20 @@ emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder, // Check that the requirement indices are the same as the desired indices. auto *requirementParameterIndices = diffAttr->getParameterIndices(); auto loweredRequirementIndices = requirementParameterIndices->getLowered( + context.getASTContext(), requirementDecl->getInterfaceType()->castTo()); SILAutoDiffIndices requirementIndices(/*source*/ 0, loweredRequirementIndices); + // NOTE: We need to extend the capacity of desired parameter indices to + // requirement parameter indices, because there's a argument count mismatch. + // When `@differentiable` partial apply is supported, this problem will go + // away. if (desiredIndices.source != requirementIndices.source || - desiredIndices.parameters.test(requirementIndices.parameters)) { + !desiredIndices.parameters->extendingCapacity( + context.getASTContext(), + requirementIndices.parameters->getCapacity()) + ->isSubsetOf(requirementIndices.parameters)) { context.emitNondifferentiabilityError(original, parentTask, diag::autodiff_protocol_member_subset_indices_not_differentiable); return None; @@ -2447,11 +2465,11 @@ ADContext::createPrimalValueStruct(const DifferentiationTask *task, /// indices, figure out whether the parent function is being differentiated with /// respect to this parameter, according to the indices. static bool isDifferentiationParameter(SILArgument *argument, - llvm::SmallBitVector indices) { + AutoDiffIndexSubset *indices) { if (!argument) return false; auto *function = argument->getFunction(); auto paramArgs = function->getArgumentsWithoutIndirectResults(); - for (unsigned i : indices.set_bits()) + for (unsigned i : indices->getIndices()) if (paramArgs[i] == argument) return true; return false; @@ -2604,7 +2622,8 @@ class PrimalGenCloner final auto &builder = getBuilder(); builder.setInsertionPoint(exit); auto structLoweredTy = - getContext().getTypeConverter().getLoweredType(structTy, ResilienceExpansion::Minimal); + getContext().getTypeConverter().getLoweredType( + structTy, ResilienceExpansion::Minimal); auto primValsVal = builder.createStruct(loc, structLoweredTy, primalValues); // If the original result was a tuple, return a tuple of all elements in the // original result tuple and the primal value struct value. @@ -2695,7 +2714,8 @@ class PrimalGenCloner final errorOccurred = true; return; } - SILAutoDiffIndices indices(/*source*/ 0, /*parameters*/ {0}); + SILAutoDiffIndices indices(/*source*/ 0, + AutoDiffIndexSubset::getDefault(getASTContext(), 1, true)); auto *task = getContext().lookUpDifferentiationTask(getterFn, indices); if (!task) { getContext().emitNondifferentiabilityError( @@ -2764,7 +2784,8 @@ class PrimalGenCloner final errorOccurred = true; return; } - SILAutoDiffIndices indices(/*source*/ 0, /*parameters*/ {0}); + SILAutoDiffIndices indices(/*source*/ 0, + AutoDiffIndexSubset::getDefault(getASTContext(), 1, true)); auto *task = getContext().lookUpDifferentiationTask(getterFn, indices); if (!task) { getContext().emitNondifferentiabilityError( @@ -2977,7 +2998,11 @@ class PrimalGenCloner final return; } // Form expected indices by assuming there's only one result. - SILAutoDiffIndices indices(activeResultIndices.front(), activeParamIndices); + SILAutoDiffIndices indices(activeResultIndices.front(), + AutoDiffIndexSubset::get( + getASTContext(), + ai->getArgumentsWithoutIndirectResults().size(), + activeParamIndices)); // Emit the VJP. auto vjpAndVJPIndices = emitAssociatedFunctionReference( @@ -3988,7 +4013,7 @@ class AdjointEmitter final : public SILInstructionVisitor { task->getIndices().isWrtParameter(selfParamIndex)) addRetElt(selfParamIndex); // Add the non-self parameters that are differentiated with respect to. - for (auto i : task->getIndices().parameters.set_bits()) { + for (auto i : task->getIndices().parameters->getIndices()) { // Do not add the self parameter because we have already added it at the // beginning. if (origTy->hasSelfParam() && i == selfParamIndex) @@ -4197,7 +4222,7 @@ class AdjointEmitter final : public SILInstructionVisitor { } } // Accumulate adjoints for the remaining non-self original parameters. - for (unsigned i : applyInfo.actualIndices.parameters.set_bits()) { + for (unsigned i : applyInfo.actualIndices.parameters->getIndices()) { // Do not set the adjoint of the original self parameter because we // already added it at the beginning. if (ai->hasSelfArgument() && i == selfParamIndex) @@ -4206,8 +4231,7 @@ class AdjointEmitter final : public SILInstructionVisitor { auto cotan = *allResultsIt++; // If a cotangent value corresponds to a non-desired parameter, it won't // be used, so release it. - if (i >= applyInfo.desiredIndices.parameters.size() || - !applyInfo.desiredIndices.parameters[i]) { + if (!applyInfo.desiredIndices.parameters->contains(i)) { emitCleanup(builder, loc, cotan); continue; } @@ -4267,8 +4291,9 @@ class AdjointEmitter final : public SILInstructionVisitor { AutoDiffAssociatedVectorSpaceKind::Cotangent, LookUpConformanceInModule(getModule().getSwiftModule())) ->getType()->getCanonicalType(); - assert(!getModule().Types.getTypeLowering(cotangentVectorTy, ResilienceExpansion::Minimal) - .isAddressOnly()); + assert(!getModule().Types.getTypeLowering( + cotangentVectorTy, ResilienceExpansion::Minimal) + .isAddressOnly()); auto *cotangentVectorDecl = cotangentVectorTy->getStructOrBoundGenericStruct(); assert(cotangentVectorDecl); @@ -4339,8 +4364,9 @@ class AdjointEmitter final : public SILInstructionVisitor { AutoDiffAssociatedVectorSpaceKind::Cotangent, LookUpConformanceInModule(getModule().getSwiftModule())) ->getType()->getCanonicalType(); - assert(!getModule().Types.getTypeLowering(cotangentVectorTy, ResilienceExpansion::Minimal) - .isAddressOnly()); + assert(!getModule().Types.getTypeLowering( + cotangentVectorTy, ResilienceExpansion::Minimal) + .isAddressOnly()); auto cotangentVectorSILTy = SILType::getPrimitiveObjectType(cotangentVectorTy); auto *cotangentVectorDecl = @@ -4376,7 +4402,8 @@ class AdjointEmitter final : public SILInstructionVisitor { field->getModuleContext(), field); auto fieldTy = field->getType().subst(substMap); auto fieldSILTy = - getContext().getTypeConverter().getLoweredType(fieldTy, ResilienceExpansion::Minimal); + getContext().getTypeConverter().getLoweredType( + fieldTy, ResilienceExpansion::Minimal); assert(fieldSILTy.isObject()); eltVals.push_back(makeZeroAdjointValue(fieldSILTy)); } @@ -4932,7 +4959,8 @@ void AdjointEmitter::emitZeroIndirect(CanType type, SILValue bufferAccess, } SILValue AdjointEmitter::emitZeroDirect(CanType type, SILLocation loc) { - auto silType = getModule().Types.getLoweredLoadableType(type, ResilienceExpansion::Minimal); + auto silType = getModule().Types.getLoweredLoadableType( + type, ResilienceExpansion::Minimal); auto *buffer = builder.createAllocStack(loc, silType); auto *initAccess = builder.createBeginAccess(loc, buffer, SILAccessKind::Init, SILAccessEnforcement::Static, @@ -5410,7 +5438,8 @@ void DifferentiationTask::createEmptyAdjoint() { // Given a type, returns its formal SIL parameter info. auto getCotangentParameterInfoForOriginalResult = [&]( CanType cotanType, ResultConvention origResConv) -> SILParameterInfo { - auto &tl = context.getTypeConverter().getTypeLowering(cotanType, ResilienceExpansion::Minimal); + auto &tl = context.getTypeConverter().getTypeLowering( + cotanType, ResilienceExpansion::Minimal); ParameterConvention conv; switch (origResConv) { case ResultConvention::Owned: @@ -5433,7 +5462,8 @@ void DifferentiationTask::createEmptyAdjoint() { // Given a type, returns its formal SIL result info. auto getCotangentResultInfoForOriginalParameter = [&]( CanType cotanType, ParameterConvention origParamConv) -> SILResultInfo { - auto &tl = context.getTypeConverter().getTypeLowering(cotanType, ResilienceExpansion::Minimal); + auto &tl = context.getTypeConverter().getTypeLowering( + cotanType, ResilienceExpansion::Minimal); ResultConvention conv; switch (origParamConv) { case ParameterConvention::Direct_Owned: @@ -5492,7 +5522,7 @@ void DifferentiationTask::createEmptyAdjoint() { } // Add adjoint results for the requested non-self wrt parameters. - for (auto i : getIndices().parameters.set_bits()) { + for (auto i : getIndices().parameters->getIndices()) { if (origTy->hasSelfParam() && i == selfParamIndex) continue; auto origParam = origParams[i]; @@ -5655,7 +5685,8 @@ void DifferentiationTask::createVJP(bool isExported) { "unexpected number of vjp parameters"); assert(vjpConv.getResults().size() == numOriginalResults + 1 && "unexpected number of vjp results"); - assert(adjointConv.getResults().size() == getIndices().parameters.count() && + assert(adjointConv.getResults().size() == + getIndices().parameters->getNumIndices() && "unexpected number of adjoint results"); // We assume that primal result conventions (for all results but the optional @@ -5732,7 +5763,7 @@ class Differentiation : public SILModuleTransform { SILValue ADContext::promoteToDifferentiableFunction( SILBuilder &builder, SILLocation loc, SILValue origFnOperand, - const llvm::SmallBitVector ¶meterIndices, unsigned differentiationOrder, + AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, DifferentiationInvoker invoker) { if (auto *ai = dyn_cast(origFnOperand)) { if (auto *sourceFn = dyn_cast(ai->getCallee())) { @@ -5809,9 +5840,6 @@ SILValue ADContext::promoteToDifferentiableFunction( // Return nullptr. if (!assocFnAndIndices) return nullptr; - assert(assocFnAndIndices->second == desiredIndices && - "FIXME: We could emit a thunk that converts the VJP to have the " - "desired indices."); auto assocFn = assocFnAndIndices->first; builder.createRetainValue(loc, assocFn, builder.getDefaultAtomicity()); assocFns.push_back(assocFn); diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index 0f16425fb23fd..574c2db237d36 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -653,19 +653,22 @@ SILDeserializer::readSILFunctionChecked(DeclID FID, SILFunction *existingFn, uint64_t jvpNameId; uint64_t vjpNameId; - uint64_t source; - ArrayRef parameters; + unsigned source; + ArrayRef rawParameterIndices; SmallVector requirements; SILDifferentiableAttrLayout::readRecord(scratch, jvpNameId, vjpNameId, - source, parameters); + source, rawParameterIndices); - llvm::SmallBitVector parametersBitVector(parameters.size()); StringRef jvpName = MF->getIdentifier(jvpNameId).str(); StringRef vjpName = MF->getIdentifier(vjpNameId).str(); - for (unsigned i : indices(parameters)) - parametersBitVector[i] = parameters[i]; - SILAutoDiffIndices indices(source, parametersBitVector); + + SmallVector parameterIndices(rawParameterIndices.begin(), + rawParameterIndices.end()); + auto *parameterIndexSubset = AutoDiffIndexSubset::get( + MF->getContext(), fn->getLoweredFunctionType()->getNumParameters(), + parameterIndices); + SILAutoDiffIndices indices(source, parameterIndexSubset); MF->readGenericRequirements(requirements, SILCursor); auto *attr = SILDifferentiableAttr::create(SILMod, indices, requirements, @@ -1505,18 +1508,19 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB, // SWIFT_ENABLE_TENSORFLOW case SILInstructionKind::AutoDiffFunctionInst: { auto numParamIndices = ListOfValues.size() - NumArguments * 3; - auto paramIndices = ListOfValues.take_front(numParamIndices); + auto rawParamIndices = + map>(ListOfValues.take_front(numParamIndices), + [](uint64_t i) { return (unsigned)i; }); auto numParams = Attr2; - llvm::SmallBitVector paramIndicesBitVec(numParams); - for (unsigned idx : paramIndices) - paramIndicesBitVec.set(idx); + auto *paramIndices = + AutoDiffIndexSubset::get(MF->getContext(), numParams, rawParamIndices); SmallVector operands; for (auto i = numParamIndices; i < NumArguments * 3; i += 3) { auto astTy = MF->getType(ListOfValues[i]); auto silTy = getSILType(astTy, (SILValueCategory)ListOfValues[i+1]); operands.push_back(getLocalValue(ListOfValues[i+2], silTy)); } - ResultVal = Builder.createAutoDiffFunction(Loc, paramIndicesBitVec, + ResultVal = Builder.createAutoDiffFunction(Loc, paramIndices, /*differentiationOrder*/ Attr, operands[0], ArrayRef(operands).drop_front()); break; diff --git a/lib/Serialization/SILFormat.h b/lib/Serialization/SILFormat.h index d9ec58038f2ef..bc0986b037629 100644 --- a/lib/Serialization/SILFormat.h +++ b/lib/Serialization/SILFormat.h @@ -317,10 +317,10 @@ namespace sil_block { // SWIFT_ENABLE_TENSORFLOW using SILDifferentiableAttrLayout = BCRecordLayout< SIL_DIFFERENTIABLE_ATTR, - IdentifierIDField, // JVP name. - IdentifierIDField, // VJP name. - BCFixed<32>, // Indices' source. - BCArray> // Indices' parameters bitvector. + IdentifierIDField, // JVP name. + IdentifierIDField, // VJP name. + BCVBR<8>, // Result index. + BCArray // Parameter indices. >; // Has an optional argument list where each argument is a typed valueref. diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp index c9ee0949ebf20..ef05131b34092 100644 --- a/lib/Serialization/SerializeSIL.cpp +++ b/lib/Serialization/SerializeSIL.cpp @@ -437,11 +437,9 @@ void SILSerializer::writeSILFunction(const SILFunction &F, bool DeclOnly) { assert(DA->hasJVP() && DA->hasVJP() && "JVP and VJP must exist in canonical SIL"); - auto ¶mIndices = DA->getIndices(); - SmallVector parameters; - for (unsigned i : indices(paramIndices.parameters)) - parameters.push_back(paramIndices.parameters[i]); - + auto &indices = DA->getIndices(); + SmallVector parameters(indices.parameters->begin(), + indices.parameters->end()); SILDifferentiableAttrLayout::emitRecord( Out, ScratchRecord, differentiableAttrAbbrCode, DA->hasJVP() @@ -450,7 +448,7 @@ void SILSerializer::writeSILFunction(const SILFunction &F, bool DeclOnly) { DA->hasVJP() ? S.addDeclBaseNameRef(Ctx.getIdentifier(DA->getVJPName())) : IdentifierID(), - paramIndices.source, parameters); + indices.source, parameters); S.writeGenericRequirements(DA->getRequirements(), SILAbbrCodes); } @@ -971,8 +969,8 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) { case SILInstructionKind::AutoDiffFunctionInst: { auto *adfi = cast(&SI); SmallVector trailingInfo; - auto ¶mIndices = adfi->getParameterIndices(); - for (unsigned idx : paramIndices.set_bits()) + auto *paramIndices = adfi->getParameterIndices(); + for (unsigned idx : paramIndices->getIndices()) trailingInfo.push_back(idx); for (auto &op : adfi->getAllOperands()) { auto val = op.get(); @@ -982,7 +980,7 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) { } SILInstAutoDiffFunctionLayout::emitRecord(Out, ScratchRecord, SILAbbrCodes[SILInstAutoDiffFunctionLayout::Code], - adfi->getDifferentiationOrder(), (unsigned)paramIndices.size(), + adfi->getDifferentiationOrder(), paramIndices->getCapacity(), adfi->getNumOperands(), trailingInfo); break; } diff --git a/test/AutoDiff/currying.swift b/test/AutoDiff/currying.swift index 08bd469d40353..8cd47d997a94f 100644 --- a/test/AutoDiff/currying.swift +++ b/test/AutoDiff/currying.swift @@ -7,7 +7,7 @@ var CurryingAutodiffTests = TestSuite("CurryingAutodiff") CurryingAutodiffTests.test("StructMember") { struct A { @differentiable(wrt: (value)) - func v(_ value: Float) -> Float { return value * value } + func v(_ value: Float) -> Float { return value * value } } let a = A() diff --git a/unittests/AST/SILAutoDiffIndices.cpp b/unittests/AST/SILAutoDiffIndices.cpp index c71706abb219a..0dbaa63ec75d2 100644 --- a/unittests/AST/SILAutoDiffIndices.cpp +++ b/unittests/AST/SILAutoDiffIndices.cpp @@ -12,71 +12,149 @@ // SWIFT_ENABLE_TENSORFLOW #include "swift/AST/AutoDiff.h" +#include "TestContext.h" #include "gtest/gtest.h" using namespace swift; +using namespace swift::unittest; -TEST(SILAutoDiffIndices, EqualityAndHash) { - using IndicesDenseMapInfo = llvm::DenseMapInfo; - - std::array empty; - // Each example is distinct. - SILAutoDiffIndices examples[] = { - {0, empty}, - {1, empty}, - {0, {0}}, - {0, {0, 1}}, - {0, {1}}, - {0, {1, 2}}, - {0, {100}}, - {0, {0, 100}}, - {0, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}} - }; - size_t exampleCount = std::extent::value; - for (size_t i = 0; i < exampleCount; ++i) { - auto example1 = examples[i]; - auto grownExample1 = example1; - grownExample1.parameters.resize(grownExample1.parameters.size() + 1); +TEST(AutoDiffIndexSubset, NumBitWordsNeeded) { + EXPECT_EQ(AutoDiffIndexSubset::getNumBitWordsNeededForCapacity(0), 0u); + EXPECT_EQ(AutoDiffIndexSubset::getNumBitWordsNeededForCapacity(1), 1u); + EXPECT_EQ(AutoDiffIndexSubset::getNumBitWordsNeededForCapacity(5), 1u); + EXPECT_EQ(AutoDiffIndexSubset::getNumBitWordsNeededForCapacity( + AutoDiffIndexSubset::numBitsPerBitWord - 1), 1u); + EXPECT_EQ(AutoDiffIndexSubset::getNumBitWordsNeededForCapacity( + AutoDiffIndexSubset::numBitsPerBitWord), 2u); + EXPECT_EQ(AutoDiffIndexSubset::getNumBitWordsNeededForCapacity( + AutoDiffIndexSubset::numBitsPerBitWord * 2 - 1), 2u); + EXPECT_EQ(AutoDiffIndexSubset::getNumBitWordsNeededForCapacity( + AutoDiffIndexSubset::numBitsPerBitWord * 2), 3u); +} - // Make sure that the grown example is actually grown. - EXPECT_TRUE(example1.parameters.size() < grownExample1.parameters.size()); +TEST(AutoDiffIndexSubset, BitWordIndexAndOffset) { + EXPECT_EQ(AutoDiffIndexSubset::getBitWordIndexAndOffset(0), + std::make_pair(0u, 0u)); + EXPECT_EQ(AutoDiffIndexSubset::getBitWordIndexAndOffset(5), + std::make_pair(0u, 5u)); + EXPECT_EQ(AutoDiffIndexSubset::getBitWordIndexAndOffset(8), + std::make_pair(0u, 8u)); + EXPECT_EQ(AutoDiffIndexSubset::getBitWordIndexAndOffset( + AutoDiffIndexSubset::numBitsPerBitWord - 1), + std::make_pair(0u, AutoDiffIndexSubset::numBitsPerBitWord - 1)); + EXPECT_EQ(AutoDiffIndexSubset::getBitWordIndexAndOffset( + AutoDiffIndexSubset::numBitsPerBitWord), + std::make_pair(1u, 0u)); +} - // Test that the example is equal to itself and to the grown version of - // itself, using both operator== and IndicesDenseMapInfo::isEqual. - EXPECT_TRUE(example1 == example1); - EXPECT_TRUE(example1 == grownExample1); - EXPECT_TRUE(grownExample1 == example1); - EXPECT_TRUE(IndicesDenseMapInfo::isEqual(example1, example1)); - EXPECT_TRUE(IndicesDenseMapInfo::isEqual(example1, grownExample1)); - EXPECT_TRUE(IndicesDenseMapInfo::isEqual(grownExample1, example1)); +TEST(AutoDiffIndexSubset, Equality) { + TestContext ctx; + EXPECT_EQ(AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {0}), + AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {0})); + EXPECT_EQ(AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {0, 2, 4}), + AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {0, 2, 4})); + EXPECT_EQ(AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {}), + AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {})); + EXPECT_NE(AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 1, + /*indices*/ {}), + AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 0, + /*indices*/ {})); + EXPECT_NE(AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {0}), + AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {})); +} - // Test that the grown version has the same hash as the original. - EXPECT_EQ(IndicesDenseMapInfo::getHashValue(example1), - IndicesDenseMapInfo::getHashValue(grownExample1)); +TEST(AutoDiffIndexSubset, Bits) { + TestContext ctx; + auto *indices1 = AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {0, 2, 4}); + EXPECT_EQ(indices1->getNumBitWords(), 1u); + EXPECT_EQ(indices1->getCapacity(), 5u); + EXPECT_TRUE(indices1->contains(0)); + EXPECT_FALSE(indices1->contains(1)); + EXPECT_TRUE(indices1->contains(2)); + EXPECT_FALSE(indices1->contains(3)); + EXPECT_TRUE(indices1->contains(4)); - // Test that the example is not equal to any of the others. - for (size_t j = i + 1; j < exampleCount; ++j) { - auto example2 = examples[j]; - auto grownExample2 = example2; - grownExample2.parameters.resize(grownExample2.parameters.size() + 1); + auto *indices2 = AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {1, 3}); + EXPECT_EQ(indices2->getNumBitWords(), 1u); + EXPECT_EQ(indices2->getCapacity(), 5u); + EXPECT_FALSE(indices2->contains(0)); + EXPECT_TRUE(indices2->contains(1)); + EXPECT_FALSE(indices2->contains(2)); + EXPECT_TRUE(indices2->contains(3)); + EXPECT_FALSE(indices2->contains(4)); +} - // Make sure that the grown example is actually grown. - EXPECT_TRUE(example2.parameters.size() < grownExample2.parameters.size()); +TEST(AutoDiffIndexSubset, Iteration) { + TestContext ctx; + // Test 1 + { + auto *indices1 = AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {0, 2, 4}); + // Check forward iteration. + EXPECT_EQ(indices1->findFirst(), 0); + EXPECT_EQ(indices1->findNext(0), 2); + EXPECT_EQ(indices1->findNext(2), 4); + EXPECT_EQ(indices1->findNext(4), (int)indices1->getCapacity()); + // Check reverse iteration. + EXPECT_EQ(indices1->findLast(), 4); + EXPECT_EQ(indices1->findPrevious(4), 2); + EXPECT_EQ(indices1->findPrevious(2), 0); + EXPECT_EQ(indices1->findPrevious(0), -1); + // Check range. + unsigned indices1Elements[3] = {0, 2, 4}; + EXPECT_TRUE(std::equal(indices1->begin(), indices1->end(), + indices1Elements)); + } + // Test 2 + { + auto *indices2 = AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {1, 3}); + // Check forward iteration. + EXPECT_EQ(indices2->findFirst(), 1); + EXPECT_EQ(indices2->findNext(1), 3); + EXPECT_EQ(indices2->findNext(3), (int)indices2->getCapacity()); + // Check reverse iteration. + EXPECT_EQ(indices2->findLast(), 3); + EXPECT_EQ(indices2->findPrevious(3), 1); + EXPECT_EQ(indices2->findPrevious(1), -1); + // Check range. + unsigned indices2Elements[2] = {1, 3}; + EXPECT_TRUE(std::equal(indices2->begin(), indices2->end(), + indices2Elements)); + } +} - EXPECT_FALSE(example1 == example2); - EXPECT_FALSE(example2 == example1); - EXPECT_FALSE(IndicesDenseMapInfo::isEqual(example1, example2)); - EXPECT_FALSE(IndicesDenseMapInfo::isEqual(example2, example1)); +TEST(AutoDiffIndexSubset, SupersetAndSubset) { + TestContext ctx; + auto *indices1 = AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {0, 2, 4}); + EXPECT_TRUE(indices1->isSupersetOf(indices1)); + EXPECT_TRUE(indices1->isSubsetOf(indices1)); + auto *indices2 = AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {2}); + EXPECT_TRUE(indices2->isSupersetOf(indices2)); + EXPECT_TRUE(indices2->isSubsetOf(indices2)); - EXPECT_FALSE(example1 == grownExample2); - EXPECT_FALSE(grownExample2 == example1); - EXPECT_FALSE(IndicesDenseMapInfo::isEqual(example1, grownExample2)); - EXPECT_FALSE(IndicesDenseMapInfo::isEqual(grownExample2, example1)); + EXPECT_TRUE(indices1->isSupersetOf(indices2)); + EXPECT_TRUE(indices2->isSubsetOf(indices1)); +} - EXPECT_FALSE(example2 == grownExample1); - EXPECT_FALSE(grownExample1 == example2); - EXPECT_FALSE(IndicesDenseMapInfo::isEqual(example2, grownExample1)); - EXPECT_FALSE(IndicesDenseMapInfo::isEqual(grownExample1, example2)); - } - } +TEST(AutoDiffIndexSubset, Insertion) { + TestContext ctx; + auto *indices1 = AutoDiffIndexSubset::get(ctx.Ctx, 5, {0, 2, 4}); + EXPECT_EQ(indices1->adding(0, ctx.Ctx), indices1); + EXPECT_EQ(indices1->adding(1, ctx.Ctx), + AutoDiffIndexSubset::get(ctx.Ctx, 5, {0, 1, 2, 4})); + EXPECT_EQ(indices1->adding(3, ctx.Ctx), + AutoDiffIndexSubset::get(ctx.Ctx, 5, {0, 2, 3, 4})); }