Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Framework for generic fvar<T> support through finite-differences #2929

Merged
merged 30 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
9faaf42
Initial integrate_1d
andrjohns Aug 7, 2023
a26d064
Working single-order gradients
andrjohns Aug 7, 2023
41a551b
Mix tests
andrjohns Aug 7, 2023
11c4a23
Stray test
andrjohns Aug 7, 2023
e3cb963
Update doc
andrjohns Aug 7, 2023
051bd8d
cpplint
andrjohns Aug 7, 2023
e3451d4
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 7, 2023
21560cf
Fix failures
andrjohns Aug 7, 2023
3435d1b
Missed typedef
andrjohns Aug 7, 2023
77e4812
Simplify fvar framework
andrjohns Aug 8, 2023
162bf58
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 8, 2023
9c36e70
Namespace missed
andrjohns Aug 8, 2023
fd27178
Fix compatibility with nested containers
andrjohns Aug 8, 2023
fd77364
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 8, 2023
85c9e5e
Move serializer definition into Math proper
andrjohns Aug 29, 2023
c09bcf4
Merge conflict
andrjohns Aug 29, 2023
9d9d58c
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 29, 2023
fffc14f
includes
andrjohns Aug 29, 2023
24a9e71
doxygen
andrjohns Aug 29, 2023
00b1303
Namespaces
andrjohns Aug 29, 2023
3be550d
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 29, 2023
4209cc9
Scalar type
andrjohns Aug 29, 2023
1536d12
Optimise deserializer use, generalise finite_diff input to remove copy
andrjohns Sep 18, 2023
e09096b
Merge commit '4cf25de56d29ef39c93eb2595d13dcfd65f97818' into HEAD
yashikno Sep 18, 2023
880a7d9
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Sep 18, 2023
b4d9d30
Update to c++-style cast
andrjohns Sep 27, 2023
119ca9b
Merge commit 'eb3b5d769f60e93e79be51a2c723ce368e8437f9' into HEAD
yashikno Sep 27, 2023
8037cd5
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Sep 27, 2023
f8d886d
Update doc & naming
andrjohns Sep 28, 2023
ad303aa
Merge branch 'develop' into fvar-support
andrjohns Sep 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions stan/math/fwd/functor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

#include <stan/math/fwd/functor/apply_scalar_unary.hpp>
#include <stan/math/fwd/functor/gradient.hpp>
#include <stan/math/fwd/functor/finite_diff.hpp>
#include <stan/math/fwd/functor/hessian.hpp>
#include <stan/math/fwd/functor/integrate_1d.hpp>
#include <stan/math/fwd/functor/jacobian.hpp>
#include <stan/math/fwd/functor/operands_and_partials.hpp>
#include <stan/math/fwd/functor/partials_propagator.hpp>
Expand Down
120 changes: 120 additions & 0 deletions stan/math/fwd/functor/finite_diff.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#ifndef STAN_MATH_FWD_FUNCTOR_FINITE_DIFF_HPP
#define STAN_MATH_FWD_FUNCTOR_FINITE_DIFF_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
#include <stan/math/prim/functor/finite_diff_gradient_auto.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/fun/sum.hpp>
#include <stan/math/prim/fun/serializer.hpp>

namespace stan {
namespace math {
namespace internal {
/**
* Helper function for aggregating tangents if the respective input argument
* was an fvar<T> type.
*
* Overload for when the input is not an fvar<T> and no tangents are needed.
*
* @tparam FuncTangent Type of tangent calculated by finite-differences
* @tparam InputArg Type of the function input argument
* @param tangent Calculated tangent
* @param arg Input argument
*/
template <typename FuncTangent, typename InputArg,
require_not_st_fvar<InputArg>* = nullptr>
inline constexpr double aggregate_tangent(const FuncTangent& tangent,
const InputArg& arg) {
return 0;
}

/**
* Helper function for aggregating tangents if the respective input argument
* was an fvar<T> type.
*
* Overload for when the input is an fvar<T> and its tangent needs to be
* aggregated.
*
* @tparam FuncTangent Type of tangent calculated by finite-differences
* @tparam InputArg Type of the function input argument
* @param tangent Calculated tangent
* @param arg Input argument
*/
template <typename FuncTangent, typename InputArg,
require_st_fvar<InputArg>* = nullptr>
inline auto aggregate_tangent(const FuncTangent& tangent, const InputArg& arg) {
return sum(apply_scalar_binary(
tangent, arg, [](const auto& x, const auto& y) { return x * y.d_; }));
}
} // namespace internal

/**
* Construct an fvar<T> where the tangent is calculated by finite-differencing.
* Finite-differencing is only perfomed where the scalar type to be evaluated is
* `fvar<T>.
*
* Higher-order inputs (i.e., fvar<var> & fvar<fvar<T>>) are also implicitly
* supported through auto-diffing the finite-differencing process.
*
* @tparam F Type of functor for which fvar<T> support is needed
* @tparam TArgs Template parameter pack of the types passed in the `operator()`
* of the functor type `F`. Must contain at least on type whose
* scalar type is `fvar<T>`
* @param func Functor for which fvar<T> support is needed
* @param args Parameter pack of arguments to be passed to functor.
*/
template <typename F, typename... TArgs,
require_any_st_fvar<TArgs...>* = nullptr>
inline auto finite_diff(const F& func, const TArgs&... args) {
using FvarT = return_type_t<TArgs...>;
using FvarInnerT = typename FvarT::Scalar;

std::vector<FvarInnerT> serialised_args
= serialize<FvarInnerT>(value_of(args)...);

auto serial_functor = [&](const auto& v) {
auto v_deserializer = to_deserializer(v);
return func(v_deserializer.read(args)...);
};

FvarInnerT rtn_value;
std::vector<FvarInnerT> grad;
finite_diff_gradient_auto(serial_functor, serialised_args, rtn_value, grad);

FvarInnerT rtn_grad = 0;
auto grad_deserializer = to_deserializer(grad);
// Use a fold-expression to aggregate tangents for input arguments
static_cast<void>(
std::initializer_list<int>{(rtn_grad += internal::aggregate_tangent(
grad_deserializer.read(args), args),
0)...});

return FvarT(rtn_value, rtn_grad);
}

/**
* Construct an fvar<T> where the tangent is calculated by finite-differencing.
* Finite-differencing is only perfomed where the scalar type to be evaluated is
* `fvar<T>.
*
* This overload is used when no fvar<T> arguments are passed and simply
* evaluates the functor with the provided arguments.
*
* @tparam F Type of functor
* @tparam TArgs Template parameter pack of the types passed in the `operator()`
* of the functor type `F`. Must contain no type whose
* scalar type is `fvar<T>`
* @param func Functor
* @param args... Parameter pack of arguments to be passed to functor.
*/
template <typename F, typename... TArgs,
require_all_not_st_fvar<TArgs...>* = nullptr>
inline auto finite_diff(const F& func, const TArgs&... args) {
return func(args...);
}

} // namespace math
} // namespace stan

#endif
101 changes: 101 additions & 0 deletions stan/math/fwd/functor/integrate_1d.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#ifndef STAN_MATH_FWD_FUNCTOR_INTEGRATE_1D_HPP
#define STAN_MATH_FWD_FUNCTOR_INTEGRATE_1D_HPP

#include <stan/math/fwd/meta.hpp>
#include <stan/math/prim/functor/integrate_1d.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/meta/forward_as.hpp>
#include <stan/math/prim/functor/apply.hpp>
#include <stan/math/fwd/functor/finite_diff.hpp>

namespace stan {
namespace math {
/**
* Return the integral of f from a to b to the given relative tolerance
*
* @tparam F Type of f
* @tparam T_a type of first limit
* @tparam T_b type of second limit
* @tparam Args types of parameter pack arguments
*
* @param f the functor to integrate
* @param a lower limit of integration
* @param b upper limit of integration
* @param relative_tolerance relative tolerance passed to Boost quadrature
* @param[in, out] msgs the print stream for warning messages
* @param args additional arguments to pass to f
* @return numeric integral of function f
*/
template <typename F, typename T_a, typename T_b, typename... Args,
require_any_st_fvar<T_a, T_b, Args...> * = nullptr>
inline return_type_t<T_a, T_b, Args...> integrate_1d_impl(
const F &f, const T_a &a, const T_b &b, double relative_tolerance,
std::ostream *msgs, const Args &... args) {
using FvarT = scalar_type_t<return_type_t<T_a, T_b, Args...>>;

// Wrap integrate_1d call in a functor where the input arguments are only
// for which tangents are needed
auto a_val = value_of(a);
auto b_val = value_of(b);
auto func
= [f, msgs, relative_tolerance, a_val, b_val](const auto &... args_var) {
return integrate_1d_impl(f, a_val, b_val, relative_tolerance, msgs,
args_var...);
};
FvarT ret = finite_diff(func, args...);

// Calculate tangents w.r.t. integration bounds if needed
if (is_fvar<T_a>::value || is_fvar<T_b>::value) {
auto val_args = std::make_tuple(value_of(args)...);
if (is_fvar<T_a>::value) {
ret.d_ += math::forward_as<FvarT>(a).d_
* math::apply(
[&](auto &&... tuple_args) {
return -f(a_val, 0.0, msgs, tuple_args...);
},
val_args);
}
if (is_fvar<T_b>::value) {
ret.d_ += math::forward_as<FvarT>(b).d_
* math::apply(
[&](auto &&... tuple_args) {
return f(b_val, 0.0, msgs, tuple_args...);
},
val_args);
}
}
return ret;
}

/**
* Compute the integral of the single variable function f from a to b to within
* a specified relative tolerance. a and b can be finite or infinite.
*
* @tparam T_a type of first limit
* @tparam T_b type of second limit
* @tparam T_theta type of parameters
* @tparam T Type of f
*
* @param f the functor to integrate
* @param a lower limit of integration
* @param b upper limit of integration
* @param theta additional parameters to be passed to f
* @param x_r additional data to be passed to f
* @param x_i additional integer data to be passed to f
* @param[in, out] msgs the print stream for warning messages
* @param relative_tolerance relative tolerance passed to Boost quadrature
* @return numeric integral of function f
*/
template <typename F, typename T_a, typename T_b, typename T_theta,
require_any_fvar_t<T_a, T_b, T_theta> * = nullptr>
inline return_type_t<T_a, T_b, T_theta> integrate_1d(
const F &f, const T_a &a, const T_b &b, const std::vector<T_theta> &theta,
const std::vector<double> &x_r, const std::vector<int> &x_i,
std::ostream *msgs, const double relative_tolerance) {
return integrate_1d_impl(integrate_1d_adapter<F>(f), a, b, relative_tolerance,
msgs, theta, x_r, x_i);
}

} // namespace math
} // namespace stan
#endif
1 change: 1 addition & 0 deletions stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@
#include <stan/math/prim/fun/scaled_add.hpp>
#include <stan/math/prim/fun/sd.hpp>
#include <stan/math/prim/fun/segment.hpp>
#include <stan/math/prim/fun/serializer.hpp>
#include <stan/math/prim/fun/select.hpp>
#include <stan/math/prim/fun/sign.hpp>
#include <stan/math/prim/fun/signbit.hpp>
Expand Down
31 changes: 15 additions & 16 deletions test/unit/math/serializer.hpp → stan/math/prim/fun/serializer.hpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#ifndef TEST_UNIT_MATH_SERIALIZER_HPP
#define TEST_UNIT_MATH_SERIALIZER_HPP
#ifndef STAN_MATH_PRIM_FUN_SERIALIZER_HPP
#define STAN_MATH_PRIM_FUN_SERIALIZER_HPP

#include <stan/math.hpp>
#include <stan/math/prim/meta/promote_scalar_type.hpp>
#include <stan/math/prim/fun/to_vector.hpp>
#include <stan/math/prim/fun/to_array_1d.hpp>
#include <complex>
#include <string>
#include <vector>

namespace stan {
namespace test {
namespace math {

/**
* A class to store a sequence of values which can be deserialized
Expand Down Expand Up @@ -44,10 +46,10 @@ struct deserializer {
/**
* Construct a deserializer from the specified sequence of values.
*
* @param vals values to deserialize
* @param v_vals values to deserialize
*/
explicit deserializer(const Eigen::Matrix<T, -1, 1>& v_vals)
: position_(0), vals_(math::to_array_1d(v_vals)) {}
: position_(0), vals_(to_array_1d(v_vals)) {}

/**
* Read a scalar conforming to the shape of the specified argument,
Expand Down Expand Up @@ -94,8 +96,8 @@ struct deserializer {
*/
template <typename U, require_std_vector_t<U>* = nullptr,
require_not_st_complex<U>* = nullptr>
typename stan::math::promote_scalar_type<T, U>::type read(const U& x) {
typename stan::math::promote_scalar_type<T, U>::type y;
promote_scalar_t<T, U> read(const U& x) {
promote_scalar_t<T, U> y;
y.reserve(x.size());
Comment on lines +100 to 101
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You didn't do this, but should these reserves be y.reserve(y.size() + x.size())?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in this case, since the aim is to return a container of the same size as x

for (size_t i = 0; i < x.size(); ++i)
y.push_back(read(x[i]));
Expand All @@ -113,9 +115,8 @@ struct deserializer {
* @return deserialized value with shape and size matching argument
*/
template <typename U, require_std_vector_st<is_complex, U>* = nullptr>
typename stan::math::promote_scalar_type<std::complex<T>, U>::type read(
const U& x) {
typename stan::math::promote_scalar_type<std::complex<T>, U>::type y;
promote_scalar_t<std::complex<T>, U> read(const U& x) {
promote_scalar_t<std::complex<T>, U> y;
y.reserve(x.size());
for (size_t i = 0; i < x.size(); ++i)
y.push_back(read(x[i]));
Expand Down Expand Up @@ -257,9 +258,7 @@ struct serializer {
*
* @return serialized values
*/
const Eigen::Matrix<T, -1, 1>& vector_vals() {
return math::to_vector(vals_);
}
const Eigen::Matrix<T, -1, 1>& vector_vals() { return to_vector(vals_); }
};

/**
Expand Down Expand Up @@ -338,10 +337,10 @@ std::vector<real_return_t<T>> serialize_return(const T& x) {
*/
template <typename... Ts>
Eigen::VectorXd serialize_args(const Ts... xs) {
return math::to_vector(serialize<double>(xs...));
return to_vector(serialize<double>(xs...));
}

} // namespace test
} // namespace math
} // namespace stan

#endif
29 changes: 15 additions & 14 deletions stan/math/prim/functor/finite_diff_gradient_auto.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,39 +46,40 @@ namespace math {
* @param[out] fx function applied to argument
* @param[out] grad_fx gradient of function at argument
*/
template <typename F>
void finite_diff_gradient_auto(const F& f, const Eigen::VectorXd& x, double& fx,
Eigen::VectorXd& grad_fx) {
Eigen::VectorXd x_temp(x);
template <typename F, typename VectorT,
typename ScalarT = return_type_t<VectorT>>
void finite_diff_gradient_auto(const F& f, const VectorT& x, ScalarT& fx,
VectorT& grad_fx) {
VectorT x_temp(x);
fx = f(x);
grad_fx.resize(x.size());
for (int i = 0; i < x.size(); ++i) {
double h = finite_diff_stepsize(x(i));
double h = finite_diff_stepsize(value_of_rec(x[i]));

double delta_f = 0;
ScalarT delta_f = 0;

x_temp(i) = x(i) + 3 * h;
x_temp[i] = x[i] + 3 * h;
delta_f += f(x_temp);

x_temp(i) = x(i) + 2 * h;
x_temp[i] = x[i] + 2 * h;
delta_f -= 9 * f(x_temp);

x_temp(i) = x(i) + h;
x_temp[i] = x[i] + h;
delta_f += 45 * f(x_temp);

x_temp(i) = x(i) + -3 * h;
x_temp[i] = x[i] + -3 * h;
delta_f -= f(x_temp);

x_temp(i) = x(i) + -2 * h;
x_temp[i] = x[i] + -2 * h;
delta_f += 9 * f(x_temp);

x_temp(i) = x(i) - h;
x_temp[i] = x[i] - h;
delta_f -= 45 * f(x_temp);

delta_f /= 60 * h;

x_temp(i) = x(i);
grad_fx(i) = delta_f;
x_temp[i] = x[i];
grad_fx[i] = delta_f;
}
}

Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/functor/integrate_1d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ inline double integrate(const F& f, double a, double b,
* @return numeric integral of function f
*/
template <typename F, typename... Args,
require_all_not_st_var<Args...>* = nullptr>
require_all_st_arithmetic<Args...>* = nullptr>
inline double integrate_1d_impl(const F& f, double a, double b,
double relative_tolerance, std::ostream* msgs,
const Args&... args) {
Expand Down
Loading