From dabd2d847147868421eaf7652465f4968a82165b Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Wed, 30 Oct 2019 19:19:42 -0700 Subject: [PATCH 1/5] [AutoDiff] Store original declaration in `DifferentiableAttr`. Store original `AbstractFunctionDecl` in `DifferentiableAttr`. This is important for requestifying `DifferentiableAttr->getParameterIndices()`: we want the ability to resolve parameter indices without needing to pass the original `AbstractFunctionDecl` to `getParameterIndices`. --- include/swift/AST/Attr.h | 16 ++++++++++---- lib/AST/Attr.cpp | 22 +++++++++++++------ lib/Parse/ParseDecl.cpp | 13 +++++++++++ lib/Sema/DerivedConformanceDifferentiable.cpp | 5 +++-- lib/Sema/TypeCheckAttr.cpp | 16 ++++++++------ lib/Sema/TypeCheckProtocol.cpp | 3 ++- lib/Serialization/Deserialization.cpp | 8 ++++--- lib/Serialization/ModuleFormat.h | 3 ++- lib/Serialization/Serialization.cpp | 6 ++++- 9 files changed, 66 insertions(+), 26 deletions(-) diff --git a/include/swift/AST/Attr.h b/include/swift/AST/Attr.h index fc140c38926a8..b823c8f573696 100644 --- a/include/swift/AST/Attr.h +++ b/include/swift/AST/Attr.h @@ -1541,6 +1541,8 @@ class DifferentiableAttr final ParsedAutoDiffParameter> { friend TrailingObjects; + /// The declaration on which the `@differentiable` attribute is declared. + AbstractFunctionDecl *OriginalFunction = nullptr; /// Whether this function is linear. bool Linear; /// The number of parsed parameters specified in 'wrt:'. @@ -1573,7 +1575,7 @@ class DifferentiableAttr final Optional vjp, TrailingWhereClause *clause); - explicit DifferentiableAttr(ASTContext &context, bool implicit, + explicit DifferentiableAttr(AbstractFunctionDecl *original, bool implicit, SourceLoc atLoc, SourceRange baseRange, bool linear, IndexSubset *indices, Optional jvp, @@ -1589,13 +1591,19 @@ class DifferentiableAttr final Optional vjp, TrailingWhereClause *clause); - static DifferentiableAttr *create(ASTContext &context, bool implicit, - SourceLoc atLoc, SourceRange baseRange, - bool linear, IndexSubset *indices, + static DifferentiableAttr *create(AbstractFunctionDecl *original, + bool implicit, SourceLoc atLoc, + SourceRange baseRange, bool linear, + IndexSubset *indices, Optional jvp, Optional vjp, GenericSignature derivativeGenSig); + AbstractFunctionDecl *getOriginalFunction() const { + return OriginalFunction; + } + void setOriginalFunction(AbstractFunctionDecl *decl); + /// Get the optional 'jvp:' function name and location. /// Use this instead of `getJVPFunction` to check whether the attribute has a /// registered JVP. diff --git a/lib/AST/Attr.cpp b/lib/AST/Attr.cpp index d35bbeb0783d3..39def50cacb08 100644 --- a/lib/AST/Attr.cpp +++ b/lib/AST/Attr.cpp @@ -1454,9 +1454,9 @@ DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit, getTrailingObjects()); } -DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit, - SourceLoc atLoc, SourceRange baseRange, - bool linear, +DifferentiableAttr::DifferentiableAttr(AbstractFunctionDecl *original, + bool implicit, SourceLoc atLoc, + SourceRange baseRange, bool linear, IndexSubset *indices, Optional jvp, Optional vjp, @@ -1464,6 +1464,7 @@ DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit, : DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit), Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)), ParameterIndices(indices) { + setOriginalFunction(original); setDerivativeGenericSignature(derivativeGenSig); } @@ -1483,19 +1484,26 @@ DifferentiableAttr::create(ASTContext &context, bool implicit, } DifferentiableAttr * -DifferentiableAttr::create(ASTContext &context, bool implicit, +DifferentiableAttr::create(AbstractFunctionDecl *original, bool implicit, SourceLoc atLoc, SourceRange baseRange, bool linear, IndexSubset *indices, Optional jvp, Optional vjp, GenericSignature derivativeGenSig) { - void *mem = context.Allocate(sizeof(DifferentiableAttr), - alignof(DifferentiableAttr)); - return new (mem) DifferentiableAttr(context, implicit, atLoc, baseRange, + auto &ctx = original->getASTContext(); + void *mem = ctx.Allocate(sizeof(DifferentiableAttr), + alignof(DifferentiableAttr)); + return new (mem) DifferentiableAttr(original, implicit, atLoc, baseRange, linear, indices, std::move(jvp), std::move(vjp), derivativeGenSig); } +void DifferentiableAttr::setOriginalFunction(AbstractFunctionDecl *decl) { + assert(!OriginalFunction && "Original function cannot have already been set"); + assert(decl && "Original function must be non-null"); + OriginalFunction = decl; +} + void DifferentiableAttr::setJVPFunction(FuncDecl *decl) { JVPFunction = decl; if (decl && !JVP) diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index b80618cfadc68..62c1dfde49dd9 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -3679,6 +3679,7 @@ Parser::parseDecl(ParseDeclOptions Flags, Decl *D = DeclResult.get(); if (!declWasHandledAlready(D)) { Handler(D); + // SWIFT_ENABLE_TENSORFLOW if (auto FD = dyn_cast(D)) { if (auto attr = D->getAttrs().getAttribute()) { // TODO(TF-718): Properly mangle names for quote decls. @@ -3716,6 +3717,18 @@ Parser::parseDecl(ParseDeclOptions Flags, Handler(quoteDecl); } } + + if (D->getAttrs().hasAttribute()) { + auto *AFD = dyn_cast(D); + if (auto *ASD = dyn_cast(D)) + AFD = ASD->getAccessor(AccessorKind::Get); + assert(AFD && "Must resolve '@differentiable' attribute declaration"); + for (auto *attr : D->getAttrs().getAttributes()) { + auto *diffAttr = const_cast(attr); + diffAttr->setOriginalFunction(AFD); + } + } + // SWIFT_ENABLE_TENSORFLOW END } } diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index dfd50b132b695..b641ca72cd1b5 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -641,8 +641,9 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) { if (auto *extDecl = dyn_cast(parentDC->getAsDecl())) derivativeGenSig = extDecl->getGenericSignature(); auto *diffableAttr = DifferentiableAttr::create( - C, /*implicit*/ true, SourceLoc(), SourceLoc(), - /*linear*/ false, {}, None, None, derivativeGenSig); + member->getAccessor(AccessorKind::Get), /*implicit*/ true, + SourceLoc(), SourceLoc(), /*linear*/ false, {}, None, None, + derivativeGenSig); member->getAttrs().add(diffableAttr); // Set getter `@differentiable` attribute parameter indices. diffableAttr->setParameterIndices(IndexSubset::get(C, 1, {0})); diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index eec7a9f7313f2..35b9983187349 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -3274,6 +3274,8 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { return; } + if (!attr->getOriginalFunction()) + attr->setOriginalFunction(original); TC.resolveDeclSignature(original); auto *originalFnTy = original->getInterfaceType()->castTo(); bool isMethod = original->hasImplicitSelfDecl(); @@ -3530,15 +3532,15 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { D->getAttrs().removeAttribute(attr); // Transfer `@differentiable` attribute from storage declaration to // getter accessor. + auto *getterDecl = asd->getAccessor(AccessorKind::Get); auto *newAttr = DifferentiableAttr::create( - ctx, /*implicit*/ true, attr->AtLoc, attr->getRange(), attr->isLinear(), - attr->getParameterIndices(), attr->getJVP(), attr->getVJP(), - attr->getDerivativeGenericSignature()); + getterDecl, /*implicit*/ true, attr->AtLoc, attr->getRange(), + attr->isLinear(), attr->getParameterIndices(), attr->getJVP(), + attr->getVJP(), attr->getDerivativeGenericSignature()); newAttr->setJVPFunction(attr->getJVPFunction()); newAttr->setVJPFunction(attr->getVJPFunction()); auto insertion = ctx.DifferentiableAttrs.try_emplace( - {asd->getAccessor(AccessorKind::Get), newAttr->getParameterIndices()}, - newAttr); + {getterDecl, newAttr->getParameterIndices()}, newAttr); // Valid `@differentiable` attributes are uniqued by their parameter // indices. Reject duplicate attributes for the same decl and parameter // indices pair. @@ -3548,7 +3550,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { diag::differentiable_attr_duplicate_note); return; } - asd->getAccessor(AccessorKind::Get)->getAttrs().add(newAttr); + getterDecl->getAttrs().add(newAttr); return; } auto insertion = ctx.DifferentiableAttrs.try_emplace( @@ -3826,7 +3828,7 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { // If the original function does not have a `@differentiable` attribute with // the same differentiation parameters, create one. if (!da) { - da = DifferentiableAttr::create(ctx, /*implicit*/ true, attr->AtLoc, + da = DifferentiableAttr::create(originalFn, /*implicit*/ true, attr->AtLoc, attr->getRange(), attr->isLinear(), checkedWrtParamIndices, /*jvp*/ None, /*vjp*/ None, diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index bfc99aa5dea79..4d08fcece71d6 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -567,8 +567,9 @@ swift::matchWitness( if (!reqDiffAttrMatch) { auto implicitDiffAttr = false; if (reqDiffAttrSupersetMatch) { + auto *witnessAFD = cast(witness); auto *newAttr = DifferentiableAttr::create( - ctx, /*implicit*/ true, reqDiffAttr->AtLoc, + witnessAFD, /*implicit*/ true, reqDiffAttr->AtLoc, reqDiffAttr->getRange(), reqDiffAttr->isLinear(), reqDiffAttr->getParameterIndices(), /*jvp*/ None, /*vjp*/ None, reqDiffAttr->getDerivativeGenericSignature()); diff --git a/lib/Serialization/Deserialization.cpp b/lib/Serialization/Deserialization.cpp index d8e79bbef467d..f113f0d635810 100644 --- a/lib/Serialization/Deserialization.cpp +++ b/lib/Serialization/Deserialization.cpp @@ -4051,6 +4051,7 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() { // SWIFT_ENABLE_TENSORFLOW case decls_block::Differentiable_DECL_ATTR: { bool isImplicit; + DeclID originalDeclId; bool linear; uint64_t jvpNameId; DeclID jvpDeclId; @@ -4060,9 +4061,10 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() { ArrayRef parameters; serialization::decls_block::DifferentiableDeclAttrLayout::readRecord( - scratch, isImplicit, linear, jvpNameId, jvpDeclId, vjpNameId, - vjpDeclId, derivativeGenSigId, parameters); + scratch, isImplicit, originalDeclId, linear, jvpNameId, jvpDeclId, + vjpNameId, vjpDeclId, derivativeGenSigId, parameters); + FuncDecl *originalDecl = cast(MF.getDecl(originalDeclId)); Optional jvp; FuncDecl *jvpDecl = nullptr; if (jvpNameId != 0) @@ -4085,7 +4087,7 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() { auto *indices = IndexSubset::get(ctx, parametersBitVector); auto diffAttr = - DifferentiableAttr::create(ctx, isImplicit, SourceLoc(), + DifferentiableAttr::create(originalDecl, isImplicit, SourceLoc(), SourceRange(), linear, indices, jvp, vjp, derivativeGenSig); diffAttr->setJVPFunction(jvpDecl); diff --git a/lib/Serialization/ModuleFormat.h b/lib/Serialization/ModuleFormat.h index b95f6272e8bb2..502e5e27103b4 100644 --- a/lib/Serialization/ModuleFormat.h +++ b/lib/Serialization/ModuleFormat.h @@ -52,7 +52,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0; /// describe what change you made. The content of this comment isn't important; /// it just ensures a conflict if two people change the module format. /// Don't worry about adhering to the 80-column limit for this line. -const uint16_t SWIFTMODULE_VERSION_MINOR = 524; // differentiable_function_extract explicit extractee type +const uint16_t SWIFTMODULE_VERSION_MINOR = 525; // @differentiable attribute original declaration /// A standard hash seed used for all string hashes in a serialized module. /// @@ -1759,6 +1759,7 @@ namespace decls_block { using DifferentiableDeclAttrLayout = BCRecordLayout< Differentiable_DECL_ATTR, BCFixed<1>, // Implicit flag. + DeclIDField, // Original function declaration. BCFixed<1>, // Linear flag. IdentifierIDField, // JVP name. DeclIDField, // JVP function declaration. diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index 9650360c9e2b9..c8cd4a5304978 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -2310,6 +2310,10 @@ class Serializer::DeclSerializer : public DeclVisitor { auto abbrCode = S.DeclTypeAbbrCodes[DifferentiableDeclAttrLayout::Code]; auto *attr = cast(DA); + assert(attr->getOriginalFunction() && + "@differentiable attribute must have original function resolved"); + DeclID originalRef = S.addDeclRef(attr->getOriginalFunction()); + IdentifierID jvpName = 0; DeclID jvpRef = 0; if (auto jvp = attr->getJVP()) @@ -2331,7 +2335,7 @@ class Serializer::DeclSerializer : public DeclVisitor { indices.push_back(paramIndices->contains(i)); DifferentiableDeclAttrLayout::emitRecord( - S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), + S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), originalRef, attr->isLinear(), jvpName, jvpRef, vjpName, vjpRef, S.addGenericSignatureRef(attr->getDerivativeGenericSignature()), indices); From 4df03077b64db926a03d4551f4ce3599ca4bb54b Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Wed, 30 Oct 2019 20:44:17 -0700 Subject: [PATCH 2/5] Avoid serializing original declaration in `@differentiable` attribute. Deserializing the original declaration DeclID in `@differentiable` attributes does not work because it causes `@differentiable` attribute deserialization to enter an infinite loop. Instead, use ad-hoc deserialization logic to set the original declaration in `@differentiable` attributes. Add round-trip `@differentiable` attribute AST serialization test. `@differentiable` attribute serialization asserts that the original declaration is set. --- lib/Parse/ParseDecl.cpp | 10 ++--- lib/Serialization/Deserialization.cpp | 43 +++++++++++++++---- lib/Serialization/ModuleFormat.h | 3 +- lib/Serialization/Serialization.cpp | 4 +- .../differentiable_attr_serialization.swift | 22 ++++++++++ 5 files changed, 64 insertions(+), 18 deletions(-) create mode 100644 test/AutoDiff/differentiable_attr_serialization.swift diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index 62c1dfde49dd9..d653b1a6603e4 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -3718,11 +3718,11 @@ Parser::parseDecl(ParseDeclOptions Flags, } } - if (D->getAttrs().hasAttribute()) { - auto *AFD = dyn_cast(D); - if (auto *ASD = dyn_cast(D)) - AFD = ASD->getAccessor(AccessorKind::Get); - assert(AFD && "Must resolve '@differentiable' attribute declaration"); + // Set original declaration in `@differentiable` attributes. + auto *AFD = dyn_cast(D); + if (auto *ASD = dyn_cast(D)) + AFD = ASD->getAccessor(AccessorKind::Get); + if (AFD) { for (auto *attr : D->getAttrs().getAttributes()) { auto *diffAttr = const_cast(attr); diffAttr->setOriginalFunction(AFD); diff --git a/lib/Serialization/Deserialization.cpp b/lib/Serialization/Deserialization.cpp index f113f0d635810..9be79fb4efb54 100644 --- a/lib/Serialization/Deserialization.cpp +++ b/lib/Serialization/Deserialization.cpp @@ -2166,6 +2166,24 @@ static bool attributeChainContains(DeclAttribute *attr) { return tempAttrs.hasAttribute(); } +// SWIFT_ENABLE_TENSORFLOW +// Set original declaration in `@differentiable` attributes. +// +// Serializing/deserializing the original declaration DeclID in +// `@differentiable` attributes does not work because it causes +// `@differentiable` attribute deserialization to enter an infinite loop. +// +// Instead, call this ad-hoc function after deserializing a declaration to set +// it as the original declaration in its `@differentiable` attributes. +static void setOriginalDeclarationInDifferentiableAttributes( + AbstractFunctionDecl *decl, DeclAttribute *attrs) { + DeclAttributes tempAttrs; + tempAttrs.setRawAttributeChain(attrs); + for (auto *attr : tempAttrs.getAttributes()) + const_cast(attr)->setOriginalFunction(decl); +} +// SWIFT_ENABLE_TENSORFLOW END + Decl *ModuleFile::getDecl(DeclID DID) { Expected deserialized = getDeclChecked(DID); if (!deserialized) { @@ -2566,6 +2584,11 @@ class swift::DeclDeserializer { ctor->setImplicitlyUnwrappedOptional(isIUO); ctor->computeType(); + // SWIFT_ENABLE_TENSORFLOW + // Set original declaration in `@differentiable` attributes. + setOriginalDeclarationInDifferentiableAttributes(ctor, DAttrs); + // SWIFT_ENABLE_TENSORFLOW END + return ctor; } @@ -3039,6 +3062,11 @@ class swift::DeclDeserializer { // Set the interface type. fn->computeType(); + // SWIFT_ENABLE_TENSORFLOW + // Set original declaration in `@differentiable` attributes. + setOriginalDeclarationInDifferentiableAttributes(fn, DAttrs); + // SWIFT_ENABLE_TENSORFLOW END + return fn; } @@ -4051,7 +4079,6 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() { // SWIFT_ENABLE_TENSORFLOW case decls_block::Differentiable_DECL_ATTR: { bool isImplicit; - DeclID originalDeclId; bool linear; uint64_t jvpNameId; DeclID jvpDeclId; @@ -4061,10 +4088,9 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() { ArrayRef parameters; serialization::decls_block::DifferentiableDeclAttrLayout::readRecord( - scratch, isImplicit, originalDeclId, linear, jvpNameId, jvpDeclId, - vjpNameId, vjpDeclId, derivativeGenSigId, parameters); + scratch, isImplicit, linear, jvpNameId, jvpDeclId, vjpNameId, + vjpDeclId, derivativeGenSigId, parameters); - FuncDecl *originalDecl = cast(MF.getDecl(originalDeclId)); Optional jvp; FuncDecl *jvpDecl = nullptr; if (jvpNameId != 0) @@ -4086,10 +4112,11 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() { parametersBitVector[i] = parameters[i]; auto *indices = IndexSubset::get(ctx, parametersBitVector); - auto diffAttr = - DifferentiableAttr::create(originalDecl, isImplicit, SourceLoc(), - SourceRange(), linear, indices, jvp, vjp, - derivativeGenSig); + auto *diffAttr = DifferentiableAttr::create( + ctx, isImplicit, SourceLoc(), SourceRange(), linear, + /*parsedParameters*/ {}, jvp, vjp, /*trailingWhereClause*/ nullptr); + diffAttr->setParameterIndices(indices); + diffAttr->setDerivativeGenericSignature(derivativeGenSig); diffAttr->setJVPFunction(jvpDecl); diffAttr->setVJPFunction(vjpDecl); Attr = diffAttr; diff --git a/lib/Serialization/ModuleFormat.h b/lib/Serialization/ModuleFormat.h index 502e5e27103b4..b95f6272e8bb2 100644 --- a/lib/Serialization/ModuleFormat.h +++ b/lib/Serialization/ModuleFormat.h @@ -52,7 +52,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0; /// describe what change you made. The content of this comment isn't important; /// it just ensures a conflict if two people change the module format. /// Don't worry about adhering to the 80-column limit for this line. -const uint16_t SWIFTMODULE_VERSION_MINOR = 525; // @differentiable attribute original declaration +const uint16_t SWIFTMODULE_VERSION_MINOR = 524; // differentiable_function_extract explicit extractee type /// A standard hash seed used for all string hashes in a serialized module. /// @@ -1759,7 +1759,6 @@ namespace decls_block { using DifferentiableDeclAttrLayout = BCRecordLayout< Differentiable_DECL_ATTR, BCFixed<1>, // Implicit flag. - DeclIDField, // Original function declaration. BCFixed<1>, // Linear flag. IdentifierIDField, // JVP name. DeclIDField, // JVP function declaration. diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index c8cd4a5304978..81878804293f3 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -2309,10 +2309,8 @@ class Serializer::DeclSerializer : public DeclVisitor { case DAK_Differentiable: { auto abbrCode = S.DeclTypeAbbrCodes[DifferentiableDeclAttrLayout::Code]; auto *attr = cast(DA); - assert(attr->getOriginalFunction() && "@differentiable attribute must have original function resolved"); - DeclID originalRef = S.addDeclRef(attr->getOriginalFunction()); IdentifierID jvpName = 0; DeclID jvpRef = 0; @@ -2335,7 +2333,7 @@ class Serializer::DeclSerializer : public DeclVisitor { indices.push_back(paramIndices->contains(i)); DifferentiableDeclAttrLayout::emitRecord( - S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), originalRef, + S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), attr->isLinear(), jvpName, jvpRef, vjpName, vjpRef, S.addGenericSignatureRef(attr->getDerivativeGenericSignature()), indices); diff --git a/test/AutoDiff/differentiable_attr_serialization.swift b/test/AutoDiff/differentiable_attr_serialization.swift new file mode 100644 index 0000000000000..bfaa2c453396a --- /dev/null +++ b/test/AutoDiff/differentiable_attr_serialization.swift @@ -0,0 +1,22 @@ +// RUN: %empty-directory(%t) +// RUN: %target-swift-frontend -emit-module %s -o %t/differentiable_attr_serialization.swiftmodule +// RUN: %target-swift-frontend -merge-modules -sil-merge-partial-modules -emit-module %t/differentiable_attr_serialization.swiftmodule + +// Test round-trip `@differentiable` attribute AST serialization. + +// Motivation: check that `@differentiable` attributes always have original +// declaration set. + +struct Foo: Differentiable { + @differentiable + func method() -> Self { self } + + @differentiable + init(_ x: Float) {} + + @differentiable + var computedProperty: Float { 1 } + + @differentiable + subscript() -> Float { 1 } +} From e2be0a2004395492e6e980deacbfe397f1c2d105 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 31 Oct 2019 16:53:41 -0700 Subject: [PATCH 3/5] Store original `Decl *` in `DifferentiableAttr`. Store original declaration `Decl *` in `DifferentiableAttr` instead of casting to `AbstractFunctionDecl *`. This helps ensure that the original declaration is set. --- include/swift/AST/Attr.h | 17 ++++------ lib/AST/Attr.cpp | 25 +++++++------- lib/Parse/ParseDecl.cpp | 34 +++++++++++++------ lib/Sema/TypeCheckAttr.cpp | 4 +-- lib/Serialization/Deserialization.cpp | 2 +- lib/Serialization/Serialization.cpp | 5 +-- .../differentiable_attr_serialization.swift | 10 ++++++ 7 files changed, 58 insertions(+), 39 deletions(-) diff --git a/include/swift/AST/Attr.h b/include/swift/AST/Attr.h index b823c8f573696..8a50ee30d52a5 100644 --- a/include/swift/AST/Attr.h +++ b/include/swift/AST/Attr.h @@ -1542,7 +1542,7 @@ class DifferentiableAttr final friend TrailingObjects; /// The declaration on which the `@differentiable` attribute is declared. - AbstractFunctionDecl *OriginalFunction = nullptr; + Decl *OriginalDeclaration = nullptr; /// Whether this function is linear. bool Linear; /// The number of parsed parameters specified in 'wrt:'. @@ -1575,7 +1575,7 @@ class DifferentiableAttr final Optional vjp, TrailingWhereClause *clause); - explicit DifferentiableAttr(AbstractFunctionDecl *original, bool implicit, + explicit DifferentiableAttr(Decl *original, bool implicit, SourceLoc atLoc, SourceRange baseRange, bool linear, IndexSubset *indices, Optional jvp, @@ -1591,18 +1591,15 @@ class DifferentiableAttr final Optional vjp, TrailingWhereClause *clause); - static DifferentiableAttr *create(AbstractFunctionDecl *original, - bool implicit, SourceLoc atLoc, - SourceRange baseRange, bool linear, - IndexSubset *indices, + static DifferentiableAttr *create(Decl *original, bool implicit, + SourceLoc atLoc, SourceRange baseRange, + bool linear, IndexSubset *indices, Optional jvp, Optional vjp, GenericSignature derivativeGenSig); - AbstractFunctionDecl *getOriginalFunction() const { - return OriginalFunction; - } - void setOriginalFunction(AbstractFunctionDecl *decl); + Decl *getOriginalDeclaration() const { return OriginalDeclaration; } + void setOriginalDeclaration(Decl *decl); /// Get the optional 'jvp:' function name and location. /// Use this instead of `getJVPFunction` to check whether the attribute has a diff --git a/lib/AST/Attr.cpp b/lib/AST/Attr.cpp index 39def50cacb08..64136ae943db2 100644 --- a/lib/AST/Attr.cpp +++ b/lib/AST/Attr.cpp @@ -1454,17 +1454,16 @@ DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit, getTrailingObjects()); } -DifferentiableAttr::DifferentiableAttr(AbstractFunctionDecl *original, - bool implicit, SourceLoc atLoc, - SourceRange baseRange, bool linear, - IndexSubset *indices, +DifferentiableAttr::DifferentiableAttr(Decl *original, bool implicit, + SourceLoc atLoc, SourceRange baseRange, + bool linear, IndexSubset *indices, Optional jvp, Optional vjp, GenericSignature derivativeGenSig) : DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit), Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)), ParameterIndices(indices) { - setOriginalFunction(original); + setOriginalDeclaration(original); setDerivativeGenericSignature(derivativeGenSig); } @@ -1484,10 +1483,9 @@ DifferentiableAttr::create(ASTContext &context, bool implicit, } DifferentiableAttr * -DifferentiableAttr::create(AbstractFunctionDecl *original, bool implicit, - SourceLoc atLoc, SourceRange baseRange, - bool linear, IndexSubset *indices, - Optional jvp, +DifferentiableAttr::create(Decl *original, bool implicit, SourceLoc atLoc, + SourceRange baseRange, bool linear, + IndexSubset *indices, Optional jvp, Optional vjp, GenericSignature derivativeGenSig) { auto &ctx = original->getASTContext(); @@ -1498,10 +1496,11 @@ DifferentiableAttr::create(AbstractFunctionDecl *original, bool implicit, std::move(vjp), derivativeGenSig); } -void DifferentiableAttr::setOriginalFunction(AbstractFunctionDecl *decl) { - assert(!OriginalFunction && "Original function cannot have already been set"); - assert(decl && "Original function must be non-null"); - OriginalFunction = decl; +void DifferentiableAttr::setOriginalDeclaration(Decl *decl) { + assert(decl && "Original declaration must be non-null"); + assert(!OriginalDeclaration && + "Original declaration cannot have already been set"); + OriginalDeclaration = decl; } void DifferentiableAttr::setJVPFunction(FuncDecl *decl) { diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index d653b1a6603e4..50dbfd027c4d8 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -3272,6 +3272,13 @@ void Parser::delayParseFromBeginningToHere(ParserPosition BeginParserPosition, consumeToken(); } +// SWIFT_ENABLE_TENSORFLOW +static void setOriginalFunctionInDifferentiableAttributes( + DeclAttributes Attributes, Decl *D) { + for (auto *attr : Attributes.getAttributes()) + const_cast(attr)->setOriginalDeclaration(D); +} + /// Parse a single syntactic declaration and return a list of decl /// ASTs. This can return multiple results for var decls that bind to multiple /// values, structs that define a struct decl and a constructor, etc. @@ -3717,17 +3724,7 @@ Parser::parseDecl(ParseDeclOptions Flags, Handler(quoteDecl); } } - - // Set original declaration in `@differentiable` attributes. - auto *AFD = dyn_cast(D); - if (auto *ASD = dyn_cast(D)) - AFD = ASD->getAccessor(AccessorKind::Get); - if (AFD) { - for (auto *attr : D->getAttrs().getAttributes()) { - auto *diffAttr = const_cast(attr); - diffAttr->setOriginalFunction(AFD); - } - } + setOriginalFunctionInDifferentiableAttributes(D->getAttrs(), D); // SWIFT_ENABLE_TENSORFLOW END } } @@ -5526,6 +5523,12 @@ Parser::parseDeclVarGetSet(Pattern *pattern, ParseDeclOptions Flags, accessors.record(*this, PrimaryVar, Invalid); + // SWIFT_ENABLE_TENSORFLOW + for (auto *accessor : accessors.Accessors) + setOriginalFunctionInDifferentiableAttributes(accessor->getAttrs(), + accessor); + // SWIFT_ENABLE_TENSORFLOW END + return makeParserResult(PrimaryVar); } @@ -5786,6 +5789,9 @@ Parser::parseDeclVar(ParseDeclOptions Flags, pattern->forEachVariable([&](VarDecl *VD) { VD->setStatic(StaticLoc.isValid()); VD->getAttrs() = Attributes; + // SWIFT_ENABLE_TENSORFLOW + setOriginalFunctionInDifferentiableAttributes(Attributes, VD); + // SWIFT_ENABLE_TENSORFLOW END setLocalDiscriminator(VD); Decls.push_back(VD); if (hasOpaqueReturnTy && sf) { @@ -7038,6 +7044,12 @@ Parser::parseDeclSubscript(SourceLoc StaticLoc, accessors.record(*this, Subscript, (Invalid || !Status.isSuccess())); + // SWIFT_ENABLE_TENSORFLOW + for (auto *accessor : accessors.Accessors) + setOriginalFunctionInDifferentiableAttributes(accessor->getAttrs(), + accessor); + // SWIFT_ENABLE_TENSORFLOW END + // No need to setLocalDiscriminator because subscripts cannot // validly appear outside of type decls. return makeParserResult(Status, Subscript); diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 35b9983187349..00fec8b96c2f7 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -3274,8 +3274,8 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { return; } - if (!attr->getOriginalFunction()) - attr->setOriginalFunction(original); + if (!attr->getOriginalDeclaration()) + attr->setOriginalDeclaration(original); TC.resolveDeclSignature(original); auto *originalFnTy = original->getInterfaceType()->castTo(); bool isMethod = original->hasImplicitSelfDecl(); diff --git a/lib/Serialization/Deserialization.cpp b/lib/Serialization/Deserialization.cpp index 9be79fb4efb54..ff28bebbe625b 100644 --- a/lib/Serialization/Deserialization.cpp +++ b/lib/Serialization/Deserialization.cpp @@ -2180,7 +2180,7 @@ static void setOriginalDeclarationInDifferentiableAttributes( DeclAttributes tempAttrs; tempAttrs.setRawAttributeChain(attrs); for (auto *attr : tempAttrs.getAttributes()) - const_cast(attr)->setOriginalFunction(decl); + const_cast(attr)->setOriginalDeclaration(decl); } // SWIFT_ENABLE_TENSORFLOW END diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index 81878804293f3..0a3ec9a313060 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -2309,8 +2309,9 @@ class Serializer::DeclSerializer : public DeclVisitor { case DAK_Differentiable: { auto abbrCode = S.DeclTypeAbbrCodes[DifferentiableDeclAttrLayout::Code]; auto *attr = cast(DA); - assert(attr->getOriginalFunction() && - "@differentiable attribute must have original function resolved"); + assert(attr->getOriginalDeclaration() && + "@differentiable attribute must have original declaration " + "resolved"); IdentifierID jvpName = 0; DeclID jvpRef = 0; diff --git a/test/AutoDiff/differentiable_attr_serialization.swift b/test/AutoDiff/differentiable_attr_serialization.swift index bfaa2c453396a..9e634f44613db 100644 --- a/test/AutoDiff/differentiable_attr_serialization.swift +++ b/test/AutoDiff/differentiable_attr_serialization.swift @@ -17,6 +17,16 @@ struct Foo: Differentiable { @differentiable var computedProperty: Float { 1 } + var computedPropertyGetter: Float { + @differentiable + get { 1 } + } + @differentiable subscript() -> Float { 1 } + + subscript(_ x: Float) -> Float { + @differentiable + get { 1 } + } } From 12bd5680c2e41573761a62648d63185327b2df4e Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 31 Oct 2019 17:15:32 -0700 Subject: [PATCH 4/5] Fixes. - Add tight assertion `attr->getOriginalDeclaration` to TypeCheckAttr.cpp. - Move `setOriginalFunctionInDifferentiableAttributes` call out of `if (declWasHandledAlready(D))` body. - This was found necessary while requestifying parameter indices resolution. --- lib/Parse/ParseDecl.cpp | 4 +++- lib/Sema/TypeCheckAttr.cpp | 5 +++-- lib/Serialization/Serialization.cpp | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index 50dbfd027c4d8..5e2768d0209c6 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -3724,9 +3724,11 @@ Parser::parseDecl(ParseDeclOptions Flags, Handler(quoteDecl); } } - setOriginalFunctionInDifferentiableAttributes(D->getAttrs(), D); // SWIFT_ENABLE_TENSORFLOW END } + // SWIFT_ENABLE_TENSORFLOW + setOriginalFunctionInDifferentiableAttributes(D->getAttrs(), D); + // SWIFT_ENABLE_TENSORFLOW END } if (!DeclResult.isParseError()) { diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 00fec8b96c2f7..d02dc8c6d40db 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -3274,8 +3274,9 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { return; } - if (!attr->getOriginalDeclaration()) - attr->setOriginalDeclaration(original); + assert(attr->getOriginalDeclaration() && + "`@differentiable` attribute should have original declaration set " + "during construction or parsing"); TC.resolveDeclSignature(original); auto *originalFnTy = original->getInterfaceType()->castTo(); bool isMethod = original->hasImplicitSelfDecl(); diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index 0a3ec9a313060..97e6d6a273e7a 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -2310,8 +2310,8 @@ class Serializer::DeclSerializer : public DeclVisitor { auto abbrCode = S.DeclTypeAbbrCodes[DifferentiableDeclAttrLayout::Code]; auto *attr = cast(DA); assert(attr->getOriginalDeclaration() && - "@differentiable attribute must have original declaration " - "resolved"); + "`@differentiable` attribute should have original declaration set " + "during construction or parsing"); IdentifierID jvpName = 0; DeclID jvpRef = 0; From 8e2d8756eaf333fa1a1f7072754d2b9c00a86ead Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 31 Oct 2019 17:34:05 -0700 Subject: [PATCH 5/5] Always set original declaration while deserializing `@differentiable` attributes. --- lib/Serialization/Deserialization.cpp | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/lib/Serialization/Deserialization.cpp b/lib/Serialization/Deserialization.cpp index ff28bebbe625b..1acf8dee433d9 100644 --- a/lib/Serialization/Deserialization.cpp +++ b/lib/Serialization/Deserialization.cpp @@ -2176,7 +2176,7 @@ static bool attributeChainContains(DeclAttribute *attr) { // Instead, call this ad-hoc function after deserializing a declaration to set // it as the original declaration in its `@differentiable` attributes. static void setOriginalDeclarationInDifferentiableAttributes( - AbstractFunctionDecl *decl, DeclAttribute *attrs) { + Decl *decl, DeclAttribute *attrs) { DeclAttributes tempAttrs; tempAttrs.setRawAttributeChain(attrs); for (auto *attr : tempAttrs.getAttributes()) @@ -2584,11 +2584,6 @@ class swift::DeclDeserializer { ctor->setImplicitlyUnwrappedOptional(isIUO); ctor->computeType(); - // SWIFT_ENABLE_TENSORFLOW - // Set original declaration in `@differentiable` attributes. - setOriginalDeclarationInDifferentiableAttributes(ctor, DAttrs); - // SWIFT_ENABLE_TENSORFLOW END - return ctor; } @@ -3062,11 +3057,6 @@ class swift::DeclDeserializer { // Set the interface type. fn->computeType(); - // SWIFT_ENABLE_TENSORFLOW - // Set original declaration in `@differentiable` attributes. - setOriginalDeclarationInDifferentiableAttributes(fn, DAttrs); - // SWIFT_ENABLE_TENSORFLOW END - return fn; } @@ -4266,9 +4256,16 @@ DeclDeserializer::getDeclCheckedImpl() { &MF, declOrOffset, static_cast(recordID)); switch (recordID) { + // SWIFT_ENABLE_TENSORFLOW + // Set original declaration in `@differentiable` attributes. #define CASE(RECORD_NAME) \ - case decls_block::RECORD_NAME##Layout::Code: \ - return deserialize##RECORD_NAME(scratch, blobData); + case decls_block::RECORD_NAME##Layout::Code: {\ + auto decl = deserialize##RECORD_NAME(scratch, blobData); \ + if (decl) \ + setOriginalDeclarationInDifferentiableAttributes(decl.get(), DAttrs); \ + return decl; \ + } + // SWIFT_ENABLE_TENSORFLOW END CASE(TypeAlias) CASE(GenericTypeParamDecl)