diff --git a/stan/math/prim/err/check_finite.hpp b/stan/math/prim/err/check_finite.hpp index 4479543f09c..34ee1526099 100644 --- a/stan/math/prim/err/check_finite.hpp +++ b/stan/math/prim/err/check_finite.hpp @@ -140,6 +140,9 @@ inline void check_finite(const char* function, const char* name, } } +template * = nullptr> +inline void check_finite(const char* function, const char* name, const T& y) {} + } // namespace math } // namespace stan diff --git a/stan/math/prim/fun/value_of.hpp b/stan/math/prim/fun/value_of.hpp index 6f6bced5c64..ac2368ac759 100644 --- a/stan/math/prim/fun/value_of.hpp +++ b/stan/math/prim/fun/value_of.hpp @@ -60,6 +60,19 @@ inline auto value_of(EigMat&& M) { std::forward(M)); } +/** + * Closures that capture non-arithmetic types have value_of__() method. + * + * @tparam F Input element type + * @param[in] f Input closure + * @return closure + **/ +template * = nullptr, + require_not_st_arithmetic* = nullptr> +inline auto value_of(const F& f) { + return f.value_of__(); +} + } // namespace math } // namespace stan diff --git a/stan/math/prim/functor.hpp b/stan/math/prim/functor.hpp index aab4147aa96..dc6899110fc 100644 --- a/stan/math/prim/functor.hpp +++ b/stan/math/prim/functor.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp new file mode 100644 index 00000000000..7df249ccfc6 --- /dev/null +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -0,0 +1,128 @@ +#ifndef STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP +#define STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP + +#include +#include + +namespace stan { +namespace math { + +template +struct closure_adapter { + using captured_scalar_t__ = double; + using ValueOf__ = closure_adapter; + static const size_t vars_count__ = 0; + F f_; + + explicit closure_adapter(const F& f) : f_(f) {} + + template + auto operator()(std::ostream* msgs, Args... args) const { + return f_(args..., msgs); + } + auto value_of__() const { return closure_adapter(f_); } + auto deep_copy_vars__() const { return closure_adapter(f_); } + void zero_adjoints__() const {} + double* accumulate_adjoints__(double* dest) const { return dest; } + template + Vari** save_varis(Vari** dest) const { + return dest; + } +}; + +template +struct simple_closure { + using captured_scalar_t__ = return_type_t; + using ValueOf__ = simple_closure()))>; + const size_t vars_count__; + F f_; + T s_; + + explicit simple_closure(const F& f, T s) + : f_(f), s_(s), vars_count__(count_vars(s)) {} + + template + auto operator()(std::ostream* msgs, Args... args) const { + return f_(s_, args..., msgs); + } + auto value_of__() const { return ValueOf__(f_, value_of(s_)); } + auto deep_copy_vars__() const { + return simple_closure(f_, deep_copy_vars(s_)); + } + void zero_adjoints__() { zero_adjoints(s_); } + double* accumulate_adjoints__(double* dest) const { + return accumulate_adjoints(dest, s_); + } + template + Vari** save_varis__(Vari** dest) const { + return save_varis(dest, s_); + } +}; + +template +auto from_lambda(F f) { + return closure_adapter(f); +} + +template +auto from_lambda(F f, T a) { + return simple_closure(f, a); +} + +namespace internal { + +template +struct ode_closure_adapter { + using captured_scalar_t__ = double; + using ValueOf__ = ode_closure_adapter; + static const size_t vars_count__ = 0; + const F f_; + + explicit ode_closure_adapter(const F& f) : f_(f) {} + + template + auto operator()(std::ostream* msgs, const T0& t, + const Eigen::Matrix& y, + Args... args) const { + return f_(t, y, msgs, args...); + } + auto value_of__() const { return ode_closure_adapter(f_); } + auto deep_copy_vars__() const { return ode_closure_adapter(f_); } + void zero_adjoints__() const {} + double* accumulate_adjoints__(double* dest) const { return dest; } + template + Vari** save_varis(Vari** dest) const { + return dest; + } +}; + +template +struct reduce_sum_closure_adapter { + using captured_scalar_t__ = double; + using ValueOf__ = reduce_sum_closure_adapter; + static const size_t vars_count__ = 0; + const F f_; + + explicit reduce_sum_closure_adapter(const F& f) : f_(f) {} + + template + auto operator()(std::ostream* msgs, const std::vector& sub_slice, + std::size_t start, std::size_t end, Args... args) const { + return f_(sub_slice, start, end, msgs, args...); + } + auto value_of__() const { return reduce_sum_closure_adapter(f_); } + auto deep_copy_vars__() const { return reduce_sum_closure_adapter(f_); } + void zero_adjoints__() const {} + double* accumulate_adjoints__(double* dest) const { return dest; } + template + Vari** save_varis(Vari** dest) const { + return dest; + } +}; + +} // namespace internal + +} // namespace math +} // namespace stan + +#endif diff --git a/stan/math/prim/functor/coupled_ode_system.hpp b/stan/math/prim/functor/coupled_ode_system.hpp index a4dbe69603d..c5b9a58adb2 100644 --- a/stan/math/prim/functor/coupled_ode_system.hpp +++ b/stan/math/prim/functor/coupled_ode_system.hpp @@ -68,7 +68,7 @@ struct coupled_ode_system_impl { dz_dt.resize(y.size()); Eigen::VectorXd f_y_t - = apply([&](const Args&... args) { return f_(t, y, msgs_, args...); }, + = apply([&](const Args&... args) { return f_(msgs_, t, y, args...); }, args_tuple_); check_size_match("coupled_ode_system", "dy_dt", f_y_t.size(), "states", @@ -104,13 +104,13 @@ struct coupled_ode_system_impl { template struct coupled_ode_system : public coupled_ode_system_impl< - std::is_arithmetic>::value, F, T_y0, + std::is_arithmetic>::value, F, T_y0, Args...> { coupled_ode_system(const F& f, const Eigen::Matrix& y0, std::ostream* msgs, const Args&... args) : coupled_ode_system_impl< - std::is_arithmetic>::value, F, T_y0, + std::is_arithmetic>::value, F, T_y0, Args...>(f, y0, msgs, args...) {} }; diff --git a/stan/math/prim/functor/integrate_ode_rk45.hpp b/stan/math/prim/functor/integrate_ode_rk45.hpp index 1e568d6b1ea..71471b69804 100644 --- a/stan/math/prim/functor/integrate_ode_rk45.hpp +++ b/stan/math/prim/functor/integrate_ode_rk45.hpp @@ -14,7 +14,7 @@ namespace math { * @deprecated use ode_rk45 */ template + typename T_ts, require_stan_closure_t* = nullptr> inline auto integrate_ode_rk45( const F& f, const std::vector& y0, const T_t0& t0, const std::vector& ts, const std::vector& theta, @@ -26,7 +26,7 @@ inline auto integrate_ode_rk45( ts, relative_tolerance, absolute_tolerance, max_num_steps, msgs, theta, x, x_int); - std::vector>> + std::vector>> y_converted; y_converted.reserve(y.size()); for (size_t i = 0; i < y.size(); ++i) @@ -35,6 +35,20 @@ inline auto integrate_ode_rk45( return y_converted; } +template * = nullptr> +inline auto integrate_ode_rk45( + const F& f, const std::vector& y0, const T_t0& t0, + const std::vector& ts, const std::vector& theta, + const std::vector& x, const std::vector& x_int, + std::ostream* msgs = nullptr, double relative_tolerance = 1e-6, + double absolute_tolerance = 1e-6, int max_num_steps = 1e6) { + closure_adapter cl(f); + return integrate_ode_rk45(cl, y0, t0, ts, theta, x, x_int, msgs, + relative_tolerance, absolute_tolerance, + max_num_steps); +} + } // namespace math } // namespace stan diff --git a/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp b/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp index 9b0178e27b5..2b85a4e0269 100644 --- a/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp +++ b/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp @@ -21,16 +21,34 @@ namespace internal { */ template struct integrate_ode_std_vector_interface_adapter { + using captured_scalar_t__ = typename F::captured_scalar_t__; + using ValueOf__ + = integrate_ode_std_vector_interface_adapter; + const int vars_count__; const F f_; - explicit integrate_ode_std_vector_interface_adapter(const F& f) : f_(f) {} + explicit integrate_ode_std_vector_interface_adapter(const F& f) + : vars_count__(f.vars_count__), f_(f) {} template - auto operator()(const T0& t, const Eigen::Matrix& y, - std::ostream* msgs, const std::vector& theta, - const std::vector& x, + auto operator()(std::ostream* msgs, const T0& t, + const Eigen::Matrix& y, + const std::vector& theta, const std::vector& x, const std::vector& x_int) const { - return to_vector(f_(t, to_array_1d(y), theta, x, x_int, msgs)); + return to_vector(f_(msgs, t, to_array_1d(y), theta, x, x_int)); + } + + auto value_of__() const { return ValueOf__(f_.value_of__()); } + auto deep_copy_vars__() const { + return integrate_ode_std_vector_interface_adapter(f_.deep_copy_vars__()); + } + void zero_adjoints__() const { f_.zero_adjoints__(); } + double* accumulate_adjoints__(double* dest) const { + return f_.accumulate_adjoints__(dest); + } + template + Vari** save_varis__(Vari** dest) const { + return f_.save_varis__(dest); } }; diff --git a/stan/math/prim/functor/ode_rk45.hpp b/stan/math/prim/functor/ode_rk45.hpp index 8a90bd1e2c7..36ebd14f8c2 100644 --- a/stan/math/prim/functor/ode_rk45.hpp +++ b/stan/math/prim/functor/ode_rk45.hpp @@ -2,6 +2,7 @@ #define STAN_MATH_PRIM_FUNCTOR_ODE_RK45_HPP #include +#include #include #include #include @@ -53,8 +54,9 @@ namespace math { * @return Solution to ODE at times \p ts */ template * = nullptr> -std::vector, + typename... Args, require_eigen_vector_t* = nullptr, + require_stan_closure_t* = nullptr> +std::vector, Eigen::Dynamic, 1>> ode_rk45_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg, T_t0 t0, const std::vector& ts, @@ -100,7 +102,7 @@ ode_rk45_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg, absolute_tolerance); check_positive(function_name, "max_num_steps", max_num_steps); - using return_t = return_type_t; + using return_t = return_type_t; // creates basic or coupled system by template specializations auto&& coupled_system = apply( [&](const auto&... args_ref) { @@ -158,6 +160,22 @@ ode_rk45_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg, return y; } +template * = nullptr> +std::vector, + Eigen::Dynamic, 1>> +ode_rk45_tol_impl(const char* function_name, const F& f, + const Eigen::Matrix& y0_arg, T_t0 t0, + const std::vector& ts, double relative_tolerance, + double absolute_tolerance, + long int max_num_steps, // NOLINT(runtime/int) + std::ostream* msgs, const Args&... args) { + internal::ode_closure_adapter cl(f); + return ode_rk45_tol_impl(function_name, cl, y0_arg, t0, ts, + relative_tolerance, absolute_tolerance, + max_num_steps, msgs, args...); +} + /** * Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of * times, { t1, t2, t3, ... } using the non-stiff Runge-Kutta 45 solver in @@ -196,7 +214,7 @@ ode_rk45_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg, */ template * = nullptr> -std::vector, +std::vector, Eigen::Dynamic, 1>> ode_rk45_tol(const F& f, const T_y0& y0_arg, T_t0 t0, const std::vector& ts, double relative_tolerance, @@ -242,7 +260,7 @@ ode_rk45_tol(const F& f, const T_y0& y0_arg, T_t0 t0, */ template * = nullptr> -std::vector, +std::vector, Eigen::Dynamic, 1>> ode_rk45(const F& f, const T_y0& y0, T_t0 t0, const std::vector& ts, std::ostream* msgs, const Args&... args) { diff --git a/stan/math/prim/functor/ode_store_sensitivities.hpp b/stan/math/prim/functor/ode_store_sensitivities.hpp index ec42a5506e4..9dfd650626f 100644 --- a/stan/math/prim/functor/ode_store_sensitivities.hpp +++ b/stan/math/prim/functor/ode_store_sensitivities.hpp @@ -28,10 +28,10 @@ namespace math { * @param args Extra arguments passed unmodified through to ODE right hand side * @return ODE state */ -template < - typename F, typename T_y0_t0, typename T_t0, typename T_t, typename... Args, - typename - = require_all_arithmetic_t...>> +template , T_y0_t0, T_t0, + T_t, scalar_type_t...>> Eigen::VectorXd ode_store_sensitivities( const F& f, const std::vector& coupled_state, const Eigen::Matrix& y0, T_t0 t0, T_t t, diff --git a/stan/math/prim/functor/reduce_sum.hpp b/stan/math/prim/functor/reduce_sum.hpp index 927189137c0..7fe632abf8b 100644 --- a/stan/math/prim/functor/reduce_sum.hpp +++ b/stan/math/prim/functor/reduce_sum.hpp @@ -45,12 +45,15 @@ struct reduce_sum_impl, struct recursive_reducer { Vec vmapped_; std::ostream* msgs_; + const ReduceFunction& f_; std::tuple args_tuple_; return_type_t sum_{0.0}; - recursive_reducer(Vec&& vmapped, std::ostream* msgs, Args&&... args) + recursive_reducer(Vec&& vmapped, std::ostream* msgs, + const ReduceFunction& f, Args&&... args) : vmapped_(std::forward(vmapped)), msgs_(msgs), + f_(f), args_tuple_(std::forward(args)...) {} /** @@ -62,6 +65,7 @@ struct reduce_sum_impl, recursive_reducer(recursive_reducer& other, tbb::split) : vmapped_(other.vmapped_), msgs_(other.msgs_), + f_(other.f_), args_tuple_(other.args_tuple_) {} /** @@ -85,8 +89,7 @@ struct reduce_sum_impl, sum_ += apply( [&](auto&&... args) { - return ReduceFunction()(sub_slice, r.begin(), r.end() - 1, msgs_, - args...); + return f_(msgs_, sub_slice, r.begin(), r.end() - 1, args...); }, args_tuple_); } @@ -143,12 +146,12 @@ struct reduce_sum_impl, */ inline ReturnType operator()(Vec&& vmapped, bool auto_partitioning, int grainsize, std::ostream* msgs, - Args&&... args) const { + const ReduceFunction& f, Args&&... args) const { const std::size_t num_terms = vmapped.size(); if (vmapped.empty()) { return 0.0; } - recursive_reducer worker(std::forward(vmapped), msgs, + recursive_reducer worker(std::forward(vmapped), msgs, f, std::forward(args)...); if (auto_partitioning) { @@ -192,28 +195,40 @@ struct reduce_sum_impl, * @return Sum of terms */ template , typename... Args> + typename = require_vector_like_t, + require_stan_closure_t* = nullptr, typename... Args> inline auto reduce_sum(Vec&& vmapped, int grainsize, std::ostream* msgs, - Args&&... args) { - using return_type = return_type_t; + const ReduceFunction& f, Args&&... args) { + using return_type = return_type_t; check_positive("reduce_sum", "grainsize", grainsize); #ifdef STAN_THREADS return internal::reduce_sum_impl()(std::forward(vmapped), true, - grainsize, msgs, + grainsize, msgs, f, std::forward(args)...); #else if (vmapped.empty()) { return return_type(0.0); } - return ReduceFunction()(std::forward(vmapped), 0, vmapped.size() - 1, - msgs, std::forward(args)...); + return f(msgs, std::forward(vmapped), 0, vmapped.size() - 1, + std::forward(args)...); #endif } +template , + require_not_stan_closure_t* = nullptr, + typename... Args> +inline auto reduce_sum(Vec&& vmapped, int grainsize, std::ostream* msgs, + Args&&... args) { + ReduceFunction f; + internal::reduce_sum_closure_adapter cl(f); + return reduce_sum(vmapped, grainsize, msgs, cl, args...); +} + } // namespace math } // namespace stan diff --git a/stan/math/prim/functor/reduce_sum_static.hpp b/stan/math/prim/functor/reduce_sum_static.hpp index 235a6d8c1b6..07fb07c00b8 100644 --- a/stan/math/prim/functor/reduce_sum_static.hpp +++ b/stan/math/prim/functor/reduce_sum_static.hpp @@ -41,28 +41,40 @@ namespace math { * @return Sum of terms */ template , typename... Args> + typename = require_vector_like_t, + require_stan_closure_t* = nullptr, typename... Args> auto reduce_sum_static(Vec&& vmapped, int grainsize, std::ostream* msgs, - Args&&... args) { - using return_type = return_type_t; + const ReduceFunction& f, Args&&... args) { + using return_type = return_type_t; check_positive("reduce_sum", "grainsize", grainsize); #ifdef STAN_THREADS return internal::reduce_sum_impl()(std::forward(vmapped), false, - grainsize, msgs, + grainsize, msgs, f, std::forward(args)...); #else if (vmapped.empty()) { return return_type(0); } - return ReduceFunction()(std::forward(vmapped), 0, vmapped.size() - 1, - msgs, std::forward(args)...); + return f(std::forward(vmapped), 0, vmapped.size() - 1, msgs, + std::forward(args)...); #endif } +template , + require_not_stan_closure_t* = nullptr, + typename... Args> +auto reduce_sum_static(Vec&& vmapped, int grainsize, std::ostream* msgs, + Args&&... args) { + ReduceFunction f; + internal::reduce_sum_closure_adapter cl(f); + return reduce_sum_static(vmapped, grainsize, msgs, cl, args...); +} + } // namespace math } // namespace stan diff --git a/stan/math/prim/meta.hpp b/stan/math/prim/meta.hpp index 5d3f139cd22..a387a99ecef 100644 --- a/stan/math/prim/meta.hpp +++ b/stan/math/prim/meta.hpp @@ -213,6 +213,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/prim/meta/is_stan_closure.hpp b/stan/math/prim/meta/is_stan_closure.hpp new file mode 100644 index 00000000000..5e4ebf09987 --- /dev/null +++ b/stan/math/prim/meta/is_stan_closure.hpp @@ -0,0 +1,56 @@ +#ifndef STAN_MATH_PRIM_META_IS_STAN_CLOSURE_HPP +#define STAN_MATH_PRIM_META_IS_STAN_CLOSURE_HPP + +#include +#include +#include + +#include + +namespace stan { + +/** + * Checks if type is a closure object. + * @tparam The type to check + * @ingroup type_trait + */ +template +struct is_stan_closure : std::false_type {}; + +template +struct is_stan_closure> + : std::true_type {}; + +template +struct scalar_type> { + using type = typename T::captured_scalar_t__; +}; + +STAN_ADD_REQUIRE_UNARY(stan_closure, is_stan_closure, general_types); + +template +struct fn_return_type { + using type = double; +}; + +template +struct fn_return_type> { + using type = typename T::captured_scalar_t__; +}; + +/** + * Convenience type for the return type of the specified template + * parameters. + * + * @tparam F callable type + * @tparam Ts sequence of types + * @see return_type + * @ingroup type_trait + */ +template +using fn_return_type_t + = return_type_t::type, Args...>; + +} // namespace stan + +#endif diff --git a/stan/math/rev/core/accumulate_adjoints.hpp b/stan/math/rev/core/accumulate_adjoints.hpp index e5b27354ebd..1a95dba46af 100644 --- a/stan/math/rev/core/accumulate_adjoints.hpp +++ b/stan/math/rev/core/accumulate_adjoints.hpp @@ -29,6 +29,10 @@ template * = nullptr, typename... Pargs> inline double* accumulate_adjoints(double* dest, EigT&& x, Pargs&&... args); +template * = nullptr, + require_not_st_arithmetic* = nullptr, typename... Pargs> +inline double* accumulate_adjoints(double* dest, F& f, Pargs&&... args); + template * = nullptr, typename... Pargs> inline double* accumulate_adjoints(double* dest, Arith&& x, Pargs&&... args); @@ -121,6 +125,27 @@ inline double* accumulate_adjoints(double* dest, EigT&& x, Pargs&&... args) { return accumulate_adjoints(dest + x.size(), std::forward(args)...); } +/** + * Accumulate adjoints from f (a closure type containing vars) + * into storage pointed to by dest, + * increment the adjoint storage pointer, + * recursively accumulate the adjoints of the rest of the arguments, + * and return final position of storage pointer. + * + * @tparam F A closure type capturing vars. + * @tparam Pargs Types of remaining arguments + * @param dest Pointer to where adjoints are to be accumulated + * @param f A closure holding vars to accumulate over + * @param args Further args to accumulate over + * @return Final position of adjoint storage pointer + */ +template *, require_not_st_arithmetic*, + typename... Pargs> +inline double* accumulate_adjoints(double* dest, F& f, Pargs&&... args) { + return accumulate_adjoints(f.accumulate_adjoints__(dest), + std::forward(args)...); +} + /** * Ignore arithmetic types. * diff --git a/stan/math/rev/core/count_vars.hpp b/stan/math/rev/core/count_vars.hpp index b0b536a27ab..78bcdcef699 100644 --- a/stan/math/rev/core/count_vars.hpp +++ b/stan/math/rev/core/count_vars.hpp @@ -29,6 +29,10 @@ inline size_t count_vars_impl(size_t count, EigT&& x, Pargs&&... args); template inline size_t count_vars_impl(size_t count, const var& x, Pargs&&... args); +template * = nullptr, + require_not_st_arithmetic* = nullptr, typename... Pargs> +inline size_t count_vars_impl(size_t count, const F& f, Pargs&&... args); + template >* = nullptr, typename... Pargs> inline size_t count_vars_impl(size_t count, Arith& x, Pargs&&... args); @@ -110,6 +114,25 @@ inline size_t count_vars_impl(size_t count, const var& x, Pargs&&... args) { return count_vars_impl(count + 1, std::forward(args)...); } +/** + * Count the number of vars in f (a closure capturing vars), + * add it to the running total, + * count the number of vars in the remaining arguments + * and return the result. + * + * @tparam F A closure type + * @tparam Pargs Types of remaining arguments + * @param[in] count The current count of the number of vars + * @param[in] f A closure holding vars + * @param[in] args objects to be forwarded to recursive call of + * `count_vars_impl` + */ +template *, require_not_st_arithmetic*, + typename... Pargs> +inline size_t count_vars_impl(size_t count, const F& f, Pargs&&... args) { + return count_vars_impl(count + f.vars_count__, std::forward(args)...); +} + /** * Arguments without vars contribute zero to the total number of vars. * diff --git a/stan/math/rev/core/deep_copy_vars.hpp b/stan/math/rev/core/deep_copy_vars.hpp index 5eb1b3d8542..f46d3f0dc4c 100644 --- a/stan/math/rev/core/deep_copy_vars.hpp +++ b/stan/math/rev/core/deep_copy_vars.hpp @@ -19,7 +19,7 @@ namespace math { * @param arg For lvalue references this will be passed by reference. * Otherwise it will be moved. */ -template >> +template * = nullptr> inline decltype(auto) deep_copy_vars(Arith&& arg) { return std::forward(arg); } @@ -81,6 +81,19 @@ inline auto deep_copy_vars(EigT&& arg) { .eval(); } +/** + * Copy the vars in f but reallocate new varis for them + * + * @tparam F A closure type + * @param f A closure of vars + * @return A new std::vector of vars + */ +template * = nullptr, + require_not_st_arithmetic* = nullptr> +inline auto deep_copy_vars(F&& f) { + return f.deep_copy_vars__(); +} + } // namespace math } // namespace stan diff --git a/stan/math/rev/core/save_varis.hpp b/stan/math/rev/core/save_varis.hpp index c53a5390539..6e19d54fa53 100644 --- a/stan/math/rev/core/save_varis.hpp +++ b/stan/math/rev/core/save_varis.hpp @@ -29,6 +29,10 @@ template * = nullptr, typename... Pargs> inline vari** save_varis(vari** dest, EigT&& x, Pargs&&... args); +template * = nullptr, + require_not_st_arithmetic* = nullptr, typename... Pargs> +inline vari** save_varis(vari** dest, F& f, Pargs&&... args); + template * = nullptr, typename... Pargs> inline vari** save_varis(vari** dest, Arith&& x, Pargs&&... args); @@ -118,6 +122,25 @@ inline vari** save_varis(vari** dest, EigT&& x, Pargs&&... args) { return save_varis(dest + x.size(), std::forward(args)...); } +/** + * Save the vari pointers in f into the memory pointed to by dest, + * increment the dest storage pointer, + * recursively call save_varis on the rest of the arguments, + * and return the final value of the dest storage pointer. + * + * @tparam F A closure type with var value type + * @tparam Pargs Types of remaining arguments + * @param[in, out] dest Pointer to where vari pointers are saved + * @param[in] f A closure capturing vars + * @param[in] args Additional arguments to have their varis saved + * @return Final position of dest pointer + */ +template *, require_not_st_arithmetic*, + typename... Pargs> +inline vari** save_varis(vari** dest, F& f, Pargs&&... args) { + return save_varis(f.save_varis__(dest), std::forward(args)...); +} + /** * Ignore arithmetic types. * diff --git a/stan/math/rev/core/zero_adjoints.hpp b/stan/math/rev/core/zero_adjoints.hpp index 36368d443ee..e68964e6ee0 100644 --- a/stan/math/rev/core/zero_adjoints.hpp +++ b/stan/math/rev/core/zero_adjoints.hpp @@ -16,6 +16,10 @@ inline void zero_adjoints(T& x, Pargs&... args); template inline void zero_adjoints(var& x, Pargs&... args); +template * = nullptr, + require_not_st_arithmetic* = nullptr> +inline void zero_adjoints(F& f, Pargs&... args); + template inline void zero_adjoints(Eigen::Matrix& x, Pargs&... args); @@ -58,6 +62,23 @@ inline void zero_adjoints(var& x, Pargs&... args) { zero_adjoints(args...); } +/** + * Zero the adjoints of the varis of every var in a closure. + * Recursively call zero_adjoints on the rest of the arguments. + * + * @tparam F type of current argument + * @tparam Pargs type of rest of arguments + * + * @param f current argument + * @param args rest of arguments to zero + */ +template *, + require_not_st_arithmetic*> +inline void zero_adjoints(F& f, Pargs&... args) { + f.zero_adjoints__(); + zero_adjoints(args...); +} + /** * Zero the adjoints of the varis of every var in an Eigen::Matrix * container. Recursively call zero_adjoints on the rest of the arguments. diff --git a/stan/math/rev/functor/coupled_ode_system.hpp b/stan/math/rev/functor/coupled_ode_system.hpp index 83783d4d03f..cfb4977ede6 100644 --- a/stan/math/rev/functor/coupled_ode_system.hpp +++ b/stan/math/rev/functor/coupled_ode_system.hpp @@ -65,7 +65,7 @@ namespace math { */ template struct coupled_ode_system_impl { - const F& f_; + F f_; const Eigen::Matrix& y0_; std::tuple()))...> local_args_tuple_; @@ -89,11 +89,11 @@ struct coupled_ode_system_impl { coupled_ode_system_impl(const F& f, const Eigen::Matrix& y0, std::ostream* msgs, const Args&... args) - : f_(f), + : f_(deep_copy_vars(f)), y0_(y0), local_args_tuple_(deep_copy_vars(args)...), num_y0_vars_(count_vars(y0_)), - num_args_vars(count_vars(args...)), + num_args_vars(count_vars(f, args...)), N_(y0.size()), args_adjoints_(num_args_vars), y_adjoints_(N_), @@ -125,7 +125,7 @@ struct coupled_ode_system_impl { y_vars.coeffRef(n) = z[n]; Eigen::Matrix f_y_t_vars - = apply([&](auto&&... args) { return f_(t, y_vars, msgs_, args...); }, + = apply([&](auto&&... args) { return f_(msgs_, t, y_vars, args...); }, local_args_tuple_); check_size_match("coupled_ode_system", "dy_dt", f_y_t_vars.size(), "states", @@ -142,13 +142,14 @@ struct coupled_ode_system_impl { apply( [&](auto&&... args) { - accumulate_adjoints(args_adjoints_.data(), args...); + accumulate_adjoints(args_adjoints_.data(), f_, args...); }, local_args_tuple_); // The vars here do not live on the nested stack so must be zero'd // separately - apply([&](auto&&... args) { zero_adjoints(args...); }, local_args_tuple_); + apply([&](auto&&... args) { zero_adjoints(f_, args...); }, + local_args_tuple_); // No need to zero adjoints after last sweep if (i + 1 < N_) { diff --git a/stan/math/rev/functor/cvodes_integrator.hpp b/stan/math/rev/functor/cvodes_integrator.hpp index 4e40354540b..d50d4823ebf 100644 --- a/stan/math/rev/functor/cvodes_integrator.hpp +++ b/stan/math/rev/functor/cvodes_integrator.hpp @@ -31,11 +31,12 @@ namespace math { template class cvodes_integrator { - using T_Return = return_type_t; + using T_Return = return_type_t; using T_y0_t0 = return_type_t; const char* function_name_; const F& f_; + typename F::ValueOf__ value_of_f_; const Eigen::Matrix y0_; const T_t0 t0_; const std::vector& ts_; @@ -102,9 +103,9 @@ class cvodes_integrator { inline void rhs(double t, const double y[], double dy_dt[]) const { const Eigen::VectorXd y_vec = Eigen::Map(y, N_); - Eigen::VectorXd dy_dt_vec - = apply([&](auto&&... args) { return f_(t, y_vec, msgs_, args...); }, - value_of_args_tuple_); + Eigen::VectorXd dy_dt_vec = apply( + [&](auto&&... args) { return value_of_f_(msgs_, t, y_vec, args...); }, + value_of_args_tuple_); check_size_match("cvodes_integrator", "dy_dt", dy_dt_vec.size(), "states", N_); @@ -121,8 +122,9 @@ class cvodes_integrator { Eigen::MatrixXd Jfy; auto f_wrapped = [&](const Eigen::Matrix& y) { - return apply([&](auto&&... args) { return f_(t, y, msgs_, args...); }, - value_of_args_tuple_); + return apply( + [&](auto&&... args) { return value_of_f_(msgs_, t, y, args...); }, + value_of_args_tuple_); }; jacobian(f_wrapped, Eigen::Map(y, N_), fy, Jfy); @@ -191,6 +193,7 @@ class cvodes_integrator { std::ostream* msgs, const T_Args&... args) : function_name_(function_name), f_(f), + value_of_f_(value_of(f)), y0_(y0.template cast()), t0_(t0), ts_(ts), @@ -202,7 +205,7 @@ class cvodes_integrator { absolute_tolerance_(absolute_tolerance), max_num_steps_(max_num_steps), num_y0_vars_(count_vars(y0_)), - num_args_vars_(count_vars(args...)), + num_args_vars_(count_vars(f, args...)), coupled_ode_(f, y0_, msgs, args...), coupled_state_(coupled_ode_.initial_state()) { check_finite(function_name, "initial state", y0_); diff --git a/stan/math/rev/functor/integrate_ode_adams.hpp b/stan/math/rev/functor/integrate_ode_adams.hpp index 0cba70a321e..b2cf7373dad 100644 --- a/stan/math/rev/functor/integrate_ode_adams.hpp +++ b/stan/math/rev/functor/integrate_ode_adams.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -14,8 +15,8 @@ namespace math { * @deprecated use ode_adams */ template -std::vector>> + typename T_ts, require_stan_closure_t* = nullptr> +std::vector>> integrate_ode_adams(const F& f, const std::vector& y0, const T_t0& t0, const std::vector& ts, const std::vector& theta, @@ -29,7 +30,7 @@ integrate_ode_adams(const F& f, const std::vector& y0, const T_t0& t0, t0, ts, relative_tolerance, absolute_tolerance, max_num_steps, msgs, theta, x, x_int); - std::vector>> + std::vector>> y_converted; for (size_t i = 0; i < y.size(); ++i) y_converted.push_back(to_array_1d(y[i])); @@ -37,6 +38,23 @@ integrate_ode_adams(const F& f, const std::vector& y0, const T_t0& t0, return y_converted; } +template * = nullptr> +std::vector>> +integrate_ode_adams(const F& f, const std::vector& y0, const T_t0& t0, + const std::vector& ts, + const std::vector& theta, + const std::vector& x, const std::vector& x_int, + std::ostream* msgs = nullptr, + double relative_tolerance = 1e-10, + double absolute_tolerance = 1e-10, + int max_num_steps = 1e8) { + closure_adapter cl(f); + return integrate_ode_adams(cl, y0, t0, ts, theta, x, x_int, msgs, + relative_tolerance, absolute_tolerance, + max_num_steps); +} + } // namespace math } // namespace stan #endif diff --git a/stan/math/rev/functor/integrate_ode_bdf.hpp b/stan/math/rev/functor/integrate_ode_bdf.hpp index c3877bdb875..3b5bf9eef29 100644 --- a/stan/math/rev/functor/integrate_ode_bdf.hpp +++ b/stan/math/rev/functor/integrate_ode_bdf.hpp @@ -2,6 +2,7 @@ #define STAN_MATH_REV_FUNCTOR_INTEGRATE_ODE_BDF_HPP #include +#include #include #include #include @@ -14,8 +15,8 @@ namespace math { * @deprecated use ode_bdf */ template -std::vector>> + typename T_ts, require_stan_closure_t* = nullptr> +std::vector>> integrate_ode_bdf(const F& f, const std::vector& y0, const T_t0& t0, const std::vector& ts, const std::vector& theta, @@ -29,7 +30,7 @@ integrate_ode_bdf(const F& f, const std::vector& y0, const T_t0& t0, ts, relative_tolerance, absolute_tolerance, max_num_steps, msgs, theta, x, x_int); - std::vector>> + std::vector>> y_converted; for (size_t i = 0; i < y.size(); ++i) y_converted.push_back(to_array_1d(y[i])); @@ -37,6 +38,22 @@ integrate_ode_bdf(const F& f, const std::vector& y0, const T_t0& t0, return y_converted; } +template * = nullptr> +std::vector>> +integrate_ode_bdf(const F& f, const std::vector& y0, const T_t0& t0, + const std::vector& ts, + const std::vector& theta, + const std::vector& x, const std::vector& x_int, + std::ostream* msgs = nullptr, + double relative_tolerance = 1e-10, + double absolute_tolerance = 1e-10, int max_num_steps = 1e8) { + closure_adapter cl(f); + return integrate_ode_bdf(cl, y0, t0, ts, theta, x, x_int, msgs, + relative_tolerance, absolute_tolerance, + max_num_steps); +} + } // namespace math } // namespace stan #endif diff --git a/stan/math/rev/functor/ode_adams.hpp b/stan/math/rev/functor/ode_adams.hpp index ee0bdafbbd5..cd283596a09 100644 --- a/stan/math/rev/functor/ode_adams.hpp +++ b/stan/math/rev/functor/ode_adams.hpp @@ -1,6 +1,7 @@ #ifndef STAN_MATH_REV_FUNCTOR_ODE_ADAMS_HPP #define STAN_MATH_REV_FUNCTOR_ODE_ADAMS_HPP +#include #include #include #include @@ -45,8 +46,9 @@ namespace math { * @return Solution to ODE at times \p ts */ template * = nullptr> -std::vector, + typename... T_Args, require_eigen_col_vector_t* = nullptr, + require_stan_closure_t* = nullptr> +std::vector, Eigen::Dynamic, 1>> ode_adams_tol_impl(const char* function_name, const F& f, const T_y0& y0, const T_t0& t0, const std::vector& ts, @@ -65,6 +67,19 @@ ode_adams_tol_impl(const char* function_name, const F& f, const T_y0& y0, args_ref_tuple); } +template * = nullptr> +auto ode_adams_tol_impl(const char* function_name, const F& f, + const Eigen::Matrix& y0, + const T_t0& t0, const std::vector& ts, + double relative_tolerance, double absolute_tolerance, + long int max_num_steps, // NOLINT(runtime/int) + std::ostream* msgs, const T_Args&... args) { + internal::ode_closure_adapter cl(f); + return ode_adams_tol_impl(function_name, cl, y0, t0, ts, relative_tolerance, + absolute_tolerance, max_num_steps, msgs, args...); +} + /** * Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of * times, { t1, t2, t3, ... } using the non-stiff Adams-Moulton solver from @@ -100,8 +115,8 @@ ode_adams_tol_impl(const char* function_name, const F& f, const T_y0& y0, */ template * = nullptr> -std::vector, - Eigen::Dynamic, 1>> +std::vector, Eigen::Dynamic, 1>> ode_adams_tol(const F& f, const T_y0& y0, const T_t0& t0, const std::vector& ts, double relative_tolerance, double absolute_tolerance, @@ -143,8 +158,8 @@ ode_adams_tol(const F& f, const T_y0& y0, const T_t0& t0, */ template * = nullptr> -std::vector, - Eigen::Dynamic, 1>> +std::vector, Eigen::Dynamic, 1>> ode_adams(const F& f, const T_y0& y0, const T_t0& t0, const std::vector& ts, std::ostream* msgs, const T_Args&... args) { diff --git a/stan/math/rev/functor/ode_bdf.hpp b/stan/math/rev/functor/ode_bdf.hpp index a07af2e3339..8f69e3abfd5 100644 --- a/stan/math/rev/functor/ode_bdf.hpp +++ b/stan/math/rev/functor/ode_bdf.hpp @@ -1,6 +1,7 @@ #ifndef STAN_MATH_REV_FUNCTOR_ODE_BDF_HPP #define STAN_MATH_REV_FUNCTOR_ODE_BDF_HPP +#include #include #include #include @@ -46,8 +47,9 @@ namespace math { * @return Solution to ODE at times \p ts */ template * = nullptr> -std::vector, + typename... T_Args, require_eigen_col_vector_t* = nullptr, + require_stan_closure_t* = nullptr> +std::vector, Eigen::Dynamic, 1>> ode_bdf_tol_impl(const char* function_name, const F& f, const T_y0& y0, const T_t0& t0, const std::vector& ts, @@ -66,6 +68,19 @@ ode_bdf_tol_impl(const char* function_name, const F& f, const T_y0& y0, args_ref_tuple); } +template * = nullptr> +auto ode_bdf_tol_impl(const char* function_name, const F& f, + const Eigen::Matrix& y0, + const T_t0& t0, const std::vector& ts, + double relative_tolerance, double absolute_tolerance, + long int max_num_steps, // NOLINT(runtime/int) + std::ostream* msgs, const T_Args&... args) { + internal::ode_closure_adapter cl(f); + return ode_bdf_tol_impl(function_name, cl, y0, t0, ts, relative_tolerance, + absolute_tolerance, max_num_steps, msgs, args...); +} + /** * Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of * times, { t1, t2, t3, ... } using the stiff backward differentiation formula @@ -101,8 +116,8 @@ ode_bdf_tol_impl(const char* function_name, const F& f, const T_y0& y0, */ template * = nullptr> -std::vector, - Eigen::Dynamic, 1>> +std::vector, Eigen::Dynamic, 1>> ode_bdf_tol(const F& f, const T_y0& y0, const T_t0& t0, const std::vector& ts, double relative_tolerance, double absolute_tolerance, @@ -144,8 +159,8 @@ ode_bdf_tol(const F& f, const T_y0& y0, const T_t0& t0, */ template * = nullptr> -std::vector, - Eigen::Dynamic, 1>> +std::vector, Eigen::Dynamic, 1>> ode_bdf(const F& f, const T_y0& y0, const T_t0& t0, const std::vector& ts, std::ostream* msgs, const T_Args&... args) { double relative_tolerance = 1e-10; diff --git a/stan/math/rev/functor/ode_store_sensitivities.hpp b/stan/math/rev/functor/ode_store_sensitivities.hpp index 8a93ff846df..39e005adf0b 100644 --- a/stan/math/rev/functor/ode_store_sensitivities.hpp +++ b/stan/math/rev/functor/ode_store_sensitivities.hpp @@ -31,7 +31,7 @@ namespace math { */ template , T_y0_t0, T_t0, T_t, scalar_type_t...>* = nullptr> Eigen::Matrix ode_store_sensitivities( const F& f, const std::vector& coupled_state, @@ -39,7 +39,7 @@ Eigen::Matrix ode_store_sensitivities( const T_t& t, std::ostream* msgs, const Args&... args) { const size_t N = y0.size(); const size_t num_y0_vars = count_vars(y0); - const size_t num_args_vars = count_vars(args...); + const size_t num_args_vars = count_vars(f, args...); const size_t num_t0_vars = count_vars(t0); const size_t num_t_vars = count_vars(t); Eigen::Matrix yt(N); @@ -51,12 +51,12 @@ Eigen::Matrix ode_store_sensitivities( Eigen::VectorXd f_y_t; if (is_var::value) - f_y_t = f(value_of(t), y, msgs, eval(value_of(args))...); + f_y_t = value_of(f)(msgs, value_of(t), y, eval(value_of(args))...); Eigen::VectorXd f_y0_t0; if (is_var::value) - f_y0_t0 - = f(value_of(t0), eval(value_of(y0)), msgs, eval(value_of(args))...); + f_y0_t0 = value_of(f)(msgs, value_of(t0), eval(value_of(y0)), + eval(value_of(args))...); const size_t total_vars = num_y0_vars + num_args_vars + num_t0_vars + num_t_vars; @@ -64,7 +64,7 @@ Eigen::Matrix ode_store_sensitivities( vari** varis = ChainableStack::instance_->memalloc_.alloc_array(total_vars); - save_varis(varis, y0, args..., t0, t); + save_varis(varis, y0, f, args..., t0, t); // memory for a column major jacobian double* jacobian_mem diff --git a/stan/math/rev/functor/reduce_sum.hpp b/stan/math/rev/functor/reduce_sum.hpp index f0931060b7c..dc4d17afcb2 100644 --- a/stan/math/rev/functor/reduce_sum.hpp +++ b/stan/math/rev/functor/reduce_sum.hpp @@ -38,23 +38,28 @@ struct reduce_sum_impl, ReturnType, */ struct recursive_reducer { const size_t num_vars_per_term_; + const size_t num_vars_closure_; // Number of vars in the closure const size_t num_vars_shared_terms_; // Number of vars in shared arguments double* sliced_partials_; // Points to adjoints of the partial calculations Vec vmapped_; std::ostream* msgs_; + const ReduceFunction& f_; std::tuple args_tuple_; double sum_{0.0}; Eigen::VectorXd args_adjoints_{0}; template - recursive_reducer(size_t num_vars_per_term, size_t num_vars_shared_terms, - double* sliced_partials, VecT&& vmapped, - std::ostream* msgs, ArgsT&&... args) + recursive_reducer(size_t num_vars_per_term, size_t num_vars_closure, + size_t num_vars_shared_terms, double* sliced_partials, + VecT&& vmapped, std::ostream* msgs, + const ReduceFunction& f, ArgsT&&... args) : num_vars_per_term_(num_vars_per_term), + num_vars_closure_(num_vars_closure), num_vars_shared_terms_(num_vars_shared_terms), sliced_partials_(sliced_partials), vmapped_(std::forward(vmapped)), msgs_(msgs), + f_(f), args_tuple_(std::forward(args)...) {} /* @@ -65,10 +70,12 @@ struct reduce_sum_impl, ReturnType, */ recursive_reducer(recursive_reducer& other, tbb::split) : num_vars_per_term_(other.num_vars_per_term_), + num_vars_closure_(other.num_vars_closure_), num_vars_shared_terms_(other.num_vars_shared_terms_), sliced_partials_(other.sliced_partials_), vmapped_(other.vmapped_), msgs_(other.msgs_), + f_(other.f_), args_tuple_(other.args_tuple_) {} /** @@ -90,7 +97,8 @@ struct reduce_sum_impl, ReturnType, } if (args_adjoints_.size() == 0) { - args_adjoints_ = Eigen::VectorXd::Zero(num_vars_shared_terms_); + args_adjoints_ + = Eigen::VectorXd::Zero(num_vars_closure_ + num_vars_shared_terms_); } // Initialize nested autodiff stack @@ -104,6 +112,9 @@ struct reduce_sum_impl, ReturnType, local_sub_slice.emplace_back(deep_copy_vars(vmapped_[i])); } + // Create a copy of the functor + auto f_local_copy = deep_copy_vars(f_); + // Create nested autodiff copies of all shared arguments that do not point // back to main autodiff stack auto args_tuple_local_copy = apply( @@ -116,8 +127,8 @@ struct reduce_sum_impl, ReturnType, // Perform calculation var sub_sum_v = apply( [&](auto&&... args) { - return ReduceFunction()(local_sub_slice, r.begin(), r.end() - 1, - msgs_, args...); + return f_local_copy(msgs_, local_sub_slice, r.begin(), r.end() - 1, + args...); }, args_tuple_local_copy); @@ -131,10 +142,13 @@ struct reduce_sum_impl, ReturnType, accumulate_adjoints(sliced_partials_ + r.begin() * num_vars_per_term_, std::move(local_sub_slice)); + // Accumulate adjoints of closure arguments + accumulate_adjoints(args_adjoints_.data(), f_local_copy); + // Accumulate adjoints of shared_arguments apply( [&](auto&&... args) { - accumulate_adjoints(args_adjoints_.data(), + accumulate_adjoints(args_adjoints_.data() + num_vars_closure_, std::forward(args)...); }, std::move(args_tuple_local_copy)); @@ -197,7 +211,8 @@ struct reduce_sum_impl, ReturnType, * @return Summation of all terms */ inline var operator()(Vec&& vmapped, bool auto_partitioning, int grainsize, - std::ostream* msgs, Args&&... args) const { + std::ostream* msgs, const ReduceFunction& f, + Args&&... args) const { const std::size_t num_terms = vmapped.size(); if (vmapped.empty()) { @@ -206,23 +221,25 @@ struct reduce_sum_impl, ReturnType, const std::size_t num_vars_per_term = count_vars(vmapped[0]); const std::size_t num_vars_sliced_terms = num_terms * num_vars_per_term; + const std::size_t num_vars_closure = count_vars(f); const std::size_t num_vars_shared_terms = count_vars(args...); vari** varis = ChainableStack::instance_->memalloc_.alloc_array( - num_vars_sliced_terms + num_vars_shared_terms); + num_vars_sliced_terms + num_vars_closure + num_vars_shared_terms); double* partials = ChainableStack::instance_->memalloc_.alloc_array( - num_vars_sliced_terms + num_vars_shared_terms); + num_vars_sliced_terms + num_vars_closure + num_vars_shared_terms); save_varis(varis, vmapped); - save_varis(varis + num_vars_sliced_terms, args...); + save_varis(varis + num_vars_sliced_terms, f); + save_varis(varis + num_vars_sliced_terms + num_vars_closure, args...); for (size_t i = 0; i < num_vars_sliced_terms; ++i) { partials[i] = 0.0; } - recursive_reducer worker(num_vars_per_term, num_vars_shared_terms, partials, - std::forward(vmapped), msgs, - std::forward(args)...); + recursive_reducer worker( + num_vars_per_term, num_vars_closure, num_vars_shared_terms, partials, + std::forward(vmapped), msgs, f, std::forward(args)...); if (auto_partitioning) { tbb::parallel_reduce( @@ -234,12 +251,13 @@ struct reduce_sum_impl, ReturnType, partitioner); } - for (size_t i = 0; i < num_vars_shared_terms; ++i) { + for (size_t i = 0; i < num_vars_closure + num_vars_shared_terms; ++i) { partials[num_vars_sliced_terms + i] = worker.args_adjoints_(i); } return var(new precomputed_gradients_vari( - worker.sum_, num_vars_sliced_terms + num_vars_shared_terms, varis, + worker.sum_, + num_vars_sliced_terms + num_vars_closure + num_vars_shared_terms, varis, partials)); } }; diff --git a/test/unit/math/prim/functor/coupled_ode_system_test.cpp b/test/unit/math/prim/functor/coupled_ode_system_test.cpp index 679d2b33b7e..b67571ff935 100644 --- a/test/unit/math/prim/functor/coupled_ode_system_test.cpp +++ b/test/unit/math/prim/functor/coupled_ode_system_test.cpp @@ -15,7 +15,9 @@ struct StanMathCoupledOdeSystem : public ::testing::Test { TEST_F(StanMathCoupledOdeSystem, initial_state_dd) { using stan::math::coupled_ode_system; - mock_ode_functor base_ode; + using adapted_ode_functor + = stan::math::internal::ode_closure_adapter; + adapted_ode_functor base_ode{mock_ode_functor()}; const int N = 3; const int M = 4; @@ -28,7 +30,7 @@ TEST_F(StanMathCoupledOdeSystem, initial_state_dd) { for (int m = 0; m < M; m++) theta_d[m] = 10 * (m + 1); - coupled_ode_system, + coupled_ode_system, std::vector, std::vector> coupled_system_dd(base_ode, y0_d, &msgs, theta_d, x, x_int); @@ -43,7 +45,9 @@ TEST_F(StanMathCoupledOdeSystem, initial_state_dd) { TEST_F(StanMathCoupledOdeSystem, size) { using stan::math::coupled_ode_system; - mock_ode_functor base_ode; + using adapted_ode_functor + = stan::math::internal::ode_closure_adapter; + adapted_ode_functor base_ode{mock_ode_functor()}; const int N = 3; const int M = 4; @@ -51,7 +55,7 @@ TEST_F(StanMathCoupledOdeSystem, size) { Eigen::VectorXd y0_d(N); std::vector theta_d(M, 0.0); - coupled_ode_system + coupled_ode_system coupled_system_dd(base_ode, y0_d, &msgs, 1, 1.0, y0_d); EXPECT_EQ(N, coupled_system_dd.size()); @@ -59,18 +63,21 @@ TEST_F(StanMathCoupledOdeSystem, size) { TEST_F(StanMathCoupledOdeSystem, recover_exception) { using stan::math::coupled_ode_system; + using adapted_ode_functor = stan::math::internal::ode_closure_adapter< + mock_throwing_ode_functor>; std::string message = "ode throws"; const int N = 3; const int M = 4; - mock_throwing_ode_functor throwing_ode(message); + adapted_ode_functor throwing_ode{ + mock_throwing_ode_functor(message)}; Eigen::VectorXd y0_d(N); std::vector theta_v(M); - coupled_ode_system, double, - std::vector, std::vector, std::vector> + coupled_ode_system, + std::vector, std::vector> coupled_system_dd(throwing_ode, y0_d, &msgs, theta_v, x, x_int); std::vector y(3); diff --git a/test/unit/math/prim/functor/integrate_ode_std_vector_interface_adapter_test.cpp b/test/unit/math/prim/functor/integrate_ode_std_vector_interface_adapter_test.cpp index 1d7f484a356..1d6fb20bf83 100644 --- a/test/unit/math/prim/functor/integrate_ode_std_vector_interface_adapter_test.cpp +++ b/test/unit/math/prim/functor/integrate_ode_std_vector_interface_adapter_test.cpp @@ -5,9 +5,10 @@ #include TEST(StanMath, check_values) { - harm_osc_ode_data_fun harm_osc; - stan::math::internal::integrate_ode_std_vector_interface_adapter< - harm_osc_ode_data_fun> + stan::math::closure_adapter harm_osc{ + harm_osc_ode_data_fun()}; + stan::math::internal::integrate_ode_std_vector_interface_adapter harm_osc_adapted(harm_osc); std::vector theta = {0.15}; @@ -19,9 +20,9 @@ TEST(StanMath, check_values) { double t = 1.0; Eigen::VectorXd out1 - = stan::math::to_vector(harm_osc(t, y, theta, x, x_int, nullptr)); + = stan::math::to_vector(harm_osc(nullptr, t, y, theta, x, x_int)); Eigen::VectorXd out2 - = harm_osc_adapted(t, stan::math::to_vector(y), nullptr, theta, x, x_int); + = harm_osc_adapted(nullptr, t, stan::math::to_vector(y), theta, x, x_int); EXPECT_MATRIX_FLOAT_EQ(out1, out2); } diff --git a/test/unit/math/prim/functor/ode_store_sensitivities_test.cpp b/test/unit/math/prim/functor/ode_store_sensitivities_test.cpp index acef36e6884..e3db0a860e3 100644 --- a/test/unit/math/prim/functor/ode_store_sensitivities_test.cpp +++ b/test/unit/math/prim/functor/ode_store_sensitivities_test.cpp @@ -6,7 +6,7 @@ #include TEST(MathPrim, ode_store_sensitivities) { - mock_ode_functor base_ode; + stan::math::closure_adapter base_ode{mock_ode_functor()}; size_t N = 5; diff --git a/test/unit/math/rev/functor/coupled_ode_system_test.cpp b/test/unit/math/rev/functor/coupled_ode_system_test.cpp index e513a20a3c3..62f5e86062d 100644 --- a/test/unit/math/rev/functor/coupled_ode_system_test.cpp +++ b/test/unit/math/rev/functor/coupled_ode_system_test.cpp @@ -17,11 +17,13 @@ struct StanAgradRevOde : public ::testing::Test { // ******************** DV **************************** TEST_F(StanAgradRevOde, coupled_ode_system_dv) { using stan::math::coupled_ode_system; + using stan::math::internal::ode_closure_adapter; // Run nested autodiff in this scope stan::math::nested_rev_autodiff nested; - harm_osc_ode_fun_eigen harm_osc; + ode_closure_adapter harm_osc{ + harm_osc_ode_fun_eigen()}; std::vector theta; std::vector z0; @@ -61,7 +63,8 @@ TEST_F(StanAgradRevOde, coupled_ode_system_dv) { TEST_F(StanAgradRevOde, initial_state_dv) { using stan::math::coupled_ode_system; using stan::math::var; - mock_ode_functor base_ode; + using stan::math::internal::ode_closure_adapter; + ode_closure_adapter base_ode{mock_ode_functor()}; const size_t N = 3; const size_t M = 4; @@ -91,7 +94,8 @@ TEST_F(StanAgradRevOde, initial_state_dv) { TEST_F(StanAgradRevOde, size_dv) { using stan::math::coupled_ode_system; using stan::math::var; - mock_ode_functor base_ode; + using stan::math::internal::ode_closure_adapter; + ode_closure_adapter base_ode{mock_ode_functor()}; const size_t N = 3; const size_t M = 4; @@ -110,7 +114,8 @@ TEST_F(StanAgradRevOde, size_dv) { TEST_F(StanAgradRevOde, memory_recovery_dv) { using stan::math::coupled_ode_system; using stan::math::var; - mock_ode_functor base_ode; + using stan::math::internal::ode_closure_adapter; + ode_closure_adapter base_ode{mock_ode_functor()}; const size_t N = 3; const size_t M = 4; @@ -135,6 +140,7 @@ TEST_F(StanAgradRevOde, memory_recovery_dv) { TEST_F(StanAgradRevOde, memory_recovery_exception_dv) { using stan::math::coupled_ode_system; using stan::math::var; + using stan::math::internal::ode_closure_adapter; std::string message = "ode throws"; const size_t N = 3; @@ -145,7 +151,8 @@ TEST_F(StanAgradRevOde, memory_recovery_exception_dv) { std::stringstream scoped_message; scoped_message << "iteration " << n; SCOPED_TRACE(scoped_message.str()); - mock_throwing_ode_functor throwing_ode(message, 1); + ode_closure_adapter> + throwing_ode{mock_throwing_ode_functor(message, 1)}; Eigen::VectorXd y0_d = Eigen::VectorXd::Zero(N); std::vector theta_v(M, 0.0); @@ -169,11 +176,13 @@ TEST_F(StanAgradRevOde, memory_recovery_exception_dv) { TEST_F(StanAgradRevOde, coupled_ode_system_vd) { using stan::math::coupled_ode_system; + using stan::math::internal::ode_closure_adapter; // Run nested autodiff in this scope stan::math::nested_rev_autodiff nested; - harm_osc_ode_fun_eigen harm_osc; + ode_closure_adapter harm_osc{ + harm_osc_ode_fun_eigen()}; std::vector theta; std::vector z0; @@ -220,7 +229,8 @@ TEST_F(StanAgradRevOde, coupled_ode_system_vd) { TEST_F(StanAgradRevOde, initial_state_vd) { using stan::math::coupled_ode_system; using stan::math::var; - mock_ode_functor base_ode; + using stan::math::internal::ode_closure_adapter; + ode_closure_adapter base_ode{mock_ode_functor()}; const size_t N = 3; const size_t M = 4; @@ -251,7 +261,8 @@ TEST_F(StanAgradRevOde, initial_state_vd) { TEST_F(StanAgradRevOde, size_vd) { using stan::math::coupled_ode_system; using stan::math::var; - mock_ode_functor base_ode; + using stan::math::internal::ode_closure_adapter; + ode_closure_adapter base_ode{mock_ode_functor()}; const size_t N = 3; const size_t M = 4; @@ -271,7 +282,8 @@ TEST_F(StanAgradRevOde, size_vd) { TEST_F(StanAgradRevOde, memory_recovery_vd) { using stan::math::coupled_ode_system; using stan::math::var; - mock_ode_functor base_ode; + using stan::math::internal::ode_closure_adapter; + ode_closure_adapter base_ode{mock_ode_functor()}; const size_t N = 3; const size_t M = 4; @@ -297,6 +309,7 @@ TEST_F(StanAgradRevOde, memory_recovery_vd) { TEST_F(StanAgradRevOde, memory_recovery_exception_vd) { using stan::math::coupled_ode_system; using stan::math::var; + using stan::math::internal::ode_closure_adapter; std::string message = "ode throws"; const size_t N = 3; @@ -307,7 +320,8 @@ TEST_F(StanAgradRevOde, memory_recovery_exception_vd) { std::stringstream scoped_message; scoped_message << "iteration " << n; SCOPED_TRACE(scoped_message.str()); - mock_throwing_ode_functor throwing_ode(message, 1); + ode_closure_adapter> + throwing_ode{mock_throwing_ode_functor(message, 1)}; Eigen::Matrix y0_v = Eigen::VectorXd::Zero(N).template cast(); @@ -332,6 +346,7 @@ TEST_F(StanAgradRevOde, memory_recovery_exception_vd) { TEST_F(StanAgradRevOde, coupled_ode_system_vv) { using stan::math::coupled_ode_system; + using stan::math::internal::ode_closure_adapter; // Run nested autodiff in this scope stan::math::nested_rev_autodiff nested; @@ -346,7 +361,8 @@ TEST_F(StanAgradRevOde, coupled_ode_system_vv) { std::vector theta_var(1); theta_var[0] = 0.15; - harm_osc_ode_fun_eigen harm_osc; + ode_closure_adapter harm_osc{ + harm_osc_ode_fun_eigen()}; std::size_t stack_size = stan::math::nested_size(); @@ -377,7 +393,7 @@ TEST_F(StanAgradRevOde, coupled_ode_system_vv) { theta_double[0] = 0.15; Eigen::VectorXd dy_dt_base - = harm_osc(0.0, y0_double, &msgs, theta_double, x, x_int); + = harm_osc(&msgs, 0.0, y0_double, theta_double, x, x_int); EXPECT_FLOAT_EQ(dy_dt_base[0], dz_dt[0]); EXPECT_FLOAT_EQ(dy_dt_base[1], dz_dt[1]); @@ -392,7 +408,8 @@ TEST_F(StanAgradRevOde, coupled_ode_system_vv) { TEST_F(StanAgradRevOde, initial_state_vv) { using stan::math::coupled_ode_system; using stan::math::var; - mock_ode_functor base_ode; + using stan::math::internal::ode_closure_adapter; + ode_closure_adapter base_ode{mock_ode_functor()}; const size_t N = 3; const size_t M = 4; @@ -424,7 +441,8 @@ TEST_F(StanAgradRevOde, initial_state_vv) { TEST_F(StanAgradRevOde, size_vv) { using stan::math::coupled_ode_system; using stan::math::var; - mock_ode_functor base_ode; + using stan::math::internal::ode_closure_adapter; + ode_closure_adapter base_ode{mock_ode_functor()}; const size_t N = 3; const size_t M = 4; @@ -445,7 +463,8 @@ TEST_F(StanAgradRevOde, size_vv) { TEST_F(StanAgradRevOde, memory_recovery_vv) { using stan::math::coupled_ode_system; using stan::math::var; - mock_ode_functor base_ode; + using stan::math::internal::ode_closure_adapter; + ode_closure_adapter base_ode{mock_ode_functor()}; const size_t N = 3; const size_t M = 4; @@ -472,6 +491,7 @@ TEST_F(StanAgradRevOde, memory_recovery_vv) { TEST_F(StanAgradRevOde, memory_recovery_exception_vv) { using stan::math::coupled_ode_system; using stan::math::var; + using stan::math::internal::ode_closure_adapter; std::string message = "ode throws"; const size_t N = 3; @@ -482,7 +502,8 @@ TEST_F(StanAgradRevOde, memory_recovery_exception_vv) { std::stringstream scoped_message; scoped_message << "iteration " << n; SCOPED_TRACE(scoped_message.str()); - mock_throwing_ode_functor throwing_ode(message, 1); + ode_closure_adapter> + throwing_ode{mock_throwing_ode_functor(message, 1)}; Eigen::Matrix y0_v = Eigen::VectorXd::Zero(N).template cast(); @@ -539,6 +560,7 @@ struct ayt { TEST_F(StanAgradRevOde, coupled_ode_system_var) { using stan::math::coupled_ode_system; using stan::math::var; + using stan::math::internal::ode_closure_adapter; Eigen::VectorXd y0(2); y0 << 0.1, 0.2; @@ -547,10 +569,10 @@ TEST_F(StanAgradRevOde, coupled_ode_system_var) { double a = 1.3; var av = a; - ayt func; + ode_closure_adapter func{ayt()}; - coupled_ode_system system(func, y0v, - &msgs, av); + coupled_ode_system system( + func, y0v, &msgs, av); std::vector z = {y0(0), y0(1), 3.2, 3.3, 4.4, 4.5, 2.1, 2.2}; @@ -572,6 +594,7 @@ TEST_F(StanAgradRevOde, coupled_ode_system_var) { TEST_F(StanAgradRevOde, coupled_ode_system_std_vector) { using stan::math::coupled_ode_system; using stan::math::var; + using stan::math::internal::ode_closure_adapter; Eigen::VectorXd y0(2); y0 << 0.1, 0.2; @@ -579,10 +602,10 @@ TEST_F(StanAgradRevOde, coupled_ode_system_std_vector) { std::vector av = {1.3}; - ayt func; + ode_closure_adapter func{ayt()}; - coupled_ode_system system(func, y0v, - &msgs, av); + coupled_ode_system system( + func, y0v, &msgs, av); std::vector z = {y0(0), y0(1), 3.2, 3.3, 4.4, 4.5, 2.1, 2.2}; @@ -604,6 +627,7 @@ TEST_F(StanAgradRevOde, coupled_ode_system_std_vector) { TEST_F(StanAgradRevOde, coupled_ode_system_vector) { using stan::math::coupled_ode_system; using stan::math::var; + using stan::math::internal::ode_closure_adapter; Eigen::VectorXd y0(2); y0 << 0.1, 0.2; @@ -612,10 +636,10 @@ TEST_F(StanAgradRevOde, coupled_ode_system_vector) { Eigen::Matrix av(1); av << 1.3; - ayt func; + ode_closure_adapter func{ayt()}; - coupled_ode_system system(func, y0v, - &msgs, av); + coupled_ode_system system( + func, y0v, &msgs, av); std::vector z = {y0(0), y0(1), 3.2, 3.3, 4.4, 4.5, 2.1, 2.2}; @@ -637,6 +661,7 @@ TEST_F(StanAgradRevOde, coupled_ode_system_vector) { TEST_F(StanAgradRevOde, coupled_ode_system_row_vector) { using stan::math::coupled_ode_system; using stan::math::var; + using stan::math::internal::ode_closure_adapter; Eigen::VectorXd y0(2); y0 << 0.1, 0.2; @@ -645,10 +670,10 @@ TEST_F(StanAgradRevOde, coupled_ode_system_row_vector) { Eigen::Matrix av(1); av << 1.3; - ayt func; + ode_closure_adapter func{ayt()}; - coupled_ode_system system(func, y0v, - &msgs, av); + coupled_ode_system system( + func, y0v, &msgs, av); std::vector z = {y0(0), y0(1), 3.2, 3.3, 4.4, 4.5, 2.1, 2.2}; @@ -670,6 +695,7 @@ TEST_F(StanAgradRevOde, coupled_ode_system_row_vector) { TEST_F(StanAgradRevOde, coupled_ode_system_matrix) { using stan::math::coupled_ode_system; using stan::math::var; + using stan::math::internal::ode_closure_adapter; Eigen::VectorXd y0(2); y0 << 0.1, 0.2; @@ -678,10 +704,10 @@ TEST_F(StanAgradRevOde, coupled_ode_system_matrix) { Eigen::Matrix av(1, 1); av << 1.3; - ayt func; + ode_closure_adapter func{ayt()}; - coupled_ode_system system(func, y0v, - &msgs, av); + coupled_ode_system system( + func, y0v, &msgs, av); std::vector z = {y0(0), y0(1), 3.2, 3.3, 4.4, 4.5, 2.1, 2.2}; @@ -703,6 +729,7 @@ TEST_F(StanAgradRevOde, coupled_ode_system_matrix) { TEST_F(StanAgradRevOde, coupled_ode_system_extra_args) { using stan::math::coupled_ode_system; using stan::math::var; + using stan::math::internal::ode_closure_adapter; Eigen::VectorXd y0(2); y0 << 0.1, 0.2; @@ -720,11 +747,11 @@ TEST_F(StanAgradRevOde, coupled_ode_system_extra_args) { Eigen::MatrixXd e6(1, 1); e6 << 0.1; - ayt func; + ode_closure_adapter func{ayt()}; - coupled_ode_system + coupled_ode_system system(func, y0v, &msgs, av, e1, e2, e3, e4, e5, e6); std::vector z = {y0(0), y0(1), 3.2, 3.3, 4.4, 4.5, 2.1, 2.2}; diff --git a/test/unit/math/rev/functor/integrate_ode_std_vector_interface_adapter_test.cpp b/test/unit/math/rev/functor/integrate_ode_std_vector_interface_adapter_test.cpp index f8611f7f57a..87512e58fce 100644 --- a/test/unit/math/rev/functor/integrate_ode_std_vector_interface_adapter_test.cpp +++ b/test/unit/math/rev/functor/integrate_ode_std_vector_interface_adapter_test.cpp @@ -8,8 +8,8 @@ TEST(StanMathRev, vd) { using stan::math::var; harm_osc_ode_data_fun harm_osc; stan::math::internal::integrate_ode_std_vector_interface_adapter< - harm_osc_ode_data_fun> - harm_osc_adapted(harm_osc); + stan::math::closure_adapter> + harm_osc_adapted{stan::math::from_lambda(harm_osc)}; std::vector theta = {0.15}; std::vector y = {1.0, 0.5}; @@ -22,7 +22,7 @@ TEST(StanMathRev, vd) { Eigen::Matrix out1 = stan::math::to_vector(harm_osc(t, y, theta, x, x_int, nullptr)); Eigen::Matrix out2 - = harm_osc_adapted(t, stan::math::to_vector(y), nullptr, theta, x, x_int); + = harm_osc_adapted(nullptr, t, stan::math::to_vector(y), theta, x, x_int); stan::math::sum(out1).grad(); Eigen::VectorXd adjs1(theta.size()); @@ -44,8 +44,8 @@ TEST(StanMathRev, dv) { using stan::math::var; harm_osc_ode_data_fun harm_osc; stan::math::internal::integrate_ode_std_vector_interface_adapter< - harm_osc_ode_data_fun> - harm_osc_adapted(harm_osc); + stan::math::closure_adapter> + harm_osc_adapted{stan::math::from_lambda(harm_osc)}; std::vector theta = {0.15}; std::vector y = {1.0, 0.5}; @@ -58,7 +58,7 @@ TEST(StanMathRev, dv) { Eigen::Matrix out1 = stan::math::to_vector(harm_osc(t, y, theta, x, x_int, nullptr)); Eigen::Matrix out2 - = harm_osc_adapted(t, stan::math::to_vector(y), nullptr, theta, x, x_int); + = harm_osc_adapted(nullptr, t, stan::math::to_vector(y), theta, x, x_int); stan::math::sum(out1).grad(); Eigen::VectorXd adjs1(y.size()); @@ -80,8 +80,8 @@ TEST(StanMathRev, vv) { using stan::math::var; harm_osc_ode_data_fun harm_osc; stan::math::internal::integrate_ode_std_vector_interface_adapter< - harm_osc_ode_data_fun> - harm_osc_adapted(harm_osc); + stan::math::closure_adapter> + harm_osc_adapted{stan::math::from_lambda(harm_osc)}; std::vector theta = {0.15}; std::vector y = {1.0, 0.5}; @@ -94,7 +94,7 @@ TEST(StanMathRev, vv) { Eigen::Matrix out1 = stan::math::to_vector(harm_osc(t, y, theta, x, x_int, nullptr)); Eigen::Matrix out2 - = harm_osc_adapted(t, stan::math::to_vector(y), nullptr, theta, x, x_int); + = harm_osc_adapted(nullptr, t, stan::math::to_vector(y), theta, x, x_int); stan::math::sum(out1).grad(); Eigen::VectorXd adjs_theta_1(theta.size()); diff --git a/test/unit/math/rev/functor/ode_rk45_rev_test.cpp b/test/unit/math/rev/functor/ode_rk45_rev_test.cpp index a1665e53e35..6d021038b06 100644 --- a/test/unit/math/rev/functor/ode_rk45_rev_test.cpp +++ b/test/unit/math/rev/functor/ode_rk45_rev_test.cpp @@ -265,6 +265,86 @@ TEST(StanMathOde_ode_rk45, scalar_std_vector_args) { EXPECT_NEAR(a1[0].adj(), -0.50107310888, 1e-5); } +TEST(StanMathOde_ode_rk45, closure_var) { + using stan::math::var; + + Eigen::VectorXd y0 = Eigen::VectorXd::Zero(1); + double t0 = 0.0; + std::vector ts = {1.1}; + + var a0 = 0.75; + std::vector a1 = {0.75}; + + auto f = stan::math::from_lambda( + [&](const auto& a, const auto& t, const auto& y, const auto& b, + std::ostream* msgs) { + return stan::test::Cos2Arg()(t, y, msgs, a, b); + }, + a0); + + var output = stan::math::ode_rk45(f, y0, t0, ts, nullptr, a1)[0][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(a0.adj(), -0.50107310888, 1e-5); + EXPECT_NEAR(a1[0].adj(), -0.50107310888, 1e-5); +} + +TEST(StanMathOde_ode_rk45, closure_double) { + using stan::math::var; + + Eigen::VectorXd y0 = Eigen::VectorXd::Zero(1); + double t0 = 0.0; + std::vector ts = {1.1}; + + var a0 = 0.75; + std::vector a1 = {0.75}; + + auto f = stan::math::from_lambda( + [](const auto& a, const auto& t, const auto& y, const auto& b, + std::ostream* msgs) { + return stan::test::Cos2Arg()(t, y, msgs, a, b); + }, + a0); + + var output = stan::math::ode_rk45(f, y0, t0, ts, nullptr, a1)[0][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(a0.adj(), -0.50107310888, 1e-5); +} + +TEST(StanMathOde_ode_rk45, higher_order) { + using stan::math::var; + + Eigen::VectorXd y0 = Eigen::VectorXd::Zero(1); + double t0 = 0.0; + std::vector ts = {1.1}; + + var a0 = 0.75; + std::vector a1 = {0.75}; + + auto f = stan::math::from_lambda( + [](const auto& a, const auto& t, const auto& y, const auto& b, + std::ostream* msgs) { + return stan::test::Cos2Arg()(t, y, msgs, a, b); + }, + a0); + + auto wrapper + = [](const auto& t, const auto& y, std::ostream* msgs, const auto& fa, + const auto& b) { return fa(msgs, t, y, b); }; + + var output = stan::math::ode_rk45(wrapper, y0, t0, ts, nullptr, f, a1)[0][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(a0.adj(), -0.50107310888, 1e-5); +} + TEST(StanMathOde_ode_rk45, std_vector_std_vector_args) { using stan::math::var; diff --git a/test/unit/math/rev/functor/ode_store_sensitivities_test.cpp b/test/unit/math/rev/functor/ode_store_sensitivities_test.cpp index b49028f03e9..5c15f33ac88 100644 --- a/test/unit/math/rev/functor/ode_store_sensitivities_test.cpp +++ b/test/unit/math/rev/functor/ode_store_sensitivities_test.cpp @@ -61,7 +61,7 @@ TEST(AgradRev, ode_store_sensitivities) { double a = 1.3; var av = a; - ayt func; + stan::math::internal::ode_closure_adapter func{ayt()}; double t0 = 0.5; double t = 0.75; @@ -110,7 +110,7 @@ TEST(AgradRev, ode_store_sensitivities_matrix) { Eigen::Matrix av(1, 1); av(0, 0) = a; - aytm func; + stan::math::internal::ode_closure_adapter func{aytm()}; double t0 = 0.5; double t = 0.75; diff --git a/test/unit/math/rev/functor/reduce_sum_closure_test.cpp b/test/unit/math/rev/functor/reduce_sum_closure_test.cpp new file mode 100644 index 00000000000..6720b2e8361 --- /dev/null +++ b/test/unit/math/rev/functor/reduce_sum_closure_test.cpp @@ -0,0 +1,89 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +struct closure_adapter { + template + auto operator()(const T_slice& subslice, std::size_t start, std::size_t end, + std::ostream* msgs, const F& f, Args... args) { + return f(msgs, subslice, start, end, args...); + } +}; + +TEST(StanMathRev_reduce_sum, grouped_gradient_closure) { + using stan::math::from_lambda; + using stan::math::var; + using stan::math::test::get_new_msg; + + double lambda_d = 10.0; + const std::size_t groups = 10; + const std::size_t elems_per_group = 1000; + const std::size_t elems = groups * elems_per_group; + + std::vector data(elems); + std::vector gidx(elems); + + for (std::size_t i = 0; i != elems; ++i) { + data[i] = i; + gidx[i] = i / elems_per_group; + } + + std::vector vlambda_v; + + for (std::size_t i = 0; i != groups; ++i) + vlambda_v.push_back(i + 0.2); + + var lambda_v = vlambda_v[0]; + + auto functor = from_lambda( + [](auto& lambda, auto& slice, std::size_t start, std::size_t end, + auto& gidx, std::ostream* msgs) { + const std::size_t num_terms = end - start + 1; + std::decay_t lambda_slice(num_terms); + for (std::size_t i = 0; i != num_terms; ++i) + lambda_slice[i] = lambda[gidx[start + i]]; + return stan::math::poisson_lpmf(slice, lambda_slice); + }, + vlambda_v); + + var poisson_lpdf + = stan::math::reduce_sum(data, 5, get_new_msg(), functor, gidx); + + std::vector vref_lambda_v; + for (std::size_t i = 0; i != elems; ++i) { + vref_lambda_v.push_back(vlambda_v[gidx[i]]); + } + var lambda_ref = vlambda_v[0]; + var poisson_lpdf_ref = stan::math::poisson_lpmf(data, vref_lambda_v); + + EXPECT_FLOAT_EQ(value_of(poisson_lpdf), value_of(poisson_lpdf_ref)); + + stan::math::grad(poisson_lpdf_ref.vi_); + const double lambda_ref_adj = lambda_ref.adj(); + + stan::math::set_zero_all_adjoints(); + stan::math::grad(poisson_lpdf.vi_); + const double lambda_adj = lambda_v.adj(); + + EXPECT_FLOAT_EQ(lambda_adj, lambda_ref_adj) + << "ref value of poisson lpdf : " << poisson_lpdf_ref.val() << std::endl + << "ref gradient wrt to lambda: " << lambda_ref_adj << std::endl + << "value of poisson lpdf : " << poisson_lpdf.val() << std::endl + << "gradient wrt to lambda: " << lambda_adj << std::endl; + + var poisson_lpdf_static + = stan::math::reduce_sum_static(data, 5, get_new_msg(), functor, gidx); + + stan::math::set_zero_all_adjoints(); + stan::math::grad(poisson_lpdf_static.vi_); + const double lambda_adj_static = lambda_v.adj(); + EXPECT_FLOAT_EQ(lambda_adj_static, lambda_ref_adj); + stan::math::recover_memory(); + + stan::math::recover_memory(); +}