From bafd17feffc75235df2312353a94a2e8d2f45871 Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Tue, 22 Sep 2020 14:26:08 +0300 Subject: [PATCH 1/5] add closures to ODEs --- stan/math/prim/err/check_finite.hpp | 5 +- stan/math/prim/fun/value_of.hpp | 13 +++ stan/math/prim/functor.hpp | 1 + stan/math/prim/functor/closure_adapter.hpp | 104 ++++++++++++++++++ stan/math/prim/functor/coupled_ode_system.hpp | 8 +- stan/math/prim/functor/integrate_ode_rk45.hpp | 18 ++- ...grate_ode_std_vector_interface_adapter.hpp | 28 ++++- stan/math/prim/functor/ode_rk45.hpp | 27 ++++- .../prim/functor/ode_store_sensitivities.hpp | 8 +- stan/math/prim/meta.hpp | 1 + stan/math/prim/meta/is_stan_closure.hpp | 56 ++++++++++ stan/math/rev/core/accumulate_adjoints.hpp | 27 +++++ stan/math/rev/core/count_vars.hpp | 25 +++++ stan/math/rev/core/deep_copy_vars.hpp | 15 ++- stan/math/rev/core/save_varis.hpp | 25 +++++ stan/math/rev/core/zero_adjoints.hpp | 21 ++++ stan/math/rev/functor/coupled_ode_system.hpp | 13 ++- stan/math/rev/functor/cvodes_integrator.hpp | 17 +-- stan/math/rev/functor/integrate_ode_adams.hpp | 24 +++- stan/math/rev/functor/integrate_ode_bdf.hpp | 23 +++- stan/math/rev/functor/ode_adams.hpp | 26 ++++- stan/math/rev/functor/ode_bdf.hpp | 26 ++++- .../rev/functor/ode_store_sensitivities.hpp | 12 +- .../prim/functor/coupled_ode_system_test.cpp | 21 ++-- ..._ode_std_vector_interface_adapter_test.cpp | 11 +- .../functor/ode_store_sensitivities_test.cpp | 2 +- .../rev/functor/coupled_ode_system_test.cpp | 97 ++++++++++------ ..._ode_std_vector_interface_adapter_test.cpp | 18 +-- .../math/rev/functor/ode_rk45_rev_test.cpp | 80 ++++++++++++++ .../functor/ode_store_sensitivities_test.cpp | 4 +- 30 files changed, 638 insertions(+), 118 deletions(-) create mode 100644 stan/math/prim/functor/closure_adapter.hpp create mode 100644 stan/math/prim/meta/is_stan_closure.hpp diff --git a/stan/math/prim/err/check_finite.hpp b/stan/math/prim/err/check_finite.hpp index ddb6cab80fc..e4e4bf33947 100644 --- a/stan/math/prim/err/check_finite.hpp +++ b/stan/math/prim/err/check_finite.hpp @@ -17,7 +17,7 @@ namespace math { * @param y variable to check * @throw domain_error if y is infinity, -infinity, or NaN */ -template +template * = nullptr> inline void check_finite(const char* function, const char* name, const T_y& y) { if (check_finite_screen(y)) { auto is_good = [](const auto& y) { return std::isfinite(y); }; @@ -25,6 +25,9 @@ inline void check_finite(const char* function, const char* name, const T_y& y) { } } +template * = nullptr> +inline void check_finite(const char* function, const char* name, const T& y) {} + } // namespace math } // namespace stan diff --git a/stan/math/prim/fun/value_of.hpp b/stan/math/prim/fun/value_of.hpp index 8cc47e80025..ade5098625d 100644 --- a/stan/math/prim/fun/value_of.hpp +++ b/stan/math/prim/fun/value_of.hpp @@ -60,6 +60,19 @@ inline auto value_of(EigMat&& M) { std::forward(M)); } +/** + * Closures that capture non-arithmetic types have value_of__() method. + * + * @tparam F Input element type + * @param[in] f Input closure + * @return closure + **/ +template * = nullptr, + require_not_st_arithmetic* = nullptr> +inline auto value_of(const F& f) { + return f.value_of__(); +} + } // namespace math } // namespace stan diff --git a/stan/math/prim/functor.hpp b/stan/math/prim/functor.hpp index aab4147aa96..dc6899110fc 100644 --- a/stan/math/prim/functor.hpp +++ b/stan/math/prim/functor.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp new file mode 100644 index 00000000000..47c10796007 --- /dev/null +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -0,0 +1,104 @@ +#ifndef STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP +#define STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP + +#include +#include + +namespace stan { +namespace math { + +template +struct closure_adapter { + using captured_scalar_t__ = double; + using ValueOf__ = closure_adapter; + static const size_t vars_count__ = 0; + F f_; + + explicit closure_adapter(const F& f) : f_(f) {} + + template + auto operator()(std::ostream* msgs, Args... args) const { + return f_(args..., msgs); + } + auto value_of__() const { return closure_adapter(f_); } + auto deep_copy_vars__() const { return closure_adapter(f_); } + void zero_adjoints__() const {} + double* accumulate_adjoints__(double* dest) const { return dest; } + template + Vari** save_varis(Vari** dest) const { + return dest; + } +}; + +template +struct simple_closure { + using captured_scalar_t__ = return_type_t; + using ValueOf__ = simple_closure()))>; + const size_t vars_count__; + F f_; + T s_; + + explicit simple_closure(const F& f, T s) + : f_(f), s_(s), vars_count__(count_vars(s)) {} + + template + auto operator()(std::ostream* msgs, Args... args) const { + return f_(s_, args..., msgs); + } + auto value_of__() const { return ValueOf__(f_, value_of(s_)); } + auto deep_copy_vars__() const { + return simple_closure(f_, deep_copy_vars(s_)); + } + void zero_adjoints__() { zero_adjoints(s_); } + double* accumulate_adjoints__(double* dest) const { + return accumulate_adjoints(dest, s_); + } + template + Vari** save_varis__(Vari** dest) const { + return save_varis(dest, s_); + } +}; + +template +auto from_lambda(F f) { + return closure_adapter(f); +} + +template +auto from_lambda(F f, T a) { + return simple_closure(f, a); +} + +namespace internal { + +template +struct ode_closure_adapter { + using captured_scalar_t__ = double; + using ValueOf__ = ode_closure_adapter; + static const size_t vars_count__ = 0; + const F f_; + + explicit ode_closure_adapter(const F& f) : f_(f) {} + + template + auto operator()(std::ostream* msgs, const T0& t, + const Eigen::Matrix& y, + Args... args) const { + return f_(t, y, msgs, args...); + } + auto value_of__() const { return ode_closure_adapter(f_); } + auto deep_copy_vars__() const { return ode_closure_adapter(f_); } + void zero_adjoints__() const {} + double* accumulate_adjoints__(double* dest) const { return dest; } + template + Vari** save_varis(Vari** dest) const { + return dest; + } +}; + +} // namespace internal + +} // namespace math +} // namespace stan + +#endif diff --git a/stan/math/prim/functor/coupled_ode_system.hpp b/stan/math/prim/functor/coupled_ode_system.hpp index a4dbe69603d..166aa21a427 100644 --- a/stan/math/prim/functor/coupled_ode_system.hpp +++ b/stan/math/prim/functor/coupled_ode_system.hpp @@ -68,7 +68,7 @@ struct coupled_ode_system_impl { dz_dt.resize(y.size()); Eigen::VectorXd f_y_t - = apply([&](const Args&... args) { return f_(t, y, msgs_, args...); }, + = apply([&](const Args&... args) { return f_(msgs_, t, y, args...); }, args_tuple_); check_size_match("coupled_ode_system", "dy_dt", f_y_t.size(), "states", @@ -104,14 +104,14 @@ struct coupled_ode_system_impl { template struct coupled_ode_system : public coupled_ode_system_impl< - std::is_arithmetic>::value, F, T_y0, + std::is_arithmetic>::value, F, T_y0, Args...> { coupled_ode_system(const F& f, const Eigen::Matrix& y0, std::ostream* msgs, const Args&... args) : coupled_ode_system_impl< - std::is_arithmetic>::value, F, T_y0, - Args...>(f, y0, msgs, args...) {} + std::is_arithmetic>::value, F, T_y0, + Args...>(f, y0, msgs, args...) {} }; } // namespace math diff --git a/stan/math/prim/functor/integrate_ode_rk45.hpp b/stan/math/prim/functor/integrate_ode_rk45.hpp index 1e568d6b1ea..71471b69804 100644 --- a/stan/math/prim/functor/integrate_ode_rk45.hpp +++ b/stan/math/prim/functor/integrate_ode_rk45.hpp @@ -14,7 +14,7 @@ namespace math { * @deprecated use ode_rk45 */ template + typename T_ts, require_stan_closure_t* = nullptr> inline auto integrate_ode_rk45( const F& f, const std::vector& y0, const T_t0& t0, const std::vector& ts, const std::vector& theta, @@ -26,7 +26,7 @@ inline auto integrate_ode_rk45( ts, relative_tolerance, absolute_tolerance, max_num_steps, msgs, theta, x, x_int); - std::vector>> + std::vector>> y_converted; y_converted.reserve(y.size()); for (size_t i = 0; i < y.size(); ++i) @@ -35,6 +35,20 @@ inline auto integrate_ode_rk45( return y_converted; } +template * = nullptr> +inline auto integrate_ode_rk45( + const F& f, const std::vector& y0, const T_t0& t0, + const std::vector& ts, const std::vector& theta, + const std::vector& x, const std::vector& x_int, + std::ostream* msgs = nullptr, double relative_tolerance = 1e-6, + double absolute_tolerance = 1e-6, int max_num_steps = 1e6) { + closure_adapter cl(f); + return integrate_ode_rk45(cl, y0, t0, ts, theta, x, x_int, msgs, + relative_tolerance, absolute_tolerance, + max_num_steps); +} + } // namespace math } // namespace stan diff --git a/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp b/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp index 9b0178e27b5..2b85a4e0269 100644 --- a/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp +++ b/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp @@ -21,16 +21,34 @@ namespace internal { */ template struct integrate_ode_std_vector_interface_adapter { + using captured_scalar_t__ = typename F::captured_scalar_t__; + using ValueOf__ + = integrate_ode_std_vector_interface_adapter; + const int vars_count__; const F f_; - explicit integrate_ode_std_vector_interface_adapter(const F& f) : f_(f) {} + explicit integrate_ode_std_vector_interface_adapter(const F& f) + : vars_count__(f.vars_count__), f_(f) {} template - auto operator()(const T0& t, const Eigen::Matrix& y, - std::ostream* msgs, const std::vector& theta, - const std::vector& x, + auto operator()(std::ostream* msgs, const T0& t, + const Eigen::Matrix& y, + const std::vector& theta, const std::vector& x, const std::vector& x_int) const { - return to_vector(f_(t, to_array_1d(y), theta, x, x_int, msgs)); + return to_vector(f_(msgs, t, to_array_1d(y), theta, x, x_int)); + } + + auto value_of__() const { return ValueOf__(f_.value_of__()); } + auto deep_copy_vars__() const { + return integrate_ode_std_vector_interface_adapter(f_.deep_copy_vars__()); + } + void zero_adjoints__() const { f_.zero_adjoints__(); } + double* accumulate_adjoints__(double* dest) const { + return f_.accumulate_adjoints__(dest); + } + template + Vari** save_varis__(Vari** dest) const { + return f_.save_varis__(dest); } }; diff --git a/stan/math/prim/functor/ode_rk45.hpp b/stan/math/prim/functor/ode_rk45.hpp index b85a705553b..7fb6a8d2387 100644 --- a/stan/math/prim/functor/ode_rk45.hpp +++ b/stan/math/prim/functor/ode_rk45.hpp @@ -2,6 +2,7 @@ #define STAN_MATH_PRIM_FUNCTOR_ODE_RK45_HPP #include +#include #include #include #include @@ -51,8 +52,8 @@ namespace math { * @return Solution to ODE at times \p ts */ template -std::vector, + typename... Args, require_stan_closure_t* = nullptr> +std::vector, Eigen::Dynamic, 1>> ode_rk45_tol_impl(const char* function_name, const F& f, const Eigen::Matrix& y0_arg, T_t0 t0, @@ -91,7 +92,7 @@ ode_rk45_tol_impl(const char* function_name, const F& f, absolute_tolerance); check_positive(function_name, "max_num_steps", max_num_steps); - using return_t = return_type_t; + using return_t = return_type_t; // creates basic or coupled system by template specializations coupled_ode_system coupled_system(f, y0, msgs, args...); @@ -140,6 +141,22 @@ ode_rk45_tol_impl(const char* function_name, const F& f, return y; } +template * = nullptr> +std::vector, + Eigen::Dynamic, 1>> +ode_rk45_tol_impl(const char* function_name, const F& f, + const Eigen::Matrix& y0_arg, T_t0 t0, + const std::vector& ts, double relative_tolerance, + double absolute_tolerance, + long int max_num_steps, // NOLINT(runtime/int) + std::ostream* msgs, const Args&... args) { + internal::ode_closure_adapter cl(f); + return ode_rk45_tol_impl(function_name, cl, y0_arg, t0, ts, + relative_tolerance, absolute_tolerance, + max_num_steps, msgs, args...); +} + /** * Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of * times, { t1, t2, t3, ... } using the non-stiff Runge-Kutta 45 solver in @@ -178,7 +195,7 @@ ode_rk45_tol_impl(const char* function_name, const F& f, */ template -std::vector, +std::vector, Eigen::Dynamic, 1>> ode_rk45_tol(const F& f, const Eigen::Matrix& y0_arg, T_t0 t0, const std::vector& ts, double relative_tolerance, @@ -224,7 +241,7 @@ ode_rk45_tol(const F& f, const Eigen::Matrix& y0_arg, */ template -std::vector, +std::vector, Eigen::Dynamic, 1>> ode_rk45(const F& f, const Eigen::Matrix& y0, T_t0 t0, const std::vector& ts, std::ostream* msgs, const Args&... args) { diff --git a/stan/math/prim/functor/ode_store_sensitivities.hpp b/stan/math/prim/functor/ode_store_sensitivities.hpp index ec42a5506e4..9dfd650626f 100644 --- a/stan/math/prim/functor/ode_store_sensitivities.hpp +++ b/stan/math/prim/functor/ode_store_sensitivities.hpp @@ -28,10 +28,10 @@ namespace math { * @param args Extra arguments passed unmodified through to ODE right hand side * @return ODE state */ -template < - typename F, typename T_y0_t0, typename T_t0, typename T_t, typename... Args, - typename - = require_all_arithmetic_t...>> +template , T_y0_t0, T_t0, + T_t, scalar_type_t...>> Eigen::VectorXd ode_store_sensitivities( const F& f, const std::vector& coupled_state, const Eigen::Matrix& y0, T_t0 t0, T_t t, diff --git a/stan/math/prim/meta.hpp b/stan/math/prim/meta.hpp index 7040e2969b7..9ff5a2c314d 100644 --- a/stan/math/prim/meta.hpp +++ b/stan/math/prim/meta.hpp @@ -202,6 +202,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/prim/meta/is_stan_closure.hpp b/stan/math/prim/meta/is_stan_closure.hpp new file mode 100644 index 00000000000..5e4ebf09987 --- /dev/null +++ b/stan/math/prim/meta/is_stan_closure.hpp @@ -0,0 +1,56 @@ +#ifndef STAN_MATH_PRIM_META_IS_STAN_CLOSURE_HPP +#define STAN_MATH_PRIM_META_IS_STAN_CLOSURE_HPP + +#include +#include +#include + +#include + +namespace stan { + +/** + * Checks if type is a closure object. + * @tparam The type to check + * @ingroup type_trait + */ +template +struct is_stan_closure : std::false_type {}; + +template +struct is_stan_closure> + : std::true_type {}; + +template +struct scalar_type> { + using type = typename T::captured_scalar_t__; +}; + +STAN_ADD_REQUIRE_UNARY(stan_closure, is_stan_closure, general_types); + +template +struct fn_return_type { + using type = double; +}; + +template +struct fn_return_type> { + using type = typename T::captured_scalar_t__; +}; + +/** + * Convenience type for the return type of the specified template + * parameters. + * + * @tparam F callable type + * @tparam Ts sequence of types + * @see return_type + * @ingroup type_trait + */ +template +using fn_return_type_t + = return_type_t::type, Args...>; + +} // namespace stan + +#endif diff --git a/stan/math/rev/core/accumulate_adjoints.hpp b/stan/math/rev/core/accumulate_adjoints.hpp index e5b27354ebd..670618dfebb 100644 --- a/stan/math/rev/core/accumulate_adjoints.hpp +++ b/stan/math/rev/core/accumulate_adjoints.hpp @@ -29,6 +29,11 @@ template * = nullptr, typename... Pargs> inline double* accumulate_adjoints(double* dest, EigT&& x, Pargs&&... args); +template * = nullptr, + require_not_st_arithmetic* = nullptr, + typename... Pargs> +inline double* accumulate_adjoints(double* dest, F& f, Pargs&&... args); + template * = nullptr, typename... Pargs> inline double* accumulate_adjoints(double* dest, Arith&& x, Pargs&&... args); @@ -121,6 +126,28 @@ inline double* accumulate_adjoints(double* dest, EigT&& x, Pargs&&... args) { return accumulate_adjoints(dest + x.size(), std::forward(args)...); } +/** + * Accumulate adjoints from f (a closure type containing vars) + * into storage pointed to by dest, + * increment the adjoint storage pointer, + * recursively accumulate the adjoints of the rest of the arguments, + * and return final position of storage pointer. + * + * @tparam F A closure type capturing vars. + * @tparam Pargs Types of remaining arguments + * @param dest Pointer to where adjoints are to be accumulated + * @param f A closure holding vars to accumulate over + * @param args Further args to accumulate over + * @return Final position of adjoint storage pointer + */ +template *, + require_not_st_arithmetic*, + typename... Pargs> +inline double* accumulate_adjoints(double* dest, F& f, Pargs&&... args) { + return accumulate_adjoints(f.accumulate_adjoints__(dest), + std::forward(args)...); +} + /** * Ignore arithmetic types. * diff --git a/stan/math/rev/core/count_vars.hpp b/stan/math/rev/core/count_vars.hpp index b0b536a27ab..d46707ab2b0 100644 --- a/stan/math/rev/core/count_vars.hpp +++ b/stan/math/rev/core/count_vars.hpp @@ -29,6 +29,11 @@ inline size_t count_vars_impl(size_t count, EigT&& x, Pargs&&... args); template inline size_t count_vars_impl(size_t count, const var& x, Pargs&&... args); +template * = nullptr, + require_not_st_arithmetic* = nullptr, + typename... Pargs> +inline size_t count_vars_impl(size_t count, const F& f, Pargs&&... args); + template >* = nullptr, typename... Pargs> inline size_t count_vars_impl(size_t count, Arith& x, Pargs&&... args); @@ -110,6 +115,26 @@ inline size_t count_vars_impl(size_t count, const var& x, Pargs&&... args) { return count_vars_impl(count + 1, std::forward(args)...); } +/** + * Count the number of vars in f (a closure capturing vars), + * add it to the running total, + * count the number of vars in the remaining arguments + * and return the result. + * + * @tparam F A closure type + * @tparam Pargs Types of remaining arguments + * @param[in] count The current count of the number of vars + * @param[in] f A closure holding vars + * @param[in] args objects to be forwarded to recursive call of + * `count_vars_impl` + */ +template *, + require_not_st_arithmetic*, + typename... Pargs> +inline size_t count_vars_impl(size_t count, const F& f, Pargs&&... args) { + return count_vars_impl(count + f.vars_count__, std::forward(args)...); +} + /** * Arguments without vars contribute zero to the total number of vars. * diff --git a/stan/math/rev/core/deep_copy_vars.hpp b/stan/math/rev/core/deep_copy_vars.hpp index 5eb1b3d8542..f46d3f0dc4c 100644 --- a/stan/math/rev/core/deep_copy_vars.hpp +++ b/stan/math/rev/core/deep_copy_vars.hpp @@ -19,7 +19,7 @@ namespace math { * @param arg For lvalue references this will be passed by reference. * Otherwise it will be moved. */ -template >> +template * = nullptr> inline decltype(auto) deep_copy_vars(Arith&& arg) { return std::forward(arg); } @@ -81,6 +81,19 @@ inline auto deep_copy_vars(EigT&& arg) { .eval(); } +/** + * Copy the vars in f but reallocate new varis for them + * + * @tparam F A closure type + * @param f A closure of vars + * @return A new std::vector of vars + */ +template * = nullptr, + require_not_st_arithmetic* = nullptr> +inline auto deep_copy_vars(F&& f) { + return f.deep_copy_vars__(); +} + } // namespace math } // namespace stan diff --git a/stan/math/rev/core/save_varis.hpp b/stan/math/rev/core/save_varis.hpp index 46c1f48799d..d06c75ca693 100644 --- a/stan/math/rev/core/save_varis.hpp +++ b/stan/math/rev/core/save_varis.hpp @@ -29,6 +29,11 @@ template * = nullptr, typename... Pargs> inline vari** save_varis(vari** dest, EigT&& x, Pargs&&... args); +template * = nullptr, + require_not_st_arithmetic* = nullptr, + typename... Pargs> +inline vari** save_varis(vari** dest, F& f, Pargs&&... args); + template * = nullptr, typename... Pargs> inline vari** save_varis(vari** dest, Arith&& x, Pargs&&... args); @@ -118,6 +123,26 @@ inline vari** save_varis(vari** dest, EigT&& x, Pargs&&... args) { return save_varis(dest + x.size(), std::forward(args)...); } +/** + * Save the vari pointers in f into the memory pointed to by dest, + * increment the dest storage pointer, + * recursively call save_varis on the rest of the arguments, + * and return the final value of the dest storage pointer. + * + * @tparam F A closure type with var value type + * @tparam Pargs Types of remaining arguments + * @param[in, out] dest Pointer to where vari pointers are saved + * @param[in] f A closure capturing vars + * @param[in] args Additional arguments to have their varis saved + * @return Final position of dest pointer + */ +template *, + require_not_st_arithmetic*, + typename... Pargs> +inline vari** save_varis(vari** dest, F& f, Pargs&&... args) { + return save_varis(f.save_varis__(dest), std::forward(args)...); +} + /** * Ignore arithmetic types. * diff --git a/stan/math/rev/core/zero_adjoints.hpp b/stan/math/rev/core/zero_adjoints.hpp index 36368d443ee..e68964e6ee0 100644 --- a/stan/math/rev/core/zero_adjoints.hpp +++ b/stan/math/rev/core/zero_adjoints.hpp @@ -16,6 +16,10 @@ inline void zero_adjoints(T& x, Pargs&... args); template inline void zero_adjoints(var& x, Pargs&... args); +template * = nullptr, + require_not_st_arithmetic* = nullptr> +inline void zero_adjoints(F& f, Pargs&... args); + template inline void zero_adjoints(Eigen::Matrix& x, Pargs&... args); @@ -58,6 +62,23 @@ inline void zero_adjoints(var& x, Pargs&... args) { zero_adjoints(args...); } +/** + * Zero the adjoints of the varis of every var in a closure. + * Recursively call zero_adjoints on the rest of the arguments. + * + * @tparam F type of current argument + * @tparam Pargs type of rest of arguments + * + * @param f current argument + * @param args rest of arguments to zero + */ +template *, + require_not_st_arithmetic*> +inline void zero_adjoints(F& f, Pargs&... args) { + f.zero_adjoints__(); + zero_adjoints(args...); +} + /** * Zero the adjoints of the varis of every var in an Eigen::Matrix * container. Recursively call zero_adjoints on the rest of the arguments. diff --git a/stan/math/rev/functor/coupled_ode_system.hpp b/stan/math/rev/functor/coupled_ode_system.hpp index ac278e7665b..4da783d6746 100644 --- a/stan/math/rev/functor/coupled_ode_system.hpp +++ b/stan/math/rev/functor/coupled_ode_system.hpp @@ -65,7 +65,7 @@ namespace math { */ template struct coupled_ode_system_impl { - const F& f_; + F f_; const Eigen::Matrix& y0_; std::tuple()))...> local_args_tuple_; @@ -89,11 +89,11 @@ struct coupled_ode_system_impl { coupled_ode_system_impl(const F& f, const Eigen::Matrix& y0, std::ostream* msgs, const Args&... args) - : f_(f), + : f_(deep_copy_vars(f)), y0_(y0), local_args_tuple_(deep_copy_vars(args)...), num_y0_vars_(count_vars(y0_)), - num_args_vars(count_vars(args...)), + num_args_vars(count_vars(f, args...)), N_(y0.size()), args_adjoints_(num_args_vars), y_adjoints_(N_), @@ -125,7 +125,7 @@ struct coupled_ode_system_impl { y_vars.coeffRef(n) = z[n]; Eigen::Matrix f_y_t_vars - = apply([&](auto&&... args) { return f_(t, y_vars, msgs_, args...); }, + = apply([&](auto&&... args) { return f_(msgs_, t, y_vars, args...); }, local_args_tuple_); check_size_match("coupled_ode_system", "dy_dt", f_y_t_vars.size(), "states", @@ -142,13 +142,14 @@ struct coupled_ode_system_impl { apply( [&](auto&&... args) { - accumulate_adjoints(args_adjoints_.data(), args...); + accumulate_adjoints(args_adjoints_.data(), f_, args...); }, local_args_tuple_); // The vars here do not live on the nested stack so must be zero'd // separately - apply([&](auto&&... args) { zero_adjoints(args...); }, local_args_tuple_); + apply([&](auto&&... args) { zero_adjoints(f_, args...); }, + local_args_tuple_); // No need to zero adjoints after last sweep if (i + 1 < N_) { diff --git a/stan/math/rev/functor/cvodes_integrator.hpp b/stan/math/rev/functor/cvodes_integrator.hpp index 70b33f21740..8b98975284f 100644 --- a/stan/math/rev/functor/cvodes_integrator.hpp +++ b/stan/math/rev/functor/cvodes_integrator.hpp @@ -31,11 +31,12 @@ namespace math { template class cvodes_integrator { - using T_Return = return_type_t; + using T_Return = return_type_t; using T_y0_t0 = return_type_t; const char* function_name_; const F& f_; + typename F::ValueOf__ value_of_f_; const Eigen::Matrix y0_; const T_t0 t0_; const std::vector& ts_; @@ -102,9 +103,9 @@ class cvodes_integrator { inline void rhs(double t, const double y[], double dy_dt[]) const { const Eigen::VectorXd y_vec = Eigen::Map(y, N_); - Eigen::VectorXd dy_dt_vec - = apply([&](auto&&... args) { return f_(t, y_vec, msgs_, args...); }, - value_of_args_tuple_); + Eigen::VectorXd dy_dt_vec = apply( + [&](auto&&... args) { return value_of_f_(msgs_, t, y_vec, args...); }, + value_of_args_tuple_); check_size_match("cvodes_integrator", "dy_dt", dy_dt_vec.size(), "states", N_); @@ -121,8 +122,9 @@ class cvodes_integrator { Eigen::MatrixXd Jfy; auto f_wrapped = [&](const Eigen::Matrix& y) { - return apply([&](auto&&... args) { return f_(t, y, msgs_, args...); }, - value_of_args_tuple_); + return apply( + [&](auto&&... args) { return value_of_f_(msgs_, t, y, args...); }, + value_of_args_tuple_); }; jacobian(f_wrapped, Eigen::Map(y, N_), fy, Jfy); @@ -191,6 +193,7 @@ class cvodes_integrator { std::ostream* msgs, const T_Args&... args) : function_name_(function_name), f_(f), + value_of_f_(value_of(f)), y0_(y0.template cast()), t0_(t0), ts_(ts), @@ -202,7 +205,7 @@ class cvodes_integrator { absolute_tolerance_(absolute_tolerance), max_num_steps_(max_num_steps), num_y0_vars_(count_vars(y0_)), - num_args_vars_(count_vars(args...)), + num_args_vars_(count_vars(f, args...)), coupled_ode_(f, y0_, msgs, args...), coupled_state_(coupled_ode_.initial_state()) { check_finite(function_name, "initial state", y0_); diff --git a/stan/math/rev/functor/integrate_ode_adams.hpp b/stan/math/rev/functor/integrate_ode_adams.hpp index 0cba70a321e..b2cf7373dad 100644 --- a/stan/math/rev/functor/integrate_ode_adams.hpp +++ b/stan/math/rev/functor/integrate_ode_adams.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -14,8 +15,8 @@ namespace math { * @deprecated use ode_adams */ template -std::vector>> + typename T_ts, require_stan_closure_t* = nullptr> +std::vector>> integrate_ode_adams(const F& f, const std::vector& y0, const T_t0& t0, const std::vector& ts, const std::vector& theta, @@ -29,7 +30,7 @@ integrate_ode_adams(const F& f, const std::vector& y0, const T_t0& t0, t0, ts, relative_tolerance, absolute_tolerance, max_num_steps, msgs, theta, x, x_int); - std::vector>> + std::vector>> y_converted; for (size_t i = 0; i < y.size(); ++i) y_converted.push_back(to_array_1d(y[i])); @@ -37,6 +38,23 @@ integrate_ode_adams(const F& f, const std::vector& y0, const T_t0& t0, return y_converted; } +template * = nullptr> +std::vector>> +integrate_ode_adams(const F& f, const std::vector& y0, const T_t0& t0, + const std::vector& ts, + const std::vector& theta, + const std::vector& x, const std::vector& x_int, + std::ostream* msgs = nullptr, + double relative_tolerance = 1e-10, + double absolute_tolerance = 1e-10, + int max_num_steps = 1e8) { + closure_adapter cl(f); + return integrate_ode_adams(cl, y0, t0, ts, theta, x, x_int, msgs, + relative_tolerance, absolute_tolerance, + max_num_steps); +} + } // namespace math } // namespace stan #endif diff --git a/stan/math/rev/functor/integrate_ode_bdf.hpp b/stan/math/rev/functor/integrate_ode_bdf.hpp index c3877bdb875..3b5bf9eef29 100644 --- a/stan/math/rev/functor/integrate_ode_bdf.hpp +++ b/stan/math/rev/functor/integrate_ode_bdf.hpp @@ -2,6 +2,7 @@ #define STAN_MATH_REV_FUNCTOR_INTEGRATE_ODE_BDF_HPP #include +#include #include #include #include @@ -14,8 +15,8 @@ namespace math { * @deprecated use ode_bdf */ template -std::vector>> + typename T_ts, require_stan_closure_t* = nullptr> +std::vector>> integrate_ode_bdf(const F& f, const std::vector& y0, const T_t0& t0, const std::vector& ts, const std::vector& theta, @@ -29,7 +30,7 @@ integrate_ode_bdf(const F& f, const std::vector& y0, const T_t0& t0, ts, relative_tolerance, absolute_tolerance, max_num_steps, msgs, theta, x, x_int); - std::vector>> + std::vector>> y_converted; for (size_t i = 0; i < y.size(); ++i) y_converted.push_back(to_array_1d(y[i])); @@ -37,6 +38,22 @@ integrate_ode_bdf(const F& f, const std::vector& y0, const T_t0& t0, return y_converted; } +template * = nullptr> +std::vector>> +integrate_ode_bdf(const F& f, const std::vector& y0, const T_t0& t0, + const std::vector& ts, + const std::vector& theta, + const std::vector& x, const std::vector& x_int, + std::ostream* msgs = nullptr, + double relative_tolerance = 1e-10, + double absolute_tolerance = 1e-10, int max_num_steps = 1e8) { + closure_adapter cl(f); + return integrate_ode_bdf(cl, y0, t0, ts, theta, x, x_int, msgs, + relative_tolerance, absolute_tolerance, + max_num_steps); +} + } // namespace math } // namespace stan #endif diff --git a/stan/math/rev/functor/ode_adams.hpp b/stan/math/rev/functor/ode_adams.hpp index f02f4695ec2..9d289de7e1e 100644 --- a/stan/math/rev/functor/ode_adams.hpp +++ b/stan/math/rev/functor/ode_adams.hpp @@ -1,6 +1,7 @@ #ifndef STAN_MATH_REV_FUNCTOR_ODE_ADAMS_HPP #define STAN_MATH_REV_FUNCTOR_ODE_ADAMS_HPP +#include #include #include #include @@ -45,8 +46,8 @@ namespace math { * @return Solution to ODE at times \p ts */ template -std::vector, + typename... T_Args, require_stan_closure_t* = nullptr> +std::vector, Eigen::Dynamic, 1>> ode_adams_tol_impl(const char* function_name, const F& f, const Eigen::Matrix& y0, @@ -61,6 +62,19 @@ ode_adams_tol_impl(const char* function_name, const F& f, return integrator(); } +template * = nullptr> +auto ode_adams_tol_impl(const char* function_name, const F& f, + const Eigen::Matrix& y0, + const T_t0& t0, const std::vector& ts, + double relative_tolerance, double absolute_tolerance, + long int max_num_steps, // NOLINT(runtime/int) + std::ostream* msgs, const T_Args&... args) { + internal::ode_closure_adapter cl(f); + return ode_adams_tol_impl(function_name, cl, y0, t0, ts, relative_tolerance, + absolute_tolerance, max_num_steps, msgs, args...); +} + /** * Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of * times, { t1, t2, t3, ... } using the non-stiff Adams-Moulton solver from @@ -96,8 +110,8 @@ ode_adams_tol_impl(const char* function_name, const F& f, */ template -std::vector, - Eigen::Dynamic, 1>> +std::vector, Eigen::Dynamic, 1>> ode_adams_tol(const F& f, const Eigen::Matrix& y0, const T_t0& t0, const std::vector& ts, double relative_tolerance, double absolute_tolerance, @@ -139,8 +153,8 @@ ode_adams_tol(const F& f, const Eigen::Matrix& y0, */ template -std::vector, - Eigen::Dynamic, 1>> +std::vector, Eigen::Dynamic, 1>> ode_adams(const F& f, const Eigen::Matrix& y0, const T_t0& t0, const std::vector& ts, std::ostream* msgs, const T_Args&... args) { diff --git a/stan/math/rev/functor/ode_bdf.hpp b/stan/math/rev/functor/ode_bdf.hpp index 26f348352f9..b6348c91f3c 100644 --- a/stan/math/rev/functor/ode_bdf.hpp +++ b/stan/math/rev/functor/ode_bdf.hpp @@ -1,6 +1,7 @@ #ifndef STAN_MATH_REV_FUNCTOR_ODE_BDF_HPP #define STAN_MATH_REV_FUNCTOR_ODE_BDF_HPP +#include #include #include #include @@ -45,8 +46,8 @@ namespace math { * @return Solution to ODE at times \p ts */ template -std::vector, + typename... T_Args, require_stan_closure_t* = nullptr> +std::vector, Eigen::Dynamic, 1>> ode_bdf_tol_impl(const char* function_name, const F& f, const Eigen::Matrix& y0, @@ -61,6 +62,19 @@ ode_bdf_tol_impl(const char* function_name, const F& f, return integrator(); } +template * = nullptr> +auto ode_bdf_tol_impl(const char* function_name, const F& f, + const Eigen::Matrix& y0, + const T_t0& t0, const std::vector& ts, + double relative_tolerance, double absolute_tolerance, + long int max_num_steps, // NOLINT(runtime/int) + std::ostream* msgs, const T_Args&... args) { + internal::ode_closure_adapter cl(f); + return ode_bdf_tol_impl(function_name, cl, y0, t0, ts, relative_tolerance, + absolute_tolerance, max_num_steps, msgs, args...); +} + /** * Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of * times, { t1, t2, t3, ... } using the stiff backward differentiation formula @@ -96,8 +110,8 @@ ode_bdf_tol_impl(const char* function_name, const F& f, */ template -std::vector, - Eigen::Dynamic, 1>> +std::vector, Eigen::Dynamic, 1>> ode_bdf_tol(const F& f, const Eigen::Matrix& y0, const T_t0& t0, const std::vector& ts, double relative_tolerance, double absolute_tolerance, @@ -139,8 +153,8 @@ ode_bdf_tol(const F& f, const Eigen::Matrix& y0, */ template -std::vector, - Eigen::Dynamic, 1>> +std::vector, Eigen::Dynamic, 1>> ode_bdf(const F& f, const Eigen::Matrix& y0, const T_t0& t0, const std::vector& ts, std::ostream* msgs, const T_Args&... args) { diff --git a/stan/math/rev/functor/ode_store_sensitivities.hpp b/stan/math/rev/functor/ode_store_sensitivities.hpp index 8a93ff846df..39e005adf0b 100644 --- a/stan/math/rev/functor/ode_store_sensitivities.hpp +++ b/stan/math/rev/functor/ode_store_sensitivities.hpp @@ -31,7 +31,7 @@ namespace math { */ template , T_y0_t0, T_t0, T_t, scalar_type_t...>* = nullptr> Eigen::Matrix ode_store_sensitivities( const F& f, const std::vector& coupled_state, @@ -39,7 +39,7 @@ Eigen::Matrix ode_store_sensitivities( const T_t& t, std::ostream* msgs, const Args&... args) { const size_t N = y0.size(); const size_t num_y0_vars = count_vars(y0); - const size_t num_args_vars = count_vars(args...); + const size_t num_args_vars = count_vars(f, args...); const size_t num_t0_vars = count_vars(t0); const size_t num_t_vars = count_vars(t); Eigen::Matrix yt(N); @@ -51,12 +51,12 @@ Eigen::Matrix ode_store_sensitivities( Eigen::VectorXd f_y_t; if (is_var::value) - f_y_t = f(value_of(t), y, msgs, eval(value_of(args))...); + f_y_t = value_of(f)(msgs, value_of(t), y, eval(value_of(args))...); Eigen::VectorXd f_y0_t0; if (is_var::value) - f_y0_t0 - = f(value_of(t0), eval(value_of(y0)), msgs, eval(value_of(args))...); + f_y0_t0 = value_of(f)(msgs, value_of(t0), eval(value_of(y0)), + eval(value_of(args))...); const size_t total_vars = num_y0_vars + num_args_vars + num_t0_vars + num_t_vars; @@ -64,7 +64,7 @@ Eigen::Matrix ode_store_sensitivities( vari** varis = ChainableStack::instance_->memalloc_.alloc_array(total_vars); - save_varis(varis, y0, args..., t0, t); + save_varis(varis, y0, f, args..., t0, t); // memory for a column major jacobian double* jacobian_mem diff --git a/test/unit/math/prim/functor/coupled_ode_system_test.cpp b/test/unit/math/prim/functor/coupled_ode_system_test.cpp index 679d2b33b7e..b67571ff935 100644 --- a/test/unit/math/prim/functor/coupled_ode_system_test.cpp +++ b/test/unit/math/prim/functor/coupled_ode_system_test.cpp @@ -15,7 +15,9 @@ struct StanMathCoupledOdeSystem : public ::testing::Test { TEST_F(StanMathCoupledOdeSystem, initial_state_dd) { using stan::math::coupled_ode_system; - mock_ode_functor base_ode; + using adapted_ode_functor + = stan::math::internal::ode_closure_adapter; + adapted_ode_functor base_ode{mock_ode_functor()}; const int N = 3; const int M = 4; @@ -28,7 +30,7 @@ TEST_F(StanMathCoupledOdeSystem, initial_state_dd) { for (int m = 0; m < M; m++) theta_d[m] = 10 * (m + 1); - coupled_ode_system, + coupled_ode_system, std::vector, std::vector> coupled_system_dd(base_ode, y0_d, &msgs, theta_d, x, x_int); @@ -43,7 +45,9 @@ TEST_F(StanMathCoupledOdeSystem, initial_state_dd) { TEST_F(StanMathCoupledOdeSystem, size) { using stan::math::coupled_ode_system; - mock_ode_functor base_ode; + using adapted_ode_functor + = stan::math::internal::ode_closure_adapter; + adapted_ode_functor base_ode{mock_ode_functor()}; const int N = 3; const int M = 4; @@ -51,7 +55,7 @@ TEST_F(StanMathCoupledOdeSystem, size) { Eigen::VectorXd y0_d(N); std::vector theta_d(M, 0.0); - coupled_ode_system + coupled_ode_system coupled_system_dd(base_ode, y0_d, &msgs, 1, 1.0, y0_d); EXPECT_EQ(N, coupled_system_dd.size()); @@ -59,18 +63,21 @@ TEST_F(StanMathCoupledOdeSystem, size) { TEST_F(StanMathCoupledOdeSystem, recover_exception) { using stan::math::coupled_ode_system; + using adapted_ode_functor = stan::math::internal::ode_closure_adapter< + mock_throwing_ode_functor>; std::string message = "ode throws"; const int N = 3; const int M = 4; - mock_throwing_ode_functor throwing_ode(message); + adapted_ode_functor throwing_ode{ + mock_throwing_ode_functor(message)}; Eigen::VectorXd y0_d(N); std::vector theta_v(M); - coupled_ode_system, double, - std::vector, std::vector, std::vector> + coupled_ode_system, + std::vector, std::vector> coupled_system_dd(throwing_ode, y0_d, &msgs, theta_v, x, x_int); std::vector y(3); diff --git a/test/unit/math/prim/functor/integrate_ode_std_vector_interface_adapter_test.cpp b/test/unit/math/prim/functor/integrate_ode_std_vector_interface_adapter_test.cpp index 1d7f484a356..1d6fb20bf83 100644 --- a/test/unit/math/prim/functor/integrate_ode_std_vector_interface_adapter_test.cpp +++ b/test/unit/math/prim/functor/integrate_ode_std_vector_interface_adapter_test.cpp @@ -5,9 +5,10 @@ #include TEST(StanMath, check_values) { - harm_osc_ode_data_fun harm_osc; - stan::math::internal::integrate_ode_std_vector_interface_adapter< - harm_osc_ode_data_fun> + stan::math::closure_adapter harm_osc{ + harm_osc_ode_data_fun()}; + stan::math::internal::integrate_ode_std_vector_interface_adapter harm_osc_adapted(harm_osc); std::vector theta = {0.15}; @@ -19,9 +20,9 @@ TEST(StanMath, check_values) { double t = 1.0; Eigen::VectorXd out1 - = stan::math::to_vector(harm_osc(t, y, theta, x, x_int, nullptr)); + = stan::math::to_vector(harm_osc(nullptr, t, y, theta, x, x_int)); Eigen::VectorXd out2 - = harm_osc_adapted(t, stan::math::to_vector(y), nullptr, theta, x, x_int); + = harm_osc_adapted(nullptr, t, stan::math::to_vector(y), theta, x, x_int); EXPECT_MATRIX_FLOAT_EQ(out1, out2); } diff --git a/test/unit/math/prim/functor/ode_store_sensitivities_test.cpp b/test/unit/math/prim/functor/ode_store_sensitivities_test.cpp index acef36e6884..e3db0a860e3 100644 --- a/test/unit/math/prim/functor/ode_store_sensitivities_test.cpp +++ b/test/unit/math/prim/functor/ode_store_sensitivities_test.cpp @@ -6,7 +6,7 @@ #include TEST(MathPrim, ode_store_sensitivities) { - mock_ode_functor base_ode; + stan::math::closure_adapter base_ode{mock_ode_functor()}; size_t N = 5; diff --git a/test/unit/math/rev/functor/coupled_ode_system_test.cpp b/test/unit/math/rev/functor/coupled_ode_system_test.cpp index e513a20a3c3..62f5e86062d 100644 --- a/test/unit/math/rev/functor/coupled_ode_system_test.cpp +++ b/test/unit/math/rev/functor/coupled_ode_system_test.cpp @@ -17,11 +17,13 @@ struct StanAgradRevOde : public ::testing::Test { // ******************** DV **************************** TEST_F(StanAgradRevOde, coupled_ode_system_dv) { using stan::math::coupled_ode_system; + using stan::math::internal::ode_closure_adapter; // Run nested autodiff in this scope stan::math::nested_rev_autodiff nested; - harm_osc_ode_fun_eigen harm_osc; + ode_closure_adapter harm_osc{ + harm_osc_ode_fun_eigen()}; std::vector theta; std::vector z0; @@ -61,7 +63,8 @@ TEST_F(StanAgradRevOde, coupled_ode_system_dv) { TEST_F(StanAgradRevOde, initial_state_dv) { using stan::math::coupled_ode_system; using stan::math::var; - mock_ode_functor base_ode; + using stan::math::internal::ode_closure_adapter; + ode_closure_adapter base_ode{mock_ode_functor()}; const size_t N = 3; const size_t M = 4; @@ -91,7 +94,8 @@ TEST_F(StanAgradRevOde, initial_state_dv) { TEST_F(StanAgradRevOde, size_dv) { using stan::math::coupled_ode_system; using stan::math::var; - mock_ode_functor base_ode; + using stan::math::internal::ode_closure_adapter; + ode_closure_adapter base_ode{mock_ode_functor()}; const size_t N = 3; const size_t M = 4; @@ -110,7 +114,8 @@ TEST_F(StanAgradRevOde, size_dv) { TEST_F(StanAgradRevOde, memory_recovery_dv) { using stan::math::coupled_ode_system; using stan::math::var; - mock_ode_functor base_ode; + using stan::math::internal::ode_closure_adapter; + ode_closure_adapter base_ode{mock_ode_functor()}; const size_t N = 3; const size_t M = 4; @@ -135,6 +140,7 @@ TEST_F(StanAgradRevOde, memory_recovery_dv) { TEST_F(StanAgradRevOde, memory_recovery_exception_dv) { using stan::math::coupled_ode_system; using stan::math::var; + using stan::math::internal::ode_closure_adapter; std::string message = "ode throws"; const size_t N = 3; @@ -145,7 +151,8 @@ TEST_F(StanAgradRevOde, memory_recovery_exception_dv) { std::stringstream scoped_message; scoped_message << "iteration " << n; SCOPED_TRACE(scoped_message.str()); - mock_throwing_ode_functor throwing_ode(message, 1); + ode_closure_adapter> + throwing_ode{mock_throwing_ode_functor(message, 1)}; Eigen::VectorXd y0_d = Eigen::VectorXd::Zero(N); std::vector theta_v(M, 0.0); @@ -169,11 +176,13 @@ TEST_F(StanAgradRevOde, memory_recovery_exception_dv) { TEST_F(StanAgradRevOde, coupled_ode_system_vd) { using stan::math::coupled_ode_system; + using stan::math::internal::ode_closure_adapter; // Run nested autodiff in this scope stan::math::nested_rev_autodiff nested; - harm_osc_ode_fun_eigen harm_osc; + ode_closure_adapter harm_osc{ + harm_osc_ode_fun_eigen()}; std::vector theta; std::vector z0; @@ -220,7 +229,8 @@ TEST_F(StanAgradRevOde, coupled_ode_system_vd) { TEST_F(StanAgradRevOde, initial_state_vd) { using stan::math::coupled_ode_system; using stan::math::var; - mock_ode_functor base_ode; + using stan::math::internal::ode_closure_adapter; + ode_closure_adapter base_ode{mock_ode_functor()}; const size_t N = 3; const size_t M = 4; @@ -251,7 +261,8 @@ TEST_F(StanAgradRevOde, initial_state_vd) { TEST_F(StanAgradRevOde, size_vd) { using stan::math::coupled_ode_system; using stan::math::var; - mock_ode_functor base_ode; + using stan::math::internal::ode_closure_adapter; + ode_closure_adapter base_ode{mock_ode_functor()}; const size_t N = 3; const size_t M = 4; @@ -271,7 +282,8 @@ TEST_F(StanAgradRevOde, size_vd) { TEST_F(StanAgradRevOde, memory_recovery_vd) { using stan::math::coupled_ode_system; using stan::math::var; - mock_ode_functor base_ode; + using stan::math::internal::ode_closure_adapter; + ode_closure_adapter base_ode{mock_ode_functor()}; const size_t N = 3; const size_t M = 4; @@ -297,6 +309,7 @@ TEST_F(StanAgradRevOde, memory_recovery_vd) { TEST_F(StanAgradRevOde, memory_recovery_exception_vd) { using stan::math::coupled_ode_system; using stan::math::var; + using stan::math::internal::ode_closure_adapter; std::string message = "ode throws"; const size_t N = 3; @@ -307,7 +320,8 @@ TEST_F(StanAgradRevOde, memory_recovery_exception_vd) { std::stringstream scoped_message; scoped_message << "iteration " << n; SCOPED_TRACE(scoped_message.str()); - mock_throwing_ode_functor throwing_ode(message, 1); + ode_closure_adapter> + throwing_ode{mock_throwing_ode_functor(message, 1)}; Eigen::Matrix y0_v = Eigen::VectorXd::Zero(N).template cast(); @@ -332,6 +346,7 @@ TEST_F(StanAgradRevOde, memory_recovery_exception_vd) { TEST_F(StanAgradRevOde, coupled_ode_system_vv) { using stan::math::coupled_ode_system; + using stan::math::internal::ode_closure_adapter; // Run nested autodiff in this scope stan::math::nested_rev_autodiff nested; @@ -346,7 +361,8 @@ TEST_F(StanAgradRevOde, coupled_ode_system_vv) { std::vector theta_var(1); theta_var[0] = 0.15; - harm_osc_ode_fun_eigen harm_osc; + ode_closure_adapter harm_osc{ + harm_osc_ode_fun_eigen()}; std::size_t stack_size = stan::math::nested_size(); @@ -377,7 +393,7 @@ TEST_F(StanAgradRevOde, coupled_ode_system_vv) { theta_double[0] = 0.15; Eigen::VectorXd dy_dt_base - = harm_osc(0.0, y0_double, &msgs, theta_double, x, x_int); + = harm_osc(&msgs, 0.0, y0_double, theta_double, x, x_int); EXPECT_FLOAT_EQ(dy_dt_base[0], dz_dt[0]); EXPECT_FLOAT_EQ(dy_dt_base[1], dz_dt[1]); @@ -392,7 +408,8 @@ TEST_F(StanAgradRevOde, coupled_ode_system_vv) { TEST_F(StanAgradRevOde, initial_state_vv) { using stan::math::coupled_ode_system; using stan::math::var; - mock_ode_functor base_ode; + using stan::math::internal::ode_closure_adapter; + ode_closure_adapter base_ode{mock_ode_functor()}; const size_t N = 3; const size_t M = 4; @@ -424,7 +441,8 @@ TEST_F(StanAgradRevOde, initial_state_vv) { TEST_F(StanAgradRevOde, size_vv) { using stan::math::coupled_ode_system; using stan::math::var; - mock_ode_functor base_ode; + using stan::math::internal::ode_closure_adapter; + ode_closure_adapter base_ode{mock_ode_functor()}; const size_t N = 3; const size_t M = 4; @@ -445,7 +463,8 @@ TEST_F(StanAgradRevOde, size_vv) { TEST_F(StanAgradRevOde, memory_recovery_vv) { using stan::math::coupled_ode_system; using stan::math::var; - mock_ode_functor base_ode; + using stan::math::internal::ode_closure_adapter; + ode_closure_adapter base_ode{mock_ode_functor()}; const size_t N = 3; const size_t M = 4; @@ -472,6 +491,7 @@ TEST_F(StanAgradRevOde, memory_recovery_vv) { TEST_F(StanAgradRevOde, memory_recovery_exception_vv) { using stan::math::coupled_ode_system; using stan::math::var; + using stan::math::internal::ode_closure_adapter; std::string message = "ode throws"; const size_t N = 3; @@ -482,7 +502,8 @@ TEST_F(StanAgradRevOde, memory_recovery_exception_vv) { std::stringstream scoped_message; scoped_message << "iteration " << n; SCOPED_TRACE(scoped_message.str()); - mock_throwing_ode_functor throwing_ode(message, 1); + ode_closure_adapter> + throwing_ode{mock_throwing_ode_functor(message, 1)}; Eigen::Matrix y0_v = Eigen::VectorXd::Zero(N).template cast(); @@ -539,6 +560,7 @@ struct ayt { TEST_F(StanAgradRevOde, coupled_ode_system_var) { using stan::math::coupled_ode_system; using stan::math::var; + using stan::math::internal::ode_closure_adapter; Eigen::VectorXd y0(2); y0 << 0.1, 0.2; @@ -547,10 +569,10 @@ TEST_F(StanAgradRevOde, coupled_ode_system_var) { double a = 1.3; var av = a; - ayt func; + ode_closure_adapter func{ayt()}; - coupled_ode_system system(func, y0v, - &msgs, av); + coupled_ode_system system( + func, y0v, &msgs, av); std::vector z = {y0(0), y0(1), 3.2, 3.3, 4.4, 4.5, 2.1, 2.2}; @@ -572,6 +594,7 @@ TEST_F(StanAgradRevOde, coupled_ode_system_var) { TEST_F(StanAgradRevOde, coupled_ode_system_std_vector) { using stan::math::coupled_ode_system; using stan::math::var; + using stan::math::internal::ode_closure_adapter; Eigen::VectorXd y0(2); y0 << 0.1, 0.2; @@ -579,10 +602,10 @@ TEST_F(StanAgradRevOde, coupled_ode_system_std_vector) { std::vector av = {1.3}; - ayt func; + ode_closure_adapter func{ayt()}; - coupled_ode_system system(func, y0v, - &msgs, av); + coupled_ode_system system( + func, y0v, &msgs, av); std::vector z = {y0(0), y0(1), 3.2, 3.3, 4.4, 4.5, 2.1, 2.2}; @@ -604,6 +627,7 @@ TEST_F(StanAgradRevOde, coupled_ode_system_std_vector) { TEST_F(StanAgradRevOde, coupled_ode_system_vector) { using stan::math::coupled_ode_system; using stan::math::var; + using stan::math::internal::ode_closure_adapter; Eigen::VectorXd y0(2); y0 << 0.1, 0.2; @@ -612,10 +636,10 @@ TEST_F(StanAgradRevOde, coupled_ode_system_vector) { Eigen::Matrix av(1); av << 1.3; - ayt func; + ode_closure_adapter func{ayt()}; - coupled_ode_system system(func, y0v, - &msgs, av); + coupled_ode_system system( + func, y0v, &msgs, av); std::vector z = {y0(0), y0(1), 3.2, 3.3, 4.4, 4.5, 2.1, 2.2}; @@ -637,6 +661,7 @@ TEST_F(StanAgradRevOde, coupled_ode_system_vector) { TEST_F(StanAgradRevOde, coupled_ode_system_row_vector) { using stan::math::coupled_ode_system; using stan::math::var; + using stan::math::internal::ode_closure_adapter; Eigen::VectorXd y0(2); y0 << 0.1, 0.2; @@ -645,10 +670,10 @@ TEST_F(StanAgradRevOde, coupled_ode_system_row_vector) { Eigen::Matrix av(1); av << 1.3; - ayt func; + ode_closure_adapter func{ayt()}; - coupled_ode_system system(func, y0v, - &msgs, av); + coupled_ode_system system( + func, y0v, &msgs, av); std::vector z = {y0(0), y0(1), 3.2, 3.3, 4.4, 4.5, 2.1, 2.2}; @@ -670,6 +695,7 @@ TEST_F(StanAgradRevOde, coupled_ode_system_row_vector) { TEST_F(StanAgradRevOde, coupled_ode_system_matrix) { using stan::math::coupled_ode_system; using stan::math::var; + using stan::math::internal::ode_closure_adapter; Eigen::VectorXd y0(2); y0 << 0.1, 0.2; @@ -678,10 +704,10 @@ TEST_F(StanAgradRevOde, coupled_ode_system_matrix) { Eigen::Matrix av(1, 1); av << 1.3; - ayt func; + ode_closure_adapter func{ayt()}; - coupled_ode_system system(func, y0v, - &msgs, av); + coupled_ode_system system( + func, y0v, &msgs, av); std::vector z = {y0(0), y0(1), 3.2, 3.3, 4.4, 4.5, 2.1, 2.2}; @@ -703,6 +729,7 @@ TEST_F(StanAgradRevOde, coupled_ode_system_matrix) { TEST_F(StanAgradRevOde, coupled_ode_system_extra_args) { using stan::math::coupled_ode_system; using stan::math::var; + using stan::math::internal::ode_closure_adapter; Eigen::VectorXd y0(2); y0 << 0.1, 0.2; @@ -720,11 +747,11 @@ TEST_F(StanAgradRevOde, coupled_ode_system_extra_args) { Eigen::MatrixXd e6(1, 1); e6 << 0.1; - ayt func; + ode_closure_adapter func{ayt()}; - coupled_ode_system + coupled_ode_system system(func, y0v, &msgs, av, e1, e2, e3, e4, e5, e6); std::vector z = {y0(0), y0(1), 3.2, 3.3, 4.4, 4.5, 2.1, 2.2}; diff --git a/test/unit/math/rev/functor/integrate_ode_std_vector_interface_adapter_test.cpp b/test/unit/math/rev/functor/integrate_ode_std_vector_interface_adapter_test.cpp index f8611f7f57a..87512e58fce 100644 --- a/test/unit/math/rev/functor/integrate_ode_std_vector_interface_adapter_test.cpp +++ b/test/unit/math/rev/functor/integrate_ode_std_vector_interface_adapter_test.cpp @@ -8,8 +8,8 @@ TEST(StanMathRev, vd) { using stan::math::var; harm_osc_ode_data_fun harm_osc; stan::math::internal::integrate_ode_std_vector_interface_adapter< - harm_osc_ode_data_fun> - harm_osc_adapted(harm_osc); + stan::math::closure_adapter> + harm_osc_adapted{stan::math::from_lambda(harm_osc)}; std::vector theta = {0.15}; std::vector y = {1.0, 0.5}; @@ -22,7 +22,7 @@ TEST(StanMathRev, vd) { Eigen::Matrix out1 = stan::math::to_vector(harm_osc(t, y, theta, x, x_int, nullptr)); Eigen::Matrix out2 - = harm_osc_adapted(t, stan::math::to_vector(y), nullptr, theta, x, x_int); + = harm_osc_adapted(nullptr, t, stan::math::to_vector(y), theta, x, x_int); stan::math::sum(out1).grad(); Eigen::VectorXd adjs1(theta.size()); @@ -44,8 +44,8 @@ TEST(StanMathRev, dv) { using stan::math::var; harm_osc_ode_data_fun harm_osc; stan::math::internal::integrate_ode_std_vector_interface_adapter< - harm_osc_ode_data_fun> - harm_osc_adapted(harm_osc); + stan::math::closure_adapter> + harm_osc_adapted{stan::math::from_lambda(harm_osc)}; std::vector theta = {0.15}; std::vector y = {1.0, 0.5}; @@ -58,7 +58,7 @@ TEST(StanMathRev, dv) { Eigen::Matrix out1 = stan::math::to_vector(harm_osc(t, y, theta, x, x_int, nullptr)); Eigen::Matrix out2 - = harm_osc_adapted(t, stan::math::to_vector(y), nullptr, theta, x, x_int); + = harm_osc_adapted(nullptr, t, stan::math::to_vector(y), theta, x, x_int); stan::math::sum(out1).grad(); Eigen::VectorXd adjs1(y.size()); @@ -80,8 +80,8 @@ TEST(StanMathRev, vv) { using stan::math::var; harm_osc_ode_data_fun harm_osc; stan::math::internal::integrate_ode_std_vector_interface_adapter< - harm_osc_ode_data_fun> - harm_osc_adapted(harm_osc); + stan::math::closure_adapter> + harm_osc_adapted{stan::math::from_lambda(harm_osc)}; std::vector theta = {0.15}; std::vector y = {1.0, 0.5}; @@ -94,7 +94,7 @@ TEST(StanMathRev, vv) { Eigen::Matrix out1 = stan::math::to_vector(harm_osc(t, y, theta, x, x_int, nullptr)); Eigen::Matrix out2 - = harm_osc_adapted(t, stan::math::to_vector(y), nullptr, theta, x, x_int); + = harm_osc_adapted(nullptr, t, stan::math::to_vector(y), theta, x, x_int); stan::math::sum(out1).grad(); Eigen::VectorXd adjs_theta_1(theta.size()); diff --git a/test/unit/math/rev/functor/ode_rk45_rev_test.cpp b/test/unit/math/rev/functor/ode_rk45_rev_test.cpp index a1665e53e35..6d021038b06 100644 --- a/test/unit/math/rev/functor/ode_rk45_rev_test.cpp +++ b/test/unit/math/rev/functor/ode_rk45_rev_test.cpp @@ -265,6 +265,86 @@ TEST(StanMathOde_ode_rk45, scalar_std_vector_args) { EXPECT_NEAR(a1[0].adj(), -0.50107310888, 1e-5); } +TEST(StanMathOde_ode_rk45, closure_var) { + using stan::math::var; + + Eigen::VectorXd y0 = Eigen::VectorXd::Zero(1); + double t0 = 0.0; + std::vector ts = {1.1}; + + var a0 = 0.75; + std::vector a1 = {0.75}; + + auto f = stan::math::from_lambda( + [&](const auto& a, const auto& t, const auto& y, const auto& b, + std::ostream* msgs) { + return stan::test::Cos2Arg()(t, y, msgs, a, b); + }, + a0); + + var output = stan::math::ode_rk45(f, y0, t0, ts, nullptr, a1)[0][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(a0.adj(), -0.50107310888, 1e-5); + EXPECT_NEAR(a1[0].adj(), -0.50107310888, 1e-5); +} + +TEST(StanMathOde_ode_rk45, closure_double) { + using stan::math::var; + + Eigen::VectorXd y0 = Eigen::VectorXd::Zero(1); + double t0 = 0.0; + std::vector ts = {1.1}; + + var a0 = 0.75; + std::vector a1 = {0.75}; + + auto f = stan::math::from_lambda( + [](const auto& a, const auto& t, const auto& y, const auto& b, + std::ostream* msgs) { + return stan::test::Cos2Arg()(t, y, msgs, a, b); + }, + a0); + + var output = stan::math::ode_rk45(f, y0, t0, ts, nullptr, a1)[0][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(a0.adj(), -0.50107310888, 1e-5); +} + +TEST(StanMathOde_ode_rk45, higher_order) { + using stan::math::var; + + Eigen::VectorXd y0 = Eigen::VectorXd::Zero(1); + double t0 = 0.0; + std::vector ts = {1.1}; + + var a0 = 0.75; + std::vector a1 = {0.75}; + + auto f = stan::math::from_lambda( + [](const auto& a, const auto& t, const auto& y, const auto& b, + std::ostream* msgs) { + return stan::test::Cos2Arg()(t, y, msgs, a, b); + }, + a0); + + auto wrapper + = [](const auto& t, const auto& y, std::ostream* msgs, const auto& fa, + const auto& b) { return fa(msgs, t, y, b); }; + + var output = stan::math::ode_rk45(wrapper, y0, t0, ts, nullptr, f, a1)[0][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(a0.adj(), -0.50107310888, 1e-5); +} + TEST(StanMathOde_ode_rk45, std_vector_std_vector_args) { using stan::math::var; diff --git a/test/unit/math/rev/functor/ode_store_sensitivities_test.cpp b/test/unit/math/rev/functor/ode_store_sensitivities_test.cpp index b49028f03e9..5c15f33ac88 100644 --- a/test/unit/math/rev/functor/ode_store_sensitivities_test.cpp +++ b/test/unit/math/rev/functor/ode_store_sensitivities_test.cpp @@ -61,7 +61,7 @@ TEST(AgradRev, ode_store_sensitivities) { double a = 1.3; var av = a; - ayt func; + stan::math::internal::ode_closure_adapter func{ayt()}; double t0 = 0.5; double t = 0.75; @@ -110,7 +110,7 @@ TEST(AgradRev, ode_store_sensitivities_matrix) { Eigen::Matrix av(1, 1); av(0, 0) = a; - aytm func; + stan::math::internal::ode_closure_adapter func{aytm()}; double t0 = 0.5; double t = 0.75; From c0eb7cd4027634a2df47473d2a733653108034d2 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Tue, 22 Sep 2020 15:22:20 +0000 Subject: [PATCH 2/5] [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.04.1 (tags/RELEASE_600/final) --- stan/math/prim/functor/coupled_ode_system.hpp | 4 ++-- stan/math/rev/core/accumulate_adjoints.hpp | 6 ++---- stan/math/rev/core/count_vars.hpp | 6 ++---- stan/math/rev/core/save_varis.hpp | 6 ++---- 4 files changed, 8 insertions(+), 14 deletions(-) diff --git a/stan/math/prim/functor/coupled_ode_system.hpp b/stan/math/prim/functor/coupled_ode_system.hpp index 166aa21a427..c5b9a58adb2 100644 --- a/stan/math/prim/functor/coupled_ode_system.hpp +++ b/stan/math/prim/functor/coupled_ode_system.hpp @@ -110,8 +110,8 @@ struct coupled_ode_system const Eigen::Matrix& y0, std::ostream* msgs, const Args&... args) : coupled_ode_system_impl< - std::is_arithmetic>::value, F, T_y0, - Args...>(f, y0, msgs, args...) {} + std::is_arithmetic>::value, F, T_y0, + Args...>(f, y0, msgs, args...) {} }; } // namespace math diff --git a/stan/math/rev/core/accumulate_adjoints.hpp b/stan/math/rev/core/accumulate_adjoints.hpp index 670618dfebb..1a95dba46af 100644 --- a/stan/math/rev/core/accumulate_adjoints.hpp +++ b/stan/math/rev/core/accumulate_adjoints.hpp @@ -30,8 +30,7 @@ template * = nullptr, inline double* accumulate_adjoints(double* dest, EigT&& x, Pargs&&... args); template * = nullptr, - require_not_st_arithmetic* = nullptr, - typename... Pargs> + require_not_st_arithmetic* = nullptr, typename... Pargs> inline double* accumulate_adjoints(double* dest, F& f, Pargs&&... args); template * = nullptr, @@ -140,8 +139,7 @@ inline double* accumulate_adjoints(double* dest, EigT&& x, Pargs&&... args) { * @param args Further args to accumulate over * @return Final position of adjoint storage pointer */ -template *, - require_not_st_arithmetic*, +template *, require_not_st_arithmetic*, typename... Pargs> inline double* accumulate_adjoints(double* dest, F& f, Pargs&&... args) { return accumulate_adjoints(f.accumulate_adjoints__(dest), diff --git a/stan/math/rev/core/count_vars.hpp b/stan/math/rev/core/count_vars.hpp index d46707ab2b0..78bcdcef699 100644 --- a/stan/math/rev/core/count_vars.hpp +++ b/stan/math/rev/core/count_vars.hpp @@ -30,8 +30,7 @@ template inline size_t count_vars_impl(size_t count, const var& x, Pargs&&... args); template * = nullptr, - require_not_st_arithmetic* = nullptr, - typename... Pargs> + require_not_st_arithmetic* = nullptr, typename... Pargs> inline size_t count_vars_impl(size_t count, const F& f, Pargs&&... args); template >* = nullptr, @@ -128,8 +127,7 @@ inline size_t count_vars_impl(size_t count, const var& x, Pargs&&... args) { * @param[in] args objects to be forwarded to recursive call of * `count_vars_impl` */ -template *, - require_not_st_arithmetic*, +template *, require_not_st_arithmetic*, typename... Pargs> inline size_t count_vars_impl(size_t count, const F& f, Pargs&&... args) { return count_vars_impl(count + f.vars_count__, std::forward(args)...); diff --git a/stan/math/rev/core/save_varis.hpp b/stan/math/rev/core/save_varis.hpp index d06c75ca693..3c66e0fbf69 100644 --- a/stan/math/rev/core/save_varis.hpp +++ b/stan/math/rev/core/save_varis.hpp @@ -30,8 +30,7 @@ template * = nullptr, inline vari** save_varis(vari** dest, EigT&& x, Pargs&&... args); template * = nullptr, - require_not_st_arithmetic* = nullptr, - typename... Pargs> + require_not_st_arithmetic* = nullptr, typename... Pargs> inline vari** save_varis(vari** dest, F& f, Pargs&&... args); template * = nullptr, @@ -136,8 +135,7 @@ inline vari** save_varis(vari** dest, EigT&& x, Pargs&&... args) { * @param[in] args Additional arguments to have their varis saved * @return Final position of dest pointer */ -template *, - require_not_st_arithmetic*, +template *, require_not_st_arithmetic*, typename... Pargs> inline vari** save_varis(vari** dest, F& f, Pargs&&... args) { return save_varis(f.save_varis__(dest), std::forward(args)...); From ed0c7ffac21e5288f0f5e245137a82aad5b419d4 Mon Sep 17 00:00:00 2001 From: Ben Date: Mon, 16 Nov 2020 14:58:20 -0500 Subject: [PATCH 3/5] Worked on reduce_sum a bit (Issue #2197) --- stan/math/prim/functor/reduce_sum.hpp | 40 ++++++--- stan/math/prim/functor/reduce_sum_static.hpp | 25 ++++-- stan/math/rev/functor/reduce_sum.hpp | 55 ++++++++---- .../rev/functor/reduce_sum_closure_test.cpp | 88 +++++++++++++++++++ 4 files changed, 173 insertions(+), 35 deletions(-) create mode 100644 test/unit/math/rev/functor/reduce_sum_closure_test.cpp diff --git a/stan/math/prim/functor/reduce_sum.hpp b/stan/math/prim/functor/reduce_sum.hpp index 927189137c0..cbb5471460e 100644 --- a/stan/math/prim/functor/reduce_sum.hpp +++ b/stan/math/prim/functor/reduce_sum.hpp @@ -45,12 +45,14 @@ struct reduce_sum_impl, struct recursive_reducer { Vec vmapped_; std::ostream* msgs_; + const ReduceFunction& f_; std::tuple args_tuple_; return_type_t sum_{0.0}; - recursive_reducer(Vec&& vmapped, std::ostream* msgs, Args&&... args) + recursive_reducer(Vec&& vmapped, std::ostream* msgs, + const ReduceFunction& f, Args&&... args) : vmapped_(std::forward(vmapped)), - msgs_(msgs), + msgs_(msgs), f_(f), args_tuple_(std::forward(args)...) {} /** @@ -61,7 +63,7 @@ struct reduce_sum_impl, */ recursive_reducer(recursive_reducer& other, tbb::split) : vmapped_(other.vmapped_), - msgs_(other.msgs_), + msgs_(other.msgs_), f_(other.f_), args_tuple_(other.args_tuple_) {} /** @@ -85,8 +87,8 @@ struct reduce_sum_impl, sum_ += apply( [&](auto&&... args) { - return ReduceFunction()(sub_slice, r.begin(), r.end() - 1, msgs_, - args...); + return f_(msgs_, sub_slice, r.begin(), r.end() - 1, + args...); }, args_tuple_); } @@ -143,13 +145,14 @@ struct reduce_sum_impl, */ inline ReturnType operator()(Vec&& vmapped, bool auto_partitioning, int grainsize, std::ostream* msgs, + const ReduceFunction& f, Args&&... args) const { const std::size_t num_terms = vmapped.size(); if (vmapped.empty()) { return 0.0; } recursive_reducer worker(std::forward(vmapped), msgs, - std::forward(args)...); + f, std::forward(args)...); if (auto_partitioning) { tbb::parallel_reduce( @@ -192,28 +195,41 @@ struct reduce_sum_impl, * @return Sum of terms */ template , typename... Args> + typename = require_vector_like_t, + require_stan_closure_t* = nullptr, + typename... Args> inline auto reduce_sum(Vec&& vmapped, int grainsize, std::ostream* msgs, - Args&&... args) { - using return_type = return_type_t; + const ReduceFunction& f, Args&&... args) { + using return_type = return_type_t; check_positive("reduce_sum", "grainsize", grainsize); #ifdef STAN_THREADS return internal::reduce_sum_impl()(std::forward(vmapped), true, - grainsize, msgs, + grainsize, msgs, f, std::forward(args)...); #else if (vmapped.empty()) { return return_type(0.0); } - return ReduceFunction()(std::forward(vmapped), 0, vmapped.size() - 1, - msgs, std::forward(args)...); + return f(msgs, std::forward(vmapped), 0, vmapped.size() - 1, + std::forward(args)...); #endif } +template , + require_not_stan_closure_t* = nullptr, + typename... Args> +inline auto reduce_sum(Vec&& vmapped, int grainsize, std::ostream* msgs, + Args&&... args) { + ReduceFunction f; + closure_adapter cl(f); + return reduce_sum(vmapped, grainsize, msgs, cl, args...); +} + } // namespace math } // namespace stan diff --git a/stan/math/prim/functor/reduce_sum_static.hpp b/stan/math/prim/functor/reduce_sum_static.hpp index 235a6d8c1b6..1f3e1fd52ec 100644 --- a/stan/math/prim/functor/reduce_sum_static.hpp +++ b/stan/math/prim/functor/reduce_sum_static.hpp @@ -41,28 +41,41 @@ namespace math { * @return Sum of terms */ template , typename... Args> + typename = require_vector_like_t, + require_stan_closure_t* = nullptr, + typename... Args> auto reduce_sum_static(Vec&& vmapped, int grainsize, std::ostream* msgs, - Args&&... args) { - using return_type = return_type_t; + const ReduceFunction& f, Args&&... args) { + using return_type = return_type_t; check_positive("reduce_sum", "grainsize", grainsize); #ifdef STAN_THREADS return internal::reduce_sum_impl()(std::forward(vmapped), false, - grainsize, msgs, + grainsize, msgs, f, std::forward(args)...); #else if (vmapped.empty()) { return return_type(0); } - return ReduceFunction()(std::forward(vmapped), 0, vmapped.size() - 1, - msgs, std::forward(args)...); + return f(std::forward(vmapped), 0, vmapped.size() - 1, + msgs, std::forward(args)...); #endif } +template , + require_not_stan_closure_t* = nullptr, + typename... Args> +auto reduce_sum_static(Vec&& vmapped, int grainsize, std::ostream* msgs, + Args&&... args) { + ReduceFunction f; + internal::ode_closure_adapter cl(f); + return reduce_sum_static(vmapped, grainsize, msgs, cl, args...); +} + } // namespace math } // namespace stan diff --git a/stan/math/rev/functor/reduce_sum.hpp b/stan/math/rev/functor/reduce_sum.hpp index f0931060b7c..de5161e0613 100644 --- a/stan/math/rev/functor/reduce_sum.hpp +++ b/stan/math/rev/functor/reduce_sum.hpp @@ -38,23 +38,28 @@ struct reduce_sum_impl, ReturnType, */ struct recursive_reducer { const size_t num_vars_per_term_; + const size_t num_vars_closure_; // Number of vars in the closure const size_t num_vars_shared_terms_; // Number of vars in shared arguments double* sliced_partials_; // Points to adjoints of the partial calculations Vec vmapped_; std::ostream* msgs_; + const ReduceFunction& f_; std::tuple args_tuple_; double sum_{0.0}; Eigen::VectorXd args_adjoints_{0}; template - recursive_reducer(size_t num_vars_per_term, size_t num_vars_shared_terms, + recursive_reducer(size_t num_vars_per_term, + size_t num_vars_closure, + size_t num_vars_shared_terms, double* sliced_partials, VecT&& vmapped, - std::ostream* msgs, ArgsT&&... args) + std::ostream* msgs, const ReduceFunction& f, ArgsT&&... args) : num_vars_per_term_(num_vars_per_term), + num_vars_closure_(num_vars_closure), num_vars_shared_terms_(num_vars_shared_terms), sliced_partials_(sliced_partials), vmapped_(std::forward(vmapped)), - msgs_(msgs), + msgs_(msgs), f_(f), args_tuple_(std::forward(args)...) {} /* @@ -65,10 +70,11 @@ struct reduce_sum_impl, ReturnType, */ recursive_reducer(recursive_reducer& other, tbb::split) : num_vars_per_term_(other.num_vars_per_term_), + num_vars_closure_(other.num_vars_closure_), num_vars_shared_terms_(other.num_vars_shared_terms_), sliced_partials_(other.sliced_partials_), vmapped_(other.vmapped_), - msgs_(other.msgs_), + msgs_(other.msgs_), f_(other.f_), args_tuple_(other.args_tuple_) {} /** @@ -90,7 +96,8 @@ struct reduce_sum_impl, ReturnType, } if (args_adjoints_.size() == 0) { - args_adjoints_ = Eigen::VectorXd::Zero(num_vars_shared_terms_); + args_adjoints_ = Eigen::VectorXd::Zero(num_vars_closure_ + + num_vars_shared_terms_); } // Initialize nested autodiff stack @@ -104,6 +111,9 @@ struct reduce_sum_impl, ReturnType, local_sub_slice.emplace_back(deep_copy_vars(vmapped_[i])); } + // Create a copy of the functor + auto f_local_copy = deep_copy_vars(f_); + // Create nested autodiff copies of all shared arguments that do not point // back to main autodiff stack auto args_tuple_local_copy = apply( @@ -116,8 +126,8 @@ struct reduce_sum_impl, ReturnType, // Perform calculation var sub_sum_v = apply( [&](auto&&... args) { - return ReduceFunction()(local_sub_slice, r.begin(), r.end() - 1, - msgs_, args...); + return f_local_copy(msgs_, local_sub_slice, r.begin(), r.end() - 1, + args...); }, args_tuple_local_copy); @@ -131,10 +141,13 @@ struct reduce_sum_impl, ReturnType, accumulate_adjoints(sliced_partials_ + r.begin() * num_vars_per_term_, std::move(local_sub_slice)); + // Accumulate adjoints of closure arguments + accumulate_adjoints(args_adjoints_.data(), f_local_copy); + // Accumulate adjoints of shared_arguments apply( [&](auto&&... args) { - accumulate_adjoints(args_adjoints_.data(), + accumulate_adjoints(args_adjoints_.data() + num_vars_closure_, std::forward(args)...); }, std::move(args_tuple_local_copy)); @@ -197,7 +210,8 @@ struct reduce_sum_impl, ReturnType, * @return Summation of all terms */ inline var operator()(Vec&& vmapped, bool auto_partitioning, int grainsize, - std::ostream* msgs, Args&&... args) const { + std::ostream* msgs, const ReduceFunction& f, + Args&&... args) const { const std::size_t num_terms = vmapped.size(); if (vmapped.empty()) { @@ -206,22 +220,27 @@ struct reduce_sum_impl, ReturnType, const std::size_t num_vars_per_term = count_vars(vmapped[0]); const std::size_t num_vars_sliced_terms = num_terms * num_vars_per_term; + const std::size_t num_vars_closure = count_vars(f); const std::size_t num_vars_shared_terms = count_vars(args...); vari** varis = ChainableStack::instance_->memalloc_.alloc_array( - num_vars_sliced_terms + num_vars_shared_terms); + num_vars_sliced_terms + num_vars_closure + num_vars_shared_terms); double* partials = ChainableStack::instance_->memalloc_.alloc_array( - num_vars_sliced_terms + num_vars_shared_terms); + num_vars_sliced_terms + num_vars_closure + num_vars_shared_terms); save_varis(varis, vmapped); - save_varis(varis + num_vars_sliced_terms, args...); + save_varis(varis + num_vars_sliced_terms, f); + save_varis(varis + num_vars_sliced_terms + num_vars_closure, args...); for (size_t i = 0; i < num_vars_sliced_terms; ++i) { partials[i] = 0.0; } - recursive_reducer worker(num_vars_per_term, num_vars_shared_terms, partials, - std::forward(vmapped), msgs, + recursive_reducer worker(num_vars_per_term, + num_vars_closure, + num_vars_shared_terms, + partials, + std::forward(vmapped), msgs, f, std::forward(args)...); if (auto_partitioning) { @@ -234,13 +253,15 @@ struct reduce_sum_impl, ReturnType, partitioner); } - for (size_t i = 0; i < num_vars_shared_terms; ++i) { + for (size_t i = 0; i < num_vars_closure + num_vars_shared_terms; ++i) { partials[num_vars_sliced_terms + i] = worker.args_adjoints_(i); } return var(new precomputed_gradients_vari( - worker.sum_, num_vars_sliced_terms + num_vars_shared_terms, varis, - partials)); + worker.sum_, num_vars_sliced_terms + + num_vars_closure + + num_vars_shared_terms, + varis, partials)); } }; } // namespace internal diff --git a/test/unit/math/rev/functor/reduce_sum_closure_test.cpp b/test/unit/math/rev/functor/reduce_sum_closure_test.cpp new file mode 100644 index 00000000000..0ae82ccd9c4 --- /dev/null +++ b/test/unit/math/rev/functor/reduce_sum_closure_test.cpp @@ -0,0 +1,88 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +struct closure_adapter { + template + auto operator()(const T_slice& subslice, std::size_t start, + std::size_t end, std::ostream* msgs, + const F& f, Args... args) { + return f(msgs, subslice, start, end, args...); + } +}; + +TEST(StanMathRev_reduce_sum, grouped_gradient_closure) { + using stan::math::var; + using stan::math::from_lambda; + using stan::math::test::get_new_msg; + + double lambda_d = 10.0; + const std::size_t groups = 10; + const std::size_t elems_per_group = 1000; + const std::size_t elems = groups * elems_per_group; + + std::vector data(elems); + std::vector gidx(elems); + + for (std::size_t i = 0; i != elems; ++i) { + data[i] = i; + gidx[i] = i / elems_per_group; + } + + std::vector vlambda_v; + + for (std::size_t i = 0; i != groups; ++i) + vlambda_v.push_back(i + 0.2); + + var lambda_v = vlambda_v[0]; + + auto functor = from_lambda( + [](auto& lambda, auto& slice, std::size_t start, std::size_t end, auto& gidx, std::ostream * msgs) { + const std::size_t num_terms = end - start + 1; + std::decay_t lambda_slice(num_terms); + for (std::size_t i = 0; i != num_terms; ++i) + lambda_slice[i] = lambda[gidx[start + i]]; + return stan::math::poisson_lpmf(slice, lambda_slice); + }, vlambda_v); + + var poisson_lpdf = stan::math::reduce_sum( + data, 5, get_new_msg(), functor, gidx); + + std::vector vref_lambda_v; + for (std::size_t i = 0; i != elems; ++i) { + vref_lambda_v.push_back(vlambda_v[gidx[i]]); + } + var lambda_ref = vlambda_v[0]; + var poisson_lpdf_ref = stan::math::poisson_lpmf(data, vref_lambda_v); + + EXPECT_FLOAT_EQ(value_of(poisson_lpdf), value_of(poisson_lpdf_ref)); + + stan::math::grad(poisson_lpdf_ref.vi_); + const double lambda_ref_adj = lambda_ref.adj(); + + stan::math::set_zero_all_adjoints(); + stan::math::grad(poisson_lpdf.vi_); + const double lambda_adj = lambda_v.adj(); + + EXPECT_FLOAT_EQ(lambda_adj, lambda_ref_adj) + << "ref value of poisson lpdf : " << poisson_lpdf_ref.val() << std::endl + << "ref gradient wrt to lambda: " << lambda_ref_adj << std::endl + << "value of poisson lpdf : " << poisson_lpdf.val() << std::endl + << "gradient wrt to lambda: " << lambda_adj << std::endl; + + var poisson_lpdf_static + = stan::math::reduce_sum_static(data, 5, get_new_msg(), functor, gidx); + + stan::math::set_zero_all_adjoints(); + stan::math::grad(poisson_lpdf_static.vi_); + const double lambda_adj_static = lambda_v.adj(); + EXPECT_FLOAT_EQ(lambda_adj_static, lambda_ref_adj); + stan::math::recover_memory(); + + stan::math::recover_memory(); +} From be52b7bd42c117ec054701eb6b0af11926547644 Mon Sep 17 00:00:00 2001 From: Ben Date: Mon, 16 Nov 2020 15:27:18 -0500 Subject: [PATCH 4/5] Added reduce_sum closure adapter (Issue #2197) --- stan/math/prim/functor/closure_adapter.hpp | 25 ++++++++++++++++++++ stan/math/prim/functor/reduce_sum.hpp | 2 +- stan/math/prim/functor/reduce_sum_static.hpp | 2 +- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp index 47c10796007..ade815f7743 100644 --- a/stan/math/prim/functor/closure_adapter.hpp +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -96,6 +96,31 @@ struct ode_closure_adapter { } }; +template +struct reduce_sum_closure_adapter { + using captured_scalar_t__ = double; + using ValueOf__ = reduce_sum_closure_adapter; + static const size_t vars_count__ = 0; + const F f_; + + explicit reduce_sum_closure_adapter(const F& f) : f_(f) {} + + template + auto operator()(std::ostream* msgs, const std::vector& sub_slice, + std::size_t start, std::size_t end, + Args... args) const { + return f_(sub_slice, start, end, msgs, args...); + } + auto value_of__() const { return reduce_sum_closure_adapter(f_); } + auto deep_copy_vars__() const { return reduce_sum_closure_adapter(f_); } + void zero_adjoints__() const {} + double* accumulate_adjoints__(double* dest) const { return dest; } + template + Vari** save_varis(Vari** dest) const { + return dest; + } +}; + } // namespace internal } // namespace math diff --git a/stan/math/prim/functor/reduce_sum.hpp b/stan/math/prim/functor/reduce_sum.hpp index cbb5471460e..7fe18fcfa22 100644 --- a/stan/math/prim/functor/reduce_sum.hpp +++ b/stan/math/prim/functor/reduce_sum.hpp @@ -226,7 +226,7 @@ template cl(f); + internal::reduce_sum_closure_adapter cl(f); return reduce_sum(vmapped, grainsize, msgs, cl, args...); } diff --git a/stan/math/prim/functor/reduce_sum_static.hpp b/stan/math/prim/functor/reduce_sum_static.hpp index 1f3e1fd52ec..77ac5066112 100644 --- a/stan/math/prim/functor/reduce_sum_static.hpp +++ b/stan/math/prim/functor/reduce_sum_static.hpp @@ -72,7 +72,7 @@ template cl(f); + internal::reduce_sum_closure_adapter cl(f); return reduce_sum_static(vmapped, grainsize, msgs, cl, args...); } From d6f3dd4522c6424927dd0562a32b9d4d09b768e2 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Wed, 18 Nov 2020 08:52:20 +0000 Subject: [PATCH 5/5] [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.04.1 (tags/RELEASE_600/final) --- stan/math/prim/functor/closure_adapter.hpp | 3 +- stan/math/prim/functor/reduce_sum.hpp | 27 +++++------ stan/math/prim/functor/reduce_sum_static.hpp | 11 ++--- stan/math/rev/functor/reduce_sum.hpp | 47 +++++++++---------- .../rev/functor/reduce_sum_closure_test.cpp | 31 ++++++------ 5 files changed, 57 insertions(+), 62 deletions(-) diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp index ade815f7743..7df249ccfc6 100644 --- a/stan/math/prim/functor/closure_adapter.hpp +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -107,8 +107,7 @@ struct reduce_sum_closure_adapter { template auto operator()(std::ostream* msgs, const std::vector& sub_slice, - std::size_t start, std::size_t end, - Args... args) const { + std::size_t start, std::size_t end, Args... args) const { return f_(sub_slice, start, end, msgs, args...); } auto value_of__() const { return reduce_sum_closure_adapter(f_); } diff --git a/stan/math/prim/functor/reduce_sum.hpp b/stan/math/prim/functor/reduce_sum.hpp index 7fe18fcfa22..7fe632abf8b 100644 --- a/stan/math/prim/functor/reduce_sum.hpp +++ b/stan/math/prim/functor/reduce_sum.hpp @@ -50,9 +50,10 @@ struct reduce_sum_impl, return_type_t sum_{0.0}; recursive_reducer(Vec&& vmapped, std::ostream* msgs, - const ReduceFunction& f, Args&&... args) + const ReduceFunction& f, Args&&... args) : vmapped_(std::forward(vmapped)), - msgs_(msgs), f_(f), + msgs_(msgs), + f_(f), args_tuple_(std::forward(args)...) {} /** @@ -63,7 +64,8 @@ struct reduce_sum_impl, */ recursive_reducer(recursive_reducer& other, tbb::split) : vmapped_(other.vmapped_), - msgs_(other.msgs_), f_(other.f_), + msgs_(other.msgs_), + f_(other.f_), args_tuple_(other.args_tuple_) {} /** @@ -87,8 +89,7 @@ struct reduce_sum_impl, sum_ += apply( [&](auto&&... args) { - return f_(msgs_, sub_slice, r.begin(), r.end() - 1, - args...); + return f_(msgs_, sub_slice, r.begin(), r.end() - 1, args...); }, args_tuple_); } @@ -145,14 +146,13 @@ struct reduce_sum_impl, */ inline ReturnType operator()(Vec&& vmapped, bool auto_partitioning, int grainsize, std::ostream* msgs, - const ReduceFunction& f, - Args&&... args) const { + const ReduceFunction& f, Args&&... args) const { const std::size_t num_terms = vmapped.size(); if (vmapped.empty()) { return 0.0; } - recursive_reducer worker(std::forward(vmapped), msgs, - f, std::forward(args)...); + recursive_reducer worker(std::forward(vmapped), msgs, f, + std::forward(args)...); if (auto_partitioning) { tbb::parallel_reduce( @@ -196,8 +196,7 @@ struct reduce_sum_impl, */ template , - require_stan_closure_t* = nullptr, - typename... Args> + require_stan_closure_t* = nullptr, typename... Args> inline auto reduce_sum(Vec&& vmapped, int grainsize, std::ostream* msgs, const ReduceFunction& f, Args&&... args) { using return_type = return_type_t; @@ -215,14 +214,14 @@ inline auto reduce_sum(Vec&& vmapped, int grainsize, std::ostream* msgs, } return f(msgs, std::forward(vmapped), 0, vmapped.size() - 1, - std::forward(args)...); + std::forward(args)...); #endif } template , - require_not_stan_closure_t* = nullptr, - typename... Args> + require_not_stan_closure_t* = nullptr, + typename... Args> inline auto reduce_sum(Vec&& vmapped, int grainsize, std::ostream* msgs, Args&&... args) { ReduceFunction f; diff --git a/stan/math/prim/functor/reduce_sum_static.hpp b/stan/math/prim/functor/reduce_sum_static.hpp index 77ac5066112..07fb07c00b8 100644 --- a/stan/math/prim/functor/reduce_sum_static.hpp +++ b/stan/math/prim/functor/reduce_sum_static.hpp @@ -42,8 +42,7 @@ namespace math { */ template , - require_stan_closure_t* = nullptr, - typename... Args> + require_stan_closure_t* = nullptr, typename... Args> auto reduce_sum_static(Vec&& vmapped, int grainsize, std::ostream* msgs, const ReduceFunction& f, Args&&... args) { using return_type = return_type_t; @@ -60,15 +59,15 @@ auto reduce_sum_static(Vec&& vmapped, int grainsize, std::ostream* msgs, return return_type(0); } - return f(std::forward(vmapped), 0, vmapped.size() - 1, - msgs, std::forward(args)...); + return f(std::forward(vmapped), 0, vmapped.size() - 1, msgs, + std::forward(args)...); #endif } template , - require_not_stan_closure_t* = nullptr, - typename... Args> + require_not_stan_closure_t* = nullptr, + typename... Args> auto reduce_sum_static(Vec&& vmapped, int grainsize, std::ostream* msgs, Args&&... args) { ReduceFunction f; diff --git a/stan/math/rev/functor/reduce_sum.hpp b/stan/math/rev/functor/reduce_sum.hpp index de5161e0613..dc4d17afcb2 100644 --- a/stan/math/rev/functor/reduce_sum.hpp +++ b/stan/math/rev/functor/reduce_sum.hpp @@ -38,7 +38,7 @@ struct reduce_sum_impl, ReturnType, */ struct recursive_reducer { const size_t num_vars_per_term_; - const size_t num_vars_closure_; // Number of vars in the closure + const size_t num_vars_closure_; // Number of vars in the closure const size_t num_vars_shared_terms_; // Number of vars in shared arguments double* sliced_partials_; // Points to adjoints of the partial calculations Vec vmapped_; @@ -49,17 +49,17 @@ struct reduce_sum_impl, ReturnType, Eigen::VectorXd args_adjoints_{0}; template - recursive_reducer(size_t num_vars_per_term, - size_t num_vars_closure, - size_t num_vars_shared_terms, - double* sliced_partials, VecT&& vmapped, - std::ostream* msgs, const ReduceFunction& f, ArgsT&&... args) + recursive_reducer(size_t num_vars_per_term, size_t num_vars_closure, + size_t num_vars_shared_terms, double* sliced_partials, + VecT&& vmapped, std::ostream* msgs, + const ReduceFunction& f, ArgsT&&... args) : num_vars_per_term_(num_vars_per_term), - num_vars_closure_(num_vars_closure), + num_vars_closure_(num_vars_closure), num_vars_shared_terms_(num_vars_shared_terms), sliced_partials_(sliced_partials), vmapped_(std::forward(vmapped)), - msgs_(msgs), f_(f), + msgs_(msgs), + f_(f), args_tuple_(std::forward(args)...) {} /* @@ -70,11 +70,12 @@ struct reduce_sum_impl, ReturnType, */ recursive_reducer(recursive_reducer& other, tbb::split) : num_vars_per_term_(other.num_vars_per_term_), - num_vars_closure_(other.num_vars_closure_), + num_vars_closure_(other.num_vars_closure_), num_vars_shared_terms_(other.num_vars_shared_terms_), sliced_partials_(other.sliced_partials_), vmapped_(other.vmapped_), - msgs_(other.msgs_), f_(other.f_), + msgs_(other.msgs_), + f_(other.f_), args_tuple_(other.args_tuple_) {} /** @@ -96,8 +97,8 @@ struct reduce_sum_impl, ReturnType, } if (args_adjoints_.size() == 0) { - args_adjoints_ = Eigen::VectorXd::Zero(num_vars_closure_ + - num_vars_shared_terms_); + args_adjoints_ + = Eigen::VectorXd::Zero(num_vars_closure_ + num_vars_shared_terms_); } // Initialize nested autodiff stack @@ -113,7 +114,7 @@ struct reduce_sum_impl, ReturnType, // Create a copy of the functor auto f_local_copy = deep_copy_vars(f_); - + // Create nested autodiff copies of all shared arguments that do not point // back to main autodiff stack auto args_tuple_local_copy = apply( @@ -127,7 +128,7 @@ struct reduce_sum_impl, ReturnType, var sub_sum_v = apply( [&](auto&&... args) { return f_local_copy(msgs_, local_sub_slice, r.begin(), r.end() - 1, - args...); + args...); }, args_tuple_local_copy); @@ -211,7 +212,7 @@ struct reduce_sum_impl, ReturnType, */ inline var operator()(Vec&& vmapped, bool auto_partitioning, int grainsize, std::ostream* msgs, const ReduceFunction& f, - Args&&... args) const { + Args&&... args) const { const std::size_t num_terms = vmapped.size(); if (vmapped.empty()) { @@ -236,12 +237,9 @@ struct reduce_sum_impl, ReturnType, partials[i] = 0.0; } - recursive_reducer worker(num_vars_per_term, - num_vars_closure, - num_vars_shared_terms, - partials, - std::forward(vmapped), msgs, f, - std::forward(args)...); + recursive_reducer worker( + num_vars_per_term, num_vars_closure, num_vars_shared_terms, partials, + std::forward(vmapped), msgs, f, std::forward(args)...); if (auto_partitioning) { tbb::parallel_reduce( @@ -258,10 +256,9 @@ struct reduce_sum_impl, ReturnType, } return var(new precomputed_gradients_vari( - worker.sum_, num_vars_sliced_terms + - num_vars_closure + - num_vars_shared_terms, - varis, partials)); + worker.sum_, + num_vars_sliced_terms + num_vars_closure + num_vars_shared_terms, varis, + partials)); } }; } // namespace internal diff --git a/test/unit/math/rev/functor/reduce_sum_closure_test.cpp b/test/unit/math/rev/functor/reduce_sum_closure_test.cpp index 0ae82ccd9c4..6720b2e8361 100644 --- a/test/unit/math/rev/functor/reduce_sum_closure_test.cpp +++ b/test/unit/math/rev/functor/reduce_sum_closure_test.cpp @@ -8,17 +8,16 @@ #include struct closure_adapter { - template - auto operator()(const T_slice& subslice, std::size_t start, - std::size_t end, std::ostream* msgs, - const F& f, Args... args) { + template + auto operator()(const T_slice& subslice, std::size_t start, std::size_t end, + std::ostream* msgs, const F& f, Args... args) { return f(msgs, subslice, start, end, args...); } }; TEST(StanMathRev_reduce_sum, grouped_gradient_closure) { - using stan::math::var; using stan::math::from_lambda; + using stan::math::var; using stan::math::test::get_new_msg; double lambda_d = 10.0; @@ -42,16 +41,18 @@ TEST(StanMathRev_reduce_sum, grouped_gradient_closure) { var lambda_v = vlambda_v[0]; auto functor = from_lambda( - [](auto& lambda, auto& slice, std::size_t start, std::size_t end, auto& gidx, std::ostream * msgs) { - const std::size_t num_terms = end - start + 1; - std::decay_t lambda_slice(num_terms); - for (std::size_t i = 0; i != num_terms; ++i) - lambda_slice[i] = lambda[gidx[start + i]]; - return stan::math::poisson_lpmf(slice, lambda_slice); - }, vlambda_v); - - var poisson_lpdf = stan::math::reduce_sum( - data, 5, get_new_msg(), functor, gidx); + [](auto& lambda, auto& slice, std::size_t start, std::size_t end, + auto& gidx, std::ostream* msgs) { + const std::size_t num_terms = end - start + 1; + std::decay_t lambda_slice(num_terms); + for (std::size_t i = 0; i != num_terms; ++i) + lambda_slice[i] = lambda[gidx[start + i]]; + return stan::math::poisson_lpmf(slice, lambda_slice); + }, + vlambda_v); + + var poisson_lpdf + = stan::math::reduce_sum(data, 5, get_new_msg(), functor, gidx); std::vector vref_lambda_v; for (std::size_t i = 0; i != elems; ++i) {