Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify pullback calls in the reverse mode #802

Merged
merged 10 commits into from
Mar 18, 2024

Conversation

PetroZarytskyi
Copy link
Collaborator

@PetroZarytskyi PetroZarytskyi commented Mar 5, 2024

This PR simplifies both the code of RMV::VisitCallExpr and the code generated by it. In particular, it replaces the _grad/_r variable pairs with single _r variables. The PR also removes some dead code from RMV::VisitCallExpr, addresses clang-tidy warnings triggered by it, and improves its test coverage.
Fixes #801.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-tidy made some suggestions

lib/Differentiator/ReverseModeVisitor.cpp Show resolved Hide resolved
lib/Differentiator/ReverseModeVisitor.cpp Show resolved Hide resolved
lib/Differentiator/ReverseModeVisitor.cpp Show resolved Hide resolved
lib/Differentiator/ReverseModeVisitor.cpp Outdated Show resolved Hide resolved
lib/Differentiator/VisitorBase.cpp Outdated Show resolved Hide resolved
@vgvassilev vgvassilev requested a review from parth-07 March 6, 2024 06:02
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-tidy made some suggestions

lib/Differentiator/ReverseModeVisitor.cpp Outdated Show resolved Hide resolved
lib/Differentiator/VisitorBase.cpp Outdated Show resolved Hide resolved
@parth-07
Copy link
Collaborator

parth-07 commented Mar 7, 2024

The work in this pull request seems good. Thank you for proactively trying to improve one of the most complicated components in the codebase. However, I think we may need to look at the issue and the solution more closely.

Please correct me if I am wrong anywhere. My autodiff skills are a bit rusty now.

Given a function:

double do_something(double p_u, double p_v) { ... }

Considering the primal code is r = do_something(u, v), then previously, we used to differentiate a call to do_something as follows:

double _r_d0 = _d_res;
double _grad0 = 0;
double _grad1 = 0;
do_something_pullback(u, v, _r_d0, _grad0, _grad1);
double _r0 = _grad0;
double _r1 = _grad1;
*_d_u += _r0;
*_d_v += _r1

With your patch, the same call will be differentiated as:

double _r_d0 = _d_res;
do_something_pullback(u, v, _r_d0, &*_d_u, &*_d_v);

The above two differentiated codes look similar and it appears that they should bring the same result mathematically. However, that's not the case. This is because of two reasons:

  • a function call hides assignment operations: assigning values of arguments to parameters
  • Argument and parameters are different mathematical variables and therefore they should not share adjoint variables.

Let's look at what happens in a function call more closely.

Primal code

r = do_something(u, v); // u and v are passed by value

This call can be made more clear by expanding it as follows:

double p_u = u;
double p_v = v;

do_something(p_u, p_v); // p_u, and p_v are passed by reference

In the original case, u and v are passed by value and assignments to p_u and p_v happens implicitly. We cannot see it but it's very much there. Now, why does this matter? In the second case, can we use the adjoint variables _d_u and _d_v for representing the adjoints of p_u and p_v? No, of course not. We need different adjoint variables for p_u and p_v because p_u is not the same variable as u and p_v is not the same variable as v. Using the same logic, we cannot use the adjoints of u and v when passing grad variables to do_something_pullback.

A practical example where this distinction matters:

#include "clad/Differentiator/Differentiator.h"
#include <iostream>
#define show(x) std::cout << #x << ": " << x <<"\n";

double reset(double u) {
    u = 0;
    return u;
}

double fn(double u, double v) {
    double res = u + v;
    res += reset(u);
    res += u;
    return res;
}

int main() {
    auto fn_grad = clad::gradient(fn);
    double u = 3, v = 5, du = 0, dv = 0;
    fn_grad.execute(u, v, &du, &dv);
    show(du);
    show(dv);
}

With this patch, the above code (incorrectly) outputs:

du: 1
dv: 1

The correct output is:

du: 2
dv: 1

Please let me know your thoughts on this.

Copy link
Collaborator

@parth-07 parth-07 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see the above comment for details.

@PetroZarytskyi
Copy link
Collaborator Author

@parth-07
Hi Parth. Thank you for your comment. I agree with your reasoning, I got confused. I still think this PR can be fixed to actually simplify things without breaking anything. Even though the changes will not be that drastic. But it needs more work.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-tidy made some suggestions

lib/Differentiator/ReverseModeVisitor.cpp Outdated Show resolved Hide resolved
@parth-07
Copy link
Collaborator

I still think this PR can be fixed to actually simplify things without breaking anything.

That's great! It's important to simplify the call differentiation component. It has grown really complex over the years.

@PetroZarytskyi PetroZarytskyi force-pushed the simplify-pullback branch 3 times, most recently from 836aa4f to b813c44 Compare March 12, 2024 22:56
Copy link

codecov bot commented Mar 12, 2024

Codecov Report

Attention: Patch coverage is 97.67442% with 2 lines in your changes are missing coverage. Please review.

Project coverage is 94.95%. Comparing base (8a77f81) to head (b340d05).
Report is 4 commits behind head on master.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #802      +/-   ##
==========================================
+ Coverage   94.86%   94.95%   +0.08%     
==========================================
  Files          49       49              
  Lines        7357     7468     +111     
==========================================
+ Hits         6979     7091     +112     
+ Misses        378      377       -1     
Files Coverage Δ
include/clad/Differentiator/CladUtils.h 100.00% <ø> (ø)
include/clad/Differentiator/ErrorEstimator.h 100.00% <ø> (ø)
...e/clad/Differentiator/MultiplexExternalRMVSource.h 100.00% <ø> (ø)
include/clad/Differentiator/ReverseModeVisitor.h 97.87% <ø> (ø)
include/clad/Differentiator/VisitorBase.h 100.00% <ø> (ø)
lib/Differentiator/CladUtils.cpp 93.00% <100.00%> (-3.75%) ⬇️
lib/Differentiator/ErrorEstimator.cpp 99.02% <100.00%> (ø)
lib/Differentiator/MultiplexExternalRMVSource.cpp 90.52% <100.00%> (ø)
lib/Differentiator/ReverseModeVisitor.cpp 97.29% <100.00%> (+0.70%) ⬆️
lib/Differentiator/VisitorBase.cpp 97.74% <ø> (-0.15%) ⬇️
... and 1 more

... and 2 files with indirect coverage changes

Files Coverage Δ
include/clad/Differentiator/CladUtils.h 100.00% <ø> (ø)
include/clad/Differentiator/ErrorEstimator.h 100.00% <ø> (ø)
...e/clad/Differentiator/MultiplexExternalRMVSource.h 100.00% <ø> (ø)
include/clad/Differentiator/ReverseModeVisitor.h 97.87% <ø> (ø)
include/clad/Differentiator/VisitorBase.h 100.00% <ø> (ø)
lib/Differentiator/CladUtils.cpp 93.00% <100.00%> (-3.75%) ⬇️
lib/Differentiator/ErrorEstimator.cpp 99.02% <100.00%> (ø)
lib/Differentiator/MultiplexExternalRMVSource.cpp 90.52% <100.00%> (ø)
lib/Differentiator/ReverseModeVisitor.cpp 97.29% <100.00%> (+0.70%) ⬆️
lib/Differentiator/VisitorBase.cpp 97.74% <ø> (-0.15%) ⬇️
... and 1 more

... and 2 files with indirect coverage changes

@PetroZarytskyi PetroZarytskyi force-pushed the simplify-pullback branch 2 times, most recently from 20c252f to 19e5e9e Compare March 13, 2024 10:51
Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

1 similar comment
Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

@parth-07
Copy link
Collaborator

Hi @PetroZarytskyi

Can you please add more details in the pull-request description / commit-message regarding how the pullback calls are being simplified?

@PetroZarytskyi
Copy link
Collaborator Author

Hi @parth-07, I updated the PR message as well as the #801 issue.

Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

Copy link
Collaborator

@parth-07 parth-07 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good apart from the few minor comments.

include/clad/Differentiator/CladUtils.h Show resolved Hide resolved
// The argument is passed by reference if it's passed as an L-value.
// However, if arg is a MaterializeTemporaryExpr, then arg is a
// temporary variable passed as a const reference.
bool isRefType = arg->isLValue() && !isa<MaterializeTemporaryExpr>(arg);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why is lvalue a reference type?

int a = b; // a is an l-value, but not a reference.
int &a_ref = a; // a_ref is an l-value and a reference.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arg is supposed to be the argument expression passed to the function. If the function expects a ref-type argument, then arg is an l-value (usually a DeclRefExpr). But when it expects a non-ref type argument, it is implicitly converted to an r-value. The AST of arg will look somewhat like this:

ImplicitCastExpr <l-value to r-value>
-DeclRefExpr

So arg will be an r-value. At least this is my understanding.

include/clad/Differentiator/VisitorBase.h Outdated Show resolved Hide resolved
@vgvassilev vgvassilev added this to the v1.5 milestone Mar 16, 2024
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-tidy made some suggestions

lib/Differentiator/ReverseModeVisitor.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

@vgvassilev vgvassilev merged commit 7cbefdf into vgvassilev:master Mar 18, 2024
88 checks passed
@PetroZarytskyi PetroZarytskyi deleted the simplify-pullback branch March 19, 2024 13:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Simplify pullback calls in the reverse mode AD
3 participants