diff --git a/include/swift/AST/ASTMangler.h b/include/swift/AST/ASTMangler.h index 6ea3807732d3d..8aa1b8736e994 100644 --- a/include/swift/AST/ASTMangler.h +++ b/include/swift/AST/ASTMangler.h @@ -155,23 +155,31 @@ class ASTMangler : public Mangler { ModuleDecl *Module); // SWIFT_ENABLE_TENSORFLOW - // Mangle the derivative function (JVP/VJP) with the given: - // - Mangled original function name. - // - Derivative function kind. - // - Parameter/result indices. + /// Mangle the derivative function (JVP/VJP) with the given: + /// - Mangled original function name. + /// - Derivative function kind. + /// - Parameter/result indices. std::string mangleAutoDiffDerivativeFunctionHelper( StringRef name, AutoDiffDerivativeFunctionKind kind, const SILAutoDiffIndices &indices); - // SWIFT_ENABLE_TENSORFLOW - // Mangle the autodiff linear map (differential/pullback) with the given: - // - Mangled original function name. - // - Linear map kind. - // - Parameter/result indices. + /// Mangle the autodiff linear map (differential/pullback) with the given: + /// - Mangled original function name. + /// - Linear map kind. + /// - Parameter/result indices. std::string mangleAutoDiffLinearMapHelper( StringRef name, AutoDiffLinearMapKind kind, const SILAutoDiffIndices &indices); + /// Mangle a SIL differentiability witness key. + /// - Mangled original function name. + /// - Parameter indices. + /// - Result indices. + /// - Derivative generic signature (optional). + std::string mangleSILDifferentiabilityWitnessKey( + SILDifferentiabilityWitnessKey key); + // SWIFT_ENABLE_TENSORFLOW END + std::string mangleKeyPathGetterThunkHelper(const AbstractStorageDecl *property, GenericSignature *signature, CanType baseType, diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 53ab7934fe7ae..f9092eff0a8dc 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -208,6 +208,16 @@ struct AutoDiffDerivativeFunctionKind { } }; +/// Identifies an autodiff derivative function configuration: +/// - Parameter indices. +/// - Result indices. +/// - Derivative generic signature (optional). +struct AutoDiffConfig { + IndexSubset *parameterIndices; + IndexSubset *resultIndices; + GenericSignature *derivativeGenericSignature; +}; + /// In conjunction with the original function declaration, identifies an /// autodiff derivative function. /// @@ -218,8 +228,7 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode { IndexSubset *const parameterIndices; AutoDiffDerivativeFunctionIdentifier( - AutoDiffDerivativeFunctionKind kind, - IndexSubset *parameterIndices) : + AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices) : kind(kind), parameterIndices(parameterIndices) {} public: @@ -238,6 +247,11 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode { } }; +/// The key type used for uniquing `SILDifferentiabilityWitness` in +/// `SILModule`: original function name, parameter indices, result indices, and +/// derivative generic signature. +using SILDifferentiabilityWitnessKey = std::pair; + /// Automatic differentiation utility namespace. namespace autodiff { /// Appends the subset's parameter's types to `result`, in the order in @@ -363,10 +377,42 @@ class VectorSpace { namespace llvm { +using swift::AutoDiffConfig; +using swift::AutoDiffDerivativeFunctionKind; +using swift::GenericSignature; +using swift::IndexSubset; using swift::SILAutoDiffIndices; template struct DenseMapInfo; +template<> struct DenseMapInfo { + static AutoDiffConfig getEmptyKey() { + auto *ptr = llvm::DenseMapInfo::getEmptyKey(); + return {static_cast(ptr), static_cast(ptr), + static_cast(ptr)}; + } + + static AutoDiffConfig getTombstoneKey() { + auto *ptr = llvm::DenseMapInfo::getTombstoneKey(); + return {static_cast(ptr), static_cast(ptr), + static_cast(ptr)}; + } + + static unsigned getHashValue(const AutoDiffConfig &Val) { + unsigned combinedHash = hash_combine( + ~1U, DenseMapInfo::getHashValue(Val.parameterIndices), + DenseMapInfo::getHashValue(Val.resultIndices), + DenseMapInfo::getHashValue(Val.derivativeGenericSignature)); + return combinedHash; + } + + static bool isEqual(const AutoDiffConfig &LHS, const AutoDiffConfig &RHS) { + return LHS.parameterIndices == RHS.parameterIndices && + LHS.resultIndices == RHS.resultIndices && + LHS.derivativeGenericSignature == RHS.derivativeGenericSignature; + } +}; + template<> struct DenseMapInfo { static SILAutoDiffIndices getEmptyKey() { return { DenseMapInfo::getEmptyKey(), nullptr }; diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index 1df71ebb700a6..a53abe5c290ec 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -686,6 +686,16 @@ ERROR(sil_witness_assoc_conf_not_found,none, ERROR(sil_witness_protocol_conformance_not_found,none, "sil protocol conformance not found", ()) +// SIL differentiability witnesses +ERROR(sil_diff_witness_expected_token,PointsToFirstBadToken, + "expected '%0' in differentiability witness", (StringRef)) +ERROR(sil_diff_witness_expected_index_list,PointsToFirstBadToken, + "expected a space-separated list of indices, e.g. '0 1'", ()) +ERROR(sil_diff_witness_expected_parameter_index,PointsToFirstBadToken, + "expected a parameter index to differentiate with respect to", ()) +ERROR(sil_diff_witness_expected_result_index,PointsToFirstBadToken, + "expected a result index to differentiate with respect to", ()) + // SIL Coverage Map ERROR(sil_coverage_func_not_found, none, "sil function not found %0", (Identifier)) diff --git a/include/swift/Parse/ParseSILSupport.h b/include/swift/Parse/ParseSILSupport.h index 00eaba3e80caf..e3e2225c12df0 100644 --- a/include/swift/Parse/ParseSILSupport.h +++ b/include/swift/Parse/ParseSILSupport.h @@ -32,6 +32,9 @@ namespace swift { virtual bool parseSILGlobal(Parser &P) = 0; virtual bool parseSILWitnessTable(Parser &P) = 0; virtual bool parseSILDefaultWitnessTable(Parser &P) = 0; + // SWIFT_ENABLE_TENSORFLOW + virtual bool parseSILDifferentiabilityWitness(Parser &P) = 0; + // SWIFT_ENABLE_TENSORFLOW END virtual bool parseSILCoverageMap(Parser &P) = 0; virtual bool parseSILProperty(Parser &P) = 0; virtual bool parseSILScope(Parser &P) = 0; diff --git a/include/swift/SIL/SILDifferentiabilityWitness.h b/include/swift/SIL/SILDifferentiabilityWitness.h new file mode 100644 index 0000000000000..46d8e7adee70d --- /dev/null +++ b/include/swift/SIL/SILDifferentiabilityWitness.h @@ -0,0 +1,153 @@ +//===--- SILDifferentiabilityWitness.h - Differentiability witnesses ------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// This file defines the SILDifferentiabilityWitness class, which maps an +// original SILFunction and derivative configuration (parameter indices, result +// indices, derivative generic signature) to derivative functions (JVP and VJP). +// +// SIL differentiability witnesses are generated from the `@differentiable` +// and `@differentiating` attributes AST declaration attributes. +// Differentiability witnesses are canonicalized by the differentiation SIL +// transform, which fills in missing derivative functions. Canonical +// differentiability witnesses from other modules can be deserialized to look up +// derivative functions. +// +//===----------------------------------------------------------------------===// + +#ifndef SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H +#define SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H + +#include "swift/AST/Attr.h" +#include "swift/AST/AutoDiff.h" +#include "swift/AST/GenericSignature.h" +#include "swift/SIL/SILAllocated.h" +#include "llvm/ADT/ilist_node.h" +#include "llvm/ADT/ilist.h" + +namespace swift { + +class SILPrintContext; + +class SILDifferentiabilityWitness + : public llvm::ilist_node, + public SILAllocated +{ +private: + /// The module which contains the differentiability witness. + SILModule &module; + /// The linkage of the differentiability witness. + SILLinkage linkage; + /// The original function. + SILFunction *originalFunction; + /// The parameter indices. + IndexSubset *parameterIndices; + /// The result indices. + IndexSubset *resultIndices; + /// The derivative generic signature (optional). + GenericSignature *derivativeGenericSignature; + /// The JVP (Jacobian-vector products) derivative function. + SILFunction *jvp; + /// The VJP (vector-Jacobian products) derivative function. + SILFunction *vjp; + /// Whether or not this differentiability witness is serialized, which allows + /// devirtualization from another module. + bool serialized; + /// The AST `@differentiable` or `@differentiating` attribute from which the + /// differentiability witness is generated. Used for diagnostics. + /// Null if the differentiability witness is parsed from SIL or if it is + /// deserialized. + DeclAttribute *attribute = nullptr; + + SILDifferentiabilityWitness(SILModule &module, SILLinkage linkage, + SILFunction *originalFunction, + IndexSubset *parameterIndices, + IndexSubset *resultIndices, + GenericSignature *derivativeGenSig, + SILFunction *jvp, SILFunction *vjp, + bool isSerialized, DeclAttribute *attribute) + : module(module), linkage(linkage), originalFunction(originalFunction), + parameterIndices(parameterIndices), resultIndices(resultIndices), + derivativeGenericSignature(derivativeGenSig), jvp(jvp), vjp(vjp), + serialized(isSerialized), attribute(attribute) {} + +public: + static SILDifferentiabilityWitness *create( + SILModule &module, SILLinkage linkage, SILFunction *originalFunction, + IndexSubset *parameterIndices, IndexSubset *resultIndices, + GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp, + bool isSerialized, DeclAttribute *attribute = nullptr); + + SILDifferentiabilityWitnessKey getKey() const; + SILModule &getModule() const { return module; } + SILLinkage getLinkage() const { return linkage; } + SILFunction *getOriginalFunction() const { return originalFunction; } + IndexSubset *getParameterIndices() const { + return parameterIndices; + } + IndexSubset *getResultIndices() const { + return resultIndices; + } + GenericSignature *getDerivativeGenericSignature() const { + return derivativeGenericSignature; + } + SILFunction *getJVP() const { return jvp; } + SILFunction *getVJP() const { return vjp; } + SILFunction *getDerivative(AutoDiffDerivativeFunctionKind kind) const { + switch (kind) { + case AutoDiffDerivativeFunctionKind::JVP: return jvp; + case AutoDiffDerivativeFunctionKind::VJP: return vjp; + } + } + void setJVP(SILFunction *jvp) { this->jvp = jvp; } + void setVJP(SILFunction *vjp) { this->vjp = vjp; } + void setDerivative(AutoDiffDerivativeFunctionKind kind, + SILFunction *derivative) { + switch (kind) { + case AutoDiffDerivativeFunctionKind::JVP: jvp = derivative; break; + case AutoDiffDerivativeFunctionKind::VJP: vjp = derivative; break; + } + } + bool isSerialized() const { return serialized; } + DeclAttribute *getAttribute() const { return attribute; } + + /// Verify that the differentiability witness is well-formed. + void verify(const SILModule &module) const; + + void print(llvm::raw_ostream &os, bool verbose = false) const; + void dump() const; +}; + +} // end namespace swift + +namespace llvm { + +//===----------------------------------------------------------------------===// +// ilist_traits for SILDifferentiabilityWitness +//===----------------------------------------------------------------------===// + +template <> +struct ilist_traits<::swift::SILDifferentiabilityWitness> + : public ilist_node_traits<::swift::SILDifferentiabilityWitness> { + using SILDifferentiabilityWitness = ::swift::SILDifferentiabilityWitness; + +public: + static void deleteNode(SILDifferentiabilityWitness *DW) { + DW->~SILDifferentiabilityWitness(); + } + +private: + void createNode(const SILDifferentiabilityWitness &); +}; + +} // namespace llvm + +#endif // SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H diff --git a/include/swift/SIL/SILModule.h b/include/swift/SIL/SILModule.h index 802bb85041b32..dac005fd4e875 100644 --- a/include/swift/SIL/SILModule.h +++ b/include/swift/SIL/SILModule.h @@ -28,6 +28,8 @@ #include "swift/SIL/SILCoverageMap.h" #include "swift/SIL/SILDeclRef.h" #include "swift/SIL/SILDefaultWitnessTable.h" +// SWIFT_ENABLE_TENSORFLOW +#include "swift/SIL/SILDifferentiabilityWitness.h" #include "swift/SIL/SILFunction.h" #include "swift/SIL/SILGlobalVariable.h" #include "swift/SIL/SILPrintContext.h" @@ -113,6 +115,10 @@ class SILModule { using PropertyListType = llvm::ilist; using WitnessTableListType = llvm::ilist; using DefaultWitnessTableListType = llvm::ilist; + // SWIFT_ENABLE_TENSORFLOW + using DifferentiabilityWitnessListType = + llvm::ilist; + // SWIFT_ENABLE_TENSORFLOW END using CoverageMapCollectionType = llvm::MapVector; @@ -139,6 +145,9 @@ class SILModule { friend SILProperty; friend SILUndef; friend SILWitnessTable; + // SWIFT_ENABLE_TENSORFLOW + friend SILDifferentiabilityWitness; + // SWIFT_ENABLE_TENSORFLOW END friend Lowering::SILGenModule; friend Lowering::TypeConverter; class SerializationCallback; @@ -194,6 +203,17 @@ class SILModule { /// The list of SILDefaultWitnessTables in the module. DefaultWitnessTableListType defaultWitnessTables; + // SWIFT_ENABLE_TENSORFLOW + /// Lookup table for SIL differentiability witnesses from original functions. + /// Indexed by key type: original function, parameter indices, result indices, + /// and derivative generic signature. + llvm::DenseMap + DifferentiabilityWitnessMap; + + /// The list of SILDifferentiabilityWitnesses in the module. + DifferentiabilityWitnessListType differentiabilityWitnesses; + // SWIFT_ENABLE_TENSORFLOW END + /// Lookup table for SIL Global Variables. llvm::StringMap GlobalVariableMap; @@ -446,6 +466,27 @@ class SILModule { return {defaultWitnessTables.begin(), defaultWitnessTables.end()}; } + // SWIFT_ENABLE_TENSORFLOW + using differentiability_witness_iterator = DifferentiabilityWitnessListType::iterator; + using differentiability_witness_const_iterator = DifferentiabilityWitnessListType::const_iterator; + DifferentiabilityWitnessListType &getDifferentiabilityWitnessList() { return differentiabilityWitnesses; } + const DifferentiabilityWitnessListType &getDifferentiabilityWitnessList() const { return differentiabilityWitnesses; } + differentiability_witness_iterator differentiability_witness_begin() { return differentiabilityWitnesses.begin(); } + differentiability_witness_iterator differentiability_witness_end() { return differentiabilityWitnesses.end(); } + differentiability_witness_const_iterator differentiability_witness_begin() const { return differentiabilityWitnesses.begin(); } + differentiability_witness_const_iterator differentiability_witness_end() const { return differentiabilityWitnesses.end(); } + iterator_range + getDifferentiabilityWitnesses() { + return {differentiabilityWitnesses.begin(), + differentiabilityWitnesses.end()}; + } + iterator_range + getDifferentiabilityWitnesses() const { + return {differentiabilityWitnesses.begin(), + differentiabilityWitnesses.end()}; + } + // SWIFT_ENABLE_TENSORFLOW END + using sil_global_iterator = GlobalListType::iterator; using sil_global_const_iterator = GlobalListType::const_iterator; GlobalListType &getSILGlobalList() { return silGlobals; } diff --git a/include/swift/Serialization/SerializedSILLoader.h b/include/swift/Serialization/SerializedSILLoader.h index 41a7b794e6f66..955d3732ea69d 100644 --- a/include/swift/Serialization/SerializedSILLoader.h +++ b/include/swift/Serialization/SerializedSILLoader.h @@ -13,6 +13,9 @@ #ifndef SWIFT_SERIALIZATION_SILLOADER_H #define SWIFT_SERIALIZATION_SILLOADER_H +// SWIFT_ENABLE_TENSORFLOW +#include "swift/AST/AutoDiff.h" +// SWIFT_ENABLE_TENSORFLOW END #include "swift/AST/Decl.h" #include "swift/AST/Identifier.h" #include "swift/SIL/Notifications.h" @@ -32,6 +35,9 @@ class SILModule; class SILVTable; class SILWitnessTable; class SILDefaultWitnessTable; +// SWIFT_ENABLE_TENSORFLOW +class SILDifferentiabilityWitness; +// SWIFT_ENABLE_TENSORFLOW END /// Maintains a list of SILDeserializer, one for each serialized modules /// in ASTContext. It provides lookupSILFunction that will perform lookup @@ -64,6 +70,10 @@ class SerializedSILLoader { SILVTable *lookupVTable(const ClassDecl *C); SILWitnessTable *lookupWitnessTable(SILWitnessTable *C); SILDefaultWitnessTable *lookupDefaultWitnessTable(SILDefaultWitnessTable *C); + // SWIFT_ENABLE_TENSORFLOW + SILDifferentiabilityWitness * + lookupDifferentiabilityWitness(SILDifferentiabilityWitnessKey key); + // SWIFT_ENABLE_TENSORFLOW END /// Invalidate the cached entries for deserialized SILFunctions. void invalidateCaches(); @@ -99,6 +109,11 @@ class SerializedSILLoader { /// Deserialize all Properties in all SILModules. void getAllProperties(); + // SWIFT_ENABLE_TENSORFLOW + /// Deserialize all DifferentiabilityWitnesses in all SILModules. + void getAllDifferentiabilityWitnesses(); + // SWIFT_ENABLE_TENSORFLOW END + SerializedSILLoader(const SerializedSILLoader &) = delete; SerializedSILLoader(SerializedSILLoader &&) = delete; SerializedSILLoader &operator=(const SerializedSILLoader &) = delete; diff --git a/include/swift/Syntax/TokenKinds.def.gyb b/include/swift/Syntax/TokenKinds.def.gyb index 4627055832e9c..ee20521c099fe 100644 --- a/include/swift/Syntax/TokenKinds.def.gyb +++ b/include/swift/Syntax/TokenKinds.def.gyb @@ -165,6 +165,9 @@ SIL_KEYWORD(sil_vtable) SIL_KEYWORD(sil_global) SIL_KEYWORD(sil_witness_table) SIL_KEYWORD(sil_default_witness_table) +// SWIFT_ENABLE_TENSORFLOW +SIL_KEYWORD(sil_differentiability_witness) +// SWIFT_ENABLE_TENSORFLOW END SIL_KEYWORD(sil_coverage_map) SIL_KEYWORD(sil_scope) diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index 47d65c3809dba..27e9c342b5cf6 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -452,6 +452,7 @@ FOR_KNOWN_FOUNDATION_TYPES(CACHE_FOUNDATION_DECL) /// For uniquifying `AutoDiffDerivativeFunctionIdentifier` allocations. llvm::FoldingSet AutoDiffDerivativeFunctionIdentifiers; + // SWIFT_ENABLE_TENSORFLOW END /// A cache of information about whether particular nominal types /// are representable in a foreign language. diff --git a/lib/AST/ASTMangler.cpp b/lib/AST/ASTMangler.cpp index efd460f158243..2e542a3cb66d7 100644 --- a/lib/AST/ASTMangler.cpp +++ b/lib/AST/ASTMangler.cpp @@ -379,11 +379,12 @@ std::string ASTMangler::mangleReabstractionThunkHelper( return finalize(); } +// SWIFT_ENABLE_TENSORFLOW std::string ASTMangler::mangleAutoDiffDerivativeFunctionHelper( StringRef name, AutoDiffDerivativeFunctionKind kind, const SILAutoDiffIndices &indices) { // TODO(TF-20): Make the mangling scheme robust. - // TODO(TF-680): Mangle `@differentiable` atttribute requirements as well. + // TODO(TF-680): Mangle derivative generic signature as well. beginManglingWithoutPrefix(); Buffer << "AD__" << name << '_'; @@ -406,7 +407,7 @@ std::string ASTMangler::mangleAutoDiffLinearMapHelper( StringRef name, AutoDiffLinearMapKind kind, const SILAutoDiffIndices &indices) { // TODO(TF-20): Make the mangling scheme robust. - // TODO(TF-680): Mangle `@differentiable` atttribute requirements as well. + // TODO(TF-680): Mangle derivative generic signature as well. beginManglingWithoutPrefix(); Buffer << "AD__" << name << '_'; @@ -425,6 +426,28 @@ std::string ASTMangler::mangleAutoDiffLinearMapHelper( return result; } +std::string ASTMangler::mangleSILDifferentiabilityWitnessKey( + SILDifferentiabilityWitnessKey key) { + // TODO(TF-20): Make the mangling scheme robust. + beginManglingWithoutPrefix(); + + auto originalName = key.first; + auto *parameterIndices = key.second.parameterIndices; + auto *resultIndices = key.second.resultIndices; + auto *derivativeGenericSignature = key.second.derivativeGenericSignature; + + Buffer << "AD__" << originalName << '_'; + Buffer << "P" << parameterIndices->getString(); + Buffer << "R" << resultIndices->getString(); + if (derivativeGenericSignature) + appendGenericSignature(derivativeGenericSignature); + + auto result = Storage.str().str(); + Storage.clear(); + return result; +} +// SWIFT_ENABLE_TENSORFLOW END + std::string ASTMangler::mangleTypeForDebugger(Type Ty, const DeclContext *DC) { PrettyStackTraceType prettyStackTrace(Ty->getASTContext(), "mangling type for debugger", Ty); diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index 7eff6979dc668..19bf0db1d7e51 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -94,6 +94,9 @@ bool Parser::parseTopLevel() { CASE_SIL(sil_global, SILGlobal) CASE_SIL(sil_witness_table, SILWitnessTable) CASE_SIL(sil_default_witness_table, SILDefaultWitnessTable) + // SWIFT_ENABLE_TENSORFLOW + CASE_SIL(sil_differentiability_witness, SILDifferentiabilityWitness) + // SWIFT_ENABLE_TENSORFLOW END CASE_SIL(sil_coverage_map, SILCoverageMap) CASE_SIL(sil_property, SILProperty) CASE_SIL(sil_scope, SILScope) diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index 281b2836b450e..a89de420a10f8 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -77,6 +77,9 @@ class SILParserTUState : public SILParserTUStateBase { bool parseSILGlobal(Parser &P) override; bool parseSILWitnessTable(Parser &P) override; bool parseSILDefaultWitnessTable(Parser &P) override; + // SWIFT_ENABLE_TENSORFLOW + bool parseSILDifferentiabilityWitness(Parser &P) override; + // SWIFT_ENABLE_TENSORFLOW END bool parseSILCoverageMap(Parser &P) override; bool parseSILProperty(Parser &P) override; bool parseSILScope(Parser &P) override; @@ -6742,6 +6745,252 @@ bool SILParserTUState::parseSILDefaultWitnessTable(Parser &P) { return false; } +// SWIFT_ENABLE_TENSORFLOW +// TODO(TF-893): Dedupe with `SILParser::convertRequirements` upstream. +// Currently, this utility is defined on `SILParser`, but SIL differentiability +// witness is defined on `SILParserTUState` and only has access to `Parser`. +// Consider redefining `SILParser::convertRequirements`as +// `Parser::convertRequirements`. +static void convertRequirements(Parser &P, SILFunction *F, + ArrayRef From, + SmallVectorImpl &To) { + if (From.empty()) { + To.clear(); + return; + } + + auto *GenericEnv = F->getLoweredFunctionType() + ->getGenericSignature() + ->getGenericEnvironment(); + assert(GenericEnv); + (void)GenericEnv; + + IdentTypeReprLookup PerformLookup(P); + // Use parser lexical scopes to resolve references + // to the generic parameters. + auto ResolveToInterfaceType = [&](TypeLoc Ty) -> Type { + Ty.getTypeRepr()->walk(PerformLookup); + swift::performTypeLocChecking(P.Context, Ty, /*isSILMode*/ true, + /*isSILType*/ true, GenericEnv, &P.SF); + assert(Ty.getType()); + return Ty.getType()->mapTypeOutOfContext(); + }; + + for (auto &Req : From) { + if (Req.getKind() == RequirementReprKind::SameType) { + auto FirstType = ResolveToInterfaceType(Req.getFirstTypeLoc()); + auto SecondType = ResolveToInterfaceType(Req.getSecondTypeLoc()); + Requirement ConvertedRequirement(RequirementKind::SameType, FirstType, + SecondType); + To.push_back(ConvertedRequirement); + continue; + } + + if (Req.getKind() == RequirementReprKind::TypeConstraint) { + auto Subject = ResolveToInterfaceType(Req.getSubjectLoc()); + auto Constraint = ResolveToInterfaceType(Req.getConstraintLoc()); + Requirement ConvertedRequirement(RequirementKind::Conformance, Subject, + Constraint); + To.push_back(ConvertedRequirement); + continue; + } + + if (Req.getKind() == RequirementReprKind::LayoutConstraint) { + auto Subject = ResolveToInterfaceType(Req.getSubjectLoc()); + Requirement ConvertedRequirement(RequirementKind::Layout, Subject, + Req.getLayoutConstraint()); + To.push_back(ConvertedRequirement); + continue; + } + llvm_unreachable("Unsupported requirement kind"); + } +} + +/// decl-sil-differentiability-witness ::= +/// 'sil_differentiability_witness' +/// ('[' 'serialized' ']')? +/// sil-linkage? +/// '[' 'parameters' index-subset ']' +/// '[' 'results' index-subset ']' +/// ('[' 'where' derivatve-generic-signature-requirements ']')? +/// sil-function-name ':' sil-type +/// '{' +/// ('jvp' sil-function-name ':' sil-type)? +/// ('vjp' sil-function-name ':' sil-type)? +/// '}' +/// +/// index-subset ::= +/// [0-9]+ (' ' [0-9]+)* +bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) { + P.consumeToken(tok::kw_sil_differentiability_witness); + SILParser State(P); + + // Parse the linkage. + Optional linkage; + if (parseSILLinkage(linkage, P)) + return true; + // Default to public linkage. + if (!linkage) + linkage = SILLinkage::Public; + + // Parse '[serialized]' flag (optional). + bool isSerialized = false; + if (P.Tok.is(tok::l_square) && P.peekToken().is(tok::identifier) && + P.peekToken().getText() == "serialized") { + isSerialized = true; + P.consumeToken(tok::l_square); + P.consumeToken(tok::identifier); + if (P.parseToken(tok::r_square, diag::sil_diff_witness_expected_token, "]")) + return true; + } + + Scope scope(&P, ScopeKind::TopLevel); + Scope body(&P, ScopeKind::FunctionBody); + + // Parse a SIL function name. + auto parseFunctionName = [&](SILFunction *&fn) -> bool { + Identifier name; + SILType ty; + SourceLoc fnNameLoc = P.Tok.getLoc(); + // We need to turn on InSILBody to parse the function reference. + Lexer::SILBodyRAII tmp(*P.L); + GenericEnvironment *ignoredEnv; + if ((State.parseGlobalName(name)) || + P.parseToken(tok::colon, diag::expected_sil_colon_value_ref) || + State.parseSILType(ty, ignoredEnv, /*IsFuncDecl*/ true)) + return true; + + // The function doesn't exist yet. Create a zombie forward declaration. + auto fnType = ty.getAs(); + if (!fnType || !ty.isObject()) { + P.diagnose(fnNameLoc, diag::expected_sil_function_type); + return true; + } + fn = State.getGlobalNameForReference(name, fnType, fnNameLoc, true); + State.TUState.PotentialZombieFns.insert(fn); + return false; + }; + + SourceLoc lastLoc = P.getEndOfPreviousLoc(); + // Parse an index subset, prefaced with the given label. + auto parseIndexSubset = + [&](StringRef label, IndexSubset *& indexSubset) -> bool { + if (P.parseToken(tok::l_square, diag::sil_diff_witness_expected_token, "[")) + return true; + if (P.parseSpecificIdentifier( + label, diag::sil_diff_witness_expected_token, label)) + return true; + // Parse parameter index list. + SmallVector paramIndices; + // Function that parses an index into `paramIndices`. Returns true on error. + auto parseParam = [&]() -> bool { + unsigned index; + // TODO: Reject non-ascending index lists. + if (P.parseUnsignedInteger(index, lastLoc, + diag::sil_diff_witness_expected_index_list)) + return true; + paramIndices.push_back(index); + return false; + }; + // Parse first. + if (parseParam()) + return true; + // Parse rest. + while (P.Tok.isNot(tok::r_square)) + if (parseParam()) + return true; + if (P.parseToken(tok::r_square, diag::sil_diff_witness_expected_token, "]")) + return true; + auto maxIndexRef = + std::max_element(paramIndices.begin(), paramIndices.end()); + indexSubset = IndexSubset::get( + P.Context, maxIndexRef ? *maxIndexRef + 1 : 0, paramIndices); + return false; + }; + // Parse parameter and result indices. + IndexSubset *parameterIndices = nullptr; + IndexSubset *resultIndices = nullptr; + if (parseIndexSubset("parameters", parameterIndices)) + return true; + if (parseIndexSubset("results", resultIndices)) + return true; + + // Parse a trailing 'where' clause (optional). + // This represents derivative generic signature requirements. + GenericSignature *derivativeGenSig = nullptr; + SourceLoc whereLoc; + SmallVector derivativeRequirementReprs; + if (P.Tok.is(tok::l_square) && P.peekToken().is(tok::kw_where)) { + P.consumeToken(tok::l_square); + bool firstTypeInComplete; + P.parseGenericWhereClause(whereLoc, derivativeRequirementReprs, + firstTypeInComplete, + /*AllowLayoutConstraints*/ false); + if (P.parseToken(tok::r_square, diag::sil_diff_witness_expected_token, "]")) + return true; + } + + // Parse original function name. + SILFunction *originalFn; + if (parseFunctionName(originalFn)) + return true; + + // Resolve derivative requirements. + if (!derivativeRequirementReprs.empty()) { + SmallVector requirements; + auto *whereClause = TrailingWhereClause::create( + originalFn->getModule().getASTContext(), whereLoc, + derivativeRequirementReprs); + convertRequirements(P, originalFn, whereClause->getRequirements(), + requirements); + assert(requirements.size() == derivativeRequirementReprs.size()); + derivativeGenSig = evaluateOrDefault( + P.Context.evaluator, + AbstractGenericSignatureRequest{ + originalFn->getLoweredFunctionType()->getGenericSignature(), + /*addedGenericParams=*/{}, + std::move(requirements)}, + nullptr); + } + + // Parse differentiability witness body. + SILFunction *jvp = nullptr; + SILFunction *vjp = nullptr; + if (P.Tok.is(tok::l_brace)) { + // Parse '{'. + SourceLoc lBraceLoc; + P.consumeIf(tok::l_brace, lBraceLoc); + // Parse JVP (optional). + if (P.Tok.is(tok::identifier) && P.Tok.getText() == "jvp") { + P.consumeToken(tok::identifier); + if (P.parseToken(tok::colon, diag::sil_diff_witness_expected_token, ":")) + return true; + Scope body(&P, ScopeKind::FunctionBody); + if (parseFunctionName(jvp)) + return true; + } + // Parse VJP (optional). + if (P.Tok.is(tok::identifier) && P.Tok.getText() == "vjp") { + P.consumeToken(tok::identifier); + if (P.parseToken(tok::colon, diag::sil_diff_witness_expected_token, ":")) + return true; + Scope body(&P, ScopeKind::FunctionBody); + if (parseFunctionName(vjp)) + return true; + } + // Parse '}'. + if (P.parseMatchingToken(tok::r_brace, lastLoc, diag::expected_sil_rbrace, + lBraceLoc)) + return true; + } + + SILDifferentiabilityWitness::create( + M, *linkage, originalFn, parameterIndices, resultIndices, + derivativeGenSig, jvp, vjp, isSerialized); + return false; +} +// SWIFT_ENABLE_TENSORFLOW END + llvm::Optional SILParser::parseSILCoverageExpr( llvm::coverage::CounterExpressionBuilder &Builder) { if (P.Tok.is(tok::integer_literal)) { diff --git a/lib/SIL/CMakeLists.txt b/lib/SIL/CMakeLists.txt index 2657fcfdfb29c..b65a01c26f683 100644 --- a/lib/SIL/CMakeLists.txt +++ b/lib/SIL/CMakeLists.txt @@ -26,6 +26,8 @@ add_swift_host_library(swiftSIL STATIC SILDebugScope.cpp SILDeclRef.cpp SILDefaultWitnessTable.cpp + # SWIFT_ENABLE_TENSORFLOW + SILDifferentiabilityWitness.cpp SILFunction.cpp SILFunctionType.cpp SILGlobalVariable.cpp diff --git a/lib/SIL/SILDifferentiabilityWitness.cpp b/lib/SIL/SILDifferentiabilityWitness.cpp new file mode 100644 index 0000000000000..36cf10e532b94 --- /dev/null +++ b/lib/SIL/SILDifferentiabilityWitness.cpp @@ -0,0 +1,40 @@ +//===--- SILDifferentiabilityWitness.cpp - Differentiability witnesses ----===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "sil-differentiability-witness" + +#include "swift/SIL/SILDifferentiabilityWitness.h" +#include "swift/SIL/SILModule.h" + +using namespace swift; + +SILDifferentiabilityWitness *SILDifferentiabilityWitness::create( + SILModule &module, SILLinkage linkage, SILFunction *originalFunction, + IndexSubset *parameterIndices, IndexSubset *resultIndices, + GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp, + bool isSerialized, DeclAttribute *attribute) { + auto *diffWitness = new (module) SILDifferentiabilityWitness( + module, linkage, originalFunction, parameterIndices, resultIndices, + derivativeGenSig, jvp, vjp, isSerialized, attribute); + // Register the differentiability witness in the module. + assert(!module.DifferentiabilityWitnessMap.count(diffWitness->getKey()) && + "Cannot create duplicate differentiability witness in a module"); + module.DifferentiabilityWitnessMap[diffWitness->getKey()] = diffWitness; + module.getDifferentiabilityWitnessList().push_back(diffWitness); + return diffWitness; +} + +SILDifferentiabilityWitnessKey SILDifferentiabilityWitness::getKey() const { + AutoDiffConfig config{parameterIndices, resultIndices, + derivativeGenericSignature}; + return std::make_pair(originalFunction->getName(), config); +} diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index dd830e41cfbab..a350176642887 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -2695,6 +2695,33 @@ printSILDefaultWitnessTables(SILPrintContext &Ctx, wt->print(Ctx.OS(), Ctx.printVerbose()); } +// SWIFT_ENABLE_TENSORFLOW +static void printSILDifferentiabilityWitnesses( + SILPrintContext &Ctx, + const SILModule::DifferentiabilityWitnessListType &diffWitnesses) { + if (!Ctx.sortSIL()) { + for (auto &dw : diffWitnesses) + dw.print(Ctx.OS(), Ctx.printVerbose()); + return; + } + + std::vector sortedDiffWitnesses; + sortedDiffWitnesses.reserve(diffWitnesses.size()); + for (auto &dw : diffWitnesses) + sortedDiffWitnesses.push_back(&dw); + std::sort(sortedDiffWitnesses.begin(), sortedDiffWitnesses.end(), + [] (const SILDifferentiabilityWitness *w1, + const SILDifferentiabilityWitness *w2) -> bool { + // TODO(TF-893): Sort based on more criteria for deterministic ordering. + return w1->getOriginalFunction()->getName() + .compare(w2->getOriginalFunction()->getName()); + } + ); + for (auto *dw : sortedDiffWitnesses) + dw->print(Ctx.OS(), Ctx.printVerbose()); +} +// SWIFT_ENABLE_TENSORFLOW END + static void printSILCoverageMaps(SILPrintContext &Ctx, const SILModule::CoverageMapCollectionType &CoverageMaps) { @@ -2812,6 +2839,10 @@ void SILModule::print(SILPrintContext &PrintCtx, ModuleDecl *M, printSILVTables(PrintCtx, getVTableList()); printSILWitnessTables(PrintCtx, getWitnessTableList()); printSILDefaultWitnessTables(PrintCtx, getDefaultWitnessTableList()); + // SWIFT_ENABLE_TENSORFLOW + printSILDifferentiabilityWitnesses(PrintCtx, + getDifferentiabilityWitnessList()); + // SWIFT_ENABLE_TENSORFLOW END printSILCoverageMaps(PrintCtx, getCoverageMaps()); printSILProperties(PrintCtx, getPropertyList()); @@ -3026,6 +3057,74 @@ void SILDefaultWitnessTable::dump() const { print(llvm::errs()); } +// SWIFT_ENABLE_TENSORFLOW +void SILDifferentiabilityWitness::print( + llvm::raw_ostream &OS, bool verbose) const { + OS << "// differentiability witness for " + << demangleSymbol(originalFunction->getName()) << '\n'; + PrintOptions qualifiedSILTypeOptions = PrintOptions::printQualifiedSILType(); + // sil_differentiability_witness (linkage)? + OS << "sil_differentiability_witness "; + printLinkage(OS, linkage, ForDefinition); + // ([serialized])? + if (isSerialized()) + OS << "[serialized] "; + // [parameters ...] + OS << "[parameters "; + interleave(getParameterIndices()->getIndices(), + [&](unsigned index) { OS << index; }, + [&] { OS << ' '; }); + // [results ...] + OS << "] [results "; + interleave(getResultIndices()->getIndices(), + [&](unsigned index) { OS << index; }, + [&] { OS << ' '; }); + OS << ']'; + // ([where ...])? + if (auto *derivativeGenSig = getDerivativeGenericSignature()) { + ArrayRef requirements; + SmallVector requirementsScratch; + auto *origGenEnv = originalFunction->getGenericEnvironment(); + if (derivativeGenSig) { + if (origGenEnv) { + requirementsScratch = derivativeGenSig->requirementsNotSatisfiedBy( + origGenEnv->getGenericSignature()); + requirements = requirementsScratch; + } else { + requirements = derivativeGenSig->getRequirements(); + } + } + if (!requirements.empty()) { + OS << " [where "; + auto subPrinter = PrintOptions::printSIL(); + interleave(requirements, + [&](Requirement req) { + req.print(OS, subPrinter); + }, + [&] { OS << ", "; }); + OS << ']'; + } + } + // @original-function-name : $original-sil-type + OS << " @" << originalFunction->getName() << " : " + << originalFunction->getLoweredType(); + // { + // jvp: @jvp-function-name : $jvp-sil-type + // vjp: @vjp-function-name : $vjp-sil-type + // } + OS << " {\n"; + if (jvp) + OS << " jvp: @" << jvp->getName() << " : " << jvp->getLoweredType() << '\n'; + if (vjp) + OS << " vjp: @" << vjp->getName() << " : " << vjp->getLoweredType() << '\n'; + OS << "}\n\n"; +} + +void SILDifferentiabilityWitness::dump() const { + print(llvm::errs()); +} +// SWIFT_ENABLE_TENSORFLOW END + void SILCoverageMap::print(SILPrintContext &PrintCtx) const { llvm::raw_ostream &OS = PrintCtx.OS(); OS << "sil_coverage_map " << QuotedString(getFile()) << " " diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp index 42ae01866d9bc..f290cc4810f45 100644 --- a/lib/SIL/SILVerifier.cpp +++ b/lib/SIL/SILVerifier.cpp @@ -5350,6 +5350,56 @@ void SILGlobalVariable::verify() const { } } +// SWIFT_ENABLE_TENSORFLOW +/// Verify that a differentiability witness follows invariants. +void SILDifferentiabilityWitness::verify(const SILModule &M) const { +#ifdef NDEBUG + if (!M.getOptions().VerifyAll) + return; +#endif + auto origFnType = originalFunction->getLoweredFunctionType(); + CanGenericSignature derivativeCanGenSig; + if (auto *derivativeGenSig = getDerivativeGenericSignature()) + derivativeCanGenSig = derivativeGenSig->getCanonicalSignature(); + auto requireSameType = + [&](CanSILFunctionType type1, CanSILFunctionType type2, + const Twine &complaint) { + if (type1 == type2) + return; + llvm::dbgs() << "SIL verification failed: " << complaint << "\n"; + llvm::dbgs() << " " << type1 << "\n " << type2 << "\n\n"; + llvm::dbgs() << "In differentiability witness:\n"; + print(llvm::dbgs()); + // We abort by default because we want to always crash in + // the debugger. + if (AbortOnFailure) + abort(); + else + exit(1); + }; + if (jvp) { + // TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType` + // to accept result indices. + auto expectedJVPType = origFnType->getAutoDiffDerivativeFunctionType( + getParameterIndices(), /*resultIndex*/ *getResultIndices()->begin(), + AutoDiffDerivativeFunctionKind::JVP, M.Types, + LookUpConformanceInModule(M.getSwiftModule()), derivativeCanGenSig); + requireSameType(jvp->getLoweredFunctionType(), expectedJVPType, + "JVP type does not match expected JVP type"); + } + if (vjp) { + // TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType` + // to result indices. + auto expectedVJPType = origFnType->getAutoDiffDerivativeFunctionType( + getParameterIndices(), /*resultIndex*/ *getResultIndices()->begin(), + AutoDiffDerivativeFunctionKind::VJP, M.Types, + LookUpConformanceInModule(M.getSwiftModule()), derivativeCanGenSig); + requireSameType(vjp->getLoweredFunctionType(), expectedVJPType, + "VJP type does not match expected VJP type"); + } +} +// SWIFT_ENABLE_TENSORFLOW END + /// Verify the module. void SILModule::verify() const { #ifdef NDEBUG @@ -5433,6 +5483,22 @@ void SILModule::verify() const { } wt.verify(*this); } + + // SWIFT_ENABLE_TENSORFLOW + // Check all differentiability witnesses. + LLVM_DEBUG(llvm::dbgs() << + "*** Checking differentiability witnesses for duplicates ***\n"); + llvm::DenseSet diffWitnesses; + for (auto &dw : getDifferentiabilityWitnesses()) { + LLVM_DEBUG(llvm::dbgs() << "Differentiability Witness:\n"; dw.dump()); + if (!diffWitnesses.insert(dw.getKey()).second) { + llvm::errs() << "Differentiability witness redefined: "; + dw.dump(); + assert(false && "triggering standard assertion failure routine"); + } + dw.verify(*this); + } + // SWIFT_ENABLE_TENSORFLOW END // Check property descriptors. LLVM_DEBUG(llvm::dbgs() << "*** Checking property descriptors ***\n"); diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index 7d293e673f8b2..7eb6c7d722b9f 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -164,9 +164,13 @@ SILDeserializer::SILDeserializer( kind == sil_index_block::SIL_GLOBALVAR_NAMES || kind == sil_index_block::SIL_WITNESS_TABLE_NAMES || kind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_NAMES || - kind == sil_index_block::SIL_PROPERTY_OFFSETS)) && + kind == sil_index_block::SIL_PROPERTY_OFFSETS || +// SWIFT_ENABLE_TENSORFLOW + kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES)) && "Expect SIL_FUNC_NAMES, SIL_VTABLE_NAMES, SIL_GLOBALVAR_NAMES, \ - SIL_WITNESS_TABLE_NAMES, or SIL_DEFAULT_WITNESS_TABLE_NAMES."); + SIL_WITNESS_TABLE_NAMES, SIL_DEFAULT_WITNESS_TABLE_NAMES, \ + SIL_DIFFERENTIABILITY_WITNESS_NAMES, or SIL_PROPERTY_OFFSETS."); +// SWIFT_ENABLE_TENSORFLOW END (void)prevKind; if (kind == sil_index_block::SIL_FUNC_NAMES) @@ -179,6 +183,10 @@ SILDeserializer::SILDeserializer( WitnessTableList = readFuncTable(scratch, blobData); else if (kind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_NAMES) DefaultWitnessTableList = readFuncTable(scratch, blobData); + // SWIFT_ENABLE_TENSORFLOW + else if (kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES) + DifferentiabilityWitnessList = readFuncTable(scratch, blobData); + // SWIFT_ENABLE_TENSORFLOW END else if (kind == sil_index_block::SIL_PROPERTY_OFFSETS) { // No matching 'names' block for property descriptors needed yet. MF->allocateBuffer(Properties, scratch); @@ -215,6 +223,14 @@ SILDeserializer::SILDeserializer( offKind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_OFFSETS) && "Expect a SIL_DEFAULT_WITNESS_TABLE_OFFSETS record."); MF->allocateBuffer(DefaultWitnessTables, scratch); + // SWIFT_ENABLE_TENSORFLOW + } else if (kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES) { + assert((next.Kind == llvm::BitstreamEntry::Record && + offKind == + sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_OFFSETS) && + "Expect a SIL_DIFFERENTIABILITY_WITNESS_OFFSETS record."); + MF->allocateBuffer(DifferentiabilityWitnesses, scratch); + // SWIFT_ENABLE_TENSORFLOW END } } } @@ -3280,6 +3296,99 @@ SILDeserializer::lookupDefaultWitnessTable(SILDefaultWitnessTable *existingWt) { return Wt; } +// SWIFT_ENABLE_TENSORFLOW +SILDifferentiabilityWitness * +SILDeserializer::readDifferentiabilityWitness(DeclID DId) { + if (DId == 0) + return nullptr; + assert(DId <= DifferentiabilityWitnesses.size() && + "Invalid SILDifferentiabilityWitness ID"); + + auto &diffWitnessOrOffset = DifferentiabilityWitnesses[DId-1]; + if (diffWitnessOrOffset.isFullyDeserialized()) + return diffWitnessOrOffset.get(); + + BCOffsetRAII restoreOffset(SILCursor); + SILCursor.JumpToBit(diffWitnessOrOffset.getOffset()); + auto entry = SILCursor.advance(AF_DontPopBlockAtEnd); + if (entry.Kind == llvm::BitstreamEntry::Error) { + LLVM_DEBUG(llvm::dbgs() + << "Cursor advance error in readDifferentiabilityWitness.\n"); + return nullptr; + } + + SmallVector scratch; + StringRef blobData; + unsigned kind = SILCursor.readRecord(entry.ID, scratch, &blobData); + assert(kind == SIL_DIFFERENTIABILITY_WITNESS && + "Expected sil_differentiability_witness"); + (void)kind; + + DeclID originalNameId, jvpNameId, vjpNameId; + unsigned rawLinkage, isSerialized, numParameterIndices, numResultIndices; + GenericSignatureID derivativeGenSigID; + ArrayRef rawParameterAndResultIndices; + + DifferentiabilityWitnessLayout::readRecord( + scratch, originalNameId, rawLinkage, isSerialized, derivativeGenSigID, + jvpNameId, vjpNameId, numParameterIndices, numResultIndices, + rawParameterAndResultIndices); + + auto linkage = fromStableSILLinkage(rawLinkage); + assert(linkage && "Expected value linkage for sil_differentiability_witness"); + auto originalName = MF->getIdentifierText(originalNameId); + auto jvpName = MF->getIdentifierText(jvpNameId); + auto vjpName = MF->getIdentifierText(vjpNameId); + auto *original = getFuncForReference(originalName); + assert(original && "Original function must be found"); + auto *jvp = getFuncForReference(jvpName); + if (!jvpName.empty()) + assert(jvp && "JVP function must be found if JVP name is not empty"); + auto *vjp = getFuncForReference(vjpName); + if (!vjpName.empty()) + assert(vjp && "VJP function must be found if VJP name is not empty"); + auto derivativeGenSig = MF->getGenericSignature(derivativeGenSigID); + + SmallVector parameterAndResultIndices( + rawParameterAndResultIndices.begin(), + rawParameterAndResultIndices.end()); + assert(parameterAndResultIndices.size() == + numParameterIndices + numResultIndices && + "Parameter/result indices count mismatch"); + auto *parameterIndices = IndexSubset::get( + MF->getContext(), original->getLoweredFunctionType()->getNumParameters(), + ArrayRef(parameterAndResultIndices) + .take_front(numParameterIndices)); + auto *resultIndices = IndexSubset::get( + MF->getContext(), original->getLoweredFunctionType()->getNumResults(), + ArrayRef(parameterAndResultIndices) + .take_back(numResultIndices)); + + auto *diffWitness = SILDifferentiabilityWitness::create( + SILMod, *linkage, original, parameterIndices, resultIndices, + derivativeGenSig, jvp, vjp, isSerialized); + diffWitnessOrOffset.set(diffWitness, /*isFullyDeserialized*/ true); + return diffWitness; +} + +SILDifferentiabilityWitness *SILDeserializer::lookupDifferentiabilityWitness( + StringRef mangledDiffWitnessKey) { + if (!DifferentiabilityWitnessList) + return nullptr; + auto iter = DifferentiabilityWitnessList->find(mangledDiffWitnessKey); + if (iter == DifferentiabilityWitnessList->end()) + return nullptr; + return readDifferentiabilityWitness(*iter); +} + +void SILDeserializer::getAllDifferentiabilityWitnesses() { + if (!DifferentiabilityWitnessList) + return; + for (unsigned I = 0, E = DifferentiabilityWitnesses.size(); I < E; ++I) + readDifferentiabilityWitness(I+1); +} +// SWIFT_ENABLE_TENSORFLOW END + SILDeserializer::~SILDeserializer() { // Drop our references to anything we've deserialized. for (auto &fnEntry : Funcs) { diff --git a/lib/Serialization/DeserializeSIL.h b/lib/Serialization/DeserializeSIL.h index c2790c01e48c2..76b118f30d0d6 100644 --- a/lib/Serialization/DeserializeSIL.h +++ b/lib/Serialization/DeserializeSIL.h @@ -58,6 +58,13 @@ namespace swift { MutableArrayRef> Properties; + // SWIFT_ENABLE_TENSORFLOW + std::unique_ptr DifferentiabilityWitnessList; + MutableArrayRef< + ModuleFile::PartiallySerialized> + DifferentiabilityWitnesses; + // SWIFT_ENABLE_TENSORFLOW END + /// A declaration will only llvm::DenseMap ConformanceToWitnessTableMap; @@ -126,6 +133,10 @@ namespace swift { SILDefaultWitnessTable * readDefaultWitnessTable(serialization::DeclID, SILDefaultWitnessTable *existingWt); + // SWIFT_ENABLE_TENSORFLOW + SILDifferentiabilityWitness * + readDifferentiabilityWitness(serialization::DeclID); + // SWIFT_ENABLE_TENSORFLOW END Optional readKeyPathComponent(ArrayRef ListOfValues, unsigned &nextValue); @@ -145,6 +156,10 @@ namespace swift { SILWitnessTable *lookupWitnessTable(SILWitnessTable *wt); SILDefaultWitnessTable * lookupDefaultWitnessTable(SILDefaultWitnessTable *wt); + // SWIFT_ENABLE_TENSORFLOW + SILDifferentiabilityWitness * + lookupDifferentiabilityWitness(StringRef mangledDiffWitnessKey); + // SWIFT_ENABLE_TENSORFLOW END /// Invalidate all cached SILFunctions. void invalidateFunctionCache(); @@ -169,6 +184,9 @@ namespace swift { getAllWitnessTables(); getAllDefaultWitnessTables(); getAllProperties(); + // SWIFT_ENABLE_TENSORFLOW + getAllDifferentiabilityWitnesses(); + // SWIFT_ENABLE_TENSORFLOW END } /// Deserialize all SILFunctions inside the module and add them to SILMod. @@ -192,6 +210,12 @@ namespace swift { /// to SILMod. void getAllProperties(); + // SWIFT_ENABLE_TENSORFLOW + /// Deserialize all DifferentiabilityWitnesses inside the module and add + /// them to SILMod. + void getAllDifferentiabilityWitnesses(); + // SWIFT_ENABLE_TENSORFLOW END + SILDeserializer(ModuleFile *MF, SILModule &M, DeserializationNotificationHandlerSet *callback); diff --git a/lib/Serialization/SILFormat.h b/lib/Serialization/SILFormat.h index 9ee90b4e22e43..df862fcb5c607 100644 --- a/lib/Serialization/SILFormat.h +++ b/lib/Serialization/SILFormat.h @@ -122,6 +122,10 @@ namespace sil_index_block { SIL_DEFAULT_WITNESS_TABLE_NAMES, SIL_DEFAULT_WITNESS_TABLE_OFFSETS, SIL_PROPERTY_OFFSETS, + // SWIFT_ENABLE_TENSORFLOW + SIL_DIFFERENTIABILITY_WITNESS_NAMES, + SIL_DIFFERENTIABILITY_WITNESS_OFFSETS, + // SWIFT_ENABLE_TENSORFLOW END }; using ListLayout = BCGenericRecordLayout< @@ -176,6 +180,8 @@ namespace sil_block { SIL_DIFFERENTIABLE_ATTR, SIL_INST_DIFFERENTIABLE_FUNCTION, SIL_INST_DIFFERENTIABLE_FUNCTION_EXTRACT, + SIL_DIFFERENTIABILITY_WITNESS, + // SWIFT_ENABLE_TENSORFLOW END // We also share these layouts from the decls block. Their enumerators must // not overlap with ours. @@ -280,6 +286,21 @@ namespace sil_block { DeclIDField >; + // SWIFT_ENABLE_TENSORFLOW + using DifferentiabilityWitnessLayout = BCRecordLayout< + SIL_DIFFERENTIABILITY_WITNESS, + DeclIDField, // Original function name + SILLinkageField, // Linkage + BCFixed<1>, // Is serialized? + GenericSignatureIDField, // Derivative function generic signature + DeclIDField, // JVP function name + DeclIDField, // VJP function name + BCVBR<8>, // Number of parameter indices + BCVBR<8>, // Number of result indices + BCArray // Parameter and result indices + >; + // SWIFT_ENABLE_TENSORFLOW END + using SILFunctionLayout = BCRecordLayout, // transparent diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index aac745d43601d..dcffb9d7848d6 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -826,6 +826,9 @@ void Serializer::writeBlockInfoBlock() { BLOCK_RECORD(sil_index_block, SIL_DEFAULT_WITNESS_TABLE_NAMES); BLOCK_RECORD(sil_index_block, SIL_DEFAULT_WITNESS_TABLE_OFFSETS); BLOCK_RECORD(sil_index_block, SIL_PROPERTY_OFFSETS); + // SWIFT_ENABLE_TENSORFLOW + BLOCK_RECORD(sil_index_block, SIL_DIFFERENTIABILITY_WITNESS_OFFSETS); + // SWIFT_ENABLE_TENSORFLOW END #undef BLOCK #undef BLOCK_RECORD diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp index 67f6e8aeaba48..849e537638a38 100644 --- a/lib/Serialization/SerializeSIL.cpp +++ b/lib/Serialization/SerializeSIL.cpp @@ -183,6 +183,14 @@ namespace { /// Holds the list of Properties. std::vector PropertyOffset; + // SWIFT_ENABLE_TENSORFLOW + /// Maps differentiability witness identifier to an ID. + Table DifferentiabilityWitnessList; + /// Holds the list of SIL differentiability witnesses. + std::vector DifferentiabilityWitnessOffset; + uint32_t /*DeclID*/ NextDifferentiabilityWitnessID = 1; + // SWIFT_ENABLE_TENSORFLOW END + /// Give each SILBasicBlock a unique ID. llvm::DenseMap BasicBlockMap; @@ -232,6 +240,10 @@ namespace { void writeSILWitnessTable(const SILWitnessTable &wt); void writeSILWitnessTableEntry(const SILWitnessTable::Entry &entry); void writeSILDefaultWitnessTable(const SILDefaultWitnessTable &wt); + // SWIFT_ENABLE_TENSORFLOW + void writeSILDifferentiabilityWitness( + const SILDifferentiabilityWitness &dw); + // SWIFT_ENABLE_TENSORFLOW END void writeSILProperty(const SILProperty &prop); void writeSILBlock(const SILModule *SILMod); @@ -2215,7 +2227,13 @@ static void writeIndexTable(Serializer &S, kind == sil_index_block::SIL_VTABLE_NAMES || kind == sil_index_block::SIL_GLOBALVAR_NAMES || kind == sil_index_block::SIL_WITNESS_TABLE_NAMES || - kind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_NAMES) && + kind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_NAMES || + // SWIFT_ENABLE_TENSORFLOW + // TODO(TF-893): Update surrounding comment/assertion text when + // upstreaming code to master. Comments/assertions have not been + // updated to keep the code diff small. + kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES) && + // SWIFT_ENABLE_TENSORFLOW END "SIL function table, global, vtable and (default) witness table " "are supported"); llvm::SmallString<4096> hashTableBlob; @@ -2273,11 +2291,22 @@ void SILSerializer::writeIndexTables() { DefaultWitnessTableOffset); } + // SWIFT_ENABLE_TENSORFLOW + if (!DifferentiabilityWitnessOffset.empty()) { + writeIndexTable(S, List, + sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES, + DifferentiabilityWitnessList); + Offset.emit(ScratchRecord, + sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_OFFSETS, + DifferentiabilityWitnessOffset); + } + // SWIFT_ENABLE_TENSORFLOW END + if (!PropertyOffset.empty()) { Offset.emit(ScratchRecord, sil_index_block::SIL_PROPERTY_OFFSETS, PropertyOffset); } - + } void SILSerializer::writeSILGlobalVar(const SILGlobalVariable &g) { @@ -2469,6 +2498,58 @@ writeSILDefaultWitnessTable(const SILDefaultWitnessTable &wt) { } } +// SWIFT_ENABLE_TENSORFLOW +void SILSerializer:: +writeSILDifferentiabilityWitness(const SILDifferentiabilityWitness &dw) { + Mangle::ASTMangler mangler; + auto mangledKey = mangler.mangleSILDifferentiabilityWitnessKey(dw.getKey()); + size_t nameLength = mangledKey.size(); + char *stringStorage = + static_cast(StringTable.Allocate(nameLength, 1)); + std::memcpy(stringStorage, mangledKey.data(), nameLength); + DifferentiabilityWitnessList[StringRef(stringStorage, nameLength)] = + NextDifferentiabilityWitnessID++; + DifferentiabilityWitnessOffset.push_back(Out.GetCurrentBitNo()); + + auto *original = dw.getOriginalFunction(); + addReferencedSILFunction(original, /*DeclOnly*/ true); + IdentifierID jvpID = 0; + IdentifierID vjpID = 0; + if (auto *jvp = dw.getJVP()) { + addReferencedSILFunction(jvp, /*DeclOnly*/ true); + jvpID = S.addUniquedStringRef(jvp->getName()); + } + if (auto *vjp = dw.getVJP()) { + addReferencedSILFunction(vjp, /*DeclOnly*/ true); + vjpID = S.addUniquedStringRef(vjp->getName()); + } + SmallVector parameterAndResultIndices( + dw.getParameterIndices()->begin(), dw.getParameterIndices()->end()); + parameterAndResultIndices.append(dw.getResultIndices()->begin(), + dw.getResultIndices()->end()); + auto originalFnType = original->getLoweredFunctionType(); + assert(originalFnType->getNumParameters() == + dw.getParameterIndices()->getCapacity() && + "Original function parameter count should match differentiability " + "witness parameter indices capacity"); + assert(originalFnType->getNumResults() == + dw.getResultIndices()->getCapacity() && + "Original function result count should match differentiability " + "witness result indices capacity"); + + DifferentiabilityWitnessLayout::emitRecord( + Out, ScratchRecord, SILAbbrCodes[DifferentiabilityWitnessLayout::Code], + S.addUniquedStringRef(original->getName()), + toStableSILLinkage(dw.getLinkage()), + dw.isSerialized(), + S.addGenericSignatureRef(dw.getDerivativeGenericSignature()), + jvpID, vjpID, + dw.getParameterIndices()->getNumIndices(), + dw.getResultIndices()->getNumIndices(), + parameterAndResultIndices); +} +// SWIFT_ENABLE_TENSORFLOW END + /// Helper function for whether to emit a function body. bool SILSerializer::shouldEmitFunctionBody(const SILFunction *F, bool isReference) { @@ -2529,6 +2610,9 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) { registerSILAbbr(); registerSILAbbr(); registerSILAbbr(); + // SWIFT_ENABLE_TENSORFLOW + registerSILAbbr(); + // SWIFT_ENABLE_TENSORFLOW END registerSILAbbr(); registerSILAbbr(); @@ -2537,6 +2621,7 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) { registerSILAbbr(); registerSILAbbr(); registerSILAbbr(); + // SWIFT_ENABLE_TENSORFLOW END // Register the abbreviation codes so these layouts can exist in both // decl blocks and sil blocks. @@ -2589,6 +2674,17 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) { writeSILDefaultWitnessTable(wt); } + // SWIFT_ENABLE_TENSORFLOW + // Write out differentiability witnesses. + for (const auto &diffWitness : SILMod->getDifferentiabilityWitnessList()) { + // TODO(TF-893): Consider checking + // `SILMod->shouldSerializeEntitiesAssociatedWithDeclContext` on the JVP/VJP + // functions. + if ((ShouldSerializeAll || diffWitness.isSerialized())) + writeSILDifferentiabilityWitness(diffWitness); + } + // SWIFT_ENABLE_TENSORFLOW END + // Emit only declarations if it is a module with pre-specializations. // And only do it in optimized builds. bool emitDeclarationsForOnoneSupport = diff --git a/lib/Serialization/SerializedSILLoader.cpp b/lib/Serialization/SerializedSILLoader.cpp index 3e910c8e5324f..f9223c4fabf4c 100644 --- a/lib/Serialization/SerializedSILLoader.cpp +++ b/lib/Serialization/SerializedSILLoader.cpp @@ -127,6 +127,19 @@ lookupDefaultWitnessTable(SILDefaultWitnessTable *WT) { return nullptr; } +// SWIFT_ENABLE_TENSORFLOW +SILDifferentiabilityWitness * +SerializedSILLoader::lookupDifferentiabilityWitness( + SILDifferentiabilityWitnessKey key) { + Mangle::ASTMangler mangler; + std::string mangledKey = mangler.mangleSILDifferentiabilityWitnessKey(key); + for (auto &Des : LoadedSILSections) + if (auto *diffWitness = Des->lookupDifferentiabilityWitness(mangledKey)) + return diffWitness; + return nullptr; +} +// SWIFT_ENABLE_TENSORFLOW END + void SerializedSILLoader::invalidateCaches() { for (auto &Des : LoadedSILSections) Des->invalidateFunctionCache(); @@ -185,3 +198,10 @@ void SerializedSILLoader::getAllProperties() { Des->getAllProperties(); } +// SWIFT_ENABLE_TENSORFLOW +/// Deserialize all DifferentiabilityWitnesses in all SILModules. +void SerializedSILLoader::getAllDifferentiabilityWitnesses() { + for (auto &Des : LoadedSILSections) + Des->getAllDifferentiabilityWitnesses(); +} +// SWIFT_ENABLE_TENSORFLOW END diff --git a/lib/Syntax/SyntaxSerialization.cpp.gyb b/lib/Syntax/SyntaxSerialization.cpp.gyb index 9c917abedc135..46f77413cbf22 100644 --- a/lib/Syntax/SyntaxSerialization.cpp.gyb +++ b/lib/Syntax/SyntaxSerialization.cpp.gyb @@ -48,6 +48,9 @@ uint8_t WrapperTypeTraits::numericValue(const tok &Value) { case tok::kw_sil_global: case tok::kw_sil_witness_table: case tok::kw_sil_default_witness_table: + // SWIFT_ENABLE_TENSORFLOW + case tok::kw_sil_differentiability_witness: + // SWIFT_ENABLE_TENSORFLOW END case tok::kw_sil_coverage_map: case tok::kw_sil_scope: case tok::sil_dollar: diff --git a/test/AutoDiff/sil_differentiability_witness.sil b/test/AutoDiff/sil_differentiability_witness.sil new file mode 100644 index 0000000000000..8f56d2480dcef --- /dev/null +++ b/test/AutoDiff/sil_differentiability_witness.sil @@ -0,0 +1,80 @@ +// RUN: %target-sil-opt %s | %target-sil-opt | %FileCheck %s + +// RUN: %empty-directory(%t) +// RUN: %target-sil-opt %s -emit-sib -o %t/tmp.sib -module-name main +// RUN: %target-sil-opt %t/tmp.sib -o %t/tmp.2.sib -module-name main +// RUN: %target-sil-opt %t/tmp.2.sib -module-name main | %FileCheck %s + +// Round-trip parsing/printing and serialization/deserialization test. +// NOTE: deserialization currently fails if public function bodies are removed +// so that they are only declarations. This may require investigation. + +sil_stage raw + +import Builtin +import Swift +import SwiftShims + +// Test public non-generic function. +// SIL differentiability witness: +// - Has public linkage (implicit). +// - Has no `where` clause. + +sil [ossa] @foo : $@convention(thin) (Float) -> Float { +bb0(%0 : $Float): + return %0 : $Float +} + +sil @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +bb0(%0 : $Float): + return undef : $(Float, @callee_guaranteed (Float) -> Float) +} + +sil @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +bb0(%0 : $Float): + return undef : $(Float, @callee_guaranteed (Float) -> Float) +} + +sil_differentiability_witness [parameters 0] [results 0] @foo : $@convention(thin) (Float) -> Float { + jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) + vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +} + +// CHECK-LABEL: // differentiability witness for foo +// CHECK: sil_differentiability_witness [parameters 0] [results 0] @foo : $@convention(thin) (Float) -> Float { +// CHECK: jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK: vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK: } + +// Test internal generic function. +// SIL differentiability witness: +// - Has hidden linkage. +// - Has `where` clause. + +sil hidden [ossa] @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T { +bb0(%0 : $*T, %1 : $*T, %2 : $Float): + copy_addr %1 to [initialization] %0 : $*T + %void = tuple () + return %void : $() +} + +sil hidden @AD__generic__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) { +bb0(%0 : $*τ_0_0, %1 : $*τ_0_0, %2 : $Float): + return undef : $@callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector +} + +sil hidden @AD__generic__vjp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) { +bb0(%0 : $*τ_0_0, %1 : $*τ_0_0, %2 : $Float): + return undef : $@callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float) +} + +sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 { + jvp: @AD__generic__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) + vjp: @AD__generic__vjp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) +} + +// CHECK-LABEL: // differentiability witness for generic +// CHECK: sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 { +// CHECK: jvp: @AD__generic__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) +// CHECK: vjp: @AD__generic__vjp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) +// CHECK: }