Skip to content

[math] Support AD for TMath::LnGamma using functions from GSL#16747

Merged
guitargeek merged 2 commits intoroot-project:masterfrom
guitargeek:poisson_ad
Oct 28, 2024
Merged

[math] Support AD for TMath::LnGamma using functions from GSL#16747
guitargeek merged 2 commits intoroot-project:masterfrom
guitargeek:poisson_ad

Conversation

@guitargeek
Copy link
Copy Markdown
Contributor

This is to avoid any num-diff fallback in RooFit, which results in annoying warnings for the user.

A new function ROOT::Math::digamma is added to the public interface, which wraps gsl_sf_psi. The digamma function is the derivative of lgamma, so it is used in CladDerivator.h to define the derivatives of TMath::LnGamma and the related gamma functions that are used to define Poisson cdfs.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Oct 26, 2024

Test Results

    18 files      18 suites   3d 17h 0m 4s ⏱️
 2 696 tests  2 696 ✅ 0 💤 0 ❌
46 047 runs  46 047 ✅ 0 💤 0 ❌

Results for commit 5f9203c.

♻️ This comment has been updated with latest results.


inline void inc_gamma_c_pullback(double a, double x, double _d_y, double *_d_a, double *_d_x);

inline void inc_gamma_pullback(double a, double x, double _d_y, double *_d_a, double *_d_x)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would that chatgpt suggestion work:

inline void inc_gamma_pullback(double a, double x, double _d_y, double *_d_a, double *_d_x) {
    // Constants
    constexpr double kMACHEP = 1.11022302462515654042363166809e-16;
    constexpr double kMAXLOG = 709.782712893383973096206318587;
    
    double _d_ans = 0, _d_ax = 0, _d_c = 0, _d_r = 0;
    bool cond_a_nonpositive = (a <= 0);
    bool cond_x_nonpositive = (x <= 0);
    bool cond_large_x_a = (x > 1.0) && (x > a);

    if (cond_a_nonpositive || cond_x_nonpositive) return;

    if (cond_large_x_a) {
        double _r_a = 0, _r_x = 0;
        inc_gamma_c_pullback(a, x, -_d_y, &_r_a, &_r_x);
        *_d_a += _r_a;
        *_d_x += _r_x;
        return;
    }

    double ax = a * std::log(x) - x - std::lgamma(a);
    if (ax < -kMAXLOG) return;

    ax = std::exp(ax);
    double r = a, c = 1.0, ans = 1.0;
    unsigned long iterations = 0;

    do {
        iterations++;
        r += 1.0;
        c *= x / r;
        ans += c;
    } while (c / ans > kMACHEP);

    // Apply derivatives on the accumulated results
    _d_ans += _d_y / a * ax;
    _d_ax += ans * _d_y / a;
    *_d_a += _d_y * -(ans * ax / (a * a));

    while (iterations--) {
        ans -= c;
        _d_c += _d_ans;
        
        double temp_c = c;
        c /= (x / r);
        *_d_x += temp_c * _d_c / r;
        _d_r += temp_c * _d_c * -(x / (r * r));
        r -= 1.0;
    }

    *_d_a += _d_r;
    _d_ax += std::exp(ax) * _d_ax;
    
    ax -= a * std::log(x);
    *_d_a += _d_ax * std::log(x) - _d_ax * std::lgamma(a);
    *_d_x += _d_ax * (a / x) - _d_ax;
}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll check!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No wait, it can't. the digamma doesn't event appear there anymore. The code can't be simplified like that, getting rid of an important inner derivative. I have however applied some manual improvements to the code now, in particular to get rid of goto statements.

@dpiparo
Copy link
Copy Markdown
Member

dpiparo commented Oct 26, 2024

Is there a way to test this improvement, e.g. verifying no warning is printed?

This is to avoid any num-diff fallback in RooFit, which results in
annoying warnings for the user.

A new function `ROOT::Math::digamma` is added to the public interface,
which wraps `gsl_sf_psi`. The digamma function is the derivative of
`lgamma`, so it is used in `CladDerivator.h` to define the derivatives
of `TMath::LnGamma` and the related gamma funcitons that are used to
define Poisson cdfs.
@guitargeek
Copy link
Copy Markdown
Contributor Author

Good question, probably there is a way to see that no warnings are emitted by clang. I'll see tomorrow.

In particular, to get rid of `goto` statements.
Copy link
Copy Markdown
Member

@vgvassilev vgvassilev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a grain of salt as I believe the code can be simplified a lot more. That’s also something that we need to do to the generated code in clad in the future.

CC: @PetroZarytskyi

@guitargeek guitargeek merged commit 94fc289 into root-project:master Oct 28, 2024
@guitargeek guitargeek deleted the poisson_ad branch October 28, 2024 13:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants