From 09aed519226aa3d2456a5b21c3308fb408260ba5 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Mon, 30 Sep 2019 17:06:03 -0700 Subject: [PATCH 1/2] Add ASTScope support for `@differentiable` attribute. `@differentiable` attribute where clauses may refer to generic parameters from some generic context. Without special ASTScope support for `@differentiable` attributes, ASTScopeLookup.cpp logic tries to resolve the generic parameter DeclNames in the where clause based on source location alone (`ASTScopeImpl::findChildContaining`) and fails. The fix is to add a special `DifferentiableAttributeScope`, mimicking `SpecializeAttributeScope`. Every `@differentiable` attribute has its own scope, derived from the declaration on which it is declared. Unlike `@_specialize`, `@differentiable` may also be declared on `AbstractStorageDecl` declarations (subscripts and variables). Resolves TF-815. --- include/swift/AST/ASTScope.h | 38 ++++++++++++ lib/AST/ASTScope.cpp | 3 + lib/AST/ASTScopeCreation.cpp | 62 +++++++++++++++++++ lib/AST/ASTScopeLookup.cpp | 36 +++++++++++ lib/AST/ASTScopeSourceRange.cpp | 7 +++ lib/AST/UnqualifiedLookup.cpp | 7 +-- .../astscope-differentiable-attr.swift | 54 ++++++++++++++++ 7 files changed, 201 insertions(+), 6 deletions(-) create mode 100644 test/NameBinding/astscope-differentiable-attr.swift diff --git a/include/swift/AST/ASTScope.h b/include/swift/AST/ASTScope.h index 30d1d91dfd952..7a6bc43b825c5 100644 --- a/include/swift/AST/ASTScope.h +++ b/include/swift/AST/ASTScope.h @@ -1531,6 +1531,44 @@ class SpecializeAttributeScope final : public ASTScopeImpl { DeclConsumer) const override; }; +// SWIFT_ENABLE_TENSORFLOW +/// A `@differentiable` attribute scope. +/// This exists because `@differentiable` attribute may have a where clause +/// referring to generic parameters from some generic context. +class DifferentiableAttributeScope final : public ASTScopeImpl { +public: + DifferentiableAttr *const differentiableAttr; + ValueDecl *const attributedDeclaration; + + DifferentiableAttributeScope(DifferentiableAttr *diffAttr, + ValueDecl *decl) + : differentiableAttr(diffAttr), attributedDeclaration(decl) { + } + virtual ~DifferentiableAttributeScope() {} + + std::string getClassName() const override; + SourceRange + getSourceRangeOfThisASTNode(bool omitAssertions = false) const override; + NullablePtr addressForPrinting() const override { + return differentiableAttr; + } + + NullablePtr + getEnclosingAbstractStorageDecl() const override; + + NullablePtr getDeclAttributeIfAny() const override { + return differentiableAttr; + } + NullablePtr getReferrent() const override; + +protected: + ASTScopeImpl *expandSpecifically(ScopeCreator &) override; + bool lookupLocalsOrMembers(ArrayRef, + DeclConsumer) const override; + bool doesContextMatchStartingContext(const DeclContext *) const override; +}; +// SWIFT_ENABLE_TENSORFLOW END + class SubscriptDeclScope final : public ASTScopeImpl { public: SubscriptDecl *const decl; diff --git a/lib/AST/ASTScope.cpp b/lib/AST/ASTScope.cpp index 88133878bdd3b..adce677b4a120 100644 --- a/lib/AST/ASTScope.cpp +++ b/lib/AST/ASTScope.cpp @@ -231,6 +231,9 @@ DEFINE_GET_CLASS_NAME(ClosureParametersScope) DEFINE_GET_CLASS_NAME(ClosureBodyScope) DEFINE_GET_CLASS_NAME(TopLevelCodeScope) DEFINE_GET_CLASS_NAME(SpecializeAttributeScope) +// SWIFT_ENABLE_TENSORFLOW +DEFINE_GET_CLASS_NAME(DifferentiableAttributeScope) +// SWIFT_ENABLE_TENSORFLOW END DEFINE_GET_CLASS_NAME(SubscriptDeclScope) DEFINE_GET_CLASS_NAME(VarDeclScope) DEFINE_GET_CLASS_NAME(EnumElementScope) diff --git a/lib/AST/ASTScopeCreation.cpp b/lib/AST/ASTScopeCreation.cpp index 7d4b08d752cf1..44de37d6c66b2 100644 --- a/lib/AST/ASTScopeCreation.cpp +++ b/lib/AST/ASTScopeCreation.cpp @@ -60,6 +60,11 @@ static SourceRange getRangeableSourceRange(const Rangeable *const p) { static SourceRange getRangeableSourceRange(const SpecializeAttr *a) { return a->getRange(); } +// SWIFT_ENABLE_TENSORFLOW +static SourceRange getRangeableSourceRange(const DifferentiableAttr *a) { + return a->getRange(); +} +// SWIFT_ENABLE_TENSORFLOW END static SourceRange getRangeableSourceRange(const ASTNode n) { return n.getSourceRange(); } @@ -94,6 +99,19 @@ static void dumpRangeable(SpecializeAttr *r, llvm::raw_ostream &f) { llvm::errs() << "SpecializeAttr\n"; } +// SWIFT_ENABLE_TENSORFLOW +static void dumpRangeable(const DifferentiableAttr *a, + llvm::raw_ostream &f) LLVM_ATTRIBUTE_USED; +static void dumpRangeable(const DifferentiableAttr *a, llvm::raw_ostream &f) { + llvm::errs() << "DifferentiableAttr\n"; +} +static void dumpRangeable(DifferentiableAttr *a, + llvm::raw_ostream &f) LLVM_ATTRIBUTE_USED; +static void dumpRangeable(DifferentiableAttr *a, llvm::raw_ostream &f) { + llvm::errs() << "DifferentiableAttr\n"; +} +// SWIFT_ENABLE_TENSORFLOW END + /// For Debugging template bool doesRangeableRangeMatch(const T *x, const SourceManager &SM, @@ -435,6 +453,18 @@ class ScopeCreator final { fn(specializeAttr); } + // SWIFT_ENABLE_TENSORFLOW + void forEachDifferentiableAttrInSourceOrder( + Decl *decl, function_ref fn) { + std::vector sortedDifferentiableAttrs; + for (auto *attr : decl->getAttrs()) + if (auto *diffAttr = dyn_cast(attr)) + sortedDifferentiableAttrs.push_back(diffAttr); + for (auto *diffAttr : sortBySourceRange(sortedDifferentiableAttrs)) + fn(diffAttr); + } + // SWIFT_ENABLE_TENSORFLOW END + std::vector expandIfConfigClausesThenCullAndSortElementsOrMembers( ArrayRef input) const { auto cleanedupNodes = sortBySourceRange(cull(expandIfConfigClauses(input))); @@ -1045,6 +1075,15 @@ void ScopeCreator::addChildrenForAllLocalizableAccessorsInSourceOrder( return enclosingAbstractStorageDecl == ad->getStorage(); }); + // SWIFT_ENABLE_TENSORFLOW + // Create scopes for `@differentiable` attributes. + forEachDifferentiableAttrInSourceOrder( + asd, [&](DifferentiableAttr *diffAttr) { + ifUniqueConstructExpandAndInsert( + parent, diffAttr, asd); + }); + // SWIFT_ENABLE_TENSORFLOW END + // Sort in order to include synthesized ones, which are out of order. // Part of rdar://53921774 rm extra copy for (auto *accessor : sortBySourceRange(accessorsToScope)) @@ -1152,6 +1191,9 @@ NO_EXPANSION(GenericParamScope) NO_EXPANSION(ASTSourceFileScope) NO_EXPANSION(ClosureParametersScope) NO_EXPANSION(SpecializeAttributeScope) +// SWIFT_ENABLE_TENSORFLOW +NO_EXPANSION(DifferentiableAttributeScope) +// SWIFT_ENABLE_TENSORFLOW END NO_EXPANSION(ConditionalClausePatternUseScope) NO_EXPANSION(LookupParentDiversionScope) @@ -1309,6 +1351,17 @@ void AbstractFunctionDeclScope::expandAScopeThatDoesNotCreateANewInsertionPoint( scopeCreator.ifUniqueConstructExpandAndInsert( this, specializeAttr, decl); }); + + // SWIFT_ENABLE_TENSORFLOW + // Create scopes for `@differentiable` attributes. + scopeCreator.forEachDifferentiableAttrInSourceOrder( + decl, [&](DifferentiableAttr *diffAttr) { + scopeCreator + .ifUniqueConstructExpandAndInsert( + this, diffAttr, decl); + }); + // SWIFT_ENABLE_TENSORFLOW END + // Create scopes for generic and ordinary parameters. // For a subscript declaration, the generic and ordinary parameters are in an // ancestor scope, so don't make them here. @@ -1636,6 +1689,12 @@ NullablePtr SpecializeAttributeScope::getEnclosingAbstractStorageDecl() const { return getParent().get()->getEnclosingAbstractStorageDecl(); } +// SWIFT_ENABLE_TENSORFLOW +NullablePtr +DifferentiableAttributeScope::getEnclosingAbstractStorageDecl() const { + return getParent().get()->getEnclosingAbstractStorageDecl(); +} +// SWIFT_ENABLE_TENSORFLOW END NullablePtr AbstractFunctionDeclScope::getEnclosingAbstractStorageDecl() const { return getParent().get()->getEnclosingAbstractStorageDecl(); @@ -1784,6 +1843,9 @@ GET_REFERRENT(AbstractStmtScope, getStmt()) GET_REFERRENT(CaptureListScope, getExpr()) GET_REFERRENT(WholeClosureScope, getExpr()) GET_REFERRENT(SpecializeAttributeScope, specializeAttr) +// SWIFT_ENABLE_TENSORFLOW +GET_REFERRENT(DifferentiableAttributeScope, differentiableAttr) +// SWIFT_ENABLE_TENSORFLOW END GET_REFERRENT(GenericTypeOrExtensionScope, portion->getReferrentOfScope(this)); const Decl * diff --git a/lib/AST/ASTScopeLookup.cpp b/lib/AST/ASTScopeLookup.cpp index ee17af8631ce9..115542bbcef6a 100644 --- a/lib/AST/ASTScopeLookup.cpp +++ b/lib/AST/ASTScopeLookup.cpp @@ -194,6 +194,21 @@ bool GenericParamScope::doesContextMatchStartingContext( return false; } +// SWIFT_ENABLE_TENSORFLOW +bool DifferentiableAttributeScope::doesContextMatchStartingContext( + const DeclContext *context) const { + // Need special logic to handle case where `attributedDeclaration` is an + // `AbstractStorageDecl` (`SubscriptDecl` or `VarDecl`). The initial starting + // context in `ASTScopeImpl::findStartingScopeForLookup` will be an accessor + // of the `attributedDeclaration`. + if (auto *asd = dyn_cast(attributedDeclaration)) + for (auto accessor : asd->getAllAccessors()) + if (up_cast(accessor) == context) + return true; + return false; +} +// SWIFT_ENABLE_TENSORFLOW END + #pragma mark lookup methods that run once per scope void ASTScopeImpl::lookup(SmallVectorImpl &history, @@ -424,6 +439,27 @@ bool SpecializeAttributeScope::lookupLocalsOrMembers( return false; } +// SWIFT_ENABLE_TENSORFLOW +bool DifferentiableAttributeScope::lookupLocalsOrMembers( + ArrayRef, DeclConsumer consumer) const { + auto visitAbstractFunctionDecl = [&](AbstractFunctionDecl *afd) { + if (auto *params = afd->getGenericParams()) + for (auto *param : params->getParams()) + if (consumer.consume({param}, DeclVisibilityKind::GenericParameter)) + return true; + return false; + }; + if (auto *afd = dyn_cast(attributedDeclaration)) { + return visitAbstractFunctionDecl(afd); + } else if (auto *asd = dyn_cast(attributedDeclaration)) { + for (auto *accessor : asd->getAllAccessors()) + if (visitAbstractFunctionDecl(accessor)) + return true; + } + return false; +} +// SWIFT_ENABLE_TENSORFLOW END + bool BraceStmtScope::lookupLocalsOrMembers(ArrayRef, DeclConsumer consumer) const { // All types and functions are visible anywhere within a brace statement diff --git a/lib/AST/ASTScopeSourceRange.cpp b/lib/AST/ASTScopeSourceRange.cpp index b9105fb5117fb..3742a4be884ce 100644 --- a/lib/AST/ASTScopeSourceRange.cpp +++ b/lib/AST/ASTScopeSourceRange.cpp @@ -193,6 +193,13 @@ SourceRange SpecializeAttributeScope::getSourceRangeOfThisASTNode( return specializeAttr->getRange(); } +// SWIFT_ENABLE_TENSORFLOW +SourceRange DifferentiableAttributeScope::getSourceRangeOfThisASTNode( + const bool omitAssertions) const { + return differentiableAttr->getRange(); +} +// SWIFT_ENABLE_TENSORFLOW END + SourceRange AbstractFunctionBodyScope::getSourceRangeOfThisASTNode( const bool omitAssertions) const { return decl->getBodySourceRange(); diff --git a/lib/AST/UnqualifiedLookup.cpp b/lib/AST/UnqualifiedLookup.cpp index a6e5d8f1bc951..4cfbcc4e2aaeb 100644 --- a/lib/AST/UnqualifiedLookup.cpp +++ b/lib/AST/UnqualifiedLookup.cpp @@ -485,12 +485,7 @@ void UnqualifiedLookupFactory::performUnqualifiedLookup() { DC, initialIsCascadingUse}; const bool crosscheckUnqualifiedLookup = Ctx.LangOpts.CrosscheckUnqualifiedLookup; - // SWIFT_ENABLE_TENSORFLOW - // NOTE(TF-815): using AST scopes for lookup causes standard library - // type-checking for `@differentiable` attributes to fail. - if ((false)) { - // if (useASTScopesForLookup()) { - // SWIFT_ENABLE_TENSORFLOW END + if (useASTScopesForLookup()) { static bool haveWarned = false; if (!haveWarned && Ctx.LangOpts.WarnIfASTScopeLookup) { haveWarned = true; diff --git a/test/NameBinding/astscope-differentiable-attr.swift b/test/NameBinding/astscope-differentiable-attr.swift new file mode 100644 index 0000000000000..9514d2a4b34d7 --- /dev/null +++ b/test/NameBinding/astscope-differentiable-attr.swift @@ -0,0 +1,54 @@ +// SWIFT_ENABLE_TENSORFLOW +// Check that ASTScope lookup works for `@differentiable` attribute. + +// NOTE(TF-815): Without custom scope support, ASTScopeLookup crashes for +// `@differentiable` attribute with where clauses on subscript and `var` +// declarations. + +// RUN: %target-swift-frontend -typecheck %s -enable-astscope-lookup + +struct Test { + var element: Element +} +extension Test: Differentiable where Element: Differentiable {} +extension Test { + @differentiable(where Element: Differentiable) + init(_ element: Element) { + self.element = element + } + + @differentiable(where Element: Differentiable) + func method() -> Element { + element + } + + @differentiable(where T: Differentiable) + func method(_ x: T) -> T { + x + } + + // NOTE(TF-815): This crashed without `DifferentiableAttributeScope` support. + @differentiable(where Element: Differentiable) + subscript(implicitGetterOnly_ : Void) -> Element { + element + } + + subscript(explicitGetterAndSetter _: Void) -> Element { + @differentiable(where Element: Differentiable) + get { element } + set {} + } + + // NOTE(TF-815): This crashed without `DifferentiableAttributeScope` support. + @differentiable(where Element: Differentiable) + var computedProperty: Element { + element + } + + var computedPropertyExplicitGetter: Element { + @differentiable(where Element: Differentiable) + get { + element + } + } +} From 774efc16cfc51c5cb919845e299f7f69b8114178 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Mon, 30 Sep 2019 17:09:02 -0700 Subject: [PATCH 2/2] Fix ASTScope support for `@differentiating` attribute. `Decl::getSourceRangeIncludingAttrs` should not consider implicit `@differentiable` attributes generated during `@differentiating` attribute type-checking. TF-835 tracks robust lowering for `@differentiating` attributes that does not involve generating implicit `@differentiable` attributes, circumventing this issue. --- lib/AST/ASTScopeCreation.cpp | 8 +++- lib/AST/Decl.cpp | 10 +++++ .../astscope-differentiating-attr.swift | 37 +++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 test/NameBinding/astscope-differentiating-attr.swift diff --git a/lib/AST/ASTScopeCreation.cpp b/lib/AST/ASTScopeCreation.cpp index 44de37d6c66b2..041d7a8e62973 100644 --- a/lib/AST/ASTScopeCreation.cpp +++ b/lib/AST/ASTScopeCreation.cpp @@ -459,7 +459,13 @@ class ScopeCreator final { std::vector sortedDifferentiableAttrs; for (auto *attr : decl->getAttrs()) if (auto *diffAttr = dyn_cast(attr)) - sortedDifferentiableAttrs.push_back(diffAttr); + // NOTE(TF-835): Skipping implicit `@differentiable` attributes is + // necessary to avoid verification failure: + // `ASTScopeImpl::verifyThatChildrenAreContainedWithin`. + // Perhaps this check is no longer necessary after TF-835: robust + // `@differentiating` attribute lowering. + if (!diffAttr->isImplicit()) + sortedDifferentiableAttrs.push_back(diffAttr); for (auto *diffAttr : sortBySourceRange(sortedDifferentiableAttrs)) fn(diffAttr); } diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index c9ea7e2d7967b..defff427b9757 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -464,6 +464,16 @@ SourceRange Decl::getSourceRangeIncludingAttrs() const { } for (auto Attr : getAttrs()) { + // SWIFT_ENABLE_TENSORFLOW + // Skip implicitly `@differentiable` attribute generated during + // `@differentiating` attribute type-checking. + // TODO(TF-835): Instead of generating implicit `@differentiable` + // attributes, lower `@differentiating` attributes to `[differentiable]` + // attributes on the referenced declaration. + if (auto *diffAttr = dyn_cast(Attr)) + if (diffAttr->isImplicit()) + continue; + // SWIFT_ENABLE_TENSORFLOW END if (Attr->getRange().isValid()) Range.widen(Attr->getRangeWithAt()); } diff --git a/test/NameBinding/astscope-differentiating-attr.swift b/test/NameBinding/astscope-differentiating-attr.swift new file mode 100644 index 0000000000000..13eabc6fdf864 --- /dev/null +++ b/test/NameBinding/astscope-differentiating-attr.swift @@ -0,0 +1,37 @@ +// SWIFT_ENABLE_TENSORFLOW +// Check that ASTScope lookup works for `@differentiating` attribute. + +// NOTE(TF-835): This test is only necessary because `@differentiating` +// attribute type-checking generates implicit `@differentiable` attributes +// on the referenced declaration. Robust lowering for `@differentiating` +// attributes should make special logic regarding implicit `@differentiable` +// attributes unnecessary. + +// RUN: %target-swift-frontend -typecheck %s -enable-astscope-lookup + +struct Test { + var element: Element +} +extension Test: Differentiable where Element: Differentiable {} +extension Test { + static func +(lhs: Self, rhs: Self) -> Self { + lhs + } + static func -(lhs: Self, rhs: Self) -> Self { + lhs + } +} + +extension Test where Element : Differentiable { + @differentiating(+) + internal static func _vjpAdd(lhs: Self, rhs: Self) + -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) { + return (lhs + rhs, { v in (v, v) }) + } + + @differentiating(-) + internal static func _vjpSubtract(lhs: Self, rhs: Self) + -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) { + return (lhs + rhs, { v in (v, v) }) + } +}