Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1540,6 +1540,7 @@ class DifferentiableAttr final
private llvm::TrailingObjects<DifferentiableAttr,
ParsedAutoDiffParameter> {
friend TrailingObjects;
friend class DifferentiableAttributeParameterIndicesRequest;

/// The declaration on which the `@differentiable` attribute is declared.
Decl *OriginalDeclaration = nullptr;
Expand All @@ -1558,7 +1559,8 @@ class DifferentiableAttr final
/// specified.
FuncDecl *VJPFunction = nullptr;
/// The differentiation parameters' indices, resolved by the type checker.
IndexSubset *ParameterIndices = nullptr;
/// The bit stores whether the parameter indices have been computed.
llvm::PointerIntPair<IndexSubset *, 1, bool> ParameterIndicesAndBit;
/// The trailing where clause (optional).
TrailingWhereClause *WhereClause = nullptr;
/// The generic signature for autodiff derivative functions. Resolved by the
Expand All @@ -1575,9 +1577,9 @@ class DifferentiableAttr final
Optional<DeclNameWithLoc> vjp,
TrailingWhereClause *clause);

explicit DifferentiableAttr(Decl *original, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear, IndexSubset *indices,
explicit DifferentiableAttr(Decl *original, bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
IndexSubset *indices,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
GenericSignature derivativeGenericSignature);
Expand Down Expand Up @@ -1611,12 +1613,9 @@ class DifferentiableAttr final
/// registered VJP.
Optional<DeclNameWithLoc> getVJP() const { return VJP; }

IndexSubset *getParameterIndices() const {
return ParameterIndices;
}
void setParameterIndices(IndexSubset *pi) {
ParameterIndices = pi;
}
bool hasComputedParameterIndices() const;
IndexSubset *getParameterIndices() const;
void setParameterIndices(IndexSubset *paramIndices);

/// The parsed differentiation parameters, i.e. the list of parameters
/// specified in 'wrt:'.
Expand Down Expand Up @@ -1647,8 +1646,7 @@ class DifferentiableAttr final
void setVJPFunction(FuncDecl *decl);

bool parametersMatch(const DifferentiableAttr &other) const {
assert(ParameterIndices && other.ParameterIndices);
return ParameterIndices == other.ParameterIndices;
return getParameterIndices() == other.getParameterIndices();
}

/// Get the derivative generic environment for the given `@differentiable`
Expand Down
23 changes: 23 additions & 0 deletions include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,29 @@ class SynthesizeDefaultInitRequest
bool isCached() const { return true; }
};

// SWIFT_ENABLE_TENSORFLOW
class DifferentiableAttributeParameterIndicesRequest :
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be named DifferentiableAttributeTypeCheckRequest? It currently does a lot more than calculate the parameter indices, so the current name is pretty confusing in the places where you're calling it to do something different. (Especially the accessordecls in SILFunctionBuilder.cpp).

public SimpleRequest<DifferentiableAttributeParameterIndicesRequest,
IndexSubset *(DifferentiableAttr *, Decl *),
CacheKind::SeparatelyCached> {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we can't use CacheKind::Cached to cache it in the evaluator? Then we wouldn't need to store the indices in the DifferentiableAttr, which would be a nice simplification.

Not necessary for this PR: It would be super good if we also made this request return the other stuff that's resolved by the type checker (jvp, vjp, and where clause). Then DifferentiableAttr would become a completely stateless thing which would be a big simplification, and which seems very aligned with the overall goals of the requesitification refactoring! We don't even need to separate them into separate requests (as mentioned in a TODO further below) to do this. We could just define some struct that contains indices, vjp, jvp, and where clause and have the request return that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we can't use CacheKind::Cached to cache it in the evaluator? Then we wouldn't need to store the indices in the DifferentiableAttr, which would be a nice simplification.

CacheKind::Cached sounds nice if possible!

I tried to make DifferentiableAttributeTypeCheckRequest cached: dan-zheng@diff-attr-request-cleanup
But I couldn't yet find a way to do it that works with attributes in non-primary-files:

Assertion failed: (paramIndices && "Parameter indices should have been resolved"), function addFunctionAttributes, file /Users/danielzheng/swift-tf/swift/lib/SIL/SILFunctionBuilder.cpp, line 98.
Stack dump:
0.	Program arguments: /Users/danielzheng/swift-tf/build/Ninja-ReleaseAssert+stdlib-Release/swift-macosx-x86_64/bin/swiftc -frontend -target x86_64-apple-macosx10.9 -module-cache-path /Users/danielzheng/swift-tf/build/Ninja-ReleaseAssert+stdlib-Release/swift-macosx-x86_64/swift-test-results/x86_64-apple-macosx10.9/clang-module-cache -sdk /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.15.sdk -swift-version 4 -ignore-module-source-info -typo-correction-limit 10 -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=missing /Users/danielzheng/swift-tf/swift/test/AutoDiff/tbdgen.swift -enable-testing
1.	Swift version 5.1.1-dev (Swift 37b5644988)
2.	While emitting SIL for getter for base (at /Users/danielzheng/swift-tf/swift/test/AutoDiff/tbdgen.swift:73:7)
0  swiftc                   0x000000010fac6115 llvm::sys::PrintStackTrace(llvm::raw_ostream&) + 37
1  swiftc                   0x000000010fac5118 llvm::sys::RunSignalHandlers() + 248
2  swiftc                   0x000000010fac6708 SignalHandler(int) + 264
3  libsystem_platform.dylib 0x00007fff728e4b5d _sigtramp + 29
4  swiftc                   0x0000000111ff3a08 cmark_strbuf__initbuf + 148307
5  libsystem_c.dylib        0x00007fff7279e6a6 abort + 127
6  libsystem_c.dylib        0x00007fff7276720d basename_r + 0
7  swiftc                   0x000000010fd41411 swift::SILFunctionBuilder::addFunctionAttributes(swift::SILFunction*, swift::DeclAttributes&, swift::SILModule&, swift::SILDeclRef) (.cold.8) + 3

There exist other isCached functions in lib/AST/TypeCheckRequests.cpp, but they seem like simpler cases than @differentiable attribute:

bool PropertyWrapperTypeInfoRequest::isCached() const {
  auto nominal = std::get<0>(getStorage());
  return nominal->getAttrs().hasAttribute<PropertyWrapperAttr>();;
}

bool AttachedPropertyWrappersRequest::isCached() const {
  auto var = std::get<0>(getStorage());
  return !var->getAttrs().isEmpty();
}

...

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that return true is a correct implementation of isCached(). isCached() decides whether the request should be cached, not whether the request has already been cached:
https://github.com/apple/swift/blob/8a9f53594e7b049c238f21cc6e22f5c10134e28c/include/swift/AST/Evaluator.h#L155

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that return true is a correct implementation of isCached().

Thanks for investigating! Here are the test failures when DifferentiableAttributeTypeCheckRequest ::isCached returns true: https://gist.github.com/dan-zheng/a22887eedd02a5a6c4ea1db2f4f38005

Implementation at dan-zheng@diff-attr-request-cleanup. Perhaps the test failures are due to orthogonal reasons, I'll investigate further.

public:
using SimpleRequest::SimpleRequest;

private:
friend SimpleRequest;

// Evaluation.
llvm::Expected<IndexSubset *>
evaluate(Evaluator &evaluator, DifferentiableAttr *attr, Decl *decl) const;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove Decl *decl and have this function look up the decl in the attr itself?

I believe that a design goal of these request things is that it should be easy for anyone to make a request, so reducing the number of arguments that the callers need to figure out is good.


public:
// Separate caching.
bool isCached() const { return true; }
Optional<IndexSubset *> getCachedResult() const;
void cacheResult(IndexSubset *value) const;
};
// SWIFT_ENABLE_TENSORFLOW END

// Allow AnyValue to compare two Type values, even though Type doesn't
// support ==.
template<>
Expand Down
5 changes: 5 additions & 0 deletions include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ SWIFT_REQUEST(TypeChecker, ClassAncestryFlagsRequest,
AncestryFlags(ClassDecl *), Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, DefaultDefinitionTypeRequest,
Type(AssociatedTypeDecl *), Cached, NoLocationInfo)
// SWIFT_ENABLE_TENSORFLOW
SWIFT_REQUEST(TypeChecker, DifferentiableAttributeParameterIndicesRequest,
IndexSubset *(DifferentiableAttr *, Decl *),
SeparatelyCached, NoLocationInfo)
// SWIFT_ENABLE_TENSORFLOW END
SWIFT_REQUEST(TypeChecker, DefaultTypeRequest,
Type(KnownProtocolKind, const DeclContext *), SeparatelyCached,
NoLocationInfo)
Expand Down
31 changes: 29 additions & 2 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include "swift/AST/GenericSignatureBuilder.h"
#include "swift/AST/Module.h"
#include "swift/AST/TypeRepr.h"
// SWIFT_ENABLE_TENSORFLOW
#include "swift/AST/TypeCheckRequests.h"
#include "swift/AST/Types.h"
// SWIFT_ENABLE_TENSORFLOW
#include "swift/AST/ParameterList.h"
Expand Down Expand Up @@ -1461,9 +1463,9 @@ DifferentiableAttr::DifferentiableAttr(Decl *original, bool implicit,
Optional<DeclNameWithLoc> vjp,
GenericSignature derivativeGenSig)
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)),
ParameterIndices(indices) {
Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)) {
setOriginalDeclaration(original);
setParameterIndices(indices);
setDerivativeGenericSignature(derivativeGenSig);
}

Expand Down Expand Up @@ -1503,6 +1505,31 @@ void DifferentiableAttr::setOriginalDeclaration(Decl *decl) {
OriginalDeclaration = decl;
}

bool DifferentiableAttr::hasComputedParameterIndices() const {
return ParameterIndicesAndBit.getInt();
}

IndexSubset *DifferentiableAttr::getParameterIndices() const {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems more in line with (my understanding of) the request system philosophy to have callers make the requests directly instead of going through wrapper methods like this. Could you delete this method and make requests directly instead? Removing Decl *decl from the request makes this much easier.

If there are a huge number of callers of getParameterIndices, you could keep this for now to simplify the PR and then we can gradually migrate to making the request.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, my understanding is that wrappers around requests are desirable to keep caller code short; request evaluation is just an internal detail. I don't think there are downsides to using wrappers instead of direct results.

Examples:

Let me know if that makes sense!

assert(getOriginalDeclaration() &&
"Original declaration must have been resolved");
auto &ctx = getOriginalDeclaration()->getASTContext();
return evaluateOrDefault(
ctx.evaluator,
DifferentiableAttributeParameterIndicesRequest{
const_cast<DifferentiableAttr *>(this), getOriginalDeclaration()},
nullptr);
}

void DifferentiableAttr::setParameterIndices(IndexSubset *paramIndices) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as getParameterIndices.

assert(getOriginalDeclaration() &&
"Original declaration must have been resolved");
auto &ctx = getOriginalDeclaration()->getASTContext();
ctx.evaluator.cacheOutput(
DifferentiableAttributeParameterIndicesRequest{
const_cast<DifferentiableAttr *>(this), getOriginalDeclaration()},
std::move(paramIndices));
}

void DifferentiableAttr::setJVPFunction(FuncDecl *decl) {
JVPFunction = decl;
if (decl && !JVP)
Expand Down
18 changes: 18 additions & 0 deletions lib/AST/TypeCheckRequests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1072,3 +1072,21 @@ void swift::simple_display(llvm::raw_ostream &out,
out << "precedence group " << desc.ident << " at ";
desc.nameLoc.print(out, desc.dc->getASTContext().SourceMgr);
}

//----------------------------------------------------------------------------//
// DifferentiableAttributeParameterIndicesRequest computation.
//----------------------------------------------------------------------------//

Optional<IndexSubset *>
DifferentiableAttributeParameterIndicesRequest::getCachedResult() const {
auto *attr = std::get<0>(getStorage());
if (attr->hasComputedParameterIndices())
return attr->ParameterIndicesAndBit.getPointer();
return None;
}

void DifferentiableAttributeParameterIndicesRequest::cacheResult(
IndexSubset *parameterIndices) const {
auto *attr = std::get<0>(getStorage());
attr->ParameterIndicesAndBit.setPointerAndInt(parameterIndices, true);
}
5 changes: 5 additions & 0 deletions lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3390,6 +3390,7 @@ static void setOriginalFunctionInDifferentiableAttributes(
for (auto *attr : Attributes.getAttributes<DifferentiableAttr>())
const_cast<DifferentiableAttr *>(attr)->setOriginalDeclaration(D);
}
// SWIFT_ENABLE_TENSORFLOW END

/// Parse a single syntactic declaration and return a list of decl
/// ASTs. This can return multiple results for var decls that bind to multiple
Expand Down Expand Up @@ -3836,6 +3837,7 @@ Parser::parseDecl(ParseDeclOptions Flags,
// SWIFT_ENABLE_TENSORFLOW END
}
// SWIFT_ENABLE_TENSORFLOW
// Set original declaration in `@differentiable` attributes.
setOriginalFunctionInDifferentiableAttributes(D->getAttrs(), D);
// SWIFT_ENABLE_TENSORFLOW END
}
Expand Down Expand Up @@ -5592,6 +5594,7 @@ Parser::parseDeclVarGetSet(Pattern *pattern, ParseDeclOptions Flags,
accessors.record(*this, PrimaryVar, Invalid);

// SWIFT_ENABLE_TENSORFLOW
// Set original declaration in `@differentiable` attributes.
for (auto *accessor : accessors.Accessors)
setOriginalFunctionInDifferentiableAttributes(accessor->getAttrs(),
accessor);
Expand Down Expand Up @@ -5852,6 +5855,7 @@ Parser::parseDeclVar(ParseDeclOptions Flags,
VD->setStatic(StaticLoc.isValid());
VD->getAttrs() = Attributes;
// SWIFT_ENABLE_TENSORFLOW
// Set original declaration in `@differentiable` attributes.
setOriginalFunctionInDifferentiableAttributes(Attributes, VD);
// SWIFT_ENABLE_TENSORFLOW END
setLocalDiscriminator(VD);
Expand Down Expand Up @@ -7109,6 +7113,7 @@ Parser::parseDeclSubscript(SourceLoc StaticLoc,
accessors.record(*this, Subscript, (Invalid || !Status.isSuccess()));

// SWIFT_ENABLE_TENSORFLOW
// Set original declaration in `@differentiable` attributes.
for (auto *accessor : accessors.Accessors)
setOriginalFunctionInDifferentiableAttributes(accessor->getAttrs(),
accessor);
Expand Down
8 changes: 8 additions & 0 deletions lib/SIL/SILFunctionBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ void SILFunctionBuilder::addFunctionAttributes(SILFunction *F,
!constant.autoDiffDerivativeFunctionIdentifier &&
!constant.isStoredPropertyInitializer() &&
!constant.isThunk()) {
// NOTE: Validate `@differentiable` attributes on `AccessorDecl`s by calling
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't completely understand this. The request seems to only delete attributes in error cases (diagnoseAndRemoveAttr). Is this call to getParameterIndices() only important for handling error cases, or is there a non-error case where it's also important?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These calls to getParameterIndices are necessary to trigger @differentiable attribute type-checking for attributes in non-primary-files. Otherwise, attribute parameter indices and other fields may not be resolved.

// `getParameterIndices`. This is significant to prevent duplicate SIL
// `[differentiable]` attribute generation: `getParameterIndices` deletes
// `@differentiable` attributes whose original declaration is an
// `AbstractStorageDecl`.
if (isa<AccessorDecl>(decl))
for (auto *A : Attrs.getAttributes<DifferentiableAttr>())
(void)A->getParameterIndices();
for (auto *A : Attrs.getAttributes<DifferentiableAttr>()) {
// Get lowered argument indices.
auto *paramIndices = A->getParameterIndices();
Expand Down
12 changes: 5 additions & 7 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,13 +613,12 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
// call to the getter.
if (member->getEffectiveAccess() > AccessLevel::Internal &&
!member->getAttrs().hasAttribute<DifferentiableAttr>()) {
(void)member->getAccessor(AccessorKind::Get)->getInterfaceType();
auto *getter = member->getSynthesizedAccessor(AccessorKind::Get);
(void)getter->getInterfaceType();
// If member or its getter already has a `@differentiable` attribute,
// continue.
if (member->getAttrs().hasAttribute<DifferentiableAttr>() ||
member->getAccessor(AccessorKind::Get)
->getAttrs()
.hasAttribute<DifferentiableAttr>())
getter->getAttrs().hasAttribute<DifferentiableAttr>())
continue;
GenericSignature derivativeGenSig = GenericSignature();
// If the parent declaration context is an extension, the nominal type may
Expand All @@ -628,9 +627,8 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
if (auto *extDecl = dyn_cast<ExtensionDecl>(parentDC->getAsDecl()))
derivativeGenSig = extDecl->getGenericSignature();
auto *diffableAttr = DifferentiableAttr::create(
member->getAccessor(AccessorKind::Get), /*implicit*/ true,
SourceLoc(), SourceLoc(), /*linear*/ false, {}, None, None,
derivativeGenSig);
getter, /*implicit*/ true, SourceLoc(), SourceLoc(),
/*linear*/ false, {}, None, None, derivativeGenSig);
member->getAttrs().add(diffableAttr);
// Set getter `@differentiable` attribute parameter indices.
diffableAttr->setParameterIndices(IndexSubset::get(C, 1, {0}));
Expand Down
Loading