From 34cf55476748dcd0424b7f4e18c4a7280a6aab53 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Wed, 24 Apr 2024 17:51:44 -0400 Subject: [PATCH 1/4] use a seperate class for csr_matrix adjoint --- stan/math/rev/fun/csr_matrix_times_vector.hpp | 57 +++++++++++++++---- 1 file changed, 47 insertions(+), 10 deletions(-) diff --git a/stan/math/rev/fun/csr_matrix_times_vector.hpp b/stan/math/rev/fun/csr_matrix_times_vector.hpp index 4a0a93876e4..b27dcb33f92 100644 --- a/stan/math/rev/fun/csr_matrix_times_vector.hpp +++ b/stan/math/rev/fun/csr_matrix_times_vector.hpp @@ -11,6 +11,50 @@ namespace stan { namespace math { +namespace internal { + template + struct csr_adjoint : public vari { + std::decay_t res_; + std::decay_t w_mat_; + std::decay_t b_; + template + csr_adjoint(T1&& res, T2&& w_mat, T3&& b) + : vari(0.0), res_(std::forward(res)), + w_mat_(std::forward(w_mat)), b_(std::forward(b)) {} + + void chain() { + chain_internal(res_, w_mat_, b_); + } + template * = nullptr, + require_rev_matrix_t* = nullptr> + void chain_internal(Result&& res, WMat&& w_mat, B&& b) { + w_mat.adj() += res.adj() * b.val().transpose(); + b.adj() += w_mat.val().transpose() * res.adj(); + } + + template * = nullptr, + require_not_rev_matrix_t* = nullptr> + void chain_internal(Result&& res, WMat&& w_mat, B&& b) { + w_mat.adj() += res.adj() * b.transpose(); + } + + template * = nullptr, + require_rev_matrix_t* = nullptr> + void chain_internal(Result&& res, WMat&& w_mat, B&& b) { + b.adj() += w_mat.transpose() * res.adj(); + } + }; + template + inline vari* make_csr_adjoint(Result_&& res, WMat_&& w_mat, B_&& b) { + return new csr_adjoint( + std::forward(res), std::forward(w_mat), + std::forward(b)); + } +} + /** * \addtogroup csr_format * Return the multiplication of the sparse matrix (specified by @@ -74,10 +118,7 @@ inline auto csr_matrix_times_vector(int m, int n, const T1& w, sparse_var_value_t w_mat_arena = to_soa_sparse_matrix(m, n, w, u_arena, v_arena); arena_t res = w_mat_arena.val() * value_of(b_arena); - reverse_pass_callback([res, w_mat_arena, b_arena]() mutable { - w_mat_arena.adj() += res.adj() * b_arena.val().transpose(); - b_arena.adj() += w_mat_arena.val().transpose() * res.adj(); - }); + stan::math::internal::make_csr_adjoint(res, w_mat_arena, b_arena); return return_t(res); } else if (!is_constant::value) { arena_t> b_arena = b; @@ -85,18 +126,14 @@ inline auto csr_matrix_times_vector(int m, int n, const T1& w, sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(), v_arena.data(), w_val_arena.data()); arena_t res = w_val_mat * value_of(b_arena); - reverse_pass_callback([w_val_mat, res, b_arena]() mutable { - b_arena.adj() += w_val_mat.transpose() * res.adj(); - }); + stan::math::internal::make_csr_adjoint(res, w_val_mat, b_arena); return return_t(res); } else { sparse_var_value_t w_mat_arena = to_soa_sparse_matrix(m, n, w, u_arena, v_arena); auto b_arena = to_arena(value_of(b)); arena_t res = w_mat_arena.val() * b_arena; - reverse_pass_callback([res, w_mat_arena, b_arena]() mutable { - w_mat_arena.adj() += res.adj() * b_arena.transpose(); - }); + stan::math::internal::make_csr_adjoint(res, w_mat_arena, b_arena); return return_t(res); } } From a3a88a568c7d416a7fd2e09a1c31c0b7b801180f Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Wed, 24 Apr 2024 17:54:16 -0400 Subject: [PATCH 2/4] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- stan/math/rev/fun/csr_matrix_times_vector.hpp | 74 +++++++++---------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/stan/math/rev/fun/csr_matrix_times_vector.hpp b/stan/math/rev/fun/csr_matrix_times_vector.hpp index b27dcb33f92..09b64f45338 100644 --- a/stan/math/rev/fun/csr_matrix_times_vector.hpp +++ b/stan/math/rev/fun/csr_matrix_times_vector.hpp @@ -12,48 +12,48 @@ namespace stan { namespace math { namespace internal { - template - struct csr_adjoint : public vari { - std::decay_t res_; - std::decay_t w_mat_; - std::decay_t b_; - template - csr_adjoint(T1&& res, T2&& w_mat, T3&& b) - : vari(0.0), res_(std::forward(res)), - w_mat_(std::forward(w_mat)), b_(std::forward(b)) {} +template +struct csr_adjoint : public vari { + std::decay_t res_; + std::decay_t w_mat_; + std::decay_t b_; + template + csr_adjoint(T1&& res, T2&& w_mat, T3&& b) + : vari(0.0), + res_(std::forward(res)), + w_mat_(std::forward(w_mat)), + b_(std::forward(b)) {} - void chain() { - chain_internal(res_, w_mat_, b_); - } - template * = nullptr, - require_rev_matrix_t* = nullptr> - void chain_internal(Result&& res, WMat&& w_mat, B&& b) { - w_mat.adj() += res.adj() * b.val().transpose(); - b.adj() += w_mat.val().transpose() * res.adj(); - } + void chain() { chain_internal(res_, w_mat_, b_); } + template * = nullptr, + require_rev_matrix_t* = nullptr> + void chain_internal(Result&& res, WMat&& w_mat, B&& b) { + w_mat.adj() += res.adj() * b.val().transpose(); + b.adj() += w_mat.val().transpose() * res.adj(); + } - template * = nullptr, - require_not_rev_matrix_t* = nullptr> - void chain_internal(Result&& res, WMat&& w_mat, B&& b) { - w_mat.adj() += res.adj() * b.transpose(); - } + template * = nullptr, + require_not_rev_matrix_t* = nullptr> + void chain_internal(Result&& res, WMat&& w_mat, B&& b) { + w_mat.adj() += res.adj() * b.transpose(); + } - template * = nullptr, - require_rev_matrix_t* = nullptr> - void chain_internal(Result&& res, WMat&& w_mat, B&& b) { - b.adj() += w_mat.transpose() * res.adj(); - } - }; - template - inline vari* make_csr_adjoint(Result_&& res, WMat_&& w_mat, B_&& b) { - return new csr_adjoint( - std::forward(res), std::forward(w_mat), - std::forward(b)); + template * = nullptr, + require_rev_matrix_t* = nullptr> + void chain_internal(Result&& res, WMat&& w_mat, B&& b) { + b.adj() += w_mat.transpose() * res.adj(); } +}; +template +inline vari* make_csr_adjoint(Result_&& res, WMat_&& w_mat, B_&& b) { + return new csr_adjoint(std::forward(res), + std::forward(w_mat), + std::forward(b)); } +} // namespace internal /** * \addtogroup csr_format From f74882535f9072d421d3461ca427b8a8a078a468 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Thu, 25 Apr 2024 13:27:19 -0400 Subject: [PATCH 3/4] update docs for new vari for csr_matrix_times_vector --- stan/math/rev/fun/csr_matrix_times_vector.hpp | 62 +++++++++++++++++-- 1 file changed, 57 insertions(+), 5 deletions(-) diff --git a/stan/math/rev/fun/csr_matrix_times_vector.hpp b/stan/math/rev/fun/csr_matrix_times_vector.hpp index 09b64f45338..4524a116043 100644 --- a/stan/math/rev/fun/csr_matrix_times_vector.hpp +++ b/stan/math/rev/fun/csr_matrix_times_vector.hpp @@ -12,11 +12,23 @@ namespace stan { namespace math { namespace internal { +/** + * `vari` for csr_matrix_times_vector + * @note `csr_matrix_times_vector` uses the old inheritance + * style to set up the reverse pass because of a linking + * issue on windows when using flto. + * + * @tparam Result_ Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or a `var` where `T` inherits from `Eigen::DenseBase` + * @tparam WMat_ Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or `double`. Or a `var` where `T` inherits from `Eigen::SparseBase` + * @tparam B_ Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or `double`. Or a `var` where `T` inherits from `Eigen::DenseBase` + * + */ template struct csr_adjoint : public vari { std::decay_t res_; std::decay_t w_mat_; std::decay_t b_; + template csr_adjoint(T1&& res, T2&& w_mat, T3&& b) : vari(0.0), @@ -25,33 +37,73 @@ struct csr_adjoint : public vari { b_(std::forward(b)) {} void chain() { chain_internal(res_, w_mat_, b_); } + + /** + * Overload for calculating adjoints of `w_mat` and `b` + * @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or a `var` where `T` inherits from `Eigen::DenseBase` + * @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar type `var`. Or a `var` where `T` inherits from `Eigen::SparseBase` + * @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type `var`. Or a `var` where `T` inherits from `Eigen::DenseBase` + * @param res The vector result of the forward pass calculation + * @param w_mat A sparse matrix + * @param b A vector + */ template * = nullptr, require_rev_matrix_t* = nullptr> - void chain_internal(Result&& res, WMat&& w_mat, B&& b) { + inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) { w_mat.adj() += res.adj() * b.val().transpose(); b.adj() += w_mat.val().transpose() * res.adj(); } + /** + * Overload for calculating adjoints of `w_mat` + * @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or a `var` where `T` inherits from `Eigen::DenseBase` + * @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar type `var`. Or a `var` where `T` inherits from `Eigen::SparseBase` + * @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type `double` + * @param res The vector result of the forward pass calculation + * @param w_mat A sparse matrix + * @param b A vector + */ template * = nullptr, require_not_rev_matrix_t* = nullptr> - void chain_internal(Result&& res, WMat&& w_mat, B&& b) { + inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) { w_mat.adj() += res.adj() * b.transpose(); } + /** + * Overload for calculating adjoints of `b` + * @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or a `var` where `T` inherits from `Eigen::DenseBase` + * @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar type `double` + * @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or a `var` where `T` inherits from `Eigen::DenseBase` + * @param res The vector result of the forward pass calculation + * @param w_mat A sparse matrix + * @param b A vector + */ template * = nullptr, require_rev_matrix_t* = nullptr> - void chain_internal(Result&& res, WMat&& w_mat, B&& b) { + inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) { b.adj() += w_mat.transpose() * res.adj(); } }; + +/** + * Helper function to construct the csr_adjoint struct. + * @tparam Result_ Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or a `var` where `T` inherits from `Eigen::DenseBase` + * @tparam WMat_ Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or `double`. Or a `var` where `T` inherits from `Eigen::SparseBase` + * @tparam B_ Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or `double`. Or a `var` where `T` inherits from `Eigen::DenseBase` + * + * @param res The vector result of the forward pass calculation + * @param w_mat A sparse matrix + * @param b A vector + */ template -inline vari* make_csr_adjoint(Result_&& res, WMat_&& w_mat, B_&& b) { - return new csr_adjoint(std::forward(res), +inline void make_csr_adjoint(Result_&& res, WMat_&& w_mat, B_&& b) { + new csr_adjoint, std::decay_t, std::decay_t>(std::forward(res), std::forward(w_mat), std::forward(b)); + return; } } // namespace internal From 04124da73ef1d76f3a536ddb67a016212c86fd28 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Thu, 25 Apr 2024 13:27:49 -0400 Subject: [PATCH 4/4] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- stan/math/rev/fun/csr_matrix_times_vector.hpp | 65 ++++++++++++------- 1 file changed, 41 insertions(+), 24 deletions(-) diff --git a/stan/math/rev/fun/csr_matrix_times_vector.hpp b/stan/math/rev/fun/csr_matrix_times_vector.hpp index 4524a116043..4665c7a4dbd 100644 --- a/stan/math/rev/fun/csr_matrix_times_vector.hpp +++ b/stan/math/rev/fun/csr_matrix_times_vector.hpp @@ -14,14 +14,18 @@ namespace math { namespace internal { /** * `vari` for csr_matrix_times_vector - * @note `csr_matrix_times_vector` uses the old inheritance - * style to set up the reverse pass because of a linking + * @note `csr_matrix_times_vector` uses the old inheritance + * style to set up the reverse pass because of a linking * issue on windows when using flto. - * - * @tparam Result_ Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or a `var` where `T` inherits from `Eigen::DenseBase` - * @tparam WMat_ Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or `double`. Or a `var` where `T` inherits from `Eigen::SparseBase` - * @tparam B_ Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or `double`. Or a `var` where `T` inherits from `Eigen::DenseBase` - * + * + * @tparam Result_ Either a type inheriting from `Eigen::DenseBase` with scalar + * type `var` or a `var` where `T` inherits from `Eigen::DenseBase` + * @tparam WMat_ Either a type inheriting from `Eigen::DenseBase` with scalar + * type `var` or `double`. Or a `var` where `T` inherits from + * `Eigen::SparseBase` + * @tparam B_ Either a type inheriting from `Eigen::DenseBase` with scalar type + * `var` or `double`. Or a `var` where `T` inherits from `Eigen::DenseBase` + * */ template struct csr_adjoint : public vari { @@ -40,9 +44,12 @@ struct csr_adjoint : public vari { /** * Overload for calculating adjoints of `w_mat` and `b` - * @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or a `var` where `T` inherits from `Eigen::DenseBase` - * @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar type `var`. Or a `var` where `T` inherits from `Eigen::SparseBase` - * @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type `var`. Or a `var` where `T` inherits from `Eigen::DenseBase` + * @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar + * type `var` or a `var` where `T` inherits from `Eigen::DenseBase` + * @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar + * type `var`. Or a `var` where `T` inherits from `Eigen::SparseBase` + * @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type + * `var`. Or a `var` where `T` inherits from `Eigen::DenseBase` * @param res The vector result of the forward pass calculation * @param w_mat A sparse matrix * @param b A vector @@ -57,9 +64,12 @@ struct csr_adjoint : public vari { /** * Overload for calculating adjoints of `w_mat` - * @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or a `var` where `T` inherits from `Eigen::DenseBase` - * @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar type `var`. Or a `var` where `T` inherits from `Eigen::SparseBase` - * @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type `double` + * @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar + * type `var` or a `var` where `T` inherits from `Eigen::DenseBase` + * @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar + * type `var`. Or a `var` where `T` inherits from `Eigen::SparseBase` + * @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type + * `double` * @param res The vector result of the forward pass calculation * @param w_mat A sparse matrix * @param b A vector @@ -73,9 +83,12 @@ struct csr_adjoint : public vari { /** * Overload for calculating adjoints of `b` - * @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or a `var` where `T` inherits from `Eigen::DenseBase` - * @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar type `double` - * @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or a `var` where `T` inherits from `Eigen::DenseBase` + * @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar + * type `var` or a `var` where `T` inherits from `Eigen::DenseBase` + * @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar + * type `double` + * @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type + * `var` or a `var` where `T` inherits from `Eigen::DenseBase` * @param res The vector result of the forward pass calculation * @param w_mat A sparse matrix * @param b A vector @@ -90,19 +103,23 @@ struct csr_adjoint : public vari { /** * Helper function to construct the csr_adjoint struct. - * @tparam Result_ Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or a `var` where `T` inherits from `Eigen::DenseBase` - * @tparam WMat_ Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or `double`. Or a `var` where `T` inherits from `Eigen::SparseBase` - * @tparam B_ Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or `double`. Or a `var` where `T` inherits from `Eigen::DenseBase` - * + * @tparam Result_ Either a type inheriting from `Eigen::DenseBase` with scalar + * type `var` or a `var` where `T` inherits from `Eigen::DenseBase` + * @tparam WMat_ Either a type inheriting from `Eigen::DenseBase` with scalar + * type `var` or `double`. Or a `var` where `T` inherits from + * `Eigen::SparseBase` + * @tparam B_ Either a type inheriting from `Eigen::DenseBase` with scalar type + * `var` or `double`. Or a `var` where `T` inherits from `Eigen::DenseBase` + * * @param res The vector result of the forward pass calculation * @param w_mat A sparse matrix - * @param b A vector + * @param b A vector */ template inline void make_csr_adjoint(Result_&& res, WMat_&& w_mat, B_&& b) { - new csr_adjoint, std::decay_t, std::decay_t>(std::forward(res), - std::forward(w_mat), - std::forward(b)); + new csr_adjoint, std::decay_t, std::decay_t>( + std::forward(res), std::forward(w_mat), + std::forward(b)); return; } } // namespace internal