Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/ABI/Mangling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,10 @@ Types
FUNCTION-KIND ::= 'C' // C function pointer type
FUNCTION-KIND ::= 'A' // @auto_closure function type (escaping)
FUNCTION-KIND ::= 'E' // function type (noescape)
FUNCTION-KIND ::= 'F' // @differentiable function type
FUNCTION-KIND ::= 'G' // @differentiable function type (escaping)
FUNCTION-KIND ::= 'H' // @differentiable(linear) function type
FUNCTION-KIND ::= 'I' // @differentiable(linear) function type (escaping)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would you differentiate between @differentiable @convention(thin) and @differentiable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Currently, modifiers like @convention(thin) and @differentiable have their own function type demangle node variant:

Modifier Demangle node kind
- NoEscapeFunctionType
@escaping FunctionType (escaping)
@autoclosure AutoClosureType
@escaping @autoclosure EscapingAutoClosureType
... ...

Here's what it looks like:

$ $SWIFT_BIN/swift-demangle --expand sS2fXfD
Demangling for $sS2fXfD
kind=Global
  kind=TypeMangling
    kind=Type
      kind=ThinFunctionType
        kind=ArgumentTuple
          kind=Type
            kind=Structure
              kind=Module, text="Swift"
              kind=Identifier, text="Float"
        kind=ReturnType
          kind=Type
            kind=Structure
              kind=Module, text="Swift"
              kind=Identifier, text="Float"
$sS2fXfD ---> @convention(thin) (Swift.Float) -> Swift.Float

This leads to an explosion of demangle node kinds for combinations of modifiers.

A more reflexible representation would be to have a single FunctionType demangle node kind, with modifier nodes as children in a list. Dummy view:

      kind=FunctionType
        kind=FunctionTypeModifiers
          kind=ThinFunction
          kind=EscapingFunction
          kind=DifferentiableFunction
        kind=ArgumentTuple
          kind=Type
            kind=Structure
              kind=Module, text="Swift"
              kind=Identifier, text="Float"
        kind=ReturnType
          kind=Type
            kind=Structure
              kind=Module, text="Swift"
              kind=Identifier, text="Float"

Changing the mangling scheme certainly seems breaking though. The significance of type mangling isn't exactly clear to me: I imagine LLDB uses this mangling to show types of variables. Perhaps it's not important to capture all combinations of function type modifiers, only the semantically important ones (@differentiable is probably more semantically heavy than @escaping).

Started a discussion at https://forums.swift.org/t/questions-about-function-type-mangling/26360.


function-signature ::= params-type params-type throws? // results and parameters

Expand Down
33 changes: 27 additions & 6 deletions include/swift/ABI/MetadataValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,14 @@ enum class FunctionMetadataConvention: uint8_t {
CFunctionPointer = 3,
};

/// Differentiability kind for function type metadata.
/// Duplicates `DifferentiabilityKind` in AutoDiff.h.
enum class FunctionMetadataDifferentiabilityKind: uint8_t {
NonDifferentiable = 0b00,
Normal = 0b01,
Linear = 0b11
};

/// Flags in a function type metadata record.
template <typename int_type>
class TargetFunctionTypeFlags {
Expand All @@ -747,7 +755,8 @@ class TargetFunctionTypeFlags {
ParamFlagsMask = 0x02000000U,
EscapingMask = 0x04000000U,
// SWIFT_ENABLE_TENSORFLOW
DifferentiableMask = 0x08000000U
DifferentiableMask = 0x08000000U,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using a single DifferentiabilityMask, similar to the one in AnyFunctionType::ExtInfo:

      DifferentiabilityMask  = 0xF8000000U,
      DifferentiabilityShift = 24U

But I didn't figure out how to make it work all the way. Things are tricky because the mask is F8 instead of FF, which requires extra bitshifting.

LinearMask = 0x10000000U
};
int_type Data;

Expand Down Expand Up @@ -785,10 +794,14 @@ class TargetFunctionTypeFlags {
}

// SWIFT_ENABLE_TENSORFLOW
constexpr TargetFunctionTypeFlags<int_type>
withDifferentiable(bool isDifferentiable) const {
return TargetFunctionTypeFlags<int_type>((Data & ~DifferentiableMask) |
(isDifferentiable ? DifferentiableMask : 0));
constexpr TargetFunctionTypeFlags<int_type> withDifferentiabilityKind(
FunctionMetadataDifferentiabilityKind differentiability) const {
return TargetFunctionTypeFlags<int_type>(
(Data & ~DifferentiableMask & ~LinearMask) |
(differentiability == FunctionMetadataDifferentiabilityKind::Normal
? DifferentiableMask : 0) |
(differentiability == FunctionMetadataDifferentiabilityKind::Linear
? LinearMask : 0));
}

unsigned getNumParameters() const { return Data & NumParametersMask; }
Expand All @@ -807,7 +820,15 @@ class TargetFunctionTypeFlags {

// SWIFT_ENABLE_TENSORFLOW
bool isDifferentiable() const {
return bool (Data & DifferentiableMask);
return getDifferentiabilityKind() >=
FunctionMetadataDifferentiabilityKind::Normal;
}
FunctionMetadataDifferentiabilityKind getDifferentiabilityKind() const {
if (bool(Data & DifferentiableMask))
return FunctionMetadataDifferentiabilityKind::Normal;
if (bool(Data & LinearMask))
return FunctionMetadataDifferentiabilityKind::Linear;
return FunctionMetadataDifferentiabilityKind::NonDifferentiable;
}

bool hasParameterFlags() const { return bool(Data & ParamFlagsMask); }
Expand Down
2 changes: 1 addition & 1 deletion include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include "swift/Basic/Range.h"

namespace swift {

enum class DifferentiabilityKind: uint8_t {
NonDifferentiable = 0b00,
Normal = 0b01,
Expand Down
6 changes: 6 additions & 0 deletions include/swift/Demangling/DemangleNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ NODE(DependentProtocolConformanceInherited)
NODE(DependentProtocolConformanceAssociated)
CONTEXT_NODE(Destructor)
CONTEXT_NODE(DidSet)
// SWIFT_ENABLE_TENSORFLOW
NODE(DifferentiableFunctionType)
NODE(EscapingDifferentiableFunctionType)
NODE(LinearFunctionType)
NODE(EscapingLinearFunctionType)
// SWIFT_ENABLE_TENSORFLOW END
NODE(Directness)
NODE(DynamicAttribute)
NODE(DirectMethodReferenceAttribute)
Expand Down
24 changes: 23 additions & 1 deletion include/swift/Demangling/TypeDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,12 @@ class TypeDecoder {
case NodeKind::NoEscapeFunctionType:
case NodeKind::AutoClosureType:
case NodeKind::EscapingAutoClosureType:
// SWIFT_ENABLE_TENSORFLOW
case NodeKind::DifferentiableFunctionType:
case NodeKind::EscapingDifferentiableFunctionType:
case NodeKind::LinearFunctionType:
case NodeKind::EscapingLinearFunctionType:
// SWIFT_ENABLE_TENSORFLOW END
case NodeKind::FunctionType: {
if (Node->getNumChildren() < 2)
return BuiltType();
Expand All @@ -508,6 +514,17 @@ class TypeDecoder {
} else if (Node->getKind() == NodeKind::ThinFunctionType) {
flags = flags.withConvention(FunctionMetadataConvention::Thin);
}
// SWIFT_ENABLE_TENSORFLOW
else if (Node->getKind() == NodeKind::DifferentiableFunctionType ||
Node->getKind() ==
NodeKind::EscapingDifferentiableFunctionType) {
flags = flags.withDifferentiabilityKind(
FunctionMetadataDifferentiabilityKind::Normal);
} else if (Node->getKind() == NodeKind::LinearFunctionType ||
Node->getKind() == NodeKind::EscapingLinearFunctionType) {
flags = flags.withDifferentiabilityKind(
FunctionMetadataDifferentiabilityKind::Linear);
}

bool isThrow =
Node->getChild(0)->getKind() == NodeKind::ThrowsAnnotation;
Expand All @@ -527,7 +544,12 @@ class TypeDecoder {
.withEscaping(
Node->getKind() == NodeKind::FunctionType ||
Node->getKind() == NodeKind::EscapingAutoClosureType ||
Node->getKind() == NodeKind::EscapingObjCBlock);
Node->getKind() == NodeKind::EscapingObjCBlock ||
// SWIFT_ENABLE_TENSORFLOW
Node->getKind() ==
NodeKind::EscapingDifferentiableFunctionType ||
Node->getKind() ==
NodeKind::EscapingLinearFunctionType);

auto result = decodeMangledType(Node->getChild(isThrow ? 2 : 1));
if (!result) return BuiltType();
Expand Down
12 changes: 11 additions & 1 deletion lib/AST/ASTDemangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,18 @@ Type ASTBuilder::createFunctionType(
}

// SWIFT_ENABLE_TENSORFLOW
if (flags.isDifferentiable())
switch (flags.getDifferentiabilityKind()) {
case FunctionMetadataDifferentiabilityKind::NonDifferentiable:
einfo =
einfo.withDifferentiabilityKind(DifferentiabilityKind::NonDifferentiable);
break;
case FunctionMetadataDifferentiabilityKind::Normal:
einfo = einfo.withDifferentiabilityKind(DifferentiabilityKind::Normal);
break;
case FunctionMetadataDifferentiabilityKind::Linear:
einfo = einfo.withDifferentiabilityKind(DifferentiabilityKind::Linear);
break;
}

// The result type must be materializable.
if (!output->isMaterializable()) return Type();
Expand Down
13 changes: 13 additions & 0 deletions lib/AST/ASTMangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1926,6 +1926,19 @@ void ASTMangler::appendFunctionType(AnyFunctionType *fn, bool isAutoClosure,
case AnyFunctionType::Representation::Thin:
return appendOperator("Xf");
case AnyFunctionType::Representation::Swift:
// SWIFT_ENABLE_TENSORFLOW
if (fn->getDifferentiabilityKind() == DifferentiabilityKind::Normal) {
if (fn->isNoEscape())
return appendOperator("XF");
else
return appendOperator("XG");
}
if (fn->getDifferentiabilityKind() == DifferentiabilityKind::Linear) {
if (fn->isNoEscape())
return appendOperator("XH");
else
return appendOperator("XI");
}
if (isAutoClosure) {
if (fn->isNoEscape())
return appendOperator("XK");
Expand Down
12 changes: 12 additions & 0 deletions lib/Demangling/Demangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2715,6 +2715,18 @@ NodePointer Demangler::demangleSpecialType() {
return popFunctionType(Node::Kind::ObjCBlock);
case 'C':
return popFunctionType(Node::Kind::CFunctionPointer);
// SWIFT_ENABLE_TENSORFLOW
case 'F':
return popFunctionType(Node::Kind::DifferentiableFunctionType);
// SWIFT_ENABLE_TENSORFLOW
case 'G':
return popFunctionType(Node::Kind::EscapingDifferentiableFunctionType);
// SWIFT_ENABLE_TENSORFLOW
case 'H':
return popFunctionType(Node::Kind::LinearFunctionType);
// SWIFT_ENABLE_TENSORFLOW
case 'I':
return popFunctionType(Node::Kind::EscapingLinearFunctionType);
case 'o':
return createType(createWithChild(Node::Kind::Unowned,
popNode(Node::Kind::Type)));
Expand Down
36 changes: 36 additions & 0 deletions lib/Demangling/NodePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,12 @@ class NodePrinter {
case Node::Kind::DependentPseudogenericSignature:
case Node::Kind::Destructor:
case Node::Kind::DidSet:
// SWIFT_ENABLE_TENSORFLOW
case Node::Kind::DifferentiableFunctionType:
case Node::Kind::EscapingDifferentiableFunctionType:
case Node::Kind::LinearFunctionType:
case Node::Kind::EscapingLinearFunctionType:
// SWIFT_ENABLE_TENSORFLOW END
case Node::Kind::DirectMethodReferenceAttribute:
case Node::Kind::Directness:
case Node::Kind::DynamicAttribute:
Expand Down Expand Up @@ -1189,6 +1195,26 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) {
Printer << "@convention(thin) ";
printFunctionType(nullptr, Node);
return nullptr;
// SWIFT_ENABLE_TENSORFLOW
case Node::Kind::DifferentiableFunctionType:
Printer << "@differentiable ";
printFunctionType(nullptr, Node);
return nullptr;
// SWIFT_ENABLE_TENSORFLOW
case Node::Kind::EscapingDifferentiableFunctionType:
Printer << "@escaping @differentiable ";
printFunctionType(nullptr, Node);
return nullptr;
// SWIFT_ENABLE_TENSORFLOW
case Node::Kind::LinearFunctionType:
Printer << "@differentiable(linear) ";
printFunctionType(nullptr, Node);
return nullptr;
// SWIFT_ENABLE_TENSORFLOW
case Node::Kind::EscapingLinearFunctionType:
Printer << "@escaping @differentiable(linear) ";
printFunctionType(nullptr, Node);
return nullptr;
case Node::Kind::FunctionType:
case Node::Kind::UncurriedFunctionType:
printFunctionType(nullptr, Node);
Expand Down Expand Up @@ -2455,6 +2481,16 @@ void NodePrinter::printEntityType(NodePointer Entity, NodePointer type,
Printer << ' ';
type = dependentType->getFirstChild();
}
// SWIFT_ENABLE_TENSORFLOW
if (type->getKind() == Node::Kind::DifferentiableFunctionType)
Printer << "@differentiable ";
else if (type->getKind() == Node::Kind::EscapingDifferentiableFunctionType)
Printer << "@escaping @differentiable ";
else if (type->getKind() == Node::Kind::LinearFunctionType)
Printer << "@differentiable(linear) ";
else if (type->getKind() == Node::Kind::EscapingLinearFunctionType)
Printer << "@escaping @differentiable(linear) ";
// SWIFT_ENABLE_TENSORFLOW END
printFunctionType(labelList, type);
} else {
print(type);
Expand Down
24 changes: 24 additions & 0 deletions lib/Demangling/OldRemangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,30 @@ void Remangler::mangleThinFunctionType(Node *node) {
mangleChildNodes(node); // argument tuple, result type
}

// SWIFT_ENABLE_TENSORFLOW
void Remangler::mangleDifferentiableFunctionType(Node *node) {
Buffer << "XF";
mangleChildNodes(node); // argument tuple, result type
}

// SWIFT_ENABLE_TENSORFLOW
void Remangler::mangleEscapingDifferentiableFunctionType(Node *node) {
Buffer << "XG";
mangleChildNodes(node); // argument tuple, result type
}

// SWIFT_ENABLE_TENSORFLOW
void Remangler::mangleLinearFunctionType(Node *node) {
Buffer << "XH";
mangleChildNodes(node); // argument tuple, result type
}

// SWIFT_ENABLE_TENSORFLOW
void Remangler::mangleEscapingLinearFunctionType(Node *node) {
Buffer << "XI";
mangleChildNodes(node); // argument tuple, result type
}

void Remangler::mangleArgumentTuple(Node *node) {
mangleSingleChildNode(node);
}
Expand Down
24 changes: 24 additions & 0 deletions lib/Demangling/Remangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1197,6 +1197,30 @@ void Remangler::mangleFunctionType(Node *node) {
Buffer << 'c';
}

// SWIFT_ENABLE_TENSORFLOW
void Remangler::mangleDifferentiableFunctionType(Node *node) {
mangleFunctionSignature(node);
Buffer << "XF";
}

// SWIFT_ENABLE_TENSORFLOW
void Remangler::mangleEscapingDifferentiableFunctionType(Node *node) {
mangleFunctionSignature(node);
Buffer << "XG";
}

// SWIFT_ENABLE_TENSORFLOW
void Remangler::mangleLinearFunctionType(Node *node) {
mangleFunctionSignature(node);
Buffer << "XH";
}

// SWIFT_ENABLE_TENSORFLOW
void Remangler::mangleEscapingLinearFunctionType(Node *node) {
mangleFunctionSignature(node);
Buffer << "XI";
}

void Remangler::mangleGenericProtocolWitnessTable(Node *node) {
mangleSingleChildNode(node);
Buffer << "WG";
Expand Down
20 changes: 19 additions & 1 deletion lib/IRGen/MetadataRequest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1065,14 +1065,32 @@ namespace {
break;
}

// SWIFT_ENABLE_TENSORFLOW
FunctionMetadataDifferentiabilityKind metadataDifferentiabilityKind;
switch (type->getDifferentiabilityKind()) {
case DifferentiabilityKind::NonDifferentiable:
metadataDifferentiabilityKind =
FunctionMetadataDifferentiabilityKind::NonDifferentiable;
break;
case DifferentiabilityKind::Normal:
metadataDifferentiabilityKind =
FunctionMetadataDifferentiabilityKind::Normal;
break;
case DifferentiabilityKind::Linear:
metadataDifferentiabilityKind =
FunctionMetadataDifferentiabilityKind::Linear;
break;
}

auto flagsVal = FunctionTypeFlags()
.withNumParameters(numParams)
.withConvention(metadataConvention)
.withThrows(type->throws())
.withParameterFlags(hasFlags)
// SWIFT_ENABLE_TENSORFLOW
.withEscaping(isEscaping)
.withDifferentiable(type->isDifferentiable());
.withDifferentiabilityKind(
metadataDifferentiabilityKind);

auto flags = llvm::ConstantInt::get(IGF.IGM.SizeTy,
flagsVal.getIntValue());
Expand Down
Loading