diff --git a/docs/ABI/Mangling.rst b/docs/ABI/Mangling.rst index 02674faeaff59..92f0ab6954a30 100644 --- a/docs/ABI/Mangling.rst +++ b/docs/ABI/Mangling.rst @@ -516,6 +516,10 @@ Types FUNCTION-KIND ::= 'C' // C function pointer type FUNCTION-KIND ::= 'A' // @auto_closure function type (escaping) FUNCTION-KIND ::= 'E' // function type (noescape) + FUNCTION-KIND ::= 'F' // @differentiable function type + FUNCTION-KIND ::= 'G' // @differentiable function type (escaping) + FUNCTION-KIND ::= 'H' // @differentiable(linear) function type + FUNCTION-KIND ::= 'I' // @differentiable(linear) function type (escaping) function-signature ::= params-type params-type throws? // results and parameters diff --git a/include/swift/ABI/MetadataValues.h b/include/swift/ABI/MetadataValues.h index c6478a837f80b..1f32981b9ecc6 100644 --- a/include/swift/ABI/MetadataValues.h +++ b/include/swift/ABI/MetadataValues.h @@ -733,6 +733,14 @@ enum class FunctionMetadataConvention: uint8_t { CFunctionPointer = 3, }; +/// Differentiability kind for function type metadata. +/// Duplicates `DifferentiabilityKind` in AutoDiff.h. +enum class FunctionMetadataDifferentiabilityKind: uint8_t { + NonDifferentiable = 0b00, + Normal = 0b01, + Linear = 0b11 +}; + /// Flags in a function type metadata record. template class TargetFunctionTypeFlags { @@ -747,7 +755,8 @@ class TargetFunctionTypeFlags { ParamFlagsMask = 0x02000000U, EscapingMask = 0x04000000U, // SWIFT_ENABLE_TENSORFLOW - DifferentiableMask = 0x08000000U + DifferentiableMask = 0x08000000U, + LinearMask = 0x10000000U }; int_type Data; @@ -785,10 +794,14 @@ class TargetFunctionTypeFlags { } // SWIFT_ENABLE_TENSORFLOW - constexpr TargetFunctionTypeFlags - withDifferentiable(bool isDifferentiable) const { - return TargetFunctionTypeFlags((Data & ~DifferentiableMask) | - (isDifferentiable ? DifferentiableMask : 0)); + constexpr TargetFunctionTypeFlags withDifferentiabilityKind( + FunctionMetadataDifferentiabilityKind differentiability) const { + return TargetFunctionTypeFlags( + (Data & ~DifferentiableMask & ~LinearMask) | + (differentiability == FunctionMetadataDifferentiabilityKind::Normal + ? DifferentiableMask : 0) | + (differentiability == FunctionMetadataDifferentiabilityKind::Linear + ? LinearMask : 0)); } unsigned getNumParameters() const { return Data & NumParametersMask; } @@ -807,7 +820,15 @@ class TargetFunctionTypeFlags { // SWIFT_ENABLE_TENSORFLOW bool isDifferentiable() const { - return bool (Data & DifferentiableMask); + return getDifferentiabilityKind() >= + FunctionMetadataDifferentiabilityKind::Normal; + } + FunctionMetadataDifferentiabilityKind getDifferentiabilityKind() const { + if (bool(Data & DifferentiableMask)) + return FunctionMetadataDifferentiabilityKind::Normal; + if (bool(Data & LinearMask)) + return FunctionMetadataDifferentiabilityKind::Linear; + return FunctionMetadataDifferentiabilityKind::NonDifferentiable; } bool hasParameterFlags() const { return bool(Data & ParamFlagsMask); } diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index b858fcddc6606..d1a9e40a40d2a 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -23,7 +23,7 @@ #include "swift/Basic/Range.h" namespace swift { - + enum class DifferentiabilityKind: uint8_t { NonDifferentiable = 0b00, Normal = 0b01, diff --git a/include/swift/Demangling/DemangleNodes.def b/include/swift/Demangling/DemangleNodes.def index b34a04eccd281..93a8ef211ceaa 100644 --- a/include/swift/Demangling/DemangleNodes.def +++ b/include/swift/Demangling/DemangleNodes.def @@ -68,6 +68,12 @@ NODE(DependentProtocolConformanceInherited) NODE(DependentProtocolConformanceAssociated) CONTEXT_NODE(Destructor) CONTEXT_NODE(DidSet) +// SWIFT_ENABLE_TENSORFLOW +NODE(DifferentiableFunctionType) +NODE(EscapingDifferentiableFunctionType) +NODE(LinearFunctionType) +NODE(EscapingLinearFunctionType) +// SWIFT_ENABLE_TENSORFLOW END NODE(Directness) NODE(DynamicAttribute) NODE(DirectMethodReferenceAttribute) diff --git a/include/swift/Demangling/TypeDecoder.h b/include/swift/Demangling/TypeDecoder.h index f1ffcf9bd3e7c..00fb316727bed 100644 --- a/include/swift/Demangling/TypeDecoder.h +++ b/include/swift/Demangling/TypeDecoder.h @@ -494,6 +494,12 @@ class TypeDecoder { case NodeKind::NoEscapeFunctionType: case NodeKind::AutoClosureType: case NodeKind::EscapingAutoClosureType: + // SWIFT_ENABLE_TENSORFLOW + case NodeKind::DifferentiableFunctionType: + case NodeKind::EscapingDifferentiableFunctionType: + case NodeKind::LinearFunctionType: + case NodeKind::EscapingLinearFunctionType: + // SWIFT_ENABLE_TENSORFLOW END case NodeKind::FunctionType: { if (Node->getNumChildren() < 2) return BuiltType(); @@ -508,6 +514,17 @@ class TypeDecoder { } else if (Node->getKind() == NodeKind::ThinFunctionType) { flags = flags.withConvention(FunctionMetadataConvention::Thin); } + // SWIFT_ENABLE_TENSORFLOW + else if (Node->getKind() == NodeKind::DifferentiableFunctionType || + Node->getKind() == + NodeKind::EscapingDifferentiableFunctionType) { + flags = flags.withDifferentiabilityKind( + FunctionMetadataDifferentiabilityKind::Normal); + } else if (Node->getKind() == NodeKind::LinearFunctionType || + Node->getKind() == NodeKind::EscapingLinearFunctionType) { + flags = flags.withDifferentiabilityKind( + FunctionMetadataDifferentiabilityKind::Linear); + } bool isThrow = Node->getChild(0)->getKind() == NodeKind::ThrowsAnnotation; @@ -527,7 +544,12 @@ class TypeDecoder { .withEscaping( Node->getKind() == NodeKind::FunctionType || Node->getKind() == NodeKind::EscapingAutoClosureType || - Node->getKind() == NodeKind::EscapingObjCBlock); + Node->getKind() == NodeKind::EscapingObjCBlock || + // SWIFT_ENABLE_TENSORFLOW + Node->getKind() == + NodeKind::EscapingDifferentiableFunctionType || + Node->getKind() == + NodeKind::EscapingLinearFunctionType); auto result = decodeMangledType(Node->getChild(isThrow ? 2 : 1)); if (!result) return BuiltType(); diff --git a/lib/AST/ASTDemangler.cpp b/lib/AST/ASTDemangler.cpp index 28e2cd781b262..5e46bf55c0455 100644 --- a/lib/AST/ASTDemangler.cpp +++ b/lib/AST/ASTDemangler.cpp @@ -388,8 +388,18 @@ Type ASTBuilder::createFunctionType( } // SWIFT_ENABLE_TENSORFLOW - if (flags.isDifferentiable()) + switch (flags.getDifferentiabilityKind()) { + case FunctionMetadataDifferentiabilityKind::NonDifferentiable: + einfo = + einfo.withDifferentiabilityKind(DifferentiabilityKind::NonDifferentiable); + break; + case FunctionMetadataDifferentiabilityKind::Normal: einfo = einfo.withDifferentiabilityKind(DifferentiabilityKind::Normal); + break; + case FunctionMetadataDifferentiabilityKind::Linear: + einfo = einfo.withDifferentiabilityKind(DifferentiabilityKind::Linear); + break; + } // The result type must be materializable. if (!output->isMaterializable()) return Type(); diff --git a/lib/AST/ASTMangler.cpp b/lib/AST/ASTMangler.cpp index 4af1b3643cf94..16f2ab7e63223 100644 --- a/lib/AST/ASTMangler.cpp +++ b/lib/AST/ASTMangler.cpp @@ -1926,6 +1926,19 @@ void ASTMangler::appendFunctionType(AnyFunctionType *fn, bool isAutoClosure, case AnyFunctionType::Representation::Thin: return appendOperator("Xf"); case AnyFunctionType::Representation::Swift: + // SWIFT_ENABLE_TENSORFLOW + if (fn->getDifferentiabilityKind() == DifferentiabilityKind::Normal) { + if (fn->isNoEscape()) + return appendOperator("XF"); + else + return appendOperator("XG"); + } + if (fn->getDifferentiabilityKind() == DifferentiabilityKind::Linear) { + if (fn->isNoEscape()) + return appendOperator("XH"); + else + return appendOperator("XI"); + } if (isAutoClosure) { if (fn->isNoEscape()) return appendOperator("XK"); diff --git a/lib/Demangling/Demangler.cpp b/lib/Demangling/Demangler.cpp index 416be168f03da..efcd0a985500e 100644 --- a/lib/Demangling/Demangler.cpp +++ b/lib/Demangling/Demangler.cpp @@ -2715,6 +2715,18 @@ NodePointer Demangler::demangleSpecialType() { return popFunctionType(Node::Kind::ObjCBlock); case 'C': return popFunctionType(Node::Kind::CFunctionPointer); + // SWIFT_ENABLE_TENSORFLOW + case 'F': + return popFunctionType(Node::Kind::DifferentiableFunctionType); + // SWIFT_ENABLE_TENSORFLOW + case 'G': + return popFunctionType(Node::Kind::EscapingDifferentiableFunctionType); + // SWIFT_ENABLE_TENSORFLOW + case 'H': + return popFunctionType(Node::Kind::LinearFunctionType); + // SWIFT_ENABLE_TENSORFLOW + case 'I': + return popFunctionType(Node::Kind::EscapingLinearFunctionType); case 'o': return createType(createWithChild(Node::Kind::Unowned, popNode(Node::Kind::Type))); diff --git a/lib/Demangling/NodePrinter.cpp b/lib/Demangling/NodePrinter.cpp index a6e9db014595b..0eb0036cf1258 100644 --- a/lib/Demangling/NodePrinter.cpp +++ b/lib/Demangling/NodePrinter.cpp @@ -345,6 +345,12 @@ class NodePrinter { case Node::Kind::DependentPseudogenericSignature: case Node::Kind::Destructor: case Node::Kind::DidSet: + // SWIFT_ENABLE_TENSORFLOW + case Node::Kind::DifferentiableFunctionType: + case Node::Kind::EscapingDifferentiableFunctionType: + case Node::Kind::LinearFunctionType: + case Node::Kind::EscapingLinearFunctionType: + // SWIFT_ENABLE_TENSORFLOW END case Node::Kind::DirectMethodReferenceAttribute: case Node::Kind::Directness: case Node::Kind::DynamicAttribute: @@ -1189,6 +1195,26 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) { Printer << "@convention(thin) "; printFunctionType(nullptr, Node); return nullptr; + // SWIFT_ENABLE_TENSORFLOW + case Node::Kind::DifferentiableFunctionType: + Printer << "@differentiable "; + printFunctionType(nullptr, Node); + return nullptr; + // SWIFT_ENABLE_TENSORFLOW + case Node::Kind::EscapingDifferentiableFunctionType: + Printer << "@escaping @differentiable "; + printFunctionType(nullptr, Node); + return nullptr; + // SWIFT_ENABLE_TENSORFLOW + case Node::Kind::LinearFunctionType: + Printer << "@differentiable(linear) "; + printFunctionType(nullptr, Node); + return nullptr; + // SWIFT_ENABLE_TENSORFLOW + case Node::Kind::EscapingLinearFunctionType: + Printer << "@escaping @differentiable(linear) "; + printFunctionType(nullptr, Node); + return nullptr; case Node::Kind::FunctionType: case Node::Kind::UncurriedFunctionType: printFunctionType(nullptr, Node); @@ -2455,6 +2481,16 @@ void NodePrinter::printEntityType(NodePointer Entity, NodePointer type, Printer << ' '; type = dependentType->getFirstChild(); } + // SWIFT_ENABLE_TENSORFLOW + if (type->getKind() == Node::Kind::DifferentiableFunctionType) + Printer << "@differentiable "; + else if (type->getKind() == Node::Kind::EscapingDifferentiableFunctionType) + Printer << "@escaping @differentiable "; + else if (type->getKind() == Node::Kind::LinearFunctionType) + Printer << "@differentiable(linear) "; + else if (type->getKind() == Node::Kind::EscapingLinearFunctionType) + Printer << "@escaping @differentiable(linear) "; + // SWIFT_ENABLE_TENSORFLOW END printFunctionType(labelList, type); } else { print(type); diff --git a/lib/Demangling/OldRemangler.cpp b/lib/Demangling/OldRemangler.cpp index 94abb07518337..dd3621a4dce46 100644 --- a/lib/Demangling/OldRemangler.cpp +++ b/lib/Demangling/OldRemangler.cpp @@ -1159,6 +1159,30 @@ void Remangler::mangleThinFunctionType(Node *node) { mangleChildNodes(node); // argument tuple, result type } +// SWIFT_ENABLE_TENSORFLOW +void Remangler::mangleDifferentiableFunctionType(Node *node) { + Buffer << "XF"; + mangleChildNodes(node); // argument tuple, result type +} + +// SWIFT_ENABLE_TENSORFLOW +void Remangler::mangleEscapingDifferentiableFunctionType(Node *node) { + Buffer << "XG"; + mangleChildNodes(node); // argument tuple, result type +} + +// SWIFT_ENABLE_TENSORFLOW +void Remangler::mangleLinearFunctionType(Node *node) { + Buffer << "XH"; + mangleChildNodes(node); // argument tuple, result type +} + +// SWIFT_ENABLE_TENSORFLOW +void Remangler::mangleEscapingLinearFunctionType(Node *node) { + Buffer << "XI"; + mangleChildNodes(node); // argument tuple, result type +} + void Remangler::mangleArgumentTuple(Node *node) { mangleSingleChildNode(node); } diff --git a/lib/Demangling/Remangler.cpp b/lib/Demangling/Remangler.cpp index c910bb11d4c29..bf60e1a7ddda8 100644 --- a/lib/Demangling/Remangler.cpp +++ b/lib/Demangling/Remangler.cpp @@ -1197,6 +1197,30 @@ void Remangler::mangleFunctionType(Node *node) { Buffer << 'c'; } +// SWIFT_ENABLE_TENSORFLOW +void Remangler::mangleDifferentiableFunctionType(Node *node) { + mangleFunctionSignature(node); + Buffer << "XF"; +} + +// SWIFT_ENABLE_TENSORFLOW +void Remangler::mangleEscapingDifferentiableFunctionType(Node *node) { + mangleFunctionSignature(node); + Buffer << "XG"; +} + +// SWIFT_ENABLE_TENSORFLOW +void Remangler::mangleLinearFunctionType(Node *node) { + mangleFunctionSignature(node); + Buffer << "XH"; +} + +// SWIFT_ENABLE_TENSORFLOW +void Remangler::mangleEscapingLinearFunctionType(Node *node) { + mangleFunctionSignature(node); + Buffer << "XI"; +} + void Remangler::mangleGenericProtocolWitnessTable(Node *node) { mangleSingleChildNode(node); Buffer << "WG"; diff --git a/lib/IRGen/MetadataRequest.cpp b/lib/IRGen/MetadataRequest.cpp index c303e3886c492..c47b621f18735 100644 --- a/lib/IRGen/MetadataRequest.cpp +++ b/lib/IRGen/MetadataRequest.cpp @@ -1065,6 +1065,23 @@ namespace { break; } + // SWIFT_ENABLE_TENSORFLOW + FunctionMetadataDifferentiabilityKind metadataDifferentiabilityKind; + switch (type->getDifferentiabilityKind()) { + case DifferentiabilityKind::NonDifferentiable: + metadataDifferentiabilityKind = + FunctionMetadataDifferentiabilityKind::NonDifferentiable; + break; + case DifferentiabilityKind::Normal: + metadataDifferentiabilityKind = + FunctionMetadataDifferentiabilityKind::Normal; + break; + case DifferentiabilityKind::Linear: + metadataDifferentiabilityKind = + FunctionMetadataDifferentiabilityKind::Linear; + break; + } + auto flagsVal = FunctionTypeFlags() .withNumParameters(numParams) .withConvention(metadataConvention) @@ -1072,7 +1089,8 @@ namespace { .withParameterFlags(hasFlags) // SWIFT_ENABLE_TENSORFLOW .withEscaping(isEscaping) - .withDifferentiable(type->isDifferentiable()); + .withDifferentiabilityKind( + metadataDifferentiabilityKind); auto flags = llvm::ConstantInt::get(IGF.IGM.SizeTy, flagsVal.getIntValue()); diff --git a/test/TypeDecoder/structural_types.swift b/test/TypeDecoder/structural_types.swift index 450cfa678ee30..799f4b4a87d9e 100644 --- a/test/TypeDecoder/structural_types.swift +++ b/test/TypeDecoder/structural_types.swift @@ -131,6 +131,45 @@ do { blackHole(b) } +// SWIFT_ENABLE_TENSORFLOW +do { + let f: @differentiable (Float) -> Float = { $0 } + // FIXME(TF-123): `@differentiable` function type + opaque abstraction + // pattern bug. + // blackHole(f) + _ = f +} + +// SWIFT_ENABLE_TENSORFLOW +do { + let f: (@escaping @differentiable (Float) -> Float) -> () = { _ in } + // FIXME(TF-123): `@differentiable` function type + opaque abstraction + // pattern bug. + // blackHole(f) + _ = f +} + +// TODO: Uncomment when `@differentiable(linear)` function types are enabled. +/* +// SWIFT_ENABLE_TENSORFLOW +do { + let f: @differentiable(linear) (Float) -> Float = { $0 } + // FIXME(TF-123): `@differentiable` function type + opaque abstraction + // pattern bug. + // blackHole(f) + _ = f +} + +// SWIFT_ENABLE_TENSORFLOW +do { + let f: (@escaping @differentiable(linear) (Float) -> Float) -> () = { _ in } + // FIXME(TF-123): `@differentiable` function type + opaque abstraction + // pattern bug. + // blackHole(f) + _ = f +} +*/ + // DEMANGLE: $syycD // DEMANGLE: $sySSzcD // DEMANGLE: $sySSncD @@ -149,6 +188,10 @@ do { // DEMANGLE: $syyyccD // DEMANGLE: $sSayyyXCGD // DEMANGLE: $sSayyyyXL_yyXBtcGD +// DEMANGLE: $sS2fXFD +// DEMANGLE: $sS2fXGD +// DEMANGLE: $sS2fXHD +// DEMANGLE: $sS2fXID // CHECK: () -> () // CHECK: (inout String) -> () @@ -168,6 +211,10 @@ do { // CHECK: (@escaping () -> ()) -> () // CHECK: Array<@convention(c) () -> ()> // CHECK: Array<(@escaping @convention(block) () -> (), @convention(block) () -> ()) -> ()> +// CHECK: @differentiable (Float) -> Float +// CHECK: @differentiable (Float) -> Float +// CHECK: @differentiable(linear) (Float) -> Float +// CHECK: @differentiable(linear) (Float) -> Float // DEMANGLE: $sSimD // DEMANGLE: $syycmD @@ -188,6 +235,10 @@ do { // DEMANGLE: $syyyccmD // DEMANGLE: $sSayyyXCGmD // DEMANGLE: $sSayyyyXL_yyXBtcGmD +// DEMANGLE: $sS2fXFmD +// DEMANGLE: $sS2fXGmD +// DEMANGLE: $sS2fXHmD +// DEMANGLE: $sS2fXImD // CHECK: Int.Type // CHECK: ((inout String) -> ()).Type @@ -207,3 +258,7 @@ do { // CHECK: ((@escaping () -> ()) -> ()).Type // CHECK: Array<@convention(c) () -> ()>.Type // CHECK: Array<(@escaping @convention(block) () -> (), @convention(block) () -> ()) -> ()>.Type +// CHECK: (@differentiable (Float) -> Float).Type +// CHECK: (@differentiable (Float) -> Float).Type +// CHECK: (@differentiable(linear) (Float) -> Float).Type +// CHECK: (@differentiable(linear) (Float) -> Float).Type