[math] Support AD for TMath::LnGamma using functions from GSL#16747
[math] Support AD for TMath::LnGamma using functions from GSL#16747guitargeek merged 2 commits intoroot-project:masterfrom
TMath::LnGamma using functions from GSL#16747Conversation
Test Results 18 files 18 suites 3d 17h 0m 4s ⏱️ Results for commit 5f9203c. ♻️ This comment has been updated with latest results. |
|
|
||
| inline void inc_gamma_c_pullback(double a, double x, double _d_y, double *_d_a, double *_d_x); | ||
|
|
||
| inline void inc_gamma_pullback(double a, double x, double _d_y, double *_d_a, double *_d_x) |
There was a problem hiding this comment.
Would that chatgpt suggestion work:
inline void inc_gamma_pullback(double a, double x, double _d_y, double *_d_a, double *_d_x) {
// Constants
constexpr double kMACHEP = 1.11022302462515654042363166809e-16;
constexpr double kMAXLOG = 709.782712893383973096206318587;
double _d_ans = 0, _d_ax = 0, _d_c = 0, _d_r = 0;
bool cond_a_nonpositive = (a <= 0);
bool cond_x_nonpositive = (x <= 0);
bool cond_large_x_a = (x > 1.0) && (x > a);
if (cond_a_nonpositive || cond_x_nonpositive) return;
if (cond_large_x_a) {
double _r_a = 0, _r_x = 0;
inc_gamma_c_pullback(a, x, -_d_y, &_r_a, &_r_x);
*_d_a += _r_a;
*_d_x += _r_x;
return;
}
double ax = a * std::log(x) - x - std::lgamma(a);
if (ax < -kMAXLOG) return;
ax = std::exp(ax);
double r = a, c = 1.0, ans = 1.0;
unsigned long iterations = 0;
do {
iterations++;
r += 1.0;
c *= x / r;
ans += c;
} while (c / ans > kMACHEP);
// Apply derivatives on the accumulated results
_d_ans += _d_y / a * ax;
_d_ax += ans * _d_y / a;
*_d_a += _d_y * -(ans * ax / (a * a));
while (iterations--) {
ans -= c;
_d_c += _d_ans;
double temp_c = c;
c /= (x / r);
*_d_x += temp_c * _d_c / r;
_d_r += temp_c * _d_c * -(x / (r * r));
r -= 1.0;
}
*_d_a += _d_r;
_d_ax += std::exp(ax) * _d_ax;
ax -= a * std::log(x);
*_d_a += _d_ax * std::log(x) - _d_ax * std::lgamma(a);
*_d_x += _d_ax * (a / x) - _d_ax;
}There was a problem hiding this comment.
No wait, it can't. the digamma doesn't event appear there anymore. The code can't be simplified like that, getting rid of an important inner derivative. I have however applied some manual improvements to the code now, in particular to get rid of goto statements.
|
Is there a way to test this improvement, e.g. verifying no warning is printed? |
This is to avoid any num-diff fallback in RooFit, which results in annoying warnings for the user. A new function `ROOT::Math::digamma` is added to the public interface, which wraps `gsl_sf_psi`. The digamma function is the derivative of `lgamma`, so it is used in `CladDerivator.h` to define the derivatives of `TMath::LnGamma` and the related gamma funcitons that are used to define Poisson cdfs.
|
Good question, probably there is a way to see that no warnings are emitted by clang. I'll see tomorrow. |
In particular, to get rid of `goto` statements.
ee2e502 to
5f9203c
Compare
vgvassilev
left a comment
There was a problem hiding this comment.
LGTM with a grain of salt as I believe the code can be simplified a lot more. That’s also something that we need to do to the generated code in clad in the future.
CC: @PetroZarytskyi
This is to avoid any num-diff fallback in RooFit, which results in annoying warnings for the user.
A new function
ROOT::Math::digammais added to the public interface, which wrapsgsl_sf_psi. The digamma function is the derivative oflgamma, so it is used inCladDerivator.hto define the derivatives ofTMath::LnGammaand the related gamma functions that are used to define Poisson cdfs.