Skip to content

Commit

Permalink
Merge pull request #2962 from stan-dev/hyper-1f0
Browse files Browse the repository at this point in the history
Add Hypergeometric 1F0 function and gradients
  • Loading branch information
syclik committed Apr 2, 2024
2 parents 35d6d53 + 6e27745 commit 1f94ed3
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 0 deletions.
1 change: 1 addition & 0 deletions stan/math/fwd/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include <stan/math/fwd/fun/gamma_p.hpp>
#include <stan/math/fwd/fun/gamma_q.hpp>
#include <stan/math/fwd/fun/grad_inc_beta.hpp>
#include <stan/math/fwd/fun/hypergeometric_1F0.hpp>
#include <stan/math/fwd/fun/hypergeometric_2F1.hpp>
#include <stan/math/fwd/fun/hypergeometric_pFq.hpp>
#include <stan/math/fwd/fun/hypot.hpp>
Expand Down
46 changes: 46 additions & 0 deletions stan/math/fwd/fun/hypergeometric_1F0.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#ifndef STAN_MATH_FWD_FUN_HYPERGEOMETRIC_1F0_HPP
#define STAN_MATH_FWD_FUN_HYPERGEOMETRIC_1F0_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/hypergeometric_1F0.hpp>
#include <stan/math/fwd/core.hpp>

namespace stan {
namespace math {

/**
* Returns the Hypergeometric 1F0 function applied to the
* input arguments:
* \f$ _1F_0(a;;z) = \sum_{k=1}^{\infty}\frac{\left(a\right)_kz^k}{k!}\f$
*
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial a} =
* -\left(1-z\right)^{-a}\log\left(1 - z\right) \f$
*
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial z} =
* a\left(1-z\right)^{-a-1} \f$
*
* @tparam Ta Fvar or arithmetic type of 'a' argument
* @tparam Tz Fvar or arithmetic type of 'z' argument
* @param[in] a Scalar 'a' argument
* @param[in] z Scalar z argument
* @return Hypergeometric 1F0 function
*/
template <typename Ta, typename Tz, typename FvarT = return_type_t<Ta, Tz>,
require_all_stan_scalar_t<Ta, Tz>* = nullptr,
require_any_fvar_t<Ta, Tz>* = nullptr>
FvarT hypergeometric_1f0(const Ta& a, const Tz& z) {
partials_type_t<Ta> a_val = value_of(a);
partials_type_t<Tz> z_val = value_of(z);
FvarT rtn = FvarT(hypergeometric_1f0(a_val, z_val), 0.0);
if (!is_constant_all<Ta>::value) {
rtn.d_ += forward_as<FvarT>(a).d() * -rtn.val() * log1m(z_val);
}
if (!is_constant_all<Tz>::value) {
rtn.d_ += forward_as<FvarT>(z).d() * rtn.val() * a_val * inv(1 - z_val);
}
return rtn;
}

} // namespace math
} // namespace stan
#endif
1 change: 1 addition & 0 deletions stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
#include <stan/math/prim/fun/grad_reg_inc_gamma.hpp>
#include <stan/math/prim/fun/grad_reg_lower_inc_gamma.hpp>
#include <stan/math/prim/fun/head.hpp>
#include <stan/math/prim/fun/hypergeometric_1F0.hpp>
#include <stan/math/prim/fun/hypergeometric_2F1.hpp>
#include <stan/math/prim/fun/hypergeometric_2F2.hpp>
#include <stan/math/prim/fun/hypergeometric_3F2.hpp>
Expand Down
40 changes: 40 additions & 0 deletions stan/math/prim/fun/hypergeometric_1F0.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#ifndef STAN_MATH_PRIM_FUN_HYPERGEOMETRIC_1F0_HPP
#define STAN_MATH_PRIM_FUN_HYPERGEOMETRIC_1F0_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err/check_less.hpp>
#include <stan/math/prim/fun/boost_policy.hpp>
#include <boost/math/special_functions/hypergeometric_1F0.hpp>
#include <cmath>

namespace stan {
namespace math {

/**
* Returns the Hypergeometric 1F0 function applied to the
* input arguments:
* \f$ _1F_0(a;;z) = \sum_{k=1}^{\infty}\frac{\left(a\right)_kz^k}{k!}\f$
*
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial a} =
* -\left(1-z\right)^{-a}\log\left(1 - z\right) \f$
*
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial z} =
* a\left(1-z\right)^{-a-1} \f$
*
* @tparam Ta Arithmetic type of 'a' argument
* @tparam Tz Arithmetic type of 'z' argument
* @param[in] a Scalar 'a' argument
* @param[in] z Scalar z argument
* @return Hypergeometric 1F0 function
*/
template <typename Ta, typename Tz, require_all_arithmetic_t<Ta, Tz>* = nullptr>
return_type_t<Ta, Tz> hypergeometric_1f0(const Ta& a, const Tz& z) {
constexpr const char* function = "hypergeometric_1f0";
check_less("hypergeometric_1f0", "abs(z)", std::fabs(z), 1.0);

return boost::math::hypergeometric_1F0(a, z, boost_policy_t<>());
}

} // namespace math
} // namespace stan
#endif
1 change: 1 addition & 0 deletions stan/math/rev/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
#include <stan/math/rev/fun/gp_periodic_cov.hpp>
#include <stan/math/rev/fun/grad.hpp>
#include <stan/math/rev/fun/grad_inc_beta.hpp>
#include <stan/math/rev/fun/hypergeometric_1F0.hpp>
#include <stan/math/rev/fun/hypergeometric_2F1.hpp>
#include <stan/math/rev/fun/hypergeometric_pFq.hpp>
#include <stan/math/rev/fun/hypot.hpp>
Expand Down
50 changes: 50 additions & 0 deletions stan/math/rev/fun/hypergeometric_1F0.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#ifndef STAN_MATH_REV_FUN_HYPERGEOMETRIC_1F0_HPP
#define STAN_MATH_REV_FUN_HYPERGEOMETRIC_1F0_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/hypergeometric_1F0.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/fun/log1m.hpp>
#include <stan/math/prim/fun/inv.hpp>
#include <stan/math/rev/core.hpp>

namespace stan {
namespace math {

/**
* Returns the Hypergeometric 1F0 function applied to the
* input arguments:
* \f$ _1F_0(a;;z) = \sum_{k=1}^{\infty}\frac{\left(a\right)_kz^k}{k!}\f$
*
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial a} =
* -\left(1-z\right)^{-a}\log\left(1 - z\right) \f$
*
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial z} =
* a\left(1-z\right)^{-a-1} \f$
*
* @tparam Ta Var or arithmetic type of 'a' argument
* @tparam Tz Var or arithmetic type of 'z' argument
* @param[in] a Scalar 'a' argument
* @param[in] z Scalar z argument
* @return Hypergeometric 1F0 function
*/
template <typename Ta, typename Tz,
require_all_stan_scalar_t<Ta, Tz>* = nullptr,
require_any_var_t<Ta, Tz>* = nullptr>
var hypergeometric_1f0(const Ta& a, const Tz& z) {
double a_val = value_of(a);
double z_val = value_of(z);
double rtn = hypergeometric_1f0(a_val, z_val);
return make_callback_var(rtn, [rtn, a, z, a_val, z_val](auto& vi) mutable {
if (!is_constant_all<Ta>::value) {
forward_as<var>(a).adj() += vi.adj() * -rtn * log1m(z_val);
}
if (!is_constant_all<Tz>::value) {
forward_as<var>(z).adj() += vi.adj() * rtn * a_val * inv(1 - z_val);
}
});
}

} // namespace math
} // namespace stan
#endif
15 changes: 15 additions & 0 deletions test/unit/math/mix/fun/hypergeometric_1F0_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include <test/unit/math/test_ad.hpp>

TEST(mathMixScalFun, hypergeometric_1f0) {
auto f = [](const auto& x1, const auto& x2) {
using stan::math::hypergeometric_1f0;
return hypergeometric_1f0(x1, x2);
};

stan::test::expect_ad(f, 5, 0.3);
stan::test::expect_ad(f, 3.4, 0.9);
stan::test::expect_ad(f, 3.4, 0.1);
stan::test::expect_ad(f, 5, -0.7);
stan::test::expect_ad(f, 7, -0.1);
stan::test::expect_ad(f, 2.8, 0.8);
}
23 changes: 23 additions & 0 deletions test/unit/math/prim/fun/hypergeometric_1F0_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include <stan/math/prim.hpp>
#include <gtest/gtest.h>
#include <cmath>
#include <limits>

TEST(MathFunctions, hypergeometric_1f0Double) {
using stan::math::hypergeometric_1f0;
using stan::math::inv;

EXPECT_FLOAT_EQ(4.62962962963, hypergeometric_1f0(3, 0.4));
EXPECT_FLOAT_EQ(0.510204081633, hypergeometric_1f0(2, -0.4));
EXPECT_FLOAT_EQ(300.906354890, hypergeometric_1f0(16.0, 0.3));
EXPECT_FLOAT_EQ(0.531441, hypergeometric_1f0(-6.0, 0.1));
}

TEST(MathFunctions, hypergeometric_1f0_throw) {
using stan::math::hypergeometric_1f0;

EXPECT_THROW(hypergeometric_1f0(2.1, 1.0), std::domain_error);
EXPECT_THROW(hypergeometric_1f0(0.5, 1.5), std::domain_error);
EXPECT_THROW(hypergeometric_1f0(0.5, -1.0), std::domain_error);
EXPECT_THROW(hypergeometric_1f0(0.5, -1.5), std::domain_error);
}

0 comments on commit 1f94ed3

Please sign in to comment.