From 871f91698e41f974e393ec7c6199d794b9f532ab Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Sun, 17 Mar 2019 13:08:29 -0700 Subject: [PATCH 1/3] SILDifferentiableFunctionType --- include/swift/AST/ASTContext.h | 3 + include/swift/AST/Attr.def | 1 + include/swift/AST/Attr.h | 13 ++ include/swift/AST/AutoDiff.h | 46 ++-- include/swift/AST/DiagnosticsParse.def | 18 ++ include/swift/AST/DiagnosticsSema.def | 15 ++ include/swift/AST/KnownStdlibTypes.def | 3 + include/swift/AST/TypeMatcher.h | 2 + include/swift/AST/TypeNodes.def | 2 + include/swift/AST/TypeRepr.h | 50 +++++ include/swift/AST/TypeReprNodes.def | 2 + include/swift/AST/Types.h | 118 +++++++++- include/swift/Parse/Parser.h | 6 +- include/swift/SIL/SILInstruction.h | 12 +- .../Serialization/DeclTypeRecordNodes.def | 2 + include/swift/Serialization/ModuleFormat.h | 12 + lib/AST/ASTContext.cpp | 43 +++- lib/AST/ASTDumper.cpp | 31 +++ lib/AST/ASTMangler.cpp | 11 + lib/AST/ASTPrinter.cpp | 22 ++ lib/AST/ASTWalker.cpp | 11 + lib/AST/AutoDiff.cpp | 8 +- lib/AST/NameLookup.cpp | 2 + lib/AST/Type.cpp | 197 +++++++++++++++- lib/AST/TypeRepr.cpp | 27 +++ lib/AST/TypeWalker.cpp | 10 + lib/IRGen/GenDiffFunc.cpp | 210 +++++++++++++++--- lib/IRGen/GenFunc.cpp | 2 +- lib/IRGen/GenType.cpp | 4 + lib/IRGen/GenType.h | 4 +- lib/IRGen/IRGenDebugInfo.cpp | 2 + lib/IRGen/MetadataRequest.cpp | 25 +++ lib/Parse/ParseDecl.cpp | 42 +++- lib/Parse/ParseType.cpp | 55 +++++ lib/ParseSIL/ParseSIL.cpp | 13 +- lib/SIL/SILFunctionType.cpp | 13 -- lib/SIL/SILInstructions.cpp | 15 +- lib/SIL/TypeLowering.cpp | 2 + .../Mandatory/Differentiation.cpp | 25 +-- lib/Sema/TypeCheckType.cpp | 169 +++++++++++++- lib/Serialization/Deserialization.cpp | 46 ++++ lib/Serialization/Serialization.cpp | 25 +++ .../sil_differentiable_function_type.sil | 22 ++ 43 files changed, 1227 insertions(+), 114 deletions(-) create mode 100644 test/AutoDiff/sil_differentiable_function_type.sil diff --git a/include/swift/AST/ASTContext.h b/include/swift/AST/ASTContext.h index 17090a2d1602e..e1d1899435a3e 100644 --- a/include/swift/AST/ASTContext.h +++ b/include/swift/AST/ASTContext.h @@ -494,6 +494,9 @@ class ASTContext final { /// has been imported. Otherwise, this returns null. StructDecl *getTensorDataTypeDecl() const; + /// Retrieve the type for Swift.AnyDerivative. + CanType getAnyDerivativeType() const; + /// Retrieve the type Swift.Never. CanType getNeverType() const; diff --git a/include/swift/AST/Attr.def b/include/swift/AST/Attr.def index 8b8d5f5414ba9..0a760505e03e1 100644 --- a/include/swift/AST/Attr.def +++ b/include/swift/AST/Attr.def @@ -53,6 +53,7 @@ TYPE_ATTR(noescape) TYPE_ATTR(escaping) // SWIFT_ENABLE_TENSORFLOW TYPE_ATTR(differentiable) +TYPE_ATTR(sil_differentiable) TYPE_ATTR(autodiff) TYPE_ATTR(nondiff) diff --git a/include/swift/AST/Attr.h b/include/swift/AST/Attr.h index 032e70c0b3724..dc9793f8b9c04 100644 --- a/include/swift/AST/Attr.h +++ b/include/swift/AST/Attr.h @@ -68,6 +68,10 @@ class TypeAttributes { Optional convention = None; Optional conventionWitnessMethodProtocol = None; + // SWIFT_ENABLE_TENSORFLOW + Optional> + differentiabilityReprKindAndOrder = None; + // For an opened existential type, the known ID. Optional OpenedID; @@ -126,6 +130,15 @@ class TypeAttributes { bool hasConvention() const { return convention.hasValue(); } StringRef getConvention() const { return *convention; } + // SWIFT_ENABLE_TENSORFLOW + bool hasDifferentiabilityRepresentationKindAndOrder() const { + return differentiabilityReprKindAndOrder.hasValue(); + } + std::pair + getDifferentiabilityRepresentationKindAndOrder() const { + return *differentiabilityReprKindAndOrder; + } + bool hasOwnership() const { return getOwnership() != ReferenceOwnership::Strong; } diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 7a2522a617c8d..fcccc59cbe48b 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -301,7 +301,7 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode { SmallBitVector indicesBitVec(capacity, false); for (auto index : indices) indicesBitVec.set(index); - return AutoDiffIndexSubset::get(ctx, indicesBitVec); + return get(ctx, indicesBitVec); } static AutoDiffIndexSubset *getDefault(ASTContext &ctx, unsigned capacity, @@ -557,6 +557,31 @@ class AutoDiffAssociatedFunctionIdentifier : public llvm::FoldingSetNode { } }; +/// The kind of ABI used to represent a differentiable function. +enum class DifferentiabilityRepresentationKind : unsigned { + /// The function is linear and is represented as a bundle of the original + /// function and its transpose. Its differential is the function itself. Its + /// pullback is its transpose. + /// + /// For original function `(T...) -> U`, there are a few typing invariants: + /// 1. T = T.TangentVector = T.CotangentVector + /// 2. U = U.TangentVector = U.CotangentVector + /// + /// |----------------------| + /// | Original | Transpose | + /// |----------------------| + Linear = 0, + + /// The function is represented as a bundle of the original function and + /// JVP functions at every order. JVP functions must be thin. + /// + /// 1 2 ... n + /// |----------------------------------------| + /// | Original | JVP@1 | JVP@2 | ... | JVP@n | + /// |----------------------------------------| + Normal = 1 +}; + /// Automatic differentiation utility namespace. namespace autodiff { @@ -606,8 +631,8 @@ class VectorSpace { Vector, /// A product of vector spaces as a tuple. Tuple, - /// A function type whose innermost result conforms to `AdditiveArithmetic`. - Function + /// An existential `AdditiveArithmetic` type. + Existential }; private: @@ -617,16 +642,12 @@ class VectorSpace { Type vectorType; // Tuple TupleType *tupleType; - // Function - AnyFunctionType *functionType; Value(Type vectorType) : vectorType(vectorType) {} Value(TupleType *tupleType) : tupleType(tupleType) {} - Value(AnyFunctionType *functionType) : functionType(functionType) {} } value; - VectorSpace(Kind kind, Value value) - : kind(kind), value(value) {} + VectorSpace(Kind kind, Value value) : kind(kind), value(value) {} public: VectorSpace() = delete; @@ -637,12 +658,11 @@ class VectorSpace { static VectorSpace getTuple(TupleType *tupleTy) { return {Kind::Tuple, tupleTy}; } - static VectorSpace getFunction(AnyFunctionType *fnTy) { - return {Kind::Function, fnTy}; - } + static VectorSpace getExistential(ASTContext &ctx); bool isVector() const { return kind == Kind::Vector; } bool isTuple() const { return kind == Kind::Tuple; } + bool isExistential() const { return kind == Kind::Existential; } Kind getKind() const { return kind; } Type getVector() const { @@ -653,10 +673,6 @@ class VectorSpace { assert(kind == Kind::Tuple); return value.tupleType; } - AnyFunctionType *getFunction() const { - assert(kind == Kind::Function); - return value.functionType; - } Type getType() const; CanType getCanonicalType() const; diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index 60b3e1afa8017..584b6a08f1ee3 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -1393,6 +1393,24 @@ ERROR(convention_attribute_witness_method_expected_colon,none, ERROR(convention_attribute_witness_method_expected_protocol,none, "expected protocol name in 'witness_method' 'convention' attribute", ()) +// sil_differentiable +ERROR(sil_differentiable_attribute_expected_lparen,none, + "expected '(' after 'sil_differentiable' attribute", ()) +ERROR(sil_differentiable_attribute_expected_max_order,none, + "expected a max differentiation order in 'sil_differentiable' attribute", ()) +ERROR(sil_differentiable_attribute_expected_rparen,none, + "expected ')' after convention name for 'sil_differentiable' attribute", ()) +ERROR(sil_differentiable_attribute_expected_lbrace,none, + "expected '{' in a '@sil_differentiable' type", ()) +ERROR(sil_differentiable_attribute_expected_differential,none, + "expected 'differential:'", ()) +ERROR(sil_differentiable_attribute_expected_pullback,none, + "expected 'pullback:' ", ()) +ERROR(sil_differentiable_attribute_expected_transpose,none, + "expected 'transpose:' ", ()) +ERROR(sil_differentiable_attribute_expected_rbrace,none, + "expected '}' to end '@sil_differentiable' type", ()) + // objc ERROR(attr_objc_missing_colon,none, "missing ':' after selector piece in @objc attribute", ()) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index d7d1ee160f551..7953958a4bb70 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -3793,6 +3793,21 @@ ERROR(sil_metatype_multiple_reprs,none, "metatypes in SIL can only be one of @thin, @thick, or @objc_metatype", ()) +// SWIFT_ENABLE_TENSORFLOW +// @sil_differentiable types +ERROR(sil_differentiable_attr_not_applicable,none, + "'@sil_differentiable' is not applicable to this type", ()) +ERROR(sil_differentiable_required_original_function_field,none, + "an original function type field is required in a '@sil_differentiable'", ()) +ERROR(sil_differentiable_required_field,none, + "a '%0' function type field is required in a '@sil_differentiable'", (StringRef)) +ERROR(sil_differentiable_fields_must_be_function_type,none, + "fields in a '@sil_differentiable' type must be function types", ()) +ERROR(sil_differentiable_invalid_field,none, + "invalid field for the specified '@sil_differentiable' representation kind", ()) +ERROR(sil_differentiable_field_cannot_be_generic,none, + "'@sil_differentiable' field type cannot be generic", ()) + //------------------------------------------------------------------------------ // MARK: @objc and @nonobjc //------------------------------------------------------------------------------ diff --git a/include/swift/AST/KnownStdlibTypes.def b/include/swift/AST/KnownStdlibTypes.def index d2a0122b89fb8..434d971abe7ae 100644 --- a/include/swift/AST/KnownStdlibTypes.def +++ b/include/swift/AST/KnownStdlibTypes.def @@ -84,4 +84,7 @@ KNOWN_STDLIB_TYPE_DECL(KeyedEncodingContainer, NominalTypeDecl, 1) KNOWN_STDLIB_TYPE_DECL(KeyedDecodingContainer, NominalTypeDecl, 1) KNOWN_STDLIB_TYPE_DECL(RangeReplaceableCollection, ProtocolDecl, 1) +// SWIFT_ENABLE_TENSORFLOW +KNOWN_STDLIB_TYPE_DECL(AnyDerivative, StructDecl, 0) + #undef KNOWN_STDLIB_TYPE_DECL diff --git a/include/swift/AST/TypeMatcher.h b/include/swift/AST/TypeMatcher.h index 28582a054a001..3a09bf8d14e64 100644 --- a/include/swift/AST/TypeMatcher.h +++ b/include/swift/AST/TypeMatcher.h @@ -239,6 +239,8 @@ class TypeMatcher { TRIVIAL_CASE(SILFunctionType) TRIVIAL_CASE(SILBlockStorageType) TRIVIAL_CASE(SILBoxType) + // SWIFT_ENABLE_TENSORFLOW + TRIVIAL_CASE(SILDifferentiableFunctionType) TRIVIAL_CASE(ProtocolCompositionType) bool visitLValueType(CanLValueType firstLValue, Type secondType, diff --git a/include/swift/AST/TypeNodes.def b/include/swift/AST/TypeNodes.def index 1e15e580f2ff9..349d771a620e8 100644 --- a/include/swift/AST/TypeNodes.def +++ b/include/swift/AST/TypeNodes.def @@ -148,6 +148,8 @@ ARTIFICIAL_TYPE(SILFunction, Type) ARTIFICIAL_TYPE(SILBlockStorage, Type) ARTIFICIAL_TYPE(SILBox, Type) ARTIFICIAL_TYPE(SILToken, Type) +// SWIFT_ENABLE_TENSORFLOW +ARTIFICIAL_TYPE(SILDifferentiableFunction, Type) TYPE(ProtocolComposition, Type) TYPE(LValue, Type) TYPE(InOut, Type) diff --git a/include/swift/AST/TypeRepr.h b/include/swift/AST/TypeRepr.h index dbeca8717497a..57b40f4cc1591 100644 --- a/include/swift/AST/TypeRepr.h +++ b/include/swift/AST/TypeRepr.h @@ -1150,6 +1150,8 @@ inline bool TypeRepr::isSimple() const { case TypeReprKind::InOut: case TypeReprKind::Composition: case TypeReprKind::OpaqueReturn: + // SWIFT_ENABLE_TENSORFLOW + case TypeReprKind::SILDifferentiableFunction: return false; case TypeReprKind::SimpleIdent: case TypeReprKind::GenericIdent: @@ -1170,6 +1172,54 @@ inline bool TypeRepr::isSimple() const { llvm_unreachable("bad TypeRepr kind"); } +// SWIFT_ENABLE_TENSORFLOW +class SILDifferentiableFunctionTypeRepr final : public TypeRepr { + GenericParamList *GenericParams; + GenericEnvironment *GenericEnv = nullptr; + TypeRepr *Original; + TypeRepr *Differential; + TypeRepr *Pullback; + TypeRepr *Transpose; + SourceRange Braces; + +public: + SILDifferentiableFunctionTypeRepr( + GenericParamList *genericParams, TypeRepr *original, + TypeRepr *differential, TypeRepr *pullback, TypeRepr *transpose, + SourceRange braces) + : TypeRepr(TypeReprKind::SILDifferentiableFunction), + GenericParams(genericParams), Original(original), + Differential(differential), Pullback(pullback), Transpose(transpose), + Braces(braces) {} + + GenericParamList *getGenericParams() const { return GenericParams; }; + GenericEnvironment *getGenericEnvironment() const { return GenericEnv; }; + void setGenericEnvironment(GenericEnvironment *env) { + assert(GenericEnv == nullptr); + GenericEnv = env; + } + TypeRepr *getOriginal() const { return Original; } + TypeRepr *getDifferential() const { return Differential; } + TypeRepr *getPullback() const { return Pullback; } + TypeRepr *getTranspose() const { return Transpose; } + + SourceRange getBraces() const { return Braces; } + + static bool classof(const TypeRepr *T) { + return T->getKind() == TypeReprKind::SILDifferentiableFunction; + } + + static bool classof(const SILDifferentiableFunctionTypeRepr *T) { + return true; + } + +private: + SourceLoc getStartLocImpl() const { return Braces.Start; } + SourceLoc getEndLocImpl() const { return Braces.End; } + void printImpl(ASTPrinter &Printer, const PrintOptions &Opts) const; + friend class TypeRepr; +}; + } // end namespace swift namespace llvm { diff --git a/include/swift/AST/TypeReprNodes.def b/include/swift/AST/TypeReprNodes.def index 653af0b565540..1913661bb0408 100644 --- a/include/swift/AST/TypeReprNodes.def +++ b/include/swift/AST/TypeReprNodes.def @@ -60,6 +60,8 @@ ABSTRACT_TYPEREPR(Specifier, TypeRepr) TYPEREPR(Owned, SpecifierTypeRepr) TYPEREPR(Fixed, TypeRepr) TYPEREPR(SILBox, TypeRepr) +// SWIFT_ENABLE_TENSORFLOW +TYPEREPR(SILDifferentiableFunction, TypeRepr) LAST_TYPEREPR(SILBox) #undef ABSTRACT_TYPEREPR diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index 3057b27f99cd8..98147bfb49fd4 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -4155,12 +4155,16 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode, SILModule &module, LookupConformanceFn lookupConformance, CanGenericSignature whereClauseGenericSignature = nullptr); - /// Returns a bit vector that specifices which parameters you can + /// Returns an index subset that specifices which parameters you can /// differentiate with respect to for this differentiable function type. (e.g. - /// which parameters are not @nondiff). The function type must be - /// differentiable. + /// which parameters are not `@nondiff`). AutoDiffIndexSubset *getDifferentiationParameterIndices(); + /// Returns an index subset that specifices which results you can + /// differentiate for this differentiable function type. (e.g. which + /// parameters are not `@nondiff`). + AutoDiffIndexSubset *getDifferentiationResultIndices(); + /// If this is a @convention(witness_method) function with a class /// constrained self parameter, return the class constraint for the /// Self type. @@ -4387,6 +4391,112 @@ class SILTokenType final : public TypeBase { }; DEFINE_EMPTY_CAN_TYPE_WRAPPER(SILTokenType, Type) +// SWIFT_ENABLE_TENSORFLOW +class SILDifferentiableFunctionType; +typedef CanTypeWrapper + CanSILDifferentiableFunctionType; + +/// The SIL-only type for differentiable functions, which represent a bundle of +/// the original function and autodiff-asssociated functions. A SIL +/// differentiable function type stores 3 types: the original function type, the +/// type of the first-order differential and the type of the first-order +/// pullback. These types help keep track of the abstraction pattern of all +/// autodiff-associated functions. The generic signature of this function +/// redirects to that of the original function, which serves as the source of +/// truth. +class SILDifferentiableFunctionType final + : public TypeBase, public llvm::FoldingSetNode { + int maxOrder; + DifferentiabilityRepresentationKind representationKind; + CanGenericSignature genericSig; + AutoDiffIndexSubset *parameterIndices, *resultIndices; + CanSILFunctionType originalType, differentialType, pullbackType; + + SILDifferentiableFunctionType( + ASTContext &C, int maxOrder, DifferentiabilityRepresentationKind reprKind, + CanGenericSignature genericSig, AutoDiffIndexSubset *parameterIndices, + AutoDiffIndexSubset *resultIndices, CanSILFunctionType originalType, + CanSILFunctionType differentialType, CanSILFunctionType pullbackType); + +public: + static CanSILDifferentiableFunctionType get( + ASTContext &C, int maxOrder, DifferentiabilityRepresentationKind reprKind, + CanGenericSignature genericSig, AutoDiffIndexSubset *parameterIndices, + AutoDiffIndexSubset *resultIndices, CanSILFunctionType originalType, + CanSILFunctionType differentialType, CanSILFunctionType pullbackType); + + static CanSILDifferentiableFunctionType getLinear( + ASTContext &C, CanGenericSignature genericSig, + AutoDiffIndexSubset *parameterIndices, AutoDiffIndexSubset *resultIndices, + CanSILFunctionType originalType, CanSILFunctionType transposeType); + + CanGenericSignature getGenericSignature() const { + return genericSig; + } + + /// Returns the maximum order that this function can be differentiated at. + /// `-1` if the function can be differentiated at any order. + int getMaxOrder() const { + return maxOrder; + } + + DifferentiabilityRepresentationKind getRepresentationKind() const { + return representationKind; + } + + AutoDiffIndexSubset *getParameterIndices() const { + return parameterIndices; + } + + AutoDiffIndexSubset *getResultIndices() const { + return resultIndices; + } + + CanSILFunctionType getOriginalFunctionType() const { + return originalType; + } + + // Returns the original function type that reflects differentiation parameter + // indices and result indices with '@nondiff' attributes on parameters and + // results. This is used primarily for printing. + CanSILFunctionType getOriginalFunctionTypeWithDifferentiabilityFlags(); + + CanSILFunctionType getDifferentialType() const { + return differentialType; + } + + CanSILFunctionType getPullbackType() const { + return pullbackType; + } + + CanSILFunctionType + getAssociatedFunctionType(AutoDiffAssociatedFunctionKind kind, + unsigned order) const; + + CanSILDifferentiableFunctionType getLinearTransposeType() const; + + static void Profile(llvm::FoldingSetNodeID &id, + int maxOrder, + DifferentiabilityRepresentationKind reprKind, + CanGenericSignature genericSig, + AutoDiffIndexSubset *parameterIndices, + AutoDiffIndexSubset *resultIndices, + CanSILFunctionType originalType, + CanSILFunctionType differentialType, + CanSILFunctionType pullbackType); + + void Profile(llvm::FoldingSetNodeID &id) { + return Profile(id, maxOrder, representationKind, genericSig, + parameterIndices, resultIndices, originalType, + differentialType, pullbackType); + } + + static bool classof(const TypeBase *T) { + return T->getKind() == TypeKind::SILDifferentiableFunction; + } +}; +DEFINE_EMPTY_CAN_TYPE_WRAPPER(SILDifferentiableFunctionType, Type) + /// A type with a special syntax that is always sugar for a library type. The /// library type may have multiple base types. For unary syntax sugar, see /// UnarySyntaxSugarType. @@ -5738,6 +5848,8 @@ inline bool TypeBase::hasSimpleTypeRepr() const { switch (getKind()) { case TypeKind::Function: case TypeKind::GenericFunction: + // SWIFT_ENABLE_TENSORFLOW + case TypeKind::SILDifferentiableFunction: return false; case TypeKind::Metatype: diff --git a/include/swift/Parse/Parser.h b/include/swift/Parse/Parser.h index 4f35036513285..a98f8131462ad 100644 --- a/include/swift/Parse/Parser.h +++ b/include/swift/Parse/Parser.h @@ -1114,7 +1114,11 @@ class Parser { ParserResult parseSILBoxType(GenericParamList *generics, const TypeAttributes &attrs, Optional &GenericsScope); - + // SWIFT_ENABLE_TENSORFLOW + ParserResult parseSILDifferentiableFunctionType( + GenericParamList *generics, const TypeAttributes &attrs, + Optional &GenericsScope); + ParserResult parseTypeTupleBody(); ParserResult parseTypeArray(TypeRepr *Base); diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h index 5af3dd3f5030a..36bfaf2567181 100644 --- a/include/swift/SIL/SILInstruction.h +++ b/include/swift/SIL/SILInstruction.h @@ -7766,6 +7766,9 @@ class AutoDiffFunctionInst final : AutoDiffIndexSubset *parameterIndices; /// The order of differentiation. unsigned differentiationOrder; + /// Whether this instruction produces a legacy '@differentiable' function. + /// When false, this instruction produces a `@sil_differentiable` type. + bool useNewSILDiffFuncType; /// The number of operands. The first operand is always the original function. /// The rest of operands determined by the order of differentiation and whether /// this is the new AD model or the legacy reverse-mode AD model. @@ -7775,7 +7778,8 @@ class AutoDiffFunctionInst final : AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, SILValue originalFunction, - ArrayRef associatedFunctions); + ArrayRef associatedFunctions, + bool useNewSILDiffFuncType = false); public: static AutoDiffFunctionInst *create(SILModule &module, @@ -7783,11 +7787,13 @@ class AutoDiffFunctionInst final : AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, SILValue originalFunction, - ArrayRef associatedFunctions); + ArrayRef associatedFunctions, + bool useNewSILDiffFuncType = false); static SILType getAutoDiffType(SILValue original, unsigned differentiationOrder, - AutoDiffIndexSubset *parameterIndices); + AutoDiffIndexSubset *parameterIndices, + bool useNewSILDiffFuncType); /// Returns the original function. SILValue getOriginalFunction() const { return getAllOperands()[0].get(); } diff --git a/include/swift/Serialization/DeclTypeRecordNodes.def b/include/swift/Serialization/DeclTypeRecordNodes.def index 46d4b167a0608..eea645431b1b4 100644 --- a/include/swift/Serialization/DeclTypeRecordNodes.def +++ b/include/swift/Serialization/DeclTypeRecordNodes.def @@ -102,6 +102,8 @@ TYPE(REFERENCE_STORAGE) TYPE(UNBOUND_GENERIC) TYPE(OPTIONAL) TYPE(SIL_FUNCTION) +// SWIFT_ENABLE_TENSORFLOW +TYPE(SIL_DIFFERENTIABLE_FUNCTION) TYPE(DYNAMIC_SELF) TYPE(OPENED_EXISTENTIAL) TYPE(EXISTENTIAL_METATYPE) diff --git a/include/swift/Serialization/ModuleFormat.h b/include/swift/Serialization/ModuleFormat.h index 1646b861b0236..a805b62722995 100644 --- a/include/swift/Serialization/ModuleFormat.h +++ b/include/swift/Serialization/ModuleFormat.h @@ -878,6 +878,18 @@ namespace decls_block { // followed by error result type/convention // Optionally a protocol conformance (for witness_methods) >; + + // SWIFT_ENABLE_TENSORFLOW + using SILDifferentiableFunctionTypeLayout = BCRecordLayout< + SIL_DIFFERENTIABLE_FUNCTION_TYPE, + BCVBR<4>, // max order + BCVBR<4>, // representation kind + GenericSignatureIDField, // generic signature + TypeIDField, // original function type + TypeIDField, // differential type + TypeIDField, // pullback type + BCArray> // parameter indices and result indices + >; using SILBlockStorageTypeLayout = BCRecordLayout< SIL_BLOCK_STORAGE_TYPE, diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index cf5608c8abe60..f8180f72edeb4 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -382,6 +382,9 @@ FOR_KNOWN_FOUNDATION_TYPES(CACHE_FOUNDATION_DECL) llvm::FoldingSet SILFunctionTypes; llvm::DenseMap SILBlockStorageTypes; llvm::FoldingSet SILBoxTypes; + // SWIFT_ENABLE_TENSORFLOW + llvm::FoldingSet + SILDifferentiableFunctionTypes; llvm::DenseMap IntegerTypes; llvm::FoldingSet BuiltinVectorTypes; llvm::FoldingSet CompoundNames; @@ -828,6 +831,12 @@ StructDecl *ASTContext::getTensorDataTypeDecl() const { return nullptr; } +CanType ASTContext::getAnyDerivativeType() const { + if (auto *anyDerivativeDecl = getAnyDerivativeDecl()) + return anyDerivativeDecl->getDeclaredType()->getCanonicalType(); + return CanType(); +} + CanType ASTContext::getNeverType() const { auto neverDecl = getNeverDecl(); if (!neverDecl) @@ -3274,15 +3283,6 @@ SILFunctionType::SILFunctionType(GenericSignature *genericSig, ExtInfo ext, "Cannot return an @noescape function type"); } } - - // SWIFT_ENABLE_TENSORFLOW - // Make sure that NotDifferentiable parameters only exist on differentiable - // functions. - if (!ext.isDifferentiable()) - for (auto param : getParameters()) - assert(param.getDifferentiability() == - SILParameterDifferentiability::DifferentiableOrNotApplicable && - "non-differentiable function has NotDifferentiable parameter"); #endif } @@ -4413,6 +4413,31 @@ CanSILBoxType SILBoxType::get(CanType boxedType) { return get(boxedType->getASTContext(), layout, subMap); } +// SWIFT_ENABLE_TENSORFLOW +CanSILDifferentiableFunctionType SILDifferentiableFunctionType::get( + ASTContext &C, int maxOrder, DifferentiabilityRepresentationKind reprKind, + CanGenericSignature genericSig, AutoDiffIndexSubset *parameterIndices, + AutoDiffIndexSubset *resultIndices, CanSILFunctionType originalType, + CanSILFunctionType differentialType, CanSILFunctionType pullbackType) { + void *insertPos = nullptr; + auto &types = C.getImpl().SILDifferentiableFunctionTypes; + llvm::FoldingSetNodeID id; + SILDifferentiableFunctionType::Profile(id, maxOrder, reprKind, genericSig, + parameterIndices, resultIndices, + originalType, differentialType, + pullbackType); + if (auto existing = types.FindNodeOrInsertPos(id, insertPos)) + return CanSILDifferentiableFunctionType(existing); + + auto newFn = new (C, AllocationArena::Permanent) + SILDifferentiableFunctionType(C, maxOrder, reprKind, genericSig, + parameterIndices, resultIndices, + originalType, differentialType, + pullbackType); + types.InsertNode(newFn, insertPos); + return CanSILDifferentiableFunctionType(newFn); +} + LayoutConstraint LayoutConstraint::getLayoutConstraint(LayoutConstraintKind Kind, ASTContext &C) { diff --git a/lib/AST/ASTDumper.cpp b/lib/AST/ASTDumper.cpp index b0dfb4e84e03d..e042ce372aea3 100644 --- a/lib/AST/ASTDumper.cpp +++ b/lib/AST/ASTDumper.cpp @@ -3614,6 +3614,37 @@ namespace { PrintWithColorRAII(OS, ParenthesisColor) << ')'; } + // SWIFT_ENABLE_TENSORFLOW + void visitSILDifferentiableFunctionType(SILDifferentiableFunctionType *T, + StringRef label) { + printCommon(label, "sil_differentiable_function_type"); + printField("max_order", T->getMaxOrder()); + StringRef reprKindStr; + switch (T->getRepresentationKind()) { + case DifferentiabilityRepresentationKind::Linear: + reprKindStr = "linear"; + break; + case DifferentiabilityRepresentationKind::Normal: + reprKindStr = "normal"; + break; + } + printField("representation_kind", reprKindStr); + auto getIndicesString = [&](AutoDiffIndexSubset *indices) { + std::string result = "("; + interleave(indices->getIndices(), [&result](unsigned index) { + result += llvm::utostr(index); + }, [&result] { result += ' '; }); + return result; + }; + printField("parameter_indices", + getIndicesString(T->getParameterIndices())); + printField("result_indices", getIndicesString(T->getResultIndices())); + printField("original", T->getOriginalFunctionType()->getString()); + printField("differential", T->getDifferentialType()->getString()); + printField("pullback", T->getPullbackType()->getString()); + PrintWithColorRAII(OS, ParenthesisColor) << ')'; + } + void visitArraySliceType(ArraySliceType *T, StringRef label) { printCommon(label, "array_slice_type"); printRec(T->getBaseType()); diff --git a/lib/AST/ASTMangler.cpp b/lib/AST/ASTMangler.cpp index d4e4bccd5f8d1..434756933de26 100644 --- a/lib/AST/ASTMangler.cpp +++ b/lib/AST/ASTMangler.cpp @@ -1099,6 +1099,17 @@ void ASTMangler::appendType(Type type, const ValueDecl *forDecl) { return; } + // SWIFT_ENABLE_TENSORFLOW + case TypeKind::SILDifferentiableFunction: { + auto box = cast(tybase); + for (auto ty : {box->getOriginalFunctionType(), + box->getDifferentialType(), box->getPullbackType()}) { + appendType(ty); + } + appendOperator("df"); + return; + } + case TypeKind::SILBlockStorage: llvm_unreachable("should never be mangled"); } diff --git a/lib/AST/ASTPrinter.cpp b/lib/AST/ASTPrinter.cpp index 3c184db2fd2ca..a459b48b67b94 100644 --- a/lib/AST/ASTPrinter.cpp +++ b/lib/AST/ASTPrinter.cpp @@ -4071,6 +4071,28 @@ class TypePrinter : public TypeVisitor { if (totalResults != 1) Printer << ")"; } + // SWIFT_ENABLE_TENSORFLOW + void visitSILDifferentiableFunctionType(SILDifferentiableFunctionType *T) { + Printer << "@sil_differentiable("; + switch (T->getRepresentationKind()) { + case DifferentiabilityRepresentationKind::Linear: + Printer << "linear) {"; + visit(T->getOriginalFunctionTypeWithDifferentiabilityFlags()); + Printer << ", transpose: "; + visit(T->getPullbackType()); + break; + case DifferentiabilityRepresentationKind::Normal: + Printer << T->getMaxOrder() << ") {"; + visit(T->getOriginalFunctionType()); + Printer << ", differential: "; + visit(T->getDifferentialType()); + Printer << ", pullback: "; + visit(T->getDifferentialType()); + break; + } + Printer << '}'; + } + void visitSILBlockStorageType(SILBlockStorageType *T) { Printer << "@block_storage "; printWithParensIfNotSimple(T->getCaptureType()); diff --git a/lib/AST/ASTWalker.cpp b/lib/AST/ASTWalker.cpp index 8dd0ff12bbd95..c627b8e7ec6bc 100644 --- a/lib/AST/ASTWalker.cpp +++ b/lib/AST/ASTWalker.cpp @@ -1840,6 +1840,17 @@ bool Traversal::visitSILBoxTypeRepr(SILBoxTypeRepr *T) { return false; } +// SWIFT_ENABLE_TENSORFLOW +bool Traversal::visitSILDifferentiableFunctionTypeRepr( + SILDifferentiableFunctionTypeRepr *T) { + for (auto *repr : {T->getOriginal(), T->getDifferential(), + T->getPullback(), T->getTranspose()}) + if (repr) + if (doIt(repr)) + return true; + return false; +} + Expr *Expr::walk(ASTWalker &walker) { return Traversal(walker).doIt(this); } diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index db5d790ffa5c4..225727715fd0a 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -345,14 +345,18 @@ void AutoDiffParameterIndicesBuilder::setParameters(unsigned lowerBound, parameters.set(lowerBound, upperBound); } +VectorSpace VectorSpace::getExistential(ASTContext &ctx) { + return {Kind::Existential, ctx.getAnyDerivativeType()}; +} + Type VectorSpace::getType() const { switch (kind) { case Kind::Vector: return value.vectorType; case Kind::Tuple: return value.tupleType; - case Kind::Function: - return value.functionType; + case Kind::Existential: + return value.vectorType; } } diff --git a/lib/AST/NameLookup.cpp b/lib/AST/NameLookup.cpp index 126b1d2c3f797..a57607c71f7a7 100644 --- a/lib/AST/NameLookup.cpp +++ b/lib/AST/NameLookup.cpp @@ -1953,6 +1953,8 @@ directReferencesForTypeRepr(Evaluator &evaluator, case TypeReprKind::Protocol: case TypeReprKind::Shared: case TypeReprKind::SILBox: + // SWIFT_ENABLE_TENSORFLOW + case TypeReprKind::SILDifferentiableFunction: return { }; case TypeReprKind::OpaqueReturn: diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index f83edbc980380..6696c18cbb5f4 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -221,6 +221,8 @@ bool CanType::isReferenceTypeImpl(CanType type, bool functionsCount) { case TypeKind::BoundGenericEnum: case TypeKind::BoundGenericStruct: case TypeKind::SILToken: + // SWIFT_ENABLE_TENSORFLOW + case TypeKind::SILDifferentiableFunction: #define REF_STORAGE(Name, ...) \ case TypeKind::Name##Storage: #include "swift/AST/ReferenceStorage.def" @@ -1118,6 +1120,8 @@ CanType TypeBase::computeCanonicalType() { case TypeKind::SILBox: case TypeKind::SILFunction: case TypeKind::SILToken: + // SWIFT_ENABLE_TENSORFLOW + case TypeKind::SILDifferentiableFunction: llvm_unreachable("SIL-only types are always canonical!"); case TypeKind::ProtocolComposition: { @@ -3650,6 +3654,23 @@ case TypeKind::Id: #endif return base; } + + // SWIFT_ENABLE_TENSORFLOW + case TypeKind::SILDifferentiableFunction: { + auto type = cast(base); + CanSILFunctionType fnTypes[3] = { + type->getOriginalFunctionType(), + type->getDifferentialType(), + type->getPullbackType() + }; + for (Type &fnType : fnTypes) + fnType = fnType.transformRec(fn); + return SILDifferentiableFunctionType::get( + type->getASTContext(), type->getMaxOrder(), + type->getRepresentationKind(), type->getGenericSignature(), + type->getParameterIndices(), type->getResultIndices(), + fnTypes[0], fnTypes[1], fnTypes[2]); + } case TypeKind::SILFunction: { auto fnTy = cast(base); @@ -4279,6 +4300,8 @@ ReferenceCounting TypeBase::getReferenceCounting() { case TypeKind::Function: case TypeKind::GenericFunction: case TypeKind::SILFunction: + // SWIFT_ENABLE_TENSORFLOW + case TypeKind::SILDifferentiableFunction: case TypeKind::SILBlockStorage: case TypeKind::Error: case TypeKind::Unresolved: @@ -4357,6 +4380,165 @@ Type TypeBase::openAnyExistentialType(OpenedArchetypeType *&opened) { return opened; } +// SWIFT_ENABLE_TENSORFLOW +AutoDiffIndexSubset * +SILFunctionType::getDifferentiationParameterIndices() { + SmallBitVector indices(getNumParameters(), true); + for (auto valueAndIndex : enumerate(getParameters())) + if (valueAndIndex.value().getDifferentiability() == + SILParameterDifferentiability::NotDifferentiable) + indices.reset(valueAndIndex.index()); + return AutoDiffIndexSubset::get(getASTContext(), indices); +} + +AutoDiffIndexSubset * +SILFunctionType::getDifferentiationResultIndices() { + SmallBitVector indices(getNumResults(), true); + // TODO(rxwei): Add result differentiability and compute the correct result + // indices. + return AutoDiffIndexSubset::get(getASTContext(), indices); +} + +SILDifferentiableFunctionType::SILDifferentiableFunctionType( + ASTContext &C, int maxOrder, DifferentiabilityRepresentationKind reprKind, + CanGenericSignature genericSig, AutoDiffIndexSubset *parameterIndices, + AutoDiffIndexSubset *resultIndices, CanSILFunctionType originalType, + CanSILFunctionType differentialType, CanSILFunctionType pullbackType) + : TypeBase(TypeKind::SILDifferentiableFunction, &C, + originalType->getRecursiveProperties()), + maxOrder(maxOrder), representationKind(reprKind), genericSig(genericSig), + parameterIndices(parameterIndices), resultIndices(resultIndices), + originalType(originalType), differentialType(differentialType), + pullbackType(pullbackType) { + assert(parameterIndices->getCapacity() == originalType->getNumParameters()); + assert(resultIndices->getCapacity() == originalType->getNumResults()); + assert(!originalType->getGenericSignature()); + assert(!differentialType->getGenericSignature()); + assert(!pullbackType->getGenericSignature()); +#ifndef NDEBUG + auto hasNondiff = [](CanSILFunctionType fnType) -> bool { + return llvm::any_of(fnType->getParameters(), + [&](const SILParameterInfo ¶m) { + // TODO(rxwei): Handle result differentiability. + return param.getDifferentiability() == + SILParameterDifferentiability::NotDifferentiable; + }); + }; + assert(!hasNondiff(originalType) && + "Original function type should not have '@nondiff'"); + assert(!hasNondiff(differentialType) && + "Differential function type should not have '@nondiff'"); + assert(!hasNondiff(originalType) && + "Differential function type should not have '@nondiff'"); + + switch (reprKind) { + case DifferentiabilityRepresentationKind::Linear: + assert(originalType->isEqual(differentialType) && + "Linear functions' differential must equal themselves"); + assert(maxOrder == -1 && + "'maxOrder' must be -1 if the function is linear"); + break; + case DifferentiabilityRepresentationKind::Normal: + break; + } +#endif +} + +CanSILDifferentiableFunctionType +SILDifferentiableFunctionType::getLinear( + ASTContext &C, CanGenericSignature genericSig, + AutoDiffIndexSubset *parameterIndices, AutoDiffIndexSubset *resultIndices, + CanSILFunctionType originalType, CanSILFunctionType transposeType) { + return get(C, /*maxOrder*/ -1, DifferentiabilityRepresentationKind::Linear, + genericSig, parameterIndices, resultIndices, originalType, + /*differnetialType*/ originalType, + /*pullbackType*/ transposeType); +} + +CanSILFunctionType SILDifferentiableFunctionType:: +getOriginalFunctionTypeWithDifferentiabilityFlags() { + SmallVector parametersWithFlags; + for (auto indexAndParam : enumerate(originalType->getParameters())) { + if (parameterIndices->contains(indexAndParam.index())) { + assert(indexAndParam.value().getDifferentiability() == + SILParameterDifferentiability::DifferentiableOrNotApplicable && + "Original function type should not have '@nondiff' parameters"); + parametersWithFlags.push_back(indexAndParam.value()); + } else { + parametersWithFlags.push_back( + indexAndParam.value().getWithDifferentiability( + SILParameterDifferentiability::NotDifferentiable)); + } + } + // TODO(rxwei): Handle result differentiability. + return SILFunctionType::get(originalType->getGenericSignature(), + originalType->getExtInfo(), + originalType->getCoroutineKind(), + originalType->getCalleeConvention(), + parametersWithFlags, originalType->getYields(), + originalType->getResults(), + originalType->getOptionalErrorResult(), + getASTContext()); +} + +CanSILFunctionType +SILDifferentiableFunctionType::getAssociatedFunctionType( + AutoDiffAssociatedFunctionKind kind, unsigned order) const { + SmallVector resultsWithLinearMap( + originalType->getResults().begin(), originalType->getResults().end()); + auto getResultConventionForFunction = [&](CanSILFunctionType type) { + if (type->getExtInfo().getRepresentation() == + SILFunctionTypeRepresentation::Thin) + return ResultConvention::Unowned; + return ResultConvention::Owned; + }; + switch (kind) { + case AutoDiffAssociatedFunctionKind::JVP: + resultsWithLinearMap.push_back( + {differentialType, getResultConventionForFunction(differentialType)}); + break; + case AutoDiffAssociatedFunctionKind::VJP: { + resultsWithLinearMap.push_back( + {pullbackType, getResultConventionForFunction(pullbackType)}); + break; + } + } + return SILFunctionType::get(originalType->getGenericSignature(), + originalType->getExtInfo(), + originalType->getCoroutineKind(), + originalType->getCalleeConvention(), + originalType->getParameters(), + originalType->getYields(), resultsWithLinearMap, + originalType->getOptionalErrorResult(), + originalType->getASTContext()); +} + +CanSILDifferentiableFunctionType +SILDifferentiableFunctionType::getLinearTransposeType() const { + assert(representationKind == DifferentiabilityRepresentationKind::Linear); + return SILDifferentiableFunctionType::get(originalType->getASTContext(), + maxOrder, representationKind, + genericSig, parameterIndices, + resultIndices, originalType, + pullbackType, differentialType); +} + +void SILDifferentiableFunctionType::Profile( + llvm::FoldingSetNodeID &id, int maxOrder, + DifferentiabilityRepresentationKind reprKind, + CanGenericSignature genericSig, AutoDiffIndexSubset *parameterIndices, + AutoDiffIndexSubset *resultIndices, CanSILFunctionType originalType, + CanSILFunctionType differentialType, CanSILFunctionType pullbackType) { + id.AddInteger(maxOrder); + id.AddInteger((unsigned)reprKind); + id.AddPointer(genericSig.getPointer()); + id.AddPointer(parameterIndices); + id.AddPointer(resultIndices); + id.AddPointer(originalType.getPointer()); + id.AddPointer(differentialType.getPointer()); + id.AddPointer(pullbackType.getPointer()); +} + // SWIFT_ENABLE_TENSORFLOW // Makes a function with the same generic signature and extinfo as `copy`, but // with `params` parameters and `retTy` return type. @@ -4386,17 +4568,10 @@ Optional TypeBase::getAutoDiffAssociatedTangentSpace( return vs; }; - // Functions' tangent is the same function except the innermost return type - // being replaced by its tangent. - if (auto *fnTy = getAs()) { - auto resultSpace = fnTy->getResult()->getAutoDiffAssociatedTangentSpace( - lookupConformance); - if (!resultSpace) - return cache(None); - return cache(VectorSpace::getFunction( - makeFunctionType(fnTy, fnTy->getParams(), resultSpace->getType(), - fnTy->getOptGenericSignature()))); - } + // Functions' tangent/cotangent is `AnyDerivative`. + // TODO: Change this to `AnyDerivative` when it is implemented. + if (is() || is()) + return VectorSpace::getExistential(ctx); // Tuples' tangent is a tuple of each element's Tangent. if (auto *tupleTy = getAs()) { diff --git a/lib/AST/TypeRepr.cpp b/lib/AST/TypeRepr.cpp index 56b7b930aae2e..fa6d34a16d27c 100644 --- a/lib/AST/TypeRepr.cpp +++ b/lib/AST/TypeRepr.cpp @@ -258,6 +258,13 @@ TypeRepr *CloneVisitor::visitSILBoxTypeRepr(SILBoxTypeRepr *type) { type->getArgumentRAngleLoc()); } +// SWIFT_ENABLE_TENSORFLOW +TypeRepr *CloneVisitor::visitSILDifferentiableFunctionTypeRepr( + SILDifferentiableFunctionTypeRepr *type) { + // TODO(rxwei): implement. + llvm_unreachable("Unimplemented"); +} + TypeRepr *CloneVisitor::visitOpaqueReturnTypeRepr(OpaqueReturnTypeRepr *type) { return new (Ctx) OpaqueReturnTypeRepr(type->getOpaqueLoc(), visit(type->getConstraint())); @@ -593,6 +600,26 @@ void SILBoxTypeRepr::printImpl(ASTPrinter &Printer, Printer.printKeyword("sil_box", Opts); } +// SWIFT_ENABLE_TENSORFLOW +void SILDifferentiableFunctionTypeRepr::printImpl( + ASTPrinter &Printer, const PrintOptions &Opts) const { + Printer << '{'; + printTypeRepr(getOriginal(), Printer, Opts); + if (auto *differential = getDifferential()) { + Printer << "differential: "; + printTypeRepr(differential, Printer, Opts); + } + if (auto *pullback = getPullback()) { + Printer << "pullback: "; + printTypeRepr(pullback, Printer, Opts); + } + if (auto *transpose = getTranspose()) { + Printer << "transpose: "; + printTypeRepr(transpose, Printer, Opts); + } + Printer << '}'; +} + // See swift/Basic/Statistic.h for declaration: this enables tracing // TypeReprs, is defined here to avoid too much layering violation / circular // linkage dependency. diff --git a/lib/AST/TypeWalker.cpp b/lib/AST/TypeWalker.cpp index ba70105096670..68eb517b272da 100644 --- a/lib/AST/TypeWalker.cpp +++ b/lib/AST/TypeWalker.cpp @@ -197,6 +197,16 @@ class Traversal : public TypeVisitor return false; } + // SWIFT_ENABLE_TENSORFLOW + bool visitSILDifferentiableFunctionType(SILDifferentiableFunctionType *ty) { + for (Type type : {ty->getOriginalFunctionType(), + ty->getDifferentialType(), ty->getPullbackType()}) { + if (type && doIt(type)) + return true; + } + return false; + } + public: explicit Traversal(TypeWalker &walker) : Walker(walker) {} diff --git a/lib/IRGen/GenDiffFunc.cpp b/lib/IRGen/GenDiffFunc.cpp index 01093f98f3483..99f2922aa59bc 100644 --- a/lib/IRGen/GenDiffFunc.cpp +++ b/lib/IRGen/GenDiffFunc.cpp @@ -32,22 +32,24 @@ using namespace swift; using namespace irgen; -using DiffFuncIndex = +using LegacyDiffFuncIndex = std::pair; namespace { -class DiffFuncFieldInfo final : public RecordField { +class LegacyDiffFuncFieldInfo final + : public RecordField { public: - DiffFuncFieldInfo(DiffFuncIndex index, const TypeInfo &type, - AutoDiffIndexSubset *parameterIndices) - : RecordField(type), Index(index), ParameterIndices(parameterIndices) {} - /// The field index. - const DiffFuncIndex Index; + const LegacyDiffFuncIndex Index; /// The parameter indices. AutoDiffIndexSubset *ParameterIndices; + LegacyDiffFuncFieldInfo(LegacyDiffFuncIndex index, + AutoDiffIndexSubset *ParameterIndices, + const TypeInfo &type) + : RecordField(type), Index(index), ParameterIndices(ParameterIndices) {} + std::string getFieldName() const { auto extractee = std::get<0>(Index); auto differentiationOrder = std::get<1>(Index); @@ -75,22 +77,24 @@ class DiffFuncFieldInfo final : public RecordField { } }; -class DiffFuncTypeInfo final - : public RecordTypeInfo { +class LegacyDiffFuncTypeInfo final + : public RecordTypeInfo { using super = - RecordTypeInfo; + RecordTypeInfo; public: - DiffFuncTypeInfo(ArrayRef fields, unsigned explosionSize, - llvm::Type *ty, Size size, SpareBitVector &&spareBits, - Alignment align, IsPOD_t isPOD, - IsFixedSize_t alwaysFixedSize) + LegacyDiffFuncTypeInfo(ArrayRef fields, + unsigned explosionSize, + llvm::Type *ty, Size size, SpareBitVector &&spareBits, + Alignment align, IsPOD_t isPOD, + IsFixedSize_t alwaysFixedSize) : super(fields, explosionSize, ty, size, std::move(spareBits), align, isPOD, alwaysFixedSize) {} Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T, - const DiffFuncFieldInfo &field) const { + const LegacyDiffFuncFieldInfo &field) const { return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T)); } @@ -114,46 +118,47 @@ class DiffFuncTypeInfo final } }; -class DiffFuncTypeBuilder - : public RecordTypeBuilder { +class LegacyDiffFuncTypeBuilder + : public RecordTypeBuilder { SILFunctionType *origFnTy; AutoDiffIndexSubset *parameterIndices; public: - DiffFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy) + LegacyDiffFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy) : RecordTypeBuilder(IGM), origFnTy(fnTy->getWithoutDifferentiability()), parameterIndices(fnTy->getDifferentiationParameterIndices()) { assert(fnTy->isDifferentiable()); } - TypeInfo *createFixed(ArrayRef fields, + TypeInfo *createFixed(ArrayRef fields, StructLayout &&layout) { llvm_unreachable("@differentiable functions are always loadable"); } - DiffFuncTypeInfo *createLoadable(ArrayRef fields, - StructLayout &&layout, - unsigned explosionSize) { - return DiffFuncTypeInfo::create( + LegacyDiffFuncTypeInfo *createLoadable( + ArrayRef fields, StructLayout &&layout, + unsigned explosionSize) { + return LegacyDiffFuncTypeInfo::create( fields, explosionSize, layout.getType(), layout.getSize(), std::move(layout.getSpareBits()), layout.getAlignment(), layout.isPOD(), layout.isAlwaysFixedSize()); } - TypeInfo *createNonFixed(ArrayRef fields, + TypeInfo *createNonFixed(ArrayRef fields, FieldsAreABIAccessible_t fieldsAccessible, StructLayout &&layout) { llvm_unreachable("@differentiable functions are always loadable"); } - DiffFuncFieldInfo getFieldInfo(unsigned index, DiffFuncIndex field, - const TypeInfo &fieldTI) { - return DiffFuncFieldInfo(field, fieldTI, parameterIndices); + LegacyDiffFuncFieldInfo getFieldInfo( + unsigned index, LegacyDiffFuncIndex field, const TypeInfo &fieldTI) { + return LegacyDiffFuncFieldInfo(field, parameterIndices, fieldTI); } - SILType getType(DiffFuncIndex field) { + SILType getType(LegacyDiffFuncIndex field) { if (std::get<0>(field) == AutoDiffFunctionExtractInst::Extractee::Original) return SILType::getPrimitiveObjectType(origFnTy->getCanonicalType()); auto differentiationOrder = std::get<1>(field); @@ -172,10 +177,10 @@ class DiffFuncTypeBuilder } // end anonymous namespace const TypeInfo * -TypeConverter::convertDifferentiableFunctionType(SILFunctionType *type) { +TypeConverter::convertLegacyDifferentiableFunctionType(SILFunctionType *type) { assert(type->isDifferentiable()); - DiffFuncTypeBuilder builder(IGM, type); - SmallVector fields; + LegacyDiffFuncTypeBuilder builder(IGM, type); + SmallVector fields; fields.push_back( std::make_pair(AutoDiffFunctionExtractInst::Extractee::Original, 0)); fields.push_back( @@ -184,3 +189,142 @@ TypeConverter::convertDifferentiableFunctionType(SILFunctionType *type) { std::make_pair(AutoDiffFunctionExtractInst::Extractee::VJP, 1)); return builder.layout(fields); } + +// New differnetiable function type. + +namespace { +class DiffFuncFieldInfo final : public RecordField { +public: + DiffFuncFieldInfo(unsigned index, const TypeInfo &type) + : RecordField(type), Index(index) {} + + /// The field index. + const unsigned Index; + + std::string getFieldName() const { + if (Index == 0) + return "original"; + return "assoc_" + llvm::utostr(Index); + } + + SILType getType(IRGenModule &IGM, SILType t) const { + auto diffFnTy = t.castTo(); + if (Index == 0) + return SILType::getPrimitiveObjectType( + diffFnTy->getOriginalFunctionType()); + switch (diffFnTy->getRepresentationKind()) { + case DifferentiabilityRepresentationKind::Normal: + assert((int)Index <= diffFnTy->getMaxOrder()); + return SILType::getPrimitiveObjectType( + diffFnTy->getAssociatedFunctionType( + AutoDiffAssociatedFunctionKind::JVP, Index)); + case DifferentiabilityRepresentationKind::Linear: + assert(Index == 1); + return SILType::getPrimitiveObjectType(diffFnTy->getPullbackType()); + } + } +}; + +class DiffFuncTypeInfo final + : public RecordTypeInfo { + using super = + RecordTypeInfo; + +public: + DiffFuncTypeInfo(ArrayRef fields, + unsigned explosionSize, + llvm::Type *ty, Size size, SpareBitVector &&spareBits, + Alignment align, IsPOD_t isPOD, + IsFixedSize_t alwaysFixedSize) + : super(fields, explosionSize, ty, size, std::move(spareBits), align, + isPOD, alwaysFixedSize) {} + + Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T, + const DiffFuncFieldInfo &field) const { + return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T)); + } + + void initializeFromParams(IRGenFunction &IGF, Explosion ¶ms, Address src, + SILType T, bool isOutlined) const override { + llvm_unreachable("unexploded @differentiable function as argument?"); + } + + void addToAggLowering(IRGenModule &IGM, SwiftAggLowering &lowering, + Size offset) const override { + for (auto &field : getFields()) { + auto fieldOffset = offset + field.getFixedByteOffset(); + cast(field.getTypeInfo()) + .addToAggLowering(IGM, lowering, fieldOffset); + } + } + + llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF) const { return None; } + llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF, SILType T) const { + return None; + } +}; + +class DiffFuncTypeBuilder + : public RecordTypeBuilder { + SILDifferentiableFunctionType *type; + + public: + DiffFuncTypeBuilder(IRGenModule &IGM, SILDifferentiableFunctionType *diffFnTy) + : RecordTypeBuilder(IGM), type(diffFnTy) { + } + + TypeInfo *createFixed(ArrayRef fields, + StructLayout &&layout) { + llvm_unreachable("@differentiable functions are always loadable"); + } + + DiffFuncTypeInfo *createLoadable( + ArrayRef fields, StructLayout &&layout, + unsigned explosionSize) { + return DiffFuncTypeInfo::create( + fields, explosionSize, layout.getType(), layout.getSize(), + std::move(layout.getSpareBits()), layout.getAlignment(), layout.isPOD(), + layout.isAlwaysFixedSize()); + } + + TypeInfo *createNonFixed(ArrayRef fields, + FieldsAreABIAccessible_t fieldsAccessible, + StructLayout &&layout) { + llvm_unreachable("@differentiable functions are always loadable"); + } + + DiffFuncFieldInfo getFieldInfo(unsigned, unsigned index, + const TypeInfo &fieldTI) { + return DiffFuncFieldInfo(index, fieldTI); + } + + SILType getType(unsigned index) { + if (index == 0) + return SILType::getPrimitiveObjectType(type->getOriginalFunctionType()); + switch (type->getRepresentationKind()) { + case DifferentiabilityRepresentationKind::Normal: + assert((int)index <= type->getMaxOrder()); + return SILType::getPrimitiveObjectType( + type->getAssociatedFunctionType( + AutoDiffAssociatedFunctionKind::JVP, index)); + case DifferentiabilityRepresentationKind::Linear: + assert(index == 1); + return SILType::getPrimitiveObjectType(type->getPullbackType()); + } + } + + StructLayout performLayout(ArrayRef fieldTypes) { + return StructLayout(IGM, /*decl=*/nullptr, LayoutKind::NonHeapObject, + LayoutStrategy::Universal, fieldTypes); + } +}; +} // end anonymous namespace + +const TypeInfo *TypeConverter::convertDifferentiableFunctionType( + SILDifferentiableFunctionType *type) { + DiffFuncTypeBuilder builder(IGM, type); + return builder.layout({0, 1}); +} diff --git a/lib/IRGen/GenFunc.cpp b/lib/IRGen/GenFunc.cpp index f439c3400abb3..28f4df14b9dd4 100644 --- a/lib/IRGen/GenFunc.cpp +++ b/lib/IRGen/GenFunc.cpp @@ -477,7 +477,7 @@ Address irgen::projectBlockStorageCapture(IRGenFunction &IGF, const TypeInfo *TypeConverter::convertFunctionType(SILFunctionType *T) { // SWIFT_ENABLE_TENSORFLOW if (T->isDifferentiable()) - return convertDifferentiableFunctionType(T); + return convertLegacyDifferentiableFunctionType(T); switch (T->getRepresentation()) { case SILFunctionType::Representation::Block: diff --git a/lib/IRGen/GenType.cpp b/lib/IRGen/GenType.cpp index f20b971f75385..dfaff00b5edeb 100644 --- a/lib/IRGen/GenType.cpp +++ b/lib/IRGen/GenType.cpp @@ -1820,6 +1820,10 @@ const TypeInfo *TypeConverter::convertType(CanType ty) { llvm_unreachable("AST FunctionTypes should be lowered by SILGen"); case TypeKind::SILFunction: return convertFunctionType(cast(ty)); + // SWIFT_ENABLE_TENSORFLOW + case TypeKind::SILDifferentiableFunction: + return convertDifferentiableFunctionType( + cast(ty)); case TypeKind::Protocol: return convertProtocolType(cast(ty)); case TypeKind::ProtocolComposition: diff --git a/lib/IRGen/GenType.h b/lib/IRGen/GenType.h index 8243d26ef54bb..9ea106c75b73c 100644 --- a/lib/IRGen/GenType.h +++ b/lib/IRGen/GenType.h @@ -135,7 +135,9 @@ class TypeConverter { const TypeInfo *convertStructType(TypeBase *key, CanType type, StructDecl *D); const TypeInfo *convertFunctionType(SILFunctionType *T); // SWIFT_ENABLE_TENSORFLOW - const TypeInfo *convertDifferentiableFunctionType(SILFunctionType *T); + const TypeInfo * + convertDifferentiableFunctionType(SILDifferentiableFunctionType *T); + const TypeInfo *convertLegacyDifferentiableFunctionType(SILFunctionType *T); const TypeInfo *convertBlockStorageType(SILBlockStorageType *T); const TypeInfo *convertBoxType(SILBoxType *T); const TypeInfo *convertArchetypeType(ArchetypeType *T); diff --git a/lib/IRGen/IRGenDebugInfo.cpp b/lib/IRGen/IRGenDebugInfo.cpp index 5eadfea4ace5c..0aefe20512c9c 100644 --- a/lib/IRGen/IRGenDebugInfo.cpp +++ b/lib/IRGen/IRGenDebugInfo.cpp @@ -1530,6 +1530,8 @@ class IRGenDebugInfoImpl : public IRGenDebugInfo { case TypeKind::SILBlockStorage: case TypeKind::SILBox: case TypeKind::SILToken: + // SWIFT_ENABLE_TENSORFLOW: + case TypeKind::SILDifferentiableFunction: case TypeKind::BuiltinUnsafeValueBuffer: LLVM_DEBUG(llvm::errs() << "Unhandled type: "; DbgTy.getType()->dump(); diff --git a/lib/IRGen/MetadataRequest.cpp b/lib/IRGen/MetadataRequest.cpp index c303e3886c492..08c74db67a00c 100644 --- a/lib/IRGen/MetadataRequest.cpp +++ b/lib/IRGen/MetadataRequest.cpp @@ -1323,6 +1323,13 @@ namespace { llvm_unreachable("should not be asking for metadata of a lowered SIL " "function type--SILGen should have used the AST type"); } + // SWIFT_ENABLE_TENSORFLOW + MetadataResponse visitSILDifferentiableFunctionType( + CanSILDifferentiableFunctionType type, DynamicMetadataRequest request) { + llvm_unreachable("should not be asking for metadata of a lowered SIL " + "differentiable function type--SILGen should have used " + "the AST type"); + } MetadataResponse visitSILTokenType(CanSILTokenType type, DynamicMetadataRequest request) { llvm_unreachable("should not be asking for metadata of a SILToken type"); @@ -2072,6 +2079,17 @@ namespace { llvm_unreachable("Not a valid SILFunctionType."); } + // SWIFT_ENABLE_TENSORFLOW + llvm::Value *visitSILDifferentiableFunctionType( + CanSILDifferentiableFunctionType type, + DynamicMetadataRequest request) { + // All differentiable function types look like () -> (). + // FIXME: It'd be nice not to have to call through the runtime here. + return IGF.emitTypeMetadataRef( + CanFunctionType::get({}, type->getASTContext().TheEmptyTupleType), + request).getMetadata(); + } + llvm::Value *visitAnyMetatypeType(CanAnyMetatypeType type, DynamicMetadataRequest request) { @@ -2282,6 +2300,13 @@ namespace { llvm_unreachable("Not a valid SILFunctionType."); } + // SWIFT_ENABLE_TENSORFLOW + llvm::Value *visitSILDifferentiableFunctionType( + CanSILDifferentiableFunctionType type, DynamicMetadataRequest request) { + return emitFromValueWitnessTable( + CanFunctionType::get({}, type->getASTContext().TheEmptyTupleType)); + } + llvm::Value *visitAnyMetatypeType(CanAnyMetatypeType type, DynamicMetadataRequest request) { diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index 14a88e64aaa15..5809382452ae2 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -2335,6 +2335,39 @@ bool Parser::parseTypeAttribute(TypeAttributes &Attributes, bool justChecking) { LPLoc); } + // SWIFT_ENABLE_TENSORFLOW + DifferentiabilityRepresentationKind differentiabilityReprKind; + unsigned maxDifferentiabilityOrder; + if (attr == TAK_sil_differentiable) { + SourceLoc LPLoc; + if (!consumeIfNotAtStartOfLine(tok::l_paren)) { + if (!justChecking) + diagnose(Tok, diag::sil_differentiable_attribute_expected_lparen); + return true; + } + SourceLoc parenParamLoc; + if (Tok.is(tok::identifier) && Tok.getText() == "linear") { + consumeToken(tok::identifier); + differentiabilityReprKind = DifferentiabilityRepresentationKind::Linear; + maxDifferentiabilityOrder = -1; + } else { + differentiabilityReprKind = DifferentiabilityRepresentationKind::Normal; + if (parseUnsignedInteger( + maxDifferentiabilityOrder, parenParamLoc, + diag::sil_differentiable_attribute_expected_max_order)) { + return true; + } + } + // Parse the ')'. We can't use parseMatchingToken if we're in + // just-checking mode. + if (justChecking && Tok.isNot(tok::r_paren)) + return true; + SourceLoc RPLoc; + parseMatchingToken(tok::r_paren, RPLoc, + diag::convention_attribute_expected_rparen, + LPLoc); + } + // In just-checking mode, we only need to consume the tokens, and we don't // want to do any other analysis. if (justChecking) @@ -2429,7 +2462,14 @@ bool Parser::parseTypeAttribute(TypeAttributes &Attributes, bool justChecking) { Attributes.convention = conventionName; Attributes.conventionWitnessMethodProtocol = witnessMethodProtocol; break; - + + // SWIFT_ENABLE_TENSORFLOW + // @sil_differentiable attribute. + case TAK_sil_differentiable: + Attributes.differentiabilityReprKindAndOrder = + {differentiabilityReprKind, maxDifferentiabilityOrder}; + break; + case TAK__opaqueReturnTypeOf: { // Parse the mangled decl name and index. auto beginLoc = Tok.getLoc(); diff --git a/lib/Parse/ParseType.cpp b/lib/Parse/ParseType.cpp index 735cb60c11024..2f49512396683 100644 --- a/lib/Parse/ParseType.cpp +++ b/lib/Parse/ParseType.cpp @@ -351,6 +351,55 @@ ParserResult Parser::parseSILBoxType(GenericParamList *generics, SourceLoc())); } +// SWIFT_ENABLE_TENSORFLOW +ParserResult Parser::parseSILDifferentiableFunctionType( + GenericParamList *generics, const TypeAttributes &attrs, + Optional &GenericsScope) { + SyntaxParsingContext TypeParsingContext(SyntaxContext, + SyntaxContextKind::Type); + SourceLoc lBraceLoc, rBraceLoc; + TypeRepr *originalType = nullptr; + constexpr unsigned numAssocFns = 3; + std::array assocFnLabels( + {"differential", "pullback", "transpose"}); + std::array assocFnTypes({nullptr, nullptr, nullptr}); + + if (parseToken(tok::l_brace, lBraceLoc, + diag::sil_differentiable_attribute_expected_lbrace)) + return makeParserError(); + + auto originalTypeParseResult = parseType(); + if (originalTypeParseResult.isParseError()) + return makeParserError(); + originalType = originalTypeParseResult.get(); + + for (auto i : range(numAssocFns)) { + if (Tok.isNot(tok::comma) || peekToken().isNot(tok::identifier) || + peekToken().getText() != assocFnLabels[i]) + continue; + consumeToken(tok::comma); + consumeToken(tok::identifier); + if (parseToken(tok::colon, diag::expected_colon_after_label, + assocFnLabels[i])) + return makeParserError(); + auto parseResult = parseType(); + if (parseResult.isParseError()) + return makeParserError(); + assocFnTypes[i] = parseResult.get(); + } + + if (parseToken(tok::r_brace, rBraceLoc, + diag::sil_differentiable_attribute_expected_rbrace)) + return makeParserError(); + + auto *diffFnType = new (Context) SILDifferentiableFunctionTypeRepr( + generics, originalType, std::get<0>(assocFnTypes), + std::get<1>(assocFnTypes), std::get<2>(assocFnTypes), + SourceRange(lBraceLoc, rBraceLoc)); + return makeParserResult(applyAttributeToType(diffFnType, attrs, + VarDecl::Specifier::Default, + SourceLoc())); +} /// parseType /// type: @@ -383,6 +432,12 @@ ParserResult Parser::parseType(Diag<> MessageID, GenericsScope.emplace(this, ScopeKind::Generics); generics = maybeParseGenericParams().getPtrOrNull(); } + + // SWIFT_ENABLE_TENSORFLOW + // In SIL mode, parse differentiable function type. + if (isInSILMode() && attrs.has(TAK_sil_differentiable)) { + return parseSILDifferentiableFunctionType(generics, attrs, GenericsScope); + } // In SIL mode, parse box types { ... }. if (isInSILMode() && Tok.is(tok::l_brace)) { diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index 4689851473b9e..8a12de0fea998 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -1374,6 +1374,13 @@ bool SILParser::parseSILType(SILType &Result, boxType->setGenericEnvironment(env); } } + // SWIFT_ENABLE_TENSORFLOW + if (auto diffFnType = dyn_cast(T)) { + if (auto generics = diffFnType->getGenericParams()) { + auto env = handleSILGenericParams(C, generics, SF); + diffFnType->setGenericEnvironment(env); + } + } return true; } }; @@ -1385,7 +1392,11 @@ bool SILParser::parseSILType(SILType &Result, if (auto fnType = dyn_cast(TyR.get())) if (auto env = fnType->getGenericEnvironment()) ParsedGenericEnv = env; - + // SWIFT_ENABLE_TENSORFLOW + if (auto diffFnType = dyn_cast(TyR.get())) + if (auto env = diffFnType->getGenericEnvironment()) + ParsedGenericEnv = env; + // Apply attributes to the type. TypeLoc Ty = P.applyAttributeToType(TyR.get(), attrs, specifier, specifierLoc); diff --git a/lib/SIL/SILFunctionType.cpp b/lib/SIL/SILFunctionType.cpp index dc722921f19a4..d3ed5d6945704 100644 --- a/lib/SIL/SILFunctionType.cpp +++ b/lib/SIL/SILFunctionType.cpp @@ -98,17 +98,6 @@ CanType SILFunctionType::getSelfInstanceType() const { } // SWIFT_ENABLE_TENSORFLOW -AutoDiffIndexSubset * -SILFunctionType::getDifferentiationParameterIndices() { - assert(isDifferentiable()); - SmallVector result; - for (auto valueAndIndex : enumerate(getParameters())) - if (valueAndIndex.value().getDifferentiability() != - SILParameterDifferentiability::NotDifferentiable) - result.push_back(valueAndIndex.index()); - return AutoDiffIndexSubset::get(getASTContext(), getNumParameters(), result); -} - CanSILFunctionType SILFunctionType::getWithDifferentiability( unsigned differentiationOrder, AutoDiffIndexSubset *parameterIndices) { // FIXME(rxwei): Handle differentiation order. @@ -133,8 +122,6 @@ CanSILFunctionType SILFunctionType::getWithDifferentiability( } CanSILFunctionType SILFunctionType::getWithoutDifferentiability() { - if (!isDifferentiable()) - return CanSILFunctionType(this); auto nondiffExtInfo = getExtInfo().withDifferentiable(false); SmallVector newParams; for (auto ¶m : getParameters()) diff --git a/lib/SIL/SILInstructions.cpp b/lib/SIL/SILInstructions.cpp index b4f9e0e228e8e..491a47242e0bd 100644 --- a/lib/SIL/SILInstructions.cpp +++ b/lib/SIL/SILInstructions.cpp @@ -569,7 +569,8 @@ TryApplyInst *TryApplyInst::create( SILType AutoDiffFunctionInst::getAutoDiffType(SILValue originalFunction, unsigned differentiationOrder, - AutoDiffIndexSubset *parameterIndices) { + AutoDiffIndexSubset *parameterIndices, + bool useNewSILDiffFuncType) { auto fnTy = originalFunction->getType().castTo(); auto diffTy = fnTy->getWithDifferentiability(differentiationOrder, parameterIndices); @@ -579,28 +580,32 @@ AutoDiffFunctionInst::getAutoDiffType(SILValue originalFunction, AutoDiffFunctionInst::AutoDiffFunctionInst( SILModule &module, SILDebugLocation debugLoc, AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, - SILValue originalFunction, ArrayRef associatedFunctions) + SILValue originalFunction, ArrayRef associatedFunctions, + bool useNewSILDiffFuncType) : InstructionBaseWithTrailingOperands( originalFunction, associatedFunctions, debugLoc, getAutoDiffType(originalFunction, differentiationOrder, - parameterIndices), + parameterIndices, useNewSILDiffFuncType), originalFunction.getOwnershipKind()), parameterIndices(parameterIndices), differentiationOrder(differentiationOrder), + useNewSILDiffFuncType(useNewSILDiffFuncType), numOperands(1 + associatedFunctions.size()) {} AutoDiffFunctionInst *AutoDiffFunctionInst::create( SILModule &module, SILDebugLocation debugLoc, AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, SILValue originalFunction, - ArrayRef associatedFunctions) { + ArrayRef associatedFunctions, + bool useNewSILDiffFuncType) { size_t size = totalSizeToAlloc(associatedFunctions.size() + 1); void *buffer = module.allocateInst(size, alignof(AutoDiffFunctionInst)); return ::new (buffer) AutoDiffFunctionInst(module, debugLoc, parameterIndices, differentiationOrder, originalFunction, - associatedFunctions); + associatedFunctions, + useNewSILDiffFuncType); } std::pair AutoDiffFunctionInst:: diff --git a/lib/SIL/TypeLowering.cpp b/lib/SIL/TypeLowering.cpp index 0cd7316945105..28db2cd501edf 100644 --- a/lib/SIL/TypeLowering.cpp +++ b/lib/SIL/TypeLowering.cpp @@ -222,6 +222,8 @@ namespace { IMPL(BuiltinUnknownObject, Reference) IMPL(BuiltinVector, Trivial) IMPL(SILToken, Trivial) + // SWIFT_ENABLE_TENSORFLOW + IMPL(SILDifferentiableFunction, Trivial) IMPL(Class, Reference) IMPL(BoundGenericClass, Reference) IMPL(AnyMetatype, Trivial) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index c19db14ec985d..fd77aadfad575 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -5342,9 +5342,8 @@ void AdjointEmitter::emitZeroIndirect(CanType type, SILValue bufferAccess, } return; } - case VectorSpace::Kind::Function: { - llvm_unreachable( - "Unimplemented: Emit thunks for abstracting zero initialization"); + case VectorSpace::Kind::Existential: { + llvm_unreachable("Unimplemented: 'AnyDerivative.zero'"); } } } @@ -5522,9 +5521,9 @@ SILValue AdjointEmitter::accumulateDirect(SILValue lhs, SILValue rhs) { } return builder.createTuple(loc, adjointTy, adjElements); } - case VectorSpace::Kind::Function: { + case VectorSpace::Kind::Existential: { llvm_unreachable( - "Unimplemented: Emit thunks for abstracting adjoint accumulation"); + "Unimplemented: 'AnyDerivative' accumulation"); } } } @@ -5583,9 +5582,8 @@ void AdjointEmitter::accumulateIndirect( } return; } - case VectorSpace::Kind::Function: { - llvm_unreachable( - "Unimplemented: Emit thunks for abstracting adjoint accumulation"); + case VectorSpace::Kind::Existential: { + llvm_unreachable("Unimplemented: 'AnyDerivative' accumulation"); } } } @@ -5637,9 +5635,9 @@ void AdjointEmitter::accumulateIndirect(SILValue lhsDestAccess, } return; } - case VectorSpace::Kind::Function: { + case VectorSpace::Kind::Existential: { llvm_unreachable( - "Unimplemented: Emit thunks for abstracting adjoint accumulation"); + "Unimplemented: 'AnyDerivative' accumulation"); } } } @@ -5982,10 +5980,11 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap( case VectorSpace::Kind::Tuple: { llvm_unreachable( "Unimplemented: Handle zero initialization for tuples"); - } - case VectorSpace::Kind::Function: + case VectorSpace::Kind::Existential: { llvm_unreachable( - "Unimplemented: Emit thunks for abstracting zero initialization"); + "Unimplemented: Handle zero initialization for AnyDerivative"); + } + } } }; diff --git a/lib/Sema/TypeCheckType.cpp b/lib/Sema/TypeCheckType.cpp index 285b26c8cc96d..105a8a3238a10 100644 --- a/lib/Sema/TypeCheckType.cpp +++ b/lib/Sema/TypeCheckType.cpp @@ -1867,6 +1867,10 @@ namespace { TypeResolutionOptions options); Type resolveSILBoxType(SILBoxTypeRepr *repr, TypeResolutionOptions options); + // SWIFT_ENABLE_TENSORFLOW + Type resolveSILDifferentiableFunctionType( + SILDifferentiableFunctionTypeRepr *repr, TypeResolutionOptions options, + DifferentiabilityRepresentationKind reprKind, unsigned maxOrder); Type buildMetatypeType(MetatypeTypeRepr *repr, Type instanceType, @@ -1958,6 +1962,11 @@ Type TypeResolver::resolveType(TypeRepr *repr, TypeResolutionOptions options) { assert((options & TypeResolutionFlags::SILType) && "SILBox repr in non-SIL type context?!"); return resolveSILBoxType(cast(repr), options); + // SWIFT_ENABLE_TENSORFLOW + case TypeReprKind::SILDifferentiableFunction: + llvm_unreachable("SILDifferentiableFunction is always attributed with " + "'@sil_differentiable'"); + case TypeReprKind::Array: return resolveArrayType(cast(repr), options); @@ -2144,10 +2153,26 @@ Type TypeResolver::resolveAttributedType(TypeAttributes &attrs, TAK_callee_guaranteed, TAK_noescape, TAK_yield_once, - TAK_yield_many}) { + TAK_yield_many, + // SWIFT_ENABLE_TENSORFLOW + TAK_sil_differentiable}) { checkUnsupportedAttr(silOnlyAttr); } - } + } + + // SWIFT_ENABLE_TENSORFLOW + if (attrs.has(TAK_sil_differentiable)) { + auto *diffFnTy = dyn_cast(repr); + if (!diffFnTy) { + diagnose(attrs.getLoc(TAK_sil_differentiable), + diag::sil_differentiable_attr_not_applicable); + return Type(); + } + auto reprKindAndOrder = + attrs.getDifferentiabilityRepresentationKindAndOrder(); + return resolveSILDifferentiableFunctionType( + diffFnTy, options, reprKindAndOrder.first, reprKindAndOrder.second); + } // Other function representation attributes are not normally supported at // source level, but we want to support them there in SIL files. @@ -2718,6 +2743,146 @@ Type TypeResolver::resolveSILBoxType(SILBoxTypeRepr *repr, return SILBoxType::get(Context, layout, subMap); } +Type TypeResolver::resolveSILDifferentiableFunctionType( + SILDifferentiableFunctionTypeRepr *repr, TypeResolutionOptions options, + DifferentiabilityRepresentationKind reprKind, unsigned maxOrder) { + // Resolve generic params using the function's generic environment, if it + // has one. + Optional resolveSILFunctionGenericParams; + Optional> useSILFunctionGenericEnv; + CanGenericSignature genericSig; + if (auto *genEnv = repr->getGenericEnvironment()) { + genericSig = genEnv->getGenericSignature()->getCanonicalSignature(); + resolveSILFunctionGenericParams = TypeResolution::forContextual(DC, genEnv); + useSILFunctionGenericEnv.emplace(resolution, + *resolveSILFunctionGenericParams); + } + + // Resolve the original type. + auto originalTypeRepr = repr->getOriginal(); + if (!originalTypeRepr) { + diagnose(repr->getLoc(), + diag::sil_differentiable_required_original_function_field); + return ErrorType::get(Context); + } + auto originalType = resolveType(repr->getOriginal(), options); + if (!originalType || originalType->hasError()) + return ErrorType::get(Context); + auto originalSILFnType = originalType->getAs(); + if (!originalSILFnType) { + diagnose(originalTypeRepr->getLoc(), + diag::sil_differentiable_fields_must_be_function_type); + return ErrorType::get(Context); + } + if (originalSILFnType->getGenericSignature()) { + diagnose(originalTypeRepr->getLoc(), + diag::sil_differentiable_field_cannot_be_generic); + return ErrorType::get(Context); + } + auto *parameterIndices = + originalSILFnType->getDifferentiationParameterIndices(); + auto *resultIndices = + originalSILFnType->getDifferentiationResultIndices(); + + auto checkAndDiagnoseInvalidField = [this, repr]( + TypeRepr *fieldRepr, StringRef fieldName) -> TypeRepr * { + if (fieldRepr) + return fieldRepr; + diagnose(repr->getLoc(), + diag::sil_differentiable_required_field, fieldName); + return nullptr; + }; + + switch (reprKind) { + case DifferentiabilityRepresentationKind::Normal: { + auto *differentialTypeRepr = + checkAndDiagnoseInvalidField(repr->getDifferential(), "differential:"); + auto *pullbackTypeRepr = + checkAndDiagnoseInvalidField(repr->getPullback(), "pullback:"); + if (!differentialTypeRepr || !pullbackTypeRepr) + return ErrorType::get(Context); + // The type should not specify transpose. + if (auto *transposeRepr = repr->getTranspose()) { + diagnose(transposeRepr->getLoc(), diag::sil_differentiable_invalid_field); + return ErrorType::get(Context); + } + + // Resolve differential type. + auto differentialType = resolveType(differentialTypeRepr, options); + if (!differentialType || differentialType->hasError()) + return ErrorType::get(Context); + auto *differentialSILFnType = differentialType->getAs(); + if (!differentialSILFnType) { + diagnose(differentialTypeRepr->getLoc(), + diag::sil_differentiable_fields_must_be_function_type); + return ErrorType::get(Context); + } + if (differentialSILFnType->getGenericSignature()) { + diagnose(differentialTypeRepr->getLoc(), + diag::sil_differentiable_field_cannot_be_generic); + return ErrorType::get(Context); + } + // Resolve pullback type. + auto pullbackType = resolveType(pullbackTypeRepr, options); + if (!pullbackType || pullbackType->hasError()) + return ErrorType::get(Context); + auto *pullbackSILFnType = pullbackType->getAs(); + if (!pullbackSILFnType) { + diagnose(pullbackTypeRepr->getLoc(), + diag::sil_differentiable_fields_must_be_function_type); + return ErrorType::get(Context); + } + if (pullbackSILFnType->getGenericSignature()) { + diagnose(pullbackTypeRepr->getLoc(), + diag::sil_differentiable_field_cannot_be_generic); + return ErrorType::get(Context); + } + + return SILDifferentiableFunctionType::get( + Context, maxOrder, DifferentiabilityRepresentationKind::Normal, + genericSig, parameterIndices, resultIndices, + originalSILFnType->getWithoutDifferentiability(), + CanSILFunctionType(differentialSILFnType), + CanSILFunctionType(pullbackSILFnType)); + } + + case DifferentiabilityRepresentationKind::Linear: { + auto *transposeTypeRepr = + checkAndDiagnoseInvalidField(repr->getTranspose(), "transpose:"); + if (!transposeTypeRepr) + return ErrorType::get(Context); + // The type should not specify differential or pullback. + for (auto *nonapplicableRepr : + {repr->getDifferential(), repr->getPullback()}) { + if (!nonapplicableRepr) + continue; + diagnose(nonapplicableRepr->getLoc(), + diag::sil_differentiable_invalid_field); + return ErrorType::get(Context); + } + // Resolve transpose type. + auto transposeType = resolveType(transposeTypeRepr, options); + if (!transposeType || transposeType->hasError()) + return ErrorType::get(Context); + auto *transposeSILFnType = transposeType->getAs(); + if (!transposeSILFnType) { + diagnose(transposeTypeRepr->getLoc(), + diag::sil_differentiable_fields_must_be_function_type); + return ErrorType::get(Context); + } + if (transposeSILFnType->getGenericSignature()) { + diagnose(transposeTypeRepr->getLoc(), + diag::sil_differentiable_field_cannot_be_generic); + return ErrorType::get(Context); + } + return SILDifferentiableFunctionType::getLinear( + Context, genericSig, parameterIndices, resultIndices, + originalSILFnType->getWithoutDifferentiability(), + CanSILFunctionType(transposeSILFnType)); + } + } +} + Type TypeResolver::resolveSILFunctionType(FunctionTypeRepr *repr, TypeResolutionOptions options, SILCoroutineKind coroutineKind, diff --git a/lib/Serialization/Deserialization.cpp b/lib/Serialization/Deserialization.cpp index 958d15ef2404a..297aa77488cf8 100644 --- a/lib/Serialization/Deserialization.cpp +++ b/lib/Serialization/Deserialization.cpp @@ -5252,6 +5252,52 @@ class swift::TypeDeserializer { errorResult, ctx, witnessMethodConformance); } + // SWIFT_ENABLE_TENSORFLOW + Expected + deserializeSILDifferentiableFunctionType(ArrayRef scratch, + StringRef blobData) { + unsigned maxOrder, representationKind; + GenericSignatureID genericSigID; + TypeID originalTypeID, differentialTypeID, pullbackTypeID; + ArrayRef parameterAndResultIndices; + decls_block::SILDifferentiableFunctionTypeLayout::readRecord( + scratch, maxOrder, genericSigID, representationKind, originalTypeID, + differentialTypeID, pullbackTypeID, parameterAndResultIndices); + auto *genericSig = MF.getGenericSignature(genericSigID); + auto originalType = MF.getTypeChecked(originalTypeID); + if (!originalType) + return originalType.takeError(); + auto differentialType = MF.getTypeChecked(differentialTypeID); + if (!differentialType) + return differentialType.takeError(); + auto pullbackType = MF.getTypeChecked(pullbackTypeID); + if (!pullbackType) + return pullbackType.takeError(); + + // Convert parameter indices and result indices to bit vectors. + auto originalFnType = + CanSILFunctionType(originalType.get()->castTo()); + auto numParameters = originalFnType->getNumParameters(); + auto numResults = originalFnType->getNumResults(); + SmallBitVector parameterIndices(numParameters); + for (auto index : parameterAndResultIndices.take_front(numParameters)) + parameterIndices.set(index); + SmallBitVector resultIndices(numResults); + auto rawResultIndices = parameterAndResultIndices.drop_front(numParameters); + assert(rawResultIndices.size() == numResults); + for (auto index : rawResultIndices) + resultIndices.set(index); + return SILDifferentiableFunctionType::get( + ctx, maxOrder, + (DifferentiabilityRepresentationKind)representationKind, + genericSig ? genericSig->getCanonicalSignature() : nullptr, + AutoDiffIndexSubset::get(MF.getContext(), parameterIndices), + AutoDiffIndexSubset::get(MF.getContext(), resultIndices), + originalFnType, + CanSILFunctionType(differentialType.get()->castTo()), + CanSILFunctionType(pullbackType.get()->castTo())); + } + Expected deserializeArraySliceType(ArrayRef scratch, StringRef blobData) { TypeID baseID; diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index 1b6f9e1dbc9a8..e81a6ebabc56a 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -4241,6 +4241,29 @@ void Serializer::writeType(Type ty) { break; } + + // SWIFT_ENABLE_TENSORFLOW + case TypeKind::SILDifferentiableFunction: { + auto fnTy = cast(ty.getPointer()); + auto abbrCode = + DeclTypeAbbrCodes[SILDifferentiableFunctionTypeLayout::Code]; + SmallVector parameterAndResultIndices; + AutoDiffIndexSubset *paramIndices = fnTy->getParameterIndices(); + for (auto i : paramIndices->getIndices()) + parameterAndResultIndices.push_back(i); + AutoDiffIndexSubset *resultIndices = fnTy->getResultIndices(); + for (auto i : resultIndices->getIndices()) + parameterAndResultIndices.push_back(i); + SILDifferentiableFunctionTypeLayout::emitRecord( + Out, ScratchRecord, abbrCode, fnTy->getMaxOrder(), + (unsigned)fnTy->getRepresentationKind(), + addGenericSignatureRef(fnTy->getGenericSignature()), + addTypeRef(fnTy->getOriginalFunctionType()), + addTypeRef(fnTy->getDifferentialType()), + addTypeRef(fnTy->getPullbackType()), + parameterAndResultIndices); + break; + } case TypeKind::ArraySlice: { auto sliceTy = cast(ty.getPointer()); @@ -4365,6 +4388,8 @@ void Serializer::writeAllDeclsAndTypes() { registerDeclTypeAbbr(); registerDeclTypeAbbr(); registerDeclTypeAbbr(); + // SWIFT_ENABLE_TENSORFLOW + registerDeclTypeAbbr(); registerDeclTypeAbbr(); registerDeclTypeAbbr(); registerDeclTypeAbbr(); diff --git a/test/AutoDiff/sil_differentiable_function_type.sil b/test/AutoDiff/sil_differentiable_function_type.sil new file mode 100644 index 0000000000000..25f7bdcdef52b --- /dev/null +++ b/test/AutoDiff/sil_differentiable_function_type.sil @@ -0,0 +1,22 @@ +// RUN: %target-swift-frontend -typecheck -verify %s + +sil_stage raw + +import Swift + +sil @foo1 : $(@sil_differentiable(1) {(Float) -> Float, differential: (Float) -> Float, pullback: (Float) -> Float}) -> () +sil @foo2 : $(@sil_differentiable(linear) {(Float) -> Float, transpose: (Float) -> Float}) -> () +sil @foo3 : $(@sil_differentiable(linear) {(Float, @nondiff Float) -> Float, transpose: (Float) -> Float}) -> () +sil @foo4 : $(@sil_differentiable(2) {(T) -> T, differential: (T.TangentVector) -> T.TangentVector, pullback: (T.TangentVector) -> T.TangentVector}) -> () + +// expected-error @+1 {{invalid field for the specified '@sil_differentiable' representation kind}} +sil @foo5 : $(@sil_differentiable(1) {(Float) -> Float, differential: (Float) -> Float, pullback: (Float) -> Float, transpose: (Float) -> Float}) -> () + +// expected-error @+1 {{a 'pullback:' function type field is required in a '@sil_differentiable'}} +sil @foo6 : $(@sil_differentiable(1) {(Float) -> Float, differential: (Float) -> Float}) -> () + +// expected-error @+1 {{a 'transpose:' function type field is required in a '@sil_differentiable'}} +sil @foo7 : $(@sil_differentiable(linear) {(Float) -> Float, differential: (Float) -> Float, pullback: (Float) -> Float}) -> () + +// expected-error @+1 {{'@sil_differentiable' field type cannot be generic}} +sil @foo8 : $(@sil_differentiable(linear) {(Float) -> Float, transpose: (T) -> T}) -> () From 29008aab5946c3b0b412ed766a2cbdad7eb165cb Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Tue, 28 May 2019 03:32:52 -0700 Subject: [PATCH 2/3] Update autodiff_function inst. --- include/swift/SIL/SILBuilder.h | 7 ++++-- include/swift/SIL/SILInstruction.h | 16 ++++++++----- lib/ParseSIL/ParseSIL.cpp | 23 ++++++++++++++++++- lib/SIL/SILInstructions.cpp | 29 ++++++++++++------------ lib/SIL/SILVerifier.cpp | 16 +++++++++++++ test/AutoDiff/autodiff_function_inst.sil | 7 ++++++ 6 files changed, 74 insertions(+), 24 deletions(-) diff --git a/include/swift/SIL/SILBuilder.h b/include/swift/SIL/SILBuilder.h index 8c5e2a29572b3..68311393daed3 100644 --- a/include/swift/SIL/SILBuilder.h +++ b/include/swift/SIL/SILBuilder.h @@ -506,13 +506,16 @@ class SILBuilder { AutoDiffFunctionInst *createAutoDiffFunction( SILLocation loc, AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, SILValue original, - ArrayRef associatedFunctions = {}) { + ArrayRef associatedFunctions = {}, + bool useNewSILDiffFuncType = false, SILType type = SILType()) { return insert(AutoDiffFunctionInst::create(getModule(), getSILDebugLocation(loc), parameterIndices, differentiationOrder, original, - associatedFunctions)); + associatedFunctions, + type, + useNewSILDiffFuncType)); } AutoDiffFunctionExtractInst *createAutoDiffFunctionExtract( diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h index 36bfaf2567181..c035a881e40e0 100644 --- a/include/swift/SIL/SILInstruction.h +++ b/include/swift/SIL/SILInstruction.h @@ -7779,7 +7779,7 @@ class AutoDiffFunctionInst final : unsigned differentiationOrder, SILValue originalFunction, ArrayRef associatedFunctions, - bool useNewSILDiffFuncType = false); + SILType type, bool useNewSILDiffFuncType); public: static AutoDiffFunctionInst *create(SILModule &module, @@ -7788,12 +7788,16 @@ class AutoDiffFunctionInst final : unsigned differentiationOrder, SILValue originalFunction, ArrayRef associatedFunctions, - bool useNewSILDiffFuncType = false); + SILType type, + bool useNewSILDiffFuncType); - static SILType getAutoDiffType(SILValue original, - unsigned differentiationOrder, - AutoDiffIndexSubset *parameterIndices, - bool useNewSILDiffFuncType); + static SILType getLegacyDifferentiableFunctionType( + SILValue original, unsigned differentiationOrder, + AutoDiffIndexSubset *parameterIndices); + + bool usesNewSILDiffFuncType() const { + return useNewSILDiffFuncType; + } /// Returns the original function. SILValue getOriginalFunction() const { return getAllOperands()[0].get(); } diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index 8a12de0fea998..267451d6443cd 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -2889,6 +2889,20 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { SourceLoc lastLoc; SmallVector parameterIndices; unsigned order = 1; + bool useNewSILDiffFuncType = false; + SILType type; + // Parse optional `[sil_differentiable]`. + if (P.Tok.is(tok::l_square) && + P.peekToken().is(tok::identifier) && + P.peekToken().getText() == "sil_differentiable") { + P.consumeToken(tok::l_square); + P.consumeToken(tok::identifier); + if (P.parseToken(tok::r_square, + diag::sil_inst_autodiff_attr_expected_rsquare, + "'sil_differentiable' attribute")) + return true; + useNewSILDiffFuncType = true; + } // Parse optional `[wrt ...]` if (P.Tok.is(tok::l_square) && P.peekToken().is(tok::identifier) && @@ -2938,6 +2952,12 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { return true; } SmallVector associatedFunctions; + // Parse optional destination type `as `. + if (P.Tok.is(tok::identifier) && P.Tok.getText() == "as") { + P.consumeToken(tok::identifier); + if (parseSILType(type)) + return true; + } // Parse optional operand lists `with { , }, ...`. if (P.Tok.is(tok::identifier) && P.Tok.getText() == "with") { P.consumeToken(tok::identifier); @@ -2972,7 +2992,8 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { AutoDiffIndexSubset::get(P.Context, fnType->getNumParameters(), parameterIndices); ResultVal = B.createAutoDiffFunction(InstLoc, parameterIndicesSubset, order, - original, associatedFunctions); + original, associatedFunctions, + useNewSILDiffFuncType, type); break; } diff --git a/lib/SIL/SILInstructions.cpp b/lib/SIL/SILInstructions.cpp index 491a47242e0bd..d0f579ce2031d 100644 --- a/lib/SIL/SILInstructions.cpp +++ b/lib/SIL/SILInstructions.cpp @@ -567,10 +567,9 @@ TryApplyInst *TryApplyInst::create( // SWIFT_ENABLE_TENSORFLOW SILType -AutoDiffFunctionInst::getAutoDiffType(SILValue originalFunction, - unsigned differentiationOrder, - AutoDiffIndexSubset *parameterIndices, - bool useNewSILDiffFuncType) { +AutoDiffFunctionInst::getLegacyDifferentiableFunctionType( + SILValue originalFunction, unsigned differentiationOrder, + AutoDiffIndexSubset *parameterIndices) { auto fnTy = originalFunction->getType().castTo(); auto diffTy = fnTy->getWithDifferentiability(differentiationOrder, parameterIndices); @@ -581,11 +580,9 @@ AutoDiffFunctionInst::AutoDiffFunctionInst( SILModule &module, SILDebugLocation debugLoc, AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, SILValue originalFunction, ArrayRef associatedFunctions, - bool useNewSILDiffFuncType) + SILType type, bool useNewSILDiffFuncType) : InstructionBaseWithTrailingOperands( - originalFunction, associatedFunctions, debugLoc, - getAutoDiffType(originalFunction, differentiationOrder, - parameterIndices, useNewSILDiffFuncType), + originalFunction, associatedFunctions, debugLoc, type, originalFunction.getOwnershipKind()), parameterIndices(parameterIndices), differentiationOrder(differentiationOrder), @@ -596,20 +593,22 @@ AutoDiffFunctionInst *AutoDiffFunctionInst::create( SILModule &module, SILDebugLocation debugLoc, AutoDiffIndexSubset *parameterIndices, unsigned differentiationOrder, SILValue originalFunction, - ArrayRef associatedFunctions, + ArrayRef associatedFunctions, SILType type, bool useNewSILDiffFuncType) { size_t size = totalSizeToAlloc(associatedFunctions.size() + 1); void *buffer = module.allocateInst(size, alignof(AutoDiffFunctionInst)); - return ::new (buffer) AutoDiffFunctionInst(module, debugLoc, - parameterIndices, - differentiationOrder, - originalFunction, - associatedFunctions, - useNewSILDiffFuncType); + return ::new (buffer) AutoDiffFunctionInst( + module, debugLoc, parameterIndices, differentiationOrder, + originalFunction, associatedFunctions, + type ? type : getLegacyDifferentiableFunctionType(originalFunction, + differentiationOrder, + parameterIndices), + useNewSILDiffFuncType); } std::pair AutoDiffFunctionInst:: getAssociatedFunctionPair(unsigned differentiationOrder) const { + assert(!useNewSILDiffFuncType); assert(differentiationOrder > 0 && differentiationOrder <= this->differentiationOrder); assert(!getAssociatedFunctions().empty() && "No associated functions. Maybe " diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp index 2ac5f00831807..218c4f26af470 100644 --- a/lib/SIL/SILVerifier.cpp +++ b/lib/SIL/SILVerifier.cpp @@ -1474,6 +1474,22 @@ class SILVerifier : public SILVerifierBase { void checkAutoDiffFunctionInst(AutoDiffFunctionInst *adfi) { require(adfi->getDifferentiationOrder() > 0, "The differentiation order must be non-zero"); + + // Special case for the future 'sil_differentiable' mode. When full + // migration is done, the default mode will be this mode. + if (adfi->usesNewSILDiffFuncType()) { + auto destTy = adfi->getOriginalFunction()->getType() + .getAs(); + require(destTy, + "The original function must have a '@sil_differentiable' type"); + require(destTy->getParameterIndices() == adfi->getParameterIndices(), + "Parameter indices must be equal"); + require(destTy->getMaxOrder() == adfi->getDifferentiationOrder(), + "Differentiation order must be equal"); + // TODO(rxwei): Add more checks. + return; + } + auto origTy = adfi->getOriginalFunction()->getType().getAs(); require(origTy, "The original function must have a function type"); diff --git a/test/AutoDiff/autodiff_function_inst.sil b/test/AutoDiff/autodiff_function_inst.sil index 9fa7868b1e2e8..d664686d849b0 100644 --- a/test/AutoDiff/autodiff_function_inst.sil +++ b/test/AutoDiff/autodiff_function_inst.sil @@ -55,3 +55,10 @@ bb0: // CHECK: [[EXTRACTED_VJP:%.*]] = autodiff_function_extract [vjp] [order 1] [[DIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float // CHECK: [[EXTRACTED_ORIG:%.*]] = autodiff_function_extract [original] [[DIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float // CHECK: return [[UNDIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float + +sil @make_new_diff_func : $@convention(thin) () -> @sil_differentiable(1) {@convention(thin) (Float) -> Float, differential: (Float) -> Float, pullback: (Float) -> Float} { +bb0: + %orig = function_ref @foo : $@convention(thin) (Float) -> Float + %diffedFunc = autodiff_function [sil_differentiable] [wrt 0] [order 1] %orig : $@convention(thin) (Float) -> Float as $@sil_differentiable(1) {@convention(thin) (Float) -> Float, differential: (Float) -> Float, pullback: (Float) -> Float} + return %diffFunc : $@sil_differentiable(1) {@convention(thin) (Float) -> Float, differential: (Float) -> Float, pullback: (Float) -> Float} +} From d750f331bcfc933e095a42f40576c5cae753e889 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Wed, 29 May 2019 02:37:30 -0700 Subject: [PATCH 3/3] Fix parsing. --- include/swift/AST/Attr.def | 1 - include/swift/AST/Attr.h | 13 ------- include/swift/AST/DiagnosticsParse.def | 11 +++--- include/swift/AST/DiagnosticsSema.def | 15 ++++---- include/swift/AST/TypeRepr.h | 19 ++++++---- include/swift/Parse/Parser.h | 5 ++- lib/AST/ASTPrinter.cpp | 2 +- lib/Parse/ParseDecl.cpp | 40 --------------------- lib/Parse/ParseType.cpp | 44 +++++++++++++++++++----- lib/ParseSIL/ParseSIL.cpp | 11 +++--- lib/SIL/SILPrinter.cpp | 1 + lib/SIL/SILVerifier.cpp | 7 ++-- lib/Sema/TypeCheckType.cpp | 33 +++++------------- test/AutoDiff/autodiff_function_inst.sil | 6 ++-- 14 files changed, 87 insertions(+), 121 deletions(-) diff --git a/include/swift/AST/Attr.def b/include/swift/AST/Attr.def index 0a760505e03e1..8b8d5f5414ba9 100644 --- a/include/swift/AST/Attr.def +++ b/include/swift/AST/Attr.def @@ -53,7 +53,6 @@ TYPE_ATTR(noescape) TYPE_ATTR(escaping) // SWIFT_ENABLE_TENSORFLOW TYPE_ATTR(differentiable) -TYPE_ATTR(sil_differentiable) TYPE_ATTR(autodiff) TYPE_ATTR(nondiff) diff --git a/include/swift/AST/Attr.h b/include/swift/AST/Attr.h index dc9793f8b9c04..032e70c0b3724 100644 --- a/include/swift/AST/Attr.h +++ b/include/swift/AST/Attr.h @@ -68,10 +68,6 @@ class TypeAttributes { Optional convention = None; Optional conventionWitnessMethodProtocol = None; - // SWIFT_ENABLE_TENSORFLOW - Optional> - differentiabilityReprKindAndOrder = None; - // For an opened existential type, the known ID. Optional OpenedID; @@ -130,15 +126,6 @@ class TypeAttributes { bool hasConvention() const { return convention.hasValue(); } StringRef getConvention() const { return *convention; } - // SWIFT_ENABLE_TENSORFLOW - bool hasDifferentiabilityRepresentationKindAndOrder() const { - return differentiabilityReprKindAndOrder.hasValue(); - } - std::pair - getDifferentiabilityRepresentationKindAndOrder() const { - return *differentiabilityReprKindAndOrder; - } - bool hasOwnership() const { return getOwnership() != ReferenceOwnership::Strong; } diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index 584b6a08f1ee3..30e76e0bf35fc 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -1395,13 +1395,14 @@ ERROR(convention_attribute_witness_method_expected_protocol,none, // sil_differentiable ERROR(sil_differentiable_attribute_expected_lparen,none, - "expected '(' after 'sil_differentiable' attribute", ()) + "expected '(' after 'sil_differentiable'", ()) ERROR(sil_differentiable_attribute_expected_max_order,none, - "expected a max differentiation order in 'sil_differentiable' attribute", ()) + "expected a max differentiation order in 'sil_differentiable(...)'", ()) ERROR(sil_differentiable_attribute_expected_rparen,none, - "expected ')' after convention name for 'sil_differentiable' attribute", ()) + "expected ')' after the representation kind or order for " + "'sil_differentiable'", ()) ERROR(sil_differentiable_attribute_expected_lbrace,none, - "expected '{' in a '@sil_differentiable' type", ()) + "expected '{' in a 'sil_differentiable' type", ()) ERROR(sil_differentiable_attribute_expected_differential,none, "expected 'differential:'", ()) ERROR(sil_differentiable_attribute_expected_pullback,none, @@ -1409,7 +1410,7 @@ ERROR(sil_differentiable_attribute_expected_pullback,none, ERROR(sil_differentiable_attribute_expected_transpose,none, "expected 'transpose:' ", ()) ERROR(sil_differentiable_attribute_expected_rbrace,none, - "expected '}' to end '@sil_differentiable' type", ()) + "expected '}' to end a 'sil_differentiable' type", ()) // objc ERROR(attr_objc_missing_colon,none, diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 7953958a4bb70..2ef2be580ab3f 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -3796,17 +3796,20 @@ ERROR(sil_metatype_multiple_reprs,none, // SWIFT_ENABLE_TENSORFLOW // @sil_differentiable types ERROR(sil_differentiable_attr_not_applicable,none, - "'@sil_differentiable' is not applicable to this type", ()) + "'sil_differentiable' is not applicable to this type", ()) ERROR(sil_differentiable_required_original_function_field,none, - "an original function type field is required in a '@sil_differentiable'", ()) + "an original function type field is required in a 'sil_differentiable'", + ()) ERROR(sil_differentiable_required_field,none, - "a '%0' function type field is required in a '@sil_differentiable'", (StringRef)) + "a '%0' function type field is required in a 'sil_differentiable'", + (StringRef)) ERROR(sil_differentiable_fields_must_be_function_type,none, - "fields in a '@sil_differentiable' type must be function types", ()) + "fields in a 'sil_differentiable' type must be function types", ()) ERROR(sil_differentiable_invalid_field,none, - "invalid field for the specified '@sil_differentiable' representation kind", ()) + "invalid field for the specified '@sil_differentiable' representation " + "kind", ()) ERROR(sil_differentiable_field_cannot_be_generic,none, - "'@sil_differentiable' field type cannot be generic", ()) + "'sil_differentiable' field type cannot be generic", ()) //------------------------------------------------------------------------------ // MARK: @objc and @nonobjc diff --git a/include/swift/AST/TypeRepr.h b/include/swift/AST/TypeRepr.h index 57b40f4cc1591..1474f5a44dc3a 100644 --- a/include/swift/AST/TypeRepr.h +++ b/include/swift/AST/TypeRepr.h @@ -1175,6 +1175,8 @@ inline bool TypeRepr::isSimple() const { // SWIFT_ENABLE_TENSORFLOW class SILDifferentiableFunctionTypeRepr final : public TypeRepr { GenericParamList *GenericParams; + DifferentiabilityRepresentationKind reprKind; + int maxOrder; GenericEnvironment *GenericEnv = nullptr; TypeRepr *Original; TypeRepr *Differential; @@ -1184,13 +1186,14 @@ class SILDifferentiableFunctionTypeRepr final : public TypeRepr { public: SILDifferentiableFunctionTypeRepr( - GenericParamList *genericParams, TypeRepr *original, - TypeRepr *differential, TypeRepr *pullback, TypeRepr *transpose, - SourceRange braces) + GenericParamList *genericParams, + DifferentiabilityRepresentationKind reprKind, int maxOrder, + TypeRepr *original, TypeRepr *differential, TypeRepr *pullback, + TypeRepr *transpose, SourceRange braces) : TypeRepr(TypeReprKind::SILDifferentiableFunction), - GenericParams(genericParams), Original(original), - Differential(differential), Pullback(pullback), Transpose(transpose), - Braces(braces) {} + GenericParams(genericParams), reprKind(reprKind), maxOrder(maxOrder), + Original(original), Differential(differential), Pullback(pullback), + Transpose(transpose), Braces(braces) {} GenericParamList *getGenericParams() const { return GenericParams; }; GenericEnvironment *getGenericEnvironment() const { return GenericEnv; }; @@ -1198,6 +1201,10 @@ class SILDifferentiableFunctionTypeRepr final : public TypeRepr { assert(GenericEnv == nullptr); GenericEnv = env; } + DifferentiabilityRepresentationKind getRepresentationKind() const { + return reprKind; + } + int getMaxOrder() const { return maxOrder; } TypeRepr *getOriginal() const { return Original; } TypeRepr *getDifferential() const { return Differential; } TypeRepr *getPullback() const { return Pullback; } diff --git a/include/swift/Parse/Parser.h b/include/swift/Parse/Parser.h index a98f8131462ad..ea21e7b1e9a59 100644 --- a/include/swift/Parse/Parser.h +++ b/include/swift/Parse/Parser.h @@ -1115,9 +1115,8 @@ class Parser { const TypeAttributes &attrs, Optional &GenericsScope); // SWIFT_ENABLE_TENSORFLOW - ParserResult parseSILDifferentiableFunctionType( - GenericParamList *generics, const TypeAttributes &attrs, - Optional &GenericsScope); + ParserResult parseSILDifferentiableFunctionTypeBody( + GenericParamList *generics, Optional &GenericsScope); ParserResult parseTypeTupleBody(); ParserResult parseTypeArray(TypeRepr *Base); diff --git a/lib/AST/ASTPrinter.cpp b/lib/AST/ASTPrinter.cpp index a459b48b67b94..eb43adb1acfdf 100644 --- a/lib/AST/ASTPrinter.cpp +++ b/lib/AST/ASTPrinter.cpp @@ -4073,7 +4073,7 @@ class TypePrinter : public TypeVisitor { // SWIFT_ENABLE_TENSORFLOW void visitSILDifferentiableFunctionType(SILDifferentiableFunctionType *T) { - Printer << "@sil_differentiable("; + Printer << "sil_differentiable("; switch (T->getRepresentationKind()) { case DifferentiabilityRepresentationKind::Linear: Printer << "linear) {"; diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index 5809382452ae2..45e30b2c252b0 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -2335,39 +2335,6 @@ bool Parser::parseTypeAttribute(TypeAttributes &Attributes, bool justChecking) { LPLoc); } - // SWIFT_ENABLE_TENSORFLOW - DifferentiabilityRepresentationKind differentiabilityReprKind; - unsigned maxDifferentiabilityOrder; - if (attr == TAK_sil_differentiable) { - SourceLoc LPLoc; - if (!consumeIfNotAtStartOfLine(tok::l_paren)) { - if (!justChecking) - diagnose(Tok, diag::sil_differentiable_attribute_expected_lparen); - return true; - } - SourceLoc parenParamLoc; - if (Tok.is(tok::identifier) && Tok.getText() == "linear") { - consumeToken(tok::identifier); - differentiabilityReprKind = DifferentiabilityRepresentationKind::Linear; - maxDifferentiabilityOrder = -1; - } else { - differentiabilityReprKind = DifferentiabilityRepresentationKind::Normal; - if (parseUnsignedInteger( - maxDifferentiabilityOrder, parenParamLoc, - diag::sil_differentiable_attribute_expected_max_order)) { - return true; - } - } - // Parse the ')'. We can't use parseMatchingToken if we're in - // just-checking mode. - if (justChecking && Tok.isNot(tok::r_paren)) - return true; - SourceLoc RPLoc; - parseMatchingToken(tok::r_paren, RPLoc, - diag::convention_attribute_expected_rparen, - LPLoc); - } - // In just-checking mode, we only need to consume the tokens, and we don't // want to do any other analysis. if (justChecking) @@ -2463,13 +2430,6 @@ bool Parser::parseTypeAttribute(TypeAttributes &Attributes, bool justChecking) { Attributes.conventionWitnessMethodProtocol = witnessMethodProtocol; break; - // SWIFT_ENABLE_TENSORFLOW - // @sil_differentiable attribute. - case TAK_sil_differentiable: - Attributes.differentiabilityReprKindAndOrder = - {differentiabilityReprKind, maxDifferentiabilityOrder}; - break; - case TAK__opaqueReturnTypeOf: { // Parse the mangled decl name and index. auto beginLoc = Tok.getLoc(); diff --git a/lib/Parse/ParseType.cpp b/lib/Parse/ParseType.cpp index 2f49512396683..c908eed541d7b 100644 --- a/lib/Parse/ParseType.cpp +++ b/lib/Parse/ParseType.cpp @@ -352,11 +352,37 @@ ParserResult Parser::parseSILBoxType(GenericParamList *generics, } // SWIFT_ENABLE_TENSORFLOW -ParserResult Parser::parseSILDifferentiableFunctionType( - GenericParamList *generics, const TypeAttributes &attrs, - Optional &GenericsScope) { +ParserResult Parser::parseSILDifferentiableFunctionTypeBody( + GenericParamList *generics, Optional &GenericsScope) { SyntaxParsingContext TypeParsingContext(SyntaxContext, SyntaxContextKind::Type); + // Parse representation kind and order. + if (!consumeIfNotAtStartOfLine(tok::l_paren)) { + diagnose(Tok, diag::sil_differentiable_attribute_expected_lparen); + return makeParserError(); + } + auto lParenLoc = PreviousLoc; + DifferentiabilityRepresentationKind reprKind; + int maxOrder; + SourceLoc parenParamLoc; + if (Tok.is(tok::identifier) && Tok.getText() == "linear") { + consumeToken(tok::identifier); + reprKind = DifferentiabilityRepresentationKind::Linear; + maxOrder = -1; + } else { + reprKind = DifferentiabilityRepresentationKind::Normal; + unsigned order; + if (parseUnsignedInteger(order, parenParamLoc, + diag::sil_differentiable_attribute_expected_max_order)) + return makeParserError(); + maxOrder = order; + } + SourceLoc rParenLoc; + if (parseMatchingToken(tok::r_paren, rParenLoc, + diag::sil_differentiable_attribute_expected_rparen, + lParenLoc)) + return makeParserError(); + // Parse field types. SourceLoc lBraceLoc, rBraceLoc; TypeRepr *originalType = nullptr; constexpr unsigned numAssocFns = 3; @@ -393,12 +419,10 @@ ParserResult Parser::parseSILDifferentiableFunctionType( return makeParserError(); auto *diffFnType = new (Context) SILDifferentiableFunctionTypeRepr( - generics, originalType, std::get<0>(assocFnTypes), + generics, reprKind, maxOrder, originalType, std::get<0>(assocFnTypes), std::get<1>(assocFnTypes), std::get<2>(assocFnTypes), SourceRange(lBraceLoc, rBraceLoc)); - return makeParserResult(applyAttributeToType(diffFnType, attrs, - VarDecl::Specifier::Default, - SourceLoc())); + return makeParserResult(diffFnType); } /// parseType @@ -435,8 +459,10 @@ ParserResult Parser::parseType(Diag<> MessageID, // SWIFT_ENABLE_TENSORFLOW // In SIL mode, parse differentiable function type. - if (isInSILMode() && attrs.has(TAK_sil_differentiable)) { - return parseSILDifferentiableFunctionType(generics, attrs, GenericsScope); + if (isInSILMode() && Tok.is(tok::identifier) && + Tok.getText() == "sil_differentiable") { + consumeToken(tok::identifier); + return parseSILDifferentiableFunctionTypeBody(generics, GenericsScope); } // In SIL mode, parse box types { ... }. diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index 267451d6443cd..13e20ea063ede 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -1389,13 +1389,15 @@ bool SILParser::parseSILType(SILType &Result, ->walk(HandleSILGenericParamsWalker(P.Context, &P.SF)); // Save the top-level function generic environment if there was one. - if (auto fnType = dyn_cast(TyR.get())) + if (auto fnType = dyn_cast(TyR.get())) { if (auto env = fnType->getGenericEnvironment()) ParsedGenericEnv = env; + } // SWIFT_ENABLE_TENSORFLOW - if (auto diffFnType = dyn_cast(TyR.get())) + else if (auto diffFnType = dyn_cast(TyR.get())) { if (auto env = diffFnType->getGenericEnvironment()) ParsedGenericEnv = env; + } // Apply attributes to the type. TypeLoc Ty = P.applyAttributeToType(TyR.get(), attrs, specifier, specifierLoc); @@ -2891,6 +2893,7 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { unsigned order = 1; bool useNewSILDiffFuncType = false; SILType type; + SmallVector associatedFunctions; // Parse optional `[sil_differentiable]`. if (P.Tok.is(tok::l_square) && P.peekToken().is(tok::identifier) && @@ -2951,10 +2954,8 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { diag::sil_inst_autodiff_expected_function_type_operand); return true; } - SmallVector associatedFunctions; // Parse optional destination type `as `. - if (P.Tok.is(tok::identifier) && P.Tok.getText() == "as") { - P.consumeToken(tok::identifier); + if (P.consumeIf(tok::kw_as)) { if (parseSILType(type)) return true; } diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index 04c90102ed669..52b4e3bfd07cf 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -1167,6 +1167,7 @@ class SILPrinter : public SILInstructionVisitor { } *this << "[order " << adfi->getDifferentiationOrder() << "] "; *this << getIDAndType(adfi->getOriginalFunction()); + *this << " as " << adfi->getType() << ' '; if (!adfi->getAssociatedFunctions().empty()) { *this << " with "; interleave(range(1, adfi->getDifferentiationOrder() + 1), diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp index 218c4f26af470..6252612b8f0dd 100644 --- a/lib/SIL/SILVerifier.cpp +++ b/lib/SIL/SILVerifier.cpp @@ -1478,13 +1478,12 @@ class SILVerifier : public SILVerifierBase { // Special case for the future 'sil_differentiable' mode. When full // migration is done, the default mode will be this mode. if (adfi->usesNewSILDiffFuncType()) { - auto destTy = adfi->getOriginalFunction()->getType() - .getAs(); + auto destTy = adfi->getType().getAs(); require(destTy, - "The original function must have a '@sil_differentiable' type"); + "The destination type must be a 'sil_differentiable' type"); require(destTy->getParameterIndices() == adfi->getParameterIndices(), "Parameter indices must be equal"); - require(destTy->getMaxOrder() == adfi->getDifferentiationOrder(), + require(destTy->getMaxOrder() == (int)adfi->getDifferentiationOrder(), "Differentiation order must be equal"); // TODO(rxwei): Add more checks. return; diff --git a/lib/Sema/TypeCheckType.cpp b/lib/Sema/TypeCheckType.cpp index 105a8a3238a10..e92ab84e2810d 100644 --- a/lib/Sema/TypeCheckType.cpp +++ b/lib/Sema/TypeCheckType.cpp @@ -1869,8 +1869,7 @@ namespace { TypeResolutionOptions options); // SWIFT_ENABLE_TENSORFLOW Type resolveSILDifferentiableFunctionType( - SILDifferentiableFunctionTypeRepr *repr, TypeResolutionOptions options, - DifferentiabilityRepresentationKind reprKind, unsigned maxOrder); + SILDifferentiableFunctionTypeRepr *repr, TypeResolutionOptions options); Type buildMetatypeType(MetatypeTypeRepr *repr, Type instanceType, @@ -1964,8 +1963,8 @@ Type TypeResolver::resolveType(TypeRepr *repr, TypeResolutionOptions options) { // SWIFT_ENABLE_TENSORFLOW case TypeReprKind::SILDifferentiableFunction: - llvm_unreachable("SILDifferentiableFunction is always attributed with " - "'@sil_differentiable'"); + return resolveSILDifferentiableFunctionType( + cast(repr), options); case TypeReprKind::Array: return resolveArrayType(cast(repr), options); @@ -2153,27 +2152,11 @@ Type TypeResolver::resolveAttributedType(TypeAttributes &attrs, TAK_callee_guaranteed, TAK_noescape, TAK_yield_once, - TAK_yield_many, - // SWIFT_ENABLE_TENSORFLOW - TAK_sil_differentiable}) { + TAK_yield_many}) { checkUnsupportedAttr(silOnlyAttr); } } - // SWIFT_ENABLE_TENSORFLOW - if (attrs.has(TAK_sil_differentiable)) { - auto *diffFnTy = dyn_cast(repr); - if (!diffFnTy) { - diagnose(attrs.getLoc(TAK_sil_differentiable), - diag::sil_differentiable_attr_not_applicable); - return Type(); - } - auto reprKindAndOrder = - attrs.getDifferentiabilityRepresentationKindAndOrder(); - return resolveSILDifferentiableFunctionType( - diffFnTy, options, reprKindAndOrder.first, reprKindAndOrder.second); - } - // Other function representation attributes are not normally supported at // source level, but we want to support them there in SIL files. auto SF = DC->getParentSourceFile(); @@ -2744,8 +2727,7 @@ Type TypeResolver::resolveSILBoxType(SILBoxTypeRepr *repr, } Type TypeResolver::resolveSILDifferentiableFunctionType( - SILDifferentiableFunctionTypeRepr *repr, TypeResolutionOptions options, - DifferentiabilityRepresentationKind reprKind, unsigned maxOrder) { + SILDifferentiableFunctionTypeRepr *repr, TypeResolutionOptions options) { // Resolve generic params using the function's generic environment, if it // has one. Optional resolveSILFunctionGenericParams; @@ -2793,7 +2775,7 @@ Type TypeResolver::resolveSILDifferentiableFunctionType( return nullptr; }; - switch (reprKind) { + switch (repr->getRepresentationKind()) { case DifferentiabilityRepresentationKind::Normal: { auto *differentialTypeRepr = checkAndDiagnoseInvalidField(repr->getDifferential(), "differential:"); @@ -2839,7 +2821,8 @@ Type TypeResolver::resolveSILDifferentiableFunctionType( } return SILDifferentiableFunctionType::get( - Context, maxOrder, DifferentiabilityRepresentationKind::Normal, + Context, repr->getMaxOrder(), + DifferentiabilityRepresentationKind::Normal, genericSig, parameterIndices, resultIndices, originalSILFnType->getWithoutDifferentiability(), CanSILFunctionType(differentialSILFnType), diff --git a/test/AutoDiff/autodiff_function_inst.sil b/test/AutoDiff/autodiff_function_inst.sil index d664686d849b0..c48e45e3f6dcb 100644 --- a/test/AutoDiff/autodiff_function_inst.sil +++ b/test/AutoDiff/autodiff_function_inst.sil @@ -56,9 +56,9 @@ bb0: // CHECK: [[EXTRACTED_ORIG:%.*]] = autodiff_function_extract [original] [[DIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float // CHECK: return [[UNDIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float -sil @make_new_diff_func : $@convention(thin) () -> @sil_differentiable(1) {@convention(thin) (Float) -> Float, differential: (Float) -> Float, pullback: (Float) -> Float} { +sil @make_new_diff_func : $@convention(thin) () -> sil_differentiable(1) {@convention(thin) (Float) -> Float, differential: (Float) -> Float, pullback: (Float) -> Float} { bb0: %orig = function_ref @foo : $@convention(thin) (Float) -> Float - %diffedFunc = autodiff_function [sil_differentiable] [wrt 0] [order 1] %orig : $@convention(thin) (Float) -> Float as $@sil_differentiable(1) {@convention(thin) (Float) -> Float, differential: (Float) -> Float, pullback: (Float) -> Float} - return %diffFunc : $@sil_differentiable(1) {@convention(thin) (Float) -> Float, differential: (Float) -> Float, pullback: (Float) -> Float} + %diffedFunc = autodiff_function [sil_differentiable] [wrt 0] [order 1] %orig : $@convention(thin) (Float) -> Float as $sil_differentiable(1) {@convention(thin) (Float) -> Float, differential: (Float) -> Float, pullback: (Float) -> Float} + return %diffedFunc : $sil_differentiable(1) {@convention(thin) (Float) -> Float, differential: (Float) -> Float, pullback: (Float) -> Float} }