Skip to content

Commit

Permalink
Merge pull request #1216 from andrjohns/feature/issue-1197-eigen-plug…
Browse files Browse the repository at this point in the history
…in-autodiff-types

Issue 1197 - Use Eigen's plugin system to views for autodiff types
  • Loading branch information
SteveBronder committed Jul 11, 2019
2 parents c11a3d7 + 0e02f00 commit 1f2d306
Show file tree
Hide file tree
Showing 37 changed files with 723 additions and 106 deletions.
2 changes: 2 additions & 0 deletions stan/math/fwd/mat.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#ifndef STAN_MATH_FWD_MAT_HPP
#define STAN_MATH_FWD_MAT_HPP

#include <stan/math/prim/mat/fun/Eigen.hpp>

#include <stan/math/fwd/core.hpp>
#include <stan/math/fwd/meta.hpp>

Expand Down
24 changes: 2 additions & 22 deletions stan/math/fwd/mat/fun/determinant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@

#include <stan/math/prim/mat/fun/Eigen.hpp>
#include <stan/math/fwd/core.hpp>
#include <stan/math/prim/mat/fun/multiply.hpp>
#include <stan/math/fwd/mat/fun/multiply.hpp>
#include <stan/math/prim/mat/fun/inverse.hpp>
#include <stan/math/fwd/mat/fun/inverse.hpp>
#include <stan/math/prim/mat/err/check_square.hpp>

namespace stan {
Expand All @@ -15,25 +11,9 @@ namespace math {
template <typename T, int R, int C>
inline fvar<T> determinant(const Eigen::Matrix<fvar<T>, R, C>& m) {
check_square("determinant", "m", m);
Eigen::Matrix<T, R, C> m_deriv(m.rows(), m.cols());
Eigen::Matrix<T, R, C> m_val(m.rows(), m.cols());

for (size_type i = 0; i < m.rows(); i++) {
for (size_type j = 0; j < m.cols(); j++) {
m_deriv(i, j) = m(i, j).d_;
m_val(i, j) = m(i, j).val_;
}
}

Eigen::Matrix<T, R, C> m_inv = inverse(m_val);
m_deriv = multiply(m_inv, m_deriv);

fvar<T> result;
result.val_ = m_val.determinant();
result.d_ = result.val_ * m_deriv.trace();

// FIXME: I think this will overcopy compared to retur fvar<T>(...);
return result;
const T vals = m.val().determinant();
return fvar<T>(vals, vals * (m.val().inverse() * m.d()).trace());
}

} // namespace math
Expand Down
23 changes: 5 additions & 18 deletions stan/math/fwd/mat/fun/log_sum_exp.hpp
Original file line number Diff line number Diff line change
@@ -1,33 +1,20 @@
#ifndef STAN_MATH_FWD_MAT_FUN_LOG_SUM_EXP_HPP
#define STAN_MATH_FWD_MAT_FUN_LOG_SUM_EXP_HPP

#include <stan/math/prim/mat/fun/Eigen.hpp>
#include <stan/math/fwd/core.hpp>
#include <stan/math/prim/mat/fun/log_sum_exp.hpp>
#include <stan/math/prim/mat/fun/Eigen.hpp>
#include <stan/math/fwd/scal/fun/log.hpp>
#include <stan/math/fwd/scal/fun/exp.hpp>

namespace stan {
namespace math {

// FIXME: cut-and-paste from fwd/log_sum_exp.hpp; should
// be able to generalize
template <typename T, int R, int C>
fvar<T> log_sum_exp(const Eigen::Matrix<fvar<T>, R, C>& v) {
using std::exp;
using std::log;
Eigen::Matrix<T, R, C> vals = v.val();
Eigen::Matrix<T, R, C> exp_vals = vals.array().exp();

Eigen::Matrix<T, 1, Eigen::Dynamic> vals(v.size());
for (int i = 0; i < v.size(); ++i)
vals[i] = v(i).val_;
T deriv(0.0);
T denominator(0.0);
for (int i = 0; i < v.size(); ++i) {
T exp_vi = exp(vals[i]);
denominator += exp_vi;
deriv += v(i).d_ * exp_vi;
}
return fvar<T>(log_sum_exp(vals), deriv / denominator);
return fvar<T>(log_sum_exp(vals),
v.d().cwiseProduct(exp_vals).sum() / exp_vals.sum());
}

} // namespace math
Expand Down
1 change: 0 additions & 1 deletion stan/math/fwd/mat/fun/qr_Q.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include <stan/math/prim/mat/fun/Eigen.hpp>
#include <stan/math/prim/scal/err/check_greater_or_equal.hpp>
#include <stan/math/fwd/core.hpp>
#include <Eigen/QR>

namespace stan {
namespace math {
Expand Down
1 change: 0 additions & 1 deletion stan/math/fwd/mat/fun/qr_R.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include <stan/math/prim/mat/fun/Eigen.hpp>
#include <stan/math/prim/scal/err/check_greater_or_equal.hpp>
#include <stan/math/fwd/core.hpp>
#include <Eigen/QR>

namespace stan {
namespace math {
Expand Down
2 changes: 2 additions & 0 deletions stan/math/mix/mat.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#ifndef STAN_MATH_MIX_MAT_HPP
#define STAN_MATH_MIX_MAT_HPP

#include <stan/math/prim/mat/fun/Eigen.hpp>

#include <stan/math/mix/mat/fun/typedefs.hpp>
#include <stan/math/mix/meta.hpp>

Expand Down
2 changes: 1 addition & 1 deletion stan/math/opencl/diagonal_multiply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <stan/math/opencl/matrix_cl.hpp>
#include <stan/math/opencl/err/check_opencl.hpp>
#include <stan/math/opencl/kernels/scalar_mul_diagonal.hpp>
#include <Eigen/Dense>
#include <stan/math/prim/mat/fun/Eigen.hpp>

namespace stan {
namespace math {
Expand Down
2 changes: 1 addition & 1 deletion stan/math/opencl/multiply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <stan/math/opencl/kernels/add.hpp>
#include <stan/math/opencl/sub_block.hpp>
#include <stan/math/opencl/zeros.hpp>
#include <Eigen/Dense>
#include <stan/math/prim/mat/fun/Eigen.hpp>

namespace stan {
namespace math {
Expand Down
2 changes: 1 addition & 1 deletion stan/math/opencl/multiply_transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <stan/math/opencl/zeros.hpp>
#include <stan/math/opencl/sub_block.hpp>

#include <Eigen/Dense>
#include <stan/math/prim/mat/fun/Eigen.hpp>

namespace stan {
namespace math {
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/mat.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef STAN_MATH_PRIM_MAT_HPP
#define STAN_MATH_PRIM_MAT_HPP

#include <stan/math/prim/mat/fun/Eigen.hpp>
#include <stan/math/prim/meta.hpp>

#include <stan/math/prim/mat/err/check_cholesky_factor.hpp>
Expand Down Expand Up @@ -43,7 +44,6 @@
#include <stan/math/prim/mat/err/is_unit_vector.hpp>
#include <stan/math/prim/mat/err/validate_non_negative_index.hpp>

#include <stan/math/prim/mat/fun/Eigen.hpp>
#include <stan/math/prim/mat/fun/LDLT_factor.hpp>
#include <stan/math/prim/mat/fun/Phi.hpp>
#include <stan/math/prim/mat/fun/Phi_approx.hpp>
Expand Down
187 changes: 187 additions & 0 deletions stan/math/prim/mat/eigen_plugins.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
#ifndef STAN_MATH_PRIM_MAT_EIGEN_PLUGINS_H
#define STAN_MATH_PRIM_MAT_EIGEN_PLUGINS_H


/**
* Reimplements is_fvar without requiring external math headers
*
* decltype((void)(T::d_)) is a pre C++17 replacement for
* std::void_t<decltype(T::d_)>
*
* TODO(Andrew): Replace with std::void_t after move to C++17
*/
template<class, class = void>
struct is_fvar : std::false_type
{ };
template<class T>
struct is_fvar<T, decltype((void)(T::d_))> : std::true_type
{ };

//TODO(Andrew): Replace std::is_const<>::value with std::is_const_v<> after move to C++17
template<typename T>
using double_return_t = std::conditional_t<std::is_const<std::remove_reference_t<T>>::value,
const double&,
double&>;

template<typename T>
using vari_return_t = std::conditional_t<std::is_const<std::remove_reference_t<T>>::value,
const decltype(T::vi_)&,
decltype(T::vi_)&>;

template<typename T>
using forward_return_t = std::conditional_t<std::is_const<std::remove_reference_t<T>>::value,
const decltype(T::val_)&,
decltype(T::val_)&>;

/**
* Structure to return a view to the values in a var, vari*, and fvar<T>.
* To identify the correct member to call for a given input, templates
* check a combination of whether the input is a pointer (i.e. vari*)
* and/or whether the input has member ".d_" (i.e. fvar).
*
* For definitions of EIGEN_EMPTY_STRUCT_CTOR, EIGEN_DEVICE_FUNC, and
* EIGEN_STRONG_INLINE; see: https://eigen.tuxfamily.org/dox/XprHelper_8h_source.html
*/
struct val_Op{
EIGEN_EMPTY_STRUCT_CTOR(val_Op);

//Returns value from a vari*
template<typename T = Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
std::enable_if_t<std::is_pointer<T>::value, double_return_t<T>>
operator()(T &v) const { return v->val_; }

//Returns value from a var
template<typename T = Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
std::enable_if_t<(!std::is_pointer<T>::value && !is_fvar<T>::value
&& !std::is_arithmetic<T>::value), double_return_t<T>>
operator()(T &v) const { return v.vi_->val_; }

//Returns value from an fvar
template<typename T = Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
std::enable_if_t<is_fvar<T>::value, forward_return_t<T>>
operator()(T &v) const { return v.val_; }

//Returns double unchanged from input
template<typename T = Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
std::enable_if_t<std::is_arithmetic<T>::value, double_return_t<T>>
operator()(T& v) const { return v; }
};

/**
* Coefficient-wise function applying val_Op struct to a matrix of const var
* or vari* and returning a view to the const matrix of doubles containing
* the values
*/
inline const CwiseUnaryOp<val_Op, const Derived>
val() const { return CwiseUnaryOp<val_Op, const Derived>(derived());
}

/**
* Coefficient-wise function applying val_Op struct to a matrix of var
* or vari* and returning a view to the values
*/
inline CwiseUnaryView<val_Op, Derived>
val() { return CwiseUnaryView<val_Op, Derived>(derived());
}

/**
* Structure to return tangent from an fvar.
*/
struct d_Op {
EIGEN_EMPTY_STRUCT_CTOR(d_Op);

//Returns tangent from an fvar
template<typename T = Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
forward_return_t<T> operator()(T &v) const { return v.d_; }
};

/**
* Coefficient-wise function applying d_Op struct to a matrix of const fvar<T>
* and returning a const matrix of type T containing the tangents
*/
inline const CwiseUnaryOp<d_Op, const Derived>
d() const { return CwiseUnaryOp<d_Op, const Derived>(derived());
}

/**
* Coefficient-wise function applying d_Op struct to a matrix of fvar<T>
* and returning a view to a matrix of type T of the tangents that can
* be modified
*/
inline CwiseUnaryView<d_Op, Derived>
d() { return CwiseUnaryView<d_Op, Derived>(derived());
}

/**
* Structure to return adjoints from var and vari*. Deduces whether the variables
* are pointers (i.e. vari*) to determine whether to return the adjoint or
* first point to the underlying vari* (in the case of var).
*/
struct adj_Op {
EIGEN_EMPTY_STRUCT_CTOR(adj_Op);

//Returns adjoint from a vari*
template<typename T = Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
std::enable_if_t<std::is_pointer<T>::value, double_return_t<T>>
operator()(T &v) const { return v->adj_; }

//Returns adjoint from a var
template<typename T = Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
std::enable_if_t<!std::is_pointer<T>::value, double_return_t<T>>
operator()(T &v) const { return v.vi_->adj_; }
};

/**
* Coefficient-wise function applying adj_Op struct to a matrix of const var
* and returning a const matrix of type T containing the values
*/
inline const CwiseUnaryOp<adj_Op, const Derived>
adj() const { return CwiseUnaryOp<adj_Op, const Derived>(derived());
}

/**
* Coefficient-wise function applying adj_Op struct to a matrix of var
* and returning a view to a matrix of doubles of the adjoints that can
* be modified
*/
inline CwiseUnaryView<adj_Op, Derived>
adj() { return CwiseUnaryView<adj_Op, Derived>(derived());
}
/**
* Structure to return vari* from a var.
*/
struct vi_Op {
EIGEN_EMPTY_STRUCT_CTOR(vi_Op);

//Returns vari* from a var
template<typename T = Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
vari_return_t<T> operator()(T &v) const { return v.vi_; }
};

/**
* Coefficient-wise function applying vi_Op struct to a matrix of const var
* and returning a const matrix of vari*
*/
inline const CwiseUnaryOp<vi_Op, const Derived>
vi() const { return CwiseUnaryOp<vi_Op, const Derived>(derived());
}

/**
* Coefficient-wise function applying vi_Op struct to a matrix of var
* and returning a view to a matrix of vari* that can be modified
*/
inline CwiseUnaryView<vi_Op, Derived>
vi() { return CwiseUnaryView<vi_Op, Derived>(derived());
}

#define EIGEN_STAN_MATRIXBASE_PLUGIN

#endif
2 changes: 1 addition & 1 deletion stan/math/prim/mat/err/check_finite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <stan/math/prim/scal/err/domain_error.hpp>
#include <stan/math/prim/scal/err/check_finite.hpp>
#include <stan/math/prim/mat/fun/value_of.hpp>
#include <Eigen/Dense>
#include <stan/math/prim/mat/fun/Eigen.hpp>
#include <boost/math/special_functions/fpclassify.hpp>

namespace stan {
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/mat/err/is_lower_triangular.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#define STAN_MATH_PRIM_MAT_ERR_IS_LOWER_TRIANGULAR_HPP

#include <stan/math/prim/meta.hpp>
#include <Eigen/Dense>
#include <stan/math/prim/mat/fun/Eigen.hpp>

namespace stan {
namespace math {
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/mat/err/is_mat_finite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#define STAN_MATH_PRIM_MAT_ERR_IS_MAT_FINITE_HPP

#include <stan/math/prim/meta.hpp>
#include <Eigen/Dense>
#include <stan/math/prim/mat/fun/Eigen.hpp>

namespace stan {
namespace math {
Expand Down
10 changes: 10 additions & 0 deletions stan/math/prim/mat/fun/Eigen.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
#ifndef STAN_MATH_PRIM_MAT_FUN_EIGEN_HPP
#define STAN_MATH_PRIM_MAT_FUN_EIGEN_HPP

#ifdef EIGEN_MATRIXBASE_PLUGIN
#ifndef EIGEN_STAN_MATRIXBASE_PLUGIN
#error "Stan uses Eigen's EIGEN_MATRIXBASE_PLUGIN macro. To use your own "
"plugin add the eigen_plugin.h file to your plugin."
#endif
#else
#define EIGEN_MATRIXBASE_PLUGIN "stan/math/prim/mat/eigen_plugins.h"
#endif

#include <Eigen/Dense>
#include <Eigen/Sparse>
#include <Eigen/QR>
#include <Eigen/src/Core/NumTraits.h>

Expand Down
1 change: 0 additions & 1 deletion stan/math/prim/mat/fun/csr_extract_u.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/mat/fun/Eigen.hpp>
#include <Eigen/Sparse>
#include <vector>
#include <numeric>

Expand Down

0 comments on commit 1f2d306

Please sign in to comment.