-
-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Framework for generic fvar<T> support through finite-differences #2929
Merged
Merged
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
9faaf42
Initial integrate_1d
andrjohns a26d064
Working single-order gradients
andrjohns 41a551b
Mix tests
andrjohns 11c4a23
Stray test
andrjohns e3cb963
Update doc
andrjohns 051bd8d
cpplint
andrjohns e3451d4
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot 21560cf
Fix failures
andrjohns 3435d1b
Missed typedef
andrjohns 77e4812
Simplify fvar framework
andrjohns 162bf58
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot 9c36e70
Namespace missed
andrjohns fd27178
Fix compatibility with nested containers
andrjohns fd77364
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot 85c9e5e
Move serializer definition into Math proper
andrjohns c09bcf4
Merge conflict
andrjohns 9d9d58c
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot fffc14f
includes
andrjohns 24a9e71
doxygen
andrjohns 00b1303
Namespaces
andrjohns 3be550d
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot 4209cc9
Scalar type
andrjohns 1536d12
Optimise deserializer use, generalise finite_diff input to remove copy
andrjohns e09096b
Merge commit '4cf25de56d29ef39c93eb2595d13dcfd65f97818' into HEAD
yashikno 880a7d9
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot b4d9d30
Update to c++-style cast
andrjohns 119ca9b
Merge commit 'eb3b5d769f60e93e79be51a2c723ce368e8437f9' into HEAD
yashikno 8037cd5
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot f8d886d
Update doc & naming
andrjohns ad303aa
Merge branch 'develop' into fvar-support
andrjohns File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
#ifndef STAN_MATH_FWD_FUNCTOR_FINITE_DIFF_HPP | ||
#define STAN_MATH_FWD_FUNCTOR_FINITE_DIFF_HPP | ||
|
||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/prim/functor/apply_scalar_binary.hpp> | ||
#include <stan/math/prim/functor/finite_diff_gradient_auto.hpp> | ||
#include <stan/math/prim/fun/value_of.hpp> | ||
#include <stan/math/prim/fun/sum.hpp> | ||
#include <stan/math/prim/fun/serializer.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
namespace internal { | ||
/** | ||
* Helper function for aggregating tangents if the respective input argument | ||
* was an fvar<T> type. | ||
* | ||
* Overload for when the input is not an fvar<T> and no tangents are needed. | ||
* | ||
* @tparam FuncTangent Type of tangent calculated by finite-differences | ||
* @tparam InputArg Type of the function input argument | ||
* @param tangent Calculated tangent | ||
* @param arg Input argument | ||
*/ | ||
template <typename FuncTangent, typename InputArg, | ||
require_not_st_fvar<InputArg>* = nullptr> | ||
inline constexpr double aggregate_tangent(const FuncTangent& tangent, | ||
const InputArg& arg) { | ||
return 0; | ||
} | ||
|
||
/** | ||
* Helper function for aggregating tangents if the respective input argument | ||
* was an fvar<T> type. | ||
* | ||
* Overload for when the input is an fvar<T> and its tangent needs to be | ||
* aggregated. | ||
* | ||
* @tparam FuncTangent Type of tangent calculated by finite-differences | ||
* @tparam InputArg Type of the function input argument | ||
* @param tangent Calculated tangent | ||
* @param arg Input argument | ||
*/ | ||
template <typename FuncTangent, typename InputArg, | ||
require_st_fvar<InputArg>* = nullptr> | ||
inline auto aggregate_tangent(const FuncTangent& tangent, const InputArg& arg) { | ||
return sum(apply_scalar_binary( | ||
tangent, arg, [](const auto& x, const auto& y) { return x * y.d_; })); | ||
} | ||
} // namespace internal | ||
|
||
/** | ||
* Construct an fvar<T> where the tangent is calculated by finite-differencing. | ||
* Finite-differencing is only perfomed where the scalar type to be evaluated is | ||
* `fvar<T>. | ||
* | ||
* Higher-order inputs (i.e., fvar<var> & fvar<fvar<T>>) are also implicitly | ||
* supported through auto-diffing the finite-differencing process. | ||
* | ||
* @tparam F Type of functor for which fvar<T> support is needed | ||
* @tparam TArgs Template parameter pack of the types passed in the `operator()` | ||
* of the functor type `F`. Must contain at least on type whose | ||
* scalar type is `fvar<T>` | ||
* @param func Functor for which fvar<T> support is needed | ||
* @param args Parameter pack of arguments to be passed to functor. | ||
*/ | ||
template <typename F, typename... TArgs, | ||
require_any_st_fvar<TArgs...>* = nullptr> | ||
inline auto finite_diff(const F& func, const TArgs&... args) { | ||
using FvarT = return_type_t<TArgs...>; | ||
using FvarInnerT = typename FvarT::Scalar; | ||
|
||
std::vector<FvarInnerT> serialised_args | ||
= serialize<FvarInnerT>(value_of(args)...); | ||
|
||
auto serial_functor = [&](const auto& v) { | ||
auto v_deserializer = to_deserializer(v); | ||
return func(v_deserializer.read(args)...); | ||
}; | ||
|
||
FvarInnerT rtn_value; | ||
std::vector<FvarInnerT> grad; | ||
finite_diff_gradient_auto(serial_functor, serialised_args, rtn_value, grad); | ||
|
||
FvarInnerT rtn_grad = 0; | ||
auto grad_deserializer = to_deserializer(grad); | ||
// Use a fold-expression to aggregate tangents for input arguments | ||
static_cast<void>( | ||
std::initializer_list<int>{(rtn_grad += internal::aggregate_tangent( | ||
grad_deserializer.read(args), args), | ||
0)...}); | ||
|
||
return FvarT(rtn_value, rtn_grad); | ||
} | ||
|
||
/** | ||
* Construct an fvar<T> where the tangent is calculated by finite-differencing. | ||
* Finite-differencing is only perfomed where the scalar type to be evaluated is | ||
* `fvar<T>. | ||
* | ||
* This overload is used when no fvar<T> arguments are passed and simply | ||
* evaluates the functor with the provided arguments. | ||
* | ||
* @tparam F Type of functor | ||
* @tparam TArgs Template parameter pack of the types passed in the `operator()` | ||
* of the functor type `F`. Must contain no type whose | ||
* scalar type is `fvar<T>` | ||
* @param func Functor | ||
* @param args... Parameter pack of arguments to be passed to functor. | ||
*/ | ||
template <typename F, typename... TArgs, | ||
require_all_not_st_fvar<TArgs...>* = nullptr> | ||
inline auto finite_diff(const F& func, const TArgs&... args) { | ||
return func(args...); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
#ifndef STAN_MATH_FWD_FUNCTOR_INTEGRATE_1D_HPP | ||
#define STAN_MATH_FWD_FUNCTOR_INTEGRATE_1D_HPP | ||
|
||
#include <stan/math/fwd/meta.hpp> | ||
#include <stan/math/prim/functor/integrate_1d.hpp> | ||
#include <stan/math/prim/fun/value_of.hpp> | ||
#include <stan/math/prim/meta/forward_as.hpp> | ||
#include <stan/math/prim/functor/apply.hpp> | ||
#include <stan/math/fwd/functor/finite_diff.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
/** | ||
* Return the integral of f from a to b to the given relative tolerance | ||
* | ||
* @tparam F Type of f | ||
* @tparam T_a type of first limit | ||
* @tparam T_b type of second limit | ||
* @tparam Args types of parameter pack arguments | ||
* | ||
* @param f the functor to integrate | ||
* @param a lower limit of integration | ||
* @param b upper limit of integration | ||
* @param relative_tolerance relative tolerance passed to Boost quadrature | ||
* @param[in, out] msgs the print stream for warning messages | ||
* @param args additional arguments to pass to f | ||
* @return numeric integral of function f | ||
*/ | ||
template <typename F, typename T_a, typename T_b, typename... Args, | ||
require_any_st_fvar<T_a, T_b, Args...> * = nullptr> | ||
inline return_type_t<T_a, T_b, Args...> integrate_1d_impl( | ||
const F &f, const T_a &a, const T_b &b, double relative_tolerance, | ||
std::ostream *msgs, const Args &... args) { | ||
using FvarT = scalar_type_t<return_type_t<T_a, T_b, Args...>>; | ||
|
||
// Wrap integrate_1d call in a functor where the input arguments are only | ||
// for which tangents are needed | ||
auto a_val = value_of(a); | ||
auto b_val = value_of(b); | ||
auto func | ||
= [f, msgs, relative_tolerance, a_val, b_val](const auto &... args_var) { | ||
return integrate_1d_impl(f, a_val, b_val, relative_tolerance, msgs, | ||
args_var...); | ||
}; | ||
FvarT ret = finite_diff(func, args...); | ||
|
||
// Calculate tangents w.r.t. integration bounds if needed | ||
if (is_fvar<T_a>::value || is_fvar<T_b>::value) { | ||
auto val_args = std::make_tuple(value_of(args)...); | ||
if (is_fvar<T_a>::value) { | ||
ret.d_ += math::forward_as<FvarT>(a).d_ | ||
* math::apply( | ||
[&](auto &&... tuple_args) { | ||
return -f(a_val, 0.0, msgs, tuple_args...); | ||
}, | ||
val_args); | ||
} | ||
if (is_fvar<T_b>::value) { | ||
ret.d_ += math::forward_as<FvarT>(b).d_ | ||
* math::apply( | ||
[&](auto &&... tuple_args) { | ||
return f(b_val, 0.0, msgs, tuple_args...); | ||
}, | ||
val_args); | ||
} | ||
} | ||
return ret; | ||
} | ||
|
||
/** | ||
* Compute the integral of the single variable function f from a to b to within | ||
* a specified relative tolerance. a and b can be finite or infinite. | ||
* | ||
* @tparam T_a type of first limit | ||
* @tparam T_b type of second limit | ||
* @tparam T_theta type of parameters | ||
* @tparam T Type of f | ||
* | ||
* @param f the functor to integrate | ||
* @param a lower limit of integration | ||
* @param b upper limit of integration | ||
* @param theta additional parameters to be passed to f | ||
* @param x_r additional data to be passed to f | ||
* @param x_i additional integer data to be passed to f | ||
* @param[in, out] msgs the print stream for warning messages | ||
* @param relative_tolerance relative tolerance passed to Boost quadrature | ||
* @return numeric integral of function f | ||
*/ | ||
template <typename F, typename T_a, typename T_b, typename T_theta, | ||
require_any_fvar_t<T_a, T_b, T_theta> * = nullptr> | ||
inline return_type_t<T_a, T_b, T_theta> integrate_1d( | ||
const F &f, const T_a &a, const T_b &b, const std::vector<T_theta> &theta, | ||
const std::vector<double> &x_r, const std::vector<int> &x_i, | ||
std::ostream *msgs, const double relative_tolerance) { | ||
return integrate_1d_impl(integrate_1d_adapter<F>(f), a, b, relative_tolerance, | ||
msgs, theta, x_r, x_i); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You didn't do this, but should these reserves be
y.reserve(y.size() + x.size())
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not in this case, since the aim is to return a container of the same size as
x