Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-Implementation of Hypergeometric PFQ gradient function #2961

Merged
merged 39 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
47cdde2
Initial natural scale impl
andrjohns Oct 4, 2023
f091e0c
Absolute impl on natural scale
andrjohns Oct 4, 2023
7d58254
Log scale initial impl
andrjohns Oct 4, 2023
af83ccd
Log scale increment
andrjohns Oct 4, 2023
c9cd18d
Replace primary definition
andrjohns Oct 4, 2023
d5ce3db
Don't calc for non-autodiff
andrjohns Oct 4, 2023
668f7fd
Initial mix working
andrjohns Oct 4, 2023
d98a289
Begin debugging autodiff
andrjohns Oct 4, 2023
071982d
Merge branch 'develop' into grad_pfq-2
andrjohns Oct 6, 2023
0795e5d
Initial ad testing working
andrjohns Oct 6, 2023
67fdd57
Updating test
andrjohns Oct 6, 2023
d52db96
Workind ad tests
andrjohns Oct 7, 2023
01988f1
Simplified and tidied
andrjohns Oct 7, 2023
644f9cb
Update doc
andrjohns Oct 7, 2023
f713153
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Oct 7, 2023
3c5b15d
Skip infsum if not needed
andrjohns Oct 7, 2023
50acf60
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Oct 7, 2023
c1cbe26
Test fix asan error
andrjohns Oct 7, 2023
07b5f4c
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Oct 7, 2023
f810c55
Fix handling of zero-inputs
andrjohns Oct 8, 2023
535dadb
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Oct 8, 2023
e769630
Update doc, additional simplification
andrjohns Oct 9, 2023
a054665
Additional simplification
andrjohns Oct 9, 2023
b9388b1
Fix sign-flip case
andrjohns Oct 9, 2023
fbcec85
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Oct 9, 2023
e318d93
Merge branch 'develop' into grad_pfq-2
andrjohns Feb 15, 2024
c149fda
Avoid unneeded computation
andrjohns Feb 21, 2024
5f7c7b2
Merge branch 'develop' into grad_pfq-2
andrjohns Feb 21, 2024
2869827
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Feb 21, 2024
b97f02a
Tidying
andrjohns Feb 24, 2024
af23ee6
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Feb 24, 2024
30273c8
Update handling integers
andrjohns Feb 26, 2024
b631517
Optimise handling of negative args
andrjohns Feb 26, 2024
a35dc3d
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Feb 26, 2024
83edc29
New impl more accurate than 3F2
andrjohns Feb 26, 2024
8da99c5
Revert bool camel case
andrjohns Feb 27, 2024
1f1ff16
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Feb 27, 2024
8919824
Trigger CI
andrjohns Feb 28, 2024
ecc713b
Apply review suggestions
andrjohns Apr 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 31 additions & 22 deletions stan/math/fwd/fun/hypergeometric_pFq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,41 @@ namespace math {
* @return Generalized hypergeometric function
*/
template <typename Ta, typename Tb, typename Tz,
require_all_matrix_t<Ta, Tb>* = nullptr,
require_return_type_t<is_fvar, Ta, Tb, Tz>* = nullptr>
inline return_type_t<Ta, Tb, Tz> hypergeometric_pFq(const Ta& a, const Tb& b,
const Tz& z) {
using fvar_t = return_type_t<Ta, Tb, Tz>;
ref_type_t<Ta> a_ref = a;
ref_type_t<Tb> b_ref = b;
auto grad_tuple = grad_pFq(a_ref, b_ref, z);

typename fvar_t::Scalar grad = 0;

if (!is_constant<Ta>::value) {
grad += dot_product(forward_as<promote_scalar_t<fvar_t, Ta>>(a_ref).d(),
std::get<0>(grad_tuple));
typename FvarT = return_type_t<Ta, Tb, Tz>,
bool GradA = !is_constant<Ta>::value,
bool GradB = !is_constant<Tb>::value,
bool GradZ = !is_constant<Tz>::value,
require_all_vector_t<Ta, Tb>* = nullptr,
require_return_type_t<is_fvar, FvarT>* = nullptr>
andrjohns marked this conversation as resolved.
Show resolved Hide resolved
inline FvarT hypergeometric_pFq(const Ta& a, const Tb& b, const Tz& z) {
using PartialsT = partials_type_t<FvarT>;
using ARefT = ref_type_t<Ta>;
using BRefT = ref_type_t<Tb>;

ARefT a_ref = a;
BRefT b_ref = b;
auto&& a_val = value_of(a_ref);
auto&& b_val = value_of(b_ref);
auto&& z_val = value_of(z);
PartialsT pfq_val = hypergeometric_pFq(a_val, b_val, z_val);
auto grad_tuple = grad_pFq<GradA, GradB, GradZ>(pfq_val, a_val, b_val, z_val);

FvarT rtn = FvarT(pfq_val, 0.0);

if (GradA) {
rtn.d_ += dot_product(forward_as<promote_scalar_t<FvarT, ARefT>>(a_ref).d(),
std::get<0>(grad_tuple));
}
if (!is_constant<Tb>::value) {
grad += dot_product(forward_as<promote_scalar_t<fvar_t, Tb>>(b_ref).d(),
std::get<1>(grad_tuple));
if (GradB) {
rtn.d_ += dot_product(forward_as<promote_scalar_t<FvarT, BRefT>>(b_ref).d(),
std::get<1>(grad_tuple));
}
if (!is_constant<Tz>::value) {
grad += forward_as<promote_scalar_t<fvar_t, Tz>>(z).d_
* std::get<2>(grad_tuple);
if (GradZ) {
rtn.d_ += forward_as<promote_scalar_t<FvarT, Tz>>(z).d_
* std::get<2>(grad_tuple);
}

return fvar_t(
hypergeometric_pFq(value_of(a_ref), value_of(b_ref), value_of(z)), grad);
return rtn;
}

} // namespace math
Expand Down
Loading