Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1541,6 +1541,8 @@ class DifferentiableAttr final
ParsedAutoDiffParameter> {
friend TrailingObjects;

/// The declaration on which the `@differentiable` attribute is declared.
Decl *OriginalDeclaration = nullptr;
/// Whether this function is linear.
bool Linear;
/// The number of parsed parameters specified in 'wrt:'.
Expand Down Expand Up @@ -1573,7 +1575,7 @@ class DifferentiableAttr final
Optional<DeclNameWithLoc> vjp,
TrailingWhereClause *clause);

explicit DifferentiableAttr(ASTContext &context, bool implicit,
explicit DifferentiableAttr(Decl *original, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear, IndexSubset *indices,
Optional<DeclNameWithLoc> jvp,
Expand All @@ -1589,13 +1591,16 @@ class DifferentiableAttr final
Optional<DeclNameWithLoc> vjp,
TrailingWhereClause *clause);

static DifferentiableAttr *create(ASTContext &context, bool implicit,
static DifferentiableAttr *create(Decl *original, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear, IndexSubset *indices,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
GenericSignature derivativeGenSig);

Decl *getOriginalDeclaration() const { return OriginalDeclaration; }
void setOriginalDeclaration(Decl *decl);

/// Get the optional 'jvp:' function name and location.
/// Use this instead of `getJVPFunction` to check whether the attribute has a
/// registered JVP.
Expand Down
27 changes: 17 additions & 10 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1454,16 +1454,16 @@ DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
getTrailingObjects<ParsedAutoDiffParameter>());
}

DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
DifferentiableAttr::DifferentiableAttr(Decl *original, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
IndexSubset *indices,
bool linear, IndexSubset *indices,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
GenericSignature derivativeGenSig)
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)),
ParameterIndices(indices) {
setOriginalDeclaration(original);
setDerivativeGenericSignature(derivativeGenSig);
}

Expand All @@ -1483,19 +1483,26 @@ DifferentiableAttr::create(ASTContext &context, bool implicit,
}

DifferentiableAttr *
DifferentiableAttr::create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear, IndexSubset *indices,
Optional<DeclNameWithLoc> jvp,
DifferentiableAttr::create(Decl *original, bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
IndexSubset *indices, Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
GenericSignature derivativeGenSig) {
void *mem = context.Allocate(sizeof(DifferentiableAttr),
alignof(DifferentiableAttr));
return new (mem) DifferentiableAttr(context, implicit, atLoc, baseRange,
auto &ctx = original->getASTContext();
void *mem = ctx.Allocate(sizeof(DifferentiableAttr),
alignof(DifferentiableAttr));
return new (mem) DifferentiableAttr(original, implicit, atLoc, baseRange,
linear, indices, std::move(jvp),
std::move(vjp), derivativeGenSig);
}

void DifferentiableAttr::setOriginalDeclaration(Decl *decl) {
assert(decl && "Original declaration must be non-null");
assert(!OriginalDeclaration &&
"Original declaration cannot have already been set");
OriginalDeclaration = decl;
}

void DifferentiableAttr::setJVPFunction(FuncDecl *decl) {
JVPFunction = decl;
if (decl && !JVP)
Expand Down
27 changes: 27 additions & 0 deletions lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3272,6 +3272,13 @@ void Parser::delayParseFromBeginningToHere(ParserPosition BeginParserPosition,
consumeToken();
}

// SWIFT_ENABLE_TENSORFLOW
static void setOriginalFunctionInDifferentiableAttributes(
DeclAttributes Attributes, Decl *D) {
for (auto *attr : Attributes.getAttributes<DifferentiableAttr>())
const_cast<DifferentiableAttr *>(attr)->setOriginalDeclaration(D);
}

/// Parse a single syntactic declaration and return a list of decl
/// ASTs. This can return multiple results for var decls that bind to multiple
/// values, structs that define a struct decl and a constructor, etc.
Expand Down Expand Up @@ -3679,6 +3686,7 @@ Parser::parseDecl(ParseDeclOptions Flags,
Decl *D = DeclResult.get();
if (!declWasHandledAlready(D)) {
Handler(D);
// SWIFT_ENABLE_TENSORFLOW
if (auto FD = dyn_cast<FuncDecl>(D)) {
if (auto attr = D->getAttrs().getAttribute<QuotedAttr>()) {
// TODO(TF-718): Properly mangle names for quote decls.
Expand Down Expand Up @@ -3716,7 +3724,11 @@ Parser::parseDecl(ParseDeclOptions Flags,
Handler(quoteDecl);
}
}
// SWIFT_ENABLE_TENSORFLOW END
}
// SWIFT_ENABLE_TENSORFLOW
setOriginalFunctionInDifferentiableAttributes(D->getAttrs(), D);
// SWIFT_ENABLE_TENSORFLOW END
}

if (!DeclResult.isParseError()) {
Expand Down Expand Up @@ -5513,6 +5525,12 @@ Parser::parseDeclVarGetSet(Pattern *pattern, ParseDeclOptions Flags,

accessors.record(*this, PrimaryVar, Invalid);

// SWIFT_ENABLE_TENSORFLOW
for (auto *accessor : accessors.Accessors)
setOriginalFunctionInDifferentiableAttributes(accessor->getAttrs(),
accessor);
// SWIFT_ENABLE_TENSORFLOW END

return makeParserResult(PrimaryVar);
}

Expand Down Expand Up @@ -5773,6 +5791,9 @@ Parser::parseDeclVar(ParseDeclOptions Flags,
pattern->forEachVariable([&](VarDecl *VD) {
VD->setStatic(StaticLoc.isValid());
VD->getAttrs() = Attributes;
// SWIFT_ENABLE_TENSORFLOW
setOriginalFunctionInDifferentiableAttributes(Attributes, VD);
// SWIFT_ENABLE_TENSORFLOW END
setLocalDiscriminator(VD);
Decls.push_back(VD);
if (hasOpaqueReturnTy && sf) {
Expand Down Expand Up @@ -7025,6 +7046,12 @@ Parser::parseDeclSubscript(SourceLoc StaticLoc,

accessors.record(*this, Subscript, (Invalid || !Status.isSuccess()));

// SWIFT_ENABLE_TENSORFLOW
for (auto *accessor : accessors.Accessors)
setOriginalFunctionInDifferentiableAttributes(accessor->getAttrs(),
accessor);
// SWIFT_ENABLE_TENSORFLOW END

// No need to setLocalDiscriminator because subscripts cannot
// validly appear outside of type decls.
return makeParserResult(Status, Subscript);
Expand Down
5 changes: 3 additions & 2 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -641,8 +641,9 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
if (auto *extDecl = dyn_cast<ExtensionDecl>(parentDC->getAsDecl()))
derivativeGenSig = extDecl->getGenericSignature();
auto *diffableAttr = DifferentiableAttr::create(
C, /*implicit*/ true, SourceLoc(), SourceLoc(),
/*linear*/ false, {}, None, None, derivativeGenSig);
member->getAccessor(AccessorKind::Get), /*implicit*/ true,
SourceLoc(), SourceLoc(), /*linear*/ false, {}, None, None,
derivativeGenSig);
member->getAttrs().add(diffableAttr);
// Set getter `@differentiable` attribute parameter indices.
diffableAttr->setParameterIndices(IndexSubset::get(C, 1, {0}));
Expand Down
17 changes: 10 additions & 7 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3274,6 +3274,9 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
return;
}

assert(attr->getOriginalDeclaration() &&
"`@differentiable` attribute should have original declaration set "
"during construction or parsing");
TC.resolveDeclSignature(original);
auto *originalFnTy = original->getInterfaceType()->castTo<AnyFunctionType>();
bool isMethod = original->hasImplicitSelfDecl();
Expand Down Expand Up @@ -3530,15 +3533,15 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
D->getAttrs().removeAttribute(attr);
// Transfer `@differentiable` attribute from storage declaration to
// getter accessor.
auto *getterDecl = asd->getAccessor(AccessorKind::Get);
auto *newAttr = DifferentiableAttr::create(
ctx, /*implicit*/ true, attr->AtLoc, attr->getRange(), attr->isLinear(),
attr->getParameterIndices(), attr->getJVP(), attr->getVJP(),
attr->getDerivativeGenericSignature());
getterDecl, /*implicit*/ true, attr->AtLoc, attr->getRange(),
attr->isLinear(), attr->getParameterIndices(), attr->getJVP(),
attr->getVJP(), attr->getDerivativeGenericSignature());
newAttr->setJVPFunction(attr->getJVPFunction());
newAttr->setVJPFunction(attr->getVJPFunction());
auto insertion = ctx.DifferentiableAttrs.try_emplace(
{asd->getAccessor(AccessorKind::Get), newAttr->getParameterIndices()},
newAttr);
{getterDecl, newAttr->getParameterIndices()}, newAttr);
// Valid `@differentiable` attributes are uniqued by their parameter
// indices. Reject duplicate attributes for the same decl and parameter
// indices pair.
Expand All @@ -3548,7 +3551,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
diag::differentiable_attr_duplicate_note);
return;
}
asd->getAccessor(AccessorKind::Get)->getAttrs().add(newAttr);
getterDecl->getAttrs().add(newAttr);
return;
}
auto insertion = ctx.DifferentiableAttrs.try_emplace(
Expand Down Expand Up @@ -3826,7 +3829,7 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) {
// If the original function does not have a `@differentiable` attribute with
// the same differentiation parameters, create one.
if (!da) {
da = DifferentiableAttr::create(ctx, /*implicit*/ true, attr->AtLoc,
da = DifferentiableAttr::create(originalFn, /*implicit*/ true, attr->AtLoc,
attr->getRange(), attr->isLinear(),
checkedWrtParamIndices, /*jvp*/ None,
/*vjp*/ None,
Expand Down
3 changes: 2 additions & 1 deletion lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,9 @@ swift::matchWitness(
if (!reqDiffAttrMatch) {
auto implicitDiffAttr = false;
if (reqDiffAttrSupersetMatch) {
auto *witnessAFD = cast<AbstractFunctionDecl>(witness);
auto *newAttr = DifferentiableAttr::create(
ctx, /*implicit*/ true, reqDiffAttr->AtLoc,
witnessAFD, /*implicit*/ true, reqDiffAttr->AtLoc,
reqDiffAttr->getRange(), reqDiffAttr->isLinear(),
reqDiffAttr->getParameterIndices(), /*jvp*/ None,
/*vjp*/ None, reqDiffAttr->getDerivativeGenericSignature());
Expand Down
38 changes: 32 additions & 6 deletions lib/Serialization/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2166,6 +2166,24 @@ static bool attributeChainContains(DeclAttribute *attr) {
return tempAttrs.hasAttribute<DERIVED>();
}

// SWIFT_ENABLE_TENSORFLOW
// Set original declaration in `@differentiable` attributes.
//
// Serializing/deserializing the original declaration DeclID in
// `@differentiable` attributes does not work because it causes
// `@differentiable` attribute deserialization to enter an infinite loop.
//
// Instead, call this ad-hoc function after deserializing a declaration to set
// it as the original declaration in its `@differentiable` attributes.
static void setOriginalDeclarationInDifferentiableAttributes(
Decl *decl, DeclAttribute *attrs) {
DeclAttributes tempAttrs;
tempAttrs.setRawAttributeChain(attrs);
for (auto *attr : tempAttrs.getAttributes<DifferentiableAttr>())
const_cast<DifferentiableAttr *>(attr)->setOriginalDeclaration(decl);
}
// SWIFT_ENABLE_TENSORFLOW END

Decl *ModuleFile::getDecl(DeclID DID) {
Expected<Decl *> deserialized = getDeclChecked(DID);
if (!deserialized) {
Expand Down Expand Up @@ -4084,10 +4102,11 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
parametersBitVector[i] = parameters[i];
auto *indices = IndexSubset::get(ctx, parametersBitVector);

auto diffAttr =
DifferentiableAttr::create(ctx, isImplicit, SourceLoc(),
SourceRange(), linear, indices, jvp, vjp,
derivativeGenSig);
auto *diffAttr = DifferentiableAttr::create(
ctx, isImplicit, SourceLoc(), SourceRange(), linear,
/*parsedParameters*/ {}, jvp, vjp, /*trailingWhereClause*/ nullptr);
diffAttr->setParameterIndices(indices);
diffAttr->setDerivativeGenericSignature(derivativeGenSig);
diffAttr->setJVPFunction(jvpDecl);
diffAttr->setVJPFunction(vjpDecl);
Attr = diffAttr;
Expand Down Expand Up @@ -4237,9 +4256,16 @@ DeclDeserializer::getDeclCheckedImpl() {
&MF, declOrOffset, static_cast<decls_block::RecordKind>(recordID));

switch (recordID) {
// SWIFT_ENABLE_TENSORFLOW
// Set original declaration in `@differentiable` attributes.
#define CASE(RECORD_NAME) \
case decls_block::RECORD_NAME##Layout::Code: \
return deserialize##RECORD_NAME(scratch, blobData);
case decls_block::RECORD_NAME##Layout::Code: {\
auto decl = deserialize##RECORD_NAME(scratch, blobData); \
if (decl) \
setOriginalDeclarationInDifferentiableAttributes(decl.get(), DAttrs); \
return decl; \
}
// SWIFT_ENABLE_TENSORFLOW END

CASE(TypeAlias)
CASE(GenericTypeParamDecl)
Expand Down
3 changes: 3 additions & 0 deletions lib/Serialization/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2309,6 +2309,9 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
case DAK_Differentiable: {
auto abbrCode = S.DeclTypeAbbrCodes[DifferentiableDeclAttrLayout::Code];
auto *attr = cast<DifferentiableAttr>(DA);
assert(attr->getOriginalDeclaration() &&
"`@differentiable` attribute should have original declaration set "
"during construction or parsing");

IdentifierID jvpName = 0;
DeclID jvpRef = 0;
Expand Down
32 changes: 32 additions & 0 deletions test/AutoDiff/differentiable_attr_serialization.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// RUN: %empty-directory(%t)
// RUN: %target-swift-frontend -emit-module %s -o %t/differentiable_attr_serialization.swiftmodule
// RUN: %target-swift-frontend -merge-modules -sil-merge-partial-modules -emit-module %t/differentiable_attr_serialization.swiftmodule

// Test round-trip `@differentiable` attribute AST serialization.

// Motivation: check that `@differentiable` attributes always have original
// declaration set.

struct Foo: Differentiable {
@differentiable
func method() -> Self { self }

@differentiable
init(_ x: Float) {}

@differentiable
var computedProperty: Float { 1 }

var computedPropertyGetter: Float {
@differentiable
get { 1 }
}

@differentiable
subscript() -> Float { 1 }

subscript(_ x: Float) -> Float {
@differentiable
get { 1 }
}
}