diff --git a/stan/math/rev/fun/csr_matrix_times_vector.hpp b/stan/math/rev/fun/csr_matrix_times_vector.hpp index 4a0a93876e4..4665c7a4dbd 100644 --- a/stan/math/rev/fun/csr_matrix_times_vector.hpp +++ b/stan/math/rev/fun/csr_matrix_times_vector.hpp @@ -11,6 +11,119 @@ namespace stan { namespace math { +namespace internal { +/** + * `vari` for csr_matrix_times_vector + * @note `csr_matrix_times_vector` uses the old inheritance + * style to set up the reverse pass because of a linking + * issue on windows when using flto. + * + * @tparam Result_ Either a type inheriting from `Eigen::DenseBase` with scalar + * type `var` or a `var` where `T` inherits from `Eigen::DenseBase` + * @tparam WMat_ Either a type inheriting from `Eigen::DenseBase` with scalar + * type `var` or `double`. Or a `var` where `T` inherits from + * `Eigen::SparseBase` + * @tparam B_ Either a type inheriting from `Eigen::DenseBase` with scalar type + * `var` or `double`. Or a `var` where `T` inherits from `Eigen::DenseBase` + * + */ +template +struct csr_adjoint : public vari { + std::decay_t res_; + std::decay_t w_mat_; + std::decay_t b_; + + template + csr_adjoint(T1&& res, T2&& w_mat, T3&& b) + : vari(0.0), + res_(std::forward(res)), + w_mat_(std::forward(w_mat)), + b_(std::forward(b)) {} + + void chain() { chain_internal(res_, w_mat_, b_); } + + /** + * Overload for calculating adjoints of `w_mat` and `b` + * @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar + * type `var` or a `var` where `T` inherits from `Eigen::DenseBase` + * @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar + * type `var`. Or a `var` where `T` inherits from `Eigen::SparseBase` + * @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type + * `var`. Or a `var` where `T` inherits from `Eigen::DenseBase` + * @param res The vector result of the forward pass calculation + * @param w_mat A sparse matrix + * @param b A vector + */ + template * = nullptr, + require_rev_matrix_t* = nullptr> + inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) { + w_mat.adj() += res.adj() * b.val().transpose(); + b.adj() += w_mat.val().transpose() * res.adj(); + } + + /** + * Overload for calculating adjoints of `w_mat` + * @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar + * type `var` or a `var` where `T` inherits from `Eigen::DenseBase` + * @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar + * type `var`. Or a `var` where `T` inherits from `Eigen::SparseBase` + * @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type + * `double` + * @param res The vector result of the forward pass calculation + * @param w_mat A sparse matrix + * @param b A vector + */ + template * = nullptr, + require_not_rev_matrix_t* = nullptr> + inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) { + w_mat.adj() += res.adj() * b.transpose(); + } + + /** + * Overload for calculating adjoints of `b` + * @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar + * type `var` or a `var` where `T` inherits from `Eigen::DenseBase` + * @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar + * type `double` + * @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type + * `var` or a `var` where `T` inherits from `Eigen::DenseBase` + * @param res The vector result of the forward pass calculation + * @param w_mat A sparse matrix + * @param b A vector + */ + template * = nullptr, + require_rev_matrix_t* = nullptr> + inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) { + b.adj() += w_mat.transpose() * res.adj(); + } +}; + +/** + * Helper function to construct the csr_adjoint struct. + * @tparam Result_ Either a type inheriting from `Eigen::DenseBase` with scalar + * type `var` or a `var` where `T` inherits from `Eigen::DenseBase` + * @tparam WMat_ Either a type inheriting from `Eigen::DenseBase` with scalar + * type `var` or `double`. Or a `var` where `T` inherits from + * `Eigen::SparseBase` + * @tparam B_ Either a type inheriting from `Eigen::DenseBase` with scalar type + * `var` or `double`. Or a `var` where `T` inherits from `Eigen::DenseBase` + * + * @param res The vector result of the forward pass calculation + * @param w_mat A sparse matrix + * @param b A vector + */ +template +inline void make_csr_adjoint(Result_&& res, WMat_&& w_mat, B_&& b) { + new csr_adjoint, std::decay_t, std::decay_t>( + std::forward(res), std::forward(w_mat), + std::forward(b)); + return; +} +} // namespace internal + /** * \addtogroup csr_format * Return the multiplication of the sparse matrix (specified by @@ -74,10 +187,7 @@ inline auto csr_matrix_times_vector(int m, int n, const T1& w, sparse_var_value_t w_mat_arena = to_soa_sparse_matrix(m, n, w, u_arena, v_arena); arena_t res = w_mat_arena.val() * value_of(b_arena); - reverse_pass_callback([res, w_mat_arena, b_arena]() mutable { - w_mat_arena.adj() += res.adj() * b_arena.val().transpose(); - b_arena.adj() += w_mat_arena.val().transpose() * res.adj(); - }); + stan::math::internal::make_csr_adjoint(res, w_mat_arena, b_arena); return return_t(res); } else if (!is_constant::value) { arena_t> b_arena = b; @@ -85,18 +195,14 @@ inline auto csr_matrix_times_vector(int m, int n, const T1& w, sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(), v_arena.data(), w_val_arena.data()); arena_t res = w_val_mat * value_of(b_arena); - reverse_pass_callback([w_val_mat, res, b_arena]() mutable { - b_arena.adj() += w_val_mat.transpose() * res.adj(); - }); + stan::math::internal::make_csr_adjoint(res, w_val_mat, b_arena); return return_t(res); } else { sparse_var_value_t w_mat_arena = to_soa_sparse_matrix(m, n, w, u_arena, v_arena); auto b_arena = to_arena(value_of(b)); arena_t res = w_mat_arena.val() * b_arena; - reverse_pass_callback([res, w_mat_arena, b_arena]() mutable { - w_mat_arena.adj() += res.adj() * b_arena.transpose(); - }); + stan::math::internal::make_csr_adjoint(res, w_mat_arena, b_arena); return return_t(res); } }