From 7e3d8c229a7b5f11bfaf651f3c6c8c9f60555747 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Fri, 15 Oct 2021 11:26:09 +0800 Subject: [PATCH 01/27] Initial testing --- stan/math/prim/eigen_plugins.h | 35 +++----- stan/math/prim/plugins/adj_view.h | 107 +++++++++++++++++++++++ stan/math/prim/plugins/typedefs.h | 23 +++++ stan/math/prim/plugins/val_view.h | 137 ++++++++++++++++++++++++++++++ 4 files changed, 280 insertions(+), 22 deletions(-) create mode 100644 stan/math/prim/plugins/adj_view.h create mode 100644 stan/math/prim/plugins/typedefs.h create mode 100644 stan/math/prim/plugins/val_view.h diff --git a/stan/math/prim/eigen_plugins.h b/stan/math/prim/eigen_plugins.h index 98e54c8a4da..e8e0eb5bfaa 100644 --- a/stan/math/prim/eigen_plugins.h +++ b/stan/math/prim/eigen_plugins.h @@ -1,17 +1,7 @@ -/** - * Reimplements is_fvar without requiring external math headers - * - * decltype((void)(T::d_)) is a pre C++17 replacement for - * std::void_t - * - * TODO(Andrew): Replace with std::void_t after move to C++17 - */ -template -struct is_fvar : std::false_type -{ }; -template -struct is_fvar : std::true_type -{ }; + +#include "plugins/typedefs.h" +#include "plugins/adj_view.h" +//#include "plugins/val_view.h" //TODO(Andrew): Replace std::is_const<>::value with std::is_const_v<> after move to C++17 template @@ -47,6 +37,7 @@ using forward_return_t = std::conditional_t(derived()); * Structure to return adjoints from var and vari*. Deduces whether the variables * are pointers (i.e. vari*) to determine whether to return the adjoint or * first point to the underlying vari* (in the case of var). - */ + *//* struct adj_Op { EIGEN_EMPTY_STRUCT_CTOR(adj_Op); @@ -159,33 +150,33 @@ struct adj_Op { std::enable_if_t::value, reverse_return_t> operator()(T &v) const { return v.vi_->adj_; } -}; +};*/ /** * Coefficient-wise function applying adj_Op struct to a matrix of const var * and returning a const matrix of type T containing the values - */ + *//* inline const CwiseUnaryOp adj() const { return CwiseUnaryOp(derived()); } - +*/ /** * Coefficient-wise function applying adj_Op struct to a matrix of var * and returning a view to a matrix of doubles of the adjoints that can * be modified. This is meant to be used on the rhs of expressions. - */ + *//* inline CwiseUnaryOp adj_op() { return CwiseUnaryOp(derived()); -} +}*/ /** * Coefficient-wise function applying adj_Op struct to a matrix of var * and returning a view to a matrix of doubles of the adjoints that can * be modified - */ + *//* inline CwiseUnaryView adj() { return CwiseUnaryView(derived()); -} +}*/ /** * Structure to return vari* from a var. */ diff --git a/stan/math/prim/plugins/adj_view.h b/stan/math/prim/plugins/adj_view.h new file mode 100644 index 00000000000..6cc246d3770 --- /dev/null +++ b/stan/math/prim/plugins/adj_view.h @@ -0,0 +1,107 @@ +template +EIGEN_DEVICE_FUNC +static inline const double& adj(const Scalar& x) { + return adj_impl>::run(x); +} + +template +EIGEN_DEVICE_FUNC +static inline const double& +adj_ref(const Scalar& x) { + return adj_ref_impl::run(x); +} + +template +EIGEN_DEVICE_FUNC +static inline double& adj_ref(Scalar& x) { + return adj_ref_impl>::run(x); +} + +template +struct adj_default_impl { }; + +template +struct adj_default_impl::value>> { + EIGEN_DEVICE_FUNC + static inline double& run(Scalar& x) { + return x->adj_; + } + EIGEN_DEVICE_FUNC + static inline const double& run(const Scalar& x) { + return x->adj_; + } +}; + +template +struct adj_default_impl::value>> { + EIGEN_DEVICE_FUNC + static inline double& run(Scalar& x) { + return x.vi_->adj_; + } + EIGEN_DEVICE_FUNC + static inline const double& run(const Scalar& x) { + return x.vi_->adj_; + } +}; + +template +struct adj_impl : adj_default_impl {}; + +template +struct scalar_adj_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_adj_op) + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE double operator() (const Scalar& a) const { return adj(a); } +}; + + +template +struct adj_ref_default_impl { }; + +template +struct adj_ref_default_impl::value>> { + EIGEN_DEVICE_FUNC + static inline double& run(Scalar& x) { + return *reinterpret_cast(&(x->adj_)); + } + EIGEN_DEVICE_FUNC + static inline const double& run(const Scalar& x) { + return *reinterpret_cast(&(x->adj_)); + } +}; + +template +struct adj_ref_default_impl::value>> { + EIGEN_DEVICE_FUNC + static inline double& run(Scalar& x) { + return *reinterpret_cast(&(x.vi_->adj_)); + } + EIGEN_DEVICE_FUNC + static inline const double& run(const Scalar& x) { + return *reinterpret_cast(&(x.vi_->adj_)); + } +}; + +template +struct adj_ref_impl : adj_ref_default_impl {}; + +template +struct scalar_adj_ref_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_adj_ref_op) + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE double& operator() (const Scalar& a) const { + return adj_ref(*const_cast(&a)); + } +}; + +typedef CwiseUnaryOp, const Derived> AdjReturnType; +typedef CwiseUnaryView, Derived> NonConstAdjReturnType; + +EIGEN_DEVICE_FUNC +inline const AdjReturnType +adj() const { return AdjReturnType(derived()); } + + +EIGEN_DEVICE_FUNC +inline NonConstAdjReturnType +adj() { return NonConstAdjReturnType(derived()); } \ No newline at end of file diff --git a/stan/math/prim/plugins/typedefs.h b/stan/math/prim/plugins/typedefs.h new file mode 100644 index 00000000000..e6018c057aa --- /dev/null +++ b/stan/math/prim/plugins/typedefs.h @@ -0,0 +1,23 @@ +template +struct is_fvar : std::false_type +{ }; +template +struct is_fvar : std::true_type +{ }; + +template +struct is_var : std::false_type +{ }; +template +struct is_var>::vi_))> : std::true_type +{ }; + +template +struct is_vari : std::false_type +{ }; +template +struct is_vari>::adj_))> : std::true_type +{ }; + +template +using eigen_base_filter_t = typename internal::global_math_functions_filtering_base::type; \ No newline at end of file diff --git a/stan/math/prim/plugins/val_view.h b/stan/math/prim/plugins/val_view.h new file mode 100644 index 00000000000..1f6615cf5bc --- /dev/null +++ b/stan/math/prim/plugins/val_view.h @@ -0,0 +1,137 @@ +template +struct val_return { }; + +template +struct val_return::value>> { + using type = double; +}; +template +struct val_return::value>> { + using type = decltype(T::d_); +}; + +template +using val_return_t = typename val_return>::type; + +template +EIGEN_DEVICE_FUNC +static inline val_return_t val(const Scalar& x){ + return val_impl>::run(x); +} + +template +EIGEN_DEVICE_FUNC +static inline const val_return_t& val_ref(const Scalar& x) { + return val_ref_impl::run(x); +} + +template +EIGEN_DEVICE_FUNC +static inline +val_return_t& +val_ref(Scalar& x) { + return val_ref_impl>::run(x); +} + +template +struct val_default_impl { }; + +template +struct val_default_impl::value>> { + EIGEN_DEVICE_FUNC + static inline double run(const Scalar& x) { + return x; + } +}; + +template +struct val_default_impl::value>> { + EIGEN_DEVICE_FUNC + static inline double run(const Scalar& x) { + return x->val_; + } +}; + +template +struct val_default_impl::value>> { + EIGEN_DEVICE_FUNC + static inline double run(const Scalar& x) { + return x.vi_->val_; + } +}; + +template +struct val_default_impl::value>> { + EIGEN_DEVICE_FUNC + static inline val_return_t run(const Scalar& x) { + return x.val_; + } +}; + +template struct val_impl : val_default_impl {}; + +template +struct scalar_val_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_val_op) + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE val_return_t operator() (const Scalar& a) const { return val(a); } +}; + +template +struct val_ref_default_impl { }; + +template +struct val_ref_default_impl::value>> { + EIGEN_DEVICE_FUNC + static inline double& run(Scalar& x) { + return x; + } + EIGEN_DEVICE_FUNC + static inline const double& run(const Scalar& x) { + return x; + } +}; + +template +struct val_ref_default_impl::value>> { + EIGEN_DEVICE_FUNC + static inline val_return_t& run(Scalar& x) { + return *reinterpret_cast*>(&(x.val_)); + } + EIGEN_DEVICE_FUNC + static inline const val_return_t& run(const Scalar& x) { + return *reinterpret_cast*>(&(x.val_)); + } +}; + +template +struct val_ref_impl : val_ref_default_impl {}; + +template +struct scalar_val_ref_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_val_ref_op) + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE + val_return_t& + operator() (const Scalar& a) const { + return val_ref(*const_cast(&a)); + } +}; + + +/** \internal the return type of imag() const */ +typedef CwiseUnaryOp, const Derived> valReturnType; +/** \internal the return type of imag() */ +typedef std::conditional_t::value || is_vari::value, + const valReturnType, + CwiseUnaryView, Derived>> +NonConstvalReturnType; + +EIGEN_DEVICE_FUNC +inline const valReturnType +val() const { return valReturnType(derived()); } + + +EIGEN_DEVICE_FUNC +inline NonConstvalReturnType +val() { return NonConstvalReturnType(derived()); } \ No newline at end of file From e2e50d5a20b9880bab9f23d172599eb04352f3c8 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Fri, 15 Oct 2021 11:45:28 +0800 Subject: [PATCH 02/27] Repalcements --- stan/math/opencl/rev/vari.hpp | 1 - stan/math/rev/core/operator_addition.hpp | 2 +- stan/math/rev/core/operator_division.hpp | 2 +- stan/math/rev/core/operator_subtraction.hpp | 2 +- stan/math/rev/core/var.hpp | 1 - stan/math/rev/core/vari.hpp | 3 --- stan/math/rev/fun/eigenvectors_sym.hpp | 2 +- stan/math/rev/fun/generalized_inverse.hpp | 6 +++--- stan/math/rev/fun/inverse.hpp | 2 +- stan/math/rev/fun/multiply.hpp | 8 ++++---- stan/math/rev/fun/svd_U.hpp | 4 ++-- stan/math/rev/fun/svd_V.hpp | 4 ++-- stan/math/rev/fun/tcrossprod.hpp | 2 +- 13 files changed, 17 insertions(+), 22 deletions(-) diff --git a/stan/math/opencl/rev/vari.hpp b/stan/math/opencl/rev/vari.hpp index cd35d14c0f0..42c038b5fd4 100644 --- a/stan/math/opencl/rev/vari.hpp +++ b/stan/math/opencl/rev/vari.hpp @@ -58,7 +58,6 @@ class vari_cl_base : public vari_base { */ inline auto& adj() { return adj_; } inline auto& adj() const { return adj_; } - inline auto& adj_op() { return adj_; } /** * Returns a view into a block of matrix. diff --git a/stan/math/rev/core/operator_addition.hpp b/stan/math/rev/core/operator_addition.hpp index 57d77baa661..3a290808b65 100644 --- a/stan/math/rev/core/operator_addition.hpp +++ b/stan/math/rev/core/operator_addition.hpp @@ -149,7 +149,7 @@ inline auto add(const VarMat& a, const Arith& b) { arena_t arena_a(a); arena_t ret(arena_a.val().array() + as_array_or_scalar(b)); reverse_pass_callback( - [ret, arena_a]() mutable { arena_a.adj() += ret.adj_op(); }); + [ret, arena_a]() mutable { arena_a.adj() += ret.adj(); }); return ret_type(ret); } diff --git a/stan/math/rev/core/operator_division.hpp b/stan/math/rev/core/operator_division.hpp index bf0307c6135..6a5519e76b6 100644 --- a/stan/math/rev/core/operator_division.hpp +++ b/stan/math/rev/core/operator_division.hpp @@ -140,7 +140,7 @@ inline auto divide(const Mat& m, Scalar c) { auto inv_c = (1.0 / value_of(c)); arena_t> res = inv_c * arena_m.val(); reverse_pass_callback([inv_c, arena_m, res]() mutable { - arena_m.adj().array() += inv_c * res.adj_op().array(); + arena_m.adj().array() += inv_c * res.adj().array(); }); return promote_scalar_t(res); } else { diff --git a/stan/math/rev/core/operator_subtraction.hpp b/stan/math/rev/core/operator_subtraction.hpp index e5c186cde97..dd3deb5b320 100644 --- a/stan/math/rev/core/operator_subtraction.hpp +++ b/stan/math/rev/core/operator_subtraction.hpp @@ -182,7 +182,7 @@ inline auto subtract(const Arith& a, const VarMat& b) { arena_t arena_b = b; arena_t ret(as_array_or_scalar(a) - arena_b.val().array()); reverse_pass_callback( - [ret, arena_b]() mutable { arena_b.adj() -= ret.adj_op(); }); + [ret, arena_b]() mutable { arena_b.adj() -= ret.adj(); }); return ret_type(ret); } diff --git a/stan/math/rev/core/var.hpp b/stan/math/rev/core/var.hpp index b51c06545d2..d6ed7479414 100644 --- a/stan/math/rev/core/var.hpp +++ b/stan/math/rev/core/var.hpp @@ -417,7 +417,6 @@ class var_value> { */ inline auto& adj() { return vi_->adj(); } inline auto& adj() const { return vi_->adj(); } - inline auto& adj_op() { return vi_->adj(); } inline Eigen::Index rows() const { return vi_->rows(); } inline Eigen::Index cols() const { return vi_->cols(); } diff --git a/stan/math/rev/core/vari.hpp b/stan/math/rev/core/vari.hpp index 2d2f230b3a0..cffd5066e86 100644 --- a/stan/math/rev/core/vari.hpp +++ b/stan/math/rev/core/vari.hpp @@ -610,7 +610,6 @@ class vari_view< */ inline auto& adj() { return adj_; } inline auto& adj() const { return adj_; } - inline auto& adj_op() { return adj_; } void set_zero_adjoint() {} void chain() {} @@ -722,7 +721,6 @@ class vari_value, is_eigen_dense_base>> */ inline auto& adj() { return adj_; } inline auto& adj() const { return adj_; } - inline auto& adj_op() { return adj_; } virtual void chain() {} /** @@ -874,7 +872,6 @@ class vari_value> : public vari_base, */ inline auto& adj() { return adj_; } inline auto& adj() const { return adj_; } - inline auto& adj_op() { return adj_; } void chain() {} /** diff --git a/stan/math/rev/fun/eigenvectors_sym.hpp b/stan/math/rev/fun/eigenvectors_sym.hpp index 8fc7a03b702..57d70bab1a0 100644 --- a/stan/math/rev/fun/eigenvectors_sym.hpp +++ b/stan/math/rev/fun/eigenvectors_sym.hpp @@ -43,7 +43,7 @@ inline auto eigenvectors_sym(const T& m) { f.diagonal().setZero(); arena_m.adj() += eigenvecs.val_op() - * f.cwiseProduct(eigenvecs.val_op().transpose() * eigenvecs.adj_op()) + * f.cwiseProduct(eigenvecs.val_op().transpose() * eigenvecs.adj()) * eigenvecs.val_op().transpose(); }); diff --git a/stan/math/rev/fun/generalized_inverse.hpp b/stan/math/rev/fun/generalized_inverse.hpp index 48a95deca23..e0514c8683c 100644 --- a/stan/math/rev/fun/generalized_inverse.hpp +++ b/stan/math/rev/fun/generalized_inverse.hpp @@ -23,14 +23,14 @@ template inline auto generalized_inverse_lambda(T1& G_arena, T2& inv_G) { return [G_arena, inv_G]() mutable { G_arena.adj() - += -(inv_G.val_op().transpose() * inv_G.adj_op() + += -(inv_G.val_op().transpose() * inv_G.adj() * inv_G.val_op().transpose()) + (-G_arena.val_op() * inv_G.val_op() + Eigen::MatrixXd::Identity(G_arena.rows(), inv_G.cols())) - * inv_G.adj_op().transpose() * inv_G.val_op() + * inv_G.adj().transpose() * inv_G.val_op() * inv_G.val_op().transpose() + inv_G.val_op().transpose() * inv_G.val_op() - * inv_G.adj_op().transpose() + * inv_G.adj().transpose() * (-inv_G.val_op() * G_arena.val_op() + Eigen::MatrixXd::Identity(inv_G.rows(), G_arena.cols())); }; diff --git a/stan/math/rev/fun/inverse.hpp b/stan/math/rev/fun/inverse.hpp index 655c871fb7b..68cfd513449 100644 --- a/stan/math/rev/fun/inverse.hpp +++ b/stan/math/rev/fun/inverse.hpp @@ -33,7 +33,7 @@ inline auto inverse(const T& m) { arena_t res = res_val; reverse_pass_callback([res, res_val, arena_m]() mutable { - arena_m.adj() -= res_val.transpose() * res.adj_op() * res_val.transpose(); + arena_m.adj() -= res_val.transpose() * res.adj() * res_val.transpose(); }); return ret_type(res); diff --git a/stan/math/rev/fun/multiply.hpp b/stan/math/rev/fun/multiply.hpp index c2fda8be407..c0374fb3938 100644 --- a/stan/math/rev/fun/multiply.hpp +++ b/stan/math/rev/fun/multiply.hpp @@ -40,8 +40,8 @@ inline auto multiply(const T1& A, const T2& B) { reverse_pass_callback( [arena_A, arena_B, arena_A_val, arena_B_val, res]() mutable { if (is_var_matrix::value || is_var_matrix::value) { - arena_A.adj() += res.adj_op() * arena_B_val.transpose(); - arena_B.adj() += arena_A_val.transpose() * res.adj_op(); + arena_A.adj() += res.adj() * arena_B_val.transpose(); + arena_B.adj() += arena_A_val.transpose() * res.adj(); } else { auto res_adj = res.adj().eval(); arena_A.adj() += res_adj * arena_B_val.transpose(); @@ -56,7 +56,7 @@ inline auto multiply(const T1& A, const T2& B) { = return_var_matrix_t; arena_t res = arena_A * arena_B.val_op(); reverse_pass_callback([arena_B, arena_A, res]() mutable { - arena_B.adj() += arena_A.transpose() * res.adj_op(); + arena_B.adj() += arena_A.transpose() * res.adj(); }); return return_t(res); } else { @@ -67,7 +67,7 @@ inline auto multiply(const T1& A, const T2& B) { T2>; arena_t res = arena_A.val_op() * arena_B; reverse_pass_callback([arena_A, arena_B, res]() mutable { - arena_A.adj() += res.adj_op() * arena_B.transpose(); + arena_A.adj() += res.adj() * arena_B.transpose(); }); return return_t(res); diff --git a/stan/math/rev/fun/svd_U.hpp b/stan/math/rev/fun/svd_U.hpp index 589c9c069a4..90556029655 100644 --- a/stan/math/rev/fun/svd_U.hpp +++ b/stan/math/rev/fun/svd_U.hpp @@ -52,7 +52,7 @@ inline auto svd_U(const EigMat& m) { reverse_pass_callback([arena_m, arena_U, arena_D, arena_V, arena_Fp, M]() mutable { - Eigen::MatrixXd UUadjT = arena_U.val_op().transpose() * arena_U.adj_op(); + Eigen::MatrixXd UUadjT = arena_U.val_op().transpose() * arena_U.adj(); arena_m.adj() += .5 * arena_U.val_op() * (arena_Fp.array() * (UUadjT - UUadjT.transpose()).array()) @@ -60,7 +60,7 @@ inline auto svd_U(const EigMat& m) { * arena_V.transpose() + (Eigen::MatrixXd::Identity(arena_m.rows(), arena_m.rows()) - arena_U.val_op() * arena_U.val_op().transpose()) - * arena_U.adj_op() * arena_D.asDiagonal().inverse() + * arena_U.adj() * arena_D.asDiagonal().inverse() * arena_V.transpose(); }); diff --git a/stan/math/rev/fun/svd_V.hpp b/stan/math/rev/fun/svd_V.hpp index cbd267b0f53..2689b6cb2af 100644 --- a/stan/math/rev/fun/svd_V.hpp +++ b/stan/math/rev/fun/svd_V.hpp @@ -52,14 +52,14 @@ inline auto svd_V(const EigMat& m) { reverse_pass_callback([arena_m, arena_U, arena_D, arena_V, arena_Fm, M]() mutable { - Eigen::MatrixXd VTVadj = arena_V.val_op().transpose() * arena_V.adj_op(); + Eigen::MatrixXd VTVadj = arena_V.val_op().transpose() * arena_V.adj(); arena_m.adj() += 0.5 * arena_U * (arena_Fm.array() * (VTVadj - VTVadj.transpose()).array()) .matrix() * arena_V.val_op().transpose() + arena_U * arena_D.asDiagonal().inverse() - * arena_V.adj_op().transpose() + * arena_V.adj().transpose() * (Eigen::MatrixXd::Identity(arena_m.cols(), arena_m.cols()) - arena_V.val_op() * arena_V.val_op().transpose()); }); diff --git a/stan/math/rev/fun/tcrossprod.hpp b/stan/math/rev/fun/tcrossprod.hpp index 5762b08ee20..6fd8d35561b 100644 --- a/stan/math/rev/fun/tcrossprod.hpp +++ b/stan/math/rev/fun/tcrossprod.hpp @@ -30,7 +30,7 @@ inline auto tcrossprod(const T& M) { if (likely(M.size() > 0)) { reverse_pass_callback([res, arena_M]() mutable { arena_M.adj() - += (res.adj_op() + res.adj_op().transpose()) * arena_M.val_op(); + += (res.adj() + res.adj().transpose()) * arena_M.val_op(); }); } From 6c5fc2df2c29a31d12bdc052dc33139f71b3a023 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Fri, 15 Oct 2021 12:12:54 +0800 Subject: [PATCH 03/27] coeffRef def --- stan/math/prim/eigen_plugins.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/stan/math/prim/eigen_plugins.h b/stan/math/prim/eigen_plugins.h index e8e0eb5bfaa..2fd41d4858a 100644 --- a/stan/math/prim/eigen_plugins.h +++ b/stan/math/prim/eigen_plugins.h @@ -205,5 +205,13 @@ inline CwiseUnaryView vi() { return CwiseUnaryView(derived()); } + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +const Scalar& coeffRef(Index row, Index col) const { + eigen_internal_assert(row >= 0 && row < rows() + && col >= 0 && col < cols()); + return internal::evaluator(derived()).coeffRef(row, col); +} + #define EIGEN_STAN_MATRIXBASE_PLUGIN #define EIGEN_STAN_ARRAYBASE_PLUGIN From 86ca245889ba1318b26b57b2a03bd6fece324eaf Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Fri, 15 Oct 2021 12:16:32 +0800 Subject: [PATCH 04/27] Replace val_op --- stan/math/opencl/rev/vari.hpp | 2 +- stan/math/rev/core/var.hpp | 2 +- stan/math/rev/core/vari.hpp | 6 +-- stan/math/rev/fun/eigenvectors_sym.hpp | 6 +-- stan/math/rev/fun/generalized_inverse.hpp | 22 ++++----- stan/math/rev/fun/mdivide_left.hpp | 4 +- stan/math/rev/fun/mdivide_left_ldlt.hpp | 4 +- stan/math/rev/fun/mdivide_left_spd.hpp | 2 +- stan/math/rev/fun/multiply.hpp | 4 +- stan/math/rev/fun/sqrt.hpp | 2 +- stan/math/rev/fun/svd_U.hpp | 6 +-- stan/math/rev/fun/svd_V.hpp | 6 +-- stan/math/rev/fun/tcrossprod.hpp | 4 +- stan/math/rev/fun/trace.hpp | 2 +- .../rev/fun/trace_gen_inv_quad_form_ldlt.hpp | 36 +++++++------- stan/math/rev/fun/trace_gen_quad_form.hpp | 48 +++++++++---------- .../math/rev/fun/trace_inv_quad_form_ldlt.hpp | 4 +- 17 files changed, 80 insertions(+), 80 deletions(-) diff --git a/stan/math/opencl/rev/vari.hpp b/stan/math/opencl/rev/vari.hpp index 42c038b5fd4..423c371e4fd 100644 --- a/stan/math/opencl/rev/vari.hpp +++ b/stan/math/opencl/rev/vari.hpp @@ -46,7 +46,7 @@ class vari_cl_base : public vari_base { * @return The value of this vari. */ inline const auto& val() const { return val_; } - inline auto& val_op() { return val_; } + inline auto& val() { return val_; } /** * Return a reference to the derivative of the root expression with diff --git a/stan/math/rev/core/var.hpp b/stan/math/rev/core/var.hpp index d6ed7479414..1a6e7bd12fd 100644 --- a/stan/math/rev/core/var.hpp +++ b/stan/math/rev/core/var.hpp @@ -405,7 +405,7 @@ class var_value> { * @return The value of this variable. */ inline const auto& val() const { return vi_->val(); } - inline auto& val_op() { return vi_->val_op(); } + inline auto& val() { return vi_->val(); } /** * Return a reference to the derivative of the root expression with diff --git a/stan/math/rev/core/vari.hpp b/stan/math/rev/core/vari.hpp index cffd5066e86..23ec0bae968 100644 --- a/stan/math/rev/core/vari.hpp +++ b/stan/math/rev/core/vari.hpp @@ -598,7 +598,7 @@ class vari_view< * @return The value of this vari. */ inline const auto& val() const { return val_; } - inline auto& val_op() { return val_; } + inline auto& val() { return val_; } /** * Return a reference to the derivative of the root expression with @@ -709,7 +709,7 @@ class vari_value, is_eigen_dense_base>> * @return The value of this vari. */ inline const auto& val() const { return val_; } - inline auto& val_op() { return val_; } + inline auto& val() { return val_; } /** * Return a reference to the derivative of the root expression with @@ -860,7 +860,7 @@ class vari_value> : public vari_base, * @return The value of this vari. */ inline const auto& val() const { return val_; } - inline auto& val_op() { return val_; } + inline auto& val() { return val_; } /** * Return a reference to the derivative of the root expression with diff --git a/stan/math/rev/fun/eigenvectors_sym.hpp b/stan/math/rev/fun/eigenvectors_sym.hpp index 57d70bab1a0..1459b54b629 100644 --- a/stan/math/rev/fun/eigenvectors_sym.hpp +++ b/stan/math/rev/fun/eigenvectors_sym.hpp @@ -42,9 +42,9 @@ inline auto eigenvectors_sym(const T& m) { .array()); f.diagonal().setZero(); arena_m.adj() - += eigenvecs.val_op() - * f.cwiseProduct(eigenvecs.val_op().transpose() * eigenvecs.adj()) - * eigenvecs.val_op().transpose(); + += eigenvecs.val() + * f.cwiseProduct(eigenvecs.val().transpose() * eigenvecs.adj()) + * eigenvecs.val().transpose(); }); return return_t(eigenvecs); diff --git a/stan/math/rev/fun/generalized_inverse.hpp b/stan/math/rev/fun/generalized_inverse.hpp index e0514c8683c..283ba338e0f 100644 --- a/stan/math/rev/fun/generalized_inverse.hpp +++ b/stan/math/rev/fun/generalized_inverse.hpp @@ -23,15 +23,15 @@ template inline auto generalized_inverse_lambda(T1& G_arena, T2& inv_G) { return [G_arena, inv_G]() mutable { G_arena.adj() - += -(inv_G.val_op().transpose() * inv_G.adj() - * inv_G.val_op().transpose()) - + (-G_arena.val_op() * inv_G.val_op() + += -(inv_G.val().transpose() * inv_G.adj() + * inv_G.val().transpose()) + + (-G_arena.val() * inv_G.val() + Eigen::MatrixXd::Identity(G_arena.rows(), inv_G.cols())) - * inv_G.adj().transpose() * inv_G.val_op() - * inv_G.val_op().transpose() - + inv_G.val_op().transpose() * inv_G.val_op() + * inv_G.adj().transpose() * inv_G.val() + * inv_G.val().transpose() + + inv_G.val().transpose() * inv_G.val() * inv_G.adj().transpose() - * (-inv_G.val_op() * G_arena.val_op() + * (-inv_G.val() * G_arena.val() + Eigen::MatrixXd::Identity(inv_G.rows(), G_arena.cols())); }; } @@ -82,17 +82,17 @@ inline auto generalized_inverse(const VarMat& G) { } } else if (G.rows() < G.cols()) { arena_t G_arena(G); - arena_t inv_G((G_arena.val_op() * G_arena.val_op().transpose()) + arena_t inv_G((G_arena.val() * G_arena.val().transpose()) .ldlt() - .solve(G_arena.val_op()) + .solve(G_arena.val()) .transpose()); reverse_pass_callback(internal::generalized_inverse_lambda(G_arena, inv_G)); return ret_type(inv_G); } else { arena_t G_arena(G); - arena_t inv_G((G_arena.val_op().transpose() * G_arena.val_op()) + arena_t inv_G((G_arena.val().transpose() * G_arena.val()) .ldlt() - .solve(G_arena.val_op().transpose())); + .solve(G_arena.val().transpose())); reverse_pass_callback(internal::generalized_inverse_lambda(G_arena, inv_G)); return ret_type(inv_G); } diff --git a/stan/math/rev/fun/mdivide_left.hpp b/stan/math/rev/fun/mdivide_left.hpp index 8c0c88d8ed5..00293903497 100644 --- a/stan/math/rev/fun/mdivide_left.hpp +++ b/stan/math/rev/fun/mdivide_left.hpp @@ -51,7 +51,7 @@ inline auto mdivide_left(const T1& A, const T2& B) { .template triangularView() .transpose() .solve(res.adj()); - arena_A.adj() -= adjB * res.val_op().transpose(); + arena_A.adj() -= adjB * res.val().transpose(); arena_B.adj() += adjB; }); @@ -80,7 +80,7 @@ inline auto mdivide_left(const T1& A, const T2& B) { .template triangularView() .transpose() .solve(res.adj()) - * res.val_op().transpose(); + * res.val().transpose(); }); return ret_type(res); } diff --git a/stan/math/rev/fun/mdivide_left_ldlt.hpp b/stan/math/rev/fun/mdivide_left_ldlt.hpp index 5c40c81c6d2..ee21553a9d4 100644 --- a/stan/math/rev/fun/mdivide_left_ldlt.hpp +++ b/stan/math/rev/fun/mdivide_left_ldlt.hpp @@ -44,7 +44,7 @@ inline auto mdivide_left_ldlt(LDLT_factor& A, const T2& B) { reverse_pass_callback([arena_A, arena_B, ldlt_ptr, res]() mutable { promote_scalar_t adjB = ldlt_ptr->solve(res.adj()); - arena_A.adj() -= adjB * res.val_op().transpose(); + arena_A.adj() -= adjB * res.val().transpose(); arena_B.adj() += adjB; }); @@ -55,7 +55,7 @@ inline auto mdivide_left_ldlt(LDLT_factor& A, const T2& B) { const auto* ldlt_ptr = make_chainable_ptr(A.ldlt()); reverse_pass_callback([arena_A, ldlt_ptr, res]() mutable { - arena_A.adj() -= ldlt_ptr->solve(res.adj()) * res.val_op().transpose(); + arena_A.adj() -= ldlt_ptr->solve(res.adj()) * res.val().transpose(); }); return ret_type(res); diff --git a/stan/math/rev/fun/mdivide_left_spd.hpp b/stan/math/rev/fun/mdivide_left_spd.hpp index 754415ec337..93800fabd40 100644 --- a/stan/math/rev/fun/mdivide_left_spd.hpp +++ b/stan/math/rev/fun/mdivide_left_spd.hpp @@ -290,7 +290,7 @@ inline auto mdivide_left_spd(const T1 &A, const T2 &B) { .transpose() .solveInPlace(adjB); - arena_A.adj() -= adjB * res.val_op().transpose(); + arena_A.adj() -= adjB * res.val().transpose(); arena_B.adj() += adjB; }); diff --git a/stan/math/rev/fun/multiply.hpp b/stan/math/rev/fun/multiply.hpp index c0374fb3938..eaa4a1755ac 100644 --- a/stan/math/rev/fun/multiply.hpp +++ b/stan/math/rev/fun/multiply.hpp @@ -54,7 +54,7 @@ inline auto multiply(const T1& A, const T2& B) { arena_t> arena_B = B; using return_t = return_var_matrix_t; - arena_t res = arena_A * arena_B.val_op(); + arena_t res = arena_A * arena_B.val(); reverse_pass_callback([arena_B, arena_A, res]() mutable { arena_B.adj() += arena_A.transpose() * res.adj(); }); @@ -65,7 +65,7 @@ inline auto multiply(const T1& A, const T2& B) { using return_t = return_var_matrix_t; - arena_t res = arena_A.val_op() * arena_B; + arena_t res = arena_A.val() * arena_B; reverse_pass_callback([arena_A, arena_B, res]() mutable { arena_A.adj() += res.adj() * arena_B.transpose(); }); diff --git a/stan/math/rev/fun/sqrt.hpp b/stan/math/rev/fun/sqrt.hpp index 9d433da8307..53e74ea52a5 100644 --- a/stan/math/rev/fun/sqrt.hpp +++ b/stan/math/rev/fun/sqrt.hpp @@ -58,7 +58,7 @@ template * = nullptr> inline auto sqrt(const T& a) { return make_callback_var( a.val().array().sqrt().matrix(), [a](auto& vi) mutable { - a.adj().array() += vi.adj().array() / (2.0 * vi.val_op().array()); + a.adj().array() += vi.adj().array() / (2.0 * vi.val().array()); }); } diff --git a/stan/math/rev/fun/svd_U.hpp b/stan/math/rev/fun/svd_U.hpp index 90556029655..80654e32c29 100644 --- a/stan/math/rev/fun/svd_U.hpp +++ b/stan/math/rev/fun/svd_U.hpp @@ -52,14 +52,14 @@ inline auto svd_U(const EigMat& m) { reverse_pass_callback([arena_m, arena_U, arena_D, arena_V, arena_Fp, M]() mutable { - Eigen::MatrixXd UUadjT = arena_U.val_op().transpose() * arena_U.adj(); + Eigen::MatrixXd UUadjT = arena_U.val().transpose() * arena_U.adj(); arena_m.adj() - += .5 * arena_U.val_op() + += .5 * arena_U.val() * (arena_Fp.array() * (UUadjT - UUadjT.transpose()).array()) .matrix() * arena_V.transpose() + (Eigen::MatrixXd::Identity(arena_m.rows(), arena_m.rows()) - - arena_U.val_op() * arena_U.val_op().transpose()) + - arena_U.val() * arena_U.val().transpose()) * arena_U.adj() * arena_D.asDiagonal().inverse() * arena_V.transpose(); }); diff --git a/stan/math/rev/fun/svd_V.hpp b/stan/math/rev/fun/svd_V.hpp index 2689b6cb2af..7e460747969 100644 --- a/stan/math/rev/fun/svd_V.hpp +++ b/stan/math/rev/fun/svd_V.hpp @@ -52,16 +52,16 @@ inline auto svd_V(const EigMat& m) { reverse_pass_callback([arena_m, arena_U, arena_D, arena_V, arena_Fm, M]() mutable { - Eigen::MatrixXd VTVadj = arena_V.val_op().transpose() * arena_V.adj(); + Eigen::MatrixXd VTVadj = arena_V.val().transpose() * arena_V.adj(); arena_m.adj() += 0.5 * arena_U * (arena_Fm.array() * (VTVadj - VTVadj.transpose()).array()) .matrix() - * arena_V.val_op().transpose() + * arena_V.val().transpose() + arena_U * arena_D.asDiagonal().inverse() * arena_V.adj().transpose() * (Eigen::MatrixXd::Identity(arena_m.cols(), arena_m.cols()) - - arena_V.val_op() * arena_V.val_op().transpose()); + - arena_V.val() * arena_V.val().transpose()); }); return ret_type(arena_V); diff --git a/stan/math/rev/fun/tcrossprod.hpp b/stan/math/rev/fun/tcrossprod.hpp index 6fd8d35561b..f683186b6c6 100644 --- a/stan/math/rev/fun/tcrossprod.hpp +++ b/stan/math/rev/fun/tcrossprod.hpp @@ -25,12 +25,12 @@ inline auto tcrossprod(const T& M) { using ret_type = return_var_matrix_t< Eigen::Matrix, T>; arena_t arena_M = M; - arena_t res = arena_M.val_op() * arena_M.val_op().transpose(); + arena_t res = arena_M.val() * arena_M.val().transpose(); if (likely(M.size() > 0)) { reverse_pass_callback([res, arena_M]() mutable { arena_M.adj() - += (res.adj() + res.adj().transpose()) * arena_M.val_op(); + += (res.adj() + res.adj().transpose()) * arena_M.val(); }); } diff --git a/stan/math/rev/fun/trace.hpp b/stan/math/rev/fun/trace.hpp index 362e798bf2e..137c109f225 100644 --- a/stan/math/rev/fun/trace.hpp +++ b/stan/math/rev/fun/trace.hpp @@ -23,7 +23,7 @@ template * = nullptr> inline auto trace(const T& m) { arena_t arena_m = m; - return make_callback_var(arena_m.val_op().trace(), + return make_callback_var(arena_m.val().trace(), [arena_m](const auto& vi) mutable { arena_m.adj().diagonal().array() += vi.adj(); }); diff --git a/stan/math/rev/fun/trace_gen_inv_quad_form_ldlt.hpp b/stan/math/rev/fun/trace_gen_inv_quad_form_ldlt.hpp index f88802af48d..01df7e0a222 100644 --- a/stan/math/rev/fun/trace_gen_inv_quad_form_ldlt.hpp +++ b/stan/math/rev/fun/trace_gen_inv_quad_form_ldlt.hpp @@ -46,7 +46,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor& A, arena_t> arena_B = B; arena_t> arena_D = D; auto AsolveB = to_arena(A.ldlt().solve(arena_B.val())); - auto BTAsolveB = to_arena(arena_B.val_op().transpose() * AsolveB); + auto BTAsolveB = to_arena(arena_B.val().transpose() * AsolveB); var res = (arena_D.val() * BTAsolveB).trace(); @@ -54,10 +54,10 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor& A, [arena_A, BTAsolveB, AsolveB, arena_B, arena_D, res]() mutable { double C_adj = res.adj(); - arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().transpose() + arena_A.adj() -= C_adj * AsolveB * arena_D.val().transpose() * AsolveB.transpose(); arena_B.adj() += C_adj * AsolveB - * (arena_D.val_op() + arena_D.val_op().transpose()); + * (arena_D.val() + arena_D.val().transpose()); arena_D.adj() += C_adj * BTAsolveB; }); @@ -69,7 +69,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor& A, arena_t> arena_D = value_of(D); auto AsolveB = to_arena(A.ldlt().solve(arena_B.val())); - var res = (arena_D * arena_B.val_op().transpose() * AsolveB).trace(); + var res = (arena_D * arena_B.val().transpose() * AsolveB).trace(); reverse_pass_callback([arena_A, AsolveB, arena_B, arena_D, res]() mutable { double C_adj = res.adj(); @@ -94,7 +94,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor& A, [arena_A, BTAsolveB, AsolveB, arena_D, res]() mutable { double C_adj = res.adj(); - arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().transpose() + arena_A.adj() -= C_adj * AsolveB * arena_D.val().transpose() * AsolveB.transpose(); arena_D.adj() += C_adj * BTAsolveB; }); @@ -112,7 +112,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor& A, reverse_pass_callback([arena_A, AsolveB, arena_D, res]() mutable { double C_adj = res.adj(); - arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().transpose() + arena_A.adj() -= C_adj * AsolveB * arena_D.val().transpose() * AsolveB.transpose(); }); @@ -122,7 +122,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor& A, arena_t> arena_B = B; arena_t> arena_D = D; auto AsolveB = to_arena(A.ldlt().solve(arena_B.val())); - auto BTAsolveB = to_arena(arena_B.val_op().transpose() * AsolveB); + auto BTAsolveB = to_arena(arena_B.val().transpose() * AsolveB); var res = (arena_D.val() * BTAsolveB).trace(); @@ -131,7 +131,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor& A, double C_adj = res.adj(); arena_B.adj() += C_adj * AsolveB - * (arena_D.val_op() + arena_D.val_op().transpose()); + * (arena_D.val() + arena_D.val().transpose()); arena_D.adj() += C_adj * BTAsolveB; }); @@ -142,7 +142,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor& A, arena_t> arena_D = value_of(D); auto AsolveB = to_arena(A.ldlt().solve(arena_B.val())); - var res = (arena_D * arena_B.val_op().transpose() * AsolveB).trace(); + var res = (arena_D * arena_B.val().transpose() * AsolveB).trace(); reverse_pass_callback([AsolveB, arena_B, arena_D, res]() mutable { arena_B.adj() += res.adj() * AsolveB * (arena_D + arena_D.transpose()); @@ -202,7 +202,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor& A, arena_t> arena_B = B; arena_t> arena_D = D; auto AsolveB = to_arena(A.ldlt().solve(arena_B.val())); - auto BTAsolveB = to_arena(arena_B.val_op().transpose() * AsolveB); + auto BTAsolveB = to_arena(arena_B.val().transpose() * AsolveB); var res = (arena_D.val().asDiagonal() * BTAsolveB).trace(); @@ -210,9 +210,9 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor& A, [arena_A, BTAsolveB, AsolveB, arena_B, arena_D, res]() mutable { double C_adj = res.adj(); - arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().asDiagonal() + arena_A.adj() -= C_adj * AsolveB * arena_D.val().asDiagonal() * AsolveB.transpose(); - arena_B.adj() += C_adj * AsolveB * 2 * arena_D.val_op().asDiagonal(); + arena_B.adj() += C_adj * AsolveB * 2 * arena_D.val().asDiagonal(); arena_D.adj() += C_adj * BTAsolveB.diagonal(); }); @@ -224,7 +224,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor& A, arena_t> arena_D = value_of(D); auto AsolveB = to_arena(A.ldlt().solve(arena_B.val())); - var res = (arena_D.asDiagonal() * arena_B.val_op().transpose() * AsolveB) + var res = (arena_D.asDiagonal() * arena_B.val().transpose() * AsolveB) .trace(); reverse_pass_callback([arena_A, AsolveB, arena_B, arena_D, res]() mutable { @@ -250,7 +250,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor& A, [arena_A, BTAsolveB, AsolveB, arena_D, res]() mutable { double C_adj = res.adj(); - arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().asDiagonal() + arena_A.adj() -= C_adj * AsolveB * arena_D.val().asDiagonal() * AsolveB.transpose(); arena_D.adj() += C_adj * BTAsolveB.diagonal(); }); @@ -269,7 +269,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor& A, reverse_pass_callback([arena_A, AsolveB, arena_D, res]() mutable { double C_adj = res.adj(); - arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().asDiagonal() + arena_A.adj() -= C_adj * AsolveB * arena_D.val().asDiagonal() * AsolveB.transpose(); }); @@ -279,7 +279,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor& A, arena_t> arena_B = B; arena_t> arena_D = D; auto AsolveB = to_arena(A.ldlt().solve(arena_B.val())); - auto BTAsolveB = to_arena(arena_B.val_op().transpose() * AsolveB); + auto BTAsolveB = to_arena(arena_B.val().transpose() * AsolveB); var res = (arena_D.val().asDiagonal() * BTAsolveB).trace(); @@ -287,7 +287,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor& A, [BTAsolveB, AsolveB, arena_B, arena_D, res]() mutable { double C_adj = res.adj(); - arena_B.adj() += C_adj * AsolveB * 2 * arena_D.val_op().asDiagonal(); + arena_B.adj() += C_adj * AsolveB * 2 * arena_D.val().asDiagonal(); arena_D.adj() += C_adj * BTAsolveB.diagonal(); }); @@ -298,7 +298,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor& A, arena_t> arena_D = value_of(D); auto AsolveB = to_arena(A.ldlt().solve(arena_B.val())); - var res = (arena_D.asDiagonal() * arena_B.val_op().transpose() * AsolveB) + var res = (arena_D.asDiagonal() * arena_B.val().transpose() * AsolveB) .trace(); reverse_pass_callback([AsolveB, arena_B, arena_D, res]() mutable { diff --git a/stan/math/rev/fun/trace_gen_quad_form.hpp b/stan/math/rev/fun/trace_gen_quad_form.hpp index 13d74683ba2..0b0751fd6c6 100644 --- a/stan/math/rev/fun/trace_gen_quad_form.hpp +++ b/stan/math/rev/fun/trace_gen_quad_form.hpp @@ -148,8 +148,8 @@ inline var trace_gen_quad_form(const Td& D, const Ta& A, const Tb& B) { arena_t> arena_A = A; arena_t> arena_B = B; - auto arena_BDT = to_arena(arena_B.val_op() * arena_D.val_op().transpose()); - auto arena_AB = to_arena(arena_A.val_op() * arena_B.val_op()); + auto arena_BDT = to_arena(arena_B.val() * arena_D.val().transpose()); + auto arena_AB = to_arena(arena_A.val() * arena_B.val()); var res = (arena_BDT.transpose() * arena_AB).trace(); @@ -157,13 +157,13 @@ inline var trace_gen_quad_form(const Td& D, const Ta& A, const Tb& B) { [arena_A, arena_B, arena_D, arena_BDT, arena_AB, res]() mutable { double C_adj = res.adj(); - arena_A.adj() += C_adj * arena_BDT * arena_B.val_op().transpose(); + arena_A.adj() += C_adj * arena_BDT * arena_B.val().transpose(); arena_B.adj() += C_adj - * (arena_AB * arena_D.val_op() - + arena_A.val_op().transpose() * arena_BDT); + * (arena_AB * arena_D.val() + + arena_A.val().transpose() * arena_BDT); - arena_D.adj() += C_adj * (arena_AB.transpose() * arena_B.val_op()); + arena_D.adj() += C_adj * (arena_AB.transpose() * arena_B.val()); }); return res; @@ -173,8 +173,8 @@ inline var trace_gen_quad_form(const Td& D, const Ta& A, const Tb& B) { arena_t> arena_A = A; arena_t> arena_B = B; - auto arena_BDT = to_arena(arena_B.val_op() * arena_D.transpose()); - auto arena_AB = to_arena(arena_A.val_op() * arena_B.val_op()); + auto arena_BDT = to_arena(arena_B.val() * arena_D.transpose()); + auto arena_AB = to_arena(arena_A.val() * arena_B.val()); var res = (arena_BDT.transpose() * arena_AB).trace(); @@ -182,10 +182,10 @@ inline var trace_gen_quad_form(const Td& D, const Ta& A, const Tb& B) { res]() mutable { double C_adj = res.adj(); - arena_A.adj() += C_adj * arena_BDT * arena_B.val_op().transpose(); + arena_A.adj() += C_adj * arena_BDT * arena_B.val().transpose(); arena_B.adj() += C_adj - * (arena_AB * arena_D + arena_A.val_op().transpose() * arena_BDT); + * (arena_AB * arena_D + arena_A.val().transpose() * arena_BDT); }); return res; @@ -195,10 +195,10 @@ inline var trace_gen_quad_form(const Td& D, const Ta& A, const Tb& B) { arena_t> arena_A = A; arena_t> arena_B = value_of(B); - auto arena_BDT = to_arena(arena_B.val_op() * arena_D.val_op().transpose()); - auto arena_AB = to_arena(arena_A.val_op() * arena_B.val_op()); + auto arena_BDT = to_arena(arena_B.val() * arena_D.val().transpose()); + auto arena_AB = to_arena(arena_A.val() * arena_B.val()); - var res = (arena_BDT.transpose() * arena_A.val_op() * arena_B).trace(); + var res = (arena_BDT.transpose() * arena_A.val() * arena_B).trace(); reverse_pass_callback( [arena_A, arena_B, arena_D, arena_BDT, arena_AB, res]() mutable { @@ -217,10 +217,10 @@ inline var trace_gen_quad_form(const Td& D, const Ta& A, const Tb& B) { auto arena_BDT = to_arena(arena_B * arena_D); - var res = (arena_BDT.transpose() * arena_A.val_op() * arena_B).trace(); + var res = (arena_BDT.transpose() * arena_A.val() * arena_B).trace(); reverse_pass_callback([arena_A, arena_B, arena_BDT, res]() mutable { - arena_A.adj() += res.adj() * arena_BDT * arena_B.val_op().transpose(); + arena_A.adj() += res.adj() * arena_BDT * arena_B.val().transpose(); }); return res; @@ -230,8 +230,8 @@ inline var trace_gen_quad_form(const Td& D, const Ta& A, const Tb& B) { arena_t> arena_A = value_of(A); arena_t> arena_B = B; - auto arena_AB = to_arena(arena_A * arena_B.val_op()); - auto arena_BDT = to_arena(arena_B.val_op() * arena_D.val_op()); + auto arena_AB = to_arena(arena_A * arena_B.val()); + auto arena_BDT = to_arena(arena_B.val() * arena_D.val()); var res = (arena_BDT.transpose() * arena_AB).trace(); @@ -241,9 +241,9 @@ inline var trace_gen_quad_form(const Td& D, const Ta& A, const Tb& B) { arena_B.adj() += C_adj - * (arena_AB * arena_D.val_op() + arena_A.transpose() * arena_BDT); + * (arena_AB * arena_D.val() + arena_A.transpose() * arena_BDT); - arena_D.adj() += C_adj * (arena_AB.transpose() * arena_B.val_op()); + arena_D.adj() += C_adj * (arena_AB.transpose() * arena_B.val()); }); return res; @@ -253,16 +253,16 @@ inline var trace_gen_quad_form(const Td& D, const Ta& A, const Tb& B) { arena_t> arena_A = value_of(A); arena_t> arena_B = B; - auto arena_AB = to_arena(arena_A * arena_B.val_op()); - auto arena_BDT = to_arena(arena_B.val_op() * arena_D.val_op()); + auto arena_AB = to_arena(arena_A * arena_B.val()); + auto arena_BDT = to_arena(arena_B.val() * arena_D.val()); var res = (arena_BDT.transpose() * arena_AB).trace(); reverse_pass_callback( [arena_A, arena_B, arena_D, arena_AB, arena_BDT, res]() mutable { arena_B.adj() += res.adj() - * (arena_AB * arena_D.val_op() - + arena_A.val_op().transpose() * arena_BDT); + * (arena_AB * arena_D.val() + + arena_A.val().transpose() * arena_BDT); }); return res; @@ -274,7 +274,7 @@ inline var trace_gen_quad_form(const Td& D, const Ta& A, const Tb& B) { auto arena_AB = to_arena(arena_A * arena_B); - var res = (arena_D.val_op() * arena_B.transpose() * arena_AB).trace(); + var res = (arena_D.val() * arena_B.transpose() * arena_AB).trace(); reverse_pass_callback([arena_AB, arena_B, arena_D, res]() mutable { arena_D.adj() += res.adj() * (arena_AB.transpose() * arena_B); diff --git a/stan/math/rev/fun/trace_inv_quad_form_ldlt.hpp b/stan/math/rev/fun/trace_inv_quad_form_ldlt.hpp index 207e029768c..bc75bf7bd02 100644 --- a/stan/math/rev/fun/trace_inv_quad_form_ldlt.hpp +++ b/stan/math/rev/fun/trace_inv_quad_form_ldlt.hpp @@ -40,7 +40,7 @@ inline var trace_inv_quad_form_ldlt(LDLT_factor& A, const T2& B) { arena_t> arena_B = B; auto AsolveB = to_arena(A.ldlt().solve(arena_B.val())); - var res = (arena_B.val_op().transpose() * AsolveB).trace(); + var res = (arena_B.val().transpose() * AsolveB).trace(); reverse_pass_callback([arena_A, AsolveB, arena_B, res]() mutable { arena_A.adj() += -res.adj() * AsolveB * AsolveB.transpose(); @@ -65,7 +65,7 @@ inline var trace_inv_quad_form_ldlt(LDLT_factor& A, const T2& B) { arena_t> arena_B = B; auto AsolveB = to_arena(A.ldlt().solve(arena_B.val())); - var res = (arena_B.val_op().transpose() * AsolveB).trace(); + var res = (arena_B.val().transpose() * AsolveB).trace(); reverse_pass_callback([AsolveB, arena_B, res]() mutable { arena_B.adj() += 2 * res.adj() * AsolveB; From 59e535e8423c47c9b44e17f6d7c5cf992b26c83f Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Fri, 15 Oct 2021 16:09:27 +0800 Subject: [PATCH 05/27] Identified broken tests --- stan/math/prim/eigen_plugins.h | 205 +--------------------- stan/math/prim/plugins/d_view.h | 86 +++++++++ stan/math/prim/plugins/vi_view.h | 88 ++++++++++ stan/math/rev/fun/multiply.hpp | 8 +- test/unit/math/mix/fun/multiply1_test.cpp | 30 ++-- 5 files changed, 196 insertions(+), 221 deletions(-) create mode 100644 stan/math/prim/plugins/d_view.h create mode 100644 stan/math/prim/plugins/vi_view.h diff --git a/stan/math/prim/eigen_plugins.h b/stan/math/prim/eigen_plugins.h index 2fd41d4858a..0767eeb7cd8 100644 --- a/stan/math/prim/eigen_plugins.h +++ b/stan/math/prim/eigen_plugins.h @@ -1,209 +1,10 @@ #include "plugins/typedefs.h" #include "plugins/adj_view.h" -//#include "plugins/val_view.h" +#include "plugins/val_view.h" +#include "plugins/d_view.h" +#include "plugins/vi_view.h" -//TODO(Andrew): Replace std::is_const<>::value with std::is_const_v<> after move to C++17 -template -using double_return_t = std::conditional_t>::value, - const double, - double>; -template -using reverse_return_t = std::conditional_t>::value, - const double&, - double&>; - -template -using vari_return_t = std::conditional_t>::value, - const decltype(T::vi_)&, - decltype(T::vi_)&>; - -template -using forward_return_t = std::conditional_t>::value, - const decltype(T::val_)&, - decltype(T::val_)&>; - -/** - * Structure to return a view to the values in a var, vari*, and fvar. - * To identify the correct member to call for a given input, templates - * check a combination of whether the input is a pointer (i.e. vari*) - * and/or whether the input has member ".d_" (i.e. fvar). - * - * There are two methods for returning doubles unchanged. One which takes a reference - * to a double and returns the same reference, used when 'chaining' methods - * (i.e. A.adj().val()). The other for passing and returning by value, used directly - * with matrices of doubles (i.e. A.val(), where A is of type MatrixXd). - * - * For definitions of EIGEN_EMPTY_STRUCT_CTOR, EIGEN_DEVICE_FUNC, and - * EIGEN_STRONG_INLINE; see: https://eigen.tuxfamily.org/dox/XprHelper_8h_source.html - */ - -struct val_Op{ - EIGEN_EMPTY_STRUCT_CTOR(val_Op); - - //Returns value from a vari* - template - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - std::enable_if_t::value, const double&> - operator()(T &v) const { return v->val_; } - - //Returns value from a var - template - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - std::enable_if_t<(!std::is_pointer::value && !is_fvar::value - && !std::is_arithmetic::value), const double&> - operator()(T &v) const { return v.vi_->val_; } - - //Returns value from an fvar - template - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - std::enable_if_t::value, forward_return_t> - operator()(T &v) const { return v.val_; } - - //Returns double unchanged from input (by value) - template - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - std::enable_if_t::value, double_return_t> - operator()(T v) const { return v; } - - //Returns double unchanged from input (by reference) - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const double& operator()(const double& v) const { return v; } - - //Returns double unchanged from input (by reference) - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - double& operator()(double& v) const { return v; } -}; - -/** - * Coefficient-wise function applying val_Op struct to a matrix of const var - * or vari* and returning a view to the const matrix of doubles containing - * the values - */ -inline const CwiseUnaryOp -val() const { return CwiseUnaryOp(derived()); -} - -/** - * Coefficient-wise function applying val_Op struct to a matrix of var - * or vari* and returning a view to the matrix of doubles containing - * the values - */ -inline CwiseUnaryOp -val_op() { return CwiseUnaryOp(derived()); -} - -/** - * Coefficient-wise function applying val_Op struct to a matrix of var - * or vari* and returning a view to the values - */ -inline CwiseUnaryView -val() { return CwiseUnaryView(derived()); -} - -/** - * Structure to return tangent from an fvar. - */ -struct d_Op { - EIGEN_EMPTY_STRUCT_CTOR(d_Op); - - //Returns tangent from an fvar - template - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - forward_return_t operator()(T &v) const { return v.d_; } -}; - -/** - * Coefficient-wise function applying d_Op struct to a matrix of const fvar - * and returning a const matrix of type T containing the tangents - */ -inline const CwiseUnaryOp -d() const { return CwiseUnaryOp(derived()); -} - -/** - * Coefficient-wise function applying d_Op struct to a matrix of fvar - * and returning a view to a matrix of type T of the tangents that can - * be modified - */ -inline CwiseUnaryView -d() { return CwiseUnaryView(derived()); -} - -/** - * Structure to return adjoints from var and vari*. Deduces whether the variables - * are pointers (i.e. vari*) to determine whether to return the adjoint or - * first point to the underlying vari* (in the case of var). - *//* -struct adj_Op { - EIGEN_EMPTY_STRUCT_CTOR(adj_Op); - - //Returns adjoint from a vari* - template - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - std::enable_if_t::value, reverse_return_t> - operator()(T &v) const { return v->adj_; } - - //Returns adjoint from a var - template - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - std::enable_if_t::value, reverse_return_t> - operator()(T &v) const { return v.vi_->adj_; } - -};*/ - -/** - * Coefficient-wise function applying adj_Op struct to a matrix of const var - * and returning a const matrix of type T containing the values - *//* -inline const CwiseUnaryOp -adj() const { return CwiseUnaryOp(derived()); -} -*/ -/** - * Coefficient-wise function applying adj_Op struct to a matrix of var - * and returning a view to a matrix of doubles of the adjoints that can - * be modified. This is meant to be used on the rhs of expressions. - *//* -inline CwiseUnaryOp adj_op() { - return CwiseUnaryOp(derived()); -}*/ - -/** - * Coefficient-wise function applying adj_Op struct to a matrix of var - * and returning a view to a matrix of doubles of the adjoints that can - * be modified - *//* -inline CwiseUnaryView -adj() { return CwiseUnaryView(derived()); -}*/ -/** - * Structure to return vari* from a var. - */ -struct vi_Op { - EIGEN_EMPTY_STRUCT_CTOR(vi_Op); - - //Returns vari* from a var - template - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - vari_return_t operator()(T &v) const { return v.vi_; } -}; - -/** - * Coefficient-wise function applying vi_Op struct to a matrix of const var - * and returning a const matrix of vari* - */ -inline const CwiseUnaryOp -vi() const { return CwiseUnaryOp(derived()); -} - -/** - * Coefficient-wise function applying vi_Op struct to a matrix of var - * and returning a view to a matrix of vari* that can be modified - */ -inline CwiseUnaryView -vi() { return CwiseUnaryView(derived()); -} EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE diff --git a/stan/math/prim/plugins/d_view.h b/stan/math/prim/plugins/d_view.h new file mode 100644 index 00000000000..dc6440b0666 --- /dev/null +++ b/stan/math/prim/plugins/d_view.h @@ -0,0 +1,86 @@ +template +using d_return_t = decltype(T::d_); + +template +EIGEN_DEVICE_FUNC +static inline const d_return_t& d(const Scalar& x) { + return d_impl>::run(x); +} + +template +EIGEN_DEVICE_FUNC +static inline const d_return_t& +d_ref(const Scalar& x) { + return d_ref_impl::run(x); +} + +template +EIGEN_DEVICE_FUNC +static inline d_return_t& d_ref(Scalar& x) { + return d_ref_impl>::run(x); +} + +template +struct d_default_impl { }; + +template +struct d_default_impl::value>> { + EIGEN_DEVICE_FUNC + static inline d_return_t& run(Scalar& x) { + return x.d_; + } + EIGEN_DEVICE_FUNC + static inline const d_return_t& run(const Scalar& x) { + return x.d_; + } +}; + +template +struct d_impl : d_default_impl {}; + +template +struct scalar_d_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_d_op) + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE d_return_t operator() (const Scalar& a) const { return d(a); } +}; + + +template +struct d_ref_default_impl { }; + +template +struct d_ref_default_impl::value>> { + EIGEN_DEVICE_FUNC + static inline d_return_t& run(Scalar& x) { + return *reinterpret_cast*>(&(x.d_)); + } + EIGEN_DEVICE_FUNC + static inline const d_return_t& run(const Scalar& x) { + return *reinterpret_cast*>(&(x.d_)); + } +}; + +template +struct d_ref_impl : d_ref_default_impl {}; + +template +struct scalar_d_ref_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_d_ref_op) + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE d_return_t& operator() (const Scalar& a) const { + return d_ref(*const_cast(&a)); + } +}; + +typedef CwiseUnaryOp, const Derived> dReturnType; +typedef CwiseUnaryView, Derived> NonConstdReturnType; + +EIGEN_DEVICE_FUNC +inline const dReturnType +d() const { return dReturnType(derived()); } + + +EIGEN_DEVICE_FUNC +inline NonConstdReturnType +d() { return NonConstdReturnType(derived()); } \ No newline at end of file diff --git a/stan/math/prim/plugins/vi_view.h b/stan/math/prim/plugins/vi_view.h new file mode 100644 index 00000000000..da1e8cfbfcd --- /dev/null +++ b/stan/math/prim/plugins/vi_view.h @@ -0,0 +1,88 @@ +template +using vi_return_t = decltype(T::vi_); + +template +EIGEN_DEVICE_FUNC +static inline const vi_return_t& vi(const Scalar& x) { + return vi_impl>::run(x); +} + +template +EIGEN_DEVICE_FUNC +static inline const vi_return_t& +vi_ref(const Scalar& x) { + return vi_ref_impl::run(x); +} + +template +EIGEN_DEVICE_FUNC +static inline vi_return_t& vi_ref(Scalar& x) { + return vi_ref_impl>::run(x); +} + +template +struct vi_default_impl { }; + +template +struct vi_default_impl::value>> { + EIGEN_DEVICE_FUNC + static inline vi_return_t& run(Scalar& x) { + return x.vi_; + } + EIGEN_DEVICE_FUNC + static inline const vi_return_t& run(const Scalar& x) { + return x.vi_; + } +}; + +template +struct vi_impl : vi_default_impl {}; + +template +struct scalar_vi_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_vi_op) + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE vi_return_t operator() (const Scalar& a) const { + return vi(a); + } +}; + + +template +struct vi_ref_default_impl { }; + +template +struct vi_ref_default_impl::value>> { + EIGEN_DEVICE_FUNC + static inline vi_return_t& run(Scalar& x) { + return *reinterpret_cast*>(&(x.vi_)); + } + EIGEN_DEVICE_FUNC + static inline const vi_return_t& run(const Scalar& x) { + return *reinterpret_cast*>(&(x.vi_)); + } +}; + +template +struct vi_ref_impl : vi_ref_default_impl {}; + +template +struct scalar_vi_ref_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_vi_ref_op) + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE vi_return_t& operator() (const Scalar& a) const { + return vi_ref(*const_cast(&a)); + } +}; + +typedef CwiseUnaryOp, const Derived> viReturnType; +typedef CwiseUnaryView, Derived> NonConstviReturnType; + +EIGEN_DEVICE_FUNC +inline const viReturnType +vi() const { return viReturnType(derived()); } + + +EIGEN_DEVICE_FUNC +inline NonConstviReturnType +vi() { return NonConstviReturnType(derived()); } \ No newline at end of file diff --git a/stan/math/rev/fun/multiply.hpp b/stan/math/rev/fun/multiply.hpp index eaa4a1755ac..e8cd8daa3d1 100644 --- a/stan/math/rev/fun/multiply.hpp +++ b/stan/math/rev/fun/multiply.hpp @@ -40,12 +40,12 @@ inline auto multiply(const T1& A, const T2& B) { reverse_pass_callback( [arena_A, arena_B, arena_A_val, arena_B_val, res]() mutable { if (is_var_matrix::value || is_var_matrix::value) { - arena_A.adj() += res.adj() * arena_B_val.transpose(); - arena_B.adj() += arena_A_val.transpose() * res.adj(); + arena_A.adj() += res.adj() * arena_B_val.transpose().eval(); + arena_B.adj() += arena_A_val.transpose().eval() * res.adj(); } else { auto res_adj = res.adj().eval(); - arena_A.adj() += res_adj * arena_B_val.transpose(); - arena_B.adj() += arena_A_val.transpose() * res_adj; + arena_A.adj() += res_adj * arena_B_val.transpose().eval(); + arena_B.adj() += arena_A_val.transpose().eval() * res_adj; } }); return return_t(res); diff --git a/test/unit/math/mix/fun/multiply1_test.cpp b/test/unit/math/mix/fun/multiply1_test.cpp index 46cc5988319..453db61e70c 100644 --- a/test/unit/math/mix/fun/multiply1_test.cpp +++ b/test/unit/math/mix/fun/multiply1_test.cpp @@ -73,28 +73,28 @@ TEST(mathMixMatFun, multiply) { rv << 100, -3; Eigen::MatrixXd m(2, 2); m << 100, 0, -3, 4; - stan::test::expect_ad(tols, f, a, v); + stan::test::expect_ad(tols, f, a, v); stan::test::expect_ad(tols, f, v, a); stan::test::expect_ad(tols, f, a, rv); stan::test::expect_ad(tols, f, rv, a); stan::test::expect_ad(tols, f, rv, v); - stan::test::expect_ad(tols, f, v, rv); + //stan::test::expect_ad(tols, f, v, rv); stan::test::expect_ad(tols, f, a, m); stan::test::expect_ad(tols, f, m, a); - stan::test::expect_ad(tols, f, m, v); - stan::test::expect_ad(tols, f, rv, m); + //stan::test::expect_ad(tols, f, m, v); + //stan::test::expect_ad(tols, f, rv, m); stan::test::expect_ad(tols, f, m, m); - + stan::test::expect_ad_matvar(tols, f, a, v); stan::test::expect_ad_matvar(tols, f, v, a); stan::test::expect_ad_matvar(tols, f, a, rv); stan::test::expect_ad_matvar(tols, f, rv, a); stan::test::expect_ad_matvar(tols, f, rv, v); - stan::test::expect_ad_matvar(tols, f, v, rv); + // stan::test::expect_ad_matvar(tols, f, v, rv); stan::test::expect_ad_matvar(tols, f, a, m); stan::test::expect_ad_matvar(tols, f, m, a); - stan::test::expect_ad_matvar(tols, f, m, v); - stan::test::expect_ad_matvar(tols, f, rv, m); + //stan::test::expect_ad_matvar(tols, f, m, v); + //stan::test::expect_ad_matvar(tols, f, rv, m); stan::test::expect_ad_matvar(tols, f, m, m); Eigen::RowVectorXd d1(3); @@ -102,10 +102,10 @@ TEST(mathMixMatFun, multiply) { Eigen::VectorXd d2(3); d2 << 4, -2, -1; stan::test::expect_ad(tols, f, d1, d2); - stan::test::expect_ad(tols, f, d2, d1); + //stan::test::expect_ad(tols, f, d2, d1); stan::test::expect_ad_matvar(tols, f, d1, d2); - stan::test::expect_ad_matvar(tols, f, d2, d1); + //stan::test::expect_ad_matvar(tols, f, d2, d1); Eigen::MatrixXd u(3, 2); u << 1, 3, -5, 4, -2, -1; @@ -116,13 +116,13 @@ TEST(mathMixMatFun, multiply) { rvv << -2, 4, 1; stan::test::expect_ad(tols, f, u, u_tr); stan::test::expect_ad(tols, f, u_tr, u); - stan::test::expect_ad(tols, f, u, vv); - stan::test::expect_ad(tols, f, rvv, u); - + //stan::test::expect_ad(tols, f, u, vv); + //stan::test::expect_ad(tols, f, rvv, u); + stan::test::expect_ad_matvar(tols, f, u, u_tr); stan::test::expect_ad_matvar(tols, f, u_tr, u); - stan::test::expect_ad_matvar(tols, f, u, vv); - stan::test::expect_ad_matvar(tols, f, rvv, u); + //stan::test::expect_ad_matvar(tols, f, u, vv); + //stan::test::expect_ad_matvar(tols, f, rvv, u); // exception cases // can't compile mismatched dimensions, so no tests From 4b56e56c64675607a021b30a478957d4ecfe955d Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Sun, 17 Oct 2021 22:44:44 +0800 Subject: [PATCH 06/27] Fix stride detection --- .../Eigen/src/Core/CoreEvaluators.h | 8 ++--- .../Eigen/src/Core/CwiseUnaryView.h | 36 +++++++++---------- .../Eigen/src/Core/util/ForwardDeclarations.h | 2 +- stan/math/prim/plugins/adj_view.h | 21 ++++++++++- stan/math/prim/plugins/d_view.h | 13 ++++++- stan/math/prim/plugins/val_view.h | 26 +++++++++++++- test/unit/math/mix/fun/multiply1_test.cpp | 24 ++++++------- test/unit/math/rev/eigen_plugins_test.cpp | 1 + 8 files changed, 93 insertions(+), 38 deletions(-) diff --git a/lib/eigen_3.3.9/Eigen/src/Core/CoreEvaluators.h b/lib/eigen_3.3.9/Eigen/src/Core/CoreEvaluators.h index 910889efa70..f45dfe46f1a 100644 --- a/lib/eigen_3.3.9/Eigen/src/Core/CoreEvaluators.h +++ b/lib/eigen_3.3.9/Eigen/src/Core/CoreEvaluators.h @@ -743,11 +743,11 @@ struct binary_evaluator, IndexBased, IndexBase // -------------------- CwiseUnaryView -------------------- -template -struct unary_evaluator, IndexBased> - : evaluator_base > +template +struct unary_evaluator, IndexBased> + : evaluator_base > { - typedef CwiseUnaryView XprType; + typedef CwiseUnaryView XprType; enum { CoeffReadCost = evaluator::CoeffReadCost + functor_traits::Cost, diff --git a/lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h b/lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h index 5a30fa8df18..953852583e6 100644 --- a/lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h +++ b/lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h @@ -13,8 +13,8 @@ namespace Eigen { namespace internal { -template -struct traits > +template +struct traits > : traits { typedef typename result_of< @@ -30,15 +30,15 @@ struct traits > // "error: no integral type can represent all of the enumerator values InnerStrideAtCompileTime = MatrixTypeInnerStride == Dynamic ? int(Dynamic) - : int(MatrixTypeInnerStride) * int(sizeof(typename traits::Scalar) / sizeof(Scalar)), + : int(MatrixTypeInnerStride) * InnerStride, OuterStrideAtCompileTime = outer_stride_at_compile_time::ret == Dynamic ? int(Dynamic) - : outer_stride_at_compile_time::ret * int(sizeof(typename traits::Scalar) / sizeof(Scalar)) + : outer_stride_at_compile_time::ret * OuterStride }; }; } -template +template class CwiseUnaryViewImpl; /** \class CwiseUnaryView @@ -54,12 +54,12 @@ class CwiseUnaryViewImpl; * * \sa MatrixBase::unaryViewExpr(const CustomUnaryOp &) const, class CwiseUnaryOp */ -template -class CwiseUnaryView : public CwiseUnaryViewImpl::StorageKind> +template +class CwiseUnaryView : public CwiseUnaryViewImpl::StorageKind> { public: - typedef typename CwiseUnaryViewImpl::StorageKind>::Base Base; + typedef typename CwiseUnaryViewImpl::StorageKind>::Base Base; EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseUnaryView) typedef typename internal::ref_selector::non_const_type MatrixTypeNested; typedef typename internal::remove_all::type NestedExpression; @@ -89,22 +89,22 @@ class CwiseUnaryView : public CwiseUnaryViewImpl +template class CwiseUnaryViewImpl - : public internal::generic_xpr_base >::type + : public internal::generic_xpr_base >::type { public: - typedef typename internal::generic_xpr_base >::type Base; + typedef typename internal::generic_xpr_base >::type Base; }; -template -class CwiseUnaryViewImpl - : public internal::dense_xpr_base< CwiseUnaryView >::type +template +class CwiseUnaryViewImpl + : public internal::dense_xpr_base< CwiseUnaryView >::type { public: - typedef CwiseUnaryView Derived; - typedef typename internal::dense_xpr_base< CwiseUnaryView >::type Base; + typedef CwiseUnaryView Derived; + typedef typename internal::dense_xpr_base< CwiseUnaryView >::type Base; EIGEN_DENSE_PUBLIC_INTERFACE(Derived) EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CwiseUnaryViewImpl) @@ -114,12 +114,12 @@ class CwiseUnaryViewImpl EIGEN_DEVICE_FUNC inline Index innerStride() const { - return derived().nestedExpression().innerStride() * sizeof(typename internal::traits::Scalar) / sizeof(Scalar); + return derived().nestedExpression().innerStride() * InnerStride; } EIGEN_DEVICE_FUNC inline Index outerStride() const { - return derived().nestedExpression().outerStride() * sizeof(typename internal::traits::Scalar) / sizeof(Scalar); + return derived().nestedExpression().outerStride() * OuterStride; } protected: EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(CwiseUnaryViewImpl) diff --git a/lib/eigen_3.3.9/Eigen/src/Core/util/ForwardDeclarations.h b/lib/eigen_3.3.9/Eigen/src/Core/util/ForwardDeclarations.h index 134544f9643..ed8a8a7c610 100644 --- a/lib/eigen_3.3.9/Eigen/src/Core/util/ForwardDeclarations.h +++ b/lib/eigen_3.3.9/Eigen/src/Core/util/ForwardDeclarations.h @@ -85,7 +85,7 @@ template class Transpose; template class Conjugate; template class CwiseNullaryOp; template class CwiseUnaryOp; -template class CwiseUnaryView; +template class CwiseUnaryView; template class CwiseBinaryOp; template class CwiseTernaryOp; template class Solve; diff --git a/stan/math/prim/plugins/adj_view.h b/stan/math/prim/plugins/adj_view.h index 6cc246d3770..8b69f785f41 100644 --- a/stan/math/prim/plugins/adj_view.h +++ b/stan/math/prim/plugins/adj_view.h @@ -94,8 +94,27 @@ struct scalar_adj_ref_op { } }; +template +struct adj_stride { + static constexpr int stride = 1; +}; + +template +struct adj_stride::value>> { + using vari_t = std::remove_pointer_t; + static constexpr int stride = sizeof(vari_t) / sizeof(typename vari_t::value_type); +}; + +template +struct adj_stride::value>> { + using vari_t = std::remove_pointer_t().vi_)>; + static constexpr int stride = sizeof(vari_t) / sizeof(typename vari_t::value_type); +}; + typedef CwiseUnaryOp, const Derived> AdjReturnType; -typedef CwiseUnaryView, Derived> NonConstAdjReturnType; +typedef CwiseUnaryView, Derived, + adj_stride::stride, + adj_stride::stride> NonConstAdjReturnType; EIGEN_DEVICE_FUNC inline const AdjReturnType diff --git a/stan/math/prim/plugins/d_view.h b/stan/math/prim/plugins/d_view.h index dc6440b0666..95c70c9c31e 100644 --- a/stan/math/prim/plugins/d_view.h +++ b/stan/math/prim/plugins/d_view.h @@ -73,8 +73,19 @@ struct scalar_d_ref_op { } }; +template +struct d_stride {}; + +template +struct d_stride::value>> { + using fvar_t = std::remove_pointer_t; + static constexpr int stride = sizeof(fvar_t) / sizeof(typename fvar_t::Scalar); +}; + typedef CwiseUnaryOp, const Derived> dReturnType; -typedef CwiseUnaryView, Derived> NonConstdReturnType; +typedef CwiseUnaryView, Derived, + d_stride::stride, + d_stride::stride> NonConstdReturnType; EIGEN_DEVICE_FUNC inline const dReturnType diff --git a/stan/math/prim/plugins/val_view.h b/stan/math/prim/plugins/val_view.h index 1f6615cf5bc..0cf893957dc 100644 --- a/stan/math/prim/plugins/val_view.h +++ b/stan/math/prim/plugins/val_view.h @@ -118,13 +118,37 @@ struct scalar_val_ref_op { } }; +template +struct val_stride { + static constexpr int stride = 1; +}; + +template +struct val_stride::value>> { + using vari_t = std::remove_pointer_t; + static constexpr int stride = sizeof(vari_t) / sizeof(typename vari_t::value_type); +}; + +template +struct val_stride::value>> { + using vari_t = std::remove_pointer_t().vi_)>; + static constexpr int stride = sizeof(vari_t) / sizeof(typename vari_t::value_type); +}; + +template +struct val_stride::value>> { + using fvar_t = typename T::Scalar; + static constexpr int stride = sizeof(fvar_t) / sizeof(typename fvar_t::Scalar); +}; /** \internal the return type of imag() const */ typedef CwiseUnaryOp, const Derived> valReturnType; /** \internal the return type of imag() */ typedef std::conditional_t::value || is_vari::value, const valReturnType, - CwiseUnaryView, Derived>> + CwiseUnaryView, Derived, + val_stride::stride, + val_stride::stride>> NonConstvalReturnType; EIGEN_DEVICE_FUNC diff --git a/test/unit/math/mix/fun/multiply1_test.cpp b/test/unit/math/mix/fun/multiply1_test.cpp index 453db61e70c..d9128ed9f08 100644 --- a/test/unit/math/mix/fun/multiply1_test.cpp +++ b/test/unit/math/mix/fun/multiply1_test.cpp @@ -78,11 +78,11 @@ TEST(mathMixMatFun, multiply) { stan::test::expect_ad(tols, f, a, rv); stan::test::expect_ad(tols, f, rv, a); stan::test::expect_ad(tols, f, rv, v); - //stan::test::expect_ad(tols, f, v, rv); + stan::test::expect_ad(tols, f, v, rv); stan::test::expect_ad(tols, f, a, m); stan::test::expect_ad(tols, f, m, a); - //stan::test::expect_ad(tols, f, m, v); - //stan::test::expect_ad(tols, f, rv, m); + stan::test::expect_ad(tols, f, m, v); + stan::test::expect_ad(tols, f, rv, m); stan::test::expect_ad(tols, f, m, m); stan::test::expect_ad_matvar(tols, f, a, v); @@ -90,11 +90,11 @@ TEST(mathMixMatFun, multiply) { stan::test::expect_ad_matvar(tols, f, a, rv); stan::test::expect_ad_matvar(tols, f, rv, a); stan::test::expect_ad_matvar(tols, f, rv, v); - // stan::test::expect_ad_matvar(tols, f, v, rv); + stan::test::expect_ad_matvar(tols, f, v, rv); stan::test::expect_ad_matvar(tols, f, a, m); stan::test::expect_ad_matvar(tols, f, m, a); - //stan::test::expect_ad_matvar(tols, f, m, v); - //stan::test::expect_ad_matvar(tols, f, rv, m); + stan::test::expect_ad_matvar(tols, f, m, v); + stan::test::expect_ad_matvar(tols, f, rv, m); stan::test::expect_ad_matvar(tols, f, m, m); Eigen::RowVectorXd d1(3); @@ -102,10 +102,10 @@ TEST(mathMixMatFun, multiply) { Eigen::VectorXd d2(3); d2 << 4, -2, -1; stan::test::expect_ad(tols, f, d1, d2); - //stan::test::expect_ad(tols, f, d2, d1); + stan::test::expect_ad(tols, f, d2, d1); stan::test::expect_ad_matvar(tols, f, d1, d2); - //stan::test::expect_ad_matvar(tols, f, d2, d1); + stan::test::expect_ad_matvar(tols, f, d2, d1); Eigen::MatrixXd u(3, 2); u << 1, 3, -5, 4, -2, -1; @@ -116,13 +116,13 @@ TEST(mathMixMatFun, multiply) { rvv << -2, 4, 1; stan::test::expect_ad(tols, f, u, u_tr); stan::test::expect_ad(tols, f, u_tr, u); - //stan::test::expect_ad(tols, f, u, vv); - //stan::test::expect_ad(tols, f, rvv, u); + stan::test::expect_ad(tols, f, u, vv); + stan::test::expect_ad(tols, f, rvv, u); stan::test::expect_ad_matvar(tols, f, u, u_tr); stan::test::expect_ad_matvar(tols, f, u_tr, u); - //stan::test::expect_ad_matvar(tols, f, u, vv); - //stan::test::expect_ad_matvar(tols, f, rvv, u); + stan::test::expect_ad_matvar(tols, f, u, vv); + stan::test::expect_ad_matvar(tols, f, rvv, u); // exception cases // can't compile mismatched dimensions, so no tests diff --git a/test/unit/math/rev/eigen_plugins_test.cpp b/test/unit/math/rev/eigen_plugins_test.cpp index b7c67917fd7..0dc6bbf5ae1 100644 --- a/test/unit/math/rev/eigen_plugins_test.cpp +++ b/test/unit/math/rev/eigen_plugins_test.cpp @@ -18,6 +18,7 @@ TEST(AgradRevMatrixAddons, var_matrix) { EXPECT_MATRIX_FLOAT_EQ(vals, mat_in.val()); EXPECT_MATRIX_FLOAT_EQ(vals.val(), mat_in.val()); + EXPECT_MATRIX_FLOAT_EQ(derivs * derivs, mat_in.adj() * mat_in.adj()); EXPECT_MATRIX_FLOAT_EQ(vals.array().exp(), mat_in.val().array().exp()); EXPECT_MATRIX_FLOAT_EQ(derivs, mat_in.adj()); From dc79dc957054857753a172429d5eaead4dfe295d Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Sun, 17 Oct 2021 22:52:04 +0800 Subject: [PATCH 07/27] Fix stride detection --- stan/math/prim/plugins/d_view.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/stan/math/prim/plugins/d_view.h b/stan/math/prim/plugins/d_view.h index 95c70c9c31e..4de9f08e938 100644 --- a/stan/math/prim/plugins/d_view.h +++ b/stan/math/prim/plugins/d_view.h @@ -74,7 +74,9 @@ struct scalar_d_ref_op { }; template -struct d_stride {}; +struct d_stride { + static constexpr int stride = 1; +}; template struct d_stride::value>> { From ccb90f4d090dd4ddfcd0db6162c58506107e33ca Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Sun, 17 Oct 2021 23:45:47 +0800 Subject: [PATCH 08/27] Stride defaults --- lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h b/lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h index 953852583e6..2ad3ffec9a9 100644 --- a/lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h +++ b/lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h @@ -30,10 +30,12 @@ struct traits > // "error: no integral type can represent all of the enumerator values InnerStrideAtCompileTime = MatrixTypeInnerStride == Dynamic ? int(Dynamic) - : int(MatrixTypeInnerStride) * InnerStride, + : int(MatrixTypeInnerStride) + * (InnerStride == -1) ? int(sizeof(typename traits::Scalar) / sizeof(Scalar)) : InnerStride, OuterStrideAtCompileTime = outer_stride_at_compile_time::ret == Dynamic ? int(Dynamic) - : outer_stride_at_compile_time::ret * OuterStride + : outer_stride_at_compile_time::ret + * (OuterStride == -1) ? int(sizeof(typename traits::Scalar) / sizeof(Scalar)) : OuterStride }; }; } @@ -114,12 +116,14 @@ class CwiseUnaryViewImpl EIGEN_DEVICE_FUNC inline Index innerStride() const { - return derived().nestedExpression().innerStride() * InnerStride; + return derived().nestedExpression().innerStride() + * (InnerStride == -1) ? sizeof(typename internal::traits::Scalar) / sizeof(Scalar) : InnerStride; } EIGEN_DEVICE_FUNC inline Index outerStride() const { - return derived().nestedExpression().outerStride() * OuterStride; + return derived().nestedExpression().outerStride() + * (OuterStride == -1) ? sizeof(typename internal::traits::Scalar) / sizeof(Scalar) : OuterStride; } protected: EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(CwiseUnaryViewImpl) From 82015b5ba9dbc8ac6e3ab20a3fbadd36a6115bc9 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Mon, 18 Oct 2021 00:23:03 +0800 Subject: [PATCH 09/27] Test stride calcs --- stan/math/prim/plugins/adj_view.h | 20 ++++++++++++-------- stan/math/prim/plugins/d_view.h | 15 +++++++++------ stan/math/prim/plugins/val_view.h | 26 ++++++++++++++++---------- 3 files changed, 37 insertions(+), 24 deletions(-) diff --git a/stan/math/prim/plugins/adj_view.h b/stan/math/prim/plugins/adj_view.h index 8b69f785f41..098e1fc1d8a 100644 --- a/stan/math/prim/plugins/adj_view.h +++ b/stan/math/prim/plugins/adj_view.h @@ -95,26 +95,30 @@ struct scalar_adj_ref_op { }; template -struct adj_stride { - static constexpr int stride = 1; +struct adj_stride { }; + +template +struct adj_stride::value + && !is_var::value>> { + static constexpr int stride = -1; }; template -struct adj_stride::value>> { - using vari_t = std::remove_pointer_t; +struct adj_stride::value>> { + using vari_t = std::remove_pointer_t; static constexpr int stride = sizeof(vari_t) / sizeof(typename vari_t::value_type); }; template -struct adj_stride::value>> { - using vari_t = std::remove_pointer_t().vi_)>; +struct adj_stride::value>> { + using vari_t = std::remove_pointer_t().vi_)>; static constexpr int stride = sizeof(vari_t) / sizeof(typename vari_t::value_type); }; typedef CwiseUnaryOp, const Derived> AdjReturnType; typedef CwiseUnaryView, Derived, - adj_stride::stride, - adj_stride::stride> NonConstAdjReturnType; + adj_stride::stride, + adj_stride::stride> NonConstAdjReturnType; EIGEN_DEVICE_FUNC inline const AdjReturnType diff --git a/stan/math/prim/plugins/d_view.h b/stan/math/prim/plugins/d_view.h index 4de9f08e938..199f4026998 100644 --- a/stan/math/prim/plugins/d_view.h +++ b/stan/math/prim/plugins/d_view.h @@ -74,20 +74,23 @@ struct scalar_d_ref_op { }; template -struct d_stride { - static constexpr int stride = 1; +struct d_stride { }; + +template +struct d_stride::value>> { + static constexpr int stride = -1; }; template -struct d_stride::value>> { - using fvar_t = std::remove_pointer_t; +struct d_stride::value>> { + using fvar_t = std::remove_pointer_t; static constexpr int stride = sizeof(fvar_t) / sizeof(typename fvar_t::Scalar); }; typedef CwiseUnaryOp, const Derived> dReturnType; typedef CwiseUnaryView, Derived, - d_stride::stride, - d_stride::stride> NonConstdReturnType; + d_stride::stride, + d_stride::stride> NonConstdReturnType; EIGEN_DEVICE_FUNC inline const dReturnType diff --git a/stan/math/prim/plugins/val_view.h b/stan/math/prim/plugins/val_view.h index 0cf893957dc..57f019c5743 100644 --- a/stan/math/prim/plugins/val_view.h +++ b/stan/math/prim/plugins/val_view.h @@ -119,25 +119,31 @@ struct scalar_val_ref_op { }; template -struct val_stride { - static constexpr int stride = 1; +struct val_stride { }; + + +template +struct val_stride::value + && !is_var::value + && !is_fvar::value>> { + static constexpr int stride = -1; }; template -struct val_stride::value>> { - using vari_t = std::remove_pointer_t; +struct val_stride::value>> { + using vari_t = std::remove_pointer_t; static constexpr int stride = sizeof(vari_t) / sizeof(typename vari_t::value_type); }; template -struct val_stride::value>> { - using vari_t = std::remove_pointer_t().vi_)>; +struct val_stride::value>> { + using vari_t = std::remove_pointer_t().vi_)>; static constexpr int stride = sizeof(vari_t) / sizeof(typename vari_t::value_type); }; template -struct val_stride::value>> { - using fvar_t = typename T::Scalar; +struct val_stride::value>> { + using fvar_t = T; static constexpr int stride = sizeof(fvar_t) / sizeof(typename fvar_t::Scalar); }; @@ -147,8 +153,8 @@ typedef CwiseUnaryOp, const Derived> valReturnType; typedef std::conditional_t::value || is_vari::value, const valReturnType, CwiseUnaryView, Derived, - val_stride::stride, - val_stride::stride>> + val_stride::stride, + val_stride::stride>> NonConstvalReturnType; EIGEN_DEVICE_FUNC From 264ef56b4097aec5ab1a13db6bde044310f0a9e2 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Mon, 18 Oct 2021 00:50:44 +0800 Subject: [PATCH 10/27] Remove coefreff --- stan/math/prim/eigen_plugins.h | 9 --------- 1 file changed, 9 deletions(-) diff --git a/stan/math/prim/eigen_plugins.h b/stan/math/prim/eigen_plugins.h index 0767eeb7cd8..86b47577ee4 100644 --- a/stan/math/prim/eigen_plugins.h +++ b/stan/math/prim/eigen_plugins.h @@ -5,14 +5,5 @@ #include "plugins/d_view.h" #include "plugins/vi_view.h" - - -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE -const Scalar& coeffRef(Index row, Index col) const { - eigen_internal_assert(row >= 0 && row < rows() - && col >= 0 && col < cols()); - return internal::evaluator(derived()).coeffRef(row, col); -} - #define EIGEN_STAN_MATRIXBASE_PLUGIN #define EIGEN_STAN_ARRAYBASE_PLUGIN From 4b374485652aef45d6f137facd9e1efe4e7a2b5a Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Mon, 18 Oct 2021 00:53:24 +0800 Subject: [PATCH 11/27] Revert "Remove coefreff" This reverts commit 264ef56b4097aec5ab1a13db6bde044310f0a9e2. --- stan/math/prim/eigen_plugins.h | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/stan/math/prim/eigen_plugins.h b/stan/math/prim/eigen_plugins.h index 86b47577ee4..0767eeb7cd8 100644 --- a/stan/math/prim/eigen_plugins.h +++ b/stan/math/prim/eigen_plugins.h @@ -5,5 +5,14 @@ #include "plugins/d_view.h" #include "plugins/vi_view.h" + + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +const Scalar& coeffRef(Index row, Index col) const { + eigen_internal_assert(row >= 0 && row < rows() + && col >= 0 && col < cols()); + return internal::evaluator(derived()).coeffRef(row, col); +} + #define EIGEN_STAN_MATRIXBASE_PLUGIN #define EIGEN_STAN_ARRAYBASE_PLUGIN From 5ea8563c71000cfc0dae222a8760130c0abd8111 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Mon, 18 Oct 2021 09:23:19 +0800 Subject: [PATCH 12/27] Fix stride setting --- lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h | 8 ++++---- test/unit/math/fwd/eigen_plugins_test.cpp | 5 ++++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h b/lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h index 2ad3ffec9a9..e230651207a 100644 --- a/lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h +++ b/lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h @@ -31,11 +31,11 @@ struct traits > InnerStrideAtCompileTime = MatrixTypeInnerStride == Dynamic ? int(Dynamic) : int(MatrixTypeInnerStride) - * (InnerStride == -1) ? int(sizeof(typename traits::Scalar) / sizeof(Scalar)) : InnerStride, + * ((InnerStride == -1) ? int(sizeof(typename traits::Scalar) / sizeof(Scalar)) : InnerStride), OuterStrideAtCompileTime = outer_stride_at_compile_time::ret == Dynamic ? int(Dynamic) : outer_stride_at_compile_time::ret - * (OuterStride == -1) ? int(sizeof(typename traits::Scalar) / sizeof(Scalar)) : OuterStride + * ((OuterStride == -1) ? int(sizeof(typename traits::Scalar) / sizeof(Scalar)) : OuterStride) }; }; } @@ -117,13 +117,13 @@ class CwiseUnaryViewImpl EIGEN_DEVICE_FUNC inline Index innerStride() const { return derived().nestedExpression().innerStride() - * (InnerStride == -1) ? sizeof(typename internal::traits::Scalar) / sizeof(Scalar) : InnerStride; + * ((InnerStride == -1) ? sizeof(typename internal::traits::Scalar) / sizeof(Scalar) : InnerStride); } EIGEN_DEVICE_FUNC inline Index outerStride() const { return derived().nestedExpression().outerStride() - * (OuterStride == -1) ? sizeof(typename internal::traits::Scalar) / sizeof(Scalar) : OuterStride; + * ((OuterStride == -1) ? sizeof(typename internal::traits::Scalar) / sizeof(Scalar) : OuterStride); } protected: EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(CwiseUnaryViewImpl) diff --git a/test/unit/math/fwd/eigen_plugins_test.cpp b/test/unit/math/fwd/eigen_plugins_test.cpp index 1e5ef8b0a76..e786a5cb6f8 100644 --- a/test/unit/math/fwd/eigen_plugins_test.cpp +++ b/test/unit/math/fwd/eigen_plugins_test.cpp @@ -19,6 +19,7 @@ TEST(AgradFwdMatrixAddons, fvar_double_matrix) { } EXPECT_MATRIX_FLOAT_EQ(vals, mat_in.val()); + EXPECT_MATRIX_FLOAT_EQ(derivs*derivs, mat_in.d() * mat_in.d()); EXPECT_MATRIX_FLOAT_EQ(vals.array().exp(), mat_in.val().array().exp()); EXPECT_MATRIX_FLOAT_EQ(derivs, mat_in.d()); @@ -30,7 +31,6 @@ TEST(AgradFwdMatrixAddons, fvar_double_matrix) { EXPECT_EQ(mat_in.d().rows(), derivs.rows()); EXPECT_EQ(mat_in.d().cols(), derivs.cols()); } - TEST(AgradFwdMatrixAddons, fvarfvar_double_matrix) { using Eigen::MatrixXd; using stan::math::matrix_ffd; @@ -48,6 +48,7 @@ TEST(AgradFwdMatrixAddons, fvarfvar_double_matrix) { } EXPECT_MATRIX_FLOAT_EQ(vals, mat_in.val().val()); + EXPECT_MATRIX_FLOAT_EQ(vals*vals, mat_in.val().val() * mat_in.val().val()); EXPECT_MATRIX_FLOAT_EQ(vals.array().exp(), mat_in.val().val().array().exp()); EXPECT_MATRIX_FLOAT_EQ(derivs, mat_in.d().val()); @@ -60,6 +61,7 @@ TEST(AgradFwdMatrixAddons, fvarfvar_double_matrix) { EXPECT_EQ(mat_in.d().cols(), derivs.cols()); } +/* TEST(AgradFwdMatrixAddons, fvar_double_vector) { using Eigen::VectorXd; using stan::math::vector_fd; @@ -169,3 +171,4 @@ TEST(AgradFwdMatrixAddons, fvarfvar_double_rowvector) { EXPECT_EQ(row_vec_in.d().rows(), derivs.rows()); EXPECT_EQ(row_vec_in.d().cols(), derivs.cols()); } +*/ \ No newline at end of file From 269b7f690c5006594006248ca9715232682be9fd Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Mon, 18 Oct 2021 09:37:10 +0800 Subject: [PATCH 13/27] cpplint --- test/unit/math/fwd/eigen_plugins_test.cpp | 2 -- test/unit/math/mix/fun/multiply1_test.cpp | 6 +++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/test/unit/math/fwd/eigen_plugins_test.cpp b/test/unit/math/fwd/eigen_plugins_test.cpp index e786a5cb6f8..2c06d0318d0 100644 --- a/test/unit/math/fwd/eigen_plugins_test.cpp +++ b/test/unit/math/fwd/eigen_plugins_test.cpp @@ -61,7 +61,6 @@ TEST(AgradFwdMatrixAddons, fvarfvar_double_matrix) { EXPECT_EQ(mat_in.d().cols(), derivs.cols()); } -/* TEST(AgradFwdMatrixAddons, fvar_double_vector) { using Eigen::VectorXd; using stan::math::vector_fd; @@ -171,4 +170,3 @@ TEST(AgradFwdMatrixAddons, fvarfvar_double_rowvector) { EXPECT_EQ(row_vec_in.d().rows(), derivs.rows()); EXPECT_EQ(row_vec_in.d().cols(), derivs.cols()); } -*/ \ No newline at end of file diff --git a/test/unit/math/mix/fun/multiply1_test.cpp b/test/unit/math/mix/fun/multiply1_test.cpp index d9128ed9f08..46cc5988319 100644 --- a/test/unit/math/mix/fun/multiply1_test.cpp +++ b/test/unit/math/mix/fun/multiply1_test.cpp @@ -73,7 +73,7 @@ TEST(mathMixMatFun, multiply) { rv << 100, -3; Eigen::MatrixXd m(2, 2); m << 100, 0, -3, 4; - stan::test::expect_ad(tols, f, a, v); + stan::test::expect_ad(tols, f, a, v); stan::test::expect_ad(tols, f, v, a); stan::test::expect_ad(tols, f, a, rv); stan::test::expect_ad(tols, f, rv, a); @@ -84,7 +84,7 @@ TEST(mathMixMatFun, multiply) { stan::test::expect_ad(tols, f, m, v); stan::test::expect_ad(tols, f, rv, m); stan::test::expect_ad(tols, f, m, m); - + stan::test::expect_ad_matvar(tols, f, a, v); stan::test::expect_ad_matvar(tols, f, v, a); stan::test::expect_ad_matvar(tols, f, a, rv); @@ -118,7 +118,7 @@ TEST(mathMixMatFun, multiply) { stan::test::expect_ad(tols, f, u_tr, u); stan::test::expect_ad(tols, f, u, vv); stan::test::expect_ad(tols, f, rvv, u); - + stan::test::expect_ad_matvar(tols, f, u, u_tr); stan::test::expect_ad_matvar(tols, f, u_tr, u); stan::test::expect_ad_matvar(tols, f, u, vv); From 90edbaf0cee4c7abb92544254d5d1387fc5425ac Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Mon, 18 Oct 2021 01:43:55 +0000 Subject: [PATCH 14/27] [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.04.1 (tags/RELEASE_600/final) --- stan/math/rev/fun/generalized_inverse.hpp | 6 +- stan/math/rev/fun/svd_U.hpp | 26 +++--- stan/math/rev/fun/tcrossprod.hpp | 3 +- .../rev/fun/trace_gen_inv_quad_form_ldlt.hpp | 82 +++++++++---------- stan/math/rev/fun/trace_gen_quad_form.hpp | 34 ++++---- test/unit/math/fwd/eigen_plugins_test.cpp | 4 +- 6 files changed, 76 insertions(+), 79 deletions(-) diff --git a/stan/math/rev/fun/generalized_inverse.hpp b/stan/math/rev/fun/generalized_inverse.hpp index 283ba338e0f..0e8a5b87aa8 100644 --- a/stan/math/rev/fun/generalized_inverse.hpp +++ b/stan/math/rev/fun/generalized_inverse.hpp @@ -23,14 +23,12 @@ template inline auto generalized_inverse_lambda(T1& G_arena, T2& inv_G) { return [G_arena, inv_G]() mutable { G_arena.adj() - += -(inv_G.val().transpose() * inv_G.adj() - * inv_G.val().transpose()) + += -(inv_G.val().transpose() * inv_G.adj() * inv_G.val().transpose()) + (-G_arena.val() * inv_G.val() + Eigen::MatrixXd::Identity(G_arena.rows(), inv_G.cols())) * inv_G.adj().transpose() * inv_G.val() * inv_G.val().transpose() - + inv_G.val().transpose() * inv_G.val() - * inv_G.adj().transpose() + + inv_G.val().transpose() * inv_G.val() * inv_G.adj().transpose() * (-inv_G.val() * G_arena.val() + Eigen::MatrixXd::Identity(inv_G.rows(), G_arena.cols())); }; diff --git a/stan/math/rev/fun/svd_U.hpp b/stan/math/rev/fun/svd_U.hpp index 80654e32c29..a0d18264c97 100644 --- a/stan/math/rev/fun/svd_U.hpp +++ b/stan/math/rev/fun/svd_U.hpp @@ -50,19 +50,19 @@ inline auto svd_U(const EigMat& m) { arena_t arena_U = svd.matrixU(); auto arena_V = to_arena(svd.matrixV()); - reverse_pass_callback([arena_m, arena_U, arena_D, arena_V, arena_Fp, - M]() mutable { - Eigen::MatrixXd UUadjT = arena_U.val().transpose() * arena_U.adj(); - arena_m.adj() - += .5 * arena_U.val() - * (arena_Fp.array() * (UUadjT - UUadjT.transpose()).array()) - .matrix() - * arena_V.transpose() - + (Eigen::MatrixXd::Identity(arena_m.rows(), arena_m.rows()) - - arena_U.val() * arena_U.val().transpose()) - * arena_U.adj() * arena_D.asDiagonal().inverse() - * arena_V.transpose(); - }); + reverse_pass_callback( + [arena_m, arena_U, arena_D, arena_V, arena_Fp, M]() mutable { + Eigen::MatrixXd UUadjT = arena_U.val().transpose() * arena_U.adj(); + arena_m.adj() + += .5 * arena_U.val() + * (arena_Fp.array() * (UUadjT - UUadjT.transpose()).array()) + .matrix() + * arena_V.transpose() + + (Eigen::MatrixXd::Identity(arena_m.rows(), arena_m.rows()) + - arena_U.val() * arena_U.val().transpose()) + * arena_U.adj() * arena_D.asDiagonal().inverse() + * arena_V.transpose(); + }); return ret_type(arena_U); } diff --git a/stan/math/rev/fun/tcrossprod.hpp b/stan/math/rev/fun/tcrossprod.hpp index f683186b6c6..f41de1f5a59 100644 --- a/stan/math/rev/fun/tcrossprod.hpp +++ b/stan/math/rev/fun/tcrossprod.hpp @@ -29,8 +29,7 @@ inline auto tcrossprod(const T& M) { if (likely(M.size() > 0)) { reverse_pass_callback([res, arena_M]() mutable { - arena_M.adj() - += (res.adj() + res.adj().transpose()) * arena_M.val(); + arena_M.adj() += (res.adj() + res.adj().transpose()) * arena_M.val(); }); } diff --git a/stan/math/rev/fun/trace_gen_inv_quad_form_ldlt.hpp b/stan/math/rev/fun/trace_gen_inv_quad_form_ldlt.hpp index 01df7e0a222..d0d6f3a6cc7 100644 --- a/stan/math/rev/fun/trace_gen_inv_quad_form_ldlt.hpp +++ b/stan/math/rev/fun/trace_gen_inv_quad_form_ldlt.hpp @@ -50,16 +50,16 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor& A, var res = (arena_D.val() * BTAsolveB).trace(); - reverse_pass_callback( - [arena_A, BTAsolveB, AsolveB, arena_B, arena_D, res]() mutable { - double C_adj = res.adj(); + reverse_pass_callback([arena_A, BTAsolveB, AsolveB, arena_B, arena_D, + res]() mutable { + double C_adj = res.adj(); - arena_A.adj() -= C_adj * AsolveB * arena_D.val().transpose() - * AsolveB.transpose(); - arena_B.adj() += C_adj * AsolveB - * (arena_D.val() + arena_D.val().transpose()); - arena_D.adj() += C_adj * BTAsolveB; - }); + arena_A.adj() + -= C_adj * AsolveB * arena_D.val().transpose() * AsolveB.transpose(); + arena_B.adj() + += C_adj * AsolveB * (arena_D.val() + arena_D.val().transpose()); + arena_D.adj() += C_adj * BTAsolveB; + }); return res; } else if (!is_constant::value && !is_constant::value @@ -90,14 +90,14 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor& A, var res = (arena_D.val() * BTAsolveB).trace(); - reverse_pass_callback( - [arena_A, BTAsolveB, AsolveB, arena_D, res]() mutable { - double C_adj = res.adj(); + reverse_pass_callback([arena_A, BTAsolveB, AsolveB, arena_D, + res]() mutable { + double C_adj = res.adj(); - arena_A.adj() -= C_adj * AsolveB * arena_D.val().transpose() - * AsolveB.transpose(); - arena_D.adj() += C_adj * BTAsolveB; - }); + arena_A.adj() + -= C_adj * AsolveB * arena_D.val().transpose() * AsolveB.transpose(); + arena_D.adj() += C_adj * BTAsolveB; + }); return res; } else if (!is_constant::value && is_constant::value @@ -112,8 +112,8 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor& A, reverse_pass_callback([arena_A, AsolveB, arena_D, res]() mutable { double C_adj = res.adj(); - arena_A.adj() -= C_adj * AsolveB * arena_D.val().transpose() - * AsolveB.transpose(); + arena_A.adj() + -= C_adj * AsolveB * arena_D.val().transpose() * AsolveB.transpose(); }); return res; @@ -130,8 +130,8 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor& A, [BTAsolveB, AsolveB, arena_B, arena_D, res]() mutable { double C_adj = res.adj(); - arena_B.adj() += C_adj * AsolveB - * (arena_D.val() + arena_D.val().transpose()); + arena_B.adj() + += C_adj * AsolveB * (arena_D.val() + arena_D.val().transpose()); arena_D.adj() += C_adj * BTAsolveB; }); @@ -206,15 +206,15 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor& A, var res = (arena_D.val().asDiagonal() * BTAsolveB).trace(); - reverse_pass_callback( - [arena_A, BTAsolveB, AsolveB, arena_B, arena_D, res]() mutable { - double C_adj = res.adj(); + reverse_pass_callback([arena_A, BTAsolveB, AsolveB, arena_B, arena_D, + res]() mutable { + double C_adj = res.adj(); - arena_A.adj() -= C_adj * AsolveB * arena_D.val().asDiagonal() - * AsolveB.transpose(); - arena_B.adj() += C_adj * AsolveB * 2 * arena_D.val().asDiagonal(); - arena_D.adj() += C_adj * BTAsolveB.diagonal(); - }); + arena_A.adj() + -= C_adj * AsolveB * arena_D.val().asDiagonal() * AsolveB.transpose(); + arena_B.adj() += C_adj * AsolveB * 2 * arena_D.val().asDiagonal(); + arena_D.adj() += C_adj * BTAsolveB.diagonal(); + }); return res; } else if (!is_constant::value && !is_constant::value @@ -224,8 +224,8 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor& A, arena_t> arena_D = value_of(D); auto AsolveB = to_arena(A.ldlt().solve(arena_B.val())); - var res = (arena_D.asDiagonal() * arena_B.val().transpose() * AsolveB) - .trace(); + var res + = (arena_D.asDiagonal() * arena_B.val().transpose() * AsolveB).trace(); reverse_pass_callback([arena_A, AsolveB, arena_B, arena_D, res]() mutable { double C_adj = res.adj(); @@ -246,14 +246,14 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor& A, var res = (arena_D.val().asDiagonal() * BTAsolveB).trace(); - reverse_pass_callback( - [arena_A, BTAsolveB, AsolveB, arena_D, res]() mutable { - double C_adj = res.adj(); + reverse_pass_callback([arena_A, BTAsolveB, AsolveB, arena_D, + res]() mutable { + double C_adj = res.adj(); - arena_A.adj() -= C_adj * AsolveB * arena_D.val().asDiagonal() - * AsolveB.transpose(); - arena_D.adj() += C_adj * BTAsolveB.diagonal(); - }); + arena_A.adj() + -= C_adj * AsolveB * arena_D.val().asDiagonal() * AsolveB.transpose(); + arena_D.adj() += C_adj * BTAsolveB.diagonal(); + }); return res; } else if (!is_constant::value && is_constant::value @@ -269,8 +269,8 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor& A, reverse_pass_callback([arena_A, AsolveB, arena_D, res]() mutable { double C_adj = res.adj(); - arena_A.adj() -= C_adj * AsolveB * arena_D.val().asDiagonal() - * AsolveB.transpose(); + arena_A.adj() + -= C_adj * AsolveB * arena_D.val().asDiagonal() * AsolveB.transpose(); }); return res; @@ -298,8 +298,8 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor& A, arena_t> arena_D = value_of(D); auto AsolveB = to_arena(A.ldlt().solve(arena_B.val())); - var res = (arena_D.asDiagonal() * arena_B.val().transpose() * AsolveB) - .trace(); + var res + = (arena_D.asDiagonal() * arena_B.val().transpose() * AsolveB).trace(); reverse_pass_callback([AsolveB, arena_B, arena_D, res]() mutable { arena_B.adj() += res.adj() * AsolveB * 2 * arena_D.asDiagonal(); diff --git a/stan/math/rev/fun/trace_gen_quad_form.hpp b/stan/math/rev/fun/trace_gen_quad_form.hpp index 0b0751fd6c6..dc72685aebf 100644 --- a/stan/math/rev/fun/trace_gen_quad_form.hpp +++ b/stan/math/rev/fun/trace_gen_quad_form.hpp @@ -178,15 +178,15 @@ inline var trace_gen_quad_form(const Td& D, const Ta& A, const Tb& B) { var res = (arena_BDT.transpose() * arena_AB).trace(); - reverse_pass_callback([arena_A, arena_B, arena_D, arena_BDT, arena_AB, - res]() mutable { - double C_adj = res.adj(); - - arena_A.adj() += C_adj * arena_BDT * arena_B.val().transpose(); - arena_B.adj() - += C_adj - * (arena_AB * arena_D + arena_A.val().transpose() * arena_BDT); - }); + reverse_pass_callback( + [arena_A, arena_B, arena_D, arena_BDT, arena_AB, res]() mutable { + double C_adj = res.adj(); + + arena_A.adj() += C_adj * arena_BDT * arena_B.val().transpose(); + arena_B.adj() + += C_adj + * (arena_AB * arena_D + arena_A.val().transpose() * arena_BDT); + }); return res; } else if (!is_constant::value && is_constant::value @@ -235,16 +235,16 @@ inline var trace_gen_quad_form(const Td& D, const Ta& A, const Tb& B) { var res = (arena_BDT.transpose() * arena_AB).trace(); - reverse_pass_callback([arena_A, arena_B, arena_D, arena_AB, arena_BDT, - res]() mutable { - double C_adj = res.adj(); + reverse_pass_callback( + [arena_A, arena_B, arena_D, arena_AB, arena_BDT, res]() mutable { + double C_adj = res.adj(); - arena_B.adj() - += C_adj - * (arena_AB * arena_D.val() + arena_A.transpose() * arena_BDT); + arena_B.adj() + += C_adj + * (arena_AB * arena_D.val() + arena_A.transpose() * arena_BDT); - arena_D.adj() += C_adj * (arena_AB.transpose() * arena_B.val()); - }); + arena_D.adj() += C_adj * (arena_AB.transpose() * arena_B.val()); + }); return res; } else if (is_constant::value && !is_constant::value diff --git a/test/unit/math/fwd/eigen_plugins_test.cpp b/test/unit/math/fwd/eigen_plugins_test.cpp index 2c06d0318d0..53394d220e9 100644 --- a/test/unit/math/fwd/eigen_plugins_test.cpp +++ b/test/unit/math/fwd/eigen_plugins_test.cpp @@ -19,7 +19,7 @@ TEST(AgradFwdMatrixAddons, fvar_double_matrix) { } EXPECT_MATRIX_FLOAT_EQ(vals, mat_in.val()); - EXPECT_MATRIX_FLOAT_EQ(derivs*derivs, mat_in.d() * mat_in.d()); + EXPECT_MATRIX_FLOAT_EQ(derivs * derivs, mat_in.d() * mat_in.d()); EXPECT_MATRIX_FLOAT_EQ(vals.array().exp(), mat_in.val().array().exp()); EXPECT_MATRIX_FLOAT_EQ(derivs, mat_in.d()); @@ -48,7 +48,7 @@ TEST(AgradFwdMatrixAddons, fvarfvar_double_matrix) { } EXPECT_MATRIX_FLOAT_EQ(vals, mat_in.val().val()); - EXPECT_MATRIX_FLOAT_EQ(vals*vals, mat_in.val().val() * mat_in.val().val()); + EXPECT_MATRIX_FLOAT_EQ(vals * vals, mat_in.val().val() * mat_in.val().val()); EXPECT_MATRIX_FLOAT_EQ(vals.array().exp(), mat_in.val().val().array().exp()); EXPECT_MATRIX_FLOAT_EQ(derivs, mat_in.d().val()); From 341e0a8fa3078d738f501c8eadc41d52f5ebf234 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Mon, 18 Oct 2021 09:50:03 +0800 Subject: [PATCH 15/27] opencl include --- stan/math/opencl/rev/softmax.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/stan/math/opencl/rev/softmax.hpp b/stan/math/opencl/rev/softmax.hpp index f61ba86963d..e031050f1ba 100644 --- a/stan/math/opencl/rev/softmax.hpp +++ b/stan/math/opencl/rev/softmax.hpp @@ -3,6 +3,7 @@ #ifdef STAN_OPENCL #include +#include #include #include #include From 92123793e0b9374888a30661db2ae866c67849a4 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Mon, 18 Oct 2021 11:07:18 +0800 Subject: [PATCH 16/27] Pointer t and simplify strides --- stan/math/prim/eigen_plugins.h | 3 --- stan/math/prim/plugins/adj_view.h | 2 +- stan/math/prim/plugins/d_view.h | 18 +---------------- stan/math/prim/plugins/val_view.h | 33 +------------------------------ 4 files changed, 3 insertions(+), 53 deletions(-) diff --git a/stan/math/prim/eigen_plugins.h b/stan/math/prim/eigen_plugins.h index 0767eeb7cd8..be57bc59981 100644 --- a/stan/math/prim/eigen_plugins.h +++ b/stan/math/prim/eigen_plugins.h @@ -1,12 +1,9 @@ - #include "plugins/typedefs.h" #include "plugins/adj_view.h" #include "plugins/val_view.h" #include "plugins/d_view.h" #include "plugins/vi_view.h" - - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeffRef(Index row, Index col) const { eigen_internal_assert(row >= 0 && row < rows() diff --git a/stan/math/prim/plugins/adj_view.h b/stan/math/prim/plugins/adj_view.h index 098e1fc1d8a..55fe6162eed 100644 --- a/stan/math/prim/plugins/adj_view.h +++ b/stan/math/prim/plugins/adj_view.h @@ -111,7 +111,7 @@ struct adj_stride::value>> { template struct adj_stride::value>> { - using vari_t = std::remove_pointer_t().vi_)>; + using vari_t = std::remove_pointer_t>>().vi_)>; static constexpr int stride = sizeof(vari_t) / sizeof(typename vari_t::value_type); }; diff --git a/stan/math/prim/plugins/d_view.h b/stan/math/prim/plugins/d_view.h index 199f4026998..dc6440b0666 100644 --- a/stan/math/prim/plugins/d_view.h +++ b/stan/math/prim/plugins/d_view.h @@ -73,24 +73,8 @@ struct scalar_d_ref_op { } }; -template -struct d_stride { }; - -template -struct d_stride::value>> { - static constexpr int stride = -1; -}; - -template -struct d_stride::value>> { - using fvar_t = std::remove_pointer_t; - static constexpr int stride = sizeof(fvar_t) / sizeof(typename fvar_t::Scalar); -}; - typedef CwiseUnaryOp, const Derived> dReturnType; -typedef CwiseUnaryView, Derived, - d_stride::stride, - d_stride::stride> NonConstdReturnType; +typedef CwiseUnaryView, Derived> NonConstdReturnType; EIGEN_DEVICE_FUNC inline const dReturnType diff --git a/stan/math/prim/plugins/val_view.h b/stan/math/prim/plugins/val_view.h index 57f019c5743..40de148a6ed 100644 --- a/stan/math/prim/plugins/val_view.h +++ b/stan/math/prim/plugins/val_view.h @@ -118,43 +118,12 @@ struct scalar_val_ref_op { } }; -template -struct val_stride { }; - - -template -struct val_stride::value - && !is_var::value - && !is_fvar::value>> { - static constexpr int stride = -1; -}; - -template -struct val_stride::value>> { - using vari_t = std::remove_pointer_t; - static constexpr int stride = sizeof(vari_t) / sizeof(typename vari_t::value_type); -}; - -template -struct val_stride::value>> { - using vari_t = std::remove_pointer_t().vi_)>; - static constexpr int stride = sizeof(vari_t) / sizeof(typename vari_t::value_type); -}; - -template -struct val_stride::value>> { - using fvar_t = T; - static constexpr int stride = sizeof(fvar_t) / sizeof(typename fvar_t::Scalar); -}; - /** \internal the return type of imag() const */ typedef CwiseUnaryOp, const Derived> valReturnType; /** \internal the return type of imag() */ typedef std::conditional_t::value || is_vari::value, const valReturnType, - CwiseUnaryView, Derived, - val_stride::stride, - val_stride::stride>> + CwiseUnaryView, Derived>> NonConstvalReturnType; EIGEN_DEVICE_FUNC From d151feb86769983d097a7566589a8f204e94e183 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Tue, 19 Oct 2021 00:16:49 +0800 Subject: [PATCH 17/27] Simplify code --- stan/math/prim/eigen_plugins.h | 26 ++++- stan/math/prim/fun/Eigen.hpp | 17 +-- stan/math/prim/plugins/adj_view.h | 74 ++++--------- stan/math/prim/plugins/d_view.h | 66 ++++-------- stan/math/prim/plugins/typedefs.h | 25 ++++- stan/math/prim/plugins/val_view.h | 123 ++++++++-------------- stan/math/prim/plugins/vi_view.h | 71 ++++--------- test/unit/math/mix/eigen_plugins_test.cpp | 4 +- 8 files changed, 162 insertions(+), 244 deletions(-) diff --git a/stan/math/prim/eigen_plugins.h b/stan/math/prim/eigen_plugins.h index be57bc59981..e5ef80fe58d 100644 --- a/stan/math/prim/eigen_plugins.h +++ b/stan/math/prim/eigen_plugins.h @@ -1,3 +1,6 @@ +#ifndef STAN_MATH_EIGEN_PLUGINS_H +#define STAN_MATH_EIGEN_PLUGINS_H + #include "plugins/typedefs.h" #include "plugins/adj_view.h" #include "plugins/val_view.h" @@ -11,5 +14,24 @@ const Scalar& coeffRef(Index row, Index col) const { return internal::evaluator(derived()).coeffRef(row, col); } -#define EIGEN_STAN_MATRIXBASE_PLUGIN -#define EIGEN_STAN_ARRAYBASE_PLUGIN +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +const Scalar& coeffRef(Index index) const { + eigen_internal_assert(index >= 0 && index < size()); + return internal::evaluator(derived()).coeffRef(index); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +Scalar& coeffRef(Index row, Index col) { + eigen_internal_assert(row >= 0 && row < rows() + && col >= 0 && col < cols()); + return internal::evaluator(derived()).coeffRef(row, col); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +Scalar& coeffRef(Index index) { + eigen_internal_assert(index >= 0 && index < size()); + return internal::evaluator(derived()).coeffRef(index); +} + +#define EIGEN_STAN_DENSEBASE_PLUGIN +#endif diff --git a/stan/math/prim/fun/Eigen.hpp b/stan/math/prim/fun/Eigen.hpp index 6efb32cbdea..29d62b71905 100644 --- a/stan/math/prim/fun/Eigen.hpp +++ b/stan/math/prim/fun/Eigen.hpp @@ -1,22 +1,13 @@ #ifndef STAN_MATH_PRIM_FUN_EIGEN_HPP #define STAN_MATH_PRIM_FUN_EIGEN_HPP -#ifdef EIGEN_MATRIXBASE_PLUGIN -#ifndef EIGEN_STAN_MATRIXBASE_PLUGIN -#error "Stan uses Eigen's EIGEN_MATRIXBASE_PLUGIN macro. To use your own " +#ifdef EIGEN_DENSEBASE_PLUGIN +#ifndef EIGEN_STAN_DENSEBASE_PLUGIN +#error "Stan uses Eigen's EIGENDENSEBASE_PLUGIN macro. To use your own " "plugin add the eigen_plugin.h file to your plugin." #endif #else -#define EIGEN_MATRIXBASE_PLUGIN "stan/math/prim/eigen_plugins.h" -#endif - -#ifdef EIGEN_ARRAYBASE_PLUGIN -#ifndef EIGEN_STAN_ARRAYBASE_PLUGIN -#error "Stan uses Eigen's EIGEN_ARRAYBASE_PLUGIN macro. To use your own " - "plugin add the eigen_plugin.h file to your plugin." -#endif -#else -#define EIGEN_ARRAYBASE_PLUGIN "stan/math/prim/eigen_plugins.h" +#define EIGEN_DENSEBASE_PLUGIN "stan/math/prim/eigen_plugins.h" #endif #include diff --git a/stan/math/prim/plugins/adj_view.h b/stan/math/prim/plugins/adj_view.h index 55fe6162eed..cc379bde92e 100644 --- a/stan/math/prim/plugins/adj_view.h +++ b/stan/math/prim/plugins/adj_view.h @@ -1,27 +1,11 @@ -template -EIGEN_DEVICE_FUNC -static inline const double& adj(const Scalar& x) { - return adj_impl>::run(x); -} - -template -EIGEN_DEVICE_FUNC -static inline const double& -adj_ref(const Scalar& x) { - return adj_ref_impl::run(x); -} - -template -EIGEN_DEVICE_FUNC -static inline double& adj_ref(Scalar& x) { - return adj_ref_impl>::run(x); -} +#ifndef STAN_MATH_PRIM_PLUGINS_ADJ_VIEW_H +#define STAN_MATH_PRIM_PLUGINS_ADJ_VIEW_H template -struct adj_default_impl { }; +struct adj_impl { }; template -struct adj_default_impl::value>> { +struct adj_impl::value>> { EIGEN_DEVICE_FUNC static inline double& run(Scalar& x) { return x->adj_; @@ -33,7 +17,7 @@ struct adj_default_impl::value>> { }; template -struct adj_default_impl::value>> { +struct adj_impl::value>> { EIGEN_DEVICE_FUNC static inline double& run(Scalar& x) { return x.vi_->adj_; @@ -45,46 +29,26 @@ struct adj_default_impl::value>> { }; template -struct adj_impl : adj_default_impl {}; - -template -struct scalar_adj_op { - EIGEN_EMPTY_STRUCT_CTOR(scalar_adj_op) - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE double operator() (const Scalar& a) const { return adj(a); } -}; - - -template -struct adj_ref_default_impl { }; +EIGEN_DEVICE_FUNC +static inline const double& adj(const Scalar& x) { + return adj_impl::run(x); +} template -struct adj_ref_default_impl::value>> { - EIGEN_DEVICE_FUNC - static inline double& run(Scalar& x) { - return *reinterpret_cast(&(x->adj_)); - } - EIGEN_DEVICE_FUNC - static inline const double& run(const Scalar& x) { - return *reinterpret_cast(&(x->adj_)); - } -}; +EIGEN_DEVICE_FUNC +static inline double& adj_ref(Scalar& x) { + return adj_impl::run(x); +} template -struct adj_ref_default_impl::value>> { - EIGEN_DEVICE_FUNC - static inline double& run(Scalar& x) { - return *reinterpret_cast(&(x.vi_->adj_)); - } +struct scalar_adj_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_adj_op) EIGEN_DEVICE_FUNC - static inline const double& run(const Scalar& x) { - return *reinterpret_cast(&(x.vi_->adj_)); + EIGEN_STRONG_INLINE const double& operator() (const Scalar& a) const { + return adj(a); } }; -template -struct adj_ref_impl : adj_ref_default_impl {}; - template struct scalar_adj_ref_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_adj_ref_op) @@ -127,4 +91,6 @@ adj() const { return AdjReturnType(derived()); } EIGEN_DEVICE_FUNC inline NonConstAdjReturnType -adj() { return NonConstAdjReturnType(derived()); } \ No newline at end of file +adj() { return NonConstAdjReturnType(derived()); } + +#endif diff --git a/stan/math/prim/plugins/d_view.h b/stan/math/prim/plugins/d_view.h index dc6440b0666..a76b35e63b2 100644 --- a/stan/math/prim/plugins/d_view.h +++ b/stan/math/prim/plugins/d_view.h @@ -1,30 +1,11 @@ -template -using d_return_t = decltype(T::d_); - -template -EIGEN_DEVICE_FUNC -static inline const d_return_t& d(const Scalar& x) { - return d_impl>::run(x); -} - -template -EIGEN_DEVICE_FUNC -static inline const d_return_t& -d_ref(const Scalar& x) { - return d_ref_impl::run(x); -} - -template -EIGEN_DEVICE_FUNC -static inline d_return_t& d_ref(Scalar& x) { - return d_ref_impl>::run(x); -} +#ifndef STAN_MATH_PRIM_PLUGINS_D_VIEW_H +#define STAN_MATH_PRIM_PLUGINS_D_VIEW_H template -struct d_default_impl { }; +struct d_impl { }; template -struct d_default_impl::value>> { +struct d_impl::value>> { EIGEN_DEVICE_FUNC static inline d_return_t& run(Scalar& x) { return x.d_; @@ -36,39 +17,30 @@ struct d_default_impl::value>> { }; template -struct d_impl : d_default_impl {}; +EIGEN_DEVICE_FUNC +static inline const d_return_t& d(const Scalar& x) { + return d_impl::run(x); +} +template +EIGEN_DEVICE_FUNC +static inline d_return_t& d_ref(Scalar& x) { + return d_impl::run(x); +} template struct scalar_d_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_d_op) EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE d_return_t operator() (const Scalar& a) const { return d(a); } -}; - - -template -struct d_ref_default_impl { }; - -template -struct d_ref_default_impl::value>> { - EIGEN_DEVICE_FUNC - static inline d_return_t& run(Scalar& x) { - return *reinterpret_cast*>(&(x.d_)); - } - EIGEN_DEVICE_FUNC - static inline const d_return_t& run(const Scalar& x) { - return *reinterpret_cast*>(&(x.d_)); - } + EIGEN_STRONG_INLINE + const d_return_t& operator() (const Scalar& a) const { return d(a); } }; -template -struct d_ref_impl : d_ref_default_impl {}; - template struct scalar_d_ref_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_d_ref_op) EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE d_return_t& operator() (const Scalar& a) const { + EIGEN_STRONG_INLINE + d_return_t& operator() (const Scalar& a) const { return d_ref(*const_cast(&a)); } }; @@ -83,4 +55,6 @@ d() const { return dReturnType(derived()); } EIGEN_DEVICE_FUNC inline NonConstdReturnType -d() { return NonConstdReturnType(derived()); } \ No newline at end of file +d() { return NonConstdReturnType(derived()); } + +#endif diff --git a/stan/math/prim/plugins/typedefs.h b/stan/math/prim/plugins/typedefs.h index e6018c057aa..92f9c4fc884 100644 --- a/stan/math/prim/plugins/typedefs.h +++ b/stan/math/prim/plugins/typedefs.h @@ -1,3 +1,6 @@ +#ifndef STAN_MATH_PRIM_PLUGINS_TYPEDEFS_H +#define STAN_MATH_PRIM_PLUGINS_TYPEDEFS_H + template struct is_fvar : std::false_type { }; @@ -19,5 +22,25 @@ template struct is_vari>::adj_))> : std::true_type { }; +template +struct val_return { }; + +template +struct val_return::value>> { + using type = double; +}; +template +struct val_return::value>> { + using type = decltype(T::d_); +}; + template -using eigen_base_filter_t = typename internal::global_math_functions_filtering_base::type; \ No newline at end of file +using val_return_t = typename val_return>::type; + +template +using vi_return_t = decltype(T::vi_); + +template +using d_return_t = decltype(T::d_); + +#endif diff --git a/stan/math/prim/plugins/val_view.h b/stan/math/prim/plugins/val_view.h index 40de148a6ed..f0941b97e5c 100644 --- a/stan/math/prim/plugins/val_view.h +++ b/stan/math/prim/plugins/val_view.h @@ -1,112 +1,81 @@ -template -struct val_return { }; - -template -struct val_return::value>> { - using type = double; -}; -template -struct val_return::value>> { - using type = decltype(T::d_); -}; - -template -using val_return_t = typename val_return>::type; - -template -EIGEN_DEVICE_FUNC -static inline val_return_t val(const Scalar& x){ - return val_impl>::run(x); -} - -template -EIGEN_DEVICE_FUNC -static inline const val_return_t& val_ref(const Scalar& x) { - return val_ref_impl::run(x); -} - -template -EIGEN_DEVICE_FUNC -static inline -val_return_t& -val_ref(Scalar& x) { - return val_ref_impl>::run(x); -} +#ifndef STAN_MATH_PRIM_PLUGINS_VAL_VIEW_H +#define STAN_MATH_PRIM_PLUGINS_VAL_VIEW_H template -struct val_default_impl { }; +struct val_impl { }; template -struct val_default_impl::value>> { +struct val_impl::value>> { EIGEN_DEVICE_FUNC - static inline double run(const Scalar& x) { + static inline const double& run(const Scalar& x) { + return x; + } + EIGEN_DEVICE_FUNC + static inline double& run(Scalar& x) { return x; } }; template -struct val_default_impl::value>> { +struct val_impl::value>> { EIGEN_DEVICE_FUNC - static inline double run(const Scalar& x) { + static inline const double& run(const Scalar& x) { return x->val_; } -}; - -template -struct val_default_impl::value>> { EIGEN_DEVICE_FUNC - static inline double run(const Scalar& x) { - return x.vi_->val_; + static inline double& run(Scalar& x) { + return x->val_; } }; template -struct val_default_impl::value>> { +struct val_impl::value>> { EIGEN_DEVICE_FUNC - static inline val_return_t run(const Scalar& x) { - return x.val_; + static inline const double& run(const Scalar& x) { + return x.vi_->val_; } -}; - -template struct val_impl : val_default_impl {}; - -template -struct scalar_val_op { - EIGEN_EMPTY_STRUCT_CTOR(scalar_val_op) EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE val_return_t operator() (const Scalar& a) const { return val(a); } + static inline double& run(Scalar& x) { + return x.vi_->val_; + } }; -template -struct val_ref_default_impl { }; - template -struct val_ref_default_impl::value>> { +struct val_impl::value>> { EIGEN_DEVICE_FUNC - static inline double& run(Scalar& x) { - return x; + static inline const val_return_t& run(const Scalar& x) { + return x.val_; } EIGEN_DEVICE_FUNC - static inline const double& run(const Scalar& x) { - return x; + static inline val_return_t& run(Scalar& x) { + return x.val_; } }; template -struct val_ref_default_impl::value>> { - EIGEN_DEVICE_FUNC - static inline val_return_t& run(Scalar& x) { - return *reinterpret_cast*>(&(x.val_)); - } +EIGEN_DEVICE_FUNC +static inline const val_return_t& val(const Scalar& x) { + return val_impl::run(x); +} + +template +EIGEN_DEVICE_FUNC +static inline +val_return_t& +val_ref(Scalar& x) { + return val_impl::run(x); +} + +template +struct scalar_val_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_val_op) EIGEN_DEVICE_FUNC - static inline const val_return_t& run(const Scalar& x) { - return *reinterpret_cast*>(&(x.val_)); + EIGEN_STRONG_INLINE + const val_return_t& operator() (const Scalar& a) const { + return val(a); } }; -template -struct val_ref_impl : val_ref_default_impl {}; - template struct scalar_val_ref_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_val_ref_op) @@ -118,9 +87,7 @@ struct scalar_val_ref_op { } }; -/** \internal the return type of imag() const */ typedef CwiseUnaryOp, const Derived> valReturnType; -/** \internal the return type of imag() */ typedef std::conditional_t::value || is_vari::value, const valReturnType, CwiseUnaryView, Derived>> @@ -133,4 +100,6 @@ val() const { return valReturnType(derived()); } EIGEN_DEVICE_FUNC inline NonConstvalReturnType -val() { return NonConstvalReturnType(derived()); } \ No newline at end of file +val() { return NonConstvalReturnType(derived()); } + +#endif diff --git a/stan/math/prim/plugins/vi_view.h b/stan/math/prim/plugins/vi_view.h index da1e8cfbfcd..50d67f20aa5 100644 --- a/stan/math/prim/plugins/vi_view.h +++ b/stan/math/prim/plugins/vi_view.h @@ -1,71 +1,30 @@ -template -using vi_return_t = decltype(T::vi_); - -template -EIGEN_DEVICE_FUNC -static inline const vi_return_t& vi(const Scalar& x) { - return vi_impl>::run(x); -} - -template -EIGEN_DEVICE_FUNC -static inline const vi_return_t& -vi_ref(const Scalar& x) { - return vi_ref_impl::run(x); -} - -template -EIGEN_DEVICE_FUNC -static inline vi_return_t& vi_ref(Scalar& x) { - return vi_ref_impl>::run(x); -} +#ifndef STAN_MATH_PRIM_PLUGINS_VI_VIEW_H +#define STAN_MATH_PRIM_PLUGINS_VI_VIEW_H template -struct vi_default_impl { }; +struct vi_impl { }; template -struct vi_default_impl::value>> { +struct vi_impl::value>> { EIGEN_DEVICE_FUNC - static inline vi_return_t& run(Scalar& x) { + static inline const vi_return_t& run(const Scalar& x) { return x.vi_; } EIGEN_DEVICE_FUNC - static inline const vi_return_t& run(const Scalar& x) { + static inline vi_return_t& run(Scalar& x) { return x.vi_; } }; -template -struct vi_impl : vi_default_impl {}; - template struct scalar_vi_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_vi_op) EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE vi_return_t operator() (const Scalar& a) const { + EIGEN_STRONG_INLINE const vi_return_t& operator() (const Scalar& a) const { return vi(a); } }; - -template -struct vi_ref_default_impl { }; - -template -struct vi_ref_default_impl::value>> { - EIGEN_DEVICE_FUNC - static inline vi_return_t& run(Scalar& x) { - return *reinterpret_cast*>(&(x.vi_)); - } - EIGEN_DEVICE_FUNC - static inline const vi_return_t& run(const Scalar& x) { - return *reinterpret_cast*>(&(x.vi_)); - } -}; - -template -struct vi_ref_impl : vi_ref_default_impl {}; - template struct scalar_vi_ref_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_vi_ref_op) @@ -75,6 +34,18 @@ struct scalar_vi_ref_op { } }; +template +EIGEN_DEVICE_FUNC +static inline const vi_return_t& vi(const Scalar& x) { + return vi_impl::run(x); +} + +template +EIGEN_DEVICE_FUNC +static inline vi_return_t& vi_ref(Scalar& x) { + return vi_impl::run(x); +} + typedef CwiseUnaryOp, const Derived> viReturnType; typedef CwiseUnaryView, Derived> NonConstviReturnType; @@ -85,4 +56,6 @@ vi() const { return viReturnType(derived()); } EIGEN_DEVICE_FUNC inline NonConstviReturnType -vi() { return NonConstviReturnType(derived()); } \ No newline at end of file +vi() { return NonConstviReturnType(derived()); } + +#endif diff --git a/test/unit/math/mix/eigen_plugins_test.cpp b/test/unit/math/mix/eigen_plugins_test.cpp index 3956f5cd06f..ae98f655365 100644 --- a/test/unit/math/mix/eigen_plugins_test.cpp +++ b/test/unit/math/mix/eigen_plugins_test.cpp @@ -18,10 +18,10 @@ TEST(AgradMixMatrixAddons, matrix_fv) { } EXPECT_MATRIX_FLOAT_EQ(vals, mat_in.val().val()); - EXPECT_MATRIX_FLOAT_EQ(vals.array().exp(), mat_in.val().val().array().exp()); + EXPECT_MATRIX_FLOAT_EQ(vals.array().exp(), mat_in.array().val().val().exp()); EXPECT_MATRIX_FLOAT_EQ(derivs, mat_in.d().val()); - EXPECT_MATRIX_FLOAT_EQ(derivs.array().exp(), mat_in.d().val().array().exp()); + EXPECT_MATRIX_FLOAT_EQ(derivs.array().exp(), mat_in.array().d().val().exp()); EXPECT_EQ(mat_in.val().val().rows(), vals.rows()); EXPECT_EQ(mat_in.val().val().cols(), vals.cols()); From 205af1c800885446669120ed54ca8ece0c0cd862 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Tue, 19 Oct 2021 00:46:40 +0800 Subject: [PATCH 18/27] Robust return typedefs --- stan/math/prim/plugins/adj_view.h | 18 +++++++++--------- stan/math/prim/plugins/d_view.h | 12 ++++++------ stan/math/prim/plugins/typedefs.h | 22 +++++++++++++++------- stan/math/prim/plugins/val_view.h | 12 ++++++------ 4 files changed, 36 insertions(+), 28 deletions(-) diff --git a/stan/math/prim/plugins/adj_view.h b/stan/math/prim/plugins/adj_view.h index cc379bde92e..755199e6874 100644 --- a/stan/math/prim/plugins/adj_view.h +++ b/stan/math/prim/plugins/adj_view.h @@ -7,11 +7,11 @@ struct adj_impl { }; template struct adj_impl::value>> { EIGEN_DEVICE_FUNC - static inline double& run(Scalar& x) { + static inline val_return_t& run(Scalar& x) { return x->adj_; } EIGEN_DEVICE_FUNC - static inline const double& run(const Scalar& x) { + static inline const val_return_t& run(const Scalar& x) { return x->adj_; } }; @@ -19,24 +19,24 @@ struct adj_impl::value>> { template struct adj_impl::value>> { EIGEN_DEVICE_FUNC - static inline double& run(Scalar& x) { + static inline val_return_t& run(Scalar& x) { return x.vi_->adj_; } EIGEN_DEVICE_FUNC - static inline const double& run(const Scalar& x) { + static inline const val_return_t& run(const Scalar& x) { return x.vi_->adj_; } }; template EIGEN_DEVICE_FUNC -static inline const double& adj(const Scalar& x) { +static inline const val_return_t& adj(const Scalar& x) { return adj_impl::run(x); } template EIGEN_DEVICE_FUNC -static inline double& adj_ref(Scalar& x) { +static inline val_return_t& adj_ref(Scalar& x) { return adj_impl::run(x); } @@ -44,7 +44,7 @@ template struct scalar_adj_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_adj_op) EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE const double& operator() (const Scalar& a) const { + EIGEN_STRONG_INLINE const val_return_t& operator() (const Scalar& a) const { return adj(a); } }; @@ -53,7 +53,7 @@ template struct scalar_adj_ref_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_adj_ref_op) EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE double& operator() (const Scalar& a) const { + EIGEN_STRONG_INLINE val_return_t& operator() (const Scalar& a) const { return adj_ref(*const_cast(&a)); } }; @@ -75,7 +75,7 @@ struct adj_stride::value>> { template struct adj_stride::value>> { - using vari_t = std::remove_pointer_t>>().vi_)>; + using vari_t = typename std::decay_t>::vari_type; static constexpr int stride = sizeof(vari_t) / sizeof(typename vari_t::value_type); }; diff --git a/stan/math/prim/plugins/d_view.h b/stan/math/prim/plugins/d_view.h index a76b35e63b2..62d7de64a95 100644 --- a/stan/math/prim/plugins/d_view.h +++ b/stan/math/prim/plugins/d_view.h @@ -7,23 +7,23 @@ struct d_impl { }; template struct d_impl::value>> { EIGEN_DEVICE_FUNC - static inline d_return_t& run(Scalar& x) { + static inline val_return_t& run(Scalar& x) { return x.d_; } EIGEN_DEVICE_FUNC - static inline const d_return_t& run(const Scalar& x) { + static inline const val_return_t& run(const Scalar& x) { return x.d_; } }; template EIGEN_DEVICE_FUNC -static inline const d_return_t& d(const Scalar& x) { +static inline const val_return_t& d(const Scalar& x) { return d_impl::run(x); } template EIGEN_DEVICE_FUNC -static inline d_return_t& d_ref(Scalar& x) { +static inline val_return_t& d_ref(Scalar& x) { return d_impl::run(x); } @@ -32,7 +32,7 @@ struct scalar_d_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_d_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const d_return_t& operator() (const Scalar& a) const { return d(a); } + const val_return_t& operator() (const Scalar& a) const { return d(a); } }; template @@ -40,7 +40,7 @@ struct scalar_d_ref_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_d_ref_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - d_return_t& operator() (const Scalar& a) const { + val_return_t& operator() (const Scalar& a) const { return d_ref(*const_cast(&a)); } }; diff --git a/stan/math/prim/plugins/typedefs.h b/stan/math/prim/plugins/typedefs.h index 92f9c4fc884..e150ab77c16 100644 --- a/stan/math/prim/plugins/typedefs.h +++ b/stan/math/prim/plugins/typedefs.h @@ -26,21 +26,29 @@ template struct val_return { }; template -struct val_return::value>> { - using type = double; +struct val_return::value>> { + using type = T; }; + template -struct val_return::value>> { - using type = decltype(T::d_); +struct val_return::value>> { + using type = typename T::value_type; }; template -using val_return_t = typename val_return>::type; +struct val_return::value>> { + using type = typename std::remove_pointer_t::value_type; +}; template -using vi_return_t = decltype(T::vi_); +struct val_return::value>> { + using type = typename T::Scalar; +}; + +template +using val_return_t = typename val_return>::type; template -using d_return_t = decltype(T::d_); +using vi_return_t = std::add_pointer_t; #endif diff --git a/stan/math/prim/plugins/val_view.h b/stan/math/prim/plugins/val_view.h index f0941b97e5c..7fe843834c5 100644 --- a/stan/math/prim/plugins/val_view.h +++ b/stan/math/prim/plugins/val_view.h @@ -7,11 +7,11 @@ struct val_impl { }; template struct val_impl::value>> { EIGEN_DEVICE_FUNC - static inline const double& run(const Scalar& x) { + static inline const val_return_t& run(const Scalar& x) { return x; } EIGEN_DEVICE_FUNC - static inline double& run(Scalar& x) { + static inline val_return_t& run(Scalar& x) { return x; } }; @@ -19,11 +19,11 @@ struct val_impl::value>> { template struct val_impl::value>> { EIGEN_DEVICE_FUNC - static inline const double& run(const Scalar& x) { + static inline const val_return_t& run(const Scalar& x) { return x->val_; } EIGEN_DEVICE_FUNC - static inline double& run(Scalar& x) { + static inline val_return_t& run(Scalar& x) { return x->val_; } }; @@ -31,11 +31,11 @@ struct val_impl::value>> { template struct val_impl::value>> { EIGEN_DEVICE_FUNC - static inline const double& run(const Scalar& x) { + static inline const val_return_t& run(const Scalar& x) { return x.vi_->val_; } EIGEN_DEVICE_FUNC - static inline double& run(Scalar& x) { + static inline val_return_t& run(Scalar& x) { return x.vi_->val_; } }; From e5d8295bab89625fea9bddd5c250381115f0b2af Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Tue, 19 Oct 2021 23:09:51 +0800 Subject: [PATCH 19/27] Update coeffRef for Stan downstream --- .../Eigen/src/Core/CwiseUnaryView.h | 9 +++++++ stan/math/prim/eigen_plugins.h | 26 ------------------- 2 files changed, 9 insertions(+), 26 deletions(-) diff --git a/lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h b/lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h index e230651207a..eb6c943e3aa 100644 --- a/lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h +++ b/lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h @@ -113,6 +113,15 @@ class CwiseUnaryViewImpl EIGEN_DEVICE_FUNC inline Scalar* data() { return &(this->coeffRef(0)); } EIGEN_DEVICE_FUNC inline const Scalar* data() const { return &(this->coeff(0)); } + + using Base::coeffRef; + EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index row, Index col) const { + return const_cast(this)->coeffRef(row, col); + } + + EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index index) const { + return const_cast(this)->coeffRef(index); + } EIGEN_DEVICE_FUNC inline Index innerStride() const { diff --git a/stan/math/prim/eigen_plugins.h b/stan/math/prim/eigen_plugins.h index e5ef80fe58d..a1e68ef88be 100644 --- a/stan/math/prim/eigen_plugins.h +++ b/stan/math/prim/eigen_plugins.h @@ -7,31 +7,5 @@ #include "plugins/d_view.h" #include "plugins/vi_view.h" -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE -const Scalar& coeffRef(Index row, Index col) const { - eigen_internal_assert(row >= 0 && row < rows() - && col >= 0 && col < cols()); - return internal::evaluator(derived()).coeffRef(row, col); -} - -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE -const Scalar& coeffRef(Index index) const { - eigen_internal_assert(index >= 0 && index < size()); - return internal::evaluator(derived()).coeffRef(index); -} - -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE -Scalar& coeffRef(Index row, Index col) { - eigen_internal_assert(row >= 0 && row < rows() - && col >= 0 && col < cols()); - return internal::evaluator(derived()).coeffRef(row, col); -} - -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE -Scalar& coeffRef(Index index) { - eigen_internal_assert(index >= 0 && index < size()); - return internal::evaluator(derived()).coeffRef(index); -} - #define EIGEN_STAN_DENSEBASE_PLUGIN #endif From 5c712915adf95049ba6c888855eaa9669d2c855b Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Thu, 21 Oct 2021 12:37:31 +0800 Subject: [PATCH 20/27] Simplify implementations and add doc --- stan/math/prim/fun/Eigen.hpp | 7 +-- stan/math/prim/plugins/adj_view.h | 72 +++++++++++++++++-------------- stan/math/prim/plugins/d_view.h | 48 ++++++++++----------- stan/math/prim/plugins/typedefs.h | 51 +++++++++++++++------- stan/math/prim/plugins/val_view.h | 56 +++++++++++++----------- stan/math/prim/plugins/vi_view.h | 52 +++++++++++----------- 6 files changed, 156 insertions(+), 130 deletions(-) diff --git a/stan/math/prim/fun/Eigen.hpp b/stan/math/prim/fun/Eigen.hpp index 29d62b71905..3864b77ae5f 100644 --- a/stan/math/prim/fun/Eigen.hpp +++ b/stan/math/prim/fun/Eigen.hpp @@ -3,10 +3,12 @@ #ifdef EIGEN_DENSEBASE_PLUGIN #ifndef EIGEN_STAN_DENSEBASE_PLUGIN -#error "Stan uses Eigen's EIGENDENSEBASE_PLUGIN macro. To use your own " +#error "Stan uses Eigen's EIGEN_DENSEBASE_PLUGIN macro. To use your own " "plugin add the eigen_plugin.h file to your plugin." #endif #else +// By using the DenseBase plugin, we do not need to specify both +// MatrixBase and ArrayBase inclusions #define EIGEN_DENSEBASE_PLUGIN "stan/math/prim/eigen_plugins.h" #endif @@ -16,8 +18,7 @@ #include #include - namespace Eigen { - +namespace Eigen { /** * Traits specialization for Eigen binary operations for `int` * and `double` arguments. diff --git a/stan/math/prim/plugins/adj_view.h b/stan/math/prim/plugins/adj_view.h index 755199e6874..7312f2a382c 100644 --- a/stan/math/prim/plugins/adj_view.h +++ b/stan/math/prim/plugins/adj_view.h @@ -1,9 +1,11 @@ #ifndef STAN_MATH_PRIM_PLUGINS_ADJ_VIEW_H #define STAN_MATH_PRIM_PLUGINS_ADJ_VIEW_H +// Forward declaration to allow specialisations template struct adj_impl { }; +// Struct for returning the adjoint from a vari* template struct adj_impl::value>> { EIGEN_DEVICE_FUNC @@ -16,6 +18,7 @@ struct adj_impl::value>> { } }; +// Struct for returning the adjoint from a var template struct adj_impl::value>> { EIGEN_DEVICE_FUNC @@ -28,69 +31,72 @@ struct adj_impl::value>> { } }; -template -EIGEN_DEVICE_FUNC -static inline const val_return_t& adj(const Scalar& x) { - return adj_impl::run(x); -} - -template -EIGEN_DEVICE_FUNC -static inline val_return_t& adj_ref(Scalar& x) { - return adj_impl::run(x); -} - +// Struct implementing operator() to be called by CWiseUnaryOp to +// return the adjoint from a const var or vari* template struct scalar_adj_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_adj_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const val_return_t& operator() (const Scalar& a) const { - return adj(a); + return adj_impl::run(a); } }; +// Struct implementing operator() to be called by CWiseUnaryView to +// return the adjoint from a non-const var or vari* template struct scalar_adj_ref_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_adj_ref_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE val_return_t& operator() (const Scalar& a) const { - return adj_ref(*const_cast(&a)); + return adj_impl::run(*const_cast(&a)); } }; +/** + * Eigen's CWiseUnaryView deduces the Stride between values in memory as + * the ratio of the sizeof the Matrix scalar type and the returned scalar type. + * This means that, by default, the stride for var_value is deduced as + * sizeof(var) / sizeof(T). + * + * However, because var types are pointers-to-implementation variables, the stride + * actually needs to be be sizeof(vari) / sizeof(T). This struct returns the correct + * stride when var types are used, otherwise it returns the default argument (-1) + * which indicates that Eigen should calculate the stride as usual + */ template struct adj_stride { }; template -struct adj_stride::value - && !is_var::value>> { +struct adj_stride::value>> { static constexpr int stride = -1; }; -template -struct adj_stride::value>> { - using vari_t = std::remove_pointer_t; - static constexpr int stride = sizeof(vari_t) / sizeof(typename vari_t::value_type); -}; - template struct adj_stride::value>> { using vari_t = typename std::decay_t>::vari_type; - static constexpr int stride = sizeof(vari_t) / sizeof(typename vari_t::value_type); + static constexpr int stride = sizeof(vari_t) + / sizeof(typename vari_t::value_type); }; -typedef CwiseUnaryOp, const Derived> AdjReturnType; -typedef CwiseUnaryView, Derived, - adj_stride::stride, - adj_stride::stride> NonConstAdjReturnType; - +/** + * Coefficient-wise function returning a view of the adjoints that cannot + * be modified + */ EIGEN_DEVICE_FUNC -inline const AdjReturnType -adj() const { return AdjReturnType(derived()); } - +inline const auto adj() const { + return CwiseUnaryOp, const Derived>(derived()); +} +/** + * Coefficient-wise function returning a view of the adjoints that can + * be modified. The stride is explicitly specified for var types. + */ EIGEN_DEVICE_FUNC -inline NonConstAdjReturnType -adj() { return NonConstAdjReturnType(derived()); } +inline auto adj() { + return CwiseUnaryView, Derived, + adj_stride::stride, + adj_stride::stride>(derived()); +} #endif diff --git a/stan/math/prim/plugins/d_view.h b/stan/math/prim/plugins/d_view.h index 62d7de64a95..98d9256e766 100644 --- a/stan/math/prim/plugins/d_view.h +++ b/stan/math/prim/plugins/d_view.h @@ -1,11 +1,10 @@ #ifndef STAN_MATH_PRIM_PLUGINS_D_VIEW_H #define STAN_MATH_PRIM_PLUGINS_D_VIEW_H -template -struct d_impl { }; +// Struct for returning the gradient from an fvar template -struct d_impl::value>> { +struct d_impl { EIGEN_DEVICE_FUNC static inline val_return_t& run(Scalar& x) { return x.d_; @@ -16,45 +15,46 @@ struct d_impl::value>> { } }; -template -EIGEN_DEVICE_FUNC -static inline const val_return_t& d(const Scalar& x) { - return d_impl::run(x); -} -template -EIGEN_DEVICE_FUNC -static inline val_return_t& d_ref(Scalar& x) { - return d_impl::run(x); -} - +// Struct implementing operator() to be called by CWiseUnaryOp to +// return the gradient from a const fvar template struct scalar_d_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_d_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const val_return_t& operator() (const Scalar& a) const { return d(a); } + const val_return_t& operator() (const Scalar& a) const { + return d_impl::run(a); + } }; +// Struct implementing operator() to be called by CWiseUnaryView to +// return the gradient from a non-const fvar template struct scalar_d_ref_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_d_ref_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE val_return_t& operator() (const Scalar& a) const { - return d_ref(*const_cast(&a)); + return d_impl::run(*const_cast(&a)); } }; -typedef CwiseUnaryOp, const Derived> dReturnType; -typedef CwiseUnaryView, Derived> NonConstdReturnType; - +/** + * Coefficient-wise function returning a view of the gradients that cannot + * be modified + */ EIGEN_DEVICE_FUNC -inline const dReturnType -d() const { return dReturnType(derived()); } - +inline auto d() const { + return CwiseUnaryOp, const Derived>(derived()); +} +/** + * Coefficient-wise function returning a view of the gradients that can + * be modified. + */ EIGEN_DEVICE_FUNC -inline NonConstdReturnType -d() { return NonConstdReturnType(derived()); } +inline auto d() { + return CwiseUnaryView, Derived>(derived()); +} #endif diff --git a/stan/math/prim/plugins/typedefs.h b/stan/math/prim/plugins/typedefs.h index e150ab77c16..65ee8779607 100644 --- a/stan/math/prim/plugins/typedefs.h +++ b/stan/math/prim/plugins/typedefs.h @@ -1,27 +1,44 @@ #ifndef STAN_MATH_PRIM_PLUGINS_TYPEDEFS_H #define STAN_MATH_PRIM_PLUGINS_TYPEDEFS_H -template -struct is_fvar : std::false_type -{ }; -template -struct is_fvar : std::true_type -{ }; +/** + * Reimplements is_fvar without requiring external math headers + * + * decltype((void)(T::d_)) is a pre C++17 replacement for + * std::void_t + * + * TODO(Andrew): Replace with std::void_t after move to C++17 + */ +template +struct is_fvar : std::false_type { }; +template +struct is_fvar : std::true_type { }; -template -struct is_var : std::false_type -{ }; -template -struct is_var>::vi_))> : std::true_type -{ }; +/** + * Reimplements is_var without requiring external math headers + */ +template +struct is_var : std::false_type { }; +template +struct is_var::vi_))> : std::true_type { }; -template -struct is_vari : std::false_type -{ }; -template +/** + * Reimplements is_vari without requiring external math headers + */ +template +struct is_vari : std::false_type { }; +template struct is_vari>::adj_))> : std::true_type { }; +/** + * Struct for determining the appropriate return type for a given input type. + * The type mappings are: + * - arithmetic -> arithmetic + * - var_value -> T + * - vari_value -> T + * - fvar -> T + */ template struct val_return { }; @@ -48,6 +65,8 @@ struct val_return::value>> { template using val_return_t = typename val_return>::type; +// Typedef for determining the type within a vari_value and returning +// with a pointer type template using vi_return_t = std::add_pointer_t; diff --git a/stan/math/prim/plugins/val_view.h b/stan/math/prim/plugins/val_view.h index 7fe843834c5..fe52f6ac441 100644 --- a/stan/math/prim/plugins/val_view.h +++ b/stan/math/prim/plugins/val_view.h @@ -1,9 +1,11 @@ #ifndef STAN_MATH_PRIM_PLUGINS_VAL_VIEW_H #define STAN_MATH_PRIM_PLUGINS_VAL_VIEW_H +// Forward declaration to allow specialisations template struct val_impl { }; +// Struct for returning an arithmetic input argument unchanged template struct val_impl::value>> { EIGEN_DEVICE_FUNC @@ -16,6 +18,7 @@ struct val_impl::value>> { } }; +// Struct for returning the value from a vari* template struct val_impl::value>> { EIGEN_DEVICE_FUNC @@ -28,6 +31,8 @@ struct val_impl::value>> { } }; + +// Struct for returning the value from a var template struct val_impl::value>> { EIGEN_DEVICE_FUNC @@ -40,6 +45,7 @@ struct val_impl::value>> { } }; +// Struct for returning the value from an fvar template struct val_impl::value>> { EIGEN_DEVICE_FUNC @@ -52,54 +58,52 @@ struct val_impl::value>> { } }; +// Struct implementing operator() to be called by CWiseUnaryOp to +// return the value from a const var or vari* template -EIGEN_DEVICE_FUNC -static inline const val_return_t& val(const Scalar& x) { - return val_impl::run(x); -} - -template -EIGEN_DEVICE_FUNC -static inline -val_return_t& -val_ref(Scalar& x) { - return val_impl::run(x); -} - -template struct scalar_val_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_val_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const val_return_t& operator() (const Scalar& a) const { - return val(a); + return val_impl::run(a); } }; -template +// Struct implementing operator() to be called by CWiseUnaryView to +// return the value from a non-const var or vari* +template struct scalar_val_ref_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_val_ref_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE val_return_t& operator() (const Scalar& a) const { - return val_ref(*const_cast(&a)); + return val_impl::run(*const_cast(&a)); } }; typedef CwiseUnaryOp, const Derived> valReturnType; -typedef std::conditional_t::value || is_vari::value, - const valReturnType, - CwiseUnaryView, Derived>> -NonConstvalReturnType; +/** + * Coefficient-wise function returning a view of the values that cannot + * be modified + */ EIGEN_DEVICE_FUNC -inline const valReturnType -val() const { return valReturnType(derived()); } - +inline const auto val() const { return valReturnType(derived()); } +/** + * Coefficient-wise function returning a view of the adjoints that can + * be modified. + * + * The .val() method will always return a const type for var & vari inputs, + * so CWiseUnaryOp is called instead. + */ EIGEN_DEVICE_FUNC -inline NonConstvalReturnType -val() { return NonConstvalReturnType(derived()); } +inline auto val() { + return std::conditional_t::value || is_vari::value, + const valReturnType, + CwiseUnaryView,Derived>>(derived()); +} #endif diff --git a/stan/math/prim/plugins/vi_view.h b/stan/math/prim/plugins/vi_view.h index 50d67f20aa5..3f984b4a2e9 100644 --- a/stan/math/prim/plugins/vi_view.h +++ b/stan/math/prim/plugins/vi_view.h @@ -1,11 +1,9 @@ #ifndef STAN_MATH_PRIM_PLUGINS_VI_VIEW_H #define STAN_MATH_PRIM_PLUGINS_VI_VIEW_H -template -struct vi_impl { }; - +// Struct for returning the vari* from a var template -struct vi_impl::value>> { +struct vi_impl{ EIGEN_DEVICE_FUNC static inline const vi_return_t& run(const Scalar& x) { return x.vi_; @@ -16,46 +14,44 @@ struct vi_impl::value>> { } }; +// Struct implementing operator() to be called by CWiseUnaryOp to +// return the vari from a const var template struct scalar_vi_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_vi_op) - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE const vi_return_t& operator() (const Scalar& a) const { - return vi(a); + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const vi_return_t& operator() (const Scalar& a) const { + return vi_impl::run(a); } }; +// Struct implementing operator() to be called by CWiseUnaryView to +// return the vari from a non-const var template struct scalar_vi_ref_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_vi_ref_op) - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE vi_return_t& operator() (const Scalar& a) const { - return vi_ref(*const_cast(&a)); + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + vi_return_t& operator() (const Scalar& a) const { + return vi_impl::run(*const_cast(&a)); } }; -template +/** + * Coefficient-wise function returning a view of the values that cannot + * be modified + */ EIGEN_DEVICE_FUNC -static inline const vi_return_t& vi(const Scalar& x) { - return vi_impl::run(x); +inline const auto vi() const { + return CwiseUnaryOp, const Derived>(derived()); } -template +/** + * Coefficient-wise function returning a view of the adjoints that can + * be modified. + */ EIGEN_DEVICE_FUNC -static inline vi_return_t& vi_ref(Scalar& x) { - return vi_impl::run(x); +inline auto vi() { + return CwiseUnaryView, Derived> (derived()); } -typedef CwiseUnaryOp, const Derived> viReturnType; -typedef CwiseUnaryView, Derived> NonConstviReturnType; - -EIGEN_DEVICE_FUNC -inline const viReturnType -vi() const { return viReturnType(derived()); } - - -EIGEN_DEVICE_FUNC -inline NonConstviReturnType -vi() { return NonConstviReturnType(derived()); } - #endif From ad112fd0bed9987ababffb2ce4f464894306edb9 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Thu, 21 Oct 2021 04:38:39 +0000 Subject: [PATCH 21/27] [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.04.1 (tags/RELEASE_600/final) --- stan/math/prim/fun/Eigen.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stan/math/prim/fun/Eigen.hpp b/stan/math/prim/fun/Eigen.hpp index 3864b77ae5f..4f4a6a50df9 100644 --- a/stan/math/prim/fun/Eigen.hpp +++ b/stan/math/prim/fun/Eigen.hpp @@ -18,7 +18,7 @@ #include #include -namespace Eigen { + namespace Eigen { /** * Traits specialization for Eigen binary operations for `int` * and `double` arguments. From bcc3374eab0727c2f0cfc442e9c2b0415c6c1e99 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Thu, 13 Jan 2022 00:20:45 +0800 Subject: [PATCH 22/27] Backport Eigen changes --- stan/math/prim/fun/Eigen.hpp | 1 + stan/math/prim/plugins/Core.h | 542 ++++++ stan/math/prim/plugins/CoreEvaluators.h | 1688 ++++++++++++++++++ stan/math/prim/plugins/CwiseUnaryView.h | 142 ++ stan/math/prim/plugins/ForwardDeclarations.h | 298 ++++ 5 files changed, 2671 insertions(+) create mode 100644 stan/math/prim/plugins/Core.h create mode 100644 stan/math/prim/plugins/CoreEvaluators.h create mode 100644 stan/math/prim/plugins/CwiseUnaryView.h create mode 100644 stan/math/prim/plugins/ForwardDeclarations.h diff --git a/stan/math/prim/fun/Eigen.hpp b/stan/math/prim/fun/Eigen.hpp index 6efb32cbdea..8b55dfdfd59 100644 --- a/stan/math/prim/fun/Eigen.hpp +++ b/stan/math/prim/fun/Eigen.hpp @@ -19,6 +19,7 @@ #define EIGEN_ARRAYBASE_PLUGIN "stan/math/prim/eigen_plugins.h" #endif +#include #include #include #include diff --git a/stan/math/prim/plugins/Core.h b/stan/math/prim/plugins/Core.h new file mode 100644 index 00000000000..195ebfa4eff --- /dev/null +++ b/stan/math/prim/plugins/Core.h @@ -0,0 +1,542 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2008 Gael Guennebaud +// Copyright (C) 2007-2011 Benoit Jacob +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CORE_H +#define EIGEN_CORE_H + +// first thing Eigen does: stop the compiler from committing suicide +#include "Eigen/src/Core/util/DisableStupidWarnings.h" + +#if defined(__CUDACC__) && !defined(EIGEN_NO_CUDA) + #define EIGEN_CUDACC __CUDACC__ +#endif + +#if defined(__CUDA_ARCH__) && !defined(EIGEN_NO_CUDA) + #define EIGEN_CUDA_ARCH __CUDA_ARCH__ +#endif + +#if defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 9) +#define EIGEN_CUDACC_VER ((__CUDACC_VER_MAJOR__ * 10000) + (__CUDACC_VER_MINOR__ * 100)) +#elif defined(__CUDACC_VER__) +#define EIGEN_CUDACC_VER __CUDACC_VER__ +#else +#define EIGEN_CUDACC_VER 0 +#endif + +// Handle NVCC/CUDA/SYCL +#if defined(__CUDACC__) || defined(__SYCL_DEVICE_ONLY__) + // Do not try asserts on CUDA and SYCL! + #ifndef EIGEN_NO_DEBUG + #define EIGEN_NO_DEBUG + #endif + + #ifdef EIGEN_INTERNAL_DEBUGGING + #undef EIGEN_INTERNAL_DEBUGGING + #endif + + #ifdef EIGEN_EXCEPTIONS + #undef EIGEN_EXCEPTIONS + #endif + + // All functions callable from CUDA code must be qualified with __device__ + #ifdef __CUDACC__ + // Do not try to vectorize on CUDA and SYCL! + #ifndef EIGEN_DONT_VECTORIZE + #define EIGEN_DONT_VECTORIZE + #endif + + #define EIGEN_DEVICE_FUNC __host__ __device__ + // We need cuda_runtime.h to ensure that that EIGEN_USING_STD_MATH macro + // works properly on the device side + #include + #else + #define EIGEN_DEVICE_FUNC + #endif + +#else + #define EIGEN_DEVICE_FUNC + +#endif + +// When compiling CUDA device code with NVCC, pull in math functions from the +// global namespace. In host mode, and when device doee with clang, use the +// std versions. +#if defined(__CUDA_ARCH__) && defined(__NVCC__) + #define EIGEN_USING_STD_MATH(FUNC) using ::FUNC; +#else + #define EIGEN_USING_STD_MATH(FUNC) using std::FUNC; +#endif + +#if (defined(_CPPUNWIND) || defined(__EXCEPTIONS)) && !defined(__CUDA_ARCH__) && !defined(EIGEN_EXCEPTIONS) && !defined(EIGEN_USE_SYCL) + #define EIGEN_EXCEPTIONS +#endif + +#ifdef EIGEN_EXCEPTIONS + #include +#endif + +// then include this file where all our macros are defined. It's really important to do it first because +// it's where we do all the alignment settings (platform detection and honoring the user's will if he +// defined e.g. EIGEN_DONT_ALIGN) so it needs to be done before we do anything with vectorization. +#include "Eigen/src/Core/util/Macros.h" + +// Disable the ipa-cp-clone optimization flag with MinGW 6.x or newer (enabled by default with -O3) +// See http://eigen.tuxfamily.org/bz/show_bug.cgi?id=556 for details. +#if EIGEN_COMP_MINGW && EIGEN_GNUC_AT_LEAST(4,6) + #pragma GCC optimize ("-fno-ipa-cp-clone") +#endif + +#include + +// this include file manages BLAS and MKL related macros +// and inclusion of their respective header files +#include "Eigen/src/Core/util/MKL_support.h" + +// if alignment is disabled, then disable vectorization. Note: EIGEN_MAX_ALIGN_BYTES is the proper check, it takes into +// account both the user's will (EIGEN_MAX_ALIGN_BYTES,EIGEN_DONT_ALIGN) and our own platform checks +#if EIGEN_MAX_ALIGN_BYTES==0 + #ifndef EIGEN_DONT_VECTORIZE + #define EIGEN_DONT_VECTORIZE + #endif +#endif + +#if EIGEN_COMP_MSVC + #include // for _aligned_malloc -- need it regardless of whether vectorization is enabled + #if (EIGEN_COMP_MSVC >= 1500) // 2008 or later + // Remember that usage of defined() in a #define is undefined by the standard. + // a user reported that in 64-bit mode, MSVC doesn't care to define _M_IX86_FP. + #if (defined(_M_IX86_FP) && (_M_IX86_FP >= 2)) || EIGEN_ARCH_x86_64 + #define EIGEN_SSE2_ON_MSVC_2008_OR_LATER + #endif + #endif +#else + // Remember that usage of defined() in a #define is undefined by the standard + #if (defined __SSE2__) && ( (!EIGEN_COMP_GNUC) || EIGEN_COMP_ICC || EIGEN_GNUC_AT_LEAST(4,2) ) + #define EIGEN_SSE2_ON_NON_MSVC_BUT_NOT_OLD_GCC + #endif +#endif + +#ifndef EIGEN_DONT_VECTORIZE + + #if defined (EIGEN_SSE2_ON_NON_MSVC_BUT_NOT_OLD_GCC) || defined(EIGEN_SSE2_ON_MSVC_2008_OR_LATER) + + // Defines symbols for compile-time detection of which instructions are + // used. + // EIGEN_VECTORIZE_YY is defined if and only if the instruction set YY is used + #define EIGEN_VECTORIZE + #define EIGEN_VECTORIZE_SSE + #define EIGEN_VECTORIZE_SSE2 + + // Detect sse3/ssse3/sse4: + // gcc and icc defines __SSE3__, ... + // there is no way to know about this on msvc. You can define EIGEN_VECTORIZE_SSE* if you + // want to force the use of those instructions with msvc. + #ifdef __SSE3__ + #define EIGEN_VECTORIZE_SSE3 + #endif + #ifdef __SSSE3__ + #define EIGEN_VECTORIZE_SSSE3 + #endif + #ifdef __SSE4_1__ + #define EIGEN_VECTORIZE_SSE4_1 + #endif + #ifdef __SSE4_2__ + #define EIGEN_VECTORIZE_SSE4_2 + #endif + #ifdef __AVX__ + #define EIGEN_VECTORIZE_AVX + #define EIGEN_VECTORIZE_SSE3 + #define EIGEN_VECTORIZE_SSSE3 + #define EIGEN_VECTORIZE_SSE4_1 + #define EIGEN_VECTORIZE_SSE4_2 + #endif + #ifdef __AVX2__ + #define EIGEN_VECTORIZE_AVX2 + #endif + #ifdef __FMA__ + #define EIGEN_VECTORIZE_FMA + #endif + #if defined(__AVX512F__) && defined(EIGEN_ENABLE_AVX512) + #define EIGEN_VECTORIZE_AVX512 + #define EIGEN_VECTORIZE_AVX2 + #define EIGEN_VECTORIZE_AVX + #define EIGEN_VECTORIZE_FMA + #ifdef __AVX512DQ__ + #define EIGEN_VECTORIZE_AVX512DQ + #endif + #ifdef __AVX512ER__ + #define EIGEN_VECTORIZE_AVX512ER + #endif + #endif + + // include files + + // This extern "C" works around a MINGW-w64 compilation issue + // https://sourceforge.net/tracker/index.php?func=detail&aid=3018394&group_id=202880&atid=983354 + // In essence, intrin.h is included by windows.h and also declares intrinsics (just as emmintrin.h etc. below do). + // However, intrin.h uses an extern "C" declaration, and g++ thus complains of duplicate declarations + // with conflicting linkage. The linkage for intrinsics doesn't matter, but at that stage the compiler doesn't know; + // so, to avoid compile errors when windows.h is included after Eigen/Core, ensure intrinsics are extern "C" here too. + // notice that since these are C headers, the extern "C" is theoretically needed anyways. + extern "C" { + // In theory we should only include immintrin.h and not the other *mmintrin.h header files directly. + // Doing so triggers some issues with ICC. However old gcc versions seems to not have this file, thus: + #if EIGEN_COMP_ICC >= 1110 + #include + #else + #include + #include + #include + #ifdef EIGEN_VECTORIZE_SSE3 + #include + #endif + #ifdef EIGEN_VECTORIZE_SSSE3 + #include + #endif + #ifdef EIGEN_VECTORIZE_SSE4_1 + #include + #endif + #ifdef EIGEN_VECTORIZE_SSE4_2 + #include + #endif + #if defined(EIGEN_VECTORIZE_AVX) || defined(EIGEN_VECTORIZE_AVX512) + #include + #endif + #endif + } // end extern "C" + #elif defined __VSX__ + #define EIGEN_VECTORIZE + #define EIGEN_VECTORIZE_VSX + #include + // We need to #undef all these ugly tokens defined in + // => use __vector instead of vector + #undef bool + #undef vector + #undef pixel + #elif defined __ALTIVEC__ + #define EIGEN_VECTORIZE + #define EIGEN_VECTORIZE_ALTIVEC + #include + // We need to #undef all these ugly tokens defined in + // => use __vector instead of vector + #undef bool + #undef vector + #undef pixel + #elif (defined __ARM_NEON) || (defined __ARM_NEON__) + #define EIGEN_VECTORIZE + #define EIGEN_VECTORIZE_NEON + #include + #elif (defined __s390x__ && defined __VEC__) + #define EIGEN_VECTORIZE + #define EIGEN_VECTORIZE_ZVECTOR + #include + #endif +#endif + +#if defined(__F16C__) && !defined(EIGEN_COMP_CLANG) + // We can use the optimized fp16 to float and float to fp16 conversion routines + #define EIGEN_HAS_FP16_C +#endif + +#if defined __CUDACC__ + #define EIGEN_VECTORIZE_CUDA + #include + #if EIGEN_CUDACC_VER >= 70500 + #define EIGEN_HAS_CUDA_FP16 + #endif +#endif + +#if defined EIGEN_HAS_CUDA_FP16 + #include + #include +#endif + +#if (defined _OPENMP) && (!defined EIGEN_DONT_PARALLELIZE) + #define EIGEN_HAS_OPENMP +#endif + +#ifdef EIGEN_HAS_OPENMP +#include +#endif + +// MSVC for windows mobile does not have the errno.h file +#if !(EIGEN_COMP_MSVC && EIGEN_OS_WINCE) && !EIGEN_COMP_ARM +#define EIGEN_HAS_ERRNO +#endif + +#ifdef EIGEN_HAS_ERRNO +#include +#endif +#include +#include +#include +#include +#include +#include +#ifndef EIGEN_NO_IO + #include +#endif +#include +#include +#include +#include // for CHAR_BIT +// for min/max: +#include + +// for std::is_nothrow_move_assignable +#ifdef EIGEN_INCLUDE_TYPE_TRAITS +#include +#endif + +// for outputting debug info +#ifdef EIGEN_DEBUG_ASSIGN +#include +#endif + +// required for __cpuid, needs to be included after cmath +#if EIGEN_COMP_MSVC && EIGEN_ARCH_i386_OR_x86_64 && !EIGEN_OS_WINCE + #include +#endif + +/** \brief Namespace containing all symbols from the %Eigen library. */ +namespace Eigen { + +inline static const char *SimdInstructionSetsInUse(void) { +#if defined(EIGEN_VECTORIZE_AVX512) + return "AVX512, FMA, AVX2, AVX, SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2"; +#elif defined(EIGEN_VECTORIZE_AVX) + return "AVX SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2"; +#elif defined(EIGEN_VECTORIZE_SSE4_2) + return "SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2"; +#elif defined(EIGEN_VECTORIZE_SSE4_1) + return "SSE, SSE2, SSE3, SSSE3, SSE4.1"; +#elif defined(EIGEN_VECTORIZE_SSSE3) + return "SSE, SSE2, SSE3, SSSE3"; +#elif defined(EIGEN_VECTORIZE_SSE3) + return "SSE, SSE2, SSE3"; +#elif defined(EIGEN_VECTORIZE_SSE2) + return "SSE, SSE2"; +#elif defined(EIGEN_VECTORIZE_ALTIVEC) + return "AltiVec"; +#elif defined(EIGEN_VECTORIZE_VSX) + return "VSX"; +#elif defined(EIGEN_VECTORIZE_NEON) + return "ARM NEON"; +#elif defined(EIGEN_VECTORIZE_ZVECTOR) + return "S390X ZVECTOR"; +#else + return "None"; +#endif +} + +} // end namespace Eigen + +#if defined EIGEN2_SUPPORT_STAGE40_FULL_EIGEN3_STRICTNESS || defined EIGEN2_SUPPORT_STAGE30_FULL_EIGEN3_API || defined EIGEN2_SUPPORT_STAGE20_RESOLVE_API_CONFLICTS || defined EIGEN2_SUPPORT_STAGE10_FULL_EIGEN2_API || defined EIGEN2_SUPPORT +// This will generate an error message: +#error Eigen2-support is only available up to version 3.2. Please go to "http://eigen.tuxfamily.org/index.php?title=Eigen2" for further information +#endif + +namespace Eigen { + +// we use size_t frequently and we'll never remember to prepend it with std:: everytime just to +// ensure QNX/QCC support +using std::size_t; +// gcc 4.6.0 wants std:: for ptrdiff_t +using std::ptrdiff_t; + +} + +/** \defgroup Core_Module Core module + * This is the main module of Eigen providing dense matrix and vector support + * (both fixed and dynamic size) with all the features corresponding to a BLAS library + * and much more... + * + * \code + * #include + * \endcode + */ + +#include "Eigen/src/Core/util/Constants.h" +#include "Eigen/src/Core/util/Meta.h" +#include "stan/math/prim/plugins/ForwardDeclarations.h" +#include "Eigen/src/Core/util/StaticAssert.h" +#include "Eigen/src/Core/util/XprHelper.h" +#include "Eigen/src/Core/util/Memory.h" + +#include "Eigen/src/Core/NumTraits.h" +#include "Eigen/src/Core/MathFunctions.h" +#include "Eigen/src/Core/GenericPacketMath.h" +#include "Eigen/src/Core/MathFunctionsImpl.h" +#include "Eigen/src/Core/arch/Default/ConjHelper.h" + +#if defined EIGEN_VECTORIZE_AVX512 + #include "Eigen/src/Core/arch/SSE/PacketMath.h" + #include "Eigen/src/Core/arch/SSE/MathFunctions.h" + #include "Eigen/src/Core/arch/AVX/PacketMath.h" + #include "Eigen/src/Core/arch/AVX/MathFunctions.h" + #include "Eigen/src/Core/arch/AVX512/PacketMath.h" + #include "Eigen/src/Core/arch/AVX512/MathFunctions.h" +#elif defined EIGEN_VECTORIZE_AVX + // Use AVX for floats and doubles, SSE for integers + #include "Eigen/src/Core/arch/SSE/PacketMath.h" + #include "Eigen/src/Core/arch/SSE/Complex.h" + #include "Eigen/src/Core/arch/SSE/MathFunctions.h" + #include "Eigen/src/Core/arch/AVX/PacketMath.h" + #include "Eigen/src/Core/arch/AVX/MathFunctions.h" + #include "Eigen/src/Core/arch/AVX/Complex.h" + #include "Eigen/src/Core/arch/AVX/TypeCasting.h" + #include "Eigen/src/Core/arch/SSE/TypeCasting.h" +#elif defined EIGEN_VECTORIZE_SSE + #include "Eigen/src/Core/arch/SSE/PacketMath.h" + #include "Eigen/src/Core/arch/SSE/MathFunctions.h" + #include "Eigen/src/Core/arch/SSE/Complex.h" + #include "Eigen/src/Core/arch/SSE/TypeCasting.h" +#elif defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) + #include "Eigen/src/Core/arch/AltiVec/PacketMath.h" + #include "Eigen/src/Core/arch/AltiVec/MathFunctions.h" + #include "Eigen/src/Core/arch/AltiVec/Complex.h" +#elif defined EIGEN_VECTORIZE_NEON + #include "Eigen/src/Core/arch/NEON/PacketMath.h" + #include "Eigen/src/Core/arch/NEON/MathFunctions.h" + #include "Eigen/src/Core/arch/NEON/Complex.h" +#elif defined EIGEN_VECTORIZE_ZVECTOR + #include "Eigen/src/Core/arch/ZVector/PacketMath.h" + #include "Eigen/src/Core/arch/ZVector/MathFunctions.h" + #include "Eigen/src/Core/arch/ZVector/Complex.h" +#endif + +// Half float support +#include "Eigen/src/Core/arch/CUDA/Half.h" +#include "Eigen/src/Core/arch/CUDA/PacketMathHalf.h" +#include "Eigen/src/Core/arch/CUDA/TypeCasting.h" + +#if defined EIGEN_VECTORIZE_CUDA + #include "Eigen/src/Core/arch/CUDA/PacketMath.h" + #include "Eigen/src/Core/arch/CUDA/MathFunctions.h" +#endif + +#include "Eigen/src/Core/arch/Default/Settings.h" + +#include "Eigen/src/Core/functors/TernaryFunctors.h" +#include "Eigen/src/Core/functors/BinaryFunctors.h" +#include "Eigen/src/Core/functors/UnaryFunctors.h" +#include "Eigen/src/Core/functors/NullaryFunctors.h" +#include "Eigen/src/Core/functors/StlFunctors.h" +#include "Eigen/src/Core/functors/AssignmentFunctors.h" + +// Specialized functors to enable the processing of complex numbers +// on CUDA devices +#include "Eigen/src/Core/arch/CUDA/Complex.h" + +#include "Eigen/src/Core/IO.h" +#include "Eigen/src/Core/DenseCoeffsBase.h" +#include "Eigen/src/Core/DenseBase.h" +#include "Eigen/src/Core/MatrixBase.h" +#include "Eigen/src/Core/EigenBase.h" + +#include "Eigen/src/Core/Product.h" +#include "stan/math/prim/plugins/CoreEvaluators.h" +#include "Eigen/src/Core/AssignEvaluator.h" + +#ifndef EIGEN_PARSED_BY_DOXYGEN // work around Doxygen bug triggered by Assign.h r814874 + // at least confirmed with Doxygen 1.5.5 and 1.5.6 + #include "Eigen/src/Core/Assign.h" +#endif + +#include "Eigen/src/Core/ArrayBase.h" +#include "Eigen/src/Core/util/BlasUtil.h" +#include "Eigen/src/Core/DenseStorage.h" +#include "Eigen/src/Core/NestByValue.h" + +// #include "Eigen/src/Core/ForceAlignedAccess.h" + +#include "Eigen/src/Core/ReturnByValue.h" +#include "Eigen/src/Core/NoAlias.h" +#include "Eigen/src/Core/PlainObjectBase.h" +#include "Eigen/src/Core/Matrix.h" +#include "Eigen/src/Core/Array.h" +#include "Eigen/src/Core/CwiseTernaryOp.h" +#include "Eigen/src/Core/CwiseBinaryOp.h" +#include "Eigen/src/Core/CwiseUnaryOp.h" +#include "Eigen/src/Core/CwiseNullaryOp.h" +#include "stan/math/prim/plugins/CwiseUnaryView.h" +#include "Eigen/src/Core/SelfCwiseBinaryOp.h" +#include "Eigen/src/Core/Dot.h" +#include "Eigen/src/Core/StableNorm.h" +#include "Eigen/src/Core/Stride.h" +#include "Eigen/src/Core/MapBase.h" +#include "Eigen/src/Core/Map.h" +#include "Eigen/src/Core/Ref.h" +#include "Eigen/src/Core/Block.h" +#include "Eigen/src/Core/VectorBlock.h" +#include "Eigen/src/Core/Transpose.h" +#include "Eigen/src/Core/DiagonalMatrix.h" +#include "Eigen/src/Core/Diagonal.h" +#include "Eigen/src/Core/DiagonalProduct.h" +#include "Eigen/src/Core/Redux.h" +#include "Eigen/src/Core/Visitor.h" +#include "Eigen/src/Core/Fuzzy.h" +#include "Eigen/src/Core/Swap.h" +#include "Eigen/src/Core/CommaInitializer.h" +#include "Eigen/src/Core/GeneralProduct.h" +#include "Eigen/src/Core/Solve.h" +#include "Eigen/src/Core/Inverse.h" +#include "Eigen/src/Core/SolverBase.h" +#include "Eigen/src/Core/PermutationMatrix.h" +#include "Eigen/src/Core/Transpositions.h" +#include "Eigen/src/Core/TriangularMatrix.h" +#include "Eigen/src/Core/SelfAdjointView.h" +#include "Eigen/src/Core/products/GeneralBlockPanelKernel.h" +#include "Eigen/src/Core/products/Parallelizer.h" +#include "Eigen/src/Core/ProductEvaluators.h" +#include "Eigen/src/Core/products/GeneralMatrixVector.h" +#include "Eigen/src/Core/products/GeneralMatrixMatrix.h" +#include "Eigen/src/Core/SolveTriangular.h" +#include "Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h" +#include "Eigen/src/Core/products/SelfadjointMatrixVector.h" +#include "Eigen/src/Core/products/SelfadjointMatrixMatrix.h" +#include "Eigen/src/Core/products/SelfadjointProduct.h" +#include "Eigen/src/Core/products/SelfadjointRank2Update.h" +#include "Eigen/src/Core/products/TriangularMatrixVector.h" +#include "Eigen/src/Core/products/TriangularMatrixMatrix.h" +#include "Eigen/src/Core/products/TriangularSolverMatrix.h" +#include "Eigen/src/Core/products/TriangularSolverVector.h" +#include "Eigen/src/Core/BandMatrix.h" +#include "Eigen/src/Core/CoreIterators.h" +#include "Eigen/src/Core/ConditionEstimator.h" + +#include "Eigen/src/Core/BooleanRedux.h" +#include "Eigen/src/Core/Select.h" +#include "Eigen/src/Core/VectorwiseOp.h" +#include "Eigen/src/Core/Random.h" +#include "Eigen/src/Core/Replicate.h" +#include "Eigen/src/Core/Reverse.h" +#include "Eigen/src/Core/ArrayWrapper.h" + +#ifdef EIGEN_USE_BLAS +#include "Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h" +#include "Eigen/src/Core/products/GeneralMatrixVector_BLAS.h" +#include "Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h" +#include "Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h" +#include "Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h" +#include "Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h" +#include "Eigen/src/Core/products/TriangularMatrixVector_BLAS.h" +#include "Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h" +#endif // EIGEN_USE_BLAS + +#ifdef EIGEN_USE_MKL_VML +#include "Eigen/src/Core/Assign_MKL.h" +#endif + +#include "Eigen/src/Core/GlobalFunctions.h" + +#include "Eigen/src/Core/util/ReenableStupidWarnings.h" + +#endif // EIGEN_CORE_H diff --git a/stan/math/prim/plugins/CoreEvaluators.h b/stan/math/prim/plugins/CoreEvaluators.h new file mode 100644 index 00000000000..d4f053c4f36 --- /dev/null +++ b/stan/math/prim/plugins/CoreEvaluators.h @@ -0,0 +1,1688 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2011 Benoit Jacob +// Copyright (C) 2011-2014 Gael Guennebaud +// Copyright (C) 2011-2012 Jitse Niesen +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +#ifndef EIGEN_COREEVALUATORS_H +#define EIGEN_COREEVALUATORS_H + +namespace Eigen { + +namespace internal { + +// This class returns the evaluator kind from the expression storage kind. +// Default assumes index based accessors +template +struct storage_kind_to_evaluator_kind { + typedef IndexBased Kind; +}; + +// This class returns the evaluator shape from the expression storage kind. +// It can be Dense, Sparse, Triangular, Diagonal, SelfAdjoint, Band, etc. +template struct storage_kind_to_shape; + +template<> struct storage_kind_to_shape { typedef DenseShape Shape; }; +template<> struct storage_kind_to_shape { typedef SolverShape Shape; }; +template<> struct storage_kind_to_shape { typedef PermutationShape Shape; }; +template<> struct storage_kind_to_shape { typedef TranspositionsShape Shape; }; + +// Evaluators have to be specialized with respect to various criteria such as: +// - storage/structure/shape +// - scalar type +// - etc. +// Therefore, we need specialization of evaluator providing additional template arguments for each kind of evaluators. +// We currently distinguish the following kind of evaluators: +// - unary_evaluator for expressions taking only one arguments (CwiseUnaryOp, CwiseUnaryView, Transpose, MatrixWrapper, ArrayWrapper, Reverse, Replicate) +// - binary_evaluator for expression taking two arguments (CwiseBinaryOp) +// - ternary_evaluator for expression taking three arguments (CwiseTernaryOp) +// - product_evaluator for linear algebra products (Product); special case of binary_evaluator because it requires additional tags for dispatching. +// - mapbase_evaluator for Map, Block, Ref +// - block_evaluator for Block (special dispatching to a mapbase_evaluator or unary_evaluator) + +template< typename T, + typename Arg1Kind = typename evaluator_traits::Kind, + typename Arg2Kind = typename evaluator_traits::Kind, + typename Arg3Kind = typename evaluator_traits::Kind, + typename Arg1Scalar = typename traits::Scalar, + typename Arg2Scalar = typename traits::Scalar, + typename Arg3Scalar = typename traits::Scalar> struct ternary_evaluator; + +template< typename T, + typename LhsKind = typename evaluator_traits::Kind, + typename RhsKind = typename evaluator_traits::Kind, + typename LhsScalar = typename traits::Scalar, + typename RhsScalar = typename traits::Scalar> struct binary_evaluator; + +template< typename T, + typename Kind = typename evaluator_traits::Kind, + typename Scalar = typename T::Scalar> struct unary_evaluator; + +// evaluator_traits contains traits for evaluator + +template +struct evaluator_traits_base +{ + // by default, get evaluator kind and shape from storage + typedef typename storage_kind_to_evaluator_kind::StorageKind>::Kind Kind; + typedef typename storage_kind_to_shape::StorageKind>::Shape Shape; +}; + +// Default evaluator traits +template +struct evaluator_traits : public evaluator_traits_base +{ +}; + +template::Shape > +struct evaluator_assume_aliasing { + static const bool value = false; +}; + +// By default, we assume a unary expression: +template +struct evaluator : public unary_evaluator +{ + typedef unary_evaluator Base; + EIGEN_DEVICE_FUNC explicit evaluator(const T& xpr) : Base(xpr) {} +}; + + +// TODO: Think about const-correctness +template +struct evaluator + : evaluator +{ + EIGEN_DEVICE_FUNC + explicit evaluator(const T& xpr) : evaluator(xpr) {} +}; + +// ---------- base class for all evaluators ---------- + +template +struct evaluator_base : public noncopyable +{ + // TODO that's not very nice to have to propagate all these traits. They are currently only needed to handle outer,inner indices. + typedef traits ExpressionTraits; + + enum { + Alignment = 0 + }; +}; + +// -------------------- Matrix and Array -------------------- +// +// evaluator is a common base class for the +// Matrix and Array evaluators. +// Here we directly specialize evaluator. This is not really a unary expression, and it is, by definition, dense, +// so no need for more sophisticated dispatching. + +template +struct evaluator > + : evaluator_base +{ + typedef PlainObjectBase PlainObjectType; + typedef typename PlainObjectType::Scalar Scalar; + typedef typename PlainObjectType::CoeffReturnType CoeffReturnType; + + enum { + IsRowMajor = PlainObjectType::IsRowMajor, + IsVectorAtCompileTime = PlainObjectType::IsVectorAtCompileTime, + RowsAtCompileTime = PlainObjectType::RowsAtCompileTime, + ColsAtCompileTime = PlainObjectType::ColsAtCompileTime, + + CoeffReadCost = NumTraits::ReadCost, + Flags = traits::EvaluatorFlags, + Alignment = traits::Alignment + }; + + EIGEN_DEVICE_FUNC evaluator() + : m_data(0), + m_outerStride(IsVectorAtCompileTime ? 0 + : int(IsRowMajor) ? ColsAtCompileTime + : RowsAtCompileTime) + { + EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); + } + + EIGEN_DEVICE_FUNC explicit evaluator(const PlainObjectType& m) + : m_data(m.data()), m_outerStride(IsVectorAtCompileTime ? 0 : m.outerStride()) + { + EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index row, Index col) const + { + if (IsRowMajor) + return m_data[row * m_outerStride.value() + col]; + else + return m_data[row + col * m_outerStride.value()]; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index index) const + { + return m_data[index]; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Scalar& coeffRef(Index row, Index col) + { + if (IsRowMajor) + return const_cast(m_data)[row * m_outerStride.value() + col]; + else + return const_cast(m_data)[row + col * m_outerStride.value()]; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Scalar& coeffRef(Index index) + { + return const_cast(m_data)[index]; + } + + template + EIGEN_STRONG_INLINE + PacketType packet(Index row, Index col) const + { + if (IsRowMajor) + return ploadt(m_data + row * m_outerStride.value() + col); + else + return ploadt(m_data + row + col * m_outerStride.value()); + } + + template + EIGEN_STRONG_INLINE + PacketType packet(Index index) const + { + return ploadt(m_data + index); + } + + template + EIGEN_STRONG_INLINE + void writePacket(Index row, Index col, const PacketType& x) + { + if (IsRowMajor) + return pstoret + (const_cast(m_data) + row * m_outerStride.value() + col, x); + else + return pstoret + (const_cast(m_data) + row + col * m_outerStride.value(), x); + } + + template + EIGEN_STRONG_INLINE + void writePacket(Index index, const PacketType& x) + { + return pstoret(const_cast(m_data) + index, x); + } + +protected: + const Scalar *m_data; + + // We do not need to know the outer stride for vectors + variable_if_dynamic m_outerStride; +}; + +template +struct evaluator > + : evaluator > > +{ + typedef Matrix XprType; + + EIGEN_DEVICE_FUNC evaluator() {} + + EIGEN_DEVICE_FUNC explicit evaluator(const XprType& m) + : evaluator >(m) + { } +}; + +template +struct evaluator > + : evaluator > > +{ + typedef Array XprType; + + EIGEN_DEVICE_FUNC evaluator() {} + + EIGEN_DEVICE_FUNC explicit evaluator(const XprType& m) + : evaluator >(m) + { } +}; + +// -------------------- Transpose -------------------- + +template +struct unary_evaluator, IndexBased> + : evaluator_base > +{ + typedef Transpose XprType; + + enum { + CoeffReadCost = evaluator::CoeffReadCost, + Flags = evaluator::Flags ^ RowMajorBit, + Alignment = evaluator::Alignment + }; + + EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& t) : m_argImpl(t.nestedExpression()) {} + + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index row, Index col) const + { + return m_argImpl.coeff(col, row); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index index) const + { + return m_argImpl.coeff(index); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Scalar& coeffRef(Index row, Index col) + { + return m_argImpl.coeffRef(col, row); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + typename XprType::Scalar& coeffRef(Index index) + { + return m_argImpl.coeffRef(index); + } + + template + EIGEN_STRONG_INLINE + PacketType packet(Index row, Index col) const + { + return m_argImpl.template packet(col, row); + } + + template + EIGEN_STRONG_INLINE + PacketType packet(Index index) const + { + return m_argImpl.template packet(index); + } + + template + EIGEN_STRONG_INLINE + void writePacket(Index row, Index col, const PacketType& x) + { + m_argImpl.template writePacket(col, row, x); + } + + template + EIGEN_STRONG_INLINE + void writePacket(Index index, const PacketType& x) + { + m_argImpl.template writePacket(index, x); + } + +protected: + evaluator m_argImpl; +}; + +// -------------------- CwiseNullaryOp -------------------- +// Like Matrix and Array, this is not really a unary expression, so we directly specialize evaluator. +// Likewise, there is not need to more sophisticated dispatching here. + +template::value, + bool has_unary = has_unary_operator::value, + bool has_binary = has_binary_operator::value> +struct nullary_wrapper +{ + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i, IndexType j) const { return op(i,j); } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i) const { return op(i); } + + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i, IndexType j) const { return op.template packetOp(i,j); } + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i) const { return op.template packetOp(i); } +}; + +template +struct nullary_wrapper +{ + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType=0, IndexType=0) const { return op(); } + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType=0, IndexType=0) const { return op.template packetOp(); } +}; + +template +struct nullary_wrapper +{ + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i, IndexType j=0) const { return op(i,j); } + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i, IndexType j=0) const { return op.template packetOp(i,j); } +}; + +// We need the following specialization for vector-only functors assigned to a runtime vector, +// for instance, using linspace and assigning a RowVectorXd to a MatrixXd or even a row of a MatrixXd. +// In this case, i==0 and j is used for the actual iteration. +template +struct nullary_wrapper +{ + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i, IndexType j) const { + eigen_assert(i==0 || j==0); + return op(i+j); + } + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i, IndexType j) const { + eigen_assert(i==0 || j==0); + return op.template packetOp(i+j); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i) const { return op(i); } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i) const { return op.template packetOp(i); } +}; + +template +struct nullary_wrapper {}; + +#if 0 && EIGEN_COMP_MSVC>0 +// Disable this ugly workaround. This is now handled in traits::match, +// but this piece of code might still become handly if some other weird compilation +// erros pop up again. + +// MSVC exhibits a weird compilation error when +// compiling: +// Eigen::MatrixXf A = MatrixXf::Random(3,3); +// Ref R = 2.f*A; +// and that has_*ary_operator> have not been instantiated yet. +// The "problem" is that evaluator<2.f*A> is instantiated by traits::match<2.f*A> +// and at that time has_*ary_operator returns true regardless of T. +// Then nullary_wrapper is badly instantiated as nullary_wrapper<.,.,true,true,true>. +// The trick is thus to defer the proper instantiation of nullary_wrapper when coeff(), +// and packet() are really instantiated as implemented below: + +// This is a simple wrapper around Index to enforce the re-instantiation of +// has_*ary_operator when needed. +template struct nullary_wrapper_workaround_msvc { + nullary_wrapper_workaround_msvc(const T&); + operator T()const; +}; + +template +struct nullary_wrapper +{ + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i, IndexType j) const { + return nullary_wrapper >::value, + has_unary_operator >::value, + has_binary_operator >::value>().operator()(op,i,j); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i) const { + return nullary_wrapper >::value, + has_unary_operator >::value, + has_binary_operator >::value>().operator()(op,i); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i, IndexType j) const { + return nullary_wrapper >::value, + has_unary_operator >::value, + has_binary_operator >::value>().template packetOp(op,i,j); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i) const { + return nullary_wrapper >::value, + has_unary_operator >::value, + has_binary_operator >::value>().template packetOp(op,i); + } +}; +#endif // MSVC workaround + +template +struct evaluator > + : evaluator_base > +{ + typedef CwiseNullaryOp XprType; + typedef typename internal::remove_all::type PlainObjectTypeCleaned; + + enum { + CoeffReadCost = internal::functor_traits::Cost, + + Flags = (evaluator::Flags + & ( HereditaryBits + | (functor_has_linear_access::ret ? LinearAccessBit : 0) + | (functor_traits::PacketAccess ? PacketAccessBit : 0))) + | (functor_traits::IsRepeatable ? 0 : EvalBeforeNestingBit), + Alignment = AlignedMax + }; + + EIGEN_DEVICE_FUNC explicit evaluator(const XprType& n) + : m_functor(n.functor()), m_wrapper() + { + EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); + } + + typedef typename XprType::CoeffReturnType CoeffReturnType; + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(IndexType row, IndexType col) const + { + return m_wrapper(m_functor, row, col); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(IndexType index) const + { + return m_wrapper(m_functor,index); + } + + template + EIGEN_STRONG_INLINE + PacketType packet(IndexType row, IndexType col) const + { + return m_wrapper.template packetOp(m_functor, row, col); + } + + template + EIGEN_STRONG_INLINE + PacketType packet(IndexType index) const + { + return m_wrapper.template packetOp(m_functor, index); + } + +protected: + const NullaryOp m_functor; + const internal::nullary_wrapper m_wrapper; +}; + +// -------------------- CwiseUnaryOp -------------------- + +template +struct unary_evaluator, IndexBased > + : evaluator_base > +{ + typedef CwiseUnaryOp XprType; + + enum { + CoeffReadCost = evaluator::CoeffReadCost + functor_traits::Cost, + + Flags = evaluator::Flags + & (HereditaryBits | LinearAccessBit | (functor_traits::PacketAccess ? PacketAccessBit : 0)), + Alignment = evaluator::Alignment + }; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + explicit unary_evaluator(const XprType& op) + : m_functor(op.functor()), + m_argImpl(op.nestedExpression()) + { + EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits::Cost); + EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); + } + + typedef typename XprType::CoeffReturnType CoeffReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index row, Index col) const + { + return m_functor(m_argImpl.coeff(row, col)); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index index) const + { + return m_functor(m_argImpl.coeff(index)); + } + + template + EIGEN_STRONG_INLINE + PacketType packet(Index row, Index col) const + { + return m_functor.packetOp(m_argImpl.template packet(row, col)); + } + + template + EIGEN_STRONG_INLINE + PacketType packet(Index index) const + { + return m_functor.packetOp(m_argImpl.template packet(index)); + } + +protected: + const UnaryOp m_functor; + evaluator m_argImpl; +}; + +// -------------------- CwiseTernaryOp -------------------- + +// this is a ternary expression +template +struct evaluator > + : public ternary_evaluator > +{ + typedef CwiseTernaryOp XprType; + typedef ternary_evaluator > Base; + + EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr) : Base(xpr) {} +}; + +template +struct ternary_evaluator, IndexBased, IndexBased> + : evaluator_base > +{ + typedef CwiseTernaryOp XprType; + + enum { + CoeffReadCost = evaluator::CoeffReadCost + evaluator::CoeffReadCost + evaluator::CoeffReadCost + functor_traits::Cost, + + Arg1Flags = evaluator::Flags, + Arg2Flags = evaluator::Flags, + Arg3Flags = evaluator::Flags, + SameType = is_same::value && is_same::value, + StorageOrdersAgree = (int(Arg1Flags)&RowMajorBit)==(int(Arg2Flags)&RowMajorBit) && (int(Arg1Flags)&RowMajorBit)==(int(Arg3Flags)&RowMajorBit), + Flags0 = (int(Arg1Flags) | int(Arg2Flags) | int(Arg3Flags)) & ( + HereditaryBits + | (int(Arg1Flags) & int(Arg2Flags) & int(Arg3Flags) & + ( (StorageOrdersAgree ? LinearAccessBit : 0) + | (functor_traits::PacketAccess && StorageOrdersAgree && SameType ? PacketAccessBit : 0) + ) + ) + ), + Flags = (Flags0 & ~RowMajorBit) | (Arg1Flags & RowMajorBit), + Alignment = EIGEN_PLAIN_ENUM_MIN( + EIGEN_PLAIN_ENUM_MIN(evaluator::Alignment, evaluator::Alignment), + evaluator::Alignment) + }; + + EIGEN_DEVICE_FUNC explicit ternary_evaluator(const XprType& xpr) + : m_functor(xpr.functor()), + m_arg1Impl(xpr.arg1()), + m_arg2Impl(xpr.arg2()), + m_arg3Impl(xpr.arg3()) + { + EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits::Cost); + EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); + } + + typedef typename XprType::CoeffReturnType CoeffReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index row, Index col) const + { + return m_functor(m_arg1Impl.coeff(row, col), m_arg2Impl.coeff(row, col), m_arg3Impl.coeff(row, col)); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index index) const + { + return m_functor(m_arg1Impl.coeff(index), m_arg2Impl.coeff(index), m_arg3Impl.coeff(index)); + } + + template + EIGEN_STRONG_INLINE + PacketType packet(Index row, Index col) const + { + return m_functor.packetOp(m_arg1Impl.template packet(row, col), + m_arg2Impl.template packet(row, col), + m_arg3Impl.template packet(row, col)); + } + + template + EIGEN_STRONG_INLINE + PacketType packet(Index index) const + { + return m_functor.packetOp(m_arg1Impl.template packet(index), + m_arg2Impl.template packet(index), + m_arg3Impl.template packet(index)); + } + +protected: + const TernaryOp m_functor; + evaluator m_arg1Impl; + evaluator m_arg2Impl; + evaluator m_arg3Impl; +}; + +// -------------------- CwiseBinaryOp -------------------- + +// this is a binary expression +template +struct evaluator > + : public binary_evaluator > +{ + typedef CwiseBinaryOp XprType; + typedef binary_evaluator > Base; + + EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr) : Base(xpr) {} +}; + +template +struct binary_evaluator, IndexBased, IndexBased> + : evaluator_base > +{ + typedef CwiseBinaryOp XprType; + + enum { + CoeffReadCost = evaluator::CoeffReadCost + evaluator::CoeffReadCost + functor_traits::Cost, + + LhsFlags = evaluator::Flags, + RhsFlags = evaluator::Flags, + SameType = is_same::value, + StorageOrdersAgree = (int(LhsFlags)&RowMajorBit)==(int(RhsFlags)&RowMajorBit), + Flags0 = (int(LhsFlags) | int(RhsFlags)) & ( + HereditaryBits + | (int(LhsFlags) & int(RhsFlags) & + ( (StorageOrdersAgree ? LinearAccessBit : 0) + | (functor_traits::PacketAccess && StorageOrdersAgree && SameType ? PacketAccessBit : 0) + ) + ) + ), + Flags = (Flags0 & ~RowMajorBit) | (LhsFlags & RowMajorBit), + Alignment = EIGEN_PLAIN_ENUM_MIN(evaluator::Alignment,evaluator::Alignment) + }; + + EIGEN_DEVICE_FUNC explicit binary_evaluator(const XprType& xpr) + : m_functor(xpr.functor()), + m_lhsImpl(xpr.lhs()), + m_rhsImpl(xpr.rhs()) + { + EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits::Cost); + EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); + } + + typedef typename XprType::CoeffReturnType CoeffReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index row, Index col) const + { + return m_functor(m_lhsImpl.coeff(row, col), m_rhsImpl.coeff(row, col)); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index index) const + { + return m_functor(m_lhsImpl.coeff(index), m_rhsImpl.coeff(index)); + } + + template + EIGEN_STRONG_INLINE + PacketType packet(Index row, Index col) const + { + return m_functor.packetOp(m_lhsImpl.template packet(row, col), + m_rhsImpl.template packet(row, col)); + } + + template + EIGEN_STRONG_INLINE + PacketType packet(Index index) const + { + return m_functor.packetOp(m_lhsImpl.template packet(index), + m_rhsImpl.template packet(index)); + } + +protected: + const BinaryOp m_functor; + evaluator m_lhsImpl; + evaluator m_rhsImpl; +}; + +// -------------------- CwiseUnaryView -------------------- + +template +struct unary_evaluator, IndexBased> + : evaluator_base > +{ + typedef CwiseUnaryView XprType; + + enum { + CoeffReadCost = evaluator::CoeffReadCost + functor_traits::Cost, + + Flags = (evaluator::Flags & (HereditaryBits | LinearAccessBit | DirectAccessBit)), + + Alignment = 0 // FIXME it is not very clear why alignment is necessarily lost... + }; + + EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& op) + : m_unaryOp(op.functor()), + m_argImpl(op.nestedExpression()) + { + EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits::Cost); + EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); + } + + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index row, Index col) const + { + return m_unaryOp(m_argImpl.coeff(row, col)); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index index) const + { + return m_unaryOp(m_argImpl.coeff(index)); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Scalar& coeffRef(Index row, Index col) + { + return m_unaryOp(m_argImpl.coeffRef(row, col)); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Scalar& coeffRef(Index index) + { + return m_unaryOp(m_argImpl.coeffRef(index)); + } + +protected: + const UnaryOp m_unaryOp; + evaluator m_argImpl; +}; + +// -------------------- Map -------------------- + +// FIXME perhaps the PlainObjectType could be provided by Derived::PlainObject ? +// but that might complicate template specialization +template +struct mapbase_evaluator; + +template +struct mapbase_evaluator : evaluator_base +{ + typedef Derived XprType; + typedef typename XprType::PointerType PointerType; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + + enum { + IsRowMajor = XprType::RowsAtCompileTime, + ColsAtCompileTime = XprType::ColsAtCompileTime, + CoeffReadCost = NumTraits::ReadCost + }; + + EIGEN_DEVICE_FUNC explicit mapbase_evaluator(const XprType& map) + : m_data(const_cast(map.data())), + m_innerStride(map.innerStride()), + m_outerStride(map.outerStride()) + { + EIGEN_STATIC_ASSERT(EIGEN_IMPLIES(evaluator::Flags&PacketAccessBit, internal::inner_stride_at_compile_time::ret==1), + PACKET_ACCESS_REQUIRES_TO_HAVE_INNER_STRIDE_FIXED_TO_1); + EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index row, Index col) const + { + return m_data[col * colStride() + row * rowStride()]; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index index) const + { + return m_data[index * m_innerStride.value()]; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Scalar& coeffRef(Index row, Index col) + { + return m_data[col * colStride() + row * rowStride()]; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Scalar& coeffRef(Index index) + { + return m_data[index * m_innerStride.value()]; + } + + template + EIGEN_STRONG_INLINE + PacketType packet(Index row, Index col) const + { + PointerType ptr = m_data + row * rowStride() + col * colStride(); + return internal::ploadt(ptr); + } + + template + EIGEN_STRONG_INLINE + PacketType packet(Index index) const + { + return internal::ploadt(m_data + index * m_innerStride.value()); + } + + template + EIGEN_STRONG_INLINE + void writePacket(Index row, Index col, const PacketType& x) + { + PointerType ptr = m_data + row * rowStride() + col * colStride(); + return internal::pstoret(ptr, x); + } + + template + EIGEN_STRONG_INLINE + void writePacket(Index index, const PacketType& x) + { + internal::pstoret(m_data + index * m_innerStride.value(), x); + } +protected: + EIGEN_DEVICE_FUNC + inline Index rowStride() const { return XprType::IsRowMajor ? m_outerStride.value() : m_innerStride.value(); } + EIGEN_DEVICE_FUNC + inline Index colStride() const { return XprType::IsRowMajor ? m_innerStride.value() : m_outerStride.value(); } + + PointerType m_data; + const internal::variable_if_dynamic m_innerStride; + const internal::variable_if_dynamic m_outerStride; +}; + +template +struct evaluator > + : public mapbase_evaluator, PlainObjectType> +{ + typedef Map XprType; + typedef typename XprType::Scalar Scalar; + // TODO: should check for smaller packet types once we can handle multi-sized packet types + typedef typename packet_traits::type PacketScalar; + + enum { + InnerStrideAtCompileTime = StrideType::InnerStrideAtCompileTime == 0 + ? int(PlainObjectType::InnerStrideAtCompileTime) + : int(StrideType::InnerStrideAtCompileTime), + OuterStrideAtCompileTime = StrideType::OuterStrideAtCompileTime == 0 + ? int(PlainObjectType::OuterStrideAtCompileTime) + : int(StrideType::OuterStrideAtCompileTime), + HasNoInnerStride = InnerStrideAtCompileTime == 1, + HasNoOuterStride = StrideType::OuterStrideAtCompileTime == 0, + HasNoStride = HasNoInnerStride && HasNoOuterStride, + IsDynamicSize = PlainObjectType::SizeAtCompileTime==Dynamic, + + PacketAccessMask = bool(HasNoInnerStride) ? ~int(0) : ~int(PacketAccessBit), + LinearAccessMask = bool(HasNoStride) || bool(PlainObjectType::IsVectorAtCompileTime) ? ~int(0) : ~int(LinearAccessBit), + Flags = int( evaluator::Flags) & (LinearAccessMask&PacketAccessMask), + + Alignment = int(MapOptions)&int(AlignedMask) + }; + + EIGEN_DEVICE_FUNC explicit evaluator(const XprType& map) + : mapbase_evaluator(map) + { } +}; + +// -------------------- Ref -------------------- + +template +struct evaluator > + : public mapbase_evaluator, PlainObjectType> +{ + typedef Ref XprType; + + enum { + Flags = evaluator >::Flags, + Alignment = evaluator >::Alignment + }; + + EIGEN_DEVICE_FUNC explicit evaluator(const XprType& ref) + : mapbase_evaluator(ref) + { } +}; + +// -------------------- Block -------------------- + +template::ret> struct block_evaluator; + +template +struct evaluator > + : block_evaluator +{ + typedef Block XprType; + typedef typename XprType::Scalar Scalar; + // TODO: should check for smaller packet types once we can handle multi-sized packet types + typedef typename packet_traits::type PacketScalar; + + enum { + CoeffReadCost = evaluator::CoeffReadCost, + + RowsAtCompileTime = traits::RowsAtCompileTime, + ColsAtCompileTime = traits::ColsAtCompileTime, + MaxRowsAtCompileTime = traits::MaxRowsAtCompileTime, + MaxColsAtCompileTime = traits::MaxColsAtCompileTime, + + ArgTypeIsRowMajor = (int(evaluator::Flags)&RowMajorBit) != 0, + IsRowMajor = (MaxRowsAtCompileTime==1 && MaxColsAtCompileTime!=1) ? 1 + : (MaxColsAtCompileTime==1 && MaxRowsAtCompileTime!=1) ? 0 + : ArgTypeIsRowMajor, + HasSameStorageOrderAsArgType = (IsRowMajor == ArgTypeIsRowMajor), + InnerSize = IsRowMajor ? int(ColsAtCompileTime) : int(RowsAtCompileTime), + InnerStrideAtCompileTime = HasSameStorageOrderAsArgType + ? int(inner_stride_at_compile_time::ret) + : int(outer_stride_at_compile_time::ret), + OuterStrideAtCompileTime = HasSameStorageOrderAsArgType + ? int(outer_stride_at_compile_time::ret) + : int(inner_stride_at_compile_time::ret), + MaskPacketAccessBit = (InnerStrideAtCompileTime == 1 || HasSameStorageOrderAsArgType) ? PacketAccessBit : 0, + + FlagsLinearAccessBit = (RowsAtCompileTime == 1 || ColsAtCompileTime == 1 || (InnerPanel && (evaluator::Flags&LinearAccessBit))) ? LinearAccessBit : 0, + FlagsRowMajorBit = XprType::Flags&RowMajorBit, + Flags0 = evaluator::Flags & ( (HereditaryBits & ~RowMajorBit) | + DirectAccessBit | + MaskPacketAccessBit), + Flags = Flags0 | FlagsLinearAccessBit | FlagsRowMajorBit, + + PacketAlignment = unpacket_traits::alignment, + Alignment0 = (InnerPanel && (OuterStrideAtCompileTime!=Dynamic) + && (OuterStrideAtCompileTime!=0) + && (((OuterStrideAtCompileTime * int(sizeof(Scalar))) % int(PacketAlignment)) == 0)) ? int(PacketAlignment) : 0, + Alignment = EIGEN_PLAIN_ENUM_MIN(evaluator::Alignment, Alignment0) + }; + typedef block_evaluator block_evaluator_type; + EIGEN_DEVICE_FUNC explicit evaluator(const XprType& block) : block_evaluator_type(block) + { + EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); + } +}; + +// no direct-access => dispatch to a unary evaluator +template +struct block_evaluator + : unary_evaluator > +{ + typedef Block XprType; + + EIGEN_DEVICE_FUNC explicit block_evaluator(const XprType& block) + : unary_evaluator(block) + {} +}; + +template +struct unary_evaluator, IndexBased> + : evaluator_base > +{ + typedef Block XprType; + + EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& block) + : m_argImpl(block.nestedExpression()), + m_startRow(block.startRow()), + m_startCol(block.startCol()), + m_linear_offset(InnerPanel?(XprType::IsRowMajor ? block.startRow()*block.cols() : block.startCol()*block.rows()):0) + { } + + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + + enum { + RowsAtCompileTime = XprType::RowsAtCompileTime, + ForwardLinearAccess = InnerPanel && bool(evaluator::Flags&LinearAccessBit) + }; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index row, Index col) const + { + return m_argImpl.coeff(m_startRow.value() + row, m_startCol.value() + col); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index index) const + { + if (ForwardLinearAccess) + return m_argImpl.coeff(m_linear_offset.value() + index); + else + return coeff(RowsAtCompileTime == 1 ? 0 : index, RowsAtCompileTime == 1 ? index : 0); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Scalar& coeffRef(Index row, Index col) + { + return m_argImpl.coeffRef(m_startRow.value() + row, m_startCol.value() + col); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Scalar& coeffRef(Index index) + { + if (ForwardLinearAccess) + return m_argImpl.coeffRef(m_linear_offset.value() + index); + else + return coeffRef(RowsAtCompileTime == 1 ? 0 : index, RowsAtCompileTime == 1 ? index : 0); + } + + template + EIGEN_STRONG_INLINE + PacketType packet(Index row, Index col) const + { + return m_argImpl.template packet(m_startRow.value() + row, m_startCol.value() + col); + } + + template + EIGEN_STRONG_INLINE + PacketType packet(Index index) const + { + if (ForwardLinearAccess) + return m_argImpl.template packet(m_linear_offset.value() + index); + else + return packet(RowsAtCompileTime == 1 ? 0 : index, + RowsAtCompileTime == 1 ? index : 0); + } + + template + EIGEN_STRONG_INLINE + void writePacket(Index row, Index col, const PacketType& x) + { + return m_argImpl.template writePacket(m_startRow.value() + row, m_startCol.value() + col, x); + } + + template + EIGEN_STRONG_INLINE + void writePacket(Index index, const PacketType& x) + { + if (ForwardLinearAccess) + return m_argImpl.template writePacket(m_linear_offset.value() + index, x); + else + return writePacket(RowsAtCompileTime == 1 ? 0 : index, + RowsAtCompileTime == 1 ? index : 0, + x); + } + +protected: + evaluator m_argImpl; + const variable_if_dynamic m_startRow; + const variable_if_dynamic m_startCol; + const variable_if_dynamic m_linear_offset; +}; + +// TODO: This evaluator does not actually use the child evaluator; +// all action is via the data() as returned by the Block expression. + +template +struct block_evaluator + : mapbase_evaluator, + typename Block::PlainObject> +{ + typedef Block XprType; + typedef typename XprType::Scalar Scalar; + + EIGEN_DEVICE_FUNC explicit block_evaluator(const XprType& block) + : mapbase_evaluator(block) + { + // TODO: for the 3.3 release, this should be turned to an internal assertion, but let's keep it as is for the beta lifetime + eigen_assert(((internal::UIntPtr(block.data()) % EIGEN_PLAIN_ENUM_MAX(1,evaluator::Alignment)) == 0) && "data is not aligned"); + } +}; + + +// -------------------- Select -------------------- +// NOTE shall we introduce a ternary_evaluator? + +// TODO enable vectorization for Select +template +struct evaluator > + : evaluator_base > +{ + typedef Select XprType; + enum { + CoeffReadCost = evaluator::CoeffReadCost + + EIGEN_PLAIN_ENUM_MAX(evaluator::CoeffReadCost, + evaluator::CoeffReadCost), + + Flags = (unsigned int)evaluator::Flags & evaluator::Flags & HereditaryBits, + + Alignment = EIGEN_PLAIN_ENUM_MIN(evaluator::Alignment, evaluator::Alignment) + }; + + EIGEN_DEVICE_FUNC explicit evaluator(const XprType& select) + : m_conditionImpl(select.conditionMatrix()), + m_thenImpl(select.thenMatrix()), + m_elseImpl(select.elseMatrix()) + { + EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); + } + + typedef typename XprType::CoeffReturnType CoeffReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index row, Index col) const + { + if (m_conditionImpl.coeff(row, col)) + return m_thenImpl.coeff(row, col); + else + return m_elseImpl.coeff(row, col); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index index) const + { + if (m_conditionImpl.coeff(index)) + return m_thenImpl.coeff(index); + else + return m_elseImpl.coeff(index); + } + +protected: + evaluator m_conditionImpl; + evaluator m_thenImpl; + evaluator m_elseImpl; +}; + + +// -------------------- Replicate -------------------- + +template +struct unary_evaluator > + : evaluator_base > +{ + typedef Replicate XprType; + typedef typename XprType::CoeffReturnType CoeffReturnType; + enum { + Factor = (RowFactor==Dynamic || ColFactor==Dynamic) ? Dynamic : RowFactor*ColFactor + }; + typedef typename internal::nested_eval::type ArgTypeNested; + typedef typename internal::remove_all::type ArgTypeNestedCleaned; + + enum { + CoeffReadCost = evaluator::CoeffReadCost, + LinearAccessMask = XprType::IsVectorAtCompileTime ? LinearAccessBit : 0, + Flags = (evaluator::Flags & (HereditaryBits|LinearAccessMask) & ~RowMajorBit) | (traits::Flags & RowMajorBit), + + Alignment = evaluator::Alignment + }; + + EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& replicate) + : m_arg(replicate.nestedExpression()), + m_argImpl(m_arg), + m_rows(replicate.nestedExpression().rows()), + m_cols(replicate.nestedExpression().cols()) + {} + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index row, Index col) const + { + // try to avoid using modulo; this is a pure optimization strategy + const Index actual_row = internal::traits::RowsAtCompileTime==1 ? 0 + : RowFactor==1 ? row + : row % m_rows.value(); + const Index actual_col = internal::traits::ColsAtCompileTime==1 ? 0 + : ColFactor==1 ? col + : col % m_cols.value(); + + return m_argImpl.coeff(actual_row, actual_col); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index index) const + { + // try to avoid using modulo; this is a pure optimization strategy + const Index actual_index = internal::traits::RowsAtCompileTime==1 + ? (ColFactor==1 ? index : index%m_cols.value()) + : (RowFactor==1 ? index : index%m_rows.value()); + + return m_argImpl.coeff(actual_index); + } + + template + EIGEN_STRONG_INLINE + PacketType packet(Index row, Index col) const + { + const Index actual_row = internal::traits::RowsAtCompileTime==1 ? 0 + : RowFactor==1 ? row + : row % m_rows.value(); + const Index actual_col = internal::traits::ColsAtCompileTime==1 ? 0 + : ColFactor==1 ? col + : col % m_cols.value(); + + return m_argImpl.template packet(actual_row, actual_col); + } + + template + EIGEN_STRONG_INLINE + PacketType packet(Index index) const + { + const Index actual_index = internal::traits::RowsAtCompileTime==1 + ? (ColFactor==1 ? index : index%m_cols.value()) + : (RowFactor==1 ? index : index%m_rows.value()); + + return m_argImpl.template packet(actual_index); + } + +protected: + const ArgTypeNested m_arg; + evaluator m_argImpl; + const variable_if_dynamic m_rows; + const variable_if_dynamic m_cols; +}; + + +// -------------------- PartialReduxExpr -------------------- + +template< typename ArgType, typename MemberOp, int Direction> +struct evaluator > + : evaluator_base > +{ + typedef PartialReduxExpr XprType; + typedef typename internal::nested_eval::type ArgTypeNested; + typedef typename internal::remove_all::type ArgTypeNestedCleaned; + typedef typename ArgType::Scalar InputScalar; + typedef typename XprType::Scalar Scalar; + enum { + TraversalSize = Direction==int(Vertical) ? int(ArgType::RowsAtCompileTime) : int(ArgType::ColsAtCompileTime) + }; + typedef typename MemberOp::template Cost CostOpType; + enum { + CoeffReadCost = TraversalSize==Dynamic ? HugeCost + : TraversalSize * evaluator::CoeffReadCost + int(CostOpType::value), + + Flags = (traits::Flags&RowMajorBit) | (evaluator::Flags&(HereditaryBits&(~RowMajorBit))) | LinearAccessBit, + + Alignment = 0 // FIXME this will need to be improved once PartialReduxExpr is vectorized + }; + + EIGEN_DEVICE_FUNC explicit evaluator(const XprType xpr) + : m_arg(xpr.nestedExpression()), m_functor(xpr.functor()) + { + EIGEN_INTERNAL_CHECK_COST_VALUE(TraversalSize==Dynamic ? HugeCost : int(CostOpType::value)); + EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); + } + + typedef typename XprType::CoeffReturnType CoeffReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const Scalar coeff(Index i, Index j) const + { + if (Direction==Vertical) + return m_functor(m_arg.col(j)); + else + return m_functor(m_arg.row(i)); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const Scalar coeff(Index index) const + { + if (Direction==Vertical) + return m_functor(m_arg.col(index)); + else + return m_functor(m_arg.row(index)); + } + +protected: + typename internal::add_const_on_value_type::type m_arg; + const MemberOp m_functor; +}; + + +// -------------------- MatrixWrapper and ArrayWrapper -------------------- +// +// evaluator_wrapper_base is a common base class for the +// MatrixWrapper and ArrayWrapper evaluators. + +template +struct evaluator_wrapper_base + : evaluator_base +{ + typedef typename remove_all::type ArgType; + enum { + CoeffReadCost = evaluator::CoeffReadCost, + Flags = evaluator::Flags, + Alignment = evaluator::Alignment + }; + + EIGEN_DEVICE_FUNC explicit evaluator_wrapper_base(const ArgType& arg) : m_argImpl(arg) {} + + typedef typename ArgType::Scalar Scalar; + typedef typename ArgType::CoeffReturnType CoeffReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index row, Index col) const + { + return m_argImpl.coeff(row, col); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index index) const + { + return m_argImpl.coeff(index); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Scalar& coeffRef(Index row, Index col) + { + return m_argImpl.coeffRef(row, col); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Scalar& coeffRef(Index index) + { + return m_argImpl.coeffRef(index); + } + + template + EIGEN_STRONG_INLINE + PacketType packet(Index row, Index col) const + { + return m_argImpl.template packet(row, col); + } + + template + EIGEN_STRONG_INLINE + PacketType packet(Index index) const + { + return m_argImpl.template packet(index); + } + + template + EIGEN_STRONG_INLINE + void writePacket(Index row, Index col, const PacketType& x) + { + m_argImpl.template writePacket(row, col, x); + } + + template + EIGEN_STRONG_INLINE + void writePacket(Index index, const PacketType& x) + { + m_argImpl.template writePacket(index, x); + } + +protected: + evaluator m_argImpl; +}; + +template +struct unary_evaluator > + : evaluator_wrapper_base > +{ + typedef MatrixWrapper XprType; + + EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& wrapper) + : evaluator_wrapper_base >(wrapper.nestedExpression()) + { } +}; + +template +struct unary_evaluator > + : evaluator_wrapper_base > +{ + typedef ArrayWrapper XprType; + + EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& wrapper) + : evaluator_wrapper_base >(wrapper.nestedExpression()) + { } +}; + + +// -------------------- Reverse -------------------- + +// defined in Reverse.h: +template struct reverse_packet_cond; + +template +struct unary_evaluator > + : evaluator_base > +{ + typedef Reverse XprType; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + + enum { + IsRowMajor = XprType::IsRowMajor, + IsColMajor = !IsRowMajor, + ReverseRow = (Direction == Vertical) || (Direction == BothDirections), + ReverseCol = (Direction == Horizontal) || (Direction == BothDirections), + ReversePacket = (Direction == BothDirections) + || ((Direction == Vertical) && IsColMajor) + || ((Direction == Horizontal) && IsRowMajor), + + CoeffReadCost = evaluator::CoeffReadCost, + + // let's enable LinearAccess only with vectorization because of the product overhead + // FIXME enable DirectAccess with negative strides? + Flags0 = evaluator::Flags, + LinearAccess = ( (Direction==BothDirections) && (int(Flags0)&PacketAccessBit) ) + || ((ReverseRow && XprType::ColsAtCompileTime==1) || (ReverseCol && XprType::RowsAtCompileTime==1)) + ? LinearAccessBit : 0, + + Flags = int(Flags0) & (HereditaryBits | PacketAccessBit | LinearAccess), + + Alignment = 0 // FIXME in some rare cases, Alignment could be preserved, like a Vector4f. + }; + + EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& reverse) + : m_argImpl(reverse.nestedExpression()), + m_rows(ReverseRow ? reverse.nestedExpression().rows() : 1), + m_cols(ReverseCol ? reverse.nestedExpression().cols() : 1) + { } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index row, Index col) const + { + return m_argImpl.coeff(ReverseRow ? m_rows.value() - row - 1 : row, + ReverseCol ? m_cols.value() - col - 1 : col); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index index) const + { + return m_argImpl.coeff(m_rows.value() * m_cols.value() - index - 1); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Scalar& coeffRef(Index row, Index col) + { + return m_argImpl.coeffRef(ReverseRow ? m_rows.value() - row - 1 : row, + ReverseCol ? m_cols.value() - col - 1 : col); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Scalar& coeffRef(Index index) + { + return m_argImpl.coeffRef(m_rows.value() * m_cols.value() - index - 1); + } + + template + EIGEN_STRONG_INLINE + PacketType packet(Index row, Index col) const + { + enum { + PacketSize = unpacket_traits::size, + OffsetRow = ReverseRow && IsColMajor ? PacketSize : 1, + OffsetCol = ReverseCol && IsRowMajor ? PacketSize : 1 + }; + typedef internal::reverse_packet_cond reverse_packet; + return reverse_packet::run(m_argImpl.template packet( + ReverseRow ? m_rows.value() - row - OffsetRow : row, + ReverseCol ? m_cols.value() - col - OffsetCol : col)); + } + + template + EIGEN_STRONG_INLINE + PacketType packet(Index index) const + { + enum { PacketSize = unpacket_traits::size }; + return preverse(m_argImpl.template packet(m_rows.value() * m_cols.value() - index - PacketSize)); + } + + template + EIGEN_STRONG_INLINE + void writePacket(Index row, Index col, const PacketType& x) + { + // FIXME we could factorize some code with packet(i,j) + enum { + PacketSize = unpacket_traits::size, + OffsetRow = ReverseRow && IsColMajor ? PacketSize : 1, + OffsetCol = ReverseCol && IsRowMajor ? PacketSize : 1 + }; + typedef internal::reverse_packet_cond reverse_packet; + m_argImpl.template writePacket( + ReverseRow ? m_rows.value() - row - OffsetRow : row, + ReverseCol ? m_cols.value() - col - OffsetCol : col, + reverse_packet::run(x)); + } + + template + EIGEN_STRONG_INLINE + void writePacket(Index index, const PacketType& x) + { + enum { PacketSize = unpacket_traits::size }; + m_argImpl.template writePacket + (m_rows.value() * m_cols.value() - index - PacketSize, preverse(x)); + } + +protected: + evaluator m_argImpl; + + // If we do not reverse rows, then we do not need to know the number of rows; same for columns + // Nonetheless, in this case it is important to set to 1 such that the coeff(index) method works fine for vectors. + const variable_if_dynamic m_rows; + const variable_if_dynamic m_cols; +}; + + +// -------------------- Diagonal -------------------- + +template +struct evaluator > + : evaluator_base > +{ + typedef Diagonal XprType; + + enum { + CoeffReadCost = evaluator::CoeffReadCost, + + Flags = (unsigned int)(evaluator::Flags & (HereditaryBits | DirectAccessBit) & ~RowMajorBit) | LinearAccessBit, + + Alignment = 0 + }; + + EIGEN_DEVICE_FUNC explicit evaluator(const XprType& diagonal) + : m_argImpl(diagonal.nestedExpression()), + m_index(diagonal.index()) + { } + + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index row, Index) const + { + return m_argImpl.coeff(row + rowOffset(), row + colOffset()); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index index) const + { + return m_argImpl.coeff(index + rowOffset(), index + colOffset()); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Scalar& coeffRef(Index row, Index) + { + return m_argImpl.coeffRef(row + rowOffset(), row + colOffset()); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Scalar& coeffRef(Index index) + { + return m_argImpl.coeffRef(index + rowOffset(), index + colOffset()); + } + +protected: + evaluator m_argImpl; + const internal::variable_if_dynamicindex m_index; + +private: + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rowOffset() const { return m_index.value() > 0 ? 0 : -m_index.value(); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index colOffset() const { return m_index.value() > 0 ? m_index.value() : 0; } +}; + + +//---------------------------------------------------------------------- +// deprecated code +//---------------------------------------------------------------------- + +// -------------------- EvalToTemp -------------------- + +// expression class for evaluating nested expression to a temporary + +template class EvalToTemp; + +template +struct traits > + : public traits +{ }; + +template +class EvalToTemp + : public dense_xpr_base >::type +{ + public: + + typedef typename dense_xpr_base::type Base; + EIGEN_GENERIC_PUBLIC_INTERFACE(EvalToTemp) + + explicit EvalToTemp(const ArgType& arg) + : m_arg(arg) + { } + + const ArgType& arg() const + { + return m_arg; + } + + Index rows() const + { + return m_arg.rows(); + } + + Index cols() const + { + return m_arg.cols(); + } + + private: + const ArgType& m_arg; +}; + +template +struct evaluator > + : public evaluator +{ + typedef EvalToTemp XprType; + typedef typename ArgType::PlainObject PlainObject; + typedef evaluator Base; + + EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr) + : m_result(xpr.arg()) + { + ::new (static_cast(this)) Base(m_result); + } + + // This constructor is used when nesting an EvalTo evaluator in another evaluator + EIGEN_DEVICE_FUNC evaluator(const ArgType& arg) + : m_result(arg) + { + ::new (static_cast(this)) Base(m_result); + } + +protected: + PlainObject m_result; +}; + +} // namespace internal + +} // end namespace Eigen + +#endif // EIGEN_COREEVALUATORS_H diff --git a/stan/math/prim/plugins/CwiseUnaryView.h b/stan/math/prim/plugins/CwiseUnaryView.h new file mode 100644 index 00000000000..c2085f99093 --- /dev/null +++ b/stan/math/prim/plugins/CwiseUnaryView.h @@ -0,0 +1,142 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2009-2010 Gael Guennebaud +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CWISE_UNARY_VIEW_H +#define EIGEN_CWISE_UNARY_VIEW_H + +namespace Eigen { + +namespace internal { +template +struct traits > + : traits +{ + typedef typename result_of< + ViewOp(const typename traits::Scalar&) + >::type Scalar; + typedef typename MatrixType::Nested MatrixTypeNested; + typedef typename remove_all::type _MatrixTypeNested; + enum { + FlagsLvalueBit = is_lvalue::value ? LvalueBit : 0, + Flags = traits<_MatrixTypeNested>::Flags & (RowMajorBit | FlagsLvalueBit | DirectAccessBit), // FIXME DirectAccessBit should not be handled by expressions + MatrixTypeInnerStride = inner_stride_at_compile_time::ret, + // need to cast the sizeof's from size_t to int explicitly, otherwise: + // "error: no integral type can represent all of the enumerator values + InnerStrideAtCompileTime = StrideType::InnerStrideAtCompileTime == 0 + ? (MatrixTypeInnerStride == Dynamic + ? int(Dynamic) + : int(MatrixTypeInnerStride) * int(sizeof(typename traits::Scalar) / sizeof(Scalar))) + : int(StrideType::InnerStrideAtCompileTime), + + OuterStrideAtCompileTime = StrideType::OuterStrideAtCompileTime == 0 + ? (outer_stride_at_compile_time::ret == Dynamic + ? int(Dynamic) + : outer_stride_at_compile_time::ret * int(sizeof(typename traits::Scalar) / sizeof(Scalar))) + : int(StrideType::OuterStrideAtCompileTime) + + }; +}; +} + +template +class CwiseUnaryViewImpl; + +/** \class CwiseUnaryView + * \ingroup Core_Module + * + * \brief Generic lvalue expression of a coefficient-wise unary operator of a matrix or a vector + * + * \tparam ViewOp template functor implementing the view + * \tparam MatrixType the type of the matrix we are applying the unary operator + * + * This class represents a lvalue expression of a generic unary view operator of a matrix or a vector. + * It is the return type of real() and imag(), and most of the time this is the only way it is used. + * + * \sa MatrixBase::unaryViewExpr(const CustomUnaryOp &) const, class CwiseUnaryOp + */ +template +class CwiseUnaryView : public CwiseUnaryViewImpl::StorageKind> +{ + public: + + typedef typename CwiseUnaryViewImpl::StorageKind>::Base Base; + EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseUnaryView) + typedef typename internal::ref_selector::non_const_type MatrixTypeNested; + typedef typename internal::remove_all::type NestedExpression; + + explicit inline CwiseUnaryView(MatrixType& mat, const ViewOp& func = ViewOp()) + : m_matrix(mat), m_functor(func) {} + + EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CwiseUnaryView) + + EIGEN_STRONG_INLINE Index rows() const { return m_matrix.rows(); } + EIGEN_STRONG_INLINE Index cols() const { return m_matrix.cols(); } + + /** \returns the functor representing unary operation */ + const ViewOp& functor() const { return m_functor; } + + /** \returns the nested expression */ + const typename internal::remove_all::type& + nestedExpression() const { return m_matrix; } + + /** \returns the nested expression */ + typename internal::remove_reference::type& + nestedExpression() { return m_matrix.const_cast_derived(); } + + protected: + MatrixTypeNested m_matrix; + ViewOp m_functor; +}; + +// Generic API dispatcher +template +class CwiseUnaryViewImpl + : public internal::generic_xpr_base >::type +{ +public: + typedef typename internal::generic_xpr_base >::type Base; +}; + +template +class CwiseUnaryViewImpl + : public internal::dense_xpr_base< CwiseUnaryView >::type +{ + public: + + typedef CwiseUnaryView Derived; + typedef typename internal::dense_xpr_base< CwiseUnaryView >::type Base; + + EIGEN_DENSE_PUBLIC_INTERFACE(Derived) + EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CwiseUnaryViewImpl) + + EIGEN_DEVICE_FUNC inline Scalar* data() { return &(this->coeffRef(0)); } + EIGEN_DEVICE_FUNC inline const Scalar* data() const { return &(this->coeff(0)); } + + EIGEN_DEVICE_FUNC inline Index innerStride() const + { + return StrideType::InnerStrideAtCompileTime != 0 + ? int(StrideType::InnerStrideAtCompileTime) + : derived().nestedExpression().innerStride() * sizeof(typename internal::traits::Scalar) / sizeof(Scalar); + + } + + EIGEN_DEVICE_FUNC inline Index outerStride() const + { + return StrideType::OuterStrideAtCompileTime != 0 + ? int(StrideType::OuterStrideAtCompileTime) + : derived().nestedExpression().outerStride() * sizeof(typename internal::traits::Scalar) / sizeof(Scalar); + + } + protected: + EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(CwiseUnaryViewImpl) +}; + +} // end namespace Eigen + +#endif // EIGEN_CWISE_UNARY_VIEW_H diff --git a/stan/math/prim/plugins/ForwardDeclarations.h b/stan/math/prim/plugins/ForwardDeclarations.h new file mode 100644 index 00000000000..3a193322fa1 --- /dev/null +++ b/stan/math/prim/plugins/ForwardDeclarations.h @@ -0,0 +1,298 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2007-2010 Benoit Jacob +// Copyright (C) 2008-2009 Gael Guennebaud +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_FORWARDDECLARATIONS_H +#define EIGEN_FORWARDDECLARATIONS_H + +namespace Eigen { +namespace internal { + +template struct traits; + +// here we say once and for all that traits == traits +// When constness must affect traits, it has to be constness on template parameters on which T itself depends. +// For example, traits > != traits >, but +// traits > == traits > +template struct traits : traits {}; + +template struct has_direct_access +{ + enum { ret = (traits::Flags & DirectAccessBit) ? 1 : 0 }; +}; + +template struct accessors_level +{ + enum { has_direct_access = (traits::Flags & DirectAccessBit) ? 1 : 0, + has_write_access = (traits::Flags & LvalueBit) ? 1 : 0, + value = has_direct_access ? (has_write_access ? DirectWriteAccessors : DirectAccessors) + : (has_write_access ? WriteAccessors : ReadOnlyAccessors) + }; +}; + +template struct evaluator_traits; + +template< typename T> struct evaluator; + +} // end namespace internal + +template struct NumTraits; + +template struct EigenBase; +template class DenseBase; +template class PlainObjectBase; +template class DenseCoeffsBase; + +template class Matrix; + +template class MatrixBase; +template class ArrayBase; + +template class Flagged; +template class StorageBase > class NoAlias; +template class NestByValue; +template class ForceAlignedAccess; +template class SwapWrapper; + +template class Block; + +template class VectorBlock; +template class Transpose; +template class Conjugate; +template class CwiseNullaryOp; +template class CwiseUnaryOp; +template class CwiseBinaryOp; +template class CwiseTernaryOp; +template class Solve; +template class Inverse; + +template class Product; + +template class DiagonalBase; +template class DiagonalWrapper; +template class DiagonalMatrix; +template class DiagonalProduct; +template class Diagonal; +template class PermutationMatrix; +template class Transpositions; +template class PermutationBase; +template class TranspositionsBase; +template class PermutationWrapper; +template class TranspositionsWrapper; + +template::has_write_access ? WriteAccessors : ReadOnlyAccessors +> class MapBase; +template class Stride; +template class InnerStride; +template class OuterStride; +template > class Map; +template class RefBase; +template,OuterStride<> >::type > class Ref; +template> class CwiseUnaryView; + +template class TriangularBase; +template class TriangularView; +template class SelfAdjointView; +template class SparseView; +template class WithFormat; +template struct CommaInitializer; +template class ReturnByValue; +template class ArrayWrapper; +template class MatrixWrapper; +template class SolverBase; +template class InnerIterator; + +namespace internal { +template struct kernel_retval_base; +template struct kernel_retval; +template struct image_retval_base; +template struct image_retval; +} // end namespace internal + +namespace internal { +template class BandMatrix; +} + +namespace internal { +template struct product_type; + +template struct EnableIf; + +/** \internal + * \class product_evaluator + * Products need their own evaluator with more template arguments allowing for + * easier partial template specializations. + */ +template< typename T, + int ProductTag = internal::product_type::ret, + typename LhsShape = typename evaluator_traits::Shape, + typename RhsShape = typename evaluator_traits::Shape, + typename LhsScalar = typename traits::Scalar, + typename RhsScalar = typename traits::Scalar + > struct product_evaluator; +} + +template::value> +struct ProductReturnType; + +// this is a workaround for sun CC +template struct LazyProductReturnType; + +namespace internal { + +// Provides scalar/packet-wise product and product with accumulation +// with optional conjugation of the arguments. +template struct conj_helper; + +template struct scalar_sum_op; +template struct scalar_difference_op; +template struct scalar_conj_product_op; +template struct scalar_min_op; +template struct scalar_max_op; +template struct scalar_opposite_op; +template struct scalar_conjugate_op; +template struct scalar_real_op; +template struct scalar_imag_op; +template struct scalar_abs_op; +template struct scalar_abs2_op; +template struct scalar_sqrt_op; +template struct scalar_rsqrt_op; +template struct scalar_exp_op; +template struct scalar_log_op; +template struct scalar_cos_op; +template struct scalar_sin_op; +template struct scalar_acos_op; +template struct scalar_asin_op; +template struct scalar_tan_op; +template struct scalar_inverse_op; +template struct scalar_square_op; +template struct scalar_cube_op; +template struct scalar_cast_op; +template struct scalar_random_op; +template struct scalar_constant_op; +template struct scalar_identity_op; +template struct scalar_sign_op; +template struct scalar_pow_op; +template struct scalar_hypot_op; +template struct scalar_product_op; +template struct scalar_quotient_op; + +// SpecialFunctions module +template struct scalar_lgamma_op; +template struct scalar_digamma_op; +template struct scalar_erf_op; +template struct scalar_erfc_op; +template struct scalar_igamma_op; +template struct scalar_igammac_op; +template struct scalar_zeta_op; +template struct scalar_betainc_op; + +} // end namespace internal + +struct IOFormat; + +// Array module +template class Array; +template class Select; +template class PartialReduxExpr; +template class VectorwiseOp; +template class Replicate; +template class Reverse; + +template class FullPivLU; +template class PartialPivLU; +namespace internal { +template struct inverse_impl; +} +template class HouseholderQR; +template class ColPivHouseholderQR; +template class FullPivHouseholderQR; +template class CompleteOrthogonalDecomposition; +template class JacobiSVD; +template class BDCSVD; +template class LLT; +template class LDLT; +template class HouseholderSequence; +template class JacobiRotation; + +// Geometry module: +template class RotationBase; +template class Cross; +template class QuaternionBase; +template class Rotation2D; +template class AngleAxis; +template class Translation; +template class AlignedBox; +template class Quaternion; +template class Transform; +template class ParametrizedLine; +template class Hyperplane; +template class UniformScaling; +template class Homogeneous; + +// Sparse module: +template class SparseMatrixBase; + +// MatrixFunctions module +template struct MatrixExponentialReturnValue; +template class MatrixFunctionReturnValue; +template class MatrixSquareRootReturnValue; +template class MatrixLogarithmReturnValue; +template class MatrixPowerReturnValue; +template class MatrixComplexPowerReturnValue; + +namespace internal { +template +struct stem_function +{ + typedef std::complex::Real> ComplexScalar; + typedef ComplexScalar type(ComplexScalar, int); +}; +} + +} // end namespace Eigen + +#endif // EIGEN_FORWARDDECLARATIONS_H From 1553cacd84d12a954e4194bac70e9a4f2e2f6b30 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Thu, 13 Jan 2022 11:58:21 +0800 Subject: [PATCH 23/27] Revert Eigen changes --- .../Eigen/src/Core/CoreEvaluators.h | 8 +-- .../Eigen/src/Core/CwiseUnaryView.h | 49 +++++++------------ .../Eigen/src/Core/util/ForwardDeclarations.h | 2 +- 3 files changed, 23 insertions(+), 36 deletions(-) diff --git a/lib/eigen_3.3.9/Eigen/src/Core/CoreEvaluators.h b/lib/eigen_3.3.9/Eigen/src/Core/CoreEvaluators.h index f45dfe46f1a..910889efa70 100644 --- a/lib/eigen_3.3.9/Eigen/src/Core/CoreEvaluators.h +++ b/lib/eigen_3.3.9/Eigen/src/Core/CoreEvaluators.h @@ -743,11 +743,11 @@ struct binary_evaluator, IndexBased, IndexBase // -------------------- CwiseUnaryView -------------------- -template -struct unary_evaluator, IndexBased> - : evaluator_base > +template +struct unary_evaluator, IndexBased> + : evaluator_base > { - typedef CwiseUnaryView XprType; + typedef CwiseUnaryView XprType; enum { CoeffReadCost = evaluator::CoeffReadCost + functor_traits::Cost, diff --git a/lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h b/lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h index eb6c943e3aa..5a30fa8df18 100644 --- a/lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h +++ b/lib/eigen_3.3.9/Eigen/src/Core/CwiseUnaryView.h @@ -13,8 +13,8 @@ namespace Eigen { namespace internal { -template -struct traits > +template +struct traits > : traits { typedef typename result_of< @@ -30,17 +30,15 @@ struct traits > // "error: no integral type can represent all of the enumerator values InnerStrideAtCompileTime = MatrixTypeInnerStride == Dynamic ? int(Dynamic) - : int(MatrixTypeInnerStride) - * ((InnerStride == -1) ? int(sizeof(typename traits::Scalar) / sizeof(Scalar)) : InnerStride), + : int(MatrixTypeInnerStride) * int(sizeof(typename traits::Scalar) / sizeof(Scalar)), OuterStrideAtCompileTime = outer_stride_at_compile_time::ret == Dynamic ? int(Dynamic) - : outer_stride_at_compile_time::ret - * ((OuterStride == -1) ? int(sizeof(typename traits::Scalar) / sizeof(Scalar)) : OuterStride) + : outer_stride_at_compile_time::ret * int(sizeof(typename traits::Scalar) / sizeof(Scalar)) }; }; } -template +template class CwiseUnaryViewImpl; /** \class CwiseUnaryView @@ -56,12 +54,12 @@ class CwiseUnaryViewImpl; * * \sa MatrixBase::unaryViewExpr(const CustomUnaryOp &) const, class CwiseUnaryOp */ -template -class CwiseUnaryView : public CwiseUnaryViewImpl::StorageKind> +template +class CwiseUnaryView : public CwiseUnaryViewImpl::StorageKind> { public: - typedef typename CwiseUnaryViewImpl::StorageKind>::Base Base; + typedef typename CwiseUnaryViewImpl::StorageKind>::Base Base; EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseUnaryView) typedef typename internal::ref_selector::non_const_type MatrixTypeNested; typedef typename internal::remove_all::type NestedExpression; @@ -91,48 +89,37 @@ class CwiseUnaryView : public CwiseUnaryViewImpl +template class CwiseUnaryViewImpl - : public internal::generic_xpr_base >::type + : public internal::generic_xpr_base >::type { public: - typedef typename internal::generic_xpr_base >::type Base; + typedef typename internal::generic_xpr_base >::type Base; }; -template -class CwiseUnaryViewImpl - : public internal::dense_xpr_base< CwiseUnaryView >::type +template +class CwiseUnaryViewImpl + : public internal::dense_xpr_base< CwiseUnaryView >::type { public: - typedef CwiseUnaryView Derived; - typedef typename internal::dense_xpr_base< CwiseUnaryView >::type Base; + typedef CwiseUnaryView Derived; + typedef typename internal::dense_xpr_base< CwiseUnaryView >::type Base; EIGEN_DENSE_PUBLIC_INTERFACE(Derived) EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CwiseUnaryViewImpl) EIGEN_DEVICE_FUNC inline Scalar* data() { return &(this->coeffRef(0)); } EIGEN_DEVICE_FUNC inline const Scalar* data() const { return &(this->coeff(0)); } - - using Base::coeffRef; - EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index row, Index col) const { - return const_cast(this)->coeffRef(row, col); - } - - EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index index) const { - return const_cast(this)->coeffRef(index); - } EIGEN_DEVICE_FUNC inline Index innerStride() const { - return derived().nestedExpression().innerStride() - * ((InnerStride == -1) ? sizeof(typename internal::traits::Scalar) / sizeof(Scalar) : InnerStride); + return derived().nestedExpression().innerStride() * sizeof(typename internal::traits::Scalar) / sizeof(Scalar); } EIGEN_DEVICE_FUNC inline Index outerStride() const { - return derived().nestedExpression().outerStride() - * ((OuterStride == -1) ? sizeof(typename internal::traits::Scalar) / sizeof(Scalar) : OuterStride); + return derived().nestedExpression().outerStride() * sizeof(typename internal::traits::Scalar) / sizeof(Scalar); } protected: EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(CwiseUnaryViewImpl) diff --git a/lib/eigen_3.3.9/Eigen/src/Core/util/ForwardDeclarations.h b/lib/eigen_3.3.9/Eigen/src/Core/util/ForwardDeclarations.h index ed8a8a7c610..134544f9643 100644 --- a/lib/eigen_3.3.9/Eigen/src/Core/util/ForwardDeclarations.h +++ b/lib/eigen_3.3.9/Eigen/src/Core/util/ForwardDeclarations.h @@ -85,7 +85,7 @@ template class Transpose; template class Conjugate; template class CwiseNullaryOp; template class CwiseUnaryOp; -template class CwiseUnaryView; +template class CwiseUnaryView; template class CwiseBinaryOp; template class CwiseTernaryOp; template class Solve; From c0544562197f224600875eee02652f4276e3dea8 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Thu, 13 Jan 2022 12:06:20 +0800 Subject: [PATCH 24/27] Use backported eigen code with plugin refactor --- stan/math/prim/plugins/adj_view.h | 6 +++--- test/unit/math/rev/eigen_plugins_test.cpp | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/stan/math/prim/plugins/adj_view.h b/stan/math/prim/plugins/adj_view.h index 7312f2a382c..977c9ad3617 100644 --- a/stan/math/prim/plugins/adj_view.h +++ b/stan/math/prim/plugins/adj_view.h @@ -69,7 +69,7 @@ struct adj_stride { }; template struct adj_stride::value>> { - static constexpr int stride = -1; + static constexpr int stride = 0; }; template @@ -95,8 +95,8 @@ inline const auto adj() const { EIGEN_DEVICE_FUNC inline auto adj() { return CwiseUnaryView, Derived, - adj_stride::stride, - adj_stride::stride>(derived()); + Stride::stride, + adj_stride::stride>>(derived()); } #endif diff --git a/test/unit/math/rev/eigen_plugins_test.cpp b/test/unit/math/rev/eigen_plugins_test.cpp index 0dc6bbf5ae1..45029d50e74 100644 --- a/test/unit/math/rev/eigen_plugins_test.cpp +++ b/test/unit/math/rev/eigen_plugins_test.cpp @@ -32,8 +32,8 @@ TEST(AgradRevMatrixAddons, var_matrix) { const matrix_v const_mat_in = matrix_v::Random(100, 100); - MatrixXd tri_out = const_mat_in.val().triangularView().solve( - const_mat_in.adj().transpose()); + MatrixXd tri_out = mat_in.val().triangularView().solve( + mat_in.adj().transpose()); matrix_vi mat_vi = mat_in.vi(); From ec20ea5d946abd458d66c3562846d17ac47474fd Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Thu, 13 Jan 2022 12:22:22 +0800 Subject: [PATCH 25/27] Remove stray eval() --- stan/math/rev/fun/multiply.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/stan/math/rev/fun/multiply.hpp b/stan/math/rev/fun/multiply.hpp index e8cd8daa3d1..eaa4a1755ac 100644 --- a/stan/math/rev/fun/multiply.hpp +++ b/stan/math/rev/fun/multiply.hpp @@ -40,12 +40,12 @@ inline auto multiply(const T1& A, const T2& B) { reverse_pass_callback( [arena_A, arena_B, arena_A_val, arena_B_val, res]() mutable { if (is_var_matrix::value || is_var_matrix::value) { - arena_A.adj() += res.adj() * arena_B_val.transpose().eval(); - arena_B.adj() += arena_A_val.transpose().eval() * res.adj(); + arena_A.adj() += res.adj() * arena_B_val.transpose(); + arena_B.adj() += arena_A_val.transpose() * res.adj(); } else { auto res_adj = res.adj().eval(); - arena_A.adj() += res_adj * arena_B_val.transpose().eval(); - arena_B.adj() += arena_A_val.transpose().eval() * res_adj; + arena_A.adj() += res_adj * arena_B_val.transpose(); + arena_B.adj() += arena_A_val.transpose() * res_adj; } }); return return_t(res); From 512968e076ebe279cb93feb5c0d93eda32985bd7 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Thu, 13 Jan 2022 15:58:21 +0800 Subject: [PATCH 26/27] Test Eigen fixes --- stan/math/prim/plugins/CwiseUnaryView.h | 38 ++++++++++--------------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/stan/math/prim/plugins/CwiseUnaryView.h b/stan/math/prim/plugins/CwiseUnaryView.h index c2085f99093..ba01edcb067 100644 --- a/stan/math/prim/plugins/CwiseUnaryView.h +++ b/stan/math/prim/plugins/CwiseUnaryView.h @@ -28,18 +28,14 @@ struct traits > MatrixTypeInnerStride = inner_stride_at_compile_time::ret, // need to cast the sizeof's from size_t to int explicitly, otherwise: // "error: no integral type can represent all of the enumerator values - InnerStrideAtCompileTime = StrideType::InnerStrideAtCompileTime == 0 - ? (MatrixTypeInnerStride == Dynamic - ? int(Dynamic) - : int(MatrixTypeInnerStride) * int(sizeof(typename traits::Scalar) / sizeof(Scalar))) - : int(StrideType::InnerStrideAtCompileTime), - - OuterStrideAtCompileTime = StrideType::OuterStrideAtCompileTime == 0 - ? (outer_stride_at_compile_time::ret == Dynamic - ? int(Dynamic) - : outer_stride_at_compile_time::ret * int(sizeof(typename traits::Scalar) / sizeof(Scalar))) - : int(StrideType::OuterStrideAtCompileTime) - + InnerStrideAtCompileTime = MatrixTypeInnerStride == Dynamic + ? int(Dynamic) + : int(MatrixTypeInnerStride) + * ((StrideType::InnerStrideAtCompileTime == 0) ? int(sizeof(typename traits::Scalar) / sizeof(Scalar)) : StrideType::InnerStrideAtCompileTime), + OuterStrideAtCompileTime = outer_stride_at_compile_time::ret == Dynamic + ? int(Dynamic) + : outer_stride_at_compile_time::ret + * ((StrideType::OuterStrideAtCompileTime == 0) ? int(sizeof(typename traits::Scalar) / sizeof(Scalar)) : StrideType::OuterStrideAtCompileTime) }; }; } @@ -65,7 +61,7 @@ class CwiseUnaryView : public CwiseUnaryViewImpl::StorageKind>::Base Base; + typedef typename CwiseUnaryViewImpl::StorageKind>::Base Base; EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseUnaryView) typedef typename internal::ref_selector::non_const_type MatrixTypeNested; typedef typename internal::remove_all::type NestedExpression; @@ -109,8 +105,8 @@ class CwiseUnaryViewImpl { public: - typedef CwiseUnaryView Derived; - typedef typename internal::dense_xpr_base< CwiseUnaryView >::type Base; + typedef CwiseUnaryView Derived; + typedef typename internal::dense_xpr_base< CwiseUnaryView >::type Base; EIGEN_DENSE_PUBLIC_INTERFACE(Derived) EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CwiseUnaryViewImpl) @@ -120,18 +116,14 @@ class CwiseUnaryViewImpl EIGEN_DEVICE_FUNC inline Index innerStride() const { - return StrideType::InnerStrideAtCompileTime != 0 - ? int(StrideType::InnerStrideAtCompileTime) - : derived().nestedExpression().innerStride() * sizeof(typename internal::traits::Scalar) / sizeof(Scalar); - + return derived().nestedExpression().innerStride() + * ((StrideType::InnerStrideAtCompileTime == 0) ? sizeof(typename internal::traits::Scalar) / sizeof(Scalar) : StrideType::InnerStrideAtCompileTime); } EIGEN_DEVICE_FUNC inline Index outerStride() const { - return StrideType::OuterStrideAtCompileTime != 0 - ? int(StrideType::OuterStrideAtCompileTime) - : derived().nestedExpression().outerStride() * sizeof(typename internal::traits::Scalar) / sizeof(Scalar); - + return derived().nestedExpression().outerStride() + * ((StrideType::OuterStrideAtCompileTime == 0) ? sizeof(typename internal::traits::Scalar) / sizeof(Scalar) : StrideType::OuterStrideAtCompileTime); } protected: EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(CwiseUnaryViewImpl) From b2ebd225c854ac6af6b17f0089b6bae78c0f33ea Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Thu, 13 Jan 2022 22:21:58 +0800 Subject: [PATCH 27/27] Eigen fixes 2 --- stan/math/prim/plugins/CwiseUnaryView.h | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/stan/math/prim/plugins/CwiseUnaryView.h b/stan/math/prim/plugins/CwiseUnaryView.h index ba01edcb067..a6d86201f8e 100644 --- a/stan/math/prim/plugins/CwiseUnaryView.h +++ b/stan/math/prim/plugins/CwiseUnaryView.h @@ -110,6 +110,19 @@ class CwiseUnaryViewImpl EIGEN_DENSE_PUBLIC_INTERFACE(Derived) EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CwiseUnaryViewImpl) + using Base::coeffRef; + + EIGEN_DEVICE_FUNC + inline const Scalar& coeffRef(Index rowId, Index colId) const + { + return derived().nestedExpression().coeffRef(colId, rowId); + } + + EIGEN_DEVICE_FUNC + inline const Scalar& coeffRef(Index index) const + { + return derived().nestedExpression().coeffRef(index); + } EIGEN_DEVICE_FUNC inline Scalar* data() { return &(this->coeffRef(0)); } EIGEN_DEVICE_FUNC inline const Scalar* data() const { return &(this->coeff(0)); }