-
-
Notifications
You must be signed in to change notification settings - Fork 183
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2962 from stan-dev/hyper-1f0
Add Hypergeometric 1F0 function and gradients
- Loading branch information
Showing
8 changed files
with
177 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#ifndef STAN_MATH_FWD_FUN_HYPERGEOMETRIC_1F0_HPP | ||
#define STAN_MATH_FWD_FUN_HYPERGEOMETRIC_1F0_HPP | ||
|
||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/prim/fun/hypergeometric_1F0.hpp> | ||
#include <stan/math/fwd/core.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Returns the Hypergeometric 1F0 function applied to the | ||
* input arguments: | ||
* \f$ _1F_0(a;;z) = \sum_{k=1}^{\infty}\frac{\left(a\right)_kz^k}{k!}\f$ | ||
* | ||
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial a} = | ||
* -\left(1-z\right)^{-a}\log\left(1 - z\right) \f$ | ||
* | ||
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial z} = | ||
* a\left(1-z\right)^{-a-1} \f$ | ||
* | ||
* @tparam Ta Fvar or arithmetic type of 'a' argument | ||
* @tparam Tz Fvar or arithmetic type of 'z' argument | ||
* @param[in] a Scalar 'a' argument | ||
* @param[in] z Scalar z argument | ||
* @return Hypergeometric 1F0 function | ||
*/ | ||
template <typename Ta, typename Tz, typename FvarT = return_type_t<Ta, Tz>, | ||
require_all_stan_scalar_t<Ta, Tz>* = nullptr, | ||
require_any_fvar_t<Ta, Tz>* = nullptr> | ||
FvarT hypergeometric_1f0(const Ta& a, const Tz& z) { | ||
partials_type_t<Ta> a_val = value_of(a); | ||
partials_type_t<Tz> z_val = value_of(z); | ||
FvarT rtn = FvarT(hypergeometric_1f0(a_val, z_val), 0.0); | ||
if (!is_constant_all<Ta>::value) { | ||
rtn.d_ += forward_as<FvarT>(a).d() * -rtn.val() * log1m(z_val); | ||
} | ||
if (!is_constant_all<Tz>::value) { | ||
rtn.d_ += forward_as<FvarT>(z).d() * rtn.val() * a_val * inv(1 - z_val); | ||
} | ||
return rtn; | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#ifndef STAN_MATH_PRIM_FUN_HYPERGEOMETRIC_1F0_HPP | ||
#define STAN_MATH_PRIM_FUN_HYPERGEOMETRIC_1F0_HPP | ||
|
||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/prim/err/check_less.hpp> | ||
#include <stan/math/prim/fun/boost_policy.hpp> | ||
#include <boost/math/special_functions/hypergeometric_1F0.hpp> | ||
#include <cmath> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Returns the Hypergeometric 1F0 function applied to the | ||
* input arguments: | ||
* \f$ _1F_0(a;;z) = \sum_{k=1}^{\infty}\frac{\left(a\right)_kz^k}{k!}\f$ | ||
* | ||
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial a} = | ||
* -\left(1-z\right)^{-a}\log\left(1 - z\right) \f$ | ||
* | ||
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial z} = | ||
* a\left(1-z\right)^{-a-1} \f$ | ||
* | ||
* @tparam Ta Arithmetic type of 'a' argument | ||
* @tparam Tz Arithmetic type of 'z' argument | ||
* @param[in] a Scalar 'a' argument | ||
* @param[in] z Scalar z argument | ||
* @return Hypergeometric 1F0 function | ||
*/ | ||
template <typename Ta, typename Tz, require_all_arithmetic_t<Ta, Tz>* = nullptr> | ||
return_type_t<Ta, Tz> hypergeometric_1f0(const Ta& a, const Tz& z) { | ||
constexpr const char* function = "hypergeometric_1f0"; | ||
check_less("hypergeometric_1f0", "abs(z)", std::fabs(z), 1.0); | ||
|
||
return boost::math::hypergeometric_1F0(a, z, boost_policy_t<>()); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#ifndef STAN_MATH_REV_FUN_HYPERGEOMETRIC_1F0_HPP | ||
#define STAN_MATH_REV_FUN_HYPERGEOMETRIC_1F0_HPP | ||
|
||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/prim/fun/hypergeometric_1F0.hpp> | ||
#include <stan/math/prim/fun/value_of.hpp> | ||
#include <stan/math/prim/fun/log1m.hpp> | ||
#include <stan/math/prim/fun/inv.hpp> | ||
#include <stan/math/rev/core.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Returns the Hypergeometric 1F0 function applied to the | ||
* input arguments: | ||
* \f$ _1F_0(a;;z) = \sum_{k=1}^{\infty}\frac{\left(a\right)_kz^k}{k!}\f$ | ||
* | ||
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial a} = | ||
* -\left(1-z\right)^{-a}\log\left(1 - z\right) \f$ | ||
* | ||
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial z} = | ||
* a\left(1-z\right)^{-a-1} \f$ | ||
* | ||
* @tparam Ta Var or arithmetic type of 'a' argument | ||
* @tparam Tz Var or arithmetic type of 'z' argument | ||
* @param[in] a Scalar 'a' argument | ||
* @param[in] z Scalar z argument | ||
* @return Hypergeometric 1F0 function | ||
*/ | ||
template <typename Ta, typename Tz, | ||
require_all_stan_scalar_t<Ta, Tz>* = nullptr, | ||
require_any_var_t<Ta, Tz>* = nullptr> | ||
var hypergeometric_1f0(const Ta& a, const Tz& z) { | ||
double a_val = value_of(a); | ||
double z_val = value_of(z); | ||
double rtn = hypergeometric_1f0(a_val, z_val); | ||
return make_callback_var(rtn, [rtn, a, z, a_val, z_val](auto& vi) mutable { | ||
if (!is_constant_all<Ta>::value) { | ||
forward_as<var>(a).adj() += vi.adj() * -rtn * log1m(z_val); | ||
} | ||
if (!is_constant_all<Tz>::value) { | ||
forward_as<var>(z).adj() += vi.adj() * rtn * a_val * inv(1 - z_val); | ||
} | ||
}); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#include <test/unit/math/test_ad.hpp> | ||
|
||
TEST(mathMixScalFun, hypergeometric_1f0) { | ||
auto f = [](const auto& x1, const auto& x2) { | ||
using stan::math::hypergeometric_1f0; | ||
return hypergeometric_1f0(x1, x2); | ||
}; | ||
|
||
stan::test::expect_ad(f, 5, 0.3); | ||
stan::test::expect_ad(f, 3.4, 0.9); | ||
stan::test::expect_ad(f, 3.4, 0.1); | ||
stan::test::expect_ad(f, 5, -0.7); | ||
stan::test::expect_ad(f, 7, -0.1); | ||
stan::test::expect_ad(f, 2.8, 0.8); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#include <stan/math/prim.hpp> | ||
#include <gtest/gtest.h> | ||
#include <cmath> | ||
#include <limits> | ||
|
||
TEST(MathFunctions, hypergeometric_1f0Double) { | ||
using stan::math::hypergeometric_1f0; | ||
using stan::math::inv; | ||
|
||
EXPECT_FLOAT_EQ(4.62962962963, hypergeometric_1f0(3, 0.4)); | ||
EXPECT_FLOAT_EQ(0.510204081633, hypergeometric_1f0(2, -0.4)); | ||
EXPECT_FLOAT_EQ(300.906354890, hypergeometric_1f0(16.0, 0.3)); | ||
EXPECT_FLOAT_EQ(0.531441, hypergeometric_1f0(-6.0, 0.1)); | ||
} | ||
|
||
TEST(MathFunctions, hypergeometric_1f0_throw) { | ||
using stan::math::hypergeometric_1f0; | ||
|
||
EXPECT_THROW(hypergeometric_1f0(2.1, 1.0), std::domain_error); | ||
EXPECT_THROW(hypergeometric_1f0(0.5, 1.5), std::domain_error); | ||
EXPECT_THROW(hypergeometric_1f0(0.5, -1.0), std::domain_error); | ||
EXPECT_THROW(hypergeometric_1f0(0.5, -1.5), std::domain_error); | ||
} |