-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add numerical differentiation support to clad #261
Conversation
Codecov Report
@@ Coverage Diff @@
## master #261 +/- ##
==========================================
+ Coverage 88.20% 88.46% +0.25%
==========================================
Files 28 28
Lines 3636 3735 +99
==========================================
+ Hits 3207 3304 +97
- Misses 429 431 +2
|
7b4e1bd
to
ac70839
Compare
7a9ebec
to
e9bc67d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add tests for the diagnostics branches? Codecov tells where they are.
That is blocked right now. When clad fails to derive a function, it crashes. Let me put up an issue for this soon. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The more I think the more I convince myself that PR, at some stage, should turn into a callback class based implementation. It is an opt-in piece and also somewhat an AD bail out plan. Let's resolve the comments we got so far and discuss how to move forward given the limited amount of time we have left.
// Try numerically deriving it. | ||
if (NArgs == 1) { | ||
// Build a clone call expression so that we can correctly | ||
// scope the function to be differentiated. | ||
Expr* call = m_Sema | ||
.ActOnCallExpr(getCurrentScope(), | ||
Clone(CE->getCallee()), | ||
noLoc, | ||
llvm::MutableArrayRef<Expr*>(CallArgs), | ||
noLoc) | ||
.get(); | ||
Expr* fnCallee = cast<CallExpr>(call)->getCallee(); | ||
OverloadedDerivedFn = GetSingleArgCentralDiffCall(fnCallee, | ||
DerivedCallArgs | ||
[0], | ||
0, 1, | ||
DerivedCallArgs); | ||
asGrad = !OverloadedDerivedFn; | ||
} else { | ||
// Build a clone call expression so that we can correctly | ||
// scope the function to be differentiated. | ||
Expr* call = m_Sema | ||
.ActOnCallExpr(getCurrentScope(), | ||
Clone(CE->getCallee()), noLoc, | ||
llvm::MutableArrayRef<Expr*>( | ||
CallArgs), | ||
noLoc) | ||
.get(); | ||
Expr* fnCallee = cast<CallExpr>(call)->getCallee(); | ||
OverloadedDerivedFn = GetMultiArgCentralDiffCall( | ||
fnCallee, CEType.getCanonicalType(), CE->getNumArgs(), | ||
NumericalDiffMultiArg, DerivedCallArgs, DerivedCallOutputArgs); | ||
} | ||
if (!OverloadedDerivedFn) { | ||
// Function was not derived => issue a warning. | ||
diag(DiagnosticsEngine::Warning, | ||
CE->getBeginLoc(), | ||
"function '%0' was not differentiated because clad failed to " | ||
"differentiate it and no suitable overload was found in " | ||
"namespace 'custom_derivatives', and function may not be " | ||
"eligible for numerical differentiation.", | ||
{FD->getNameAsString()}); | ||
return StmtDiff(Clone(CE)); | ||
} else { | ||
diag(DiagnosticsEngine::Remark, noLoc, | ||
"Falling back to numerical differentiation for '%0' since no " | ||
"suitable overload was found and clad could not derive it. " | ||
"To disable this feature, compile your programs with " | ||
"-DCLAD_NO_NUM_DIFF.", | ||
{FD->getNameAsString()}); | ||
} | ||
} else { | ||
OverloadedDerivedFn = m_Sema | ||
.ActOnCallExpr(getCurrentScope(), | ||
BuildDeclRef(derivedFD), | ||
noLoc, | ||
llvm::MutableArrayRef<Expr*>( | ||
DerivedCallArgs), | ||
noLoc) | ||
.get(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move this into a separate function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The whole thing? I don't know if that would be feasible...We will have to send a lot of information as parameters to that function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about putting 1406 - 1423 in a function? Putting all of it will require a function with 8-9 arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vgvassilev How do you think it looks now? If it is not good, I can put it all in a function.
noLoc, | ||
llvm::MutableArrayRef<Expr*>(CallArgs), | ||
noLoc) | ||
.get(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider moving into a separate function.
} | ||
|
||
/// A funtion to calculate the derivative of a function using the central | ||
/// difference formula. Note: we do not propogate errors resulting in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate here and maybe in the comments.
dfb7b2d
to
66d6b03
Compare
@@ -179,6 +179,12 @@ namespace clad { | |||
estimationPlugin->InstantiateCustomModel(*m_DerivativeBuilder)); | |||
} | |||
} | |||
|
|||
// If enabled, set the proper fields in derivative builder. | |||
if (m_DO.PrintNumDiffErrorInfo) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
d20aedc
to
b709ff1
Compare
Clad can now build calls to a "forward" (single-arg) and "reverse" (multi-arg) numerical diff function. Numerical diff can be used in the following contexts: - With clad forward and reverse mode for all in-built scalar and non scalar types. - With clad and standalone for "forward" and "reverse" derivatives. - Standalone for in-built scalar and non-scalar types. - Standalone with support for user defined types as input (both value and pointer forms). - Standalone to differentiate overloaded operators. - Standalone to differentiate functors. Numerical diff error estimates can be printed with the help of the '-fprint-num-diff-errors' flag, and numerical diff may be disabled using the '-DCLAD_NO_NUM_DIFF' flag during compilation.
b709ff1
to
53db6d1
Compare
This functionality will allow clad to be able to use numerical differentiation as a backup (or even standalone) in the case that a differentiation request is made for a function that is not visible to clad. Full list of possible functionalities is as given below:
PR Checklist:
clad::numerical_diff
).Good to implement:
Known bugs: