Skip to content

Commit

Permalink
Merge pull request #3053 from stan-dev/fix/csr-matrix-seperate-vari
Browse files Browse the repository at this point in the history
use a seperate class for csr_matrix adjoint
  • Loading branch information
SteveBronder committed Apr 26, 2024
2 parents 08d8a22 + 04124da commit 9f759e1
Showing 1 changed file with 116 additions and 10 deletions.
126 changes: 116 additions & 10 deletions stan/math/rev/fun/csr_matrix_times_vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,119 @@
namespace stan {
namespace math {

namespace internal {
/**
* `vari` for csr_matrix_times_vector
* @note `csr_matrix_times_vector` uses the old inheritance
* style to set up the reverse pass because of a linking
* issue on windows when using flto.
*
* @tparam Result_ Either a type inheriting from `Eigen::DenseBase` with scalar
* type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase`
* @tparam WMat_ Either a type inheriting from `Eigen::DenseBase` with scalar
* type `var` or `double`. Or a `var<T>` where `T` inherits from
* `Eigen::SparseBase`
* @tparam B_ Either a type inheriting from `Eigen::DenseBase` with scalar type
* `var` or `double`. Or a `var<T>` where `T` inherits from `Eigen::DenseBase`
*
*/
template <typename Result_, typename WMat_, typename B_>
struct csr_adjoint : public vari {
std::decay_t<Result_> res_;
std::decay_t<WMat_> w_mat_;
std::decay_t<B_> b_;

template <typename T1, typename T2, typename T3>
csr_adjoint(T1&& res, T2&& w_mat, T3&& b)
: vari(0.0),
res_(std::forward<T1>(res)),
w_mat_(std::forward<T2>(w_mat)),
b_(std::forward<T3>(b)) {}

void chain() { chain_internal(res_, w_mat_, b_); }

/**
* Overload for calculating adjoints of `w_mat` and `b`
* @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar
* type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase`
* @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar
* type `var`. Or a `var<T>` where `T` inherits from `Eigen::SparseBase`
* @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type
* `var`. Or a `var<T>` where `T` inherits from `Eigen::DenseBase`
* @param res The vector result of the forward pass calculation
* @param w_mat A sparse matrix
* @param b A vector
*/
template <typename Result, typename WMat, typename B,
require_rev_matrix_t<WMat>* = nullptr,
require_rev_matrix_t<B>* = nullptr>
inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
w_mat.adj() += res.adj() * b.val().transpose();
b.adj() += w_mat.val().transpose() * res.adj();
}

/**
* Overload for calculating adjoints of `w_mat`
* @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar
* type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase`
* @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar
* type `var`. Or a `var<T>` where `T` inherits from `Eigen::SparseBase`
* @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type
* `double`
* @param res The vector result of the forward pass calculation
* @param w_mat A sparse matrix
* @param b A vector
*/
template <typename Result, typename WMat, typename B,
require_rev_matrix_t<WMat>* = nullptr,
require_not_rev_matrix_t<B>* = nullptr>
inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
w_mat.adj() += res.adj() * b.transpose();
}

/**
* Overload for calculating adjoints of `b`
* @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar
* type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase`
* @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar
* type `double`
* @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type
* `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase`
* @param res The vector result of the forward pass calculation
* @param w_mat A sparse matrix
* @param b A vector
*/
template <typename Result, typename WMat, typename B,
require_not_rev_matrix_t<WMat>* = nullptr,
require_rev_matrix_t<B>* = nullptr>
inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
b.adj() += w_mat.transpose() * res.adj();
}
};

/**
* Helper function to construct the csr_adjoint struct.
* @tparam Result_ Either a type inheriting from `Eigen::DenseBase` with scalar
* type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase`
* @tparam WMat_ Either a type inheriting from `Eigen::DenseBase` with scalar
* type `var` or `double`. Or a `var<T>` where `T` inherits from
* `Eigen::SparseBase`
* @tparam B_ Either a type inheriting from `Eigen::DenseBase` with scalar type
* `var` or `double`. Or a `var<T>` where `T` inherits from `Eigen::DenseBase`
*
* @param res The vector result of the forward pass calculation
* @param w_mat A sparse matrix
* @param b A vector
*/
template <typename Result_, typename WMat_, typename B_>
inline void make_csr_adjoint(Result_&& res, WMat_&& w_mat, B_&& b) {
new csr_adjoint<std::decay_t<Result_>, std::decay_t<WMat_>, std::decay_t<B_>>(
std::forward<Result_>(res), std::forward<WMat_>(w_mat),
std::forward<B_>(b));
return;
}
} // namespace internal

/**
* \addtogroup csr_format
* Return the multiplication of the sparse matrix (specified by
Expand Down Expand Up @@ -74,29 +187,22 @@ inline auto csr_matrix_times_vector(int m, int n, const T1& w,
sparse_var_value_t w_mat_arena
= to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
arena_t<return_t> res = w_mat_arena.val() * value_of(b_arena);
reverse_pass_callback([res, w_mat_arena, b_arena]() mutable {
w_mat_arena.adj() += res.adj() * b_arena.val().transpose();
b_arena.adj() += w_mat_arena.val().transpose() * res.adj();
});
stan::math::internal::make_csr_adjoint(res, w_mat_arena, b_arena);
return return_t(res);
} else if (!is_constant<T2>::value) {
arena_t<promote_scalar_t<var, T2>> b_arena = b;
auto w_val_arena = to_arena(value_of(w));
sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(),
v_arena.data(), w_val_arena.data());
arena_t<return_t> res = w_val_mat * value_of(b_arena);
reverse_pass_callback([w_val_mat, res, b_arena]() mutable {
b_arena.adj() += w_val_mat.transpose() * res.adj();
});
stan::math::internal::make_csr_adjoint(res, w_val_mat, b_arena);
return return_t(res);
} else {
sparse_var_value_t w_mat_arena
= to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
auto b_arena = to_arena(value_of(b));
arena_t<return_t> res = w_mat_arena.val() * b_arena;
reverse_pass_callback([res, w_mat_arena, b_arena]() mutable {
w_mat_arena.adj() += res.adj() * b_arena.transpose();
});
stan::math::internal::make_csr_adjoint(res, w_mat_arena, b_arena);
return return_t(res);
}
}
Expand Down

0 comments on commit 9f759e1

Please sign in to comment.