Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 28 additions & 61 deletions stan/math/prim/fun/log_gamma_q_dgamma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/digamma.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/fabs.hpp>
#include <stan/math/prim/fun/gamma_p.hpp>
#include <stan/math/prim/fun/gamma_q.hpp>
#include <stan/math/prim/fun/grad_reg_inc_gamma.hpp>
Expand All @@ -19,17 +20,6 @@
namespace stan {
namespace math {

/**
* Result structure containing log(Q(a,z)) and its gradient with respect to a.
*
* @tparam T return type
*/
template <typename T>
struct log_gamma_q_result {
T log_q; ///< log(Q(a,z)) where Q is upper regularized incomplete gamma
T dlog_q_da; ///< d/da log(Q(a,z))
};

namespace internal {

/**
Expand All @@ -40,50 +30,36 @@ namespace internal {
* @tparam T_z Type of value parameter z (double or fvar types)
* @param a Shape parameter
* @param z Value at which to evaluate
* @param precision Convergence threshold
* @param precision Convergence threshold, default of sqrt(machine_epsilon)
* @param max_steps Maximum number of continued fraction iterations
* @return log(Q(a,z)) with the return type of T_a and T_z
*/
template <typename T_a, typename T_z>
inline return_type_t<T_a, T_z> log_q_gamma_cf(const T_a& a, const T_z& z,
double precision = 1e-16,
double precision = 1.49012e-08,
int max_steps = 250) {
using stan::math::lgamma;
using stan::math::log;
using stan::math::value_of;
using std::fabs;
using T_return = return_type_t<T_a, T_z>;

const T_return log_prefactor = a * log(z) - z - lgamma(a);

T_return b = z + 1.0 - a;
T_return C = (fabs(value_of(b)) >= EPSILON) ? b : T_return(EPSILON);
T_return b_init = z + 1.0 - a;
T_return C = (fabs(value_of(b_init)) >= EPSILON) ? b_init : std::decay_t<decltype(b_init)>(EPSILON);
T_return D = 0.0;
T_return f = C;

for (int i = 1; i <= max_steps; ++i) {
T_a an = -i * (i - a);
b += 2.0;

const T_return b = b_init + 2.0 * i;
D = b + an * D;
if (fabs(D) < EPSILON) {
D = EPSILON;
}
D = (fabs(value_of(D)) >= EPSILON) ? D : std::decay_t<decltype(D)>(EPSILON);
C = b + an / C;
if (fabs(C) < EPSILON) {
C = EPSILON;
}

C = (fabs(value_of(C)) >= EPSILON) ? C : std::decay_t<decltype(C)>(EPSILON);
D = inv(D);
T_return delta = C * D;
const T_return delta = C * D;
f *= delta;

const double delta_m1 = value_of(fabs(value_of(delta) - 1.0));
const double delta_m1 = fabs(value_of(delta) - 1.0);
if (delta_m1 < precision) {
break;
}
}

return log_prefactor - log(f);
}

Expand All @@ -102,52 +78,43 @@ inline return_type_t<T_a, T_z> log_q_gamma_cf(const T_a& a, const T_z& z,
* @tparam T_z type of the value parameter
* @param a shape parameter (must be positive)
* @param z value parameter (must be non-negative)
* @param precision convergence threshold
* @param precision convergence threshold, default of sqrt(machine_epsilon)
* @param max_steps maximum iterations for continued fraction
* @return structure containing log(Q(a,z)) and d/da log(Q(a,z))
*/
template <typename T_a, typename T_z>
inline log_gamma_q_result<return_type_t<T_a, T_z>> log_gamma_q_dgamma(
const T_a& a, const T_z& z, double precision = 1e-16, int max_steps = 250) {
using std::exp;
using std::log;
inline std::pair<return_type_t<T_a, T_z>, return_type_t<T_a, T_z>> log_gamma_q_dgamma(
const T_a& a, const T_z& z, double precision = 1.49012e-08, int max_steps = 250) {
using T_return = return_type_t<T_a, T_z>;

const double a_dbl = value_of(a);
const double z_dbl = value_of(z);

log_gamma_q_result<T_return> result;

const double a_val = value_of(a);
const double z_val = value_of(z);
// For z > a + 1, use continued fraction for better numerical stability
if (z_dbl > a_dbl + 1.0) {
result.log_q = internal::log_q_gamma_cf(a_dbl, z_dbl, precision, max_steps);

if (z_val > a_val + 1.0) {
std::pair<T_return, T_return> result{internal::log_q_gamma_cf(a_val, z_val, precision, max_steps), 0.0};
// For gradient, use: d/da log(Q) = (1/Q) * dQ/da
// grad_reg_inc_gamma computes dQ/da
const double Q_val = exp(result.log_q);
const T_return Q_val = exp(result.first);
const double dQ_da
= grad_reg_inc_gamma(a_dbl, z_dbl, tgamma(a_dbl), digamma(a_dbl));
result.dlog_q_da = dQ_da / Q_val;

= grad_reg_inc_gamma(a_val, z_val, tgamma(a_val), digamma(a_val));
result.second = dQ_da / Q_val;
return result;
} else {
// For z <= a + 1, use log1m(P(a,z)) for better numerical accuracy
const double P_val = gamma_p(a_dbl, z_dbl);
result.log_q = log1m(P_val);

const double P_val = gamma_p(a_val, z_val);
std::pair<T_return, T_return> result{log1m(P_val), 0.0};
// Gradient: d/da log(Q) = (1/Q) * dQ/da
// grad_reg_inc_gamma computes dQ/da
const double Q_val = exp(result.log_q);
const T_return Q_val = exp(result.first);
if (Q_val > 0) {
const double dQ_da
= grad_reg_inc_gamma(a_dbl, z_dbl, tgamma(a_dbl), digamma(a_dbl));
result.dlog_q_da = dQ_da / Q_val;
= grad_reg_inc_gamma(a_val, z_val, tgamma(a_val), digamma(a_val));
result.second = dQ_da / Q_val;
} else {
// Fallback if Q rounds to zero - use asymptotic approximation
result.dlog_q_da = log(z_dbl) - digamma(a_dbl);
result.second = log(z_val) - digamma(a_val);
}
return result;
}

return result;
}

} // namespace math
Expand Down
127 changes: 59 additions & 68 deletions stan/math/prim/prob/gamma_lccdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,79 +22,77 @@
#include <stan/math/prim/fun/log_gamma_q_dgamma.hpp>
#include <stan/math/prim/functor/partials_propagator.hpp>
#include <cmath>
#include <optional>

namespace stan {
namespace math {
namespace internal {
template <typename T>
struct Q_eval {
T log_Q{0.0};
T dlogQ_dalpha{0.0};
bool ok{false};
};

/**
* Computes log q and d(log q) / d(alpha) using continued fraction.
*/
template <typename T, typename T_shape, bool any_fvar, bool partials_fvar>
static inline Q_eval<T> eval_q_cf(const T& alpha, const T& beta_y) {
Q_eval<T> out;
template <bool any_fvar, bool partials_fvar, typename T_shape, typename T1, typename T2>
inline std::optional<std::pair<return_type_t<T1, T2>, return_type_t<T1, T2>>>
eval_q_cf(const T1& alpha, const T2& beta_y) {
using scalar_t = return_type_t<T1, T2>;
using ret_t = std::pair<scalar_t, scalar_t>;
if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
auto log_q_result
std::pair<double, double> log_q_result
= log_gamma_q_dgamma(value_of_rec(alpha), value_of_rec(beta_y));
out.log_Q = log_q_result.log_q;
out.dlogQ_dalpha = log_q_result.dlog_q_da;
if (likely(std::isfinite(value_of_rec(log_q_result.first)))) {
return std::optional{log_q_result};
} else {
return std::optional<ret_t>{std::nullopt};
}
} else {
out.log_Q = internal::log_q_gamma_cf(alpha, beta_y);
ret_t out{internal::log_q_gamma_cf(alpha, beta_y), 0.0};
if (unlikely(!std::isfinite(value_of_rec(out.first)))) {
return std::optional<ret_t>{std::nullopt};
}
if constexpr (is_autodiff_v<T_shape>) {
if constexpr (!partials_fvar) {
out.dlogQ_dalpha
out.second
= grad_reg_inc_gamma(alpha, beta_y, tgamma(alpha), digamma(alpha))
/ exp(out.log_Q);
/ exp(out.first);
} else {
T alpha_unit = alpha;
auto alpha_unit = alpha;
alpha_unit.d_ = 1;
T beta_y_unit = beta_y;
auto beta_y_unit = beta_y;
beta_y_unit.d_ = 0;
T log_Q_fvar = internal::log_q_gamma_cf(alpha_unit, beta_y_unit);
out.dlogQ_dalpha = log_Q_fvar.d_;
auto log_Q_fvar = internal::log_q_gamma_cf(alpha_unit, beta_y_unit);
out.second = log_Q_fvar.d_;
}
}
return std::optional{out};
}

out.ok = std::isfinite(value_of_rec(out.log_Q));
return out;
}

/**
* Computes log q and d(log q) / d(alpha) using log1m.
*/
template <typename T, typename T_shape, bool partials_fvar>
static inline Q_eval<T> eval_q_log1m(const T& alpha, const T& beta_y) {
Q_eval<T> out;
out.log_Q = log1m(gamma_p(alpha, beta_y));

if (!std::isfinite(value_of_rec(out.log_Q))) {
out.ok = false;
return out;
template <bool partials_fvar, typename T_shape, typename T1, typename T2>
inline std::optional<std::pair<return_type_t<T1, T2>, return_type_t<T1, T2>>>
eval_q_log1m(const T1& alpha, const T2& beta_y) {
using scalar_t = return_type_t<T1, T2>;
using ret_t = std::pair<scalar_t, scalar_t>;
ret_t out{log1m(gamma_p(alpha, beta_y)), 0.0};
if (unlikely(!std::isfinite(value_of_rec(out.first)))) {
return std::optional<ret_t>{std::nullopt};
}

if constexpr (is_autodiff_v<T_shape>) {
if constexpr (partials_fvar) {
T alpha_unit = alpha;
auto alpha_unit = alpha;
alpha_unit.d_ = 1;
T beta_unit = beta_y;
auto beta_unit = beta_y;
beta_unit.d_ = 0;
T log_Q_fvar = log1m(gamma_p(alpha_unit, beta_unit));
out.dlogQ_dalpha = log_Q_fvar.d_;
auto log_Q_fvar = log1m(gamma_p(alpha_unit, beta_unit));
out.second = log_Q_fvar.d_;
} else {
out.dlogQ_dalpha
= -grad_reg_lower_inc_gamma(alpha, beta_y) / exp(out.log_Q);
out.second
= -grad_reg_lower_inc_gamma(alpha, beta_y) / exp(out.first);
}
}

out.ok = true;
return out;
return std::optional{out};
}
} // namespace internal

Expand Down Expand Up @@ -137,63 +135,56 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
for (size_t n = 0; n < N; n++) {
// Explicit results for extreme values
// The gradients are technically ill-defined, but treated as zero
const T_partials_return y_dbl = y_vec.val(n);
if (y_dbl == 0.0) {
const T_partials_return y_val = y_vec.val(n);
if (y_val == 0.0) {
continue;
}
if (y_dbl == INFTY) {
if (y_val == INFTY) {
return ops_partials.build(negative_infinity());
}

const T_partials_return alpha_dbl = alpha_vec.val(n);
const T_partials_return beta_dbl = beta_vec.val(n);
const T_partials_return alpha_val = alpha_vec.val(n);
const T_partials_return beta_val = beta_vec.val(n);

const T_partials_return beta_y = beta_dbl * y_dbl;
const T_partials_return beta_y = beta_val * y_val;
if (beta_y == INFTY) {
return ops_partials.build(negative_infinity());
}

const bool use_continued_fraction = beta_y > alpha_dbl + 1.0;
internal::Q_eval<T_partials_return> result;
if (use_continued_fraction) {
result = internal::eval_q_cf<T_partials_return, T_shape, any_fvar,
partials_fvar>(alpha_dbl, beta_y);
std::optional<std::pair<T_partials_return, T_partials_return>> result;
if (beta_y > alpha_val + 1.0) {
result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(alpha_val, beta_y);
} else {
result
= internal::eval_q_log1m<T_partials_return, T_shape, partials_fvar>(
alpha_dbl, beta_y);

if (!result.ok && beta_y > 0.0) {
result = internal::eval_q_log1m<partials_fvar, T_shape>(alpha_val, beta_y);
if (!result && beta_y > 0.0) {
// Fallback to continued fraction if log1m fails
result = internal::eval_q_cf<T_partials_return, T_shape, any_fvar,
partials_fvar>(alpha_dbl, beta_y);
result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(alpha_val, beta_y);
}
}
if (!result.ok) {
if (unlikely(!result)) {
return ops_partials.build(negative_infinity());
}

P += result.log_Q;
P += result->first;

if constexpr (is_autodiff_v<T_y> || is_autodiff_v<T_inv_scale>) {
const T_partials_return log_y = log(y_dbl);
const T_partials_return alpha_minus_one = fma(alpha_dbl, log_y, -log_y);
const T_partials_return log_y = log(y_val);
const T_partials_return alpha_minus_one = fma(alpha_val, log_y, -log_y);

const T_partials_return log_pdf = alpha_dbl * log(beta_dbl)
- lgamma(alpha_dbl) + alpha_minus_one
const T_partials_return log_pdf = alpha_val * log(beta_val)
- lgamma(alpha_val) + alpha_minus_one
- beta_y;

const T_partials_return hazard = exp(log_pdf - result.log_Q); // f/Q
const T_partials_return hazard = exp(log_pdf - result->first); // f/Q

if constexpr (is_autodiff_v<T_y>) {
partials<0>(ops_partials)[n] -= hazard;
}
if constexpr (is_autodiff_v<T_inv_scale>) {
partials<2>(ops_partials)[n] -= (y_dbl / beta_dbl) * hazard;
partials<2>(ops_partials)[n] -= (y_val / beta_val) * hazard;
}
}
if constexpr (is_autodiff_v<T_shape>) {
partials<1>(ops_partials)[n] += result.dlogQ_dalpha;
partials<1>(ops_partials)[n] += result->second;
}
}
return ops_partials.build(P);
Expand Down