Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Hypergeometric 1F0 function and gradients #2962

Merged
merged 11 commits into from
Apr 2, 2024
1 change: 1 addition & 0 deletions stan/math/fwd/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include <stan/math/fwd/fun/gamma_p.hpp>
#include <stan/math/fwd/fun/gamma_q.hpp>
#include <stan/math/fwd/fun/grad_inc_beta.hpp>
#include <stan/math/fwd/fun/hypergeometric_1F0.hpp>
#include <stan/math/fwd/fun/hypergeometric_2F1.hpp>
#include <stan/math/fwd/fun/hypergeometric_pFq.hpp>
#include <stan/math/fwd/fun/hypot.hpp>
Expand Down
46 changes: 46 additions & 0 deletions stan/math/fwd/fun/hypergeometric_1F0.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#ifndef STAN_MATH_FWD_FUN_HYPERGEOMETRIC_1F0_HPP
#define STAN_MATH_FWD_FUN_HYPERGEOMETRIC_1F0_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/hypergeometric_1F0.hpp>
#include <stan/math/fwd/core.hpp>

namespace stan {
namespace math {

/**
* Returns the Hypergeometric 1F0 function applied to the
* input arguments:
* \f$ _1F_0(a;;z) = \sum_{k=1}^{\infty}\frac{\left(a\right)_kz^k}{k!}\f$
*
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial a} =
* -\left(1-z\right)^{-a}\log\left(1 - z\right) \f$
*
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial z} =
* a\left(1-z\right)^{-a-1} \f$
*
* @tparam Ta Fvar or arithmetic type of 'a' argument
* @tparam Tz Fvar or arithmetic type of 'z' argument
* @param[in] a Scalar 'a' argument
* @param[in] z Scalar z argument
* @return Hypergeometric 1F0 function
*/
template <typename Ta, typename Tz, require_any_fvar_t<Ta, Tz>* = nullptr>
auto hypergeometric_1f0(const Ta& a, const Tz& z) {
andrjohns marked this conversation as resolved.
Show resolved Hide resolved
using FvarT = return_type_t<Ta, Tz>;

auto a_val = value_of(a);
andrjohns marked this conversation as resolved.
Show resolved Hide resolved
auto z_val = value_of(z);
FvarT rtn = FvarT(hypergeometric_1f0(a_val, z_val), 0.0);
if (!is_constant_all<Ta>::value) {
rtn.d_ += forward_as<FvarT>(a).d() * -rtn.val() * log1m(z_val);
}
if (!is_constant_all<Tz>::value) {
rtn.d_ += forward_as<FvarT>(z).d() * rtn.val() * a_val * inv(1 - z_val);
}
return rtn;
}

} // namespace math
} // namespace stan
#endif
1 change: 1 addition & 0 deletions stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
#include <stan/math/prim/fun/grad_reg_inc_gamma.hpp>
#include <stan/math/prim/fun/grad_reg_lower_inc_gamma.hpp>
#include <stan/math/prim/fun/head.hpp>
#include <stan/math/prim/fun/hypergeometric_1F0.hpp>
#include <stan/math/prim/fun/hypergeometric_2F1.hpp>
#include <stan/math/prim/fun/hypergeometric_2F2.hpp>
#include <stan/math/prim/fun/hypergeometric_3F2.hpp>
Expand Down
40 changes: 40 additions & 0 deletions stan/math/prim/fun/hypergeometric_1F0.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#ifndef STAN_MATH_PRIM_FUN_HYPERGEOMETRIC_1F0_HPP
#define STAN_MATH_PRIM_FUN_HYPERGEOMETRIC_1F0_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err/check_less.hpp>
#include <stan/math/prim/fun/boost_policy.hpp>
#include <boost/math/special_functions/hypergeometric_1F0.hpp>
#include <cmath>

namespace stan {
namespace math {

/**
* Returns the Hypergeometric 1F0 function applied to the
* input arguments:
* \f$ _1F_0(a;;z) = \sum_{k=1}^{\infty}\frac{\left(a\right)_kz^k}{k!}\f$
*
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial a} =
* -\left(1-z\right)^{-a}\log\left(1 - z\right) \f$
*
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial z} =
* a\left(1-z\right)^{-a-1} \f$
*
* @tparam Ta Arithmetic type of 'a' argument
* @tparam Tz Arithmetic type of 'z' argument
* @param[in] a Scalar 'a' argument
* @param[in] z Scalar z argument
* @return Hypergeometric 1F0 function
*/
template <typename Ta, typename Tz, require_all_arithmetic_t<Ta, Tz>* = nullptr>
auto hypergeometric_1f0(const Ta& a, const Tz& z) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use proper return type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the type assumptions being made, is Ta and Tz just double and double only?

If I'm reading this properly, this isn't vectorized or anything, is it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to allow all combinations of int/double (as well as size_t,float, etc.)

constexpr const char* function = "hypergeometric_1f0";
check_less("hypergeometric_1f0", "abs(z)", std::abs(z), 1.0);

return boost::math::hypergeometric_1F0(a, z, boost_policy_t<>());
}

} // namespace math
} // namespace stan
#endif
1 change: 1 addition & 0 deletions stan/math/rev/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
#include <stan/math/rev/fun/gp_periodic_cov.hpp>
#include <stan/math/rev/fun/grad.hpp>
#include <stan/math/rev/fun/grad_inc_beta.hpp>
#include <stan/math/rev/fun/hypergeometric_1F0.hpp>
#include <stan/math/rev/fun/hypergeometric_2F1.hpp>
#include <stan/math/rev/fun/hypergeometric_pFq.hpp>
#include <stan/math/rev/fun/hypot.hpp>
Expand Down
48 changes: 48 additions & 0 deletions stan/math/rev/fun/hypergeometric_1F0.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#ifndef STAN_MATH_REV_FUN_HYPERGEOMETRIC_1F0_HPP
#define STAN_MATH_REV_FUN_HYPERGEOMETRIC_1F0_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/hypergeometric_1F0.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/fun/log1m.hpp>
#include <stan/math/prim/fun/inv.hpp>
#include <stan/math/rev/core.hpp>

namespace stan {
namespace math {

/**
* Returns the Hypergeometric 1F0 function applied to the
* input arguments:
* \f$ _1F_0(a;;z) = \sum_{k=1}^{\infty}\frac{\left(a\right)_kz^k}{k!}\f$
*
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial a} =
* -\left(1-z\right)^{-a}\log\left(1 - z\right) \f$
*
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial z} =
* a\left(1-z\right)^{-a-1} \f$
*
* @tparam Ta Var or arithmetic type of 'a' argument
* @tparam Tz Var or arithmetic type of 'z' argument
* @param[in] a Scalar 'a' argument
* @param[in] z Scalar z argument
* @return Hypergeometric 1F0 function
*/
template <typename Ta, typename Tz, require_any_var_t<Ta, Tz>* = nullptr>
var hypergeometric_1f0(const Ta& a, const Tz& z) {
double a_val = value_of(a);
double z_val = value_of(z);
double rtn = hypergeometric_1f0(a_val, z_val);
return make_callback_var(rtn, [rtn, a, z, a_val, z_val](auto& vi) mutable {
if (!is_constant_all<Ta>::value) {
forward_as<var>(a).adj() += vi.adj() * -rtn * log1m(z_val);
}
if (!is_constant_all<Tz>::value) {
forward_as<var>(z).adj() += vi.adj() * rtn * a_val * inv(1 - z_val);
}
});
}

} // namespace math
} // namespace stan
#endif
15 changes: 15 additions & 0 deletions test/unit/math/mix/fun/hypergeometric_1F0_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include <test/unit/math/test_ad.hpp>

TEST(mathMixScalFun, hypergeometric_1f0) {
auto f = [](const auto& x1, const auto& x2) {
using stan::math::hypergeometric_1f0;
return hypergeometric_1f0(x1, x2);
};

stan::test::expect_ad(f, 5, 0.3);
stan::test::expect_ad(f, 3.4, 0.9);
stan::test::expect_ad(f, 3.4, 0.1);
stan::test::expect_ad(f, 5, -0.7);
stan::test::expect_ad(f, 7, -0.1);
stan::test::expect_ad(f, 2.8, 0.8);
}
23 changes: 23 additions & 0 deletions test/unit/math/prim/fun/hypergeometric_1F0_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include <stan/math/prim.hpp>
#include <gtest/gtest.h>
#include <cmath>
#include <limits>

TEST(MathFunctions, hypergeometric_1f0Double) {
using stan::math::hypergeometric_1f0;
using stan::math::inv;

EXPECT_FLOAT_EQ(4.62962962963, hypergeometric_1f0(3, 0.4));
EXPECT_FLOAT_EQ(0.510204081633, hypergeometric_1f0(2, -0.4));
EXPECT_FLOAT_EQ(300.906354890, hypergeometric_1f0(16.0, 0.3));
EXPECT_FLOAT_EQ(0.531441, hypergeometric_1f0(-6.0, 0.1));
}

TEST(MathFunctions, hypergeometric_1f0_throw) {
using stan::math::hypergeometric_1f0;

EXPECT_THROW(hypergeometric_1f0(2.1, 1.0), std::domain_error);
EXPECT_THROW(hypergeometric_1f0(0.5, 1.5), std::domain_error);
EXPECT_THROW(hypergeometric_1f0(0.5, -1.0), std::domain_error);
EXPECT_THROW(hypergeometric_1f0(0.5, -1.5), std::domain_error);
}