-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[AutoDiff] Requestify @differentiable attribute parameter indices.
#28017
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1634,6 +1634,29 @@ class SynthesizeDefaultInitRequest | |
| bool isCached() const { return true; } | ||
| }; | ||
|
|
||
| // SWIFT_ENABLE_TENSORFLOW | ||
| class DifferentiableAttributeParameterIndicesRequest : | ||
| public SimpleRequest<DifferentiableAttributeParameterIndicesRequest, | ||
| IndexSubset *(DifferentiableAttr *, Decl *), | ||
| CacheKind::SeparatelyCached> { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason we can't use Not necessary for this PR: It would be super good if we also made this request return the other stuff that's resolved by the type checker (jvp, vjp, and where clause). Then DifferentiableAttr would become a completely stateless thing which would be a big simplification, and which seems very aligned with the overall goals of the requesitification refactoring! We don't even need to separate them into separate requests (as mentioned in a TODO further below) to do this. We could just define some struct that contains indices, vjp, jvp, and where clause and have the request return that.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I tried to make There exist other bool PropertyWrapperTypeInfoRequest::isCached() const {
auto nominal = std::get<0>(getStorage());
return nominal->getAttrs().hasAttribute<PropertyWrapperAttr>();;
}
bool AttachedPropertyWrappersRequest::isCached() const {
auto var = std::get<0>(getStorage());
return !var->getAttrs().isEmpty();
}
...There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Thanks for investigating! Here are the test failures when Implementation at dan-zheng@diff-attr-request-cleanup. Perhaps the test failures are due to orthogonal reasons, I'll investigate further. |
||
| public: | ||
| using SimpleRequest::SimpleRequest; | ||
|
|
||
| private: | ||
| friend SimpleRequest; | ||
|
|
||
| // Evaluation. | ||
| llvm::Expected<IndexSubset *> | ||
| evaluate(Evaluator &evaluator, DifferentiableAttr *attr, Decl *decl) const; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you remove I believe that a design goal of these request things is that it should be easy for anyone to make a request, so reducing the number of arguments that the callers need to figure out is good. |
||
|
|
||
| public: | ||
| // Separate caching. | ||
| bool isCached() const { return true; } | ||
| Optional<IndexSubset *> getCachedResult() const; | ||
| void cacheResult(IndexSubset *value) const; | ||
| }; | ||
| // SWIFT_ENABLE_TENSORFLOW END | ||
|
|
||
| // Allow AnyValue to compare two Type values, even though Type doesn't | ||
| // support ==. | ||
| template<> | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,8 @@ | |
| #include "swift/AST/GenericSignatureBuilder.h" | ||
| #include "swift/AST/Module.h" | ||
| #include "swift/AST/TypeRepr.h" | ||
| // SWIFT_ENABLE_TENSORFLOW | ||
| #include "swift/AST/TypeCheckRequests.h" | ||
| #include "swift/AST/Types.h" | ||
| // SWIFT_ENABLE_TENSORFLOW | ||
| #include "swift/AST/ParameterList.h" | ||
|
|
@@ -1461,9 +1463,9 @@ DifferentiableAttr::DifferentiableAttr(Decl *original, bool implicit, | |
| Optional<DeclNameWithLoc> vjp, | ||
| GenericSignature derivativeGenSig) | ||
| : DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit), | ||
| Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)), | ||
| ParameterIndices(indices) { | ||
| Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)) { | ||
| setOriginalDeclaration(original); | ||
| setParameterIndices(indices); | ||
| setDerivativeGenericSignature(derivativeGenSig); | ||
| } | ||
|
|
||
|
|
@@ -1503,6 +1505,31 @@ void DifferentiableAttr::setOriginalDeclaration(Decl *decl) { | |
| OriginalDeclaration = decl; | ||
| } | ||
|
|
||
| bool DifferentiableAttr::hasComputedParameterIndices() const { | ||
| return ParameterIndicesAndBit.getInt(); | ||
| } | ||
|
|
||
| IndexSubset *DifferentiableAttr::getParameterIndices() const { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems more in line with (my understanding of) the request system philosophy to have callers make the requests directly instead of going through wrapper methods like this. Could you delete this method and make requests directly instead? Removing If there are a huge number of callers of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, my understanding is that wrappers around requests are desirable to keep caller code short; request evaluation is just an internal detail. I don't think there are downsides to using wrappers instead of direct results. Examples:
Let me know if that makes sense! |
||
| assert(getOriginalDeclaration() && | ||
| "Original declaration must have been resolved"); | ||
| auto &ctx = getOriginalDeclaration()->getASTContext(); | ||
| return evaluateOrDefault( | ||
| ctx.evaluator, | ||
| DifferentiableAttributeParameterIndicesRequest{ | ||
| const_cast<DifferentiableAttr *>(this), getOriginalDeclaration()}, | ||
| nullptr); | ||
| } | ||
|
|
||
| void DifferentiableAttr::setParameterIndices(IndexSubset *paramIndices) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as |
||
| assert(getOriginalDeclaration() && | ||
| "Original declaration must have been resolved"); | ||
| auto &ctx = getOriginalDeclaration()->getASTContext(); | ||
| ctx.evaluator.cacheOutput( | ||
| DifferentiableAttributeParameterIndicesRequest{ | ||
| const_cast<DifferentiableAttr *>(this), getOriginalDeclaration()}, | ||
| std::move(paramIndices)); | ||
| } | ||
|
|
||
| void DifferentiableAttr::setJVPFunction(FuncDecl *decl) { | ||
| JVPFunction = decl; | ||
| if (decl && !JVP) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -83,6 +83,14 @@ void SILFunctionBuilder::addFunctionAttributes(SILFunction *F, | |
| !constant.autoDiffDerivativeFunctionIdentifier && | ||
| !constant.isStoredPropertyInitializer() && | ||
| !constant.isThunk()) { | ||
| // NOTE: Validate `@differentiable` attributes on `AccessorDecl`s by calling | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't completely understand this. The request seems to only delete attributes in error cases (
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These calls to |
||
| // `getParameterIndices`. This is significant to prevent duplicate SIL | ||
| // `[differentiable]` attribute generation: `getParameterIndices` deletes | ||
| // `@differentiable` attributes whose original declaration is an | ||
| // `AbstractStorageDecl`. | ||
| if (isa<AccessorDecl>(decl)) | ||
| for (auto *A : Attrs.getAttributes<DifferentiableAttr>()) | ||
| (void)A->getParameterIndices(); | ||
| for (auto *A : Attrs.getAttributes<DifferentiableAttr>()) { | ||
| // Get lowered argument indices. | ||
| auto *paramIndices = A->getParameterIndices(); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be named
DifferentiableAttributeTypeCheckRequest? It currently does a lot more than calculate the parameter indices, so the current name is pretty confusing in the places where you're calling it to do something different. (Especially the accessordecls in SILFunctionBuilder.cpp).