From f56b443e6af4200a33525d866da882fc974c8f08 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Wed, 27 Mar 2019 21:45:45 -0700 Subject: [PATCH 1/8] WIP --- include/swift/AST/AutoDiff.h | 90 ++++++++++++++++++++++++++++++++++++ lib/AST/ASTContext.cpp | 15 ++++++ lib/AST/AutoDiff.cpp | 75 ++++++++++++++++++++++++++++++ 3 files changed, 180 insertions(+) diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index f4e8806b8f7f9..1fe7c7accf436 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -20,6 +20,7 @@ #include "ASTContext.h" #include "llvm/ADT/SmallBitVector.h" +#include "swift/Basic/Range.h" namespace swift { @@ -219,6 +220,95 @@ class AutoDiffParameterIndicesBuilder { unsigned size() { return parameters.size(); } }; +class AutoDiffIndexSubset : public llvm::FoldingSetNode { +private: + using Byte = uint8_t; + + unsigned capacity; + unsigned size; + unsigned numBytes; + + static unsigned getNumBytesNeeded(unsigned largest) { + return (largest + 1) / sizeof(Byte) + 1; + } + + unsigned getNumBytes() const { + return numBytes; + } + + const Byte *getBytesData() const { + return reinterpret_cast(this + 1); + } + + ArrayRef getBytes() const { + return {getBytesData(), getNumBytes()}; + } + + MutableArrayRef getMutableBytes() { + return {const_cast(getBytesData()), getNumBytes()}; + } + + explicit AutoDiffIndexSubset(unsigned capacity, unsigned size, + unsigned numBytes, ArrayRef bytes); + +public: + AutoDiffIndexSubset() = delete; + AutoDiffIndexSubset(const AutoDiffIndexSubset &) = delete; + AutoDiffIndexSubset &operator=(const AutoDiffIndexSubset &) = delete; + + static AutoDiffIndexSubset *get(ASTContext &ctx, unsigned capacity, + bool includeAll = false); + static AutoDiffIndexSubset *get(ASTContext &ctx, unsigned capacity, + IntRange<> range); + static AutoDiffIndexSubset *get(ASTContext &ctx, unsigned capacity, + ArrayRef indices); + + unsigned getCapacity() const { + return capacity; + } + + unsigned getSize() const { + return size; + } + + ArrayRef getIndices() const { + return indices; + } + + bool contains(unsigned index) const { + + } + + bool equals(const AutoDiffIndexSubset *other) const; + bool isProperSubsetOf(const AutoDiffIndexSubset *other) const; + bool isProperSupersetOf(const AutoDiffIndexSubset *other) const; + + AutoDiffIndexSubset *adding(unsigned index, ASTContext &ctx) const; + AutoDiffIndexSubset *adding(AutoDiffIndexSubset *other, + ASTContext &ctx) const; + + static void Profile(llvm::FoldingSetNodeID &id, + const SmallBitVector &rawIndices) const; + void Profile(llvm::FoldingSetNodeID &id) const { + Profile(id, rawIndices); + } +}; + +class AutoDiffFunctionParameterSubset { +private: + AutoDiffIndexSubset *indexSubset; + bool isCurried; + +public: + explicit AutoDiffFunctionParameterSubset( + AutoDiffIndexSubset *indexSubset, bool isCurried) + : indexSubset(indexSubset), isCurried(isCurried) {} + + explicit AutoDiffFunctionParameterSubset( + ASTContext &ctx, AutoDiffIndexSubset *parameterSubset, + Optional isSelfIncluded); +}; + /// SIL-level automatic differentiation indices. Consists of a source index, /// i.e. index of the dependent result to differentiate from, and parameter /// indices, i.e. index of independent parameters to differentiate with diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index e44b5c5a56491..b60c9034f5bec 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -390,6 +390,9 @@ FOR_KNOWN_FOUNDATION_TYPES(CACHE_FOUNDATION_DECL) /// For uniquifying `AutoDiffParameterIndices` allocations. llvm::FoldingSet AutoDiffParameterIndicesSet; + /// For uniquifying `AutoDiffIndexSubset` allocations. + llvm::FoldingSet AutoDiffIndexSubsets; + /// For uniquifying `AutoDiffAssociatedFunctionIdentifier` allocations. llvm::FoldingSet AutoDiffAssociatedFunctionIdentifiers; @@ -5245,6 +5248,18 @@ AutoDiffParameterIndices::get(llvm::SmallBitVector indices, ASTContext &C) { return newNode; } +AutoDiffIndexSubset *AutoDiffIndexSubset::get( + ASTContext &ctx, const SmallBitVector &&rawIndices) { + auto &foldingSet = ctx.getImpl().AutoDiffIndexSubsets; + llvm::FoldingSetNodeID id; + Profile(id, rawIndices); + void *insertPos = nullptr; + auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos); + if (existing) + return existing; + +} + AutoDiffAssociatedFunctionIdentifier * AutoDiffAssociatedFunctionIdentifier::get( AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder, diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index 3fb754ae6faf6..5dba9eb86955c 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -355,3 +355,78 @@ CanType VectorSpace::getCanonicalType() const { NominalTypeDecl *VectorSpace::getNominal() const { return getVector()->getNominalOrBoundGenericNominal(); } + +AutoDiffIndexSubset * +AutoDiffIndexSubset::get(ASTContext &ctx, unsigned capacity, bool includeAll) { + auto *buf = reinterpret_cast( + ctx.Allocate(sizeof(AutoDiffIndexSubset), alignof(AutoDiffIndexSubset))); + return new (buf) AutoDiffIndexSubset(SmallBitVector(capacity, includeAll)); +} + +AutoDiffIndexSubset * +AutoDiffIndexSubset::get(ASTContext &ctx, const SmallBitVector &rawIndices) { + auto *buf = reinterpret_cast( + ctx.Allocate(sizeof(AutoDiffIndexSubset), alignof(AutoDiffIndexSubset))); + return new (buf) AutoDiffIndexSubset(std::move(rawIndices)); +} + +AutoDiffIndexSubset *AutoDiffIndexSubset::get(ASTContext &ctx, + unsigned capacity, + IntRange<> range) { + auto *subset = get(ctx, capacity); + subset->rawIndices.set(range.front(), range.back()); + return subset; +} + +AutoDiffIndexSubset *AutoDiffIndexSubset::get(ASTContext &ctx, + unsigned capacity, + ArrayRef indices) { + auto *subset = get(ctx, capacity); + for (auto i : indices) { + assert(i < capacity && "Index must be smaller than capacity"); + subset->rawIndices.set(i); + } + return subset; +} + +bool AutoDiffIndexSubset::equals(const AutoDiffIndexSubset *other) const { + return rawIndices == other->rawIndices; +} + +bool AutoDiffIndexSubset:: +isProperSubsetOf(const AutoDiffIndexSubset *other) const { + return getSize() == other->getSize() && other->rawIndices.test(rawIndices); +} + +bool AutoDiffIndexSubset:: +isProperSupersetOf(const AutoDiffIndexSubset *other) const { + return getSize() == other->getSize() && rawIndices.test(other->rawIndices); +} + +AutoDiffIndexSubset * +AutoDiffIndexSubset::adding(unsigned index, ASTContext &ctx) const { + assert(index < getCapacity()); + +} + +AutoDiffIndexSubset * +AutoDiffIndexSubset::adding(AutoDiffIndexSubset *other, ASTContext &ctx) const { + return get(ctx, rawIndices | other->rawIndices); +} + +void AutoDiffIndexSubset::Profile(llvm::FoldingSetNodeID &id, + const SmallBitVector &rawIndices) { + id.AddInteger(rawIndices.size()); + for (auto i : rawIndices.set_bits()) + id.AddInteger(i); +} + +AutoDiffFunctionParameterSubset:: +AutoDiffFunctionParameterSubset(ASTContext &ctx, + AutoDiffIndexSubset *parameterSubset, + Optional isSelfIncluded) { + if (isSelfIncluded.hasValue()) { + indexSubset = + AutoDiffIndexSubset::get(ctx, parameterSubset->getCapacity() + 1); + } +} From 69d2363af914eae2efaf121ce09751ff6e816b3c Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Thu, 4 Apr 2019 01:06:40 -0700 Subject: [PATCH 2/8] WIP --- include/swift/AST/AutoDiff.h | 138 ++++++++++++++++++++++++--------- lib/AST/ASTContext.cpp | 25 +++++- lib/AST/AutoDiff.cpp | 146 ++++++++++++++++++++++++++--------- 3 files changed, 233 insertions(+), 76 deletions(-) diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 1fe7c7accf436..e44d0ff21b419 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -222,34 +222,44 @@ class AutoDiffParameterIndicesBuilder { class AutoDiffIndexSubset : public llvm::FoldingSetNode { private: - using Byte = uint8_t; + using BitWord = uint64_t; unsigned capacity; - unsigned size; - unsigned numBytes; + unsigned numBitWords; - static unsigned getNumBytesNeeded(unsigned largest) { - return (largest + 1) / sizeof(Byte) + 1; + static std::pair getBitWordIndexAndOffset(unsigned index); + static unsigned getNumBitWordsNeededForCapacity(unsigned capacity); + + unsigned getNumBitWords() const { + return numBitWords; + } + + BitWord *getBitWordsData() { + return reinterpret_cast(this + 1); + } + + const BitWord *getBitWordsData() const { + return reinterpret_cast(this + 1); } - unsigned getNumBytes() const { - return numBytes; + ArrayRef getBitWords() const { + return {getBitWordsData(), getNumBitWords()}; } - const Byte *getBytesData() const { - return reinterpret_cast(this + 1); + BitWord getBitWord(unsigned i) const { + return getBitWordsData()[i]; } - ArrayRef getBytes() const { - return {getBytesData(), getNumBytes()}; + BitWord &getBitWord(unsigned i) { + return getBitWordsData()[i]; } - MutableArrayRef getMutableBytes() { - return {const_cast(getBytesData()), getNumBytes()}; + MutableArrayRef getMutableBitWords() { + return {const_cast(getBitWordsData()), getNumBitWords()}; } - explicit AutoDiffIndexSubset(unsigned capacity, unsigned size, - unsigned numBytes, ArrayRef bytes); + explicit AutoDiffIndexSubset(unsigned capacity, unsigned numBitWords, + ArrayRef indices); public: AutoDiffIndexSubset() = delete; @@ -267,46 +277,104 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode { return capacity; } - unsigned getSize() const { - return size; - } + class iterator; - ArrayRef getIndices() const { - return indices; - } + iterator begin() const; + iterator end() const; + iterator_range getIndices() const; + unsigned getNumIndices() const; bool contains(unsigned index) const { - + unsigned bitWordIndex, offset; + std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(index); + return getBitWord(bitWordIndex) & (1 << offset); } bool equals(const AutoDiffIndexSubset *other) const; - bool isProperSubsetOf(const AutoDiffIndexSubset *other) const; - bool isProperSupersetOf(const AutoDiffIndexSubset *other) const; + bool isSubsetOf(const AutoDiffIndexSubset *other) const; + bool isSupersetOf(const AutoDiffIndexSubset *other) const; AutoDiffIndexSubset *adding(unsigned index, ASTContext &ctx) const; - AutoDiffIndexSubset *adding(AutoDiffIndexSubset *other, - ASTContext &ctx) const; - static void Profile(llvm::FoldingSetNodeID &id, - const SmallBitVector &rawIndices) const; - void Profile(llvm::FoldingSetNodeID &id) const { - Profile(id, rawIndices); - } + void Profile(llvm::FoldingSetNodeID &id) const; + +private: + int findNext(int startIndex) const; + int findFirst() const { return findNext(-1); } + int findPrevious(int endIndex) const; + int findLast() const { return findPrevious(capacity); } + +public: + class iterator { + typedef unsigned value_type; + typedef int difference_type; + typedef unsigned * pointer; + typedef unsigned & reference; + typedef std::forward_iterator_tag iterator_category; + private: + const AutoDiffIndexSubset *parent; + int current = 0; + + void advance() { + assert(current != -1 && "Trying to advance past end."); + current = parent->findNext(current); + } + + public: + iterator(const AutoDiffIndexSubset *parent, int current) + : parent(parent), current(current) {} + explicit iterator(const AutoDiffIndexSubset *parent) + : iterator(parent, parent->findFirst()) {} + iterator(const iterator &) = default; + + iterator operator++(int) { + auto prev = *this; + advance(); + return prev; + } + + iterator &operator++() { + advance(); + return *this; + } + + unsigned operator*() const { return current; } + + bool operator==(const iterator &other) const { + assert(&parent == &other.parent && + "Comparing iterators from different BitVectors"); + return current == other.current; + } + + bool operator!=(const iterator &other) const { + assert(&parent == &other.parent && + "Comparing iterators from different BitVectors"); + return current != other.current; + } + }; }; class AutoDiffFunctionParameterSubset { private: AutoDiffIndexSubset *indexSubset; - bool isCurried; + bool curried; public: explicit AutoDiffFunctionParameterSubset( AutoDiffIndexSubset *indexSubset, bool isCurried) - : indexSubset(indexSubset), isCurried(isCurried) {} + : indexSubset(indexSubset), curried(isCurried) {} explicit AutoDiffFunctionParameterSubset( ASTContext &ctx, AutoDiffIndexSubset *parameterSubset, Optional isSelfIncluded); + + AutoDiffIndexSubset *getIndexSubset() const { + return indexSubset; + } + + bool isCurried() const { + return curried; + } }; /// SIL-level automatic differentiation indices. Consists of a source index, @@ -414,10 +482,6 @@ class AutoDiffAssociatedFunctionIdentifier : public llvm::FoldingSetNode { return parameterIndices; } - static AutoDiffAssociatedFunctionIdentifier *get( - AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder, - AutoDiffParameterIndices *parameterIndices, ASTContext &C); - void Profile(llvm::FoldingSetNodeID &ID) { ID.AddInteger(kind); ID.AddInteger(differentiationOrder); diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index b60c9034f5bec..9fa85e227db86 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -5248,16 +5248,33 @@ AutoDiffParameterIndices::get(llvm::SmallBitVector indices, ASTContext &C) { return newNode; } -AutoDiffIndexSubset *AutoDiffIndexSubset::get( - ASTContext &ctx, const SmallBitVector &&rawIndices) { +AutoDiffIndexSubset *AutoDiffIndexSubset::get(ASTContext &ctx, + unsigned capacity, + ArrayRef indices) { auto &foldingSet = ctx.getImpl().AutoDiffIndexSubsets; llvm::FoldingSetNodeID id; - Profile(id, rawIndices); + id.AddInteger(capacity); +#ifndef NDEBUG + int last = -1; +#endif + for (unsigned index : indices) { +#ifndef NDEBUG + assert((int)index > last && "Indices must be ascending"); + last = index; +#endif + id.AddInteger(index); + } void *insertPos = nullptr; auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos); if (existing) return existing; - + auto numBitWords = sizeof(AutoDiffIndexSubset) + + getNumBitWordsNeededForCapacity(capacity); + auto *buf = reinterpret_cast( + ctx.Allocate(numBitWords, alignof(AutoDiffIndexSubset))); + auto *newNode = new (buf) AutoDiffIndexSubset(capacity, numBitWords, indices); + foldingSet.InsertNode(newNode, insertPos); + return newNode; } AutoDiffAssociatedFunctionIdentifier * diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index 5dba9eb86955c..f898860d10631 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -356,69 +356,145 @@ NominalTypeDecl *VectorSpace::getNominal() const { return getVector()->getNominalOrBoundGenericNominal(); } -AutoDiffIndexSubset * -AutoDiffIndexSubset::get(ASTContext &ctx, unsigned capacity, bool includeAll) { - auto *buf = reinterpret_cast( - ctx.Allocate(sizeof(AutoDiffIndexSubset), alignof(AutoDiffIndexSubset))); - return new (buf) AutoDiffIndexSubset(SmallBitVector(capacity, includeAll)); +std::pair +AutoDiffIndexSubset::getBitWordIndexAndOffset(unsigned index) { + auto bitWordIndex = (index + 1) / sizeof(BitWord); + auto bitWordOffset = index - bitWordIndex * sizeof(BitWord); + return {bitWordIndex, bitWordOffset}; +} + +unsigned +AutoDiffIndexSubset::getNumBitWordsNeededForCapacity(unsigned capacity) { + auto numBitWords = capacity / sizeof(BitWord); + if (capacity % sizeof(BitWord)) + numBitWords += 1; + return numBitWords; +} + +AutoDiffIndexSubset::AutoDiffIndexSubset(unsigned capacity, + unsigned numBitWords, + ArrayRef indices) + : capacity(capacity), numBitWords(numBitWords) { + std::uninitialized_fill_n(getBitWordsData(), numBitWords, 0); + for (auto i : indices) { + unsigned bitWordIndex, offset; + std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(i); + getBitWord(bitWordIndex) |= (1 << offset); + } } AutoDiffIndexSubset * -AutoDiffIndexSubset::get(ASTContext &ctx, const SmallBitVector &rawIndices) { - auto *buf = reinterpret_cast( - ctx.Allocate(sizeof(AutoDiffIndexSubset), alignof(AutoDiffIndexSubset))); - return new (buf) AutoDiffIndexSubset(std::move(rawIndices)); +AutoDiffIndexSubset::get(ASTContext &ctx, unsigned capacity, bool includeAll) { + return get(ctx, capacity, + SmallVector(capacity, (unsigned)includeAll)); } AutoDiffIndexSubset *AutoDiffIndexSubset::get(ASTContext &ctx, unsigned capacity, IntRange<> range) { - auto *subset = get(ctx, capacity); - subset->rawIndices.set(range.front(), range.back()); - return subset; + return get(ctx, capacity, + SmallVector(range.begin(), range.end())); } -AutoDiffIndexSubset *AutoDiffIndexSubset::get(ASTContext &ctx, - unsigned capacity, - ArrayRef indices) { - auto *subset = get(ctx, capacity); - for (auto i : indices) { - assert(i < capacity && "Index must be smaller than capacity"); - subset->rawIndices.set(i); - } - return subset; +void AutoDiffIndexSubset::Profile(llvm::FoldingSetNodeID &id) const { + id.AddInteger(capacity); + for (auto index : getIndices()) + id.AddInteger(index); +} + +AutoDiffIndexSubset::iterator AutoDiffIndexSubset::begin() const { + return iterator(this); +} + +AutoDiffIndexSubset::iterator AutoDiffIndexSubset::end() const { + return iterator(this, findLast() + 1); +} + +iterator_range +AutoDiffIndexSubset::getIndices() const { + return make_range(begin(), end()); +} + +unsigned AutoDiffIndexSubset::getNumIndices() const { + return accumulate(getIndices(), (unsigned)0, + [](unsigned total, BitWord bitWord) { + return total + llvm::countPopulation(bitWord); + }); } bool AutoDiffIndexSubset::equals(const AutoDiffIndexSubset *other) const { - return rawIndices == other->rawIndices; + return capacity == other->getCapacity() && + getBitWords().equals(other->getBitWords()); } bool AutoDiffIndexSubset:: -isProperSubsetOf(const AutoDiffIndexSubset *other) const { - return getSize() == other->getSize() && other->rawIndices.test(rawIndices); +isSubsetOf(const AutoDiffIndexSubset *other) const { + assert(capacity == other->capacity); + for (auto index : range(numBitWords)) + if (getBitWord(index) & ~other->getBitWord(index)) + return false; + return true; } bool AutoDiffIndexSubset:: -isProperSupersetOf(const AutoDiffIndexSubset *other) const { - return getSize() == other->getSize() && rawIndices.test(other->rawIndices); +isSupersetOf(const AutoDiffIndexSubset *other) const { + assert(capacity == other->capacity); + for (auto index : range(numBitWords)) + if (~getBitWord(index) & other->getBitWord(index)) + return false; + return true; } AutoDiffIndexSubset * AutoDiffIndexSubset::adding(unsigned index, ASTContext &ctx) const { assert(index < getCapacity()); - + SmallVector newIndices(begin(), end()); + newIndices.push_back(index); + llvm::sort(newIndices.begin(), newIndices.end()); + return get(ctx, capacity, newIndices); } -AutoDiffIndexSubset * -AutoDiffIndexSubset::adding(AutoDiffIndexSubset *other, ASTContext &ctx) const { - return get(ctx, rawIndices | other->rawIndices); +int AutoDiffIndexSubset::findNext(int startIndex) const { + if (numBitWords == 0) + return -1; + unsigned bitWordIndex = 0, offset = 0; + if (startIndex >= 0) + std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(startIndex); + for (; bitWordIndex < getNumBitWords(); ++bitWordIndex) { + auto bitWord = getBitWord(bitWordIndex); + if (!bitWord) { + ++bitWordIndex; + continue; + } + while (offset < sizeof(BitWord)) { + if (bitWord & (1 << offset)) + return bitWordIndex * sizeof(BitWord) + offset; + ++offset; + } + offset = 0; + } + return -1; } -void AutoDiffIndexSubset::Profile(llvm::FoldingSetNodeID &id, - const SmallBitVector &rawIndices) { - id.AddInteger(rawIndices.size()); - for (auto i : rawIndices.set_bits()) - id.AddInteger(i); +int AutoDiffIndexSubset::findPrevious(int endIndex) const { + if (numBitWords == 0) + return -1; + unsigned bitWordIndex, offset; + std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(endIndex); + for (; bitWordIndex >= 0; --bitWordIndex) { + auto bitWord = getBitWord(bitWordIndex); + if (!bitWord) { + --bitWordIndex; + continue; + } + while (offset >= 0) { + if (bitWord & (1 << offset)) + return bitWordIndex * sizeof(BitWord) + offset; + --offset; + } + offset = sizeof(BitWord) - 1; + } + return -1; } AutoDiffFunctionParameterSubset:: From 5e09d6a9e83018586f5afffe8d839c4dbb1196b8 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Fri, 10 May 2019 14:26:00 -0700 Subject: [PATCH 3/8] WIP --- include/swift/AST/AutoDiff.h | 39 ++++++----- include/swift/AST/Types.h | 2 +- lib/AST/AutoDiff.cpp | 69 +++++++++---------- lib/SIL/SILFunctionType.cpp | 7 +- .../Mandatory/Differentiation.cpp | 59 ++++++++-------- lib/Serialization/DeserializeSIL.cpp | 6 +- lib/Serialization/SerializeSIL.cpp | 4 +- 7 files changed, 96 insertions(+), 90 deletions(-) diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index e44d0ff21b419..0c66fa1170660 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -74,6 +74,7 @@ class ParsedAutoDiffParameter { }; class AnyFunctionType; +class AutoDiffIndexSubset; class AutoDiffParameterIndicesBuilder; class Type; @@ -174,7 +175,8 @@ class AutoDiffParameterIndices : public llvm::FoldingSetNode { /// ==> returns 1110 /// (because the lowered SIL type is (A, B, C, D) -> R) /// - llvm::SmallBitVector getLowered(AnyFunctionType *functionType) const; + AutoDiffIndexSubset *getLowered(ASTContext &ctx, + AnyFunctionType *functionType) const; void Profile(llvm::FoldingSetNodeID &ID) const { ID.AddInteger(parameters.size()); @@ -224,7 +226,10 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode { private: using BitWord = uint64_t; + /// The total capacity of the index subset, which is `1` less than the largest + /// index. unsigned capacity; + /// The number of bit words in the index subset. in the index subset. unsigned numBitWords; static std::pair getBitWordIndexAndOffset(unsigned index); @@ -272,6 +277,8 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode { IntRange<> range); static AutoDiffIndexSubset *get(ASTContext &ctx, unsigned capacity, ArrayRef indices); + template + static AutoDiffIndexSubset *get(ASTContext &ctx, ArrayRef bits); unsigned getCapacity() const { return capacity; @@ -400,38 +407,33 @@ struct SILAutoDiffIndices { /// Function type: (A, B) -> (C, D) -> R /// Bits: [C][D][A][B] /// - llvm::SmallBitVector parameters; + AutoDiffIndexSubset *parameters; /// Creates a set of AD indices from the given source index and a bit vector /// representing parameter indices. /*implicit*/ SILAutoDiffIndices(unsigned source, - llvm::SmallBitVector parameters) + AutoDiffIndexSubset *parameters) : source(source), parameters(parameters) {} - /// Creates a set of AD indices from the given source index and an array of - /// parameter indices. Elements in `parameters` must be ascending integers. - /*implicit*/ SILAutoDiffIndices(unsigned source, - ArrayRef parameters); - bool operator==(const SILAutoDiffIndices &other) const; /// Queries whether the function's parameter with index `parameterIndex` is /// one of the parameters to differentiate with respect to. bool isWrtParameter(unsigned parameterIndex) const { - return parameterIndex < parameters.size() && - parameters.test(parameterIndex); + return parameterIndex < parameters->getCapacity() && + parameters->contains(parameterIndex); } void print(llvm::raw_ostream &s = llvm::outs()) const { s << "(source=" << source << " parameters=("; - interleave(parameters.set_bits(), + interleave(parameters->getIndices(), [&s](unsigned p) { s << p; }, [&s]{ s << ' '; }); s << "))"; } std::string mangle() const { std::string result = "src_" + llvm::utostr(source) + "_wrt_"; - interleave(parameters.set_bits(), + interleave(parameters->getIndices(), [&](unsigned idx) { result += llvm::utostr(idx); }, [&] { result += '_'; }); return result; @@ -482,6 +484,10 @@ class AutoDiffAssociatedFunctionIdentifier : public llvm::FoldingSetNode { return parameterIndices; } + static AutoDiffAssociatedFunctionIdentifier *get( + AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder, + AutoDiffParameterIndices *parameterIndices, ASTContext &C); + void Profile(llvm::FoldingSetNodeID &ID) { ID.AddInteger(kind); ID.AddInteger(differentiationOrder); @@ -603,19 +609,18 @@ template struct DenseMapInfo; template<> struct DenseMapInfo { static SILAutoDiffIndices getEmptyKey() { - return { DenseMapInfo::getEmptyKey(), SmallBitVector() }; + return { DenseMapInfo::getEmptyKey(), nullptr }; } static SILAutoDiffIndices getTombstoneKey() { - return { DenseMapInfo::getTombstoneKey(), - SmallBitVector(sizeof(intptr_t), true) }; + return { DenseMapInfo::getTombstoneKey(), nullptr }; } static unsigned getHashValue(const SILAutoDiffIndices &Val) { - auto params = Val.parameters.set_bits(); unsigned combinedHash = hash_combine(~1U, DenseMapInfo::getHashValue(Val.source), - hash_combine_range(params.begin(), params.end())); + hash_combine_range(Val.parameters->begin(), + Val.parameters->end())); return combinedHash; } diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index d4e4d74b513af..ca8d60975773d 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -4159,7 +4159,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode, /// Returns the type of a differentiation function that is associated with /// a function of this type. CanSILFunctionType getAutoDiffAssociatedFunctionType( - const SmallBitVector ¶meterIndices, unsigned resultIndex, + AutoDiffIndexSubset *parameterIndices, unsigned resultIndex, unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind, SILModule &module, LookupConformanceFn lookupConformance, GenericSignature *whereClauseGenericSignature = nullptr); diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index 9af8c4306fd9e..8e609ff0435e6 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -20,33 +20,8 @@ using namespace swift; -SILAutoDiffIndices::SILAutoDiffIndices( - unsigned source, ArrayRef parameters) : source(source) { - if (parameters.empty()) - return; - - auto max = *std::max_element(parameters.begin(), parameters.end()); - this->parameters.resize(max + 1); - int last = -1; - for (auto paramIdx : parameters) { - assert((int)paramIdx > last && "Parameter indices must be ascending"); - last = paramIdx; - this->parameters.set(paramIdx); - } -} - -bool SILAutoDiffIndices::operator==( - const SILAutoDiffIndices &other) const { - if (source != other.source) - return false; - - // The parameters are the same when they have exactly the same set bit - // indices, even if they have different sizes. - llvm::SmallBitVector buffer(std::max(parameters.size(), - other.parameters.size())); - buffer ^= parameters; - buffer ^= other.parameters; - return buffer.none(); +bool SILAutoDiffIndices::operator==(const SILAutoDiffIndices &other) const { + return source == other.source && parameters == other.parameters; } AutoDiffAssociatedFunctionKind:: @@ -222,8 +197,9 @@ static unsigned countNumFlattenedElementTypes(Type type) { /// ==> returns 1110 /// (because the lowered SIL type is (A, B, C, D) -> R) /// -llvm::SmallBitVector -AutoDiffParameterIndices::getLowered(AnyFunctionType *functionType) const { +AutoDiffIndexSubset * +AutoDiffParameterIndices::getLowered(ASTContext &ctx, + AnyFunctionType *functionType) const { SmallVector curryLevels; unwrapCurryLevels(functionType, curryLevels); @@ -241,16 +217,18 @@ AutoDiffParameterIndices::getLowered(AnyFunctionType *functionType) const { // Construct the result by setting each range of bits that corresponds to each // "on" parameter. - llvm::SmallBitVector result(totalLoweredSize); + llvm::SmallVector loweredIndices; unsigned currentBitIndex = 0; for (unsigned i : range(parameters.size())) { auto paramLoweredSize = paramLoweredSizes[i]; - if (parameters[i]) - result.set(currentBitIndex, currentBitIndex + paramLoweredSize); + if (parameters[i]) { + auto indices = range(currentBitIndex, currentBitIndex + paramLoweredSize); + loweredIndices.append(indices.begin(), indices.end()); + } currentBitIndex += paramLoweredSize; } - return result; + return AutoDiffIndexSubset::get(ctx, totalLoweredSize, loweredIndices); } static unsigned getNumAutoDiffParameterIndices(AnyFunctionType *fnTy) { @@ -386,6 +364,17 @@ AutoDiffIndexSubset::get(ASTContext &ctx, unsigned capacity, bool includeAll) { SmallVector(capacity, (unsigned)includeAll)); } +template +AutoDiffIndexSubset * +AutoDiffIndexSubset::get(ASTContext &ctx, ArrayRef bits) { + SmallVector indices; + indices.reserve(bits.size()); + for (auto i : indices(bits)) + if (bits[i]) + indices.push_back(i); + return get(ctx, bits.size(), indices); +} + AutoDiffIndexSubset *AutoDiffIndexSubset::get(ASTContext &ctx, unsigned capacity, IntRange<> range) { @@ -445,9 +434,17 @@ isSupersetOf(const AutoDiffIndexSubset *other) const { AutoDiffIndexSubset * AutoDiffIndexSubset::adding(unsigned index, ASTContext &ctx) const { assert(index < getCapacity()); - SmallVector newIndices(begin(), end()); - newIndices.push_back(index); - llvm::sort(newIndices.begin(), newIndices.end()); + SmallVector newIndices; + newIndices.reserve(capacity + 1); + bool inserted = false; + for (auto it = begin(); it != end(); ++it) { + auto curIndex = *it; + if (inserted && curIndex > index) { + newIndices.push_back(index); + inserted = false; + } + newIndices.push_back(curIndex); + } return get(ctx, capacity, newIndices); } diff --git a/lib/SIL/SILFunctionType.cpp b/lib/SIL/SILFunctionType.cpp index 89f0b734acd38..387fcfb782eea 100644 --- a/lib/SIL/SILFunctionType.cpp +++ b/lib/SIL/SILFunctionType.cpp @@ -148,10 +148,9 @@ CanSILFunctionType SILFunctionType::getWithoutDifferentiability() { } CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType( - const SmallBitVector ¶meterIndices, unsigned resultIndex, - unsigned differentiationOrder, - AutoDiffAssociatedFunctionKind kind, SILModule &module, - LookupConformanceFn lookupConformance, + AutoDiffIndexSubset *parameterIndices, unsigned resultIndex, + unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind, + SILModule &module, LookupConformanceFn lookupConformance, GenericSignature *whereClauseGenSig) { // JVP: (T...) -> ((R...), // (T.TangentVector...) -> (R.TangentVector...)) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 2499810cc683b..713555b58c313 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -923,22 +923,20 @@ class ADContext { DifferentiationTask * lookUpMinimalDifferentiationTask(SILFunction *original, const SILAutoDiffIndices &indices) { - auto supersetParamIndices = llvm::SmallBitVector(); - const auto &indexSet = indices.parameters; + AutoDiffIndexSubset *superset = nullptr; + auto *indexSet = indices.parameters; if (auto *existingTask = lookUpDifferentiationTask(original, indices)) return existingTask; for (auto *rda : original->getDifferentiableAttrs()) { - const auto &rdaIndexSet = rda->getIndices().parameters; + auto *rdaIndexSet = rda->getIndices().parameters; // If all indices in indexSet are in rdaIndexSet, and it has fewer // indices than our current candidate and a primitive adjoint, rda is our // new candidate. - if (!indexSet.test(rdaIndexSet) && // all indexSet indices in rdaIndexSet - (supersetParamIndices.empty() || // fewer parameters than before - rdaIndexSet.count() < supersetParamIndices.count())) - supersetParamIndices = rda->getIndices().parameters; + if (rdaIndexSet->isSubsetOf(indexSet)) + superset = rda->getIndices().parameters; } - auto existing = enqueuedTaskIndices.find( - {original, {indices.source, supersetParamIndices}}); + auto existing = + enqueuedTaskIndices.find({original, {indices.source, superset}}); if (existing == enqueuedTaskIndices.end()) return nullptr; return differentiationTasks[existing->getSecond()].get(); @@ -948,7 +946,7 @@ class ADContext { /// function. SILValue promoteToDifferentiableFunction( SILBuilder &builder, SILLocation loc, SILValue origFnOperand, - const llvm::SmallBitVector ¶meterIndices, + AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, DifferentiationInvoker invoker); /// For an autodiff_function instruction that is missing associated functions, @@ -1235,8 +1233,7 @@ class DifferentiableActivityInfo { bool isVaried(SILValue value, unsigned independentVariableIndex) const; bool isUseful(SILValue value, unsigned dependentVariableIndex) const; - bool isVaried(SILValue value, - const llvm::SmallBitVector ¶meterIndices) const; + bool isVaried(SILValue value, AutoDiffIndexSubset *parameterIndices) const; bool isActive(SILValue value, const SILAutoDiffIndices &indices) const; Activity getActivity(SILValue value, @@ -1506,8 +1503,8 @@ bool DifferentiableActivityInfo::isVaried( } bool DifferentiableActivityInfo::isVaried( - SILValue value, const llvm::SmallBitVector ¶meterIndices) const { - for (auto paramIdx : parameterIndices.set_bits()) + SILValue value, AutoDiffIndexSubset *parameterIndices) const { + for (auto paramIdx : parameterIndices->getIndices()) if (isVaried(value, paramIdx)) return true; return false; @@ -1792,7 +1789,7 @@ emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder, if (auto diffableFnType = original->getType().castTo()) { if (diffableFnType->isDifferentiable()) { auto paramIndices = diffableFnType->getDifferentiationParameterIndices(); - for (auto i : desiredIndices.parameters.set_bits()) { + for (auto i : desiredIndices.parameters->getIndices()) { if (i >= paramIndices.size() || !paramIndices[i]) { context.emitNondifferentiabilityError(original, parentTask, diag::autodiff_function_nondiff_parameter_not_differentiable); @@ -1944,12 +1941,13 @@ emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder, // Check that the requirement indices are the same as the desired indices. auto *requirementParameterIndices = diffAttr->getParameterIndices(); auto loweredRequirementIndices = requirementParameterIndices->getLowered( + context.getASTContext(), requirementDecl->getInterfaceType()->castTo()); SILAutoDiffIndices requirementIndices(/*source*/ 0, loweredRequirementIndices); if (desiredIndices.source != requirementIndices.source || - desiredIndices.parameters.test(requirementIndices.parameters)) { + !desiredIndices.parameters->isSubsetOf(requirementIndices.parameters)) { context.emitNondifferentiabilityError(original, parentTask, diag::autodiff_protocol_member_subset_indices_not_differentiable); return None; @@ -2425,11 +2423,11 @@ ADContext::createPrimalValueStruct(const DifferentiationTask *task, /// indices, figure out whether the parent function is being differentiated with /// respect to this parameter, according to the indices. static bool isDifferentiationParameter(SILArgument *argument, - llvm::SmallBitVector indices) { + AutoDiffIndexSubset *indices) { if (!argument) return false; auto *function = argument->getFunction(); auto paramArgs = function->getArgumentsWithoutIndirectResults(); - for (unsigned i : indices.set_bits()) + for (unsigned i : indices->getIndices()) if (paramArgs[i] == argument) return true; return false; @@ -2674,7 +2672,9 @@ class PrimalGenCloner final errorOccurred = true; return; } - SILAutoDiffIndices indices(/*source*/ 0, /*parameters*/ {0}); + SILAutoDiffIndices indices( + /*source*/ 0, + /*parameters*/ AutoDiffIndexSubset::get(getASTContext(), 1)); auto *task = getContext().lookUpDifferentiationTask(getterFn, indices); if (!task) { getContext().emitNondifferentiabilityError( @@ -2956,7 +2956,11 @@ class PrimalGenCloner final return; } // Form expected indices by assuming there's only one result. - SILAutoDiffIndices indices(activeResultIndices.front(), activeParamIndices); + SILAutoDiffIndices indices(activeResultIndices.front(), + AutoDiffIndexSubset::get( + getASTContext(), + ai->getArgumentsWithoutIndirectResults().size(), + activeParamIndices)); // Emit the VJP. auto vjpAndVJPIndices = emitAssociatedFunctionReference( @@ -3968,7 +3972,7 @@ class AdjointEmitter final : public SILInstructionVisitor { task->getIndices().isWrtParameter(selfParamIndex)) addRetElt(selfParamIndex); // Add the non-self parameters that are differentiated with respect to. - for (auto i : task->getIndices().parameters.set_bits()) { + for (auto i : task->getIndices().parameters->getIndices()) { // Do not add the self parameter because we have already added it at the // beginning. if (origTy->hasSelfParam() && i == selfParamIndex) @@ -4178,7 +4182,7 @@ class AdjointEmitter final : public SILInstructionVisitor { } } // Accumulate adjoints for the remaining non-self original parameters. - for (unsigned i : applyInfo.actualIndices.parameters.set_bits()) { + for (unsigned i : applyInfo.actualIndices.parameters->getIndices()) { // Do not set the adjoint of the original self parameter because we // already added it at the beginning. if (ai->hasSelfArgument() && i == selfParamIndex) @@ -4187,8 +4191,8 @@ class AdjointEmitter final : public SILInstructionVisitor { auto cotan = *allResultsIt++; // If a cotangent value corresponds to a non-desired parameter, it won't // be used, so release it. - if (i >= applyInfo.desiredIndices.parameters.size() || - !applyInfo.desiredIndices.parameters[i]) { + if (i >= applyInfo.desiredIndices.parameters->getCapacity() || + !applyInfo.desiredIndices.parameters->contains(i)) { emitCleanup(builder, loc, cotan); continue; } @@ -5473,7 +5477,7 @@ void DifferentiationTask::createEmptyAdjoint() { } // Add adjoint results for the requested non-self wrt parameters. - for (auto i : getIndices().parameters.set_bits()) { + for (auto i : getIndices().parameters->getIndices()) { if (origTy->hasSelfParam() && i == selfParamIndex) continue; auto origParam = origParams[i]; @@ -5636,7 +5640,8 @@ void DifferentiationTask::createVJP(bool isExported) { "unexpected number of vjp parameters"); assert(vjpConv.getResults().size() == numOriginalResults + 1 && "unexpected number of vjp results"); - assert(adjointConv.getResults().size() == getIndices().parameters.count() && + assert(adjointConv.getResults().size() == + getIndices().parameters->getCapacity() && "unexpected number of adjoint results"); // We assume that primal result conventions (for all results but the optional @@ -5721,7 +5726,7 @@ class Differentiation : public SILModuleTransform { SILValue ADContext::promoteToDifferentiableFunction( SILBuilder &builder, SILLocation loc, SILValue origFnOperand, - const llvm::SmallBitVector ¶meterIndices, unsigned differentiationOrder, + AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, DifferentiationInvoker invoker) { if (auto *ai = dyn_cast(origFnOperand)) { if (auto *sourceFn = dyn_cast(ai->getCallee())) { diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index d9a7a3dcb1d13..88fc284fccb21 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -660,9 +660,9 @@ SILDeserializer::readSILFunctionChecked(DeclID FID, SILFunction *existingFn, llvm::SmallBitVector parametersBitVector(parameters.size()); StringRef jvpName = MF->getIdentifier(jvpNameId).str(); StringRef vjpName = MF->getIdentifier(vjpNameId).str(); - for (unsigned i : indices(parameters)) - parametersBitVector[i] = parameters[i]; - SILAutoDiffIndices indices(source, parametersBitVector); + auto *parameterIndexSubset = + AutoDiffIndexSubset::get(MF->getContext(), parameters); + SILAutoDiffIndices indices(source, parameterIndexSubset); MF->readGenericRequirements(requirements, SILCursor); auto *attr = SILDifferentiableAttr::create(SILMod, indices, requirements, diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp index 74aef167b4faa..6991d5ad76109 100644 --- a/lib/Serialization/SerializeSIL.cpp +++ b/lib/Serialization/SerializeSIL.cpp @@ -431,8 +431,8 @@ void SILSerializer::writeSILFunction(const SILFunction &F, bool DeclOnly) { auto ¶mIndices = DA->getIndices(); SmallVector parameters; - for (unsigned i : indices(paramIndices.parameters)) - parameters.push_back(paramIndices.parameters[i]); + for (unsigned i : range(paramIndices.parameters->getCapacity())) + parameters.push_back(paramIndices.parameters->contains(i)); SILDifferentiableAttrLayout::emitRecord( Out, ScratchRecord, differentiableAttrAbbrCode, From e7d2d81258895417a217b62c5e9cd3675fbccb96 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Fri, 10 May 2019 19:19:02 -0700 Subject: [PATCH 4/8] WIP --- include/swift/AST/AutoDiff.h | 9 ++-- include/swift/AST/DiagnosticsParse.def | 4 +- include/swift/AST/Types.h | 4 +- include/swift/SIL/SILBuilder.h | 2 +- include/swift/SIL/SILFunction.h | 3 ++ include/swift/SIL/SILInstruction.h | 10 ++--- lib/AST/AutoDiff.cpp | 24 ++++++----- lib/IRGen/GenDiffFunc.cpp | 6 +-- lib/ParseSIL/ParseSIL.cpp | 42 ++++++++++++++----- lib/SIL/SILDeclRef.cpp | 2 +- lib/SIL/SILFunctionBuilder.cpp | 1 + lib/SIL/SILFunctionType.cpp | 29 ++++++------- lib/SIL/SILInstructions.cpp | 6 +-- lib/SIL/SILPrinter.cpp | 6 +-- lib/SIL/SILVerifier.cpp | 10 ++--- lib/SILGen/SILGenPoly.cpp | 1 + .../Mandatory/Differentiation.cpp | 5 ++- lib/Serialization/DeserializeSIL.cpp | 23 +++++----- lib/Serialization/SerializeSIL.cpp | 6 +-- 19 files changed, 115 insertions(+), 78 deletions(-) diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 0c66fa1170660..fa3d561bd0bfe 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -277,8 +277,6 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode { IntRange<> range); static AutoDiffIndexSubset *get(ASTContext &ctx, unsigned capacity, ArrayRef indices); - template - static AutoDiffIndexSubset *get(ASTContext &ctx, ArrayRef bits); unsigned getCapacity() const { return capacity; @@ -297,11 +295,14 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode { return getBitWord(bitWordIndex) & (1 << offset); } + bool isEmpty() const; bool equals(const AutoDiffIndexSubset *other) const; bool isSubsetOf(const AutoDiffIndexSubset *other) const; bool isSupersetOf(const AutoDiffIndexSubset *other) const; AutoDiffIndexSubset *adding(unsigned index, ASTContext &ctx) const; + AutoDiffIndexSubset *extendingCapacity(ASTContext &ctx, + unsigned newCapacity) const; void Profile(llvm::FoldingSetNodeID &id) const; @@ -349,13 +350,13 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode { bool operator==(const iterator &other) const { assert(&parent == &other.parent && - "Comparing iterators from different BitVectors"); + "Comparing iterators from different AutoDiffIndexSubsets"); return current == other.current; } bool operator!=(const iterator &other) const { assert(&parent == &other.parent && - "Comparing iterators from different BitVectors"); + "Comparing iterators from different AutoDiffIndexSubsets"); return current != other.current; } }; diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index b5baf560dd687..f06386b1cf8c7 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -1535,7 +1535,9 @@ ERROR(sil_inst_autodiff_operand_list_expected_rbrace,PointsToFirstBadToken, ERROR(sil_inst_autodiff_num_operand_list_order_mismatch,PointsToFirstBadToken, "the number of operand lists does not match the order", ()) ERROR(sil_inst_autodiff_expected_associated_function_kind_attr,PointsToFirstBadToken, - "expects an assoiacted function kind attribute, e.g. '[jvp]'", ()) + "expected an assoiacted function kind attribute, e.g. '[jvp]'", ()) +ERROR(sil_inst_autodiff_expected_function_type_operand,PointsToFirstBadToken, + "expected an operand of a function type", ()) //------------------------------------------------------------------------------ // MARK: Generics parsing diagnostics diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index a165d8531d7f0..cfd96963a9877 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -4148,7 +4148,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode, // SWIFT_ENABLE_TENSORFLOW CanSILFunctionType getWithDifferentiability( - unsigned differentiationOrder, const SmallBitVector ¶meterIndices); + unsigned differentiationOrder, AutoDiffIndexSubset *parameterIndices); CanSILFunctionType getWithoutDifferentiability(); @@ -4164,7 +4164,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode, /// differentiate with respect to for this differentiable function type. (e.g. /// which parameters are not @nondiff). The function type must be /// differentiable. - SmallBitVector getDifferentiationParameterIndices() const; + AutoDiffIndexSubset *getDifferentiationParameterIndices(); /// If this is a @convention(witness_method) function with a class /// constrained self parameter, return the class constraint for the diff --git a/include/swift/SIL/SILBuilder.h b/include/swift/SIL/SILBuilder.h index 3fe0747495cc8..8c5e2a29572b3 100644 --- a/include/swift/SIL/SILBuilder.h +++ b/include/swift/SIL/SILBuilder.h @@ -504,7 +504,7 @@ class SILBuilder { /// SWIFT_ENABLE_TENSORFLOW AutoDiffFunctionInst *createAutoDiffFunction( - SILLocation loc, const llvm::SmallBitVector ¶meterIndices, + SILLocation loc, AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, SILValue original, ArrayRef associatedFunctions = {}) { return insert(AutoDiffFunctionInst::create(getModule(), diff --git a/include/swift/SIL/SILFunction.h b/include/swift/SIL/SILFunction.h index f82a6f20e4f5c..30b08688f24ff 100644 --- a/include/swift/SIL/SILFunction.h +++ b/include/swift/SIL/SILFunction.h @@ -174,6 +174,9 @@ class SILDifferentiableAttr final { SILFunction *getOriginal() const { return Original; } const SILAutoDiffIndices &getIndices() const { return indices; } + void setIndices(const SILAutoDiffIndices &indices) { + this->indices = indices; + } TrailingWhereClause *getWhereClause() const { return WhereClause; } diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h index be2c95918c1ee..db4f0dcf164b6 100644 --- a/include/swift/SIL/SILInstruction.h +++ b/include/swift/SIL/SILInstruction.h @@ -7721,7 +7721,7 @@ class AutoDiffFunctionInst final : private: friend SILBuilder; /// Differentiation parameter indices. - SmallBitVector parameterIndices; + AutoDiffIndexSubset *parameterIndices; /// The order of differentiation. unsigned differentiationOrder; /// The number of operands. The first operand is always the original function. @@ -7730,7 +7730,7 @@ class AutoDiffFunctionInst final : unsigned numOperands; AutoDiffFunctionInst(SILModule &module, SILDebugLocation debugLoc, - const SmallBitVector ¶meterIndices, + AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, SILValue originalFunction, ArrayRef associatedFunctions); @@ -7738,20 +7738,20 @@ class AutoDiffFunctionInst final : public: static AutoDiffFunctionInst *create(SILModule &module, SILDebugLocation debugLoc, - const SmallBitVector ¶meterIndices, + AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, SILValue originalFunction, ArrayRef associatedFunctions); static SILType getAutoDiffType(SILValue original, unsigned differentiationOrder, - const SmallBitVector ¶meterIndices); + AutoDiffIndexSubset *parameterIndices); /// Returns the original function. SILValue getOriginalFunction() const { return getAllOperands()[0].get(); } /// Returns differentiation indices. - const SmallBitVector &getParameterIndices() const { + AutoDiffIndexSubset *getParameterIndices() const { return parameterIndices; } diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index 8e609ff0435e6..f2786a1daad1b 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -364,17 +364,6 @@ AutoDiffIndexSubset::get(ASTContext &ctx, unsigned capacity, bool includeAll) { SmallVector(capacity, (unsigned)includeAll)); } -template -AutoDiffIndexSubset * -AutoDiffIndexSubset::get(ASTContext &ctx, ArrayRef bits) { - SmallVector indices; - indices.reserve(bits.size()); - for (auto i : indices(bits)) - if (bits[i]) - indices.push_back(i); - return get(ctx, bits.size(), indices); -} - AutoDiffIndexSubset *AutoDiffIndexSubset::get(ASTContext &ctx, unsigned capacity, IntRange<> range) { @@ -408,6 +397,10 @@ unsigned AutoDiffIndexSubset::getNumIndices() const { }); } +bool AutoDiffIndexSubset::isEmpty() const { + return llvm::all_of(getBitWords(), [](BitWord bw) { return !(bool)bw; }); +} + bool AutoDiffIndexSubset::equals(const AutoDiffIndexSubset *other) const { return capacity == other->getCapacity() && getBitWords().equals(other->getBitWords()); @@ -448,6 +441,15 @@ AutoDiffIndexSubset::adding(unsigned index, ASTContext &ctx) const { return get(ctx, capacity, newIndices); } +AutoDiffIndexSubset *AutoDiffIndexSubset::extendingCapacity( + ASTContext &ctx, unsigned newCapacity) const { + assert(newCapacity >= getCapacity()); + SmallVector indices; + for (auto index : getIndices()) + indices.push_back(index); + return AutoDiffIndexSubset::get(ctx, newCapacity, indices); +} + int AutoDiffIndexSubset::findNext(int startIndex) const { if (numBitWords == 0) return -1; diff --git a/lib/IRGen/GenDiffFunc.cpp b/lib/IRGen/GenDiffFunc.cpp index a486171487948..01093f98f3483 100644 --- a/lib/IRGen/GenDiffFunc.cpp +++ b/lib/IRGen/GenDiffFunc.cpp @@ -39,14 +39,14 @@ namespace { class DiffFuncFieldInfo final : public RecordField { public: DiffFuncFieldInfo(DiffFuncIndex index, const TypeInfo &type, - const SmallBitVector ¶meterIndices) + AutoDiffIndexSubset *parameterIndices) : RecordField(type), Index(index), ParameterIndices(parameterIndices) {} /// The field index. const DiffFuncIndex Index; /// The parameter indices. - SmallBitVector ParameterIndices; + AutoDiffIndexSubset *ParameterIndices; std::string getFieldName() const { auto extractee = std::get<0>(Index); @@ -119,7 +119,7 @@ class DiffFuncTypeBuilder DiffFuncIndex> { SILFunctionType *origFnTy; - SmallBitVector parameterIndices; + AutoDiffIndexSubset *parameterIndices; public: DiffFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy) diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index c93ca437ee553..cb55fe2626e7e 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -974,7 +974,7 @@ void SILParser::convertRequirements(SILFunction *F, /// SWIFT_ENABLE_TENSORFLOW /// Parse a `differentiable` attribute, e.g. -/// `[differentiable wrt 0, 1 adjoint @other]`. +/// `[differentiable wrt 0, 1 vjp @other]`. /// Returns true on error. static bool parseDifferentiableAttr( SmallVectorImpl &DAs, SILParser &SP) { @@ -1045,9 +1045,12 @@ static bool parseDifferentiableAttr( whereLoc, requirementReprs); } // Create a SILDifferentiableAttr and we are done. + auto maxIndexRef = std::max_element(ParamIndices.begin(), ParamIndices.end()); + auto *paramIndicesSubset = AutoDiffIndexSubset::get( + P.Context, maxIndexRef ? *maxIndexRef + 1 : 0, ParamIndices); auto *Attr = SILDifferentiableAttr::create( - SP.SILMod, {SourceIndex, ParamIndices}, JVPName.str(), VJPName.str(), - WhereClause); + SP.SILMod, {SourceIndex, paramIndicesSubset}, JVPName.str(), + VJPName.str(), WhereClause); DAs.push_back(Attr); return false; } @@ -2873,7 +2876,7 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { // {%1 : $T, %2 : $T}, {%3 : $T, %4 : $T} // ^ jvp ^ vjp SourceLoc lastLoc; - SmallBitVector parameterIndices(32); + SmallVector parameterIndices; unsigned order = 1; // Parse optional `[wrt ...]` if (P.Tok.is(tok::l_square) && @@ -2882,15 +2885,12 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { P.consumeToken(tok::l_square); P.consumeToken(tok::identifier); // Parse indices. - unsigned size = parameterIndices.size(); while (P.Tok.is(tok::integer_literal)) { unsigned index; if (P.parseUnsignedInteger(index, lastLoc, diag::sil_inst_autodiff_expected_parameter_index)) return true; - if (index >= size) - parameterIndices.resize((size *= 2)); - parameterIndices.set(index); + parameterIndices.push_back(index); } if (P.parseToken(tok::r_square, diag::sil_inst_autodiff_attr_expected_rsquare, @@ -2917,8 +2917,15 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { } // Parse the original function value. SILValue original; - if (parseTypedValueRef(original, B)) + SourceLoc originalOperandLoc; + if (parseTypedValueRef(original, originalOperandLoc, B)) return true; + auto fnType = original->getType().getAs(); + if (!fnType) { + P.diagnose(originalOperandLoc, + diag::sil_inst_autodiff_expected_function_type_operand); + return true; + } SmallVector associatedFunctions; // Parse optional operand lists `with { , }, ...`. if (P.Tok.is(tok::identifier) && P.Tok.getText() == "with") { @@ -2950,7 +2957,10 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { } if (parseSILDebugLocation(InstLoc, B)) return true; - ResultVal = B.createAutoDiffFunction(InstLoc, parameterIndices, order, + auto *parameterIndicesSubset = + AutoDiffIndexSubset::get(P.Context, fnType->getNumParameters(), + parameterIndices); + ResultVal = B.createAutoDiffFunction(InstLoc, parameterIndicesSubset, order, original, associatedFunctions); break; } @@ -5714,6 +5724,18 @@ bool SILParserTUState::parseDeclSIL(Parser &P) { // SWIFT_ENABLE_TENSORFLOW for (auto &attr : DiffAttrs) { + // Resolve parameter indices to have the right capacity, if it's + // different from the number of parameters. We have to do this because + // the parser does not know the function type before creating a + // `SILDifferentiableAttr`, so it had to find the max of all provided + // indices. + if (attr->getIndices().parameters->getCapacity() != + SILFnType->getNumParameters()) { + auto *newParamIndices = attr->getIndices().parameters + ->extendingCapacity(P.Context, SILFnType->getNumParameters()); + attr->setIndices({attr->getIndices().source, newParamIndices}); + } + // Resolve where clause requirements. // If no where clause, continue. if (!attr->getWhereClause()) diff --git a/lib/SIL/SILDeclRef.cpp b/lib/SIL/SILDeclRef.cpp index 65f4fa1a20f88..bbd0a611525aa 100644 --- a/lib/SIL/SILDeclRef.cpp +++ b/lib/SIL/SILDeclRef.cpp @@ -683,7 +683,7 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const { getDecl()->getInterfaceType()->castTo(); auto silParameterIndices = autoDiffAssociatedFunctionIdentifier->getParameterIndices()->getLowered( - functionTy); + functionTy->getASTContext(), functionTy); SILAutoDiffIndices indices(/*source*/ 0, silParameterIndices); std::string mangledKind; switch (autoDiffAssociatedFunctionIdentifier->getKind()) { diff --git a/lib/SIL/SILFunctionBuilder.cpp b/lib/SIL/SILFunctionBuilder.cpp index e1f2d0ecca64f..f1f18ae695341 100644 --- a/lib/SIL/SILFunctionBuilder.cpp +++ b/lib/SIL/SILFunctionBuilder.cpp @@ -89,6 +89,7 @@ void SILFunctionBuilder::addFunctionAttributes(SILFunction *F, // Get lowered argument indices. auto paramIndices = A->getParameterIndices(); auto loweredParamIndices = paramIndices->getLowered( + F->getASTContext(), decl->getInterfaceType()->castTo()); SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices); auto silDiffAttr = SILDifferentiableAttr::create( diff --git a/lib/SIL/SILFunctionType.cpp b/lib/SIL/SILFunctionType.cpp index 575304d6b6a58..c2d61c0fa61a2 100644 --- a/lib/SIL/SILFunctionType.cpp +++ b/lib/SIL/SILFunctionType.cpp @@ -98,20 +98,19 @@ CanType SILFunctionType::getSelfInstanceType() const { } // SWIFT_ENABLE_TENSORFLOW -SmallBitVector -SILFunctionType::getDifferentiationParameterIndices() const { +AutoDiffIndexSubset * +SILFunctionType::getDifferentiationParameterIndices() { assert(isDifferentiable()); - SmallBitVector result(NumParameters, true); + SmallVector result; for (auto valueAndIndex : enumerate(getParameters())) - if (valueAndIndex.value().getDifferentiability() == + if (valueAndIndex.value().getDifferentiability() != SILParameterDifferentiability::NotDifferentiable) - result.reset(valueAndIndex.index()); - return result; + result.push_back(valueAndIndex.index()); + return AutoDiffIndexSubset::get(getASTContext(), getNumParameters(), result); } CanSILFunctionType SILFunctionType::getWithDifferentiability( - unsigned differentiationOrder, - const SmallBitVector ¶meterIndices) { + unsigned differentiationOrder, AutoDiffIndexSubset *parameterIndices) { // FIXME(rxwei): Handle differentiation order. SmallVector newParameters; @@ -119,9 +118,10 @@ CanSILFunctionType SILFunctionType::getWithDifferentiability( auto ¶m = paramAndIndex.value(); unsigned index = paramAndIndex.index(); newParameters.push_back(param.getWithDifferentiability( - index < parameterIndices.size() && parameterIndices[index] - ? SILParameterDifferentiability::DifferentiableOrNotApplicable - : SILParameterDifferentiability::NotDifferentiable)); + index < parameterIndices->getCapacity() && + parameterIndices->contains(index) + ? SILParameterDifferentiability::DifferentiableOrNotApplicable + : SILParameterDifferentiability::NotDifferentiable)); } auto newExtInfo = getExtInfo().withDifferentiable(); @@ -212,7 +212,8 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType( // Helper function testing if we are differentiating wrt this index. auto isWrtIndex = [&](unsigned index) -> bool { - return index < parameterIndices.size() && parameterIndices[index]; + return index < parameterIndices->getCapacity() && + parameterIndices->contains(index); }; // Calculate WRT parameter infos, in the order that they should appear in the @@ -2316,8 +2317,8 @@ const SILConstantInfo &TypeConverter::getConstantInfo(SILDeclRef constant) { if (auto *autoDiffFuncId = constant.autoDiffAssociatedFunctionIdentifier) { auto origFnConstantInfo = getConstantInfo(constant.asAutoDiffOriginalFunction()); - auto loweredIndices = - autoDiffFuncId->getParameterIndices()->getLowered(formalInterfaceType); + auto loweredIndices = autoDiffFuncId->getParameterIndices() + ->getLowered(Context, formalInterfaceType); silFnType = origFnConstantInfo.SILFnType->getAutoDiffAssociatedFunctionType( loweredIndices, /*resultIndex*/ 0, autoDiffFuncId->getDifferentiationOrder(), autoDiffFuncId->getKind(), M, diff --git a/lib/SIL/SILInstructions.cpp b/lib/SIL/SILInstructions.cpp index 936720ca8ac0a..b78402b7ba92d 100644 --- a/lib/SIL/SILInstructions.cpp +++ b/lib/SIL/SILInstructions.cpp @@ -569,7 +569,7 @@ TryApplyInst *TryApplyInst::create( SILType AutoDiffFunctionInst::getAutoDiffType(SILValue originalFunction, unsigned differentiationOrder, - const SmallBitVector ¶meterIndices) { + AutoDiffIndexSubset *parameterIndices) { auto fnTy = originalFunction->getType().castTo(); auto diffTy = fnTy->getWithDifferentiability(differentiationOrder, parameterIndices); @@ -578,7 +578,7 @@ AutoDiffFunctionInst::getAutoDiffType(SILValue originalFunction, AutoDiffFunctionInst::AutoDiffFunctionInst( SILModule &module, SILDebugLocation debugLoc, - const SmallBitVector ¶meterIndices, unsigned differentiationOrder, + AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, SILValue originalFunction, ArrayRef associatedFunctions) : InstructionBaseWithTrailingOperands( originalFunction, associatedFunctions, debugLoc, @@ -591,7 +591,7 @@ AutoDiffFunctionInst::AutoDiffFunctionInst( AutoDiffFunctionInst *AutoDiffFunctionInst::create( SILModule &module, SILDebugLocation debugLoc, - const SmallBitVector ¶meterIndices, + AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, SILValue originalFunction, ArrayRef associatedFunctions) { size_t size = totalSizeToAlloc(associatedFunctions.size() + 1); diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index 83d2edbecee99..a2a68b6070792 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -1159,9 +1159,9 @@ class SILPrinter : public SILInstructionVisitor { // SWIFT_ENABLE_TENSORFLOW void visitAutoDiffFunctionInst(AutoDiffFunctionInst *adfi) { - if (adfi->getParameterIndices().any()) { + if (adfi->getParameterIndices()->isEmpty()) { *this << "[wrt"; - for (auto i : adfi->getParameterIndices().set_bits()) + for (auto i : adfi->getParameterIndices()->getIndices()) *this << ' ' << i; *this << "] "; } @@ -3118,7 +3118,7 @@ void SILSpecializeAttr::print(llvm::raw_ostream &OS) const { void SILDifferentiableAttr::print(llvm::raw_ostream &OS) const { auto &indices = getIndices(); OS << "source " << indices.source << " wrt "; - interleave(indices.parameters.set_bits(), + interleave(indices.parameters->getIndices(), [&](unsigned index) { OS << index; }, [&] { OS << ", "; }); if (!JVPName.empty()) { diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp index 5fa78ce501809..48d9cad43b4bc 100644 --- a/lib/SIL/SILVerifier.cpp +++ b/lib/SIL/SILVerifier.cpp @@ -4671,7 +4671,7 @@ class SILVerifier : public SILVerifierBase { }; // Parameter indices must be specified. - require(!Attr.getIndices().parameters.empty(), + require(!Attr.getIndices().parameters->isEmpty(), "Parameter indices cannot be empty"); // JVP and VJP must be specified in canonical SIL. if (F->getModule().getStage() == SILStage::Canonical) @@ -4680,16 +4680,16 @@ class SILVerifier : public SILVerifierBase { // Verify if specified parameter indices are valid. auto numParams = countParams(F->getLoweredFunctionType()); int lastIndex = -1; - for (auto paramIdx : Attr.getIndices().parameters.set_bits()) { + for (auto paramIdx : Attr.getIndices().parameters->getIndices()) { require(paramIdx < numParams, "Parameter index out of bounds."); auto currentIdx = (int)paramIdx; require(currentIdx > lastIndex, "Parameter indices not ascending."); lastIndex = currentIdx; } - // TODO: Verify if the specified primal/adjoint function has the right - // signature. SIL function verification runs right after a function is + // TODO: Verify if the specified JVP/VJP function has the right signature. + // SIL function verification runs right after a function is // parsed. - // However, the adjoint function may come after the this function. Without + // However, the JVP/VJP function may come after the this function. Without // changing the compiler too much, is there a way to verify this at a module // level, after everything is parsed? } diff --git a/lib/SILGen/SILGenPoly.cpp b/lib/SILGen/SILGenPoly.cpp index 5e687f3bf5425..95bd4ca91dd18 100644 --- a/lib/SILGen/SILGenPoly.cpp +++ b/lib/SILGen/SILGenPoly.cpp @@ -3820,6 +3820,7 @@ getWitnessFunctionRef(SILGenFunction &SGF, auto originalFn = SGF.emitGlobalFunctionRef( loc, witness.asAutoDiffOriginalFunction()); auto loweredIndices = autoDiffFuncId->getParameterIndices()->getLowered( + SGF.getASTContext(), witness.getDecl()->getInterfaceType()->castTo()); auto autoDiffFn = SGF.B.createAutoDiffFunction( loc, loweredIndices, /*differentiationOrder*/ 1, originalFn); diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 9af0ef6ec5267..220f4eaa40037 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -1812,7 +1812,7 @@ emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder, if (diffableFnType->isDifferentiable()) { auto paramIndices = diffableFnType->getDifferentiationParameterIndices(); for (auto i : desiredIndices.parameters->getIndices()) { - if (i >= paramIndices.size() || !paramIndices[i]) { + if (i >= paramIndices->getCapacity() || !paramIndices->contains(i)) { context.emitNondifferentiabilityError(original, parentTask, diag::autodiff_function_nondiff_parameter_not_differentiable); return None; @@ -2764,7 +2764,8 @@ class PrimalGenCloner final errorOccurred = true; return; } - SILAutoDiffIndices indices(/*source*/ 0, /*parameters*/ {0}); + SILAutoDiffIndices indices(/*source*/ 0, + /*parameters*/ AutoDiffIndexSubset::get(getASTContext(), 1)); auto *task = getContext().lookUpDifferentiationTask(getterFn, indices); if (!task) { getContext().emitNondifferentiabilityError( diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index 700a45652a88b..25bdcffd63eb9 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -654,17 +654,19 @@ SILDeserializer::readSILFunctionChecked(DeclID FID, SILFunction *existingFn, uint64_t jvpNameId; uint64_t vjpNameId; uint64_t source; - ArrayRef parameters; + ArrayRef rawParameterIndices; SmallVector requirements; SILDifferentiableAttrLayout::readRecord(scratch, jvpNameId, vjpNameId, - source, parameters); + source, rawParameterIndices); - llvm::SmallBitVector parametersBitVector(parameters.size()); StringRef jvpName = MF->getIdentifier(jvpNameId).str(); StringRef vjpName = MF->getIdentifier(vjpNameId).str(); - auto *parameterIndexSubset = - AutoDiffIndexSubset::get(MF->getContext(), parameters); + SmallVector parameterIndices(rawParameterIndices.begin(), + rawParameterIndices.end()); + auto *parameterIndexSubset = AutoDiffIndexSubset::get( + MF->getContext(), fn->getLoweredFunctionType()->getNumParameters(), + parameterIndices); SILAutoDiffIndices indices(source, parameterIndexSubset); MF->readGenericRequirements(requirements, SILCursor); @@ -1505,18 +1507,19 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB, // SWIFT_ENABLE_TENSORFLOW case SILInstructionKind::AutoDiffFunctionInst: { auto numParamIndices = ListOfValues.size() - NumArguments * 3; - auto paramIndices = ListOfValues.take_front(numParamIndices); + auto rawParamIndices = + map>(ListOfValues.take_front(numParamIndices), + [&](uint64_t i) { return i; }); auto numParams = Attr2; - llvm::SmallBitVector paramIndicesBitVec(numParams); - for (unsigned idx : paramIndices) - paramIndicesBitVec.set(idx); + auto *paramIndices = + AutoDiffIndexSubset::get(MF->getContext(), numParams, rawParamIndices); SmallVector operands; for (auto i = numParamIndices; i < NumArguments * 3; i += 3) { auto astTy = MF->getType(ListOfValues[i]); auto silTy = getSILType(astTy, (SILValueCategory)ListOfValues[i+1]); operands.push_back(getLocalValue(ListOfValues[i+2], silTy)); } - ResultVal = Builder.createAutoDiffFunction(Loc, paramIndicesBitVec, + ResultVal = Builder.createAutoDiffFunction(Loc, paramIndices, /*differentiationOrder*/ Attr, operands[0], ArrayRef(operands).drop_front()); break; diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp index 851d3009e4f09..3ac5f9d5b2b81 100644 --- a/lib/Serialization/SerializeSIL.cpp +++ b/lib/Serialization/SerializeSIL.cpp @@ -971,8 +971,8 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) { case SILInstructionKind::AutoDiffFunctionInst: { auto *adfi = cast(&SI); SmallVector trailingInfo; - auto ¶mIndices = adfi->getParameterIndices(); - for (unsigned idx : paramIndices.set_bits()) + auto *paramIndices = adfi->getParameterIndices(); + for (unsigned idx : paramIndices->getIndices()) trailingInfo.push_back(idx); for (auto &op : adfi->getAllOperands()) { auto val = op.get(); @@ -982,7 +982,7 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) { } SILInstAutoDiffFunctionLayout::emitRecord(Out, ScratchRecord, SILAbbrCodes[SILInstAutoDiffFunctionLayout::Code], - adfi->getDifferentiationOrder(), (unsigned)paramIndices.size(), + adfi->getDifferentiationOrder(), paramIndices->getCapacity(), adfi->getNumOperands(), trailingInfo); break; } From 4f08e5724ab226b43abc3604798a5df6536c9edc Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Mon, 13 May 2019 09:23:10 -0700 Subject: [PATCH 5/8] [AutoDiff] Parameter indices data structure overhaul. --- include/swift/AST/AutoDiff.h | 203 ++++++++++++------ include/swift/Serialization/ModuleFormat.h | 2 +- lib/AST/ASTContext.cpp | 16 +- lib/AST/AutoDiff.cpp | 194 +++-------------- lib/SIL/SILPrinter.cpp | 2 +- .../Mandatory/Differentiation.cpp | 84 +++++--- lib/Serialization/DeserializeSIL.cpp | 3 +- lib/Serialization/SILFormat.h | 8 +- lib/Serialization/SerializeSIL.cpp | 10 +- test/AutoDiff/currying.swift | 2 +- unittests/AST/SILAutoDiffIndices.cpp | 158 +++++++++----- 11 files changed, 352 insertions(+), 330 deletions(-) diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index fa3d561bd0bfe..014dd98398820 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -223,22 +223,31 @@ class AutoDiffParameterIndicesBuilder { }; class AutoDiffIndexSubset : public llvm::FoldingSetNode { -private: - using BitWord = uint64_t; +public: + typedef uint64_t BitWord; + + static constexpr unsigned bitWordSize = sizeof(BitWord); + static constexpr unsigned numBitsPerBitWord = bitWordSize * 8; + + static std::pair + getBitWordIndexAndOffset(unsigned index) { + auto bitWordIndex = index / numBitsPerBitWord; + auto bitWordOffset = index % numBitsPerBitWord; + return {bitWordIndex, bitWordOffset}; + } + static unsigned getNumBitWordsNeededForCapacity(unsigned capacity) { + if (capacity == 0) return 0; + return capacity / numBitsPerBitWord + 1; + } + +private: /// The total capacity of the index subset, which is `1` less than the largest /// index. unsigned capacity; /// The number of bit words in the index subset. in the index subset. unsigned numBitWords; - static std::pair getBitWordIndexAndOffset(unsigned index); - static unsigned getNumBitWordsNeededForCapacity(unsigned capacity); - - unsigned getNumBitWords() const { - return numBitWords; - } - BitWord *getBitWordsData() { return reinterpret_cast(this + 1); } @@ -263,31 +272,67 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode { return {const_cast(getBitWordsData()), getNumBitWords()}; } - explicit AutoDiffIndexSubset(unsigned capacity, unsigned numBitWords, - ArrayRef indices); + explicit AutoDiffIndexSubset(unsigned capacity, ArrayRef indices) + : capacity(capacity), + numBitWords(getNumBitWordsNeededForCapacity(capacity)) { + std::uninitialized_fill_n(getBitWordsData(), numBitWords, 0); + for (auto i : indices) { + unsigned bitWordIndex, offset; + std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(i); + getBitWord(bitWordIndex) |= (1 << offset); + } + } public: AutoDiffIndexSubset() = delete; AutoDiffIndexSubset(const AutoDiffIndexSubset &) = delete; AutoDiffIndexSubset &operator=(const AutoDiffIndexSubset &) = delete; - static AutoDiffIndexSubset *get(ASTContext &ctx, unsigned capacity, - bool includeAll = false); - static AutoDiffIndexSubset *get(ASTContext &ctx, unsigned capacity, - IntRange<> range); - static AutoDiffIndexSubset *get(ASTContext &ctx, unsigned capacity, + // Defined in ASTContext.h. + static AutoDiffIndexSubset *get(ASTContext &ctx, + unsigned capacity, ArrayRef indices); + static AutoDiffIndexSubset *getDefault(ASTContext &ctx, + unsigned capacity, + bool includeAll = false) { + if (includeAll) + return getFromRange(ctx, capacity, IntRange<>(capacity)); + return get(ctx, capacity, {}); + } + + static AutoDiffIndexSubset *getFromRange(ASTContext &ctx, + unsigned capacity, + IntRange<> range) { + return get(ctx, capacity, + SmallVector(range.begin(), range.end())); + } + + unsigned getNumBitWords() const { + return numBitWords; + } + unsigned getCapacity() const { return capacity; } class iterator; - iterator begin() const; - iterator end() const; - iterator_range getIndices() const; - unsigned getNumIndices() const; + iterator begin() const { + return iterator(this); + } + + iterator end() const { + return iterator(this, (int)capacity); + } + + iterator_range getIndices() const { + return make_range(begin(), end()); + } + + unsigned getNumIndices() const { + return (unsigned)std::distance(begin(), end()); + } bool contains(unsigned index) const { unsigned bitWordIndex, offset; @@ -295,30 +340,91 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode { return getBitWord(bitWordIndex) & (1 << offset); } - bool isEmpty() const; - bool equals(const AutoDiffIndexSubset *other) const; - bool isSubsetOf(const AutoDiffIndexSubset *other) const; - bool isSupersetOf(const AutoDiffIndexSubset *other) const; + bool isEmpty() const { + return llvm::all_of(getBitWords(), [](BitWord bw) { return !(bool)bw; }); + } + + bool equals(const AutoDiffIndexSubset *other) const { + return capacity == other->getCapacity() && + getBitWords().equals(other->getBitWords()); + } + + bool isSubsetOf(const AutoDiffIndexSubset *other) const { + assert(capacity == other->capacity); + for (auto index : range(numBitWords)) + if (getBitWord(index) & ~other->getBitWord(index)) + return false; + return true; + } + + bool isSupersetOf(const AutoDiffIndexSubset *other) const { + assert(capacity == other->capacity); + for (auto index : range(numBitWords)) + if (~getBitWord(index) & other->getBitWord(index)) + return false; + return true; + } - AutoDiffIndexSubset *adding(unsigned index, ASTContext &ctx) const; - AutoDiffIndexSubset *extendingCapacity(ASTContext &ctx, - unsigned newCapacity) const; + AutoDiffIndexSubset *adding( + unsigned index, ASTContext &ctx) const { + assert(index < getCapacity()); + SmallVector newIndices; + newIndices.reserve(capacity + 1); + bool inserted = false; + for (auto curIndex : getIndices()) { + if (inserted && curIndex > index) { + newIndices.push_back(index); + inserted = false; + } + newIndices.push_back(curIndex); + } + return get(ctx, capacity, newIndices); + } - void Profile(llvm::FoldingSetNodeID &id) const; + AutoDiffIndexSubset *extendingCapacity( + ASTContext &ctx, unsigned newCapacity) const { + assert(newCapacity >= capacity); + if (newCapacity == capacity) + return const_cast(this); + SmallVector indices; + for (auto index : getIndices()) + indices.push_back(index); + return AutoDiffIndexSubset::get(ctx, newCapacity, indices); + } + + void Profile(llvm::FoldingSetNodeID &id) const { + id.AddInteger(capacity); + for (auto index : getIndices()) + id.AddInteger(index); + } + + void print(llvm::raw_ostream &s = llvm::outs()) const { + s << '{'; + interleave(range(capacity), [this, &s](unsigned i) { s << contains(i); }, + [&s] { s << ", "; }); + s << '}'; + } + + void dump(llvm::raw_ostream &s = llvm::errs()) const { + s << "(autodiff_index_subset capacity=" << capacity << " indices=("; + interleave(getIndices(), [&s](unsigned i) { s << i; }, + [&s] { s << ", "; }); + s << "))"; + } -private: int findNext(int startIndex) const; int findFirst() const { return findNext(-1); } int findPrevious(int endIndex) const; int findLast() const { return findPrevious(capacity); } -public: class iterator { - typedef unsigned value_type; - typedef int difference_type; - typedef unsigned * pointer; - typedef unsigned & reference; - typedef std::forward_iterator_tag iterator_category; + public: + typedef unsigned value_type; + typedef unsigned difference_type; + typedef unsigned * pointer; + typedef unsigned & reference; + typedef std::forward_iterator_tag iterator_category; + private: const AutoDiffIndexSubset *parent; int current = 0; @@ -349,42 +455,19 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode { unsigned operator*() const { return current; } bool operator==(const iterator &other) const { - assert(&parent == &other.parent && + assert(parent == other.parent && "Comparing iterators from different AutoDiffIndexSubsets"); return current == other.current; } bool operator!=(const iterator &other) const { - assert(&parent == &other.parent && + assert(parent == other.parent && "Comparing iterators from different AutoDiffIndexSubsets"); return current != other.current; } }; }; -class AutoDiffFunctionParameterSubset { -private: - AutoDiffIndexSubset *indexSubset; - bool curried; - -public: - explicit AutoDiffFunctionParameterSubset( - AutoDiffIndexSubset *indexSubset, bool isCurried) - : indexSubset(indexSubset), curried(isCurried) {} - - explicit AutoDiffFunctionParameterSubset( - ASTContext &ctx, AutoDiffIndexSubset *parameterSubset, - Optional isSelfIncluded); - - AutoDiffIndexSubset *getIndexSubset() const { - return indexSubset; - } - - bool isCurried() const { - return curried; - } -}; - /// SIL-level automatic differentiation indices. Consists of a source index, /// i.e. index of the dependent result to differentiate from, and parameter /// indices, i.e. index of independent parameters to differentiate with diff --git a/include/swift/Serialization/ModuleFormat.h b/include/swift/Serialization/ModuleFormat.h index b5a40fc921ad0..e02e226f474e3 100644 --- a/include/swift/Serialization/ModuleFormat.h +++ b/include/swift/Serialization/ModuleFormat.h @@ -52,7 +52,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0; /// describe what change you made. The content of this comment isn't important; /// it just ensures a conflict if two people change the module format. /// Don't worry about adhering to the 80-column limit for this line. -const uint16_t SWIFTMODULE_VERSION_MINOR = 489; // Last change: `@differentiating` wrt +const uint16_t SWIFTMODULE_VERSION_MINOR = 490; // Last change: `@differentiable` parameter indices layout. using DeclIDField = BCFixed<31>; diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index 59de2bef59cd7..6b7026f23c9dd 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -4536,9 +4536,9 @@ AutoDiffParameterIndices::get(llvm::SmallBitVector indices, ASTContext &C) { return newNode; } -AutoDiffIndexSubset *AutoDiffIndexSubset::get(ASTContext &ctx, - unsigned capacity, - ArrayRef indices) { +AutoDiffIndexSubset * +AutoDiffIndexSubset::get(ASTContext &ctx, unsigned capacity, + ArrayRef indices) { auto &foldingSet = ctx.getImpl().AutoDiffIndexSubsets; llvm::FoldingSetNodeID id; id.AddInteger(capacity); @@ -4548,7 +4548,7 @@ AutoDiffIndexSubset *AutoDiffIndexSubset::get(ASTContext &ctx, for (unsigned index : indices) { #ifndef NDEBUG assert((int)index > last && "Indices must be ascending"); - last = index; + last = (int)index; #endif id.AddInteger(index); } @@ -4556,11 +4556,11 @@ AutoDiffIndexSubset *AutoDiffIndexSubset::get(ASTContext &ctx, auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos); if (existing) return existing; - auto numBitWords = sizeof(AutoDiffIndexSubset) + - getNumBitWordsNeededForCapacity(capacity); + auto sizeToAlloc = sizeof(AutoDiffIndexSubset) + + getNumBitWordsNeededForCapacity(capacity); auto *buf = reinterpret_cast( - ctx.Allocate(numBitWords, alignof(AutoDiffIndexSubset))); - auto *newNode = new (buf) AutoDiffIndexSubset(capacity, numBitWords, indices); + ctx.Allocate(sizeToAlloc, alignof(AutoDiffIndexSubset))); + auto *newNode = new (buf) AutoDiffIndexSubset(capacity, indices); foldingSet.InsertNode(newNode, insertPos); return newNode; } diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index f2786a1daad1b..11f77eebe524e 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -331,174 +331,48 @@ NominalTypeDecl *VectorSpace::getNominal() const { return getVector()->getNominalOrBoundGenericNominal(); } -std::pair -AutoDiffIndexSubset::getBitWordIndexAndOffset(unsigned index) { - auto bitWordIndex = (index + 1) / sizeof(BitWord); - auto bitWordOffset = index - bitWordIndex * sizeof(BitWord); - return {bitWordIndex, bitWordOffset}; -} - -unsigned -AutoDiffIndexSubset::getNumBitWordsNeededForCapacity(unsigned capacity) { - auto numBitWords = capacity / sizeof(BitWord); - if (capacity % sizeof(BitWord)) - numBitWords += 1; - return numBitWords; -} - -AutoDiffIndexSubset::AutoDiffIndexSubset(unsigned capacity, - unsigned numBitWords, - ArrayRef indices) - : capacity(capacity), numBitWords(numBitWords) { - std::uninitialized_fill_n(getBitWordsData(), numBitWords, 0); - for (auto i : indices) { - unsigned bitWordIndex, offset; - std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(i); - getBitWord(bitWordIndex) |= (1 << offset); - } -} - -AutoDiffIndexSubset * -AutoDiffIndexSubset::get(ASTContext &ctx, unsigned capacity, bool includeAll) { - return get(ctx, capacity, - SmallVector(capacity, (unsigned)includeAll)); -} - -AutoDiffIndexSubset *AutoDiffIndexSubset::get(ASTContext &ctx, - unsigned capacity, - IntRange<> range) { - return get(ctx, capacity, - SmallVector(range.begin(), range.end())); -} - -void AutoDiffIndexSubset::Profile(llvm::FoldingSetNodeID &id) const { - id.AddInteger(capacity); - for (auto index : getIndices()) - id.AddInteger(index); -} - -AutoDiffIndexSubset::iterator AutoDiffIndexSubset::begin() const { - return iterator(this); -} - -AutoDiffIndexSubset::iterator AutoDiffIndexSubset::end() const { - return iterator(this, findLast() + 1); -} - -iterator_range -AutoDiffIndexSubset::getIndices() const { - return make_range(begin(), end()); -} - -unsigned AutoDiffIndexSubset::getNumIndices() const { - return accumulate(getIndices(), (unsigned)0, - [](unsigned total, BitWord bitWord) { - return total + llvm::countPopulation(bitWord); - }); -} - -bool AutoDiffIndexSubset::isEmpty() const { - return llvm::all_of(getBitWords(), [](BitWord bw) { return !(bool)bw; }); -} - -bool AutoDiffIndexSubset::equals(const AutoDiffIndexSubset *other) const { - return capacity == other->getCapacity() && - getBitWords().equals(other->getBitWords()); -} - -bool AutoDiffIndexSubset:: -isSubsetOf(const AutoDiffIndexSubset *other) const { - assert(capacity == other->capacity); - for (auto index : range(numBitWords)) - if (getBitWord(index) & ~other->getBitWord(index)) - return false; - return true; -} - -bool AutoDiffIndexSubset:: -isSupersetOf(const AutoDiffIndexSubset *other) const { - assert(capacity == other->capacity); - for (auto index : range(numBitWords)) - if (~getBitWord(index) & other->getBitWord(index)) - return false; - return true; -} - -AutoDiffIndexSubset * -AutoDiffIndexSubset::adding(unsigned index, ASTContext &ctx) const { - assert(index < getCapacity()); - SmallVector newIndices; - newIndices.reserve(capacity + 1); - bool inserted = false; - for (auto it = begin(); it != end(); ++it) { - auto curIndex = *it; - if (inserted && curIndex > index) { - newIndices.push_back(index); - inserted = false; - } - newIndices.push_back(curIndex); - } - return get(ctx, capacity, newIndices); -} - -AutoDiffIndexSubset *AutoDiffIndexSubset::extendingCapacity( - ASTContext &ctx, unsigned newCapacity) const { - assert(newCapacity >= getCapacity()); - SmallVector indices; - for (auto index : getIndices()) - indices.push_back(index); - return AutoDiffIndexSubset::get(ctx, newCapacity, indices); -} - int AutoDiffIndexSubset::findNext(int startIndex) const { - if (numBitWords == 0) - return -1; + assert(startIndex < (int)capacity && "Start index cannot be past the end"); unsigned bitWordIndex = 0, offset = 0; - if (startIndex >= 0) - std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(startIndex); - for (; bitWordIndex < getNumBitWords(); ++bitWordIndex) { - auto bitWord = getBitWord(bitWordIndex); - if (!bitWord) { - ++bitWordIndex; - continue; - } - while (offset < sizeof(BitWord)) { - if (bitWord & (1 << offset)) - return bitWordIndex * sizeof(BitWord) + offset; - ++offset; + if (startIndex >= 0) { + auto indexAndOffset = getBitWordIndexAndOffset(startIndex); + bitWordIndex = indexAndOffset.first; + offset = indexAndOffset.second + 1; + } + for (; bitWordIndex < numBitWords; ++bitWordIndex, offset = 0) { + for (; offset < numBitsPerBitWord; ++offset) { + auto index = bitWordIndex * numBitsPerBitWord + offset; + auto bitWord = getBitWord(bitWordIndex); + if (!bitWord) + break; + if (index >= capacity) + return capacity; + if (bitWord & ((BitWord)1 << offset)) + return index; } - offset = 0; } - return -1; + return capacity; } int AutoDiffIndexSubset::findPrevious(int endIndex) const { - if (numBitWords == 0) - return -1; - unsigned bitWordIndex, offset; - std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(endIndex); - for (; bitWordIndex >= 0; --bitWordIndex) { - auto bitWord = getBitWord(bitWordIndex); - if (!bitWord) { - --bitWordIndex; - continue; - } - while (offset >= 0) { - if (bitWord & (1 << offset)) - return bitWordIndex * sizeof(BitWord) + offset; - --offset; + assert(endIndex >= 0 && "End index cannot be before the start"); + int bitWordIndex = numBitWords - 1, offset = numBitsPerBitWord - 1; + if (endIndex < (int)capacity) { + auto indexAndOffset = getBitWordIndexAndOffset(endIndex); + bitWordIndex = (int)indexAndOffset.first; + offset = (int)indexAndOffset.second - 1; + } + for (; bitWordIndex >= 0; --bitWordIndex, offset = numBitsPerBitWord - 1) { + for (; offset < (int)numBitsPerBitWord; --offset) { + auto index = bitWordIndex * (int)numBitsPerBitWord + offset; + auto bitWord = getBitWord(bitWordIndex); + if (!bitWord) + break; + if (index < 0) + return -1; + if (bitWord & ((BitWord)1 << offset)) + return index; } - offset = sizeof(BitWord) - 1; } return -1; } - -AutoDiffFunctionParameterSubset:: -AutoDiffFunctionParameterSubset(ASTContext &ctx, - AutoDiffIndexSubset *parameterSubset, - Optional isSelfIncluded) { - if (isSelfIncluded.hasValue()) { - indexSubset = - AutoDiffIndexSubset::get(ctx, parameterSubset->getCapacity() + 1); - } -} diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index a2a68b6070792..04c90102ed669 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -1159,7 +1159,7 @@ class SILPrinter : public SILInstructionVisitor { // SWIFT_ENABLE_TENSORFLOW void visitAutoDiffFunctionInst(AutoDiffFunctionInst *adfi) { - if (adfi->getParameterIndices()->isEmpty()) { + if (!adfi->getParameterIndices()->isEmpty()) { *this << "[wrt"; for (auto i : adfi->getParameterIndices()->getIndices()) *this << ' ' << i; diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 220f4eaa40037..a082c3910e229 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -213,9 +213,10 @@ static FuncDecl *findOperatorDeclInProtocol(DeclName operatorName, /// qualifier for creating a `store` instruction into the buffer. static StoreOwnershipQualifier getBufferSOQ(Type type, SILFunction &fn) { if (fn.hasOwnership()) - return fn.getModule().Types.getTypeLowering(type, ResilienceExpansion::Minimal).isTrivial() - ? StoreOwnershipQualifier::Trivial - : StoreOwnershipQualifier::Init; + return fn.getModule().Types.getTypeLowering( + type, ResilienceExpansion::Minimal).isTrivial() + ? StoreOwnershipQualifier::Trivial + : StoreOwnershipQualifier::Init; return StoreOwnershipQualifier::Unqualified; } @@ -223,9 +224,10 @@ static StoreOwnershipQualifier getBufferSOQ(Type type, SILFunction &fn) { /// qualified for creating a `load` instruction from the buffer. static LoadOwnershipQualifier getBufferLOQ(Type type, SILFunction &fn) { if (fn.hasOwnership()) - return fn.getModule().Types.getTypeLowering(type, ResilienceExpansion::Minimal).isTrivial() - ? LoadOwnershipQualifier::Trivial - : LoadOwnershipQualifier::Take; + return fn.getModule().Types.getTypeLowering( + type, ResilienceExpansion::Minimal).isTrivial() + ? LoadOwnershipQualifier::Trivial + : LoadOwnershipQualifier::Take; return LoadOwnershipQualifier::Unqualified; } @@ -932,16 +934,27 @@ class ADContext { DifferentiationTask * lookUpMinimalDifferentiationTask(SILFunction *original, const SILAutoDiffIndices &indices) { - AutoDiffIndexSubset *superset = nullptr; + auto *superset = AutoDiffIndexSubset::getDefault( + getASTContext(), + original->getLoweredFunctionType()->getNumParameters(), false); auto *indexSet = indices.parameters; if (auto *existingTask = lookUpDifferentiationTask(original, indices)) return existingTask; for (auto *rda : original->getDifferentiableAttrs()) { auto *rdaIndexSet = rda->getIndices().parameters; - // If all indices in indexSet are in rdaIndexSet, and it has fewer - // indices than our current candidate and a primitive adjoint, rda is our + // If all indices in `indexSet` are in `rdaIndexSet`, and it has fewer + // indices than our current candidate and a primitive VJP, `rda` is our // new candidate. - if (rdaIndexSet->isSubsetOf(indexSet)) + // + // NOTE: `rda` may come from a un-partial-applied function, it may have + // more parameters than the desired indices. We expect this logic to go + // away when we support `@differentiable` partial apply. + if (rdaIndexSet->isSupersetOf( + indexSet->extendingCapacity(getASTContext(), + rdaIndexSet->getCapacity())) && + // fewer parameters than before + (superset->isEmpty() || + rdaIndexSet->getNumIndices() < superset->getNumIndices())) superset = rda->getIndices().parameters; } auto existing = @@ -1812,7 +1825,7 @@ emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder, if (diffableFnType->isDifferentiable()) { auto paramIndices = diffableFnType->getDifferentiationParameterIndices(); for (auto i : desiredIndices.parameters->getIndices()) { - if (i >= paramIndices->getCapacity() || !paramIndices->contains(i)) { + if (!paramIndices->contains(i)) { context.emitNondifferentiabilityError(original, parentTask, diag::autodiff_function_nondiff_parameter_not_differentiable); return None; @@ -1968,8 +1981,15 @@ emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder, SILAutoDiffIndices requirementIndices(/*source*/ 0, loweredRequirementIndices); + // NOTE: We need to extend the capacity of desired parameter indices to + // requirement parameter indices, because there's a argument count mismatch. + // When `@differentiable` partial apply is supported, this problem will go + // away. if (desiredIndices.source != requirementIndices.source || - !desiredIndices.parameters->isSubsetOf(requirementIndices.parameters)) { + !desiredIndices.parameters->extendingCapacity( + context.getASTContext(), + requirementIndices.parameters->getCapacity()) + ->isSubsetOf(requirementIndices.parameters)) { context.emitNondifferentiabilityError(original, parentTask, diag::autodiff_protocol_member_subset_indices_not_differentiable); return None; @@ -2602,7 +2622,8 @@ class PrimalGenCloner final auto &builder = getBuilder(); builder.setInsertionPoint(exit); auto structLoweredTy = - getContext().getTypeConverter().getLoweredType(structTy, ResilienceExpansion::Minimal); + getContext().getTypeConverter().getLoweredType( + structTy, ResilienceExpansion::Minimal); auto primValsVal = builder.createStruct(loc, structLoweredTy, primalValues); // If the original result was a tuple, return a tuple of all elements in the // original result tuple and the primal value struct value. @@ -2693,9 +2714,8 @@ class PrimalGenCloner final errorOccurred = true; return; } - SILAutoDiffIndices indices( - /*source*/ 0, - /*parameters*/ AutoDiffIndexSubset::get(getASTContext(), 1)); + SILAutoDiffIndices indices(/*source*/ 0, + AutoDiffIndexSubset::getDefault(getASTContext(), 1, true)); auto *task = getContext().lookUpDifferentiationTask(getterFn, indices); if (!task) { getContext().emitNondifferentiabilityError( @@ -2765,7 +2785,7 @@ class PrimalGenCloner final return; } SILAutoDiffIndices indices(/*source*/ 0, - /*parameters*/ AutoDiffIndexSubset::get(getASTContext(), 1)); + AutoDiffIndexSubset::getDefault(getASTContext(), 1, true)); auto *task = getContext().lookUpDifferentiationTask(getterFn, indices); if (!task) { getContext().emitNondifferentiabilityError( @@ -4211,8 +4231,7 @@ class AdjointEmitter final : public SILInstructionVisitor { auto cotan = *allResultsIt++; // If a cotangent value corresponds to a non-desired parameter, it won't // be used, so release it. - if (i >= applyInfo.desiredIndices.parameters->getCapacity() || - !applyInfo.desiredIndices.parameters->contains(i)) { + if (!applyInfo.desiredIndices.parameters->contains(i)) { emitCleanup(builder, loc, cotan); continue; } @@ -4272,8 +4291,9 @@ class AdjointEmitter final : public SILInstructionVisitor { AutoDiffAssociatedVectorSpaceKind::Cotangent, LookUpConformanceInModule(getModule().getSwiftModule())) ->getType()->getCanonicalType(); - assert(!getModule().Types.getTypeLowering(cotangentVectorTy, ResilienceExpansion::Minimal) - .isAddressOnly()); + assert(!getModule().Types.getTypeLowering( + cotangentVectorTy, ResilienceExpansion::Minimal) + .isAddressOnly()); auto *cotangentVectorDecl = cotangentVectorTy->getStructOrBoundGenericStruct(); assert(cotangentVectorDecl); @@ -4344,8 +4364,9 @@ class AdjointEmitter final : public SILInstructionVisitor { AutoDiffAssociatedVectorSpaceKind::Cotangent, LookUpConformanceInModule(getModule().getSwiftModule())) ->getType()->getCanonicalType(); - assert(!getModule().Types.getTypeLowering(cotangentVectorTy, ResilienceExpansion::Minimal) - .isAddressOnly()); + assert(!getModule().Types.getTypeLowering( + cotangentVectorTy, ResilienceExpansion::Minimal) + .isAddressOnly()); auto cotangentVectorSILTy = SILType::getPrimitiveObjectType(cotangentVectorTy); auto *cotangentVectorDecl = @@ -4381,7 +4402,8 @@ class AdjointEmitter final : public SILInstructionVisitor { field->getModuleContext(), field); auto fieldTy = field->getType().subst(substMap); auto fieldSILTy = - getContext().getTypeConverter().getLoweredType(fieldTy, ResilienceExpansion::Minimal); + getContext().getTypeConverter().getLoweredType( + fieldTy, ResilienceExpansion::Minimal); assert(fieldSILTy.isObject()); eltVals.push_back(makeZeroAdjointValue(fieldSILTy)); } @@ -4937,7 +4959,8 @@ void AdjointEmitter::emitZeroIndirect(CanType type, SILValue bufferAccess, } SILValue AdjointEmitter::emitZeroDirect(CanType type, SILLocation loc) { - auto silType = getModule().Types.getLoweredLoadableType(type, ResilienceExpansion::Minimal); + auto silType = getModule().Types.getLoweredLoadableType( + type, ResilienceExpansion::Minimal); auto *buffer = builder.createAllocStack(loc, silType); auto *initAccess = builder.createBeginAccess(loc, buffer, SILAccessKind::Init, SILAccessEnforcement::Static, @@ -5415,7 +5438,8 @@ void DifferentiationTask::createEmptyAdjoint() { // Given a type, returns its formal SIL parameter info. auto getCotangentParameterInfoForOriginalResult = [&]( CanType cotanType, ResultConvention origResConv) -> SILParameterInfo { - auto &tl = context.getTypeConverter().getTypeLowering(cotanType, ResilienceExpansion::Minimal); + auto &tl = context.getTypeConverter().getTypeLowering( + cotanType, ResilienceExpansion::Minimal); ParameterConvention conv; switch (origResConv) { case ResultConvention::Owned: @@ -5438,7 +5462,8 @@ void DifferentiationTask::createEmptyAdjoint() { // Given a type, returns its formal SIL result info. auto getCotangentResultInfoForOriginalParameter = [&]( CanType cotanType, ParameterConvention origParamConv) -> SILResultInfo { - auto &tl = context.getTypeConverter().getTypeLowering(cotanType, ResilienceExpansion::Minimal); + auto &tl = context.getTypeConverter().getTypeLowering( + cotanType, ResilienceExpansion::Minimal); ResultConvention conv; switch (origParamConv) { case ParameterConvention::Direct_Owned: @@ -5661,7 +5686,7 @@ void DifferentiationTask::createVJP(bool isExported) { assert(vjpConv.getResults().size() == numOriginalResults + 1 && "unexpected number of vjp results"); assert(adjointConv.getResults().size() == - getIndices().parameters->getCapacity() && + getIndices().parameters->getNumIndices() && "unexpected number of adjoint results"); // We assume that primal result conventions (for all results but the optional @@ -5815,9 +5840,6 @@ SILValue ADContext::promoteToDifferentiableFunction( // Return nullptr. if (!assocFnAndIndices) return nullptr; - assert(assocFnAndIndices->second == desiredIndices && - "FIXME: We could emit a thunk that converts the VJP to have the " - "desired indices."); auto assocFn = assocFnAndIndices->first; builder.createRetainValue(loc, assocFn, builder.getDefaultAtomicity()); assocFns.push_back(assocFn); diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index 25bdcffd63eb9..76ec9d911e1db 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -653,7 +653,7 @@ SILDeserializer::readSILFunctionChecked(DeclID FID, SILFunction *existingFn, uint64_t jvpNameId; uint64_t vjpNameId; - uint64_t source; + unsigned source; ArrayRef rawParameterIndices; SmallVector requirements; @@ -662,6 +662,7 @@ SILDeserializer::readSILFunctionChecked(DeclID FID, SILFunction *existingFn, StringRef jvpName = MF->getIdentifier(jvpNameId).str(); StringRef vjpName = MF->getIdentifier(vjpNameId).str(); + SmallVector parameterIndices(rawParameterIndices.begin(), rawParameterIndices.end()); auto *parameterIndexSubset = AutoDiffIndexSubset::get( diff --git a/lib/Serialization/SILFormat.h b/lib/Serialization/SILFormat.h index d9ec58038f2ef..bc0986b037629 100644 --- a/lib/Serialization/SILFormat.h +++ b/lib/Serialization/SILFormat.h @@ -317,10 +317,10 @@ namespace sil_block { // SWIFT_ENABLE_TENSORFLOW using SILDifferentiableAttrLayout = BCRecordLayout< SIL_DIFFERENTIABLE_ATTR, - IdentifierIDField, // JVP name. - IdentifierIDField, // VJP name. - BCFixed<32>, // Indices' source. - BCArray> // Indices' parameters bitvector. + IdentifierIDField, // JVP name. + IdentifierIDField, // VJP name. + BCVBR<8>, // Result index. + BCArray // Parameter indices. >; // Has an optional argument list where each argument is a typed valueref. diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp index 3ac5f9d5b2b81..ef05131b34092 100644 --- a/lib/Serialization/SerializeSIL.cpp +++ b/lib/Serialization/SerializeSIL.cpp @@ -437,11 +437,9 @@ void SILSerializer::writeSILFunction(const SILFunction &F, bool DeclOnly) { assert(DA->hasJVP() && DA->hasVJP() && "JVP and VJP must exist in canonical SIL"); - auto ¶mIndices = DA->getIndices(); - SmallVector parameters; - for (unsigned i : range(paramIndices.parameters->getCapacity())) - parameters.push_back(paramIndices.parameters->contains(i)); - + auto &indices = DA->getIndices(); + SmallVector parameters(indices.parameters->begin(), + indices.parameters->end()); SILDifferentiableAttrLayout::emitRecord( Out, ScratchRecord, differentiableAttrAbbrCode, DA->hasJVP() @@ -450,7 +448,7 @@ void SILSerializer::writeSILFunction(const SILFunction &F, bool DeclOnly) { DA->hasVJP() ? S.addDeclBaseNameRef(Ctx.getIdentifier(DA->getVJPName())) : IdentifierID(), - paramIndices.source, parameters); + indices.source, parameters); S.writeGenericRequirements(DA->getRequirements(), SILAbbrCodes); } diff --git a/test/AutoDiff/currying.swift b/test/AutoDiff/currying.swift index 08bd469d40353..8cd47d997a94f 100644 --- a/test/AutoDiff/currying.swift +++ b/test/AutoDiff/currying.swift @@ -7,7 +7,7 @@ var CurryingAutodiffTests = TestSuite("CurryingAutodiff") CurryingAutodiffTests.test("StructMember") { struct A { @differentiable(wrt: (value)) - func v(_ value: Float) -> Float { return value * value } + func v(_ value: Float) -> Float { return value * value } } let a = A() diff --git a/unittests/AST/SILAutoDiffIndices.cpp b/unittests/AST/SILAutoDiffIndices.cpp index c71706abb219a..b26f9393a2280 100644 --- a/unittests/AST/SILAutoDiffIndices.cpp +++ b/unittests/AST/SILAutoDiffIndices.cpp @@ -12,71 +12,115 @@ // SWIFT_ENABLE_TENSORFLOW #include "swift/AST/AutoDiff.h" +#include "TestContext.h" #include "gtest/gtest.h" using namespace swift; +using namespace swift::unittest; -TEST(SILAutoDiffIndices, EqualityAndHash) { - using IndicesDenseMapInfo = llvm::DenseMapInfo; - - std::array empty; - // Each example is distinct. - SILAutoDiffIndices examples[] = { - {0, empty}, - {1, empty}, - {0, {0}}, - {0, {0, 1}}, - {0, {1}}, - {0, {1, 2}}, - {0, {100}}, - {0, {0, 100}}, - {0, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}} - }; - size_t exampleCount = std::extent::value; - for (size_t i = 0; i < exampleCount; ++i) { - auto example1 = examples[i]; - auto grownExample1 = example1; - grownExample1.parameters.resize(grownExample1.parameters.size() + 1); - - // Make sure that the grown example is actually grown. - EXPECT_TRUE(example1.parameters.size() < grownExample1.parameters.size()); - - // Test that the example is equal to itself and to the grown version of - // itself, using both operator== and IndicesDenseMapInfo::isEqual. - EXPECT_TRUE(example1 == example1); - EXPECT_TRUE(example1 == grownExample1); - EXPECT_TRUE(grownExample1 == example1); - EXPECT_TRUE(IndicesDenseMapInfo::isEqual(example1, example1)); - EXPECT_TRUE(IndicesDenseMapInfo::isEqual(example1, grownExample1)); - EXPECT_TRUE(IndicesDenseMapInfo::isEqual(grownExample1, example1)); +TEST(AutoDiffIndexSubset, NumBitWordsNeeded) { + EXPECT_EQ(AutoDiffIndexSubset::getNumBitWordsNeededForCapacity(0), 0u); + EXPECT_EQ(AutoDiffIndexSubset::getNumBitWordsNeededForCapacity(1), 1u); + EXPECT_EQ(AutoDiffIndexSubset::getNumBitWordsNeededForCapacity(5), 1u); + EXPECT_EQ(AutoDiffIndexSubset::getNumBitWordsNeededForCapacity( + AutoDiffIndexSubset::numBitsPerBitWord - 1), 1u); + EXPECT_EQ(AutoDiffIndexSubset::getNumBitWordsNeededForCapacity( + AutoDiffIndexSubset::numBitsPerBitWord), 2u); + EXPECT_EQ(AutoDiffIndexSubset::getNumBitWordsNeededForCapacity( + AutoDiffIndexSubset::numBitsPerBitWord * 2 - 1), 2u); + EXPECT_EQ(AutoDiffIndexSubset::getNumBitWordsNeededForCapacity( + AutoDiffIndexSubset::numBitsPerBitWord * 2), 3u); +} - // Test that the grown version has the same hash as the original. - EXPECT_EQ(IndicesDenseMapInfo::getHashValue(example1), - IndicesDenseMapInfo::getHashValue(grownExample1)); +TEST(AutoDiffIndexSubset, BitWordIndexAndOffset) { + EXPECT_EQ(AutoDiffIndexSubset::getBitWordIndexAndOffset(0), + std::make_pair(0u, 0u)); + EXPECT_EQ(AutoDiffIndexSubset::getBitWordIndexAndOffset(5), + std::make_pair(0u, 5u)); + EXPECT_EQ(AutoDiffIndexSubset::getBitWordIndexAndOffset(8), + std::make_pair(0u, 8u)); + EXPECT_EQ(AutoDiffIndexSubset::getBitWordIndexAndOffset( + AutoDiffIndexSubset::numBitsPerBitWord - 1), + std::make_pair(0u, AutoDiffIndexSubset::numBitsPerBitWord - 1)); + EXPECT_EQ(AutoDiffIndexSubset::getBitWordIndexAndOffset( + AutoDiffIndexSubset::numBitsPerBitWord), + std::make_pair(1u, 0u)); +} - // Test that the example is not equal to any of the others. - for (size_t j = i + 1; j < exampleCount; ++j) { - auto example2 = examples[j]; - auto grownExample2 = example2; - grownExample2.parameters.resize(grownExample2.parameters.size() + 1); +TEST(AutoDiffIndexSubset, Bits) { + TestContext ctx; + auto *indices1 = AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {0, 2, 4}); + EXPECT_EQ(indices1->getNumBitWords(), 1u); + EXPECT_EQ(indices1->getCapacity(), 5u); + EXPECT_TRUE(indices1->contains(0)); + EXPECT_FALSE(indices1->contains(1)); + EXPECT_TRUE(indices1->contains(2)); + EXPECT_FALSE(indices1->contains(3)); + EXPECT_TRUE(indices1->contains(4)); - // Make sure that the grown example is actually grown. - EXPECT_TRUE(example2.parameters.size() < grownExample2.parameters.size()); + auto *indices2 = AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {1, 3}); + EXPECT_EQ(indices2->getNumBitWords(), 1u); + EXPECT_EQ(indices2->getCapacity(), 5u); + EXPECT_FALSE(indices2->contains(0)); + EXPECT_TRUE(indices2->contains(1)); + EXPECT_FALSE(indices2->contains(2)); + EXPECT_TRUE(indices2->contains(3)); + EXPECT_FALSE(indices2->contains(4)); +} - EXPECT_FALSE(example1 == example2); - EXPECT_FALSE(example2 == example1); - EXPECT_FALSE(IndicesDenseMapInfo::isEqual(example1, example2)); - EXPECT_FALSE(IndicesDenseMapInfo::isEqual(example2, example1)); +TEST(AutoDiffIndexSubset, Iteration) { + TestContext ctx; + // Test 1 + { + auto *indices1 = AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {0, 2, 4}); + // Check forward iteration. + EXPECT_EQ(indices1->findFirst(), 0); + EXPECT_EQ(indices1->findNext(0), 2); + EXPECT_EQ(indices1->findNext(2), 4); + EXPECT_EQ(indices1->findNext(4), (int)indices1->getCapacity()); + // Check reverse iteration. + EXPECT_EQ(indices1->findLast(), 4); + EXPECT_EQ(indices1->findPrevious(4), 2); + EXPECT_EQ(indices1->findPrevious(2), 0); + EXPECT_EQ(indices1->findPrevious(0), -1); + // Check range. + unsigned indices1Elements[3] = {0, 2, 4}; + EXPECT_TRUE(std::equal(indices1->begin(), indices1->end(), + indices1Elements)); + } + // Test 2 + { + auto *indices2 = AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {1, 3}); + // Check forward iteration. + EXPECT_EQ(indices2->findFirst(), 1); + EXPECT_EQ(indices2->findNext(1), 3); + EXPECT_EQ(indices2->findNext(3), (int)indices2->getCapacity()); + // Check reverse iteration. + EXPECT_EQ(indices2->findLast(), 3); + EXPECT_EQ(indices2->findPrevious(3), 1); + EXPECT_EQ(indices2->findPrevious(1), -1); + // Check range. + unsigned indices2Elements[2] = {1, 3}; + EXPECT_TRUE(std::equal(indices2->begin(), indices2->end(), + indices2Elements)); + } +} - EXPECT_FALSE(example1 == grownExample2); - EXPECT_FALSE(grownExample2 == example1); - EXPECT_FALSE(IndicesDenseMapInfo::isEqual(example1, grownExample2)); - EXPECT_FALSE(IndicesDenseMapInfo::isEqual(grownExample2, example1)); +TEST(AutoDiffIndexSubset, SupersetAndSubset) { + TestContext ctx; + auto *indices1 = AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {0, 2, 4}); + EXPECT_TRUE(indices1->isSupersetOf(indices1)); + EXPECT_TRUE(indices1->isSubsetOf(indices1)); + auto *indices2 = AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {2}); + EXPECT_TRUE(indices2->isSupersetOf(indices2)); + EXPECT_TRUE(indices2->isSubsetOf(indices2)); - EXPECT_FALSE(example2 == grownExample1); - EXPECT_FALSE(grownExample1 == example2); - EXPECT_FALSE(IndicesDenseMapInfo::isEqual(example2, grownExample1)); - EXPECT_FALSE(IndicesDenseMapInfo::isEqual(grownExample1, example2)); - } - } + EXPECT_TRUE(indices1->isSupersetOf(indices2)); + EXPECT_TRUE(indices2->isSubsetOf(indices1)); } From 0d8a6348326115a68942ffeb01cc8b750858a7a0 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Tue, 14 May 2019 04:44:22 -0700 Subject: [PATCH 6/8] Add equality test. --- unittests/AST/SILAutoDiffIndices.cpp | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/unittests/AST/SILAutoDiffIndices.cpp b/unittests/AST/SILAutoDiffIndices.cpp index b26f9393a2280..40be6b95ce308 100644 --- a/unittests/AST/SILAutoDiffIndices.cpp +++ b/unittests/AST/SILAutoDiffIndices.cpp @@ -47,6 +47,30 @@ TEST(AutoDiffIndexSubset, BitWordIndexAndOffset) { std::make_pair(1u, 0u)); } +TEST(AutoDiffIndexSubset, Equality) { + TestContext ctx; + EXPECT_EQ(AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {0}), + AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {0})); + EXPECT_EQ(AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {0, 2, 4}), + AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {0, 2, 4})); + EXPECT_EQ(AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {}), + AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {})); + EXPECT_NE(AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 1, + /*indices*/ {}), + AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 0, + /*indices*/ {})); + EXPECT_NE(AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {0}), + AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, + /*indices*/ {})); +} + TEST(AutoDiffIndexSubset, Bits) { TestContext ctx; auto *indices1 = AutoDiffIndexSubset::get(ctx.Ctx, /*capacity*/ 5, From f79d5396282bfd6e32a3746201ed9e7d4b741f4d Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Tue, 14 May 2019 05:02:46 -0700 Subject: [PATCH 7/8] Add tests for `adding` and fix bugs. --- include/swift/AST/AutoDiff.h | 48 ++++------------------------ lib/AST/AutoDiff.cpp | 45 ++++++++++++++++++++++++++ lib/Serialization/DeserializeSIL.cpp | 2 +- unittests/AST/SILAutoDiffIndices.cpp | 10 ++++++ 4 files changed, 62 insertions(+), 43 deletions(-) diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 014dd98398820..3511787e20d3b 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -344,53 +344,17 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode { return llvm::all_of(getBitWords(), [](BitWord bw) { return !(bool)bw; }); } - bool equals(const AutoDiffIndexSubset *other) const { + bool equals(AutoDiffIndexSubset *other) const { return capacity == other->getCapacity() && getBitWords().equals(other->getBitWords()); } - bool isSubsetOf(const AutoDiffIndexSubset *other) const { - assert(capacity == other->capacity); - for (auto index : range(numBitWords)) - if (getBitWord(index) & ~other->getBitWord(index)) - return false; - return true; - } - - bool isSupersetOf(const AutoDiffIndexSubset *other) const { - assert(capacity == other->capacity); - for (auto index : range(numBitWords)) - if (~getBitWord(index) & other->getBitWord(index)) - return false; - return true; - } - - AutoDiffIndexSubset *adding( - unsigned index, ASTContext &ctx) const { - assert(index < getCapacity()); - SmallVector newIndices; - newIndices.reserve(capacity + 1); - bool inserted = false; - for (auto curIndex : getIndices()) { - if (inserted && curIndex > index) { - newIndices.push_back(index); - inserted = false; - } - newIndices.push_back(curIndex); - } - return get(ctx, capacity, newIndices); - } + bool isSubsetOf(AutoDiffIndexSubset *other) const; + bool isSupersetOf(AutoDiffIndexSubset *other) const; - AutoDiffIndexSubset *extendingCapacity( - ASTContext &ctx, unsigned newCapacity) const { - assert(newCapacity >= capacity); - if (newCapacity == capacity) - return const_cast(this); - SmallVector indices; - for (auto index : getIndices()) - indices.push_back(index); - return AutoDiffIndexSubset::get(ctx, newCapacity, indices); - } + AutoDiffIndexSubset *adding(unsigned index, ASTContext &ctx) const; + AutoDiffIndexSubset *extendingCapacity(ASTContext &ctx, + unsigned newCapacity) const; void Profile(llvm::FoldingSetNodeID &id) const { id.AddInteger(capacity); diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index 11f77eebe524e..755015902e2dd 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -331,6 +331,51 @@ NominalTypeDecl *VectorSpace::getNominal() const { return getVector()->getNominalOrBoundGenericNominal(); } +bool AutoDiffIndexSubset::isSubsetOf(AutoDiffIndexSubset *other) const { + assert(capacity == other->capacity); + for (auto index : range(numBitWords)) + if (getBitWord(index) & ~other->getBitWord(index)) + return false; + return true; +} + +bool AutoDiffIndexSubset::isSupersetOf(AutoDiffIndexSubset *other) const { + assert(capacity == other->capacity); + for (auto index : range(numBitWords)) + if (~getBitWord(index) & other->getBitWord(index)) + return false; + return true; +} + +AutoDiffIndexSubset *AutoDiffIndexSubset::adding(unsigned index, + ASTContext &ctx) const { + assert(index < getCapacity()); + if (contains(index)) + return const_cast(this); + SmallVector newIndices; + newIndices.reserve(capacity + 1); + bool inserted = false; + for (auto curIndex : getIndices()) { + if (!inserted && curIndex > index) { + newIndices.push_back(index); + inserted = true; + } + newIndices.push_back(curIndex); + } + return get(ctx, capacity, newIndices); +} + +AutoDiffIndexSubset *AutoDiffIndexSubset::extendingCapacity( + ASTContext &ctx, unsigned newCapacity) const { + assert(newCapacity >= capacity); + if (newCapacity == capacity) + return const_cast(this); + SmallVector indices; + for (auto index : getIndices()) + indices.push_back(index); + return AutoDiffIndexSubset::get(ctx, newCapacity, indices); +} + int AutoDiffIndexSubset::findNext(int startIndex) const { assert(startIndex < (int)capacity && "Start index cannot be past the end"); unsigned bitWordIndex = 0, offset = 0; diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index 76ec9d911e1db..574c2db237d36 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -1510,7 +1510,7 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB, auto numParamIndices = ListOfValues.size() - NumArguments * 3; auto rawParamIndices = map>(ListOfValues.take_front(numParamIndices), - [&](uint64_t i) { return i; }); + [](uint64_t i) { return (unsigned)i; }); auto numParams = Attr2; auto *paramIndices = AutoDiffIndexSubset::get(MF->getContext(), numParams, rawParamIndices); diff --git a/unittests/AST/SILAutoDiffIndices.cpp b/unittests/AST/SILAutoDiffIndices.cpp index 40be6b95ce308..0dbaa63ec75d2 100644 --- a/unittests/AST/SILAutoDiffIndices.cpp +++ b/unittests/AST/SILAutoDiffIndices.cpp @@ -148,3 +148,13 @@ TEST(AutoDiffIndexSubset, SupersetAndSubset) { EXPECT_TRUE(indices1->isSupersetOf(indices2)); EXPECT_TRUE(indices2->isSubsetOf(indices1)); } + +TEST(AutoDiffIndexSubset, Insertion) { + TestContext ctx; + auto *indices1 = AutoDiffIndexSubset::get(ctx.Ctx, 5, {0, 2, 4}); + EXPECT_EQ(indices1->adding(0, ctx.Ctx), indices1); + EXPECT_EQ(indices1->adding(1, ctx.Ctx), + AutoDiffIndexSubset::get(ctx.Ctx, 5, {0, 1, 2, 4})); + EXPECT_EQ(indices1->adding(3, ctx.Ctx), + AutoDiffIndexSubset::get(ctx.Ctx, 5, {0, 2, 3, 4})); +} From 659c430c6d729c12f44f38ffcf3448c3c6320aeb Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Tue, 14 May 2019 05:06:13 -0700 Subject: [PATCH 8/8] Typo fixes in comments. --- include/swift/AST/AutoDiff.h | 2 +- include/swift/AST/DiagnosticsParse.def | 2 +- lib/SIL/SILVerifier.cpp | 3 +-- lib/SILOptimizer/Mandatory/Differentiation.cpp | 4 ++-- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 3511787e20d3b..c47e0e835d71d 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -245,7 +245,7 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode { /// The total capacity of the index subset, which is `1` less than the largest /// index. unsigned capacity; - /// The number of bit words in the index subset. in the index subset. + /// The number of bit words in the index subset. unsigned numBitWords; BitWord *getBitWordsData() { diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index f06386b1cf8c7..60b3e1afa8017 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -1535,7 +1535,7 @@ ERROR(sil_inst_autodiff_operand_list_expected_rbrace,PointsToFirstBadToken, ERROR(sil_inst_autodiff_num_operand_list_order_mismatch,PointsToFirstBadToken, "the number of operand lists does not match the order", ()) ERROR(sil_inst_autodiff_expected_associated_function_kind_attr,PointsToFirstBadToken, - "expected an assoiacted function kind attribute, e.g. '[jvp]'", ()) + "expected an associated function kind attribute, e.g. '[jvp]'", ()) ERROR(sil_inst_autodiff_expected_function_type_operand,PointsToFirstBadToken, "expected an operand of a function type", ()) diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp index 64a5ad898ca75..1f2c461be29b1 100644 --- a/lib/SIL/SILVerifier.cpp +++ b/lib/SIL/SILVerifier.cpp @@ -4617,8 +4617,7 @@ class SILVerifier : public SILVerifierBase { lastIndex = currentIdx; } // TODO: Verify if the specified JVP/VJP function has the right signature. - // SIL function verification runs right after a function is - // parsed. + // SIL function verification runs right after a function is parsed. // However, the JVP/VJP function may come after the this function. Without // changing the compiler too much, is there a way to verify this at a module // level, after everything is parsed? diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index a082c3910e229..2985bd64e9390 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -943,8 +943,8 @@ class ADContext { for (auto *rda : original->getDifferentiableAttrs()) { auto *rdaIndexSet = rda->getIndices().parameters; // If all indices in `indexSet` are in `rdaIndexSet`, and it has fewer - // indices than our current candidate and a primitive VJP, `rda` is our - // new candidate. + // indices than our current candidate and a primitive VJP, then `rda` is + // our new candidate. // // NOTE: `rda` may come from a un-partial-applied function, it may have // more parameters than the desired indices. We expect this logic to go