From 7e5d804ce7fa59277530cd2cc536dd8e889bf692 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Fri, 28 Jun 2019 02:28:30 -0700 Subject: [PATCH 01/26] Diff witness --- .../swift/SIL/SILDifferentibilityWitness.h | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 include/swift/SIL/SILDifferentibilityWitness.h diff --git a/include/swift/SIL/SILDifferentibilityWitness.h b/include/swift/SIL/SILDifferentibilityWitness.h new file mode 100644 index 0000000000000..93794ae63b820 --- /dev/null +++ b/include/swift/SIL/SILDifferentibilityWitness.h @@ -0,0 +1,108 @@ +//===--- SILProperty.h - Defines the SILProperty class ----------*- C++ -*-===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// This file defines the SILProperty class, which is used to capture the +// metadata about a property definition necessary for it to be resiliently +// included in KeyPaths across modules. +// +//===----------------------------------------------------------------------===// + +#ifndef SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H +#define SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H + +#include "swift/AST/AutoDiff.h" +#include "swift/AST/GenericSignature.h" +#include "swift/SIL/SILAllocated.h" +#include "swift/SIL/SILInstruction.h" +#include "llvm/ADT/ilist_node.h" +#include "llvm/ADT/ilist.h" + +namespace swift { + +class SILPrintContext; + +class SILDifferentiabilityWitness + : public llvm::ilist_node, + public SILAllocated +{ +private: + /// The module which contains the SILWitnessTable. + SILModule &module; + /// The original function. + SILFunction *originalFunction; + /// The parameter indieces. + AutoDiffIndexSubset *parameterIndices; + /// The result indieces. + AutoDiffIndexSubset *resultIndices; + /// The max differentiation order. + unsigned maxOrder; + /// Derivative functions. + MutableArrayRef derivatives; + /// True if serialized. + bool serialized; + + SILDifferentiabilityWitness(SILModule &module, + SILFunction *originalFunction, + AutoDiffIndexSubset *parameterIndices, + AutoDiffIndexSubset *resultIndices, + bool isSerialized) + : moduel(module), originalFunction(originalFunction), + parameterIndices(parameterIndices), resultIndices(resultIndices), + serialized(isSerialized) {} + +public: + static SILProperty *create(SILModule &M, + bool Serialized, + AbstractStorageDecl *Decl, + Optional Component); + + bool isSerialized() const { return Serialized; } + + AbstractStorageDecl *getDecl() const { return Decl; } + + bool isTrivial() const { + return !Component.hasValue(); + } + + const Optional &getComponent() const { + return Component; + } + + void print(SILPrintContext &Ctx) const; + void dump() const; + + void verify(const SILModule &M) const; +}; + +} // end namespace swift + +namespace llvm { + +//===----------------------------------------------------------------------===// +// ilist_traits for SILProperty +//===----------------------------------------------------------------------===// + +template <> +struct ilist_traits<::swift::SILProperty> + : public ilist_node_traits<::swift::SILProperty> { + using SILProperty = ::swift::SILProperty; + +public: + static void deleteNode(SILProperty *VT) { VT->~SILProperty(); } + +private: + void createNode(const SILProperty &); +}; + +} // namespace llvm + +#endif From a5c1d13cc5501baf061af781d62f996482135ef8 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Tue, 1 Oct 2019 21:46:12 -0700 Subject: [PATCH 02/26] Rename SILDifferentiabilityWitness.h. --- ...SILDifferentibilityWitness.h => SILDifferentiabilityWitness.h} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/swift/SIL/{SILDifferentibilityWitness.h => SILDifferentiabilityWitness.h} (100%) diff --git a/include/swift/SIL/SILDifferentibilityWitness.h b/include/swift/SIL/SILDifferentiabilityWitness.h similarity index 100% rename from include/swift/SIL/SILDifferentibilityWitness.h rename to include/swift/SIL/SILDifferentiabilityWitness.h From 0ce7d9128676f14d0f8543991aa807ee0f88cf8b Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Tue, 1 Oct 2019 21:48:09 -0700 Subject: [PATCH 03/26] Update SILDifferentiabilityWitness definition. --- .../swift/SIL/SILDifferentiabilityWitness.h | 95 +++++++++++-------- lib/SIL/CMakeLists.txt | 2 + lib/SIL/SILDifferentiabilityWitness.cpp | 32 +++++++ 3 files changed, 88 insertions(+), 41 deletions(-) create mode 100644 lib/SIL/SILDifferentiabilityWitness.cpp diff --git a/include/swift/SIL/SILDifferentiabilityWitness.h b/include/swift/SIL/SILDifferentiabilityWitness.h index 93794ae63b820..da71d78799e36 100644 --- a/include/swift/SIL/SILDifferentiabilityWitness.h +++ b/include/swift/SIL/SILDifferentiabilityWitness.h @@ -1,4 +1,4 @@ -//===--- SILProperty.h - Defines the SILProperty class ----------*- C++ -*-===// +//===--- SILDifferentiabilityWitness.h - Differentiability witnesses ------===// // // This source file is part of the Swift.org open source project // @@ -10,9 +10,16 @@ // //===----------------------------------------------------------------------===// // -// This file defines the SILProperty class, which is used to capture the -// metadata about a property definition necessary for it to be resiliently -// included in KeyPaths across modules. +// This file defines the SILDifferentiabilityWitness class, which maps an +// original SILFunction and derivative configuration (parameter indices, result +// indices, derivative generic signature) to derivative functions (JVP and VJP). +// +// SIL differentiability witnesses are generated from the `@differentiable` +// and `@differentiating` attributes AST declaration attributes. +// Differentiability witnesses are canonicalized by the differentiation SIL +// transform, which fills in missing derivative functions. Canonical +// differentiability witnesses from other modules can be deserialized to look up +// derivative functions. // //===----------------------------------------------------------------------===// @@ -35,52 +42,56 @@ class SILDifferentiabilityWitness public SILAllocated { private: - /// The module which contains the SILWitnessTable. + /// The module which contains the SIL differentiability witness. SILModule &module; /// The original function. SILFunction *originalFunction; - /// The parameter indieces. + /// The parameter indices. AutoDiffIndexSubset *parameterIndices; - /// The result indieces. + /// The result indices. AutoDiffIndexSubset *resultIndices; - /// The max differentiation order. - unsigned maxOrder; - /// Derivative functions. - MutableArrayRef derivatives; - /// True if serialized. + /// The derivative generic signature (optional). + GenericSignature *derivativeGenericSignature; + /// The JVP (Jacobian-vector products) derivative function. + SILFunction *jvp; + /// The VJP (vector-Jacobian products) derivative function. + SILFunction *vjp; + /// Whether or not this differentiability witness is serialized, which allows + /// devirtualization from another module. bool serialized; - SILDifferentiabilityWitness(SILModule &module, - SILFunction *originalFunction, + SILDifferentiabilityWitness(SILModule &module, SILFunction *originalFunction, AutoDiffIndexSubset *parameterIndices, AutoDiffIndexSubset *resultIndices, + GenericSignature *derivativeGenSig, + SILFunction *jvp, SILFunction *vjp, bool isSerialized) - : moduel(module), originalFunction(originalFunction), + : module(module), originalFunction(originalFunction), parameterIndices(parameterIndices), resultIndices(resultIndices), + derivativeGenericSignature(derivativeGenSig), jvp(jvp), vjp(vjp), serialized(isSerialized) {} public: - static SILProperty *create(SILModule &M, - bool Serialized, - AbstractStorageDecl *Decl, - Optional Component); - - bool isSerialized() const { return Serialized; } - - AbstractStorageDecl *getDecl() const { return Decl; } - - bool isTrivial() const { - return !Component.hasValue(); + SILModule &getModule() const { return module; } + SILFunction *getOriginalFunction() const { return originalFunction; } + AutoDiffIndexSubset *getParameterIndices() const { + return parameterIndices; + } + AutoDiffIndexSubset *getResultIndices() const { + return resultIndices; } - - const Optional &getComponent() const { - return Component; + GenericSignature *getDerivativeGenericSignature() const { + return derivativeGenericSignature; } - - void print(SILPrintContext &Ctx) const; - void dump() const; - - void verify(const SILModule &M) const; + SILFunction *getJVP() const { return jvp; } + SILFunction *getVJP() const { return vjp; } + bool isSerialized() const { return serialized; } + + static SILDifferentiabilityWitness *create( + SILModule &module, SILFunction *originalFunction, + AutoDiffIndexSubset *parameterIndices, AutoDiffIndexSubset *resultIndices, + GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp, + bool isSerialized); }; } // end namespace swift @@ -88,21 +99,23 @@ class SILDifferentiabilityWitness namespace llvm { //===----------------------------------------------------------------------===// -// ilist_traits for SILProperty +// ilist_traits for SILDifferentiabilityWitness //===----------------------------------------------------------------------===// template <> -struct ilist_traits<::swift::SILProperty> - : public ilist_node_traits<::swift::SILProperty> { - using SILProperty = ::swift::SILProperty; +struct ilist_traits<::swift::SILDifferentiabilityWitness> + : public ilist_node_traits<::swift::SILDifferentiabilityWitness> { + using SILDifferentiabilityWitness = ::swift::SILDifferentiabilityWitness; public: - static void deleteNode(SILProperty *VT) { VT->~SILProperty(); } + static void deleteNode(SILDifferentiabilityWitness *DW) { + DW->~SILDifferentiabilityWitness(); + } private: - void createNode(const SILProperty &); + void createNode(const SILDifferentiabilityWitness &); }; } // namespace llvm -#endif +#endif // SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H diff --git a/lib/SIL/CMakeLists.txt b/lib/SIL/CMakeLists.txt index 2657fcfdfb29c..b65a01c26f683 100644 --- a/lib/SIL/CMakeLists.txt +++ b/lib/SIL/CMakeLists.txt @@ -26,6 +26,8 @@ add_swift_host_library(swiftSIL STATIC SILDebugScope.cpp SILDeclRef.cpp SILDefaultWitnessTable.cpp + # SWIFT_ENABLE_TENSORFLOW + SILDifferentiabilityWitness.cpp SILFunction.cpp SILFunctionType.cpp SILGlobalVariable.cpp diff --git a/lib/SIL/SILDifferentiabilityWitness.cpp b/lib/SIL/SILDifferentiabilityWitness.cpp new file mode 100644 index 0000000000000..e0bb329aa569b --- /dev/null +++ b/lib/SIL/SILDifferentiabilityWitness.cpp @@ -0,0 +1,32 @@ +//===--- SILDifferentiabilityWitness.cpp - Differentiability witnesses ----===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "sil-differentiability-witness" + +#include "swift/SIL/SILDifferentiabilityWitness.h" +#include "swift/SIL/SILModule.h" + +using namespace swift; + +SILDifferentiabilityWitness *SILDifferentiabilityWitness::create( + SILModule &module, SILFunction *originalFunction, + AutoDiffIndexSubset *parameterIndices, AutoDiffIndexSubset *resultIndices, + GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp, + bool isSerialized) { + void *buf = module.allocate(sizeof(SILDifferentiabilityWitness), + alignof(SILDifferentiabilityWitness)); + SILDifferentiabilityWitness *dw = ::new (buf) + SILDifferentiabilityWitness(module, originalFunction, parameterIndices, + resultIndices, derivativeGenSig, jvp, vjp, + isSerialized); + return dw; +} From 471e660321883bc5eb3bad0b73d40cf154909061 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Wed, 2 Oct 2019 00:33:10 -0700 Subject: [PATCH 04/26] Add SILDifferentiabilityWitness to SILModule. --- include/swift/SIL/SILModule.h | 40 +++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/include/swift/SIL/SILModule.h b/include/swift/SIL/SILModule.h index 802bb85041b32..b6a3bab78e252 100644 --- a/include/swift/SIL/SILModule.h +++ b/include/swift/SIL/SILModule.h @@ -28,6 +28,8 @@ #include "swift/SIL/SILCoverageMap.h" #include "swift/SIL/SILDeclRef.h" #include "swift/SIL/SILDefaultWitnessTable.h" +// SWIFT_ENABLE_TENSORFLOW +#include "swift/SIL/SILDifferentiabilityWitness.h" #include "swift/SIL/SILFunction.h" #include "swift/SIL/SILGlobalVariable.h" #include "swift/SIL/SILPrintContext.h" @@ -113,6 +115,10 @@ class SILModule { using PropertyListType = llvm::ilist; using WitnessTableListType = llvm::ilist; using DefaultWitnessTableListType = llvm::ilist; + // SWIFT_ENABLE_TENSORFLOW + using DifferentiabilityWitnessListType = + llvm::ilist; + // SWIFT_ENABLE_TENSORFLOW END using CoverageMapCollectionType = llvm::MapVector; @@ -194,6 +200,19 @@ class SILModule { /// The list of SILDefaultWitnessTables in the module. DefaultWitnessTableListType defaultWitnessTables; + // SWIFT_ENABLE_TENSORFLOW + /// Lookup table for SIL differentiability witnesses from original functions. + /// Indexed by original function, parameter indices, result indices, and + /// derivative generic signature. + llvm::DenseMap, + SILDifferentiabilityWitness *> + DifferentiabilityWitnessMap; + + /// The list of SILDifferentiabilityWitnesses in the module. + DifferentiabilityWitnessListType differentiabilityWitnesses; + // SWIFT_ENABLE_TENSORFLOW END + /// Lookup table for SIL Global Variables. llvm::StringMap GlobalVariableMap; @@ -446,6 +465,27 @@ class SILModule { return {defaultWitnessTables.begin(), defaultWitnessTables.end()}; } + // SWIFT_ENABLE_TENSORFLOW + using differentiability_witness_iterator = DifferentiabilityWitnessListType::iterator; + using differentiability_witness_const_iterator = DifferentiabilityWitnessListType::const_iterator; + DifferentiabilityWitnessListType &getDifferentiabilityWitnessList() { return differentiabilityWitnesses; } + const DifferentiabilityWitnessListType &getDifferentiabilityWitnessList() const { return differentiabilityWitnesses; } + differentiability_witness_iterator differentiability_witness_begin() { return differentiabilityWitnesses.begin(); } + differentiability_witness_iterator differentiability_witness_end() { return differentiabilityWitnesses.end(); } + differentiability_witness_const_iterator differentiability_witness_begin() const { return differentiabilityWitnesses.begin(); } + differentiability_witness_const_iterator differentiability_witness_end() const { return differentiabilityWitnesses.end(); } + iterator_range + getDifferentiabilityWitnesses() { + return {differentiabilityWitnesses.begin(), + differentiabilityWitnesses.end()}; + } + iterator_range + getDifferentiabilityWitnesses() const { + return {differentiabilityWitnesses.begin(), + differentiabilityWitnesses.end()}; + } + // SWIFT_ENABLE_TENSORFLOW END + using sil_global_iterator = GlobalListType::iterator; using sil_global_const_iterator = GlobalListType::const_iterator; GlobalListType &getSILGlobalList() { return silGlobals; } From bf87d2e46eeb729810f695c03db5f77e0750c22f Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Wed, 2 Oct 2019 00:33:39 -0700 Subject: [PATCH 05/26] [WIP] Start SILDifferentiabilityWitness parsing/printing. - Printing compiles but may require changes to be parseable. Namely, if there is no suitable utility for parsing standalone generic signatures, changes are needed. - Parsing is a stub. - Note: it is difficult to test parsing/printing without generating SILDifferentiabilityWitness instances. --- include/swift/Parse/ParseSILSupport.h | 3 + .../swift/SIL/SILDifferentiabilityWitness.h | 3 + include/swift/Syntax/TokenKinds.def.gyb | 3 + lib/ParseSIL/ParseSIL.cpp | 19 +++++ lib/SIL/SILPrinter.cpp | 74 +++++++++++++++++++ lib/Syntax/SyntaxSerialization.cpp.gyb | 3 + 6 files changed, 105 insertions(+) diff --git a/include/swift/Parse/ParseSILSupport.h b/include/swift/Parse/ParseSILSupport.h index 00eaba3e80caf..e3e2225c12df0 100644 --- a/include/swift/Parse/ParseSILSupport.h +++ b/include/swift/Parse/ParseSILSupport.h @@ -32,6 +32,9 @@ namespace swift { virtual bool parseSILGlobal(Parser &P) = 0; virtual bool parseSILWitnessTable(Parser &P) = 0; virtual bool parseSILDefaultWitnessTable(Parser &P) = 0; + // SWIFT_ENABLE_TENSORFLOW + virtual bool parseSILDifferentiabilityWitness(Parser &P) = 0; + // SWIFT_ENABLE_TENSORFLOW END virtual bool parseSILCoverageMap(Parser &P) = 0; virtual bool parseSILProperty(Parser &P) = 0; virtual bool parseSILScope(Parser &P) = 0; diff --git a/include/swift/SIL/SILDifferentiabilityWitness.h b/include/swift/SIL/SILDifferentiabilityWitness.h index da71d78799e36..b1d65ba535cc1 100644 --- a/include/swift/SIL/SILDifferentiabilityWitness.h +++ b/include/swift/SIL/SILDifferentiabilityWitness.h @@ -92,6 +92,9 @@ class SILDifferentiabilityWitness AutoDiffIndexSubset *parameterIndices, AutoDiffIndexSubset *resultIndices, GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp, bool isSerialized); + + void print(llvm::raw_ostream &OS, bool verbose = false) const; + void dump() const; }; } // end namespace swift diff --git a/include/swift/Syntax/TokenKinds.def.gyb b/include/swift/Syntax/TokenKinds.def.gyb index 4627055832e9c..ee20521c099fe 100644 --- a/include/swift/Syntax/TokenKinds.def.gyb +++ b/include/swift/Syntax/TokenKinds.def.gyb @@ -165,6 +165,9 @@ SIL_KEYWORD(sil_vtable) SIL_KEYWORD(sil_global) SIL_KEYWORD(sil_witness_table) SIL_KEYWORD(sil_default_witness_table) +// SWIFT_ENABLE_TENSORFLOW +SIL_KEYWORD(sil_differentiability_witness) +// SWIFT_ENABLE_TENSORFLOW END SIL_KEYWORD(sil_coverage_map) SIL_KEYWORD(sil_scope) diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index 3e83587606ec1..5541f3a70fc39 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -77,6 +77,9 @@ class SILParserTUState : public SILParserTUStateBase { bool parseSILGlobal(Parser &P) override; bool parseSILWitnessTable(Parser &P) override; bool parseSILDefaultWitnessTable(Parser &P) override; + // SWIFT_ENABLE_TENSORFLOW + bool parseSILDifferentiabilityWitness(Parser &P) override; + // SWIFT_ENABLE_TENSORFLOW END bool parseSILCoverageMap(Parser &P) override; bool parseSILProperty(Parser &P) override; bool parseSILScope(Parser &P) override; @@ -6793,6 +6796,22 @@ bool SILParserTUState::parseSILDefaultWitnessTable(Parser &P) { return false; } +/// decl-sil-differentiability-witness ::= +/// 'sil_differentiability_witness' +/// sil-function-name +/// 'wrt' autodiff-index-subset +/// 'sources' autodiff-index-subset +/// ('derivative_generic_signature' generic-signature)? +/// '{' ('jvp' sil-function-name)? ('vjp' sil-function-name)? '}' +/// +/// autodiff-index-subset ::= +/// [0-9]+ (',', [0-9]+)* +bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { + P.consumeToken(tok::kw_sil_differentiability_witness); + // TODO(TF-867): Implement parsing. Test round-tripping with printing. + return false; +} + llvm::Optional SILParser::parseSILCoverageExpr( llvm::coverage::CounterExpressionBuilder &Builder) { if (P.Tok.is(tok::integer_literal)) { diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index 13cf14af620b9..232e9b78c2e8e 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -2704,6 +2704,32 @@ printSILDefaultWitnessTables(SILPrintContext &Ctx, wt->print(Ctx.OS(), Ctx.printVerbose()); } +// SWIFT_ENABLE_TENSORFLOW +static void printSILDifferentiabilityWitnesses( + SILPrintContext &Ctx, + const SILModule::DifferentiabilityWitnessListType &diffWitnesses) { + if (!Ctx.sortSIL()) { + for (auto &dw : diffWitnesses) + dw.print(Ctx.OS(), Ctx.printVerbose()); + return; + } + + std::vector sortedDiffWitnesses; + sortedDiffWitnesses.reserve(diffWitnesses.size()); + for (auto &dw : diffWitnesses) + sortedDiffWitnesses.push_back(&dw); + std::sort(sortedDiffWitnesses.begin(), sortedDiffWitnesses.end(), + [] (const SILDifferentiabilityWitness *w1, + const SILDifferentiabilityWitness *w2) -> bool { + return w1->getOriginalFunction()->getName() + .compare(w2->getOriginalFunction()->getName()); + } + ); + for (auto *dw : sortedDiffWitnesses) + dw->print(Ctx.OS(), Ctx.printVerbose()); +} +// SWIFT_ENABLE_TENSORFLOW END + static void printSILCoverageMaps(SILPrintContext &Ctx, const SILModule::CoverageMapCollectionType &CoverageMaps) { @@ -2821,6 +2847,10 @@ void SILModule::print(SILPrintContext &PrintCtx, ModuleDecl *M, printSILVTables(PrintCtx, getVTableList()); printSILWitnessTables(PrintCtx, getWitnessTableList()); printSILDefaultWitnessTables(PrintCtx, getDefaultWitnessTableList()); + // SWIFT_ENABLE_TENSORFLOW + printSILDifferentiabilityWitnesses(PrintCtx, + getDifferentiabilityWitnessList()); + // SWIFT_ENABLE_TENSORFLOW END printSILCoverageMaps(PrintCtx, getCoverageMaps()); printSILProperties(PrintCtx, getPropertyList()); @@ -3035,6 +3065,50 @@ void SILDefaultWitnessTable::dump() const { print(llvm::errs()); } +// SWIFT_ENABLE_TENSORFLOW +void SILDifferentiabilityWitness::print( + llvm::raw_ostream &OS, bool verbose) const { + // TODO(TF-867): Test SIL differentiability witness round-trip printing and + // parsing. It is currently untested and certainly broken. + + // sil_differentiability_witness @original-function-name + PrintOptions qualifiedSILTypeOptions = PrintOptions::printQualifiedSILType(); + OS << "sil_differentiability_witness @" << originalFunction->getName(); + // wrt 0, 1, ... + OS << " wrt "; + interleave(parameterIndices->getIndices(), + [&](unsigned index) { OS << index; }, + [&] { OS << ", "; }); + // sources 0, 1, ... + OS << " sources "; + interleave(resultIndices->getIndices(), + [&](unsigned index) { OS << index; }, + [&] { OS << ", "; }); + // wrt 0, 1, ... + if (derivativeGenericSignature) { + OS << " derivative_generic_signature "; + // NOTE: This needs to be changed if there is no utility for parsing + // generic signatures. Idea: we could instead print the type of the original + // function substituted into this generic signature. + derivativeGenericSignature->print(OS); + } + // { + // jvp: @jvp-function-name + // vjp: @vjp-function-name + // } + OS << " {\n"; + if (jvp) + OS << " jvp: @" << jvp->getName() << "\n"; + if (vjp) + OS << " vjp: @" << vjp->getName() << "\n"; + OS << "}\n\n"; +} + +void SILDifferentiabilityWitness::dump() const { + print(llvm::errs()); +} +// SWIFT_ENABLE_TENSORFLOW END + void SILCoverageMap::print(SILPrintContext &PrintCtx) const { llvm::raw_ostream &OS = PrintCtx.OS(); OS << "sil_coverage_map " << QuotedString(getFile()) << " " diff --git a/lib/Syntax/SyntaxSerialization.cpp.gyb b/lib/Syntax/SyntaxSerialization.cpp.gyb index 9c917abedc135..46f77413cbf22 100644 --- a/lib/Syntax/SyntaxSerialization.cpp.gyb +++ b/lib/Syntax/SyntaxSerialization.cpp.gyb @@ -48,6 +48,9 @@ uint8_t WrapperTypeTraits::numericValue(const tok &Value) { case tok::kw_sil_global: case tok::kw_sil_witness_table: case tok::kw_sil_default_witness_table: + // SWIFT_ENABLE_TENSORFLOW + case tok::kw_sil_differentiability_witness: + // SWIFT_ENABLE_TENSORFLOW END case tok::kw_sil_coverage_map: case tok::kw_sil_scope: case tok::sil_dollar: From 6b7868424491670963673212a44c111b406433eb Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Wed, 2 Oct 2019 09:39:06 -0700 Subject: [PATCH 06/26] Use improved syntax. `parameters (0, 1, ...) results (0, 1, ...) where <...>` --- lib/SIL/SILPrinter.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index 232e9b78c2e8e..aab259ac47f6c 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -2721,6 +2721,7 @@ static void printSILDifferentiabilityWitnesses( std::sort(sortedDiffWitnesses.begin(), sortedDiffWitnesses.end(), [] (const SILDifferentiabilityWitness *w1, const SILDifferentiabilityWitness *w2) -> bool { + // TODO: Sort based on more criteria. return w1->getOriginalFunction()->getName() .compare(w2->getOriginalFunction()->getName()); } @@ -3074,19 +3075,20 @@ void SILDifferentiabilityWitness::print( // sil_differentiability_witness @original-function-name PrintOptions qualifiedSILTypeOptions = PrintOptions::printQualifiedSILType(); OS << "sil_differentiability_witness @" << originalFunction->getName(); - // wrt 0, 1, ... - OS << " wrt "; + // parameters (0, 1, ...) + OS << " parameters ("; interleave(parameterIndices->getIndices(), [&](unsigned index) { OS << index; }, [&] { OS << ", "; }); - // sources 0, 1, ... - OS << " sources "; + // results (0, 1, ...) + OS << ") results ("; interleave(resultIndices->getIndices(), [&](unsigned index) { OS << index; }, [&] { OS << ", "; }); + OS << ')'; // wrt 0, 1, ... if (derivativeGenericSignature) { - OS << " derivative_generic_signature "; + OS << " where "; // NOTE: This needs to be changed if there is no utility for parsing // generic signatures. Idea: we could instead print the type of the original // function substituted into this generic signature. From a2ae0f208985a6acdb6e09ad3a7ebfe71a84da56 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 10 Oct 2019 20:10:21 -0700 Subject: [PATCH 07/26] Finish parsing/printing/serialization. --- include/swift/AST/DiagnosticsParse.def | 14 ++ .../swift/SIL/SILDifferentiabilityWitness.h | 24 +- include/swift/SIL/SILModule.h | 10 +- .../swift/Serialization/SerializedSILLoader.h | 5 + lib/Parse/ParseDecl.cpp | 3 + lib/ParseSIL/ParseSIL.cpp | 218 +++++++++++++++++- lib/SIL/SILDifferentiabilityWitness.cpp | 18 +- lib/SIL/SILPrinter.cpp | 45 +++- lib/Serialization/DeserializeSIL.cpp | 90 +++++++- lib/Serialization/DeserializeSIL.h | 19 ++ lib/Serialization/SILFormat.h | 20 ++ lib/Serialization/Serialization.cpp | 3 + lib/Serialization/SerializeSIL.cpp | 71 ++++++ lib/Serialization/SerializedSILLoader.cpp | 7 + 14 files changed, 509 insertions(+), 38 deletions(-) diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index 1df71ebb700a6..2aa113ee7ddda 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -686,6 +686,20 @@ ERROR(sil_witness_assoc_conf_not_found,none, ERROR(sil_witness_protocol_conformance_not_found,none, "sil protocol conformance not found", ()) +// [differentiable ...] (sil-decl attr) +ERROR(sil_diff_witness_expected_keyword,PointsToFirstBadToken, + "expected '%0' in differentiability witness", (StringRef)) +ERROR(sil_diff_witness_expected_parameter_list,PointsToFirstBadToken, + "expected an comma-separated list of parameter indices, e.g. (0, 1)", ()) +ERROR(sil_diff_witness_expected_rsquare,PointsToFirstBadToken, + "expected ']' to end 'differentiable' attribute", ()) +ERROR(sil_diff_witness_expected_parameter_index,PointsToFirstBadToken, + "expected the index of a parameter to differentiate w.r.t.", ()) +ERROR(sil_diff_witness_expected_source_index,PointsToFirstBadToken, + "expected the index of a result to differentiate from", ()) + +// SIL differentiability witnesses + // SIL Coverage Map ERROR(sil_coverage_func_not_found, none, "sil function not found %0", (Identifier)) diff --git a/include/swift/SIL/SILDifferentiabilityWitness.h b/include/swift/SIL/SILDifferentiabilityWitness.h index b1d65ba535cc1..f1a44197548ed 100644 --- a/include/swift/SIL/SILDifferentiabilityWitness.h +++ b/include/swift/SIL/SILDifferentiabilityWitness.h @@ -2,7 +2,7 @@ // // This source file is part of the Swift.org open source project // -// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors +// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information @@ -42,8 +42,10 @@ class SILDifferentiabilityWitness public SILAllocated { private: - /// The module which contains the SIL differentiability witness. + /// The module which contains the differentiability witness. SILModule &module; + /// The linkage of the differentiability witness. + SILLinkage linkage; /// The original function. SILFunction *originalFunction; /// The parameter indices. @@ -60,19 +62,31 @@ class SILDifferentiabilityWitness /// devirtualization from another module. bool serialized; - SILDifferentiabilityWitness(SILModule &module, SILFunction *originalFunction, + SILDifferentiabilityWitness(SILModule &module, SILLinkage linkage, + SILFunction *originalFunction, AutoDiffIndexSubset *parameterIndices, AutoDiffIndexSubset *resultIndices, GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp, bool isSerialized) - : module(module), originalFunction(originalFunction), + : module(module), linkage(linkage), originalFunction(originalFunction), parameterIndices(parameterIndices), resultIndices(resultIndices), derivativeGenericSignature(derivativeGenSig), jvp(jvp), vjp(vjp), serialized(isSerialized) {} public: + /// The key type, used for uniquing `SILDifferentiabilityWitness` in + /// `SILModule`, original function, parameter indices, result indices, and + /// derivative generic signature. + using Key = std::tuple; + Key getKey() { + return std::make_tuple(originalFunction, parameterIndices, resultIndices, + derivativeGenericSignature); + } + SILModule &getModule() const { return module; } + SILLinkage getLinkage() const { return linkage; } SILFunction *getOriginalFunction() const { return originalFunction; } AutoDiffIndexSubset *getParameterIndices() const { return parameterIndices; @@ -88,7 +102,7 @@ class SILDifferentiabilityWitness bool isSerialized() const { return serialized; } static SILDifferentiabilityWitness *create( - SILModule &module, SILFunction *originalFunction, + SILModule &module, SILLinkage linkage, SILFunction *originalFunction, AutoDiffIndexSubset *parameterIndices, AutoDiffIndexSubset *resultIndices, GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp, bool isSerialized); diff --git a/include/swift/SIL/SILModule.h b/include/swift/SIL/SILModule.h index b6a3bab78e252..c434986ca97df 100644 --- a/include/swift/SIL/SILModule.h +++ b/include/swift/SIL/SILModule.h @@ -145,6 +145,9 @@ class SILModule { friend SILProperty; friend SILUndef; friend SILWitnessTable; + // SWIFT_ENABLE_TENSORFLOW + friend SILDifferentiabilityWitness; + // SWIFT_ENABLE_TENSORFLOW END friend Lowering::SILGenModule; friend Lowering::TypeConverter; class SerializationCallback; @@ -202,10 +205,9 @@ class SILModule { // SWIFT_ENABLE_TENSORFLOW /// Lookup table for SIL differentiability witnesses from original functions. - /// Indexed by original function, parameter indices, result indices, and - /// derivative generic signature. - llvm::DenseMap, + /// Indexed by key type: original function, parameter indices, result indices, + /// and derivative generic signature. + llvm::DenseMap DifferentiabilityWitnessMap; diff --git a/include/swift/Serialization/SerializedSILLoader.h b/include/swift/Serialization/SerializedSILLoader.h index 41a7b794e6f66..c78c446f16408 100644 --- a/include/swift/Serialization/SerializedSILLoader.h +++ b/include/swift/Serialization/SerializedSILLoader.h @@ -99,6 +99,11 @@ class SerializedSILLoader { /// Deserialize all Properties in all SILModules. void getAllProperties(); + // SWIFT_ENABLE_TENSORFLOW + /// Deserialize all DifferentiabilityWitnesses in all SILModules. + void getAllDifferentiabilityWitnesses(); + // SWIFT_ENABLE_TENSORFLOW END + SerializedSILLoader(const SerializedSILLoader &) = delete; SerializedSILLoader(SerializedSILLoader &&) = delete; SerializedSILLoader &operator=(const SerializedSILLoader &) = delete; diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index 7eff6979dc668..19bf0db1d7e51 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -94,6 +94,9 @@ bool Parser::parseTopLevel() { CASE_SIL(sil_global, SILGlobal) CASE_SIL(sil_witness_table, SILWitnessTable) CASE_SIL(sil_default_witness_table, SILDefaultWitnessTable) + // SWIFT_ENABLE_TENSORFLOW + CASE_SIL(sil_differentiability_witness, SILDifferentiabilityWitness) + // SWIFT_ENABLE_TENSORFLOW END CASE_SIL(sil_coverage_map, SILCoverageMap) CASE_SIL(sil_property, SILProperty) CASE_SIL(sil_scope, SILScope) diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index 7ab56a4552b96..b0ed0b7759ec0 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -6745,21 +6745,225 @@ bool SILParserTUState::parseSILDefaultWitnessTable(Parser &P) { return false; } +// SWIFT_ENABLE_TENSORFLOW +// TODO: Dedupe with `SILParser::convertRequirements` upstream. +// Consider defining this as `Parser::convertRequirements`. +static void convertRequirements(Parser &P, SILFunction *F, + ArrayRef From, + SmallVectorImpl &To) { + if (From.empty()) { + To.clear(); + return; + } + + auto *GenericEnv = F->getLoweredFunctionType() + ->getGenericSignature() + ->getGenericEnvironment(); + assert(GenericEnv); + (void)GenericEnv; + + IdentTypeReprLookup PerformLookup(P); + // Use parser lexical scopes to resolve references + // to the generic parameters. + auto ResolveToInterfaceType = [&](TypeLoc Ty) -> Type { + Ty.getTypeRepr()->walk(PerformLookup); + swift::performTypeLocChecking(P.Context, Ty, /*isSILMode*/ true, + /*isSILType*/ true, GenericEnv, &P.SF); + assert(Ty.getType()); + return Ty.getType()->mapTypeOutOfContext(); + }; + + for (auto &Req : From) { + if (Req.getKind() == RequirementReprKind::SameType) { + auto FirstType = ResolveToInterfaceType(Req.getFirstTypeLoc()); + auto SecondType = ResolveToInterfaceType(Req.getSecondTypeLoc()); + Requirement ConvertedRequirement(RequirementKind::SameType, FirstType, + SecondType); + To.push_back(ConvertedRequirement); + continue; + } + + if (Req.getKind() == RequirementReprKind::TypeConstraint) { + auto Subject = ResolveToInterfaceType(Req.getSubjectLoc()); + auto Constraint = ResolveToInterfaceType(Req.getConstraintLoc()); + Requirement ConvertedRequirement(RequirementKind::Conformance, Subject, + Constraint); + To.push_back(ConvertedRequirement); + continue; + } + + if (Req.getKind() == RequirementReprKind::LayoutConstraint) { + auto Subject = ResolveToInterfaceType(Req.getSubjectLoc()); + Requirement ConvertedRequirement(RequirementKind::Layout, Subject, + Req.getLayoutConstraint()); + To.push_back(ConvertedRequirement); + continue; + } + llvm_unreachable("Unsupported requirement kind"); + } +} + /// decl-sil-differentiability-witness ::= /// 'sil_differentiability_witness' -/// sil-function-name -/// 'wrt' autodiff-index-subset -/// 'sources' autodiff-index-subset -/// ('derivative_generic_signature' generic-signature)? -/// '{' ('jvp' sil-function-name)? ('vjp' sil-function-name)? '}' +/// sil-function-name ':' sil-type +/// 'parameters' autodiff-index-subset +/// 'results' autodiff-index-subset +/// ('where' generic-signature)? +/// '{' +/// ('jvp' sil-function-name ':' sil-type)? +/// ('vjp' sil-function-name ':' sil-type)? +/// '}' /// /// autodiff-index-subset ::= -/// [0-9]+ (',', [0-9]+)* +/// '(' [0-9]+ (',', [0-9]+)* ')' bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { P.consumeToken(tok::kw_sil_differentiability_witness); - // TODO(TF-867): Implement parsing. Test round-tripping with printing. + SILParser State(P); + + // Parse the linkage. + Optional linkage; + if (parseSILLinkage(linkage, P)) + return true; + if (!linkage) + linkage = SILLinkage::PublicExternal; + + Scope S(&P, ScopeKind::TopLevel); + Scope Body(&P, ScopeKind::FunctionBody); + + auto parseFunctionNameAndType = [&](SILFunction *&fn) -> bool { + Identifier name; + SILType ty; + SourceLoc fnNameLoc = P.Tok.getLoc(); + // We need to turn on InSILBody to parse the function reference. + Lexer::SILBodyRAII tmp(*P.L); + GenericEnvironment *ignoredEnv; + if ((State.parseGlobalName(name)) || + P.parseToken(tok::colon, diag::expected_sil_colon_value_ref) || + State.parseSILType(ty, ignoredEnv, /*IsFuncDecl*/ true)) + return true; + + // The function doesn't exist yet. Create a zombie forward declaration. + auto fnType = ty.getAs(); + if (!fnType || !ty.isObject()) { + P.diagnose(fnNameLoc, diag::expected_sil_function_type); + return true; + } + fn = State.getGlobalNameForReference(name, fnType, fnNameLoc, true); + State.TUState.PotentialZombieFns.insert(fn); + return false; + }; + + SourceLoc lastLoc = P.getEndOfPreviousLoc(); + + SILFunction *originalFn; + if (parseFunctionNameAndType(originalFn)) + return true; + + auto parseAutoDiffIndexSubset = + [&](StringRef label, AutoDiffIndexSubset *& paramIndexSubset) -> bool { + if (P.parseSpecificIdentifier( + label, diag::sil_diff_witness_expected_keyword, label)) + return true; + if (P.parseToken(tok::l_paren, diag::sil_diff_witness_expected_keyword, + "(")) + return true; + // Parse parameter index list. + SmallVector paramIndices; + // Function that parses an index into `paramIndices`. Returns true on error. + auto parseParam = [&]() -> bool { + unsigned index; + // TODO: Reject non-ascending parameter index lists. + if (P.parseUnsignedInteger(index, lastLoc, + diag::sil_diff_witness_expected_parameter_list)) + return true; + paramIndices.push_back(index); + return false; + }; + // Parse first. + if (parseParam()) + return true; + // Parse rest. + while (P.consumeIf(tok::comma)) + if (parseParam()) + return true; + if (P.parseToken(tok::r_paren, diag::sil_diff_witness_expected_keyword, + "(")) + return true; + auto maxIndexRef = + std::max_element(paramIndices.begin(), paramIndices.end()); + paramIndexSubset = AutoDiffIndexSubset::get( + P.Context, maxIndexRef ? *maxIndexRef + 1 : 0, paramIndices); + return false; + }; + AutoDiffIndexSubset *parameterIndices = nullptr; + AutoDiffIndexSubset *resultIndices = nullptr; + if (parseAutoDiffIndexSubset("parameters", parameterIndices)) + return true; + if (parseAutoDiffIndexSubset("results", resultIndices)) + return true; + + GenericSignature *derivativeGenSig = nullptr; + // Parse a trailing 'where' clause if any. + if (P.Tok.is(tok::kw_where)) { + SourceLoc whereLoc; + SmallVector requirementReprs; + bool firstTypeInComplete; + P.parseGenericWhereClause(whereLoc, requirementReprs, firstTypeInComplete, + /*AllowLayoutConstraints*/ false); + auto *whereClause = TrailingWhereClause::create( + originalFn->getModule().getASTContext(), whereLoc, requirementReprs); + SmallVector requirements; + convertRequirements(P, originalFn, whereClause->getRequirements(), + requirements); + assert(requirements.size() == requirementReprs.size()); + derivativeGenSig = evaluateOrDefault( + P.Context.evaluator, + AbstractGenericSignatureRequest{ + originalFn->getLoweredFunctionType()->getGenericSignature(), + /*addedGenericParams=*/{}, + std::move(requirements)}, + nullptr); + } + + SILFunction *jvp = nullptr; + SILFunction *vjp = nullptr; + if (P.Tok.is(tok::l_brace)) { + SourceLoc LBraceLoc = P.Tok.getLoc(); + P.consumeToken(tok::l_brace); + + if (P.Tok.is(tok::identifier) && P.Tok.getText() == "jvp") { + P.consumeToken(tok::identifier); + if (P.parseToken(tok::colon, diag::sil_diff_witness_expected_keyword, + ":")) + return true; + Scope Body(&P, ScopeKind::FunctionBody); + if (parseFunctionNameAndType(jvp)) + return true; + } + + if (P.Tok.is(tok::identifier) && P.Tok.getText() == "vjp") { + P.consumeToken(tok::identifier); + if (P.parseToken(tok::colon, diag::sil_diff_witness_expected_keyword, + ":")) + return true; + Scope Body(&P, ScopeKind::FunctionBody); + if (parseFunctionNameAndType(vjp)) + return true; + } + + if (P.parseMatchingToken(tok::r_brace, lastLoc, diag::expected_sil_rbrace, + LBraceLoc)) + return true; + } + + // TODO: Parse `isSerialized` flag. + bool isSerialized = false; + SILDifferentiabilityWitness::create( + M, *linkage, originalFn, parameterIndices, resultIndices, + derivativeGenSig, jvp, vjp, isSerialized); return false; } +// SWIFT_ENABLE_TENSORFLOW END llvm::Optional SILParser::parseSILCoverageExpr( llvm::coverage::CounterExpressionBuilder &Builder) { diff --git a/lib/SIL/SILDifferentiabilityWitness.cpp b/lib/SIL/SILDifferentiabilityWitness.cpp index e0bb329aa569b..0437d961ae835 100644 --- a/lib/SIL/SILDifferentiabilityWitness.cpp +++ b/lib/SIL/SILDifferentiabilityWitness.cpp @@ -2,7 +2,7 @@ // // This source file is part of the Swift.org open source project // -// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors +// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information @@ -18,15 +18,19 @@ using namespace swift; SILDifferentiabilityWitness *SILDifferentiabilityWitness::create( - SILModule &module, SILFunction *originalFunction, + SILModule &module, SILLinkage linkage, SILFunction *originalFunction, AutoDiffIndexSubset *parameterIndices, AutoDiffIndexSubset *resultIndices, GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp, bool isSerialized) { void *buf = module.allocate(sizeof(SILDifferentiabilityWitness), alignof(SILDifferentiabilityWitness)); - SILDifferentiabilityWitness *dw = ::new (buf) - SILDifferentiabilityWitness(module, originalFunction, parameterIndices, - resultIndices, derivativeGenSig, jvp, vjp, - isSerialized); - return dw; + auto *diffWitness = ::new (buf) SILDifferentiabilityWitness( + module, linkage, originalFunction, parameterIndices, resultIndices, + derivativeGenSig, jvp, vjp, isSerialized); + // Register the differentiability witness in the module. + assert(!module.DifferentiabilityWitnessMap.count(diffWitness->getKey()) && + "Cannot create duplicate differentiability witness in a module"); + module.DifferentiabilityWitnessMap[diffWitness->getKey()] = diffWitness; + module.getDifferentiabilityWitnessList().push_back(diffWitness); + return diffWitness; } diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index bd7d21771bf17..6fa826a10869a 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -2712,7 +2712,7 @@ static void printSILDifferentiabilityWitnesses( std::sort(sortedDiffWitnesses.begin(), sortedDiffWitnesses.end(), [] (const SILDifferentiabilityWitness *w1, const SILDifferentiabilityWitness *w2) -> bool { - // TODO: Sort based on more criteria. + // TODO: Sort based on more criteria for deterministic ordering. return w1->getOriginalFunction()->getName() .compare(w2->getOriginalFunction()->getName()); } @@ -3060,12 +3060,12 @@ void SILDefaultWitnessTable::dump() const { // SWIFT_ENABLE_TENSORFLOW void SILDifferentiabilityWitness::print( llvm::raw_ostream &OS, bool verbose) const { - // TODO(TF-867): Test SIL differentiability witness round-trip printing and - // parsing. It is currently untested and certainly broken. - - // sil_differentiability_witness @original-function-name + // sil_differentiability_witness @original-function-name : $original-sil-type PrintOptions qualifiedSILTypeOptions = PrintOptions::printQualifiedSILType(); - OS << "sil_differentiability_witness @" << originalFunction->getName(); + OS << "sil_differentiability_witness "; + printLinkage(OS, linkage, ForDefinition); + OS << "@" << originalFunction->getName() << " : " + << originalFunction->getLoweredType(); // parameters (0, 1, ...) OS << " parameters ("; interleave(parameterIndices->getIndices(), @@ -3079,21 +3079,42 @@ void SILDifferentiabilityWitness::print( OS << ')'; // wrt 0, 1, ... if (derivativeGenericSignature) { - OS << " where "; // NOTE: This needs to be changed if there is no utility for parsing // generic signatures. Idea: we could instead print the type of the original // function substituted into this generic signature. - derivativeGenericSignature->print(OS); + ArrayRef requirements; + SmallVector requirementsScratch; + auto *origGenEnv = originalFunction->getGenericEnvironment(); + if (derivativeGenericSignature) { + if (origGenEnv) { + requirementsScratch = + derivativeGenericSignature->requirementsNotSatisfiedBy( + origGenEnv->getGenericSignature()); + requirements = requirementsScratch; + } else { + requirements = derivativeGenericSignature->getRequirements(); + } + } + if (!requirements.empty()) { + OS << " where "; + auto SubPrinter = PrintOptions::printSIL(); + interleave(requirements, + [&](Requirement req) { + req.print(OS, SubPrinter); + return; + }, + [&] { OS << ", "; }); + } } // { - // jvp: @jvp-function-name - // vjp: @vjp-function-name + // jvp: @jvp-function-name : $jvp-sil-type + // vjp: @vjp-function-name : $vjp-sil-type // } OS << " {\n"; if (jvp) - OS << " jvp: @" << jvp->getName() << "\n"; + OS << " jvp: @" << jvp->getName() << " : " << jvp->getLoweredType() << "\n"; if (vjp) - OS << " vjp: @" << vjp->getName() << "\n"; + OS << " vjp: @" << vjp->getName() << " : " << vjp->getLoweredType() << "\n"; OS << "}\n\n"; } diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index e6ff6ad6b3c96..c8fb18f9bb549 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -148,7 +148,9 @@ SILDeserializer::SILDeserializer( // SIL_DEFAULT_WITNESS_TABLE_NAMES. But each one can be // omitted if no entries exist in the module file. unsigned kind = 0; - while (kind != sil_index_block::SIL_PROPERTY_OFFSETS) { +// SWIFT_ENABLE_TENSORFLOW + while (kind != sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_OFFSETS) { +// SWIFT_ENABLE_TENSORFLOW END auto next = cursor.advance(); if (next.Kind == llvm::BitstreamEntry::EndBlock) return; @@ -164,9 +166,13 @@ SILDeserializer::SILDeserializer( kind == sil_index_block::SIL_GLOBALVAR_NAMES || kind == sil_index_block::SIL_WITNESS_TABLE_NAMES || kind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_NAMES || - kind == sil_index_block::SIL_PROPERTY_OFFSETS)) && + kind == sil_index_block::SIL_PROPERTY_OFFSETS || +// SWIFT_ENABLE_TENSORFLOW + kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_OFFSETS)) && "Expect SIL_FUNC_NAMES, SIL_VTABLE_NAMES, SIL_GLOBALVAR_NAMES, \ - SIL_WITNESS_TABLE_NAMES, or SIL_DEFAULT_WITNESS_TABLE_NAMES."); + SIL_WITNESS_TABLE_NAMES, SIL_DEFAULT_WITNESS_TABLE_NAMES, \ + SIL_PROPERTY_OFFSETS, or SIL_DIFFERENTIABILITY_WITNESS_OFFSETS."); +// SWIFT_ENABLE_TENSORFLOW END (void)prevKind; if (kind == sil_index_block::SIL_FUNC_NAMES) @@ -182,8 +188,14 @@ SILDeserializer::SILDeserializer( else if (kind == sil_index_block::SIL_PROPERTY_OFFSETS) { // No matching 'names' block for property descriptors needed yet. MF->allocateBuffer(Properties, scratch); +// SWIFT_ENABLE_TENSORFLOW + } + else if (kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_OFFSETS) { + // No matching 'names' block for differentiability witnesses needed yet. + MF->allocateBuffer(DifferentiabilityWitnesses, scratch); return; } +// SWIFT_ENABLE_TENSORFLOW END // Read SIL_FUNC|VTABLE|GLOBALVAR_OFFSETS record. next = cursor.advance(); @@ -2945,6 +2957,78 @@ void SILDeserializer::getAllProperties() { } } +// SWIFT_ENABLE_TENSORFLOW +SILDifferentiabilityWitness * +SILDeserializer::readDifferentiabilityWitness(DeclID DId) { + auto &diffWitnessOrOffset = DifferentiabilityWitnesses[DId-1]; + + if (diffWitnessOrOffset.isFullyDeserialized()) + return diffWitnessOrOffset.get(); + + BCOffsetRAII restoreOffset(SILCursor); + SILCursor.JumpToBit(diffWitnessOrOffset.getOffset()); + auto entry = SILCursor.advance(AF_DontPopBlockAtEnd); + if (entry.Kind == llvm::BitstreamEntry::Error) { + LLVM_DEBUG(llvm::dbgs() + << "Cursor advance error in readDifferentiabilityWitness.\n"); + return nullptr; + } + + SmallVector scratch; + StringRef blobData; + unsigned kind = SILCursor.readRecord(entry.ID, scratch, &blobData); + assert(kind == SIL_DIFFERENTIABILITY_WITNESS && + "Expected sil_differentiability_witness"); + (void)kind; + + DeclID originalNameId, jvpNameId, vjpNameId; + unsigned rawLinkage, isSerialized, numParameterIndices, numResultIndices; + GenericSignatureID derivativeGenSigID; + ArrayRef rawParameterAndResultIndices; + + DifferentiabilityWitnessLayout::readRecord( + scratch, originalNameId, rawLinkage, isSerialized, derivativeGenSigID, + jvpNameId, vjpNameId, numParameterIndices, numResultIndices, + rawParameterAndResultIndices); + + auto linkage = fromStableSILLinkage(rawLinkage); + assert(linkage && "Expected value linkage for sil_differentiability_witness"); + auto originalName = MF->getIdentifier(originalNameId).str(); + auto jvpName = MF->getIdentifier(jvpNameId).str(); + auto vjpName = MF->getIdentifier(vjpNameId).str(); + auto *original = getFuncForReference(originalName); + auto *jvp = getFuncForReference(jvpName); + auto *vjp = getFuncForReference(vjpName); + auto derivativeGenSig = MF->getGenericSignature(derivativeGenSigID); + + SmallVector parameterAndResultIndices( + rawParameterAndResultIndices.begin(), + rawParameterAndResultIndices.end()); + assert(parameterAndResultIndices.size() == + numParameterIndices + numResultIndices && + "Parameter/result indices count mismatch"); + auto *parameterIndices = AutoDiffIndexSubset::get( + MF->getContext(), original->getLoweredFunctionType()->getNumParameters(), + ArrayRef(parameterAndResultIndices) + .take_front(numParameterIndices)); + auto *resultIndices = AutoDiffIndexSubset::get( + MF->getContext(), original->getLoweredFunctionType()->getNumResults(), + ArrayRef(parameterAndResultIndices) + .take_back(numResultIndices)); + + auto *diffWitness = SILDifferentiabilityWitness::create( + SILMod, *linkage, original, parameterIndices, resultIndices, + derivativeGenSig, jvp, vjp, isSerialized); + diffWitnessOrOffset.set(diffWitness, /*isFullyDeserialized*/ true); + return diffWitness; +} + +void SILDeserializer::getAllDifferentiabilityWitnesses() { + for (unsigned I = 0, E = DifferentiabilityWitnesses.size(); I < E; ++I) + readDifferentiabilityWitness(I+1); +} +// SWIFT_ENABLE_TENSORFLOW END + void SILDeserializer::readWitnessTableEntries( llvm::BitstreamEntry &entry, std::vector &witnessEntries, diff --git a/lib/Serialization/DeserializeSIL.h b/lib/Serialization/DeserializeSIL.h index c2790c01e48c2..2a31f713c0064 100644 --- a/lib/Serialization/DeserializeSIL.h +++ b/lib/Serialization/DeserializeSIL.h @@ -58,6 +58,12 @@ namespace swift { MutableArrayRef> Properties; + // SWIFT_ENABLE_TENSORFLOW + MutableArrayRef< + ModuleFile::PartiallySerialized> + DifferentiabilityWitnesses; + // SWIFT_ENABLE_TENSORFLOW END + /// A declaration will only llvm::DenseMap ConformanceToWitnessTableMap; @@ -126,6 +132,10 @@ namespace swift { SILDefaultWitnessTable * readDefaultWitnessTable(serialization::DeclID, SILDefaultWitnessTable *existingWt); + // SWIFT_ENABLE_TENSORFLOW + SILDifferentiabilityWitness * + readDifferentiabilityWitness(serialization::DeclID); + // SWIFT_ENABLE_TENSORFLOW Optional readKeyPathComponent(ArrayRef ListOfValues, unsigned &nextValue); @@ -169,6 +179,9 @@ namespace swift { getAllWitnessTables(); getAllDefaultWitnessTables(); getAllProperties(); + // SWIFT_ENABLE_TENSORFLOW + getAllDifferentiabilityWitnesses(); + // SWIFT_ENABLE_TENSORFLOW END } /// Deserialize all SILFunctions inside the module and add them to SILMod. @@ -192,6 +205,12 @@ namespace swift { /// to SILMod. void getAllProperties(); + // SWIFT_ENABLE_TENSORFLOW + /// Deserialize all DifferentiabilityWitnesses inside the module and add + /// them to SILMod. + void getAllDifferentiabilityWitnesses(); + // SWIFT_ENABLE_TENSORFLOW END + SILDeserializer(ModuleFile *MF, SILModule &M, DeserializationNotificationHandlerSet *callback); diff --git a/lib/Serialization/SILFormat.h b/lib/Serialization/SILFormat.h index 9ee90b4e22e43..92b40929fb932 100644 --- a/lib/Serialization/SILFormat.h +++ b/lib/Serialization/SILFormat.h @@ -122,6 +122,9 @@ namespace sil_index_block { SIL_DEFAULT_WITNESS_TABLE_NAMES, SIL_DEFAULT_WITNESS_TABLE_OFFSETS, SIL_PROPERTY_OFFSETS, + // SWIFT_ENABLE_TENSORFLOW + SIL_DIFFERENTIABILITY_WITNESS_OFFSETS, + // SWIFT_ENABLE_TENSORFLOW END }; using ListLayout = BCGenericRecordLayout< @@ -176,6 +179,8 @@ namespace sil_block { SIL_DIFFERENTIABLE_ATTR, SIL_INST_DIFFERENTIABLE_FUNCTION, SIL_INST_DIFFERENTIABLE_FUNCTION_EXTRACT, + SIL_DIFFERENTIABILITY_WITNESS, + // SWIFT_ENABLE_TENSORFLOW END // We also share these layouts from the decls block. Their enumerators must // not overlap with ours. @@ -280,6 +285,21 @@ namespace sil_block { DeclIDField >; + // SWIFT_ENABLE_TENSORFLOW + using DifferentiabilityWitnessLayout = BCRecordLayout< + SIL_DIFFERENTIABILITY_WITNESS, + DeclIDField, // Original function name + SILLinkageField, // Linkage + BCFixed<1>, // Is serialized? + GenericSignatureIDField, // Derivative function generic signature + DeclIDField, // JVP function name + DeclIDField, // VJP function name + BCVBR<8>, // Number of parameter indices + BCVBR<8>, // Number of result indices + BCArray // Parameter and result indices + >; + // SWIFT_ENABLE_TENSORFLOW END + using SILFunctionLayout = BCRecordLayout, // transparent diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index aac745d43601d..dcffb9d7848d6 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -826,6 +826,9 @@ void Serializer::writeBlockInfoBlock() { BLOCK_RECORD(sil_index_block, SIL_DEFAULT_WITNESS_TABLE_NAMES); BLOCK_RECORD(sil_index_block, SIL_DEFAULT_WITNESS_TABLE_OFFSETS); BLOCK_RECORD(sil_index_block, SIL_PROPERTY_OFFSETS); + // SWIFT_ENABLE_TENSORFLOW + BLOCK_RECORD(sil_index_block, SIL_DIFFERENTIABILITY_WITNESS_OFFSETS); + // SWIFT_ENABLE_TENSORFLOW END #undef BLOCK #undef BLOCK_RECORD diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp index 67f6e8aeaba48..40bf6c0e601a9 100644 --- a/lib/Serialization/SerializeSIL.cpp +++ b/lib/Serialization/SerializeSIL.cpp @@ -183,6 +183,11 @@ namespace { /// Holds the list of Properties. std::vector PropertyOffset; + // SWIFT_ENABLE_TENSORFLOW + /// Holds the list of SIL differentiability witnesses. + std::vector DifferentiabilityWitnessOffset; + // SWIFT_ENABLE_TENSORFLOW END + /// Give each SILBasicBlock a unique ID. llvm::DenseMap BasicBlockMap; @@ -232,6 +237,10 @@ namespace { void writeSILWitnessTable(const SILWitnessTable &wt); void writeSILWitnessTableEntry(const SILWitnessTable::Entry &entry); void writeSILDefaultWitnessTable(const SILDefaultWitnessTable &wt); + // SWIFT_ENABLE_TENSORFLOW + void writeSILDifferentiabilityWitness( + const SILDifferentiabilityWitness &dw); + // SWIFT_ENABLE_TENSORFLOW END void writeSILProperty(const SILProperty &prop); void writeSILBlock(const SILModule *SILMod); @@ -2278,6 +2287,13 @@ void SILSerializer::writeIndexTables() { PropertyOffset); } + // SWIFT_ENABLE_TENSORFLOW + if (!DifferentiabilityWitnessOffset.empty()) { + Offset.emit(ScratchRecord, + sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_OFFSETS, + DifferentiabilityWitnessOffset); + } + // SWIFT_ENABLE_TENSORFLOW END } void SILSerializer::writeSILGlobalVar(const SILGlobalVariable &g) { @@ -2469,6 +2485,46 @@ writeSILDefaultWitnessTable(const SILDefaultWitnessTable &wt) { } } +// SWIFT_ENABLE_TENSORFLOW +void SILSerializer:: +writeSILDifferentiabilityWitness(const SILDifferentiabilityWitness &dw) { + DifferentiabilityWitnessOffset.push_back(Out.GetCurrentBitNo()); + + auto *original = dw.getOriginalFunction(); + addReferencedSILFunction(original, /*DeclOnly*/ true); + IdentifierID jvpID = 0; + IdentifierID vjpID = 0; + if (auto *jvp = dw.getJVP()) { + addReferencedSILFunction(jvp, /*DeclOnly*/ true); + jvpID = S.addUniquedStringRef(jvp->getName()); + } + if (auto *vjp = dw.getVJP()) { + addReferencedSILFunction(vjp, /*DeclOnly*/ true); + vjpID = S.addUniquedStringRef(vjp->getName()); + } + SmallVector parameterAndResultIndices( + dw.getParameterIndices()->begin(), dw.getParameterIndices()->end()); + parameterAndResultIndices.append(dw.getResultIndices()->begin(), + dw.getResultIndices()->end()); + auto originalFnType = original->getLoweredFunctionType(); + assert(originalFnType->getNumParameters() == + dw.getParameterIndices()->getCapacity()); + assert(originalFnType->getNumResults() == + dw.getResultIndices()->getCapacity()); + + DifferentiabilityWitnessLayout::emitRecord( + Out, ScratchRecord, SILAbbrCodes[DifferentiabilityWitnessLayout::Code], + S.addUniquedStringRef(original->getName()), + toStableSILLinkage(dw.getLinkage()), + dw.isSerialized(), + S.addGenericSignatureRef(dw.getDerivativeGenericSignature()), + jvpID, vjpID, + dw.getParameterIndices()->getNumIndices(), + dw.getResultIndices()->getNumIndices(), + parameterAndResultIndices); +} +// SWIFT_ENABLE_TENSORFLOW END + /// Helper function for whether to emit a function body. bool SILSerializer::shouldEmitFunctionBody(const SILFunction *F, bool isReference) { @@ -2529,6 +2585,9 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) { registerSILAbbr(); registerSILAbbr(); registerSILAbbr(); + // SWIFT_ENABLE_TENSORFLOW + registerSILAbbr(); + // SWIFT_ENABLE_TENSORFLOW END registerSILAbbr(); registerSILAbbr(); @@ -2537,6 +2596,7 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) { registerSILAbbr(); registerSILAbbr(); registerSILAbbr(); + // SWIFT_ENABLE_TENSORFLOW END // Register the abbreviation codes so these layouts can exist in both // decl blocks and sil blocks. @@ -2589,6 +2649,17 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) { writeSILDefaultWitnessTable(wt); } + // SWIFT_ENABLE_TENSORFLOW + // Write out differentiability witnesses. + for (const auto &diffWitness : SILMod->getDifferentiabilityWitnessList()) { + // TODO: Consider checking + // `SILMod->shouldSerializeEntitiesAssociatedWithDeclContext` on the JVP/VJP + // functions. + if ((ShouldSerializeAll || diffWitness.isSerialized())) + writeSILDifferentiabilityWitness(diffWitness); + } + // SWIFT_ENABLE_TENSORFLOW END + // Emit only declarations if it is a module with pre-specializations. // And only do it in optimized builds. bool emitDeclarationsForOnoneSupport = diff --git a/lib/Serialization/SerializedSILLoader.cpp b/lib/Serialization/SerializedSILLoader.cpp index 3e910c8e5324f..0227ec22cbbb2 100644 --- a/lib/Serialization/SerializedSILLoader.cpp +++ b/lib/Serialization/SerializedSILLoader.cpp @@ -185,3 +185,10 @@ void SerializedSILLoader::getAllProperties() { Des->getAllProperties(); } +// SWIFT_ENABLE_TENSORFLOW +/// Deserialize all DifferentiabilityWitnesses in all SILModules. +void SerializedSILLoader::getAllDifferentiabilityWitnesses() { + for (auto &Des : LoadedSILSections) + Des->getAllDifferentiabilityWitnesses(); +} +// SWIFT_ENABLE_TENSORFLOW END From 419eea2b513521a83b8bfc3cedb712b224bbd009 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 10 Oct 2019 21:25:30 -0700 Subject: [PATCH 08/26] Revamp serialization to enable lookup by key. Use mangling to support string key lookup. --- include/swift/AST/ASTMangler.h | 26 ++- include/swift/AST/AutoDiff.h | 8 + .../swift/SIL/SILDifferentiabilityWitness.h | 11 +- include/swift/SIL/SILModule.h | 5 +- .../swift/Serialization/SerializedSILLoader.h | 10 + lib/AST/ASTMangler.cpp | 28 ++- lib/SIL/SILDifferentiabilityWitness.cpp | 5 + lib/Serialization/DeserializeSIL.cpp | 178 ++++++++++-------- lib/Serialization/DeserializeSIL.h | 7 +- lib/Serialization/SILFormat.h | 1 + lib/Serialization/SerializeSIL.cpp | 24 ++- lib/Serialization/SerializedSILLoader.cpp | 13 ++ 12 files changed, 205 insertions(+), 111 deletions(-) diff --git a/include/swift/AST/ASTMangler.h b/include/swift/AST/ASTMangler.h index 6ea3807732d3d..8aa1b8736e994 100644 --- a/include/swift/AST/ASTMangler.h +++ b/include/swift/AST/ASTMangler.h @@ -155,23 +155,31 @@ class ASTMangler : public Mangler { ModuleDecl *Module); // SWIFT_ENABLE_TENSORFLOW - // Mangle the derivative function (JVP/VJP) with the given: - // - Mangled original function name. - // - Derivative function kind. - // - Parameter/result indices. + /// Mangle the derivative function (JVP/VJP) with the given: + /// - Mangled original function name. + /// - Derivative function kind. + /// - Parameter/result indices. std::string mangleAutoDiffDerivativeFunctionHelper( StringRef name, AutoDiffDerivativeFunctionKind kind, const SILAutoDiffIndices &indices); - // SWIFT_ENABLE_TENSORFLOW - // Mangle the autodiff linear map (differential/pullback) with the given: - // - Mangled original function name. - // - Linear map kind. - // - Parameter/result indices. + /// Mangle the autodiff linear map (differential/pullback) with the given: + /// - Mangled original function name. + /// - Linear map kind. + /// - Parameter/result indices. std::string mangleAutoDiffLinearMapHelper( StringRef name, AutoDiffLinearMapKind kind, const SILAutoDiffIndices &indices); + /// Mangle a SIL differentiability witness key. + /// - Mangled original function name. + /// - Parameter indices. + /// - Result indices. + /// - Derivative generic signature (optional). + std::string mangleSILDifferentiabilityWitnessKey( + SILDifferentiabilityWitnessKey key); + // SWIFT_ENABLE_TENSORFLOW END + std::string mangleKeyPathGetterThunkHelper(const AbstractStorageDecl *property, GenericSignature *signature, CanType baseType, diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 6334a6c5d05fc..6db6898cd6101 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -481,6 +481,14 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode { } }; +/// The key type used for uniquing `SILDifferentiabilityWitness` in +/// `SILModule`: original function name, parameter indices, result indices, and +/// derivative generic signature. +// TODO: Unify with `AutoDiffDerivativeFunctionIdentifier`. +using SILDifferentiabilityWitnessKey = +std::tuple; + /// Automatic differentiation utility namespace. namespace autodiff { /// Appends the subset's parameter's types to `result`, in the order in diff --git a/include/swift/SIL/SILDifferentiabilityWitness.h b/include/swift/SIL/SILDifferentiabilityWitness.h index f1a44197548ed..3b5d44365212d 100644 --- a/include/swift/SIL/SILDifferentiabilityWitness.h +++ b/include/swift/SIL/SILDifferentiabilityWitness.h @@ -75,16 +75,7 @@ class SILDifferentiabilityWitness serialized(isSerialized) {} public: - /// The key type, used for uniquing `SILDifferentiabilityWitness` in - /// `SILModule`, original function, parameter indices, result indices, and - /// derivative generic signature. - using Key = std::tuple; - Key getKey() { - return std::make_tuple(originalFunction, parameterIndices, resultIndices, - derivativeGenericSignature); - } - + SILDifferentiabilityWitnessKey getKey() const; SILModule &getModule() const { return module; } SILLinkage getLinkage() const { return linkage; } SILFunction *getOriginalFunction() const { return originalFunction; } diff --git a/include/swift/SIL/SILModule.h b/include/swift/SIL/SILModule.h index c434986ca97df..dac005fd4e875 100644 --- a/include/swift/SIL/SILModule.h +++ b/include/swift/SIL/SILModule.h @@ -207,9 +207,8 @@ class SILModule { /// Lookup table for SIL differentiability witnesses from original functions. /// Indexed by key type: original function, parameter indices, result indices, /// and derivative generic signature. - llvm::DenseMap - DifferentiabilityWitnessMap; + llvm::DenseMap + DifferentiabilityWitnessMap; /// The list of SILDifferentiabilityWitnesses in the module. DifferentiabilityWitnessListType differentiabilityWitnesses; diff --git a/include/swift/Serialization/SerializedSILLoader.h b/include/swift/Serialization/SerializedSILLoader.h index c78c446f16408..955d3732ea69d 100644 --- a/include/swift/Serialization/SerializedSILLoader.h +++ b/include/swift/Serialization/SerializedSILLoader.h @@ -13,6 +13,9 @@ #ifndef SWIFT_SERIALIZATION_SILLOADER_H #define SWIFT_SERIALIZATION_SILLOADER_H +// SWIFT_ENABLE_TENSORFLOW +#include "swift/AST/AutoDiff.h" +// SWIFT_ENABLE_TENSORFLOW END #include "swift/AST/Decl.h" #include "swift/AST/Identifier.h" #include "swift/SIL/Notifications.h" @@ -32,6 +35,9 @@ class SILModule; class SILVTable; class SILWitnessTable; class SILDefaultWitnessTable; +// SWIFT_ENABLE_TENSORFLOW +class SILDifferentiabilityWitness; +// SWIFT_ENABLE_TENSORFLOW END /// Maintains a list of SILDeserializer, one for each serialized modules /// in ASTContext. It provides lookupSILFunction that will perform lookup @@ -64,6 +70,10 @@ class SerializedSILLoader { SILVTable *lookupVTable(const ClassDecl *C); SILWitnessTable *lookupWitnessTable(SILWitnessTable *C); SILDefaultWitnessTable *lookupDefaultWitnessTable(SILDefaultWitnessTable *C); + // SWIFT_ENABLE_TENSORFLOW + SILDifferentiabilityWitness * + lookupDifferentiabilityWitness(SILDifferentiabilityWitnessKey key); + // SWIFT_ENABLE_TENSORFLOW END /// Invalidate the cached entries for deserialized SILFunctions. void invalidateCaches(); diff --git a/lib/AST/ASTMangler.cpp b/lib/AST/ASTMangler.cpp index efd460f158243..8e8850ec4ebae 100644 --- a/lib/AST/ASTMangler.cpp +++ b/lib/AST/ASTMangler.cpp @@ -379,11 +379,12 @@ std::string ASTMangler::mangleReabstractionThunkHelper( return finalize(); } +// SWIFT_ENABLE_TENSORFLOW std::string ASTMangler::mangleAutoDiffDerivativeFunctionHelper( StringRef name, AutoDiffDerivativeFunctionKind kind, const SILAutoDiffIndices &indices) { // TODO(TF-20): Make the mangling scheme robust. - // TODO(TF-680): Mangle `@differentiable` atttribute requirements as well. + // TODO(TF-680): Mangle derivative generic signature as well. beginManglingWithoutPrefix(); Buffer << "AD__" << name << '_'; @@ -406,7 +407,7 @@ std::string ASTMangler::mangleAutoDiffLinearMapHelper( StringRef name, AutoDiffLinearMapKind kind, const SILAutoDiffIndices &indices) { // TODO(TF-20): Make the mangling scheme robust. - // TODO(TF-680): Mangle `@differentiable` atttribute requirements as well. + // TODO(TF-680): Mangle derivative generic signature as well. beginManglingWithoutPrefix(); Buffer << "AD__" << name << '_'; @@ -425,6 +426,29 @@ std::string ASTMangler::mangleAutoDiffLinearMapHelper( return result; } +std::string ASTMangler::mangleSILDifferentiabilityWitnessKey( + SILDifferentiabilityWitnessKey key) { + // TODO(TF-20): Make the mangling scheme robust. + // TODO(TF-680): Mangle derivative generic signature as well. + beginManglingWithoutPrefix(); + + auto originalName = std::get<0>(key); + auto *parameterIndices = std::get<1>(key); + auto *resultIndices = std::get<2>(key); + auto *derivativeGenericSignature = std::get<3>(key); + + Buffer << "AD__" << originalName << '_'; + Buffer << "P" << parameterIndices->getString(); + Buffer << "R" << resultIndices->getString(); + if (derivativeGenericSignature) + appendGenericSignature(derivativeGenericSignature); + + auto result = Storage.str().str(); + Storage.clear(); + return result; +} +// SWIFT_ENABLE_TENSORFLOW END + std::string ASTMangler::mangleTypeForDebugger(Type Ty, const DeclContext *DC) { PrettyStackTraceType prettyStackTrace(Ty->getASTContext(), "mangling type for debugger", Ty); diff --git a/lib/SIL/SILDifferentiabilityWitness.cpp b/lib/SIL/SILDifferentiabilityWitness.cpp index 0437d961ae835..a18f287b53a7d 100644 --- a/lib/SIL/SILDifferentiabilityWitness.cpp +++ b/lib/SIL/SILDifferentiabilityWitness.cpp @@ -34,3 +34,8 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::create( module.getDifferentiabilityWitnessList().push_back(diffWitness); return diffWitness; } + +SILDifferentiabilityWitnessKey SILDifferentiabilityWitness::getKey() const { + return std::make_tuple(originalFunction->getName(), parameterIndices, + resultIndices, derivativeGenericSignature); +} diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index c8fb18f9bb549..06ab9c8f2cdb3 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -149,7 +149,7 @@ SILDeserializer::SILDeserializer( // omitted if no entries exist in the module file. unsigned kind = 0; // SWIFT_ENABLE_TENSORFLOW - while (kind != sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_OFFSETS) { + while (kind != sil_index_block::SIL_PROPERTY_OFFSETS) { // SWIFT_ENABLE_TENSORFLOW END auto next = cursor.advance(); if (next.Kind == llvm::BitstreamEntry::EndBlock) @@ -168,10 +168,10 @@ SILDeserializer::SILDeserializer( kind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_NAMES || kind == sil_index_block::SIL_PROPERTY_OFFSETS || // SWIFT_ENABLE_TENSORFLOW - kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_OFFSETS)) && + kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES)) && "Expect SIL_FUNC_NAMES, SIL_VTABLE_NAMES, SIL_GLOBALVAR_NAMES, \ SIL_WITNESS_TABLE_NAMES, SIL_DEFAULT_WITNESS_TABLE_NAMES, \ - SIL_PROPERTY_OFFSETS, or SIL_DIFFERENTIABILITY_WITNESS_OFFSETS."); + SIL_DIFFERENTIABILITY_WITNESS_NAMES, or SIL_PROPERTY_OFFSETS."); // SWIFT_ENABLE_TENSORFLOW END (void)prevKind; @@ -185,17 +185,15 @@ SILDeserializer::SILDeserializer( WitnessTableList = readFuncTable(scratch, blobData); else if (kind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_NAMES) DefaultWitnessTableList = readFuncTable(scratch, blobData); + // SWIFT_ENABLE_TENSORFLOW + else if (kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES) + DifferentiabilityWitnessList = readFuncTable(scratch, blobData); + // SWIFT_ENABLE_TENSORFLOW END else if (kind == sil_index_block::SIL_PROPERTY_OFFSETS) { // No matching 'names' block for property descriptors needed yet. MF->allocateBuffer(Properties, scratch); -// SWIFT_ENABLE_TENSORFLOW - } - else if (kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_OFFSETS) { - // No matching 'names' block for differentiability witnesses needed yet. - MF->allocateBuffer(DifferentiabilityWitnesses, scratch); return; } -// SWIFT_ENABLE_TENSORFLOW END // Read SIL_FUNC|VTABLE|GLOBALVAR_OFFSETS record. next = cursor.advance(); @@ -227,6 +225,12 @@ SILDeserializer::SILDeserializer( offKind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_OFFSETS) && "Expect a SIL_DEFAULT_WITNESS_TABLE_OFFSETS record."); MF->allocateBuffer(DefaultWitnessTables, scratch); + } else if (kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES) { + assert((next.Kind == llvm::BitstreamEntry::Record && + offKind == + sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_OFFSETS) && + "Expect a SIL_DIFFERENTIABILITY_WITNESS_OFFSETS record."); + MF->allocateBuffer(DifferentiabilityWitnesses, scratch); } } } @@ -2957,78 +2961,6 @@ void SILDeserializer::getAllProperties() { } } -// SWIFT_ENABLE_TENSORFLOW -SILDifferentiabilityWitness * -SILDeserializer::readDifferentiabilityWitness(DeclID DId) { - auto &diffWitnessOrOffset = DifferentiabilityWitnesses[DId-1]; - - if (diffWitnessOrOffset.isFullyDeserialized()) - return diffWitnessOrOffset.get(); - - BCOffsetRAII restoreOffset(SILCursor); - SILCursor.JumpToBit(diffWitnessOrOffset.getOffset()); - auto entry = SILCursor.advance(AF_DontPopBlockAtEnd); - if (entry.Kind == llvm::BitstreamEntry::Error) { - LLVM_DEBUG(llvm::dbgs() - << "Cursor advance error in readDifferentiabilityWitness.\n"); - return nullptr; - } - - SmallVector scratch; - StringRef blobData; - unsigned kind = SILCursor.readRecord(entry.ID, scratch, &blobData); - assert(kind == SIL_DIFFERENTIABILITY_WITNESS && - "Expected sil_differentiability_witness"); - (void)kind; - - DeclID originalNameId, jvpNameId, vjpNameId; - unsigned rawLinkage, isSerialized, numParameterIndices, numResultIndices; - GenericSignatureID derivativeGenSigID; - ArrayRef rawParameterAndResultIndices; - - DifferentiabilityWitnessLayout::readRecord( - scratch, originalNameId, rawLinkage, isSerialized, derivativeGenSigID, - jvpNameId, vjpNameId, numParameterIndices, numResultIndices, - rawParameterAndResultIndices); - - auto linkage = fromStableSILLinkage(rawLinkage); - assert(linkage && "Expected value linkage for sil_differentiability_witness"); - auto originalName = MF->getIdentifier(originalNameId).str(); - auto jvpName = MF->getIdentifier(jvpNameId).str(); - auto vjpName = MF->getIdentifier(vjpNameId).str(); - auto *original = getFuncForReference(originalName); - auto *jvp = getFuncForReference(jvpName); - auto *vjp = getFuncForReference(vjpName); - auto derivativeGenSig = MF->getGenericSignature(derivativeGenSigID); - - SmallVector parameterAndResultIndices( - rawParameterAndResultIndices.begin(), - rawParameterAndResultIndices.end()); - assert(parameterAndResultIndices.size() == - numParameterIndices + numResultIndices && - "Parameter/result indices count mismatch"); - auto *parameterIndices = AutoDiffIndexSubset::get( - MF->getContext(), original->getLoweredFunctionType()->getNumParameters(), - ArrayRef(parameterAndResultIndices) - .take_front(numParameterIndices)); - auto *resultIndices = AutoDiffIndexSubset::get( - MF->getContext(), original->getLoweredFunctionType()->getNumResults(), - ArrayRef(parameterAndResultIndices) - .take_back(numResultIndices)); - - auto *diffWitness = SILDifferentiabilityWitness::create( - SILMod, *linkage, original, parameterIndices, resultIndices, - derivativeGenSig, jvp, vjp, isSerialized); - diffWitnessOrOffset.set(diffWitness, /*isFullyDeserialized*/ true); - return diffWitness; -} - -void SILDeserializer::getAllDifferentiabilityWitnesses() { - for (unsigned I = 0, E = DifferentiabilityWitnesses.size(); I < E; ++I) - readDifferentiabilityWitness(I+1); -} -// SWIFT_ENABLE_TENSORFLOW END - void SILDeserializer::readWitnessTableEntries( llvm::BitstreamEntry &entry, std::vector &witnessEntries, @@ -3364,6 +3296,90 @@ SILDeserializer::lookupDefaultWitnessTable(SILDefaultWitnessTable *existingWt) { return Wt; } +// SWIFT_ENABLE_TENSORFLOW +SILDifferentiabilityWitness * +SILDeserializer::readDifferentiabilityWitness(DeclID DId) { + auto &diffWitnessOrOffset = DifferentiabilityWitnesses[DId-1]; + + if (diffWitnessOrOffset.isFullyDeserialized()) + return diffWitnessOrOffset.get(); + + BCOffsetRAII restoreOffset(SILCursor); + SILCursor.JumpToBit(diffWitnessOrOffset.getOffset()); + auto entry = SILCursor.advance(AF_DontPopBlockAtEnd); + if (entry.Kind == llvm::BitstreamEntry::Error) { + LLVM_DEBUG(llvm::dbgs() + << "Cursor advance error in readDifferentiabilityWitness.\n"); + return nullptr; + } + + SmallVector scratch; + StringRef blobData; + unsigned kind = SILCursor.readRecord(entry.ID, scratch, &blobData); + assert(kind == SIL_DIFFERENTIABILITY_WITNESS && + "Expected sil_differentiability_witness"); + (void)kind; + + DeclID originalNameId, jvpNameId, vjpNameId; + unsigned rawLinkage, isSerialized, numParameterIndices, numResultIndices; + GenericSignatureID derivativeGenSigID; + ArrayRef rawParameterAndResultIndices; + + DifferentiabilityWitnessLayout::readRecord( + scratch, originalNameId, rawLinkage, isSerialized, derivativeGenSigID, + jvpNameId, vjpNameId, numParameterIndices, numResultIndices, + rawParameterAndResultIndices); + + auto linkage = fromStableSILLinkage(rawLinkage); + assert(linkage && "Expected value linkage for sil_differentiability_witness"); + auto originalName = MF->getIdentifier(originalNameId).str(); + auto jvpName = MF->getIdentifier(jvpNameId).str(); + auto vjpName = MF->getIdentifier(vjpNameId).str(); + auto *original = getFuncForReference(originalName); + auto *jvp = getFuncForReference(jvpName); + auto *vjp = getFuncForReference(vjpName); + auto derivativeGenSig = MF->getGenericSignature(derivativeGenSigID); + + SmallVector parameterAndResultIndices( + rawParameterAndResultIndices.begin(), + rawParameterAndResultIndices.end()); + assert(parameterAndResultIndices.size() == + numParameterIndices + numResultIndices && + "Parameter/result indices count mismatch"); + auto *parameterIndices = AutoDiffIndexSubset::get( + MF->getContext(), original->getLoweredFunctionType()->getNumParameters(), + ArrayRef(parameterAndResultIndices) + .take_front(numParameterIndices)); + auto *resultIndices = AutoDiffIndexSubset::get( + MF->getContext(), original->getLoweredFunctionType()->getNumResults(), + ArrayRef(parameterAndResultIndices) + .take_back(numResultIndices)); + + auto *diffWitness = SILDifferentiabilityWitness::create( + SILMod, *linkage, original, parameterIndices, resultIndices, + derivativeGenSig, jvp, vjp, isSerialized); + diffWitnessOrOffset.set(diffWitness, /*isFullyDeserialized*/ true); + return diffWitness; +} + +SILDifferentiabilityWitness *SILDeserializer::lookupDifferentiabilityWitness( + StringRef mangledDiffWitnessKey) { + if (!DifferentiabilityWitnessList) + return nullptr; + auto iter = DifferentiabilityWitnessList->find(mangledDiffWitnessKey); + if (iter == DifferentiabilityWitnessList->end()) + return nullptr; + + auto *diffWitness = readDifferentiabilityWitness(*iter); + return diffWitness; +} + +void SILDeserializer::getAllDifferentiabilityWitnesses() { + for (unsigned I = 0, E = DifferentiabilityWitnesses.size(); I < E; ++I) + readDifferentiabilityWitness(I+1); +} +// SWIFT_ENABLE_TENSORFLOW END + SILDeserializer::~SILDeserializer() { // Drop our references to anything we've deserialized. for (auto &fnEntry : Funcs) { diff --git a/lib/Serialization/DeserializeSIL.h b/lib/Serialization/DeserializeSIL.h index 2a31f713c0064..76b118f30d0d6 100644 --- a/lib/Serialization/DeserializeSIL.h +++ b/lib/Serialization/DeserializeSIL.h @@ -59,6 +59,7 @@ namespace swift { Properties; // SWIFT_ENABLE_TENSORFLOW + std::unique_ptr DifferentiabilityWitnessList; MutableArrayRef< ModuleFile::PartiallySerialized> DifferentiabilityWitnesses; @@ -135,7 +136,7 @@ namespace swift { // SWIFT_ENABLE_TENSORFLOW SILDifferentiabilityWitness * readDifferentiabilityWitness(serialization::DeclID); - // SWIFT_ENABLE_TENSORFLOW + // SWIFT_ENABLE_TENSORFLOW END Optional readKeyPathComponent(ArrayRef ListOfValues, unsigned &nextValue); @@ -155,6 +156,10 @@ namespace swift { SILWitnessTable *lookupWitnessTable(SILWitnessTable *wt); SILDefaultWitnessTable * lookupDefaultWitnessTable(SILDefaultWitnessTable *wt); + // SWIFT_ENABLE_TENSORFLOW + SILDifferentiabilityWitness * + lookupDifferentiabilityWitness(StringRef mangledDiffWitnessKey); + // SWIFT_ENABLE_TENSORFLOW END /// Invalidate all cached SILFunctions. void invalidateFunctionCache(); diff --git a/lib/Serialization/SILFormat.h b/lib/Serialization/SILFormat.h index 92b40929fb932..df862fcb5c607 100644 --- a/lib/Serialization/SILFormat.h +++ b/lib/Serialization/SILFormat.h @@ -123,6 +123,7 @@ namespace sil_index_block { SIL_DEFAULT_WITNESS_TABLE_OFFSETS, SIL_PROPERTY_OFFSETS, // SWIFT_ENABLE_TENSORFLOW + SIL_DIFFERENTIABILITY_WITNESS_NAMES, SIL_DIFFERENTIABILITY_WITNESS_OFFSETS, // SWIFT_ENABLE_TENSORFLOW END }; diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp index 40bf6c0e601a9..161359587fafe 100644 --- a/lib/Serialization/SerializeSIL.cpp +++ b/lib/Serialization/SerializeSIL.cpp @@ -184,8 +184,11 @@ namespace { std::vector PropertyOffset; // SWIFT_ENABLE_TENSORFLOW + /// Maps differentiability witness identifier to an ID. + Table DifferentiabilityWitnessList; /// Holds the list of SIL differentiability witnesses. std::vector DifferentiabilityWitnessOffset; + uint32_t /*DeclID*/ NextDifferentiabilityWitnessID = 1; // SWIFT_ENABLE_TENSORFLOW END /// Give each SILBasicBlock a unique ID. @@ -2282,18 +2285,21 @@ void SILSerializer::writeIndexTables() { DefaultWitnessTableOffset); } - if (!PropertyOffset.empty()) { - Offset.emit(ScratchRecord, sil_index_block::SIL_PROPERTY_OFFSETS, - PropertyOffset); - } - // SWIFT_ENABLE_TENSORFLOW if (!DifferentiabilityWitnessOffset.empty()) { + writeIndexTable(S, List, + sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES, + DifferentiabilityWitnessList); Offset.emit(ScratchRecord, sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_OFFSETS, DifferentiabilityWitnessOffset); } // SWIFT_ENABLE_TENSORFLOW END + + if (!PropertyOffset.empty()) { + Offset.emit(ScratchRecord, sil_index_block::SIL_PROPERTY_OFFSETS, + PropertyOffset); + } } void SILSerializer::writeSILGlobalVar(const SILGlobalVariable &g) { @@ -2488,6 +2494,14 @@ writeSILDefaultWitnessTable(const SILDefaultWitnessTable &wt) { // SWIFT_ENABLE_TENSORFLOW void SILSerializer:: writeSILDifferentiabilityWitness(const SILDifferentiabilityWitness &dw) { + Mangle::ASTMangler mangler; + auto mangledKey = mangler.mangleSILDifferentiabilityWitnessKey(dw.getKey()); + size_t nameLength = mangledKey.size(); + char *stringStorage = + static_cast(StringTable.Allocate(nameLength, 1)); + std::memcpy(stringStorage, mangledKey.data(), nameLength); + DifferentiabilityWitnessList[StringRef(stringStorage, nameLength)] = + NextDifferentiabilityWitnessID++; DifferentiabilityWitnessOffset.push_back(Out.GetCurrentBitNo()); auto *original = dw.getOriginalFunction(); diff --git a/lib/Serialization/SerializedSILLoader.cpp b/lib/Serialization/SerializedSILLoader.cpp index 0227ec22cbbb2..f9223c4fabf4c 100644 --- a/lib/Serialization/SerializedSILLoader.cpp +++ b/lib/Serialization/SerializedSILLoader.cpp @@ -127,6 +127,19 @@ lookupDefaultWitnessTable(SILDefaultWitnessTable *WT) { return nullptr; } +// SWIFT_ENABLE_TENSORFLOW +SILDifferentiabilityWitness * +SerializedSILLoader::lookupDifferentiabilityWitness( + SILDifferentiabilityWitnessKey key) { + Mangle::ASTMangler mangler; + std::string mangledKey = mangler.mangleSILDifferentiabilityWitnessKey(key); + for (auto &Des : LoadedSILSections) + if (auto *diffWitness = Des->lookupDifferentiabilityWitness(mangledKey)) + return diffWitness; + return nullptr; +} +// SWIFT_ENABLE_TENSORFLOW END + void SerializedSILLoader::invalidateCaches() { for (auto &Des : LoadedSILSections) Des->invalidateFunctionCache(); From b6cd1d7727226455e349d33c5cecc8917295f159 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 10 Oct 2019 21:51:41 -0700 Subject: [PATCH 09/26] Add SIL verification. --- .../swift/SIL/SILDifferentiabilityWitness.h | 13 +++-- lib/SIL/SILVerifier.cpp | 53 +++++++++++++++++++ 2 files changed, 61 insertions(+), 5 deletions(-) diff --git a/include/swift/SIL/SILDifferentiabilityWitness.h b/include/swift/SIL/SILDifferentiabilityWitness.h index 3b5d44365212d..3b687f7889d91 100644 --- a/include/swift/SIL/SILDifferentiabilityWitness.h +++ b/include/swift/SIL/SILDifferentiabilityWitness.h @@ -75,6 +75,12 @@ class SILDifferentiabilityWitness serialized(isSerialized) {} public: + static SILDifferentiabilityWitness *create( + SILModule &module, SILLinkage linkage, SILFunction *originalFunction, + AutoDiffIndexSubset *parameterIndices, AutoDiffIndexSubset *resultIndices, + GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp, + bool isSerialized); + SILDifferentiabilityWitnessKey getKey() const; SILModule &getModule() const { return module; } SILLinkage getLinkage() const { return linkage; } @@ -92,11 +98,8 @@ class SILDifferentiabilityWitness SILFunction *getVJP() const { return vjp; } bool isSerialized() const { return serialized; } - static SILDifferentiabilityWitness *create( - SILModule &module, SILLinkage linkage, SILFunction *originalFunction, - AutoDiffIndexSubset *parameterIndices, AutoDiffIndexSubset *resultIndices, - GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp, - bool isSerialized); + /// Verify that the differentiability witness is well-formed. + void verify(const SILModule &M) const; void print(llvm::raw_ostream &OS, bool verbose = false) const; void dump() const; diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp index 42ae01866d9bc..9271faf0f511b 100644 --- a/lib/SIL/SILVerifier.cpp +++ b/lib/SIL/SILVerifier.cpp @@ -5350,6 +5350,43 @@ void SILGlobalVariable::verify() const { } } +// SWIFT_ENABLE_TENSORFLOW +/// Verify that a differentiability witness follows invariants. +void SILDifferentiabilityWitness::verify(const SILModule &M) const { +#ifdef NDEBUG + if (!M.getOptions().VerifyAll) + return; +#endif + auto origFnType = originalFunction->getLoweredFunctionType(); + if (jvp) { + // TODO: Change `SILFunctionType::getAutoDiffDerivativeFunctionType` to + // accept result indices. + auto expectedJVPType = origFnType->getAutoDiffDerivativeFunctionType( + getParameterIndices(), /*resultIndex*/ *resultIndices->begin(), + AutoDiffDerivativeFunctionKind::JVP, M.Types, + LookUpConformanceInModule(M.getSwiftModule()), + getDerivativeGenericSignature()->getCanonicalSignature()); + SILVerifier(*jvp).requireSameType( + SILType::getPrimitiveObjectType(jvp->getLoweredFunctionType()), + SILType::getPrimitiveObjectType(expectedJVPType), + "JVP type does not match expected JVP type"); + } + if (vjp) { + // TODO: Change `SILFunctionType::getAutoDiffDerivativeFunctionType` to + // accept result indices. + auto expectedVJPType = origFnType->getAutoDiffDerivativeFunctionType( + getParameterIndices(), /*resultIndex*/ *resultIndices->begin(), + AutoDiffDerivativeFunctionKind::VJP, M.Types, + LookUpConformanceInModule(M.getSwiftModule()), + getDerivativeGenericSignature()->getCanonicalSignature()); + SILVerifier(*jvp).requireSameType( + SILType::getPrimitiveObjectType(vjp->getLoweredFunctionType()), + SILType::getPrimitiveObjectType(expectedVJPType), + "VJP type does not match expected VJP type"); + } +} +// SWIFT_ENABLE_TENSORFLOW END + /// Verify the module. void SILModule::verify() const { #ifdef NDEBUG @@ -5433,6 +5470,22 @@ void SILModule::verify() const { } wt.verify(*this); } + + // SWIFT_ENABLE_TENSORFLOW + // Check all differentiability witnesses. + LLVM_DEBUG(llvm::dbgs() << + "*** Checking differentiability witnesses for duplicates ***\n"); + llvm::DenseSet diffWitnesses; + for (auto &dw : getDifferentiabilityWitnesses()) { + LLVM_DEBUG(llvm::dbgs() << "Differentiability Witness:\n"; dw.dump()); + if (!diffWitnesses.insert(dw.getKey()).second) { + llvm::errs() << "Differentiability witness redefined: "; + dw.dump(); + assert(false && "triggering standard assertion failure routine"); + } + dw.verify(*this); + } + // SWIFT_ENABLE_TENSORFLOW END // Check property descriptors. LLVM_DEBUG(llvm::dbgs() << "*** Checking property descriptors ***\n"); From 573dd3efca1899c2aa5f2b377390228c583b4824 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 10 Oct 2019 21:56:26 -0700 Subject: [PATCH 10/26] Add miscellaneous todo comments. --- include/swift/AST/AutoDiff.h | 2 +- lib/AST/ASTMangler.cpp | 1 - lib/ParseSIL/ParseSIL.cpp | 4 ++-- lib/SIL/SILPrinter.cpp | 2 +- lib/SIL/SILVerifier.cpp | 8 ++++---- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 6db6898cd6101..ac11ed06b52e1 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -484,7 +484,7 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode { /// The key type used for uniquing `SILDifferentiabilityWitness` in /// `SILModule`: original function name, parameter indices, result indices, and /// derivative generic signature. -// TODO: Unify with `AutoDiffDerivativeFunctionIdentifier`. +// TODO(TF-893): Unify with `AutoDiffDerivativeFunctionIdentifier`. using SILDifferentiabilityWitnessKey = std::tuple; diff --git a/lib/AST/ASTMangler.cpp b/lib/AST/ASTMangler.cpp index 8e8850ec4ebae..e92dc011f8919 100644 --- a/lib/AST/ASTMangler.cpp +++ b/lib/AST/ASTMangler.cpp @@ -429,7 +429,6 @@ std::string ASTMangler::mangleAutoDiffLinearMapHelper( std::string ASTMangler::mangleSILDifferentiabilityWitnessKey( SILDifferentiabilityWitnessKey key) { // TODO(TF-20): Make the mangling scheme robust. - // TODO(TF-680): Mangle derivative generic signature as well. beginManglingWithoutPrefix(); auto originalName = std::get<0>(key); diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index b0ed0b7759ec0..e6c64c32ac43e 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -6746,7 +6746,7 @@ bool SILParserTUState::parseSILDefaultWitnessTable(Parser &P) { } // SWIFT_ENABLE_TENSORFLOW -// TODO: Dedupe with `SILParser::convertRequirements` upstream. +// TODO(TF-893): Dedupe with `SILParser::convertRequirements` upstream. // Consider defining this as `Parser::convertRequirements`. static void convertRequirements(Parser &P, SILFunction *F, ArrayRef From, @@ -6956,7 +6956,7 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { return true; } - // TODO: Parse `isSerialized` flag. + // TODO(TF-893): Parse `isSerialized` flag. bool isSerialized = false; SILDifferentiabilityWitness::create( M, *linkage, originalFn, parameterIndices, resultIndices, diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index 6fa826a10869a..25ba219d04c09 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -2712,7 +2712,7 @@ static void printSILDifferentiabilityWitnesses( std::sort(sortedDiffWitnesses.begin(), sortedDiffWitnesses.end(), [] (const SILDifferentiabilityWitness *w1, const SILDifferentiabilityWitness *w2) -> bool { - // TODO: Sort based on more criteria for deterministic ordering. + // TODO(TF-893): Sort based on more criteria for deterministic ordering. return w1->getOriginalFunction()->getName() .compare(w2->getOriginalFunction()->getName()); } diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp index 9271faf0f511b..6ce246081a749 100644 --- a/lib/SIL/SILVerifier.cpp +++ b/lib/SIL/SILVerifier.cpp @@ -5359,8 +5359,8 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const { #endif auto origFnType = originalFunction->getLoweredFunctionType(); if (jvp) { - // TODO: Change `SILFunctionType::getAutoDiffDerivativeFunctionType` to - // accept result indices. + // TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType` + // to accept result indices. auto expectedJVPType = origFnType->getAutoDiffDerivativeFunctionType( getParameterIndices(), /*resultIndex*/ *resultIndices->begin(), AutoDiffDerivativeFunctionKind::JVP, M.Types, @@ -5372,8 +5372,8 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const { "JVP type does not match expected JVP type"); } if (vjp) { - // TODO: Change `SILFunctionType::getAutoDiffDerivativeFunctionType` to - // accept result indices. + // TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType` + // to result indices. auto expectedVJPType = origFnType->getAutoDiffDerivativeFunctionType( getParameterIndices(), /*resultIndex*/ *resultIndices->begin(), AutoDiffDerivativeFunctionKind::VJP, M.Types, From 843b6316cb425f824037d34bdc8f19c0c19cc6a3 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 10 Oct 2019 22:06:39 -0700 Subject: [PATCH 11/26] Add round-trip parsing/printing test. --- .../sil_differentiability_witness_parse.sil | 126 ++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 test/AutoDiff/sil_differentiability_witness_parse.sil diff --git a/test/AutoDiff/sil_differentiability_witness_parse.sil b/test/AutoDiff/sil_differentiability_witness_parse.sil new file mode 100644 index 0000000000000..0bb6b9e57f581 --- /dev/null +++ b/test/AutoDiff/sil_differentiability_witness_parse.sil @@ -0,0 +1,126 @@ +// RUN: %target-sil-opt %s -module-name=sil_differentiability_witness_parse | %target-sil-opt -module-name=sil_differentiability_witness_parse | %FileCheck %s + +// Round-trip parsing and printing test. + +sil_stage raw + +import Builtin +import Swift +import SwiftShims + +@differentiable(wrt: (x, y), jvp: foo_jvp where T : Differentiable) +@_silgen_name("foo") +func foo(_ x: T, _ y: Float) -> T + +@_silgen_name("foo_jvp") +func foo_jvp(_ x: T, _ y: Float) -> (T, (T.TangentVector, Float) -> T.TangentVector) where T : Differentiable + +@_silgen_name("foo_vjp") +func foo_vjp(_ x: T, _ y: Float) -> (T, (T.TangentVector) -> (T.TangentVector, Float)) where T : Differentiable + +// main +sil [ossa] @main : $@convention(c) (Int32, UnsafeMutablePointer>>) -> Int32 { +bb0(%0 : $Int32, %1 : $UnsafeMutablePointer>>): + %2 = integer_literal $Builtin.Int32, 0 // user: %3 + %3 = struct $Int32 (%2 : $Builtin.Int32) // user: %4 + return %3 : $Int32 // id: %4 +} // end sil function 'main' + +// foo +sil hidden [differentiable source 0 wrt 0, 1 jvp @AD__foo__jvp_src_0_wrt_0_1 where T : Differentiable] [ossa] @foo : $@convention(thin) (@in_guaranteed T, Float) -> @out T { +// %0 // user: %5 +// %1 // users: %5, %3 +// %2 // user: %4 +bb0(%0 : $*T, %1 : $*T, %2 : $Float): + debug_value_addr %1 : $*T, let, name "x", argno 1 // id: %3 + debug_value %2 : $Float, let, name "y", argno 2 // id: %4 + copy_addr %1 to [initialization] %0 : $*T // id: %5 + %6 = tuple () // user: %7 + return %6 : $() // id: %7 +} // end sil function 'foo' + +// foo_jvp +sil hidden [ossa] @foo_jvp : $@convention(thin) (@in_guaranteed T, Float) -> (@out T, @owned @callee_guaranteed (@in_guaranteed T.TangentVector, Float) -> @out T.TangentVector) { +// %0 // user: %5 +// %1 // users: %5, %3 +// %2 // user: %4 +bb0(%0 : $*T, %1 : $*T, %2 : $Float): + debug_value_addr %1 : $*T, let, name "x", argno 1 // id: %3 + debug_value %2 : $Float, let, name "y", argno 2 // id: %4 + copy_addr %1 to [initialization] %0 : $*T // id: %5 + // function_ref closure #1 in foo_jvp(_:_:) + %6 = function_ref @$s4main7foo_jvpyx_13TangentVectorQzAD_Sftctx_Sfts14DifferentiableRzlFA2D_SftcfU_ : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector // user: %7 + %7 = partial_apply [callee_guaranteed] %6() : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector // user: %8 + return %7 : $@callee_guaranteed (@in_guaranteed T.TangentVector, Float) -> @out T.TangentVector // id: %8 +} // end sil function 'foo_jvp' + +// AD__foo__jvp_src_0_wrt_0_1 +sil hidden [transparent] [thunk] [ossa] @AD__foo__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) { +// %0 // user: %4 +// %1 // user: %4 +// %2 // user: %4 +bb0(%0 : $*τ_0_0, %1 : $*τ_0_0, %2 : $Float): + // function_ref foo_jvp + %3 = function_ref @foo_jvp : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) // user: %4 + %4 = apply %3<τ_0_0>(%0, %1, %2) : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) // user: %5 + return %4 : $@callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector // id: %5 +} // end sil function 'AD__foo__jvp_src_0_wrt_0_1' + +// closure #1 in foo_jvp(_:_:) +sil private [ossa] @$s4main7foo_jvpyx_13TangentVectorQzAD_Sftctx_Sfts14DifferentiableRzlFA2D_SftcfU_ : $@convention(thin) (@in_guaranteed T.TangentVector, Float) -> @out T.TangentVector { +// %0 // user: %5 +// %1 // users: %5, %3 +// %2 // user: %4 +bb0(%0 : $*T.TangentVector, %1 : $*T.TangentVector, %2 : $Float): + debug_value_addr %1 : $*T.TangentVector, let, name "dx", argno 1 // id: %3 + debug_value %2 : $Float, let, name "dy", argno 2 // id: %4 + copy_addr %1 to [initialization] %0 : $*T.TangentVector // id: %5 + %6 = tuple () // user: %7 + return %6 : $() // id: %7 +} // end sil function '$s4main7foo_jvpyx_13TangentVectorQzAD_Sftctx_Sfts14DifferentiableRzlFA2D_SftcfU_' + +// foo_vjp +sil hidden [ossa] @foo_vjp : $@convention(thin) (@in_guaranteed T, Float) -> (@out T, @owned @callee_guaranteed (@in_guaranteed T.TangentVector) -> (@out T.TangentVector, Float)) { +// %0 // user: %5 +// %1 // users: %5, %3 +// %2 // user: %4 +bb0(%0 : $*T, %1 : $*T, %2 : $Float): + debug_value_addr %1 : $*T, let, name "x", argno 1 // id: %3 + debug_value %2 : $Float, let, name "y", argno 2 // id: %4 + copy_addr %1 to [initialization] %0 : $*T // id: %5 + // function_ref closure #1 in foo_vjp(_:_:) + %6 = function_ref @$s4main7foo_vjpyx_13TangentVectorQz_SftADctx_Sfts14DifferentiableRzlFAD_SftADcfU_ : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float) // user: %7 + %7 = partial_apply [callee_guaranteed] %6() : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float) // user: %8 + return %7 : $@callee_guaranteed (@in_guaranteed T.TangentVector) -> (@out T.TangentVector, Float) // id: %8 +} // end sil function 'foo_vjp' + +// closure #1 in foo_vjp(_:_:) +sil private [ossa] @$s4main7foo_vjpyx_13TangentVectorQz_SftADctx_Sfts14DifferentiableRzlFAD_SftADcfU_ : $@convention(thin) (@in_guaranteed T.TangentVector) -> (@out T.TangentVector, Float) { +// %0 // user: %3 +// %1 // users: %3, %2 +bb0(%0 : $*T.TangentVector, %1 : $*T.TangentVector): + debug_value_addr %1 : $*T.TangentVector, let, name "$0", argno 1 // id: %2 + copy_addr %1 to [initialization] %0 : $*T.TangentVector // id: %3 + %4 = metatype $@thin Float.Type + %5 = alloc_stack $Float // users: %10, %9, %8 + %6 = metatype $@thick Float.Type // user: %8 + // function_ref static AdditiveArithmetic<>.zero.getter + %7 = function_ref @$ss18AdditiveArithmeticPss27ExpressibleByIntegerLiteralRzrlE4zeroxvgZ : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %8 + %8 = apply %7(%5, %6) : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0 + %9 = load [trivial] %5 : $*Float // user: %11 + dealloc_stack %5 : $*Float // id: %10 + return %9 : $Float // id: %11 +} // end sil function '$s4main7foo_vjpyx_13TangentVectorQz_SftADctx_Sfts14DifferentiableRzlFAD_SftADcfU_' + +// static AdditiveArithmetic<>.zero.getter +sil [serialized] [always_inline] @$ss18AdditiveArithmeticPss27ExpressibleByIntegerLiteralRzrlE4zeroxvgZ : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0 + +sil_differentiability_witness hidden @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 parameters (0, 1) results (0) where τ_0_0 : _Differentiable { + jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) + vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) +} + +// CHECK: sil_differentiability_witness hidden @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 parameters (0, 1) results (0) where τ_0_0 : _Differentiable { +// CHECK: jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) +// CHECK: vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) +// CHECK: } From 835f1c02f783d3770113c6a9cf2f3be2b6718296 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 10 Oct 2019 22:20:35 -0700 Subject: [PATCH 12/26] Clean up. --- lib/ParseSIL/ParseSIL.cpp | 31 +++++++++++-------- .../sil_differentiability_witness_parse.sil | 2 +- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index e6c64c32ac43e..b75b8d568386d 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -6830,7 +6830,8 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { Scope S(&P, ScopeKind::TopLevel); Scope Body(&P, ScopeKind::FunctionBody); - auto parseFunctionNameAndType = [&](SILFunction *&fn) -> bool { + // Parse a SIL function name. + auto parseFunctionName = [&](SILFunction *&fn) -> bool { Identifier name; SILType ty; SourceLoc fnNameLoc = P.Tok.getLoc(); @@ -6852,13 +6853,13 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { State.TUState.PotentialZombieFns.insert(fn); return false; }; - - SourceLoc lastLoc = P.getEndOfPreviousLoc(); - + // Parse original function name. SILFunction *originalFn; - if (parseFunctionNameAndType(originalFn)) + if (parseFunctionName(originalFn)) return true; + SourceLoc lastLoc = P.getEndOfPreviousLoc(); + // Parse an index subset, prefaced with the given label. auto parseAutoDiffIndexSubset = [&](StringRef label, AutoDiffIndexSubset *& paramIndexSubset) -> bool { if (P.parseSpecificIdentifier( @@ -6895,6 +6896,7 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { P.Context, maxIndexRef ? *maxIndexRef + 1 : 0, paramIndices); return false; }; + // Parse parameter and result indices. AutoDiffIndexSubset *parameterIndices = nullptr; AutoDiffIndexSubset *resultIndices = nullptr; if (parseAutoDiffIndexSubset("parameters", parameterIndices)) @@ -6902,8 +6904,9 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { if (parseAutoDiffIndexSubset("results", resultIndices)) return true; + // Parse a trailing 'where' clause (optional). + // This represents derivative generic signature requirements. GenericSignature *derivativeGenSig = nullptr; - // Parse a trailing 'where' clause if any. if (P.Tok.is(tok::kw_where)) { SourceLoc whereLoc; SmallVector requirementReprs; @@ -6925,34 +6928,36 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { nullptr); } + // Parse differentiability witness body. SILFunction *jvp = nullptr; SILFunction *vjp = nullptr; if (P.Tok.is(tok::l_brace)) { - SourceLoc LBraceLoc = P.Tok.getLoc(); + // Parse '{'. + SourceLoc lBraceLoc = P.Tok.getLoc(); P.consumeToken(tok::l_brace); - + // Parse JVP (optional). if (P.Tok.is(tok::identifier) && P.Tok.getText() == "jvp") { P.consumeToken(tok::identifier); if (P.parseToken(tok::colon, diag::sil_diff_witness_expected_keyword, ":")) return true; Scope Body(&P, ScopeKind::FunctionBody); - if (parseFunctionNameAndType(jvp)) + if (parseFunctionName(jvp)) return true; } - + // Parse VJP (optional). if (P.Tok.is(tok::identifier) && P.Tok.getText() == "vjp") { P.consumeToken(tok::identifier); if (P.parseToken(tok::colon, diag::sil_diff_witness_expected_keyword, ":")) return true; Scope Body(&P, ScopeKind::FunctionBody); - if (parseFunctionNameAndType(vjp)) + if (parseFunctionName(vjp)) return true; } - + // Parse '}'. if (P.parseMatchingToken(tok::r_brace, lastLoc, diag::expected_sil_rbrace, - LBraceLoc)) + lBraceLoc)) return true; } diff --git a/test/AutoDiff/sil_differentiability_witness_parse.sil b/test/AutoDiff/sil_differentiability_witness_parse.sil index 0bb6b9e57f581..e85710a5c156a 100644 --- a/test/AutoDiff/sil_differentiability_witness_parse.sil +++ b/test/AutoDiff/sil_differentiability_witness_parse.sil @@ -120,7 +120,7 @@ sil_differentiability_witness hidden @foo : $@convention(thin) <τ_0_0> (@in_gua vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) } -// CHECK: sil_differentiability_witness hidden @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 parameters (0, 1) results (0) where τ_0_0 : _Differentiable { +// CHECK-LABEL: sil_differentiability_witness hidden @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 parameters (0, 1) results (0) where τ_0_0 : _Differentiable { // CHECK: jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) // CHECK: vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) // CHECK: } From aca1abbf8eab86d59d9d47ace1aa58e8415e6b20 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 10 Oct 2019 22:25:25 -0700 Subject: [PATCH 13/26] Add Swift source for parsing test. --- .../sil_differentiability_witness_parse.sil | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test/AutoDiff/sil_differentiability_witness_parse.sil b/test/AutoDiff/sil_differentiability_witness_parse.sil index e85710a5c156a..46324f495324f 100644 --- a/test/AutoDiff/sil_differentiability_witness_parse.sil +++ b/test/AutoDiff/sil_differentiability_witness_parse.sil @@ -2,6 +2,24 @@ // Round-trip parsing and printing test. +// Swift source code (`-emit-silgen` output is below): +// +// @differentiable(jvp: foo_jvp where T: Differentiable) +// @_silgen_name("foo") +// func foo(_ x: T, _ y: Float) -> T { x } +// +// @_silgen_name("foo_jvp") +// func foo_jvp(_ x: T, _ y: Float) -> (T, (T.TangentVector, Float) -> T.TangentVector) { +// (x, { dx, dy in dx }) +// } +// +// @_silgen_name("foo_vjp") +// func foo_vjp(_ x: T, _ y: Float) -> (T, (T.TangentVector) -> (T.TangentVector, Float)) { +// (x, { ($0, .zero) }) +// } +// +// The `sil_differentiability_witness` at the end was manually written. + sil_stage raw import Builtin From 418770006d9a7b45e028af6c4ef9e9c52e9dc098 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Fri, 11 Oct 2019 10:24:55 -0700 Subject: [PATCH 14/26] `AutoDiffIndexSubset` -> `IndexSubset` --- include/swift/AST/AutoDiff.h | 3 +-- include/swift/SIL/SILDifferentiabilityWitness.h | 14 +++++++------- lib/ParseSIL/ParseSIL.cpp | 14 +++++++------- lib/SIL/SILDifferentiabilityWitness.cpp | 2 +- lib/Serialization/DeserializeSIL.cpp | 4 ++-- 5 files changed, 18 insertions(+), 19 deletions(-) diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index d36444239fb27..71c482e31c6b2 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -243,8 +243,7 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode { /// derivative generic signature. // TODO(TF-893): Unify with `AutoDiffDerivativeFunctionIdentifier`. using SILDifferentiabilityWitnessKey = -std::tuple; + std::tuple; /// Automatic differentiation utility namespace. namespace autodiff { diff --git a/include/swift/SIL/SILDifferentiabilityWitness.h b/include/swift/SIL/SILDifferentiabilityWitness.h index 3b687f7889d91..f75de5aaf39b1 100644 --- a/include/swift/SIL/SILDifferentiabilityWitness.h +++ b/include/swift/SIL/SILDifferentiabilityWitness.h @@ -49,9 +49,9 @@ class SILDifferentiabilityWitness /// The original function. SILFunction *originalFunction; /// The parameter indices. - AutoDiffIndexSubset *parameterIndices; + IndexSubset *parameterIndices; /// The result indices. - AutoDiffIndexSubset *resultIndices; + IndexSubset *resultIndices; /// The derivative generic signature (optional). GenericSignature *derivativeGenericSignature; /// The JVP (Jacobian-vector products) derivative function. @@ -64,8 +64,8 @@ class SILDifferentiabilityWitness SILDifferentiabilityWitness(SILModule &module, SILLinkage linkage, SILFunction *originalFunction, - AutoDiffIndexSubset *parameterIndices, - AutoDiffIndexSubset *resultIndices, + IndexSubset *parameterIndices, + IndexSubset *resultIndices, GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp, bool isSerialized) @@ -77,7 +77,7 @@ class SILDifferentiabilityWitness public: static SILDifferentiabilityWitness *create( SILModule &module, SILLinkage linkage, SILFunction *originalFunction, - AutoDiffIndexSubset *parameterIndices, AutoDiffIndexSubset *resultIndices, + IndexSubset *parameterIndices, IndexSubset *resultIndices, GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp, bool isSerialized); @@ -85,10 +85,10 @@ class SILDifferentiabilityWitness SILModule &getModule() const { return module; } SILLinkage getLinkage() const { return linkage; } SILFunction *getOriginalFunction() const { return originalFunction; } - AutoDiffIndexSubset *getParameterIndices() const { + IndexSubset *getParameterIndices() const { return parameterIndices; } - AutoDiffIndexSubset *getResultIndices() const { + IndexSubset *getResultIndices() const { return resultIndices; } GenericSignature *getDerivativeGenericSignature() const { diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index 22feb789ebb21..d7240366f36ea 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -6860,8 +6860,8 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { SourceLoc lastLoc = P.getEndOfPreviousLoc(); // Parse an index subset, prefaced with the given label. - auto parseAutoDiffIndexSubset = - [&](StringRef label, AutoDiffIndexSubset *& paramIndexSubset) -> bool { + auto parseIndexSubset = + [&](StringRef label, IndexSubset *& indexSubset) -> bool { if (P.parseSpecificIdentifier( label, diag::sil_diff_witness_expected_keyword, label)) return true; @@ -6892,16 +6892,16 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { return true; auto maxIndexRef = std::max_element(paramIndices.begin(), paramIndices.end()); - paramIndexSubset = AutoDiffIndexSubset::get( + indexSubset = IndexSubset::get( P.Context, maxIndexRef ? *maxIndexRef + 1 : 0, paramIndices); return false; }; // Parse parameter and result indices. - AutoDiffIndexSubset *parameterIndices = nullptr; - AutoDiffIndexSubset *resultIndices = nullptr; - if (parseAutoDiffIndexSubset("parameters", parameterIndices)) + IndexSubset *parameterIndices = nullptr; + IndexSubset *resultIndices = nullptr; + if (parseIndexSubset("parameters", parameterIndices)) return true; - if (parseAutoDiffIndexSubset("results", resultIndices)) + if (parseIndexSubset("results", resultIndices)) return true; // Parse a trailing 'where' clause (optional). diff --git a/lib/SIL/SILDifferentiabilityWitness.cpp b/lib/SIL/SILDifferentiabilityWitness.cpp index a18f287b53a7d..e79ab0a45e1d3 100644 --- a/lib/SIL/SILDifferentiabilityWitness.cpp +++ b/lib/SIL/SILDifferentiabilityWitness.cpp @@ -19,7 +19,7 @@ using namespace swift; SILDifferentiabilityWitness *SILDifferentiabilityWitness::create( SILModule &module, SILLinkage linkage, SILFunction *originalFunction, - AutoDiffIndexSubset *parameterIndices, AutoDiffIndexSubset *resultIndices, + IndexSubset *parameterIndices, IndexSubset *resultIndices, GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp, bool isSerialized) { void *buf = module.allocate(sizeof(SILDifferentiabilityWitness), diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index c8bbeaa0692c7..d7aa575cfd235 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -3346,11 +3346,11 @@ SILDeserializer::readDifferentiabilityWitness(DeclID DId) { assert(parameterAndResultIndices.size() == numParameterIndices + numResultIndices && "Parameter/result indices count mismatch"); - auto *parameterIndices = AutoDiffIndexSubset::get( + auto *parameterIndices = IndexSubset::get( MF->getContext(), original->getLoweredFunctionType()->getNumParameters(), ArrayRef(parameterAndResultIndices) .take_front(numParameterIndices)); - auto *resultIndices = AutoDiffIndexSubset::get( + auto *resultIndices = IndexSubset::get( MF->getContext(), original->getLoweredFunctionType()->getNumResults(), ArrayRef(parameterAndResultIndices) .take_back(numResultIndices)); From 873468ffc7ae2f5364696b76ac910503fa848738 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Fri, 11 Oct 2019 20:43:47 +0000 Subject: [PATCH 15/26] Minor fix. --- lib/SIL/SILVerifier.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp index 6ce246081a749..6930944bbc4ad 100644 --- a/lib/SIL/SILVerifier.cpp +++ b/lib/SIL/SILVerifier.cpp @@ -5379,7 +5379,7 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const { AutoDiffDerivativeFunctionKind::VJP, M.Types, LookUpConformanceInModule(M.getSwiftModule()), getDerivativeGenericSignature()->getCanonicalSignature()); - SILVerifier(*jvp).requireSameType( + SILVerifier(*vjp).requireSameType( SILType::getPrimitiveObjectType(vjp->getLoweredFunctionType()), SILType::getPrimitiveObjectType(expectedVJPType), "VJP type does not match expected VJP type"); From aea64d372aaa490a88233f8814938b403f2f0850 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 12 Oct 2019 23:58:45 +0000 Subject: [PATCH 16/26] Update differentiability witness syntax. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Print original function name in comment. ``` // differentiability witness for foo sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @foo ``` --- lib/ParseSIL/ParseSIL.cpp | 55 +++++++++++-------- lib/SIL/SILPrinter.cpp | 26 +++++---- lib/SIL/SILVerifier.cpp | 9 +-- .../sil_differentiability_witness_parse.sil | 5 +- 4 files changed, 56 insertions(+), 39 deletions(-) diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index d7240366f36ea..08bff15625b69 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -6805,17 +6805,17 @@ static void convertRequirements(Parser &P, SILFunction *F, /// decl-sil-differentiability-witness ::= /// 'sil_differentiability_witness' +/// '[' 'parameters' index-subset ']' +/// '[' 'results' index-subset ']' +/// ('[' 'where' derivatve-generic-signature-requirements ']')? /// sil-function-name ':' sil-type -/// 'parameters' autodiff-index-subset -/// 'results' autodiff-index-subset -/// ('where' generic-signature)? /// '{' /// ('jvp' sil-function-name ':' sil-type)? /// ('vjp' sil-function-name ':' sil-type)? /// '}' /// -/// autodiff-index-subset ::= -/// '(' [0-9]+ (',', [0-9]+)* ')' +/// index-subset ::= +/// [0-9]+ (' ' [0-9]+)* bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { P.consumeToken(tok::kw_sil_differentiability_witness); SILParser State(P); @@ -6853,21 +6853,17 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { State.TUState.PotentialZombieFns.insert(fn); return false; }; - // Parse original function name. - SILFunction *originalFn; - if (parseFunctionName(originalFn)) - return true; SourceLoc lastLoc = P.getEndOfPreviousLoc(); // Parse an index subset, prefaced with the given label. auto parseIndexSubset = [&](StringRef label, IndexSubset *& indexSubset) -> bool { + if (P.parseToken(tok::l_square, diag::sil_diff_witness_expected_keyword, + "[")) + return true; if (P.parseSpecificIdentifier( label, diag::sil_diff_witness_expected_keyword, label)) return true; - if (P.parseToken(tok::l_paren, diag::sil_diff_witness_expected_keyword, - "(")) - return true; // Parse parameter index list. SmallVector paramIndices; // Function that parses an index into `paramIndices`. Returns true on error. @@ -6884,11 +6880,11 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { if (parseParam()) return true; // Parse rest. - while (P.consumeIf(tok::comma)) + while (P.Tok.isNot(tok::r_square)) if (parseParam()) return true; - if (P.parseToken(tok::r_paren, diag::sil_diff_witness_expected_keyword, - "(")) + if (P.parseToken(tok::r_square, diag::sil_diff_witness_expected_keyword, + "]")) return true; auto maxIndexRef = std::max_element(paramIndices.begin(), paramIndices.end()); @@ -6907,18 +6903,33 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { // Parse a trailing 'where' clause (optional). // This represents derivative generic signature requirements. GenericSignature *derivativeGenSig = nullptr; - if (P.Tok.is(tok::kw_where)) { - SourceLoc whereLoc; - SmallVector requirementReprs; + SourceLoc whereLoc; + SmallVector derivativeRequirementReprs; + if (P.Tok.is(tok::l_square) && P.peekToken().is(tok::kw_where)) { + P.consumeToken(tok::l_square); bool firstTypeInComplete; - P.parseGenericWhereClause(whereLoc, requirementReprs, firstTypeInComplete, + P.parseGenericWhereClause(whereLoc, derivativeRequirementReprs, + firstTypeInComplete, /*AllowLayoutConstraints*/ false); - auto *whereClause = TrailingWhereClause::create( - originalFn->getModule().getASTContext(), whereLoc, requirementReprs); + if (P.parseToken(tok::r_square, diag::sil_diff_witness_expected_keyword, + "]")) + return true; + } + + // Parse original function name. + SILFunction *originalFn; + if (parseFunctionName(originalFn)) + return true; + + // Resolve derivative requirements. + if (!derivativeRequirementReprs.empty()) { SmallVector requirements; + auto *whereClause = TrailingWhereClause::create( + originalFn->getModule().getASTContext(), whereLoc, + derivativeRequirementReprs); convertRequirements(P, originalFn, whereClause->getRequirements(), requirements); - assert(requirements.size() == requirementReprs.size()); + assert(requirements.size() == derivativeRequirementReprs.size()); derivativeGenSig = evaluateOrDefault( P.Context.evaluator, AbstractGenericSignatureRequest{ diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index 25ba219d04c09..6e32f8a6a1dda 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -3060,24 +3060,24 @@ void SILDefaultWitnessTable::dump() const { // SWIFT_ENABLE_TENSORFLOW void SILDifferentiabilityWitness::print( llvm::raw_ostream &OS, bool verbose) const { + OS << "// differentiability witness for " + << demangleSymbol(originalFunction->getName()) << "\n"; // sil_differentiability_witness @original-function-name : $original-sil-type PrintOptions qualifiedSILTypeOptions = PrintOptions::printQualifiedSILType(); OS << "sil_differentiability_witness "; printLinkage(OS, linkage, ForDefinition); - OS << "@" << originalFunction->getName() << " : " - << originalFunction->getLoweredType(); - // parameters (0, 1, ...) - OS << " parameters ("; + // [parameters 0 1 ...] + OS << "[parameters "; interleave(parameterIndices->getIndices(), [&](unsigned index) { OS << index; }, - [&] { OS << ", "; }); - // results (0, 1, ...) - OS << ") results ("; + [&] { OS << " "; }); + // [results 0 1 ...] + OS << "] [results "; interleave(resultIndices->getIndices(), [&](unsigned index) { OS << index; }, - [&] { OS << ", "; }); - OS << ')'; - // wrt 0, 1, ... + [&] { OS << " "; }); + OS << ']'; + // [where ...] if (derivativeGenericSignature) { // NOTE: This needs to be changed if there is no utility for parsing // generic signatures. Idea: we could instead print the type of the original @@ -3096,7 +3096,7 @@ void SILDifferentiabilityWitness::print( } } if (!requirements.empty()) { - OS << " where "; + OS << " [where "; auto SubPrinter = PrintOptions::printSIL(); interleave(requirements, [&](Requirement req) { @@ -3104,8 +3104,12 @@ void SILDifferentiabilityWitness::print( return; }, [&] { OS << ", "; }); + OS << ']'; } } + // original: @original-function-name : $original-sil-type + OS << " @" << originalFunction->getName() << " : " + << originalFunction->getLoweredType(); // { // jvp: @jvp-function-name : $jvp-sil-type // vjp: @vjp-function-name : $vjp-sil-type diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp index 6930944bbc4ad..2ffca5c970d67 100644 --- a/lib/SIL/SILVerifier.cpp +++ b/lib/SIL/SILVerifier.cpp @@ -5358,14 +5358,16 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const { return; #endif auto origFnType = originalFunction->getLoweredFunctionType(); + CanGenericSignature derivativeCanGenSig; + if (auto *derivativeGenSig = getDerivativeGenericSignature()) + derivativeCanGenSig = derivativeGenSig->getCanonicalSignature(); if (jvp) { // TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType` // to accept result indices. auto expectedJVPType = origFnType->getAutoDiffDerivativeFunctionType( getParameterIndices(), /*resultIndex*/ *resultIndices->begin(), AutoDiffDerivativeFunctionKind::JVP, M.Types, - LookUpConformanceInModule(M.getSwiftModule()), - getDerivativeGenericSignature()->getCanonicalSignature()); + LookUpConformanceInModule(M.getSwiftModule()), derivativeCanGenSig); SILVerifier(*jvp).requireSameType( SILType::getPrimitiveObjectType(jvp->getLoweredFunctionType()), SILType::getPrimitiveObjectType(expectedJVPType), @@ -5377,8 +5379,7 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const { auto expectedVJPType = origFnType->getAutoDiffDerivativeFunctionType( getParameterIndices(), /*resultIndex*/ *resultIndices->begin(), AutoDiffDerivativeFunctionKind::VJP, M.Types, - LookUpConformanceInModule(M.getSwiftModule()), - getDerivativeGenericSignature()->getCanonicalSignature()); + LookUpConformanceInModule(M.getSwiftModule()), derivativeCanGenSig); SILVerifier(*vjp).requireSameType( SILType::getPrimitiveObjectType(vjp->getLoweredFunctionType()), SILType::getPrimitiveObjectType(expectedVJPType), diff --git a/test/AutoDiff/sil_differentiability_witness_parse.sil b/test/AutoDiff/sil_differentiability_witness_parse.sil index 46324f495324f..f2908c3024eba 100644 --- a/test/AutoDiff/sil_differentiability_witness_parse.sil +++ b/test/AutoDiff/sil_differentiability_witness_parse.sil @@ -133,12 +133,13 @@ bb0(%0 : $*T.TangentVector, %1 : $*T.TangentVector): // static AdditiveArithmetic<>.zero.getter sil [serialized] [always_inline] @$ss18AdditiveArithmeticPss27ExpressibleByIntegerLiteralRzrlE4zeroxvgZ : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0 -sil_differentiability_witness hidden @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 parameters (0, 1) results (0) where τ_0_0 : _Differentiable { +sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 { jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) } -// CHECK-LABEL: sil_differentiability_witness hidden @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 parameters (0, 1) results (0) where τ_0_0 : _Differentiable { +// CHECK-LABEL: // differentiability witness for foo +// CHECK: sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 { // CHECK: jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) // CHECK: vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) // CHECK: } From bb93af30770ef628583e030d67fefe0d49e8b611 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 13 Oct 2019 00:18:05 +0000 Subject: [PATCH 17/26] Add `AutoDiffConfig` and use in `SILDifferentiabilityWitnessKey`. --- include/swift/AST/AutoDiff.h | 39 ++++++++++++++++-- .../swift/SIL/SILDifferentiabilityWitness.h | 40 +++++++++++++------ lib/AST/ASTContext.cpp | 25 ++++++++++++ lib/AST/ASTMangler.cpp | 9 +++-- lib/SIL/SILDifferentiabilityWitness.cpp | 10 ++++- lib/SIL/SILPrinter.cpp | 18 ++++----- lib/SIL/SILVerifier.cpp | 4 +- 7 files changed, 110 insertions(+), 35 deletions(-) diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 71c482e31c6b2..9f9301af4d987 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -208,6 +208,41 @@ struct AutoDiffDerivativeFunctionKind { } }; +/// Identifies an autodiff derivative function configuration: +/// - Parameter indices. +/// - Result indices. +/// - Derivative generic signature (optional). +// TODO(TF-893): Use `AutoDiffConfig` in `AutoDiffDerivativeFunctionIdentifier` +// to avoid duplication. +class AutoDiffConfig : public llvm::FoldingSetNode { + IndexSubset *const parameterIndices; + IndexSubset *const resultIndices; + GenericSignature *derivativeGenericSignature; + + AutoDiffConfig(IndexSubset *parameterIndices, IndexSubset *resultIndices, + GenericSignature *derivativeGenericSignature) + : parameterIndices(parameterIndices), resultIndices(resultIndices), + derivativeGenericSignature(derivativeGenericSignature) {} + +public: + IndexSubset *getParameterIndices() const { return parameterIndices; } + IndexSubset *getResultIndices() const { return resultIndices; } + GenericSignature *getDerivativeGenericSignature() const { + return derivativeGenericSignature; + } + + static AutoDiffConfig *get(IndexSubset *parameterIndices, + IndexSubset *resultIndices, + GenericSignature *derivativeGenericSignature, + ASTContext &C); + + void Profile(llvm::FoldingSetNodeID &ID) { + ID.AddPointer(parameterIndices); + ID.AddPointer(resultIndices); + ID.AddPointer(derivativeGenericSignature); + } +}; + /// In conjunction with the original function declaration, identifies an /// autodiff derivative function. /// @@ -241,9 +276,7 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode { /// The key type used for uniquing `SILDifferentiabilityWitness` in /// `SILModule`: original function name, parameter indices, result indices, and /// derivative generic signature. -// TODO(TF-893): Unify with `AutoDiffDerivativeFunctionIdentifier`. -using SILDifferentiabilityWitnessKey = - std::tuple; +using SILDifferentiabilityWitnessKey = std::pair; /// Automatic differentiation utility namespace. namespace autodiff { diff --git a/include/swift/SIL/SILDifferentiabilityWitness.h b/include/swift/SIL/SILDifferentiabilityWitness.h index f75de5aaf39b1..0fb697edf871a 100644 --- a/include/swift/SIL/SILDifferentiabilityWitness.h +++ b/include/swift/SIL/SILDifferentiabilityWitness.h @@ -29,7 +29,6 @@ #include "swift/AST/AutoDiff.h" #include "swift/AST/GenericSignature.h" #include "swift/SIL/SILAllocated.h" -#include "swift/SIL/SILInstruction.h" #include "llvm/ADT/ilist_node.h" #include "llvm/ADT/ilist.h" @@ -48,12 +47,9 @@ class SILDifferentiabilityWitness SILLinkage linkage; /// The original function. SILFunction *originalFunction; - /// The parameter indices. - IndexSubset *parameterIndices; - /// The result indices. - IndexSubset *resultIndices; - /// The derivative generic signature (optional). - GenericSignature *derivativeGenericSignature; + /// The autodiff configuration: parameter indices, result indices, and + /// derivative generic signature (optional). + AutoDiffConfig *autoDiffConfig; /// The JVP (Jacobian-vector products) derivative function. SILFunction *jvp; /// The VJP (vector-Jacobian products) derivative function. @@ -62,6 +58,11 @@ class SILDifferentiabilityWitness /// devirtualization from another module. bool serialized; + static AutoDiffConfig * + getAutoDiffConfig(SILModule &module, IndexSubset *parameterIndices, + IndexSubset *resultIndices, + GenericSignature *derivativeGenSig); + SILDifferentiabilityWitness(SILModule &module, SILLinkage linkage, SILFunction *originalFunction, IndexSubset *parameterIndices, @@ -70,9 +71,9 @@ class SILDifferentiabilityWitness SILFunction *jvp, SILFunction *vjp, bool isSerialized) : module(module), linkage(linkage), originalFunction(originalFunction), - parameterIndices(parameterIndices), resultIndices(resultIndices), - derivativeGenericSignature(derivativeGenSig), jvp(jvp), vjp(vjp), - serialized(isSerialized) {} + autoDiffConfig(getAutoDiffConfig( + module, parameterIndices, resultIndices, derivativeGenSig)), + jvp(jvp), vjp(vjp), serialized(isSerialized) {} public: static SILDifferentiabilityWitness *create( @@ -86,16 +87,29 @@ class SILDifferentiabilityWitness SILLinkage getLinkage() const { return linkage; } SILFunction *getOriginalFunction() const { return originalFunction; } IndexSubset *getParameterIndices() const { - return parameterIndices; + return autoDiffConfig->getParameterIndices(); } IndexSubset *getResultIndices() const { - return resultIndices; + return autoDiffConfig->getResultIndices(); } GenericSignature *getDerivativeGenericSignature() const { - return derivativeGenericSignature; + return autoDiffConfig->getDerivativeGenericSignature(); } SILFunction *getJVP() const { return jvp; } SILFunction *getVJP() const { return vjp; } + SILFunction *getDerivative(AutoDiffDerivativeFunctionKind kind) const { + switch (kind) { + case AutoDiffDerivativeFunctionKind::JVP: return jvp; + case AutoDiffDerivativeFunctionKind::VJP: return vjp; + } + } + void setDerivative(AutoDiffDerivativeFunctionKind kind, + SILFunction *derivative) { + switch (kind) { + case AutoDiffDerivativeFunctionKind::JVP: jvp = derivative; break; + case AutoDiffDerivativeFunctionKind::VJP: vjp = derivative; break; + } + } bool isSerialized() const { return serialized; } /// Verify that the differentiability witness is well-formed. diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index 47d65c3809dba..6734ad40298e4 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -449,9 +449,13 @@ FOR_KNOWN_FOUNDATION_TYPES(CACHE_FOUNDATION_DECL) /// For uniquifying `IndexSubset` allocations. llvm::FoldingSet IndexSubsets; + /// For uniquifying `AutoDiffConfig` allocations. + llvm::FoldingSet AutoDiffConfigs; + /// For uniquifying `AutoDiffDerivativeFunctionIdentifier` allocations. llvm::FoldingSet AutoDiffDerivativeFunctionIdentifiers; + // SWIFT_ENABLE_TENSORFLOW END /// A cache of information about whether particular nominal types /// are representable in a foreign language. @@ -4828,6 +4832,27 @@ IndexSubset::get(ASTContext &ctx, const SmallBitVector &indices) { return newNode; } +AutoDiffConfig *AutoDiffConfig::get( + IndexSubset *parameterIndices, IndexSubset *resultIndices, + GenericSignature *derivativeGenericSignature, ASTContext &C) { + assert(parameterIndices); + assert(resultIndices); + auto &foldingSet = C.getImpl().AutoDiffConfigs; + llvm::FoldingSetNodeID id; + id.AddPointer(parameterIndices); + id.AddPointer(resultIndices); + id.AddPointer(derivativeGenericSignature); + void *insertPos; + auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos); + if (existing) + return existing; + void *buf = C.Allocate(sizeof(AutoDiffConfig), alignof(AutoDiffConfig)); + auto *newNode = new (buf) AutoDiffConfig( + parameterIndices, resultIndices, derivativeGenericSignature); + foldingSet.InsertNode(newNode, insertPos); + return newNode; +} + AutoDiffDerivativeFunctionIdentifier * AutoDiffDerivativeFunctionIdentifier::get( AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices, diff --git a/lib/AST/ASTMangler.cpp b/lib/AST/ASTMangler.cpp index e92dc011f8919..542141d8f30f7 100644 --- a/lib/AST/ASTMangler.cpp +++ b/lib/AST/ASTMangler.cpp @@ -431,10 +431,11 @@ std::string ASTMangler::mangleSILDifferentiabilityWitnessKey( // TODO(TF-20): Make the mangling scheme robust. beginManglingWithoutPrefix(); - auto originalName = std::get<0>(key); - auto *parameterIndices = std::get<1>(key); - auto *resultIndices = std::get<2>(key); - auto *derivativeGenericSignature = std::get<3>(key); + auto originalName = key.first; + auto *autoDiffConfig = key.second; + auto *parameterIndices = autoDiffConfig->getParameterIndices(); + auto *resultIndices = autoDiffConfig->getResultIndices(); + auto *derivativeGenericSignature = autoDiffConfig->getDerivativeGenericSignature(); Buffer << "AD__" << originalName << '_'; Buffer << "P" << parameterIndices->getString(); diff --git a/lib/SIL/SILDifferentiabilityWitness.cpp b/lib/SIL/SILDifferentiabilityWitness.cpp index e79ab0a45e1d3..91becd839bdde 100644 --- a/lib/SIL/SILDifferentiabilityWitness.cpp +++ b/lib/SIL/SILDifferentiabilityWitness.cpp @@ -35,7 +35,13 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::create( return diffWitness; } +AutoDiffConfig *SILDifferentiabilityWitness::getAutoDiffConfig( + SILModule &module, IndexSubset *parameterIndices, + IndexSubset *resultIndices, GenericSignature *derivativeGenSig) { + return AutoDiffConfig::get(parameterIndices, resultIndices, derivativeGenSig, + module.getASTContext()); +} + SILDifferentiabilityWitnessKey SILDifferentiabilityWitness::getKey() const { - return std::make_tuple(originalFunction->getName(), parameterIndices, - resultIndices, derivativeGenericSignature); + return std::make_pair(originalFunction->getName(), autoDiffConfig); } diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index 6e32f8a6a1dda..085d21a3cf183 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -3068,31 +3068,27 @@ void SILDifferentiabilityWitness::print( printLinkage(OS, linkage, ForDefinition); // [parameters 0 1 ...] OS << "[parameters "; - interleave(parameterIndices->getIndices(), + interleave(getParameterIndices()->getIndices(), [&](unsigned index) { OS << index; }, [&] { OS << " "; }); // [results 0 1 ...] OS << "] [results "; - interleave(resultIndices->getIndices(), + interleave(getResultIndices()->getIndices(), [&](unsigned index) { OS << index; }, [&] { OS << " "; }); OS << ']'; // [where ...] - if (derivativeGenericSignature) { - // NOTE: This needs to be changed if there is no utility for parsing - // generic signatures. Idea: we could instead print the type of the original - // function substituted into this generic signature. + if (auto *derivativeGenSig = getDerivativeGenericSignature()) { ArrayRef requirements; SmallVector requirementsScratch; auto *origGenEnv = originalFunction->getGenericEnvironment(); - if (derivativeGenericSignature) { + if (derivativeGenSig) { if (origGenEnv) { - requirementsScratch = - derivativeGenericSignature->requirementsNotSatisfiedBy( - origGenEnv->getGenericSignature()); + requirementsScratch = derivativeGenSig->requirementsNotSatisfiedBy( + origGenEnv->getGenericSignature()); requirements = requirementsScratch; } else { - requirements = derivativeGenericSignature->getRequirements(); + requirements = derivativeGenSig->getRequirements(); } } if (!requirements.empty()) { diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp index 2ffca5c970d67..4b62ef6301eac 100644 --- a/lib/SIL/SILVerifier.cpp +++ b/lib/SIL/SILVerifier.cpp @@ -5365,7 +5365,7 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const { // TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType` // to accept result indices. auto expectedJVPType = origFnType->getAutoDiffDerivativeFunctionType( - getParameterIndices(), /*resultIndex*/ *resultIndices->begin(), + getParameterIndices(), /*resultIndex*/ *getResultIndices()->begin(), AutoDiffDerivativeFunctionKind::JVP, M.Types, LookUpConformanceInModule(M.getSwiftModule()), derivativeCanGenSig); SILVerifier(*jvp).requireSameType( @@ -5377,7 +5377,7 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const { // TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType` // to result indices. auto expectedVJPType = origFnType->getAutoDiffDerivativeFunctionType( - getParameterIndices(), /*resultIndex*/ *resultIndices->begin(), + getParameterIndices(), /*resultIndex*/ *getResultIndices()->begin(), AutoDiffDerivativeFunctionKind::VJP, M.Types, LookUpConformanceInModule(M.getSwiftModule()), derivativeCanGenSig); SILVerifier(*vjp).requireSameType( From f240ed2287a6e4f3ec014c62257b445e0b162a85 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 13 Oct 2019 01:11:10 +0000 Subject: [PATCH 18/26] Change `AutoDiffConfig` to a POD. The contents of `AutoDiffConfig` are all uniqued, so uniquing the product does not make sense. --- include/swift/AST/AutoDiff.h | 70 +++++++++++-------- .../swift/SIL/SILDifferentiabilityWitness.h | 21 +++--- lib/AST/ASTContext.cpp | 24 ------- lib/AST/ASTMangler.cpp | 7 +- lib/SIL/SILDifferentiabilityWitness.cpp | 11 +-- 5 files changed, 57 insertions(+), 76 deletions(-) diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 9f9301af4d987..a31ec70d66040 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -212,35 +212,10 @@ struct AutoDiffDerivativeFunctionKind { /// - Parameter indices. /// - Result indices. /// - Derivative generic signature (optional). -// TODO(TF-893): Use `AutoDiffConfig` in `AutoDiffDerivativeFunctionIdentifier` -// to avoid duplication. -class AutoDiffConfig : public llvm::FoldingSetNode { - IndexSubset *const parameterIndices; - IndexSubset *const resultIndices; +struct AutoDiffConfig { + IndexSubset *parameterIndices; + IndexSubset *resultIndices; GenericSignature *derivativeGenericSignature; - - AutoDiffConfig(IndexSubset *parameterIndices, IndexSubset *resultIndices, - GenericSignature *derivativeGenericSignature) - : parameterIndices(parameterIndices), resultIndices(resultIndices), - derivativeGenericSignature(derivativeGenericSignature) {} - -public: - IndexSubset *getParameterIndices() const { return parameterIndices; } - IndexSubset *getResultIndices() const { return resultIndices; } - GenericSignature *getDerivativeGenericSignature() const { - return derivativeGenericSignature; - } - - static AutoDiffConfig *get(IndexSubset *parameterIndices, - IndexSubset *resultIndices, - GenericSignature *derivativeGenericSignature, - ASTContext &C); - - void Profile(llvm::FoldingSetNodeID &ID) { - ID.AddPointer(parameterIndices); - ID.AddPointer(resultIndices); - ID.AddPointer(derivativeGenericSignature); - } }; /// In conjunction with the original function declaration, identifies an @@ -253,8 +228,7 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode { IndexSubset *const parameterIndices; AutoDiffDerivativeFunctionIdentifier( - AutoDiffDerivativeFunctionKind kind, - IndexSubset *parameterIndices) : + AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices) : kind(kind), parameterIndices(parameterIndices) {} public: @@ -276,7 +250,7 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode { /// The key type used for uniquing `SILDifferentiabilityWitness` in /// `SILModule`: original function name, parameter indices, result indices, and /// derivative generic signature. -using SILDifferentiabilityWitnessKey = std::pair; +using SILDifferentiabilityWitnessKey = std::pair; /// Automatic differentiation utility namespace. namespace autodiff { @@ -403,10 +377,44 @@ class VectorSpace { namespace llvm { +using swift::AutoDiffConfig; +using swift::AutoDiffDerivativeFunctionKind; +using swift::GenericSignature; +using swift::IndexSubset; using swift::SILAutoDiffIndices; template struct DenseMapInfo; +template<> struct DenseMapInfo { + static AutoDiffConfig getEmptyKey() { + auto *ptr = llvm::DenseMapInfo::getEmptyKey(); + return {static_cast(ptr), + static_cast(ptr), + static_cast(ptr)}; + } + + static AutoDiffConfig getTombstoneKey() { + auto *ptr = llvm::DenseMapInfo::getTombstoneKey(); + return {static_cast(ptr), + static_cast(ptr), + static_cast(ptr)}; + } + + static unsigned getHashValue(const AutoDiffConfig &Val) { + unsigned combinedHash = hash_combine( + ~1U, DenseMapInfo::getHashValue(Val.parameterIndices), + DenseMapInfo::getHashValue(Val.resultIndices), + DenseMapInfo::getHashValue(Val.derivativeGenericSignature)); + return combinedHash; + } + + static bool isEqual(const AutoDiffConfig &LHS, const AutoDiffConfig &RHS) { + return LHS.parameterIndices == RHS.parameterIndices && + LHS.resultIndices == RHS.resultIndices && + LHS.derivativeGenericSignature == RHS.derivativeGenericSignature; + } +}; + template<> struct DenseMapInfo { static SILAutoDiffIndices getEmptyKey() { return { DenseMapInfo::getEmptyKey(), nullptr }; diff --git a/include/swift/SIL/SILDifferentiabilityWitness.h b/include/swift/SIL/SILDifferentiabilityWitness.h index 0fb697edf871a..a62e7f67ff97d 100644 --- a/include/swift/SIL/SILDifferentiabilityWitness.h +++ b/include/swift/SIL/SILDifferentiabilityWitness.h @@ -47,9 +47,12 @@ class SILDifferentiabilityWitness SILLinkage linkage; /// The original function. SILFunction *originalFunction; - /// The autodiff configuration: parameter indices, result indices, and - /// derivative generic signature (optional). - AutoDiffConfig *autoDiffConfig; + /// The parameter indices. + IndexSubset *parameterIndices; + /// The result indices. + IndexSubset *resultIndices; + /// The derivative generic signature (optional). + GenericSignature *derivativeGenericSignature; /// The JVP (Jacobian-vector products) derivative function. SILFunction *jvp; /// The VJP (vector-Jacobian products) derivative function. @@ -71,9 +74,9 @@ class SILDifferentiabilityWitness SILFunction *jvp, SILFunction *vjp, bool isSerialized) : module(module), linkage(linkage), originalFunction(originalFunction), - autoDiffConfig(getAutoDiffConfig( - module, parameterIndices, resultIndices, derivativeGenSig)), - jvp(jvp), vjp(vjp), serialized(isSerialized) {} + parameterIndices(parameterIndices), resultIndices(resultIndices), + derivativeGenericSignature(derivativeGenSig), jvp(jvp), vjp(vjp), + serialized(isSerialized) {} public: static SILDifferentiabilityWitness *create( @@ -87,13 +90,13 @@ class SILDifferentiabilityWitness SILLinkage getLinkage() const { return linkage; } SILFunction *getOriginalFunction() const { return originalFunction; } IndexSubset *getParameterIndices() const { - return autoDiffConfig->getParameterIndices(); + return parameterIndices; } IndexSubset *getResultIndices() const { - return autoDiffConfig->getResultIndices(); + return resultIndices; } GenericSignature *getDerivativeGenericSignature() const { - return autoDiffConfig->getDerivativeGenericSignature(); + return derivativeGenericSignature; } SILFunction *getJVP() const { return jvp; } SILFunction *getVJP() const { return vjp; } diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index 6734ad40298e4..27e9c342b5cf6 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -449,9 +449,6 @@ FOR_KNOWN_FOUNDATION_TYPES(CACHE_FOUNDATION_DECL) /// For uniquifying `IndexSubset` allocations. llvm::FoldingSet IndexSubsets; - /// For uniquifying `AutoDiffConfig` allocations. - llvm::FoldingSet AutoDiffConfigs; - /// For uniquifying `AutoDiffDerivativeFunctionIdentifier` allocations. llvm::FoldingSet AutoDiffDerivativeFunctionIdentifiers; @@ -4832,27 +4829,6 @@ IndexSubset::get(ASTContext &ctx, const SmallBitVector &indices) { return newNode; } -AutoDiffConfig *AutoDiffConfig::get( - IndexSubset *parameterIndices, IndexSubset *resultIndices, - GenericSignature *derivativeGenericSignature, ASTContext &C) { - assert(parameterIndices); - assert(resultIndices); - auto &foldingSet = C.getImpl().AutoDiffConfigs; - llvm::FoldingSetNodeID id; - id.AddPointer(parameterIndices); - id.AddPointer(resultIndices); - id.AddPointer(derivativeGenericSignature); - void *insertPos; - auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos); - if (existing) - return existing; - void *buf = C.Allocate(sizeof(AutoDiffConfig), alignof(AutoDiffConfig)); - auto *newNode = new (buf) AutoDiffConfig( - parameterIndices, resultIndices, derivativeGenericSignature); - foldingSet.InsertNode(newNode, insertPos); - return newNode; -} - AutoDiffDerivativeFunctionIdentifier * AutoDiffDerivativeFunctionIdentifier::get( AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices, diff --git a/lib/AST/ASTMangler.cpp b/lib/AST/ASTMangler.cpp index 542141d8f30f7..2e542a3cb66d7 100644 --- a/lib/AST/ASTMangler.cpp +++ b/lib/AST/ASTMangler.cpp @@ -432,10 +432,9 @@ std::string ASTMangler::mangleSILDifferentiabilityWitnessKey( beginManglingWithoutPrefix(); auto originalName = key.first; - auto *autoDiffConfig = key.second; - auto *parameterIndices = autoDiffConfig->getParameterIndices(); - auto *resultIndices = autoDiffConfig->getResultIndices(); - auto *derivativeGenericSignature = autoDiffConfig->getDerivativeGenericSignature(); + auto *parameterIndices = key.second.parameterIndices; + auto *resultIndices = key.second.resultIndices; + auto *derivativeGenericSignature = key.second.derivativeGenericSignature; Buffer << "AD__" << originalName << '_'; Buffer << "P" << parameterIndices->getString(); diff --git a/lib/SIL/SILDifferentiabilityWitness.cpp b/lib/SIL/SILDifferentiabilityWitness.cpp index 91becd839bdde..496e13abc1516 100644 --- a/lib/SIL/SILDifferentiabilityWitness.cpp +++ b/lib/SIL/SILDifferentiabilityWitness.cpp @@ -35,13 +35,8 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::create( return diffWitness; } -AutoDiffConfig *SILDifferentiabilityWitness::getAutoDiffConfig( - SILModule &module, IndexSubset *parameterIndices, - IndexSubset *resultIndices, GenericSignature *derivativeGenSig) { - return AutoDiffConfig::get(parameterIndices, resultIndices, derivativeGenSig, - module.getASTContext()); -} - SILDifferentiabilityWitnessKey SILDifferentiabilityWitness::getKey() const { - return std::make_pair(originalFunction->getName(), autoDiffConfig); + AutoDiffConfig config{parameterIndices, resultIndices, + derivativeGenericSignature}; + return std::make_pair(originalFunction->getName(), config); } From ad6b7aa43b6d67496e3130650d1a006a2d4dbb9c Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 13 Oct 2019 01:12:58 +0000 Subject: [PATCH 19/26] Add `DeclAttribute *` to `SILDifferentiabilityWitness`. Unserialized, to be used for diagnostics. Will revisit later when revamping the differentiation transform. --- include/swift/SIL/SILDifferentiabilityWitness.h | 13 ++++++++++--- lib/SIL/SILDifferentiabilityWitness.cpp | 4 ++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/include/swift/SIL/SILDifferentiabilityWitness.h b/include/swift/SIL/SILDifferentiabilityWitness.h index a62e7f67ff97d..e0b6a68cda950 100644 --- a/include/swift/SIL/SILDifferentiabilityWitness.h +++ b/include/swift/SIL/SILDifferentiabilityWitness.h @@ -26,6 +26,7 @@ #ifndef SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H #define SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H +#include "swift/AST/Attr.h" #include "swift/AST/AutoDiff.h" #include "swift/AST/GenericSignature.h" #include "swift/SIL/SILAllocated.h" @@ -60,6 +61,11 @@ class SILDifferentiabilityWitness /// Whether or not this differentiability witness is serialized, which allows /// devirtualization from another module. bool serialized; + /// The AST `@differentiable` or `@differentiating` attribute from which the + /// differentiability witness is generated. Used for diagnostics. + /// Null if the differentiability witness is parsed from SIL or if it is + /// deserialized. + DeclAttribute *attribute = nullptr; static AutoDiffConfig * getAutoDiffConfig(SILModule &module, IndexSubset *parameterIndices, @@ -72,18 +78,18 @@ class SILDifferentiabilityWitness IndexSubset *resultIndices, GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp, - bool isSerialized) + bool isSerialized, DeclAttribute *attribute) : module(module), linkage(linkage), originalFunction(originalFunction), parameterIndices(parameterIndices), resultIndices(resultIndices), derivativeGenericSignature(derivativeGenSig), jvp(jvp), vjp(vjp), - serialized(isSerialized) {} + serialized(isSerialized), attribute(attribute) {} public: static SILDifferentiabilityWitness *create( SILModule &module, SILLinkage linkage, SILFunction *originalFunction, IndexSubset *parameterIndices, IndexSubset *resultIndices, GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp, - bool isSerialized); + bool isSerialized, DeclAttribute *attribute = nullptr); SILDifferentiabilityWitnessKey getKey() const; SILModule &getModule() const { return module; } @@ -114,6 +120,7 @@ class SILDifferentiabilityWitness } } bool isSerialized() const { return serialized; } + DeclAttribute *getAttribute() const { return attribute; } /// Verify that the differentiability witness is well-formed. void verify(const SILModule &M) const; diff --git a/lib/SIL/SILDifferentiabilityWitness.cpp b/lib/SIL/SILDifferentiabilityWitness.cpp index 496e13abc1516..c663518504433 100644 --- a/lib/SIL/SILDifferentiabilityWitness.cpp +++ b/lib/SIL/SILDifferentiabilityWitness.cpp @@ -21,12 +21,12 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::create( SILModule &module, SILLinkage linkage, SILFunction *originalFunction, IndexSubset *parameterIndices, IndexSubset *resultIndices, GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp, - bool isSerialized) { + bool isSerialized, DeclAttribute *attribute) { void *buf = module.allocate(sizeof(SILDifferentiabilityWitness), alignof(SILDifferentiabilityWitness)); auto *diffWitness = ::new (buf) SILDifferentiabilityWitness( module, linkage, originalFunction, parameterIndices, resultIndices, - derivativeGenSig, jvp, vjp, isSerialized); + derivativeGenSig, jvp, vjp, isSerialized, attribute); // Register the differentiability witness in the module. assert(!module.DifferentiabilityWitnessMap.count(diffWitness->getKey()) && "Cannot create duplicate differentiability witness in a module"); From 7c63d03f481b0283a39607ffa525777df06b9f11 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 13 Oct 2019 01:19:14 +0000 Subject: [PATCH 20/26] Clean up. --- include/swift/AST/DiagnosticsParse.def | 4 +--- lib/Serialization/DeserializeSIL.cpp | 4 ++-- lib/Serialization/SerializeSIL.cpp | 2 +- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index 2aa113ee7ddda..2bebc478307de 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -686,7 +686,7 @@ ERROR(sil_witness_assoc_conf_not_found,none, ERROR(sil_witness_protocol_conformance_not_found,none, "sil protocol conformance not found", ()) -// [differentiable ...] (sil-decl attr) +// SIL differentiability witnesses ERROR(sil_diff_witness_expected_keyword,PointsToFirstBadToken, "expected '%0' in differentiability witness", (StringRef)) ERROR(sil_diff_witness_expected_parameter_list,PointsToFirstBadToken, @@ -698,8 +698,6 @@ ERROR(sil_diff_witness_expected_parameter_index,PointsToFirstBadToken, ERROR(sil_diff_witness_expected_source_index,PointsToFirstBadToken, "expected the index of a result to differentiate from", ()) -// SIL differentiability witnesses - // SIL Coverage Map ERROR(sil_coverage_func_not_found, none, "sil function not found %0", (Identifier)) diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index d7aa575cfd235..e5c911cb320f8 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -148,9 +148,7 @@ SILDeserializer::SILDeserializer( // SIL_DEFAULT_WITNESS_TABLE_NAMES. But each one can be // omitted if no entries exist in the module file. unsigned kind = 0; -// SWIFT_ENABLE_TENSORFLOW while (kind != sil_index_block::SIL_PROPERTY_OFFSETS) { -// SWIFT_ENABLE_TENSORFLOW END auto next = cursor.advance(); if (next.Kind == llvm::BitstreamEntry::EndBlock) return; @@ -225,12 +223,14 @@ SILDeserializer::SILDeserializer( offKind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_OFFSETS) && "Expect a SIL_DEFAULT_WITNESS_TABLE_OFFSETS record."); MF->allocateBuffer(DefaultWitnessTables, scratch); + // SWIFT_ENABLE_TENSORFLOW } else if (kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES) { assert((next.Kind == llvm::BitstreamEntry::Record && offKind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_OFFSETS) && "Expect a SIL_DIFFERENTIABILITY_WITNESS_OFFSETS record."); MF->allocateBuffer(DifferentiabilityWitnesses, scratch); + // SWIFT_ENABLE_TENSORFLOW END } } } diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp index 161359587fafe..cff9f319525dc 100644 --- a/lib/Serialization/SerializeSIL.cpp +++ b/lib/Serialization/SerializeSIL.cpp @@ -2666,7 +2666,7 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) { // SWIFT_ENABLE_TENSORFLOW // Write out differentiability witnesses. for (const auto &diffWitness : SILMod->getDifferentiabilityWitnessList()) { - // TODO: Consider checking + // TODO(TF-893): Consider checking // `SILMod->shouldSerializeEntitiesAssociatedWithDeclContext` on the JVP/VJP // functions. if ((ShouldSerializeAll || diffWitness.isSerialized())) From c3959adede66a8bec6879febfb04bf8719ac24bc Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 13 Oct 2019 01:39:15 +0000 Subject: [PATCH 21/26] Address review feedback. --- include/swift/AST/AutoDiff.h | 8 +++----- include/swift/AST/DiagnosticsParse.def | 12 +++++------- lib/ParseSIL/ParseSIL.cpp | 4 ++-- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index a31ec70d66040..f9092eff0a8dc 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -388,16 +388,14 @@ template struct DenseMapInfo; template<> struct DenseMapInfo { static AutoDiffConfig getEmptyKey() { auto *ptr = llvm::DenseMapInfo::getEmptyKey(); - return {static_cast(ptr), - static_cast(ptr), + return {static_cast(ptr), static_cast(ptr), static_cast(ptr)}; } static AutoDiffConfig getTombstoneKey() { auto *ptr = llvm::DenseMapInfo::getTombstoneKey(); - return {static_cast(ptr), - static_cast(ptr), - static_cast(ptr)}; + return {static_cast(ptr), static_cast(ptr), + static_cast(ptr)}; } static unsigned getHashValue(const AutoDiffConfig &Val) { diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index 2bebc478307de..1906b28f134fb 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -689,14 +689,12 @@ ERROR(sil_witness_protocol_conformance_not_found,none, // SIL differentiability witnesses ERROR(sil_diff_witness_expected_keyword,PointsToFirstBadToken, "expected '%0' in differentiability witness", (StringRef)) -ERROR(sil_diff_witness_expected_parameter_list,PointsToFirstBadToken, - "expected an comma-separated list of parameter indices, e.g. (0, 1)", ()) -ERROR(sil_diff_witness_expected_rsquare,PointsToFirstBadToken, - "expected ']' to end 'differentiable' attribute", ()) +ERROR(sil_diff_witness_expected_index_list,PointsToFirstBadToken, + "expected a space-separated list of indices, e.g. '0 1'", ()) ERROR(sil_diff_witness_expected_parameter_index,PointsToFirstBadToken, - "expected the index of a parameter to differentiate w.r.t.", ()) -ERROR(sil_diff_witness_expected_source_index,PointsToFirstBadToken, - "expected the index of a result to differentiate from", ()) + "expected a parameter index to differentiate with respect to.", ()) +ERROR(sil_diff_witness_expected_result_index,PointsToFirstBadToken, + "expected a result index to differentiate with respect to", ()) // SIL Coverage Map ERROR(sil_coverage_func_not_found, none, diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index 08bff15625b69..4de54f5dedd6b 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -6869,9 +6869,9 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { // Function that parses an index into `paramIndices`. Returns true on error. auto parseParam = [&]() -> bool { unsigned index; - // TODO: Reject non-ascending parameter index lists. + // TODO: Reject non-ascending index lists. if (P.parseUnsignedInteger(index, lastLoc, - diag::sil_diff_witness_expected_parameter_list)) + diag::sil_diff_witness_expected_index_list)) return true; paramIndices.push_back(index); return false; From 2073f863826fb305c1fc105a520d84de1e1463ec Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 13 Oct 2019 06:24:22 +0000 Subject: [PATCH 22/26] Address review feedback. --- include/swift/AST/DiagnosticsParse.def | 4 +-- .../swift/SIL/SILDifferentiabilityWitness.h | 11 +++--- lib/ParseSIL/ParseSIL.cpp | 34 +++++++++---------- lib/SIL/SILDifferentiabilityWitness.cpp | 4 +-- lib/SIL/SILPrinter.cpp | 24 ++++++------- lib/Serialization/DeserializeSIL.cpp | 4 +-- lib/Serialization/SerializeSIL.cpp | 3 +- 7 files changed, 38 insertions(+), 46 deletions(-) diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index 1906b28f134fb..a53abe5c290ec 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -687,12 +687,12 @@ ERROR(sil_witness_protocol_conformance_not_found,none, "sil protocol conformance not found", ()) // SIL differentiability witnesses -ERROR(sil_diff_witness_expected_keyword,PointsToFirstBadToken, +ERROR(sil_diff_witness_expected_token,PointsToFirstBadToken, "expected '%0' in differentiability witness", (StringRef)) ERROR(sil_diff_witness_expected_index_list,PointsToFirstBadToken, "expected a space-separated list of indices, e.g. '0 1'", ()) ERROR(sil_diff_witness_expected_parameter_index,PointsToFirstBadToken, - "expected a parameter index to differentiate with respect to.", ()) + "expected a parameter index to differentiate with respect to", ()) ERROR(sil_diff_witness_expected_result_index,PointsToFirstBadToken, "expected a result index to differentiate with respect to", ()) diff --git a/include/swift/SIL/SILDifferentiabilityWitness.h b/include/swift/SIL/SILDifferentiabilityWitness.h index e0b6a68cda950..46d8e7adee70d 100644 --- a/include/swift/SIL/SILDifferentiabilityWitness.h +++ b/include/swift/SIL/SILDifferentiabilityWitness.h @@ -67,11 +67,6 @@ class SILDifferentiabilityWitness /// deserialized. DeclAttribute *attribute = nullptr; - static AutoDiffConfig * - getAutoDiffConfig(SILModule &module, IndexSubset *parameterIndices, - IndexSubset *resultIndices, - GenericSignature *derivativeGenSig); - SILDifferentiabilityWitness(SILModule &module, SILLinkage linkage, SILFunction *originalFunction, IndexSubset *parameterIndices, @@ -112,6 +107,8 @@ class SILDifferentiabilityWitness case AutoDiffDerivativeFunctionKind::VJP: return vjp; } } + void setJVP(SILFunction *jvp) { this->jvp = jvp; } + void setVJP(SILFunction *vjp) { this->vjp = vjp; } void setDerivative(AutoDiffDerivativeFunctionKind kind, SILFunction *derivative) { switch (kind) { @@ -123,9 +120,9 @@ class SILDifferentiabilityWitness DeclAttribute *getAttribute() const { return attribute; } /// Verify that the differentiability witness is well-formed. - void verify(const SILModule &M) const; + void verify(const SILModule &module) const; - void print(llvm::raw_ostream &OS, bool verbose = false) const; + void print(llvm::raw_ostream &os, bool verbose = false) const; void dump() const; }; diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index 4de54f5dedd6b..c77459b6f925a 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -6747,7 +6747,10 @@ bool SILParserTUState::parseSILDefaultWitnessTable(Parser &P) { // SWIFT_ENABLE_TENSORFLOW // TODO(TF-893): Dedupe with `SILParser::convertRequirements` upstream. -// Consider defining this as `Parser::convertRequirements`. +// Currently, this utility is defined on `SILParser`, but SIL differentiability +// witness is defined on `SILParserTUState` and only has access to `Parser`. +// Consider redefining `SILParser::convertRequirements`as +// `Parser::convertRequirements`. static void convertRequirements(Parser &P, SILFunction *F, ArrayRef From, SmallVectorImpl &To) { @@ -6827,8 +6830,8 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { if (!linkage) linkage = SILLinkage::PublicExternal; - Scope S(&P, ScopeKind::TopLevel); - Scope Body(&P, ScopeKind::FunctionBody); + Scope scope(&P, ScopeKind::TopLevel); + Scope body(&P, ScopeKind::FunctionBody); // Parse a SIL function name. auto parseFunctionName = [&](SILFunction *&fn) -> bool { @@ -6858,11 +6861,10 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { // Parse an index subset, prefaced with the given label. auto parseIndexSubset = [&](StringRef label, IndexSubset *& indexSubset) -> bool { - if (P.parseToken(tok::l_square, diag::sil_diff_witness_expected_keyword, - "[")) + if (P.parseToken(tok::l_square, diag::sil_diff_witness_expected_token, "[")) return true; if (P.parseSpecificIdentifier( - label, diag::sil_diff_witness_expected_keyword, label)) + label, diag::sil_diff_witness_expected_token, label)) return true; // Parse parameter index list. SmallVector paramIndices; @@ -6883,8 +6885,7 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { while (P.Tok.isNot(tok::r_square)) if (parseParam()) return true; - if (P.parseToken(tok::r_square, diag::sil_diff_witness_expected_keyword, - "]")) + if (P.parseToken(tok::r_square, diag::sil_diff_witness_expected_token, "]")) return true; auto maxIndexRef = std::max_element(paramIndices.begin(), paramIndices.end()); @@ -6911,8 +6912,7 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { P.parseGenericWhereClause(whereLoc, derivativeRequirementReprs, firstTypeInComplete, /*AllowLayoutConstraints*/ false); - if (P.parseToken(tok::r_square, diag::sil_diff_witness_expected_keyword, - "]")) + if (P.parseToken(tok::r_square, diag::sil_diff_witness_expected_token, "]")) return true; } @@ -6944,25 +6944,23 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { SILFunction *vjp = nullptr; if (P.Tok.is(tok::l_brace)) { // Parse '{'. - SourceLoc lBraceLoc = P.Tok.getLoc(); - P.consumeToken(tok::l_brace); + SourceLoc lBraceLoc; + P.consumeIf(tok::l_brace, lBraceLoc); // Parse JVP (optional). if (P.Tok.is(tok::identifier) && P.Tok.getText() == "jvp") { P.consumeToken(tok::identifier); - if (P.parseToken(tok::colon, diag::sil_diff_witness_expected_keyword, - ":")) + if (P.parseToken(tok::colon, diag::sil_diff_witness_expected_token, ":")) return true; - Scope Body(&P, ScopeKind::FunctionBody); + Scope body(&P, ScopeKind::FunctionBody); if (parseFunctionName(jvp)) return true; } // Parse VJP (optional). if (P.Tok.is(tok::identifier) && P.Tok.getText() == "vjp") { P.consumeToken(tok::identifier); - if (P.parseToken(tok::colon, diag::sil_diff_witness_expected_keyword, - ":")) + if (P.parseToken(tok::colon, diag::sil_diff_witness_expected_token, ":")) return true; - Scope Body(&P, ScopeKind::FunctionBody); + Scope body(&P, ScopeKind::FunctionBody); if (parseFunctionName(vjp)) return true; } diff --git a/lib/SIL/SILDifferentiabilityWitness.cpp b/lib/SIL/SILDifferentiabilityWitness.cpp index c663518504433..36cf10e532b94 100644 --- a/lib/SIL/SILDifferentiabilityWitness.cpp +++ b/lib/SIL/SILDifferentiabilityWitness.cpp @@ -22,9 +22,7 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::create( IndexSubset *parameterIndices, IndexSubset *resultIndices, GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp, bool isSerialized, DeclAttribute *attribute) { - void *buf = module.allocate(sizeof(SILDifferentiabilityWitness), - alignof(SILDifferentiabilityWitness)); - auto *diffWitness = ::new (buf) SILDifferentiabilityWitness( + auto *diffWitness = new (module) SILDifferentiabilityWitness( module, linkage, originalFunction, parameterIndices, resultIndices, derivativeGenSig, jvp, vjp, isSerialized, attribute); // Register the differentiability witness in the module. diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index 085d21a3cf183..a4c4604a6908d 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -3061,23 +3061,23 @@ void SILDefaultWitnessTable::dump() const { void SILDifferentiabilityWitness::print( llvm::raw_ostream &OS, bool verbose) const { OS << "// differentiability witness for " - << demangleSymbol(originalFunction->getName()) << "\n"; - // sil_differentiability_witness @original-function-name : $original-sil-type + << demangleSymbol(originalFunction->getName()) << '\n'; PrintOptions qualifiedSILTypeOptions = PrintOptions::printQualifiedSILType(); + // sil_differentiability_witness (linkage)? OS << "sil_differentiability_witness "; printLinkage(OS, linkage, ForDefinition); - // [parameters 0 1 ...] + // [parameters ...] OS << "[parameters "; interleave(getParameterIndices()->getIndices(), [&](unsigned index) { OS << index; }, - [&] { OS << " "; }); - // [results 0 1 ...] + [&] { OS << ' '; }); + // [results ...] OS << "] [results "; interleave(getResultIndices()->getIndices(), [&](unsigned index) { OS << index; }, - [&] { OS << " "; }); + [&] { OS << ' '; }); OS << ']'; - // [where ...] + // ([where ...])? if (auto *derivativeGenSig = getDerivativeGenericSignature()) { ArrayRef requirements; SmallVector requirementsScratch; @@ -3093,17 +3093,17 @@ void SILDifferentiabilityWitness::print( } if (!requirements.empty()) { OS << " [where "; - auto SubPrinter = PrintOptions::printSIL(); + auto subPrinter = PrintOptions::printSIL(); interleave(requirements, [&](Requirement req) { - req.print(OS, SubPrinter); + req.print(OS, subPrinter); return; }, [&] { OS << ", "; }); OS << ']'; } } - // original: @original-function-name : $original-sil-type + // @original-function-name : $original-sil-type OS << " @" << originalFunction->getName() << " : " << originalFunction->getLoweredType(); // { @@ -3112,9 +3112,9 @@ void SILDifferentiabilityWitness::print( // } OS << " {\n"; if (jvp) - OS << " jvp: @" << jvp->getName() << " : " << jvp->getLoweredType() << "\n"; + OS << " jvp: @" << jvp->getName() << " : " << jvp->getLoweredType() << '\n'; if (vjp) - OS << " vjp: @" << vjp->getName() << " : " << vjp->getLoweredType() << "\n"; + OS << " vjp: @" << vjp->getName() << " : " << vjp->getLoweredType() << '\n'; OS << "}\n\n"; } diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index e5c911cb320f8..6e0f77bd28540 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -3369,9 +3369,7 @@ SILDifferentiabilityWitness *SILDeserializer::lookupDifferentiabilityWitness( auto iter = DifferentiabilityWitnessList->find(mangledDiffWitnessKey); if (iter == DifferentiabilityWitnessList->end()) return nullptr; - - auto *diffWitness = readDifferentiabilityWitness(*iter); - return diffWitness; + return readDifferentiabilityWitness(*iter); } void SILDeserializer::getAllDifferentiabilityWitnesses() { diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp index cff9f319525dc..3e4641c5aa774 100644 --- a/lib/Serialization/SerializeSIL.cpp +++ b/lib/Serialization/SerializeSIL.cpp @@ -2300,6 +2300,7 @@ void SILSerializer::writeIndexTables() { Offset.emit(ScratchRecord, sil_index_block::SIL_PROPERTY_OFFSETS, PropertyOffset); } + } void SILSerializer::writeSILGlobalVar(const SILGlobalVariable &g) { @@ -2517,7 +2518,7 @@ writeSILDifferentiabilityWitness(const SILDifferentiabilityWitness &dw) { vjpID = S.addUniquedStringRef(vjp->getName()); } SmallVector parameterAndResultIndices( - dw.getParameterIndices()->begin(), dw.getParameterIndices()->end()); + dw.getParameterIndices()->begin(), dw.getParameterIndices()->end()); parameterAndResultIndices.append(dw.getResultIndices()->begin(), dw.getResultIndices()->end()); auto originalFnType = original->getLoweredFunctionType(); From 27d7abc5cb7a9456b93a9f934a993aa0c26da5d0 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 13 Oct 2019 06:35:10 +0000 Subject: [PATCH 23/26] Parse/print `[serialized]` flag. Manually verified parsing/printing. Chose not to add additional test for now to keep the test small. --- lib/ParseSIL/ParseSIL.cpp | 14 ++++++++++++-- lib/SIL/SILPrinter.cpp | 3 +++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index c77459b6f925a..52d7a12269286 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -6808,6 +6808,7 @@ static void convertRequirements(Parser &P, SILFunction *F, /// decl-sil-differentiability-witness ::= /// 'sil_differentiability_witness' +/// ('[' 'serialized' ']')? /// '[' 'parameters' index-subset ']' /// '[' 'results' index-subset ']' /// ('[' 'where' derivatve-generic-signature-requirements ']')? @@ -6830,6 +6831,17 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { if (!linkage) linkage = SILLinkage::PublicExternal; + // Parse '[serialized]' flag (optional). + bool isSerialized = false; + if (P.Tok.is(tok::l_square) && P.peekToken().is(tok::identifier) && + P.peekToken().getText() == "serialized") { + isSerialized = true; + P.consumeToken(tok::l_square); + P.consumeToken(tok::identifier); + if (P.parseToken(tok::r_square, diag::sil_diff_witness_expected_token, "]")) + return true; + } + Scope scope(&P, ScopeKind::TopLevel); Scope body(&P, ScopeKind::FunctionBody); @@ -6970,8 +6982,6 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { return true; } - // TODO(TF-893): Parse `isSerialized` flag. - bool isSerialized = false; SILDifferentiabilityWitness::create( M, *linkage, originalFn, parameterIndices, resultIndices, derivativeGenSig, jvp, vjp, isSerialized); diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index a4c4604a6908d..6debcb183b749 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -3066,6 +3066,9 @@ void SILDifferentiabilityWitness::print( // sil_differentiability_witness (linkage)? OS << "sil_differentiability_witness "; printLinkage(OS, linkage, ForDefinition); + // ([serialized])? + if (isSerialized()) + OS << "[serialized] "; // [parameters ...] OS << "[parameters "; interleave(getParameterIndices()->getIndices(), From 69209be998fab96491d2f5963924779c667dfeac Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 13 Oct 2019 07:34:37 +0000 Subject: [PATCH 24/26] Add parsing/printing tests, address review feedback. --- lib/ParseSIL/ParseSIL.cpp | 4 +- lib/SIL/SILPrinter.cpp | 1 - .../sil_differentiability_witness_parse.sil | 161 +++++------------- 3 files changed, 43 insertions(+), 123 deletions(-) diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index 52d7a12269286..a89de420a10f8 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -6809,6 +6809,7 @@ static void convertRequirements(Parser &P, SILFunction *F, /// decl-sil-differentiability-witness ::= /// 'sil_differentiability_witness' /// ('[' 'serialized' ']')? +/// sil-linkage? /// '[' 'parameters' index-subset ']' /// '[' 'results' index-subset ']' /// ('[' 'where' derivatve-generic-signature-requirements ']')? @@ -6828,8 +6829,9 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { Optional linkage; if (parseSILLinkage(linkage, P)) return true; + // Default to public linkage. if (!linkage) - linkage = SILLinkage::PublicExternal; + linkage = SILLinkage::Public; // Parse '[serialized]' flag (optional). bool isSerialized = false; diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index 6debcb183b749..a350176642887 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -3100,7 +3100,6 @@ void SILDifferentiabilityWitness::print( interleave(requirements, [&](Requirement req) { req.print(OS, subPrinter); - return; }, [&] { OS << ", "; }); OS << ']'; diff --git a/test/AutoDiff/sil_differentiability_witness_parse.sil b/test/AutoDiff/sil_differentiability_witness_parse.sil index f2908c3024eba..e9b8e28eb0d1b 100644 --- a/test/AutoDiff/sil_differentiability_witness_parse.sil +++ b/test/AutoDiff/sil_differentiability_witness_parse.sil @@ -2,144 +2,63 @@ // Round-trip parsing and printing test. -// Swift source code (`-emit-silgen` output is below): -// -// @differentiable(jvp: foo_jvp where T: Differentiable) -// @_silgen_name("foo") -// func foo(_ x: T, _ y: Float) -> T { x } -// -// @_silgen_name("foo_jvp") -// func foo_jvp(_ x: T, _ y: Float) -> (T, (T.TangentVector, Float) -> T.TangentVector) { -// (x, { dx, dy in dx }) -// } -// -// @_silgen_name("foo_vjp") -// func foo_vjp(_ x: T, _ y: Float) -> (T, (T.TangentVector) -> (T.TangentVector, Float)) { -// (x, { ($0, .zero) }) -// } -// -// The `sil_differentiability_witness` at the end was manually written. - sil_stage raw import Builtin import Swift import SwiftShims -@differentiable(wrt: (x, y), jvp: foo_jvp where T : Differentiable) -@_silgen_name("foo") -func foo(_ x: T, _ y: Float) -> T - -@_silgen_name("foo_jvp") -func foo_jvp(_ x: T, _ y: Float) -> (T, (T.TangentVector, Float) -> T.TangentVector) where T : Differentiable +// Test public non-generic function. +// SIL differentiability witness: +// - Has public linkage (implicit). +// - Has no `where` clause. -@_silgen_name("foo_vjp") -func foo_vjp(_ x: T, _ y: Float) -> (T, (T.TangentVector) -> (T.TangentVector, Float)) where T : Differentiable +sil [ossa] @foo : $@convention(thin) (Float) -> Float -// main -sil [ossa] @main : $@convention(c) (Int32, UnsafeMutablePointer>>) -> Int32 { -bb0(%0 : $Int32, %1 : $UnsafeMutablePointer>>): - %2 = integer_literal $Builtin.Int32, 0 // user: %3 - %3 = struct $Int32 (%2 : $Builtin.Int32) // user: %4 - return %3 : $Int32 // id: %4 -} // end sil function 'main' +sil @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) -// foo -sil hidden [differentiable source 0 wrt 0, 1 jvp @AD__foo__jvp_src_0_wrt_0_1 where T : Differentiable] [ossa] @foo : $@convention(thin) (@in_guaranteed T, Float) -> @out T { -// %0 // user: %5 -// %1 // users: %5, %3 -// %2 // user: %4 -bb0(%0 : $*T, %1 : $*T, %2 : $Float): - debug_value_addr %1 : $*T, let, name "x", argno 1 // id: %3 - debug_value %2 : $Float, let, name "y", argno 2 // id: %4 - copy_addr %1 to [initialization] %0 : $*T // id: %5 - %6 = tuple () // user: %7 - return %6 : $() // id: %7 -} // end sil function 'foo' +sil @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) -// foo_jvp -sil hidden [ossa] @foo_jvp : $@convention(thin) (@in_guaranteed T, Float) -> (@out T, @owned @callee_guaranteed (@in_guaranteed T.TangentVector, Float) -> @out T.TangentVector) { -// %0 // user: %5 -// %1 // users: %5, %3 -// %2 // user: %4 -bb0(%0 : $*T, %1 : $*T, %2 : $Float): - debug_value_addr %1 : $*T, let, name "x", argno 1 // id: %3 - debug_value %2 : $Float, let, name "y", argno 2 // id: %4 - copy_addr %1 to [initialization] %0 : $*T // id: %5 - // function_ref closure #1 in foo_jvp(_:_:) - %6 = function_ref @$s4main7foo_jvpyx_13TangentVectorQzAD_Sftctx_Sfts14DifferentiableRzlFA2D_SftcfU_ : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector // user: %7 - %7 = partial_apply [callee_guaranteed] %6() : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector // user: %8 - return %7 : $@callee_guaranteed (@in_guaranteed T.TangentVector, Float) -> @out T.TangentVector // id: %8 -} // end sil function 'foo_jvp' +sil_differentiability_witness [parameters 0] [results 0] @foo : $@convention(thin) (Float) -> Float { + jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) + vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +} -// AD__foo__jvp_src_0_wrt_0_1 -sil hidden [transparent] [thunk] [ossa] @AD__foo__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) { -// %0 // user: %4 -// %1 // user: %4 -// %2 // user: %4 -bb0(%0 : $*τ_0_0, %1 : $*τ_0_0, %2 : $Float): - // function_ref foo_jvp - %3 = function_ref @foo_jvp : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) // user: %4 - %4 = apply %3<τ_0_0>(%0, %1, %2) : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) // user: %5 - return %4 : $@callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector // id: %5 -} // end sil function 'AD__foo__jvp_src_0_wrt_0_1' +// CHECK-LABEL: // differentiability witness for foo +// CHECK: sil_differentiability_witness [parameters 0] [results 0] @foo : $@convention(thin) (Float) -> Float { +// CHECK: jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK: vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK: } -// closure #1 in foo_jvp(_:_:) -sil private [ossa] @$s4main7foo_jvpyx_13TangentVectorQzAD_Sftctx_Sfts14DifferentiableRzlFA2D_SftcfU_ : $@convention(thin) (@in_guaranteed T.TangentVector, Float) -> @out T.TangentVector { -// %0 // user: %5 -// %1 // users: %5, %3 -// %2 // user: %4 -bb0(%0 : $*T.TangentVector, %1 : $*T.TangentVector, %2 : $Float): - debug_value_addr %1 : $*T.TangentVector, let, name "dx", argno 1 // id: %3 - debug_value %2 : $Float, let, name "dy", argno 2 // id: %4 - copy_addr %1 to [initialization] %0 : $*T.TangentVector // id: %5 - %6 = tuple () // user: %7 - return %6 : $() // id: %7 -} // end sil function '$s4main7foo_jvpyx_13TangentVectorQzAD_Sftctx_Sfts14DifferentiableRzlFA2D_SftcfU_' +// Test internal generic function. +// SIL differentiability witness: +// - Has hidden linkage. +// - Has `where` clause. -// foo_vjp -sil hidden [ossa] @foo_vjp : $@convention(thin) (@in_guaranteed T, Float) -> (@out T, @owned @callee_guaranteed (@in_guaranteed T.TangentVector) -> (@out T.TangentVector, Float)) { -// %0 // user: %5 -// %1 // users: %5, %3 -// %2 // user: %4 +sil hidden [ossa] @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T { bb0(%0 : $*T, %1 : $*T, %2 : $Float): - debug_value_addr %1 : $*T, let, name "x", argno 1 // id: %3 - debug_value %2 : $Float, let, name "y", argno 2 // id: %4 - copy_addr %1 to [initialization] %0 : $*T // id: %5 - // function_ref closure #1 in foo_vjp(_:_:) - %6 = function_ref @$s4main7foo_vjpyx_13TangentVectorQz_SftADctx_Sfts14DifferentiableRzlFAD_SftADcfU_ : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float) // user: %7 - %7 = partial_apply [callee_guaranteed] %6() : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float) // user: %8 - return %7 : $@callee_guaranteed (@in_guaranteed T.TangentVector) -> (@out T.TangentVector, Float) // id: %8 -} // end sil function 'foo_vjp' + copy_addr %1 to [initialization] %0 : $*T + %void = tuple () + return %void : $() +} -// closure #1 in foo_vjp(_:_:) -sil private [ossa] @$s4main7foo_vjpyx_13TangentVectorQz_SftADctx_Sfts14DifferentiableRzlFAD_SftADcfU_ : $@convention(thin) (@in_guaranteed T.TangentVector) -> (@out T.TangentVector, Float) { -// %0 // user: %3 -// %1 // users: %3, %2 -bb0(%0 : $*T.TangentVector, %1 : $*T.TangentVector): - debug_value_addr %1 : $*T.TangentVector, let, name "$0", argno 1 // id: %2 - copy_addr %1 to [initialization] %0 : $*T.TangentVector // id: %3 - %4 = metatype $@thin Float.Type - %5 = alloc_stack $Float // users: %10, %9, %8 - %6 = metatype $@thick Float.Type // user: %8 - // function_ref static AdditiveArithmetic<>.zero.getter - %7 = function_ref @$ss18AdditiveArithmeticPss27ExpressibleByIntegerLiteralRzrlE4zeroxvgZ : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %8 - %8 = apply %7(%5, %6) : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0 - %9 = load [trivial] %5 : $*Float // user: %11 - dealloc_stack %5 : $*Float // id: %10 - return %9 : $Float // id: %11 -} // end sil function '$s4main7foo_vjpyx_13TangentVectorQz_SftADctx_Sfts14DifferentiableRzlFAD_SftADcfU_' +sil hidden @AD__generic__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) { +bb0(%0 : $*τ_0_0, %1 : $*τ_0_0, %2 : $Float): + return undef : $@callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector +} -// static AdditiveArithmetic<>.zero.getter -sil [serialized] [always_inline] @$ss18AdditiveArithmeticPss27ExpressibleByIntegerLiteralRzrlE4zeroxvgZ : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0 +sil hidden @AD__generic__vjp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) { +bb0(%0 : $*τ_0_0, %1 : $*τ_0_0, %2 : $Float): + return undef : $@callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float) // id: %5 +} -sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 { - jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) - vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) +sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 { + jvp: @AD__generic__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) + vjp: @AD__generic__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) } -// CHECK-LABEL: // differentiability witness for foo -// CHECK: sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 { -// CHECK: jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) -// CHECK: vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) +// CHECK-LABEL: // differentiability witness for generic +// CHECK: sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 { +// CHECK: jvp: @AD__generic__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) +// CHECK: vjp: @AD__generic__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) // CHECK: } From df9cd497902873bb95b92fca184f51564c16ab7d Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 13 Oct 2019 08:42:01 +0000 Subject: [PATCH 25/26] Fix serialization and add test. Note: deserialization does not work when SIL differentiability witness references bodyless function declarations. --- lib/Serialization/DeserializeSIL.cpp | 19 +++++++--- lib/Serialization/SerializeSIL.cpp | 16 +++++++-- ....sil => sil_differentiability_witness.sil} | 36 +++++++++++++------ 3 files changed, 54 insertions(+), 17 deletions(-) rename test/AutoDiff/{sil_differentiability_witness_parse.sil => sil_differentiability_witness.sil} (57%) diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index 6e0f77bd28540..7eb6c7d722b9f 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -3299,8 +3299,12 @@ SILDeserializer::lookupDefaultWitnessTable(SILDefaultWitnessTable *existingWt) { // SWIFT_ENABLE_TENSORFLOW SILDifferentiabilityWitness * SILDeserializer::readDifferentiabilityWitness(DeclID DId) { - auto &diffWitnessOrOffset = DifferentiabilityWitnesses[DId-1]; + if (DId == 0) + return nullptr; + assert(DId <= DifferentiabilityWitnesses.size() && + "Invalid SILDifferentiabilityWitness ID"); + auto &diffWitnessOrOffset = DifferentiabilityWitnesses[DId-1]; if (diffWitnessOrOffset.isFullyDeserialized()) return diffWitnessOrOffset.get(); @@ -3332,12 +3336,17 @@ SILDeserializer::readDifferentiabilityWitness(DeclID DId) { auto linkage = fromStableSILLinkage(rawLinkage); assert(linkage && "Expected value linkage for sil_differentiability_witness"); - auto originalName = MF->getIdentifier(originalNameId).str(); - auto jvpName = MF->getIdentifier(jvpNameId).str(); - auto vjpName = MF->getIdentifier(vjpNameId).str(); + auto originalName = MF->getIdentifierText(originalNameId); + auto jvpName = MF->getIdentifierText(jvpNameId); + auto vjpName = MF->getIdentifierText(vjpNameId); auto *original = getFuncForReference(originalName); + assert(original && "Original function must be found"); auto *jvp = getFuncForReference(jvpName); + if (!jvpName.empty()) + assert(jvp && "JVP function must be found if JVP name is not empty"); auto *vjp = getFuncForReference(vjpName); + if (!vjpName.empty()) + assert(vjp && "VJP function must be found if VJP name is not empty"); auto derivativeGenSig = MF->getGenericSignature(derivativeGenSigID); SmallVector parameterAndResultIndices( @@ -3373,6 +3382,8 @@ SILDifferentiabilityWitness *SILDeserializer::lookupDifferentiabilityWitness( } void SILDeserializer::getAllDifferentiabilityWitnesses() { + if (!DifferentiabilityWitnessList) + return; for (unsigned I = 0, E = DifferentiabilityWitnesses.size(); I < E; ++I) readDifferentiabilityWitness(I+1); } diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp index 3e4641c5aa774..849e537638a38 100644 --- a/lib/Serialization/SerializeSIL.cpp +++ b/lib/Serialization/SerializeSIL.cpp @@ -2227,7 +2227,13 @@ static void writeIndexTable(Serializer &S, kind == sil_index_block::SIL_VTABLE_NAMES || kind == sil_index_block::SIL_GLOBALVAR_NAMES || kind == sil_index_block::SIL_WITNESS_TABLE_NAMES || - kind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_NAMES) && + kind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_NAMES || + // SWIFT_ENABLE_TENSORFLOW + // TODO(TF-893): Update surrounding comment/assertion text when + // upstreaming code to master. Comments/assertions have not been + // updated to keep the code diff small. + kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES) && + // SWIFT_ENABLE_TENSORFLOW END "SIL function table, global, vtable and (default) witness table " "are supported"); llvm::SmallString<4096> hashTableBlob; @@ -2523,9 +2529,13 @@ writeSILDifferentiabilityWitness(const SILDifferentiabilityWitness &dw) { dw.getResultIndices()->end()); auto originalFnType = original->getLoweredFunctionType(); assert(originalFnType->getNumParameters() == - dw.getParameterIndices()->getCapacity()); + dw.getParameterIndices()->getCapacity() && + "Original function parameter count should match differentiability " + "witness parameter indices capacity"); assert(originalFnType->getNumResults() == - dw.getResultIndices()->getCapacity()); + dw.getResultIndices()->getCapacity() && + "Original function result count should match differentiability " + "witness result indices capacity"); DifferentiabilityWitnessLayout::emitRecord( Out, ScratchRecord, SILAbbrCodes[DifferentiabilityWitnessLayout::Code], diff --git a/test/AutoDiff/sil_differentiability_witness_parse.sil b/test/AutoDiff/sil_differentiability_witness.sil similarity index 57% rename from test/AutoDiff/sil_differentiability_witness_parse.sil rename to test/AutoDiff/sil_differentiability_witness.sil index e9b8e28eb0d1b..8f56d2480dcef 100644 --- a/test/AutoDiff/sil_differentiability_witness_parse.sil +++ b/test/AutoDiff/sil_differentiability_witness.sil @@ -1,6 +1,13 @@ -// RUN: %target-sil-opt %s -module-name=sil_differentiability_witness_parse | %target-sil-opt -module-name=sil_differentiability_witness_parse | %FileCheck %s +// RUN: %target-sil-opt %s | %target-sil-opt | %FileCheck %s -// Round-trip parsing and printing test. +// RUN: %empty-directory(%t) +// RUN: %target-sil-opt %s -emit-sib -o %t/tmp.sib -module-name main +// RUN: %target-sil-opt %t/tmp.sib -o %t/tmp.2.sib -module-name main +// RUN: %target-sil-opt %t/tmp.2.sib -module-name main | %FileCheck %s + +// Round-trip parsing/printing and serialization/deserialization test. +// NOTE: deserialization currently fails if public function bodies are removed +// so that they are only declarations. This may require investigation. sil_stage raw @@ -13,11 +20,20 @@ import SwiftShims // - Has public linkage (implicit). // - Has no `where` clause. -sil [ossa] @foo : $@convention(thin) (Float) -> Float +sil [ossa] @foo : $@convention(thin) (Float) -> Float { +bb0(%0 : $Float): + return %0 : $Float +} -sil @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +sil @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +bb0(%0 : $Float): + return undef : $(Float, @callee_guaranteed (Float) -> Float) +} -sil @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +sil @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +bb0(%0 : $Float): + return undef : $(Float, @callee_guaranteed (Float) -> Float) +} sil_differentiability_witness [parameters 0] [results 0] @foo : $@convention(thin) (Float) -> Float { jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) @@ -49,16 +65,16 @@ bb0(%0 : $*τ_0_0, %1 : $*τ_0_0, %2 : $Float): sil hidden @AD__generic__vjp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) { bb0(%0 : $*τ_0_0, %1 : $*τ_0_0, %2 : $Float): - return undef : $@callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float) // id: %5 + return undef : $@callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float) } sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 { - jvp: @AD__generic__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) - vjp: @AD__generic__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) + jvp: @AD__generic__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) + vjp: @AD__generic__vjp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) } // CHECK-LABEL: // differentiability witness for generic // CHECK: sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 { -// CHECK: jvp: @AD__generic__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) -// CHECK: vjp: @AD__generic__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) +// CHECK: jvp: @AD__generic__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) +// CHECK: vjp: @AD__generic__vjp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) // CHECK: } From d673b7056cd03dbdfb77c05a23ac0ad09a433315 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 13 Oct 2019 10:06:32 +0000 Subject: [PATCH 26/26] Fix verification. Verification assertion messages should appear on the differentiability witness. --- lib/SIL/SILVerifier.cpp | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp index 4b62ef6301eac..f290cc4810f45 100644 --- a/lib/SIL/SILVerifier.cpp +++ b/lib/SIL/SILVerifier.cpp @@ -5361,6 +5361,22 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const { CanGenericSignature derivativeCanGenSig; if (auto *derivativeGenSig = getDerivativeGenericSignature()) derivativeCanGenSig = derivativeGenSig->getCanonicalSignature(); + auto requireSameType = + [&](CanSILFunctionType type1, CanSILFunctionType type2, + const Twine &complaint) { + if (type1 == type2) + return; + llvm::dbgs() << "SIL verification failed: " << complaint << "\n"; + llvm::dbgs() << " " << type1 << "\n " << type2 << "\n\n"; + llvm::dbgs() << "In differentiability witness:\n"; + print(llvm::dbgs()); + // We abort by default because we want to always crash in + // the debugger. + if (AbortOnFailure) + abort(); + else + exit(1); + }; if (jvp) { // TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType` // to accept result indices. @@ -5368,10 +5384,8 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const { getParameterIndices(), /*resultIndex*/ *getResultIndices()->begin(), AutoDiffDerivativeFunctionKind::JVP, M.Types, LookUpConformanceInModule(M.getSwiftModule()), derivativeCanGenSig); - SILVerifier(*jvp).requireSameType( - SILType::getPrimitiveObjectType(jvp->getLoweredFunctionType()), - SILType::getPrimitiveObjectType(expectedJVPType), - "JVP type does not match expected JVP type"); + requireSameType(jvp->getLoweredFunctionType(), expectedJVPType, + "JVP type does not match expected JVP type"); } if (vjp) { // TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType` @@ -5380,10 +5394,8 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const { getParameterIndices(), /*resultIndex*/ *getResultIndices()->begin(), AutoDiffDerivativeFunctionKind::VJP, M.Types, LookUpConformanceInModule(M.getSwiftModule()), derivativeCanGenSig); - SILVerifier(*vjp).requireSameType( - SILType::getPrimitiveObjectType(vjp->getLoweredFunctionType()), - SILType::getPrimitiveObjectType(expectedVJPType), - "VJP type does not match expected VJP type"); + requireSameType(vjp->getLoweredFunctionType(), expectedVJPType, + "VJP type does not match expected VJP type"); } } // SWIFT_ENABLE_TENSORFLOW END