-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support of custom _forw
functions
#1037
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clang-tidy made some suggestions
Expr* customForwPassCE = | ||
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( | ||
forwPassFnName, args, getCurrentScope(), | ||
const_cast<DeclContext*>(FD->getDeclContext())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning: do not use const_cast [cppcoreguidelines-pro-type-const-cast]
const_cast<DeclContext*>(FD->getDeclContext()));
^
5d64c63
to
08d8eba
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #1037 +/- ##
==========================================
+ Coverage 94.09% 94.12% +0.02%
==========================================
Files 55 55
Lines 8250 8275 +25
==========================================
+ Hits 7763 7789 +26
+ Misses 487 486 -1
... and 2 files with indirect coverage changes
|
58129cd
to
a14f6aa
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clang-tidy made some suggestions
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( | ||
forwPassFnName, args, getCurrentScope(), | ||
const_cast<DeclContext*>(FD->getDeclContext())); | ||
return customForwPassCE; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning: do not use const_cast [cppcoreguidelines-pro-type-const-cast]
const_cast<DeclContext*>(FD->getDeclContext()));
^
a14f6aa
to
bbc015b
Compare
bbc015b
to
d77cb5e
Compare
_forw
functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lgtm! @PetroZarytskyi, can you take a look at the change in the TBR part?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! this should help a lot. @parth-07 is this consistent with what you'd thought about custom constructor support in the reverse mode? if so, I think we can do smth similar for them, I suppose
d77cb5e
to
bcf7911
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This commit adds support for custom (user-provided) `_forw` functions. A `_forw` function, if available, is called in place of the actual function. For example, if the primal code contains: ```cpp someFn(u, v, w); ``` and user has defined a custom `_reverse_forw` function for `someFn` as follows: ```cpp namespace clad { namespace custom_derivatives { void someFn_reverse_forw(double u, double v, double w, double *d_u, double *d_v, double *dw) { // ... // ... } } } ``` Then clad will generate the derivative function as follows: ```cpp // forward-pass clad::custom_derivatives::someFn_reverse_forw(u, v, w, d_u, d_v, d_w); // ... // reverse-pass; no change in reverse-pass someFn_pullback(u, v, w, d_u, d_v, d_w); // ... ``` But more importantly, why do we need such a functionality? Two reasons: - Supporting reference/pointer return types in the reverse-mode. This has been discussed at great length here: vgvassilev#425 (vgvassilev#425) - Supporting types whose elements grows dynamically, such as `std::vector` and `std::map`. The issue is that we correctly need to update the size/property of the adjoint variable when a function call updates the size/property of the corresponding primal variable. For example: a call to `vec.push_back(...)` should update the size of `_d_vec` as well. However, the actual function call does not modify the adjoint variable in any way. Here comes `_forw` functions to the rescue. `_forw` functions makes it possible to adjust the adjoint variable size/properties along with executing the actual function call. Please note that `_reverse_forw` function signature takes adjoint variables as arguments and return `clad::ValueAndAdjoint<U, V>` to support the reference/pointer return type.
bcf7911
to
6ede83c
Compare
This commit adds support for custom (user-provided)
_forw
functions.A
_forw
function, if available, is called in place of the actualfunction.
For example, if the primal code contains:
someFn(u, v, w);
and user has defined a custom
_forw
function forsomeFn
as follows:Then clad will generate the derivative function as follows:
But more importantly, why do we need such a functionality? Two reasons:
Supporting reference/pointer return types in the reverse-mode. This
has been discussed at great length here:
Add initial support for diff of ref return types in rev mode #425 (Add initial support for diff of ref return types in rev mode #425)
Supporting types whose elements grows dynamically, such as
std::vector
andstd::map
. The issue is that we correctlyneed to update the size/property of the adjoint variable when a
function call updates the size/property of the corresponding primal
variable. For example: a call to
vec.push_back(...)
should updatethe size of
_d_vec
as well. However, the actual function call doesnot modify the adjoint variable in any way. Here comes
_forw
functionsto the rescue.
_forw
functions makes it possible to adjust the adjointvariable size/properties along with executing the actual function call.
Please note that
_forw
function signature takes adjoint variables asarguments and return
clad::ValueAndAdjoint<U, V>
to support thereference/pointer return type.