diff --git a/include/swift/AST/Attr.h b/include/swift/AST/Attr.h index fc140c38926a8..8a50ee30d52a5 100644 --- a/include/swift/AST/Attr.h +++ b/include/swift/AST/Attr.h @@ -1541,6 +1541,8 @@ class DifferentiableAttr final ParsedAutoDiffParameter> { friend TrailingObjects; + /// The declaration on which the `@differentiable` attribute is declared. + Decl *OriginalDeclaration = nullptr; /// Whether this function is linear. bool Linear; /// The number of parsed parameters specified in 'wrt:'. @@ -1573,7 +1575,7 @@ class DifferentiableAttr final Optional vjp, TrailingWhereClause *clause); - explicit DifferentiableAttr(ASTContext &context, bool implicit, + explicit DifferentiableAttr(Decl *original, bool implicit, SourceLoc atLoc, SourceRange baseRange, bool linear, IndexSubset *indices, Optional jvp, @@ -1589,13 +1591,16 @@ class DifferentiableAttr final Optional vjp, TrailingWhereClause *clause); - static DifferentiableAttr *create(ASTContext &context, bool implicit, + static DifferentiableAttr *create(Decl *original, bool implicit, SourceLoc atLoc, SourceRange baseRange, bool linear, IndexSubset *indices, Optional jvp, Optional vjp, GenericSignature derivativeGenSig); + Decl *getOriginalDeclaration() const { return OriginalDeclaration; } + void setOriginalDeclaration(Decl *decl); + /// Get the optional 'jvp:' function name and location. /// Use this instead of `getJVPFunction` to check whether the attribute has a /// registered JVP. diff --git a/lib/AST/Attr.cpp b/lib/AST/Attr.cpp index d35bbeb0783d3..64136ae943db2 100644 --- a/lib/AST/Attr.cpp +++ b/lib/AST/Attr.cpp @@ -1454,16 +1454,16 @@ DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit, getTrailingObjects()); } -DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit, +DifferentiableAttr::DifferentiableAttr(Decl *original, bool implicit, SourceLoc atLoc, SourceRange baseRange, - bool linear, - IndexSubset *indices, + bool linear, IndexSubset *indices, Optional jvp, Optional vjp, GenericSignature derivativeGenSig) : DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit), Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)), ParameterIndices(indices) { + setOriginalDeclaration(original); setDerivativeGenericSignature(derivativeGenSig); } @@ -1483,19 +1483,26 @@ DifferentiableAttr::create(ASTContext &context, bool implicit, } DifferentiableAttr * -DifferentiableAttr::create(ASTContext &context, bool implicit, - SourceLoc atLoc, SourceRange baseRange, - bool linear, IndexSubset *indices, - Optional jvp, +DifferentiableAttr::create(Decl *original, bool implicit, SourceLoc atLoc, + SourceRange baseRange, bool linear, + IndexSubset *indices, Optional jvp, Optional vjp, GenericSignature derivativeGenSig) { - void *mem = context.Allocate(sizeof(DifferentiableAttr), - alignof(DifferentiableAttr)); - return new (mem) DifferentiableAttr(context, implicit, atLoc, baseRange, + auto &ctx = original->getASTContext(); + void *mem = ctx.Allocate(sizeof(DifferentiableAttr), + alignof(DifferentiableAttr)); + return new (mem) DifferentiableAttr(original, implicit, atLoc, baseRange, linear, indices, std::move(jvp), std::move(vjp), derivativeGenSig); } +void DifferentiableAttr::setOriginalDeclaration(Decl *decl) { + assert(decl && "Original declaration must be non-null"); + assert(!OriginalDeclaration && + "Original declaration cannot have already been set"); + OriginalDeclaration = decl; +} + void DifferentiableAttr::setJVPFunction(FuncDecl *decl) { JVPFunction = decl; if (decl && !JVP) diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index b80618cfadc68..5e2768d0209c6 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -3272,6 +3272,13 @@ void Parser::delayParseFromBeginningToHere(ParserPosition BeginParserPosition, consumeToken(); } +// SWIFT_ENABLE_TENSORFLOW +static void setOriginalFunctionInDifferentiableAttributes( + DeclAttributes Attributes, Decl *D) { + for (auto *attr : Attributes.getAttributes()) + const_cast(attr)->setOriginalDeclaration(D); +} + /// Parse a single syntactic declaration and return a list of decl /// ASTs. This can return multiple results for var decls that bind to multiple /// values, structs that define a struct decl and a constructor, etc. @@ -3679,6 +3686,7 @@ Parser::parseDecl(ParseDeclOptions Flags, Decl *D = DeclResult.get(); if (!declWasHandledAlready(D)) { Handler(D); + // SWIFT_ENABLE_TENSORFLOW if (auto FD = dyn_cast(D)) { if (auto attr = D->getAttrs().getAttribute()) { // TODO(TF-718): Properly mangle names for quote decls. @@ -3716,7 +3724,11 @@ Parser::parseDecl(ParseDeclOptions Flags, Handler(quoteDecl); } } + // SWIFT_ENABLE_TENSORFLOW END } + // SWIFT_ENABLE_TENSORFLOW + setOriginalFunctionInDifferentiableAttributes(D->getAttrs(), D); + // SWIFT_ENABLE_TENSORFLOW END } if (!DeclResult.isParseError()) { @@ -5513,6 +5525,12 @@ Parser::parseDeclVarGetSet(Pattern *pattern, ParseDeclOptions Flags, accessors.record(*this, PrimaryVar, Invalid); + // SWIFT_ENABLE_TENSORFLOW + for (auto *accessor : accessors.Accessors) + setOriginalFunctionInDifferentiableAttributes(accessor->getAttrs(), + accessor); + // SWIFT_ENABLE_TENSORFLOW END + return makeParserResult(PrimaryVar); } @@ -5773,6 +5791,9 @@ Parser::parseDeclVar(ParseDeclOptions Flags, pattern->forEachVariable([&](VarDecl *VD) { VD->setStatic(StaticLoc.isValid()); VD->getAttrs() = Attributes; + // SWIFT_ENABLE_TENSORFLOW + setOriginalFunctionInDifferentiableAttributes(Attributes, VD); + // SWIFT_ENABLE_TENSORFLOW END setLocalDiscriminator(VD); Decls.push_back(VD); if (hasOpaqueReturnTy && sf) { @@ -7025,6 +7046,12 @@ Parser::parseDeclSubscript(SourceLoc StaticLoc, accessors.record(*this, Subscript, (Invalid || !Status.isSuccess())); + // SWIFT_ENABLE_TENSORFLOW + for (auto *accessor : accessors.Accessors) + setOriginalFunctionInDifferentiableAttributes(accessor->getAttrs(), + accessor); + // SWIFT_ENABLE_TENSORFLOW END + // No need to setLocalDiscriminator because subscripts cannot // validly appear outside of type decls. return makeParserResult(Status, Subscript); diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index dfd50b132b695..b641ca72cd1b5 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -641,8 +641,9 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) { if (auto *extDecl = dyn_cast(parentDC->getAsDecl())) derivativeGenSig = extDecl->getGenericSignature(); auto *diffableAttr = DifferentiableAttr::create( - C, /*implicit*/ true, SourceLoc(), SourceLoc(), - /*linear*/ false, {}, None, None, derivativeGenSig); + member->getAccessor(AccessorKind::Get), /*implicit*/ true, + SourceLoc(), SourceLoc(), /*linear*/ false, {}, None, None, + derivativeGenSig); member->getAttrs().add(diffableAttr); // Set getter `@differentiable` attribute parameter indices. diffableAttr->setParameterIndices(IndexSubset::get(C, 1, {0})); diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index eec7a9f7313f2..d02dc8c6d40db 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -3274,6 +3274,9 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { return; } + assert(attr->getOriginalDeclaration() && + "`@differentiable` attribute should have original declaration set " + "during construction or parsing"); TC.resolveDeclSignature(original); auto *originalFnTy = original->getInterfaceType()->castTo(); bool isMethod = original->hasImplicitSelfDecl(); @@ -3530,15 +3533,15 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { D->getAttrs().removeAttribute(attr); // Transfer `@differentiable` attribute from storage declaration to // getter accessor. + auto *getterDecl = asd->getAccessor(AccessorKind::Get); auto *newAttr = DifferentiableAttr::create( - ctx, /*implicit*/ true, attr->AtLoc, attr->getRange(), attr->isLinear(), - attr->getParameterIndices(), attr->getJVP(), attr->getVJP(), - attr->getDerivativeGenericSignature()); + getterDecl, /*implicit*/ true, attr->AtLoc, attr->getRange(), + attr->isLinear(), attr->getParameterIndices(), attr->getJVP(), + attr->getVJP(), attr->getDerivativeGenericSignature()); newAttr->setJVPFunction(attr->getJVPFunction()); newAttr->setVJPFunction(attr->getVJPFunction()); auto insertion = ctx.DifferentiableAttrs.try_emplace( - {asd->getAccessor(AccessorKind::Get), newAttr->getParameterIndices()}, - newAttr); + {getterDecl, newAttr->getParameterIndices()}, newAttr); // Valid `@differentiable` attributes are uniqued by their parameter // indices. Reject duplicate attributes for the same decl and parameter // indices pair. @@ -3548,7 +3551,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { diag::differentiable_attr_duplicate_note); return; } - asd->getAccessor(AccessorKind::Get)->getAttrs().add(newAttr); + getterDecl->getAttrs().add(newAttr); return; } auto insertion = ctx.DifferentiableAttrs.try_emplace( @@ -3826,7 +3829,7 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { // If the original function does not have a `@differentiable` attribute with // the same differentiation parameters, create one. if (!da) { - da = DifferentiableAttr::create(ctx, /*implicit*/ true, attr->AtLoc, + da = DifferentiableAttr::create(originalFn, /*implicit*/ true, attr->AtLoc, attr->getRange(), attr->isLinear(), checkedWrtParamIndices, /*jvp*/ None, /*vjp*/ None, diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index bfc99aa5dea79..4d08fcece71d6 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -567,8 +567,9 @@ swift::matchWitness( if (!reqDiffAttrMatch) { auto implicitDiffAttr = false; if (reqDiffAttrSupersetMatch) { + auto *witnessAFD = cast(witness); auto *newAttr = DifferentiableAttr::create( - ctx, /*implicit*/ true, reqDiffAttr->AtLoc, + witnessAFD, /*implicit*/ true, reqDiffAttr->AtLoc, reqDiffAttr->getRange(), reqDiffAttr->isLinear(), reqDiffAttr->getParameterIndices(), /*jvp*/ None, /*vjp*/ None, reqDiffAttr->getDerivativeGenericSignature()); diff --git a/lib/Serialization/Deserialization.cpp b/lib/Serialization/Deserialization.cpp index d8e79bbef467d..1acf8dee433d9 100644 --- a/lib/Serialization/Deserialization.cpp +++ b/lib/Serialization/Deserialization.cpp @@ -2166,6 +2166,24 @@ static bool attributeChainContains(DeclAttribute *attr) { return tempAttrs.hasAttribute(); } +// SWIFT_ENABLE_TENSORFLOW +// Set original declaration in `@differentiable` attributes. +// +// Serializing/deserializing the original declaration DeclID in +// `@differentiable` attributes does not work because it causes +// `@differentiable` attribute deserialization to enter an infinite loop. +// +// Instead, call this ad-hoc function after deserializing a declaration to set +// it as the original declaration in its `@differentiable` attributes. +static void setOriginalDeclarationInDifferentiableAttributes( + Decl *decl, DeclAttribute *attrs) { + DeclAttributes tempAttrs; + tempAttrs.setRawAttributeChain(attrs); + for (auto *attr : tempAttrs.getAttributes()) + const_cast(attr)->setOriginalDeclaration(decl); +} +// SWIFT_ENABLE_TENSORFLOW END + Decl *ModuleFile::getDecl(DeclID DID) { Expected deserialized = getDeclChecked(DID); if (!deserialized) { @@ -4084,10 +4102,11 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() { parametersBitVector[i] = parameters[i]; auto *indices = IndexSubset::get(ctx, parametersBitVector); - auto diffAttr = - DifferentiableAttr::create(ctx, isImplicit, SourceLoc(), - SourceRange(), linear, indices, jvp, vjp, - derivativeGenSig); + auto *diffAttr = DifferentiableAttr::create( + ctx, isImplicit, SourceLoc(), SourceRange(), linear, + /*parsedParameters*/ {}, jvp, vjp, /*trailingWhereClause*/ nullptr); + diffAttr->setParameterIndices(indices); + diffAttr->setDerivativeGenericSignature(derivativeGenSig); diffAttr->setJVPFunction(jvpDecl); diffAttr->setVJPFunction(vjpDecl); Attr = diffAttr; @@ -4237,9 +4256,16 @@ DeclDeserializer::getDeclCheckedImpl() { &MF, declOrOffset, static_cast(recordID)); switch (recordID) { + // SWIFT_ENABLE_TENSORFLOW + // Set original declaration in `@differentiable` attributes. #define CASE(RECORD_NAME) \ - case decls_block::RECORD_NAME##Layout::Code: \ - return deserialize##RECORD_NAME(scratch, blobData); + case decls_block::RECORD_NAME##Layout::Code: {\ + auto decl = deserialize##RECORD_NAME(scratch, blobData); \ + if (decl) \ + setOriginalDeclarationInDifferentiableAttributes(decl.get(), DAttrs); \ + return decl; \ + } + // SWIFT_ENABLE_TENSORFLOW END CASE(TypeAlias) CASE(GenericTypeParamDecl) diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index 9650360c9e2b9..97e6d6a273e7a 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -2309,6 +2309,9 @@ class Serializer::DeclSerializer : public DeclVisitor { case DAK_Differentiable: { auto abbrCode = S.DeclTypeAbbrCodes[DifferentiableDeclAttrLayout::Code]; auto *attr = cast(DA); + assert(attr->getOriginalDeclaration() && + "`@differentiable` attribute should have original declaration set " + "during construction or parsing"); IdentifierID jvpName = 0; DeclID jvpRef = 0; diff --git a/test/AutoDiff/differentiable_attr_serialization.swift b/test/AutoDiff/differentiable_attr_serialization.swift new file mode 100644 index 0000000000000..9e634f44613db --- /dev/null +++ b/test/AutoDiff/differentiable_attr_serialization.swift @@ -0,0 +1,32 @@ +// RUN: %empty-directory(%t) +// RUN: %target-swift-frontend -emit-module %s -o %t/differentiable_attr_serialization.swiftmodule +// RUN: %target-swift-frontend -merge-modules -sil-merge-partial-modules -emit-module %t/differentiable_attr_serialization.swiftmodule + +// Test round-trip `@differentiable` attribute AST serialization. + +// Motivation: check that `@differentiable` attributes always have original +// declaration set. + +struct Foo: Differentiable { + @differentiable + func method() -> Self { self } + + @differentiable + init(_ x: Float) {} + + @differentiable + var computedProperty: Float { 1 } + + var computedPropertyGetter: Float { + @differentiable + get { 1 } + } + + @differentiable + subscript() -> Float { 1 } + + subscript(_ x: Float) -> Float { + @differentiable + get { 1 } + } +}