Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Hypergeometric 1F0 function and gradients #2962

Merged
merged 11 commits into from
Apr 2, 2024
1 change: 1 addition & 0 deletions stan/math/fwd/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include <stan/math/fwd/fun/gamma_p.hpp>
#include <stan/math/fwd/fun/gamma_q.hpp>
#include <stan/math/fwd/fun/grad_inc_beta.hpp>
#include <stan/math/fwd/fun/hypergeometric_1F0.hpp>
#include <stan/math/fwd/fun/hypergeometric_2F1.hpp>
#include <stan/math/fwd/fun/hypergeometric_pFq.hpp>
#include <stan/math/fwd/fun/hypot.hpp>
Expand Down
46 changes: 46 additions & 0 deletions stan/math/fwd/fun/hypergeometric_1F0.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#ifndef STAN_MATH_FWD_FUN_HYPERGEOMETRIC_1F0_HPP
#define STAN_MATH_FWD_FUN_HYPERGEOMETRIC_1F0_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/hypergeometric_1F0.hpp>
#include <stan/math/fwd/core.hpp>

namespace stan {
namespace math {

/**
* Returns the Hypergeometric 1F0 function applied to the
* input arguments:
* \f$ _1F_0(a;;z) = \sum_{k=1}^{\infty}\frac{\left(a\right)_kz^k}{k!}\f$
*
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial a} =
* -\left(1-z\right)^{-a}\log\left(1 - z\right) \f$
*
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial z} =
* a\left(1-z\right)^{-a-1} \f$
*
* @tparam Ta Fvar or arithmetic type of 'a' argument
* @tparam Tz Fvar or arithmetic type of 'z' argument
* @param[in] a Scalar 'a' argument
* @param[in] z Scalar z argument
* @return Hypergeometric 1F0 function
*/
template <typename Ta, typename Tz, typename FvarT = return_type_t<Ta, Tz>,
require_all_stan_scalar_t<Ta, Tz>* = nullptr,
require_any_fvar_t<Ta, Tz>* = nullptr>
FvarT hypergeometric_1f0(const Ta& a, const Tz& z) {
partials_type_t<Ta> a_val = value_of(a);
partials_type_t<Tz> z_val = value_of(z);
FvarT rtn = FvarT(hypergeometric_1f0(a_val, z_val), 0.0);
if (!is_constant_all<Ta>::value) {
rtn.d_ += forward_as<FvarT>(a).d() * -rtn.val() * log1m(z_val);
}
if (!is_constant_all<Tz>::value) {
rtn.d_ += forward_as<FvarT>(z).d() * rtn.val() * a_val * inv(1 - z_val);
}
return rtn;
}

} // namespace math
} // namespace stan
#endif
1 change: 1 addition & 0 deletions stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
#include <stan/math/prim/fun/grad_reg_inc_gamma.hpp>
#include <stan/math/prim/fun/grad_reg_lower_inc_gamma.hpp>
#include <stan/math/prim/fun/head.hpp>
#include <stan/math/prim/fun/hypergeometric_1F0.hpp>
#include <stan/math/prim/fun/hypergeometric_2F1.hpp>
#include <stan/math/prim/fun/hypergeometric_2F2.hpp>
#include <stan/math/prim/fun/hypergeometric_3F2.hpp>
Expand Down
40 changes: 40 additions & 0 deletions stan/math/prim/fun/hypergeometric_1F0.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#ifndef STAN_MATH_PRIM_FUN_HYPERGEOMETRIC_1F0_HPP
#define STAN_MATH_PRIM_FUN_HYPERGEOMETRIC_1F0_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err/check_less.hpp>
#include <stan/math/prim/fun/boost_policy.hpp>
#include <boost/math/special_functions/hypergeometric_1F0.hpp>
#include <cmath>

namespace stan {
namespace math {

/**
* Returns the Hypergeometric 1F0 function applied to the
* input arguments:
* \f$ _1F_0(a;;z) = \sum_{k=1}^{\infty}\frac{\left(a\right)_kz^k}{k!}\f$
*
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial a} =
* -\left(1-z\right)^{-a}\log\left(1 - z\right) \f$
*
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial z} =
* a\left(1-z\right)^{-a-1} \f$
*
* @tparam Ta Arithmetic type of 'a' argument
* @tparam Tz Arithmetic type of 'z' argument
* @param[in] a Scalar 'a' argument
* @param[in] z Scalar z argument
* @return Hypergeometric 1F0 function
*/
template <typename Ta, typename Tz, require_all_arithmetic_t<Ta, Tz>* = nullptr>
return_type_t<Ta, Tz> hypergeometric_1f0(const Ta& a, const Tz& z) {
constexpr const char* function = "hypergeometric_1f0";
check_less("hypergeometric_1f0", "abs(z)", std::fabs(z), 1.0);

return boost::math::hypergeometric_1F0(a, z, boost_policy_t<>());
}

} // namespace math
} // namespace stan
#endif
1 change: 1 addition & 0 deletions stan/math/rev/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
#include <stan/math/rev/fun/gp_periodic_cov.hpp>
#include <stan/math/rev/fun/grad.hpp>
#include <stan/math/rev/fun/grad_inc_beta.hpp>
#include <stan/math/rev/fun/hypergeometric_1F0.hpp>
#include <stan/math/rev/fun/hypergeometric_2F1.hpp>
#include <stan/math/rev/fun/hypergeometric_pFq.hpp>
#include <stan/math/rev/fun/hypot.hpp>
Expand Down
50 changes: 50 additions & 0 deletions stan/math/rev/fun/hypergeometric_1F0.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#ifndef STAN_MATH_REV_FUN_HYPERGEOMETRIC_1F0_HPP
#define STAN_MATH_REV_FUN_HYPERGEOMETRIC_1F0_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/hypergeometric_1F0.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/fun/log1m.hpp>
#include <stan/math/prim/fun/inv.hpp>
#include <stan/math/rev/core.hpp>

namespace stan {
namespace math {

/**
* Returns the Hypergeometric 1F0 function applied to the
* input arguments:
* \f$ _1F_0(a;;z) = \sum_{k=1}^{\infty}\frac{\left(a\right)_kz^k}{k!}\f$
*
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial a} =
* -\left(1-z\right)^{-a}\log\left(1 - z\right) \f$
*
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial z} =
* a\left(1-z\right)^{-a-1} \f$
*
* @tparam Ta Var or arithmetic type of 'a' argument
* @tparam Tz Var or arithmetic type of 'z' argument
* @param[in] a Scalar 'a' argument
* @param[in] z Scalar z argument
* @return Hypergeometric 1F0 function
*/
template <typename Ta, typename Tz,
require_all_stan_scalar_t<Ta, Tz>* = nullptr,
require_any_var_t<Ta, Tz>* = nullptr>
var hypergeometric_1f0(const Ta& a, const Tz& z) {
double a_val = value_of(a);
double z_val = value_of(z);
double rtn = hypergeometric_1f0(a_val, z_val);
return make_callback_var(rtn, [rtn, a, z, a_val, z_val](auto& vi) mutable {
if (!is_constant_all<Ta>::value) {
forward_as<var>(a).adj() += vi.adj() * -rtn * log1m(z_val);
}
if (!is_constant_all<Tz>::value) {
forward_as<var>(z).adj() += vi.adj() * rtn * a_val * inv(1 - z_val);
}
});
}

} // namespace math
} // namespace stan
#endif
15 changes: 15 additions & 0 deletions test/unit/math/mix/fun/hypergeometric_1F0_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include <test/unit/math/test_ad.hpp>

TEST(mathMixScalFun, hypergeometric_1f0) {
auto f = [](const auto& x1, const auto& x2) {
using stan::math::hypergeometric_1f0;
return hypergeometric_1f0(x1, x2);
};

stan::test::expect_ad(f, 5, 0.3);
stan::test::expect_ad(f, 3.4, 0.9);
stan::test::expect_ad(f, 3.4, 0.1);
stan::test::expect_ad(f, 5, -0.7);
stan::test::expect_ad(f, 7, -0.1);
stan::test::expect_ad(f, 2.8, 0.8);
}
23 changes: 23 additions & 0 deletions test/unit/math/prim/fun/hypergeometric_1F0_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include <stan/math/prim.hpp>
#include <gtest/gtest.h>
#include <cmath>
#include <limits>

TEST(MathFunctions, hypergeometric_1f0Double) {
using stan::math::hypergeometric_1f0;
using stan::math::inv;

EXPECT_FLOAT_EQ(4.62962962963, hypergeometric_1f0(3, 0.4));
EXPECT_FLOAT_EQ(0.510204081633, hypergeometric_1f0(2, -0.4));
EXPECT_FLOAT_EQ(300.906354890, hypergeometric_1f0(16.0, 0.3));
EXPECT_FLOAT_EQ(0.531441, hypergeometric_1f0(-6.0, 0.1));
}

TEST(MathFunctions, hypergeometric_1f0_throw) {
using stan::math::hypergeometric_1f0;

EXPECT_THROW(hypergeometric_1f0(2.1, 1.0), std::domain_error);
EXPECT_THROW(hypergeometric_1f0(0.5, 1.5), std::domain_error);
EXPECT_THROW(hypergeometric_1f0(0.5, -1.0), std::domain_error);
EXPECT_THROW(hypergeometric_1f0(0.5, -1.5), std::domain_error);
}