Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Backport Eigen changes for proposed plugin refactor #2654

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7e3d8c2
Initial testing
Oct 15, 2021
e2e50d5
Repalcements
Oct 15, 2021
6c5fc2d
coeffRef def
Oct 15, 2021
86ca245
Replace val_op
Oct 15, 2021
59e535e
Identified broken tests
Oct 15, 2021
4b56e56
Fix stride detection
Oct 17, 2021
dc79dc9
Fix stride detection
Oct 17, 2021
ccb90f4
Stride defaults
Oct 17, 2021
82015b5
Test stride calcs
Oct 17, 2021
264ef56
Remove coefreff
Oct 17, 2021
4b37448
Revert "Remove coefreff"
Oct 17, 2021
5ea8563
Fix stride setting
Oct 18, 2021
269b7f6
cpplint
Oct 18, 2021
949116e
Merge commit '7d64047c4320ddd0a8f37a3071f04d6dfc852374' into HEAD
yashikno Oct 18, 2021
90edbaf
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Oct 18, 2021
341e0a8
opencl include
Oct 18, 2021
9212379
Pointer t and simplify strides
Oct 18, 2021
d151feb
Simplify code
Oct 18, 2021
205af1c
Robust return typedefs
Oct 18, 2021
139f5c4
Merge branch 'stan-dev:develop' into plugin-refactor
andrjohns Oct 19, 2021
e5d8295
Update coeffRef for Stan downstream
Oct 19, 2021
5c71291
Simplify implementations and add doc
Oct 21, 2021
ad112fd
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Oct 21, 2021
222161a
Merge branch 'develop' into plugin-refactor
andrjohns Dec 28, 2021
bcc3374
Backport Eigen changes
andrjohns Jan 12, 2022
1553cac
Revert Eigen changes
andrjohns Jan 13, 2022
c8d0cbc
Merge branch 'plugin-refactor' into feature/backport-view-stride
andrjohns Jan 13, 2022
c054456
Use backported eigen code with plugin refactor
andrjohns Jan 13, 2022
cce4f96
Merge branch 'stan-dev:develop' into backport-plugin-testing
andrjohns Jan 13, 2022
ec20ea5
Remove stray eval()
andrjohns Jan 13, 2022
0c61436
Merge branch 'backport-plugin-testing' of https://github.com/andrjohn…
andrjohns Jan 13, 2022
512968e
Test Eigen fixes
andrjohns Jan 13, 2022
b2ebd22
Eigen fixes 2
andrjohns Jan 13, 2022
1d52230
Merge branch 'develop' into backport-plugin-testing
andrjohns Jan 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions stan/math/opencl/rev/softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#ifdef STAN_OPENCL

#include <stan/math/opencl/prim/log_sum_exp.hpp>
#include <stan/math/opencl/prim/dot_product.hpp>
#include <stan/math/opencl/kernel_generator.hpp>
#include <stan/math/rev/core.hpp>
#include <stan/math/rev/fun/value_of.hpp>
Expand Down
3 changes: 1 addition & 2 deletions stan/math/opencl/rev/vari.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class vari_cl_base : public vari_base {
* @return The value of this vari.
*/
inline const auto& val() const { return val_; }
inline auto& val_op() { return val_; }
inline auto& val() { return val_; }

/**
* Return a reference to the derivative of the root expression with
Expand All @@ -58,7 +58,6 @@ class vari_cl_base : public vari_base {
*/
inline auto& adj() { return adj_; }
inline auto& adj() const { return adj_; }
inline auto& adj_op() { return adj_; }

/**
* Returns a view into a block of matrix.
Expand Down
225 changes: 9 additions & 216 deletions stan/math/prim/eigen_plugins.h
Original file line number Diff line number Diff line change
@@ -1,218 +1,11 @@
/**
* Reimplements is_fvar without requiring external math headers
*
* decltype((void)(T::d_)) is a pre C++17 replacement for
* std::void_t<decltype(T::d_)>
*
* TODO(Andrew): Replace with std::void_t after move to C++17
*/
template<class, class = void>
struct is_fvar : std::false_type
{ };
template<class T>
struct is_fvar<T, decltype((void)(T::d_))> : std::true_type
{ };
#ifndef STAN_MATH_EIGEN_PLUGINS_H
#define STAN_MATH_EIGEN_PLUGINS_H

//TODO(Andrew): Replace std::is_const<>::value with std::is_const_v<> after move to C++17
template<typename T>
using double_return_t = std::conditional_t<std::is_const<std::remove_reference_t<T>>::value,
const double,
double>;
template<typename T>
using reverse_return_t = std::conditional_t<std::is_const<std::remove_reference_t<T>>::value,
const double&,
double&>;
#include "plugins/typedefs.h"
#include "plugins/adj_view.h"
#include "plugins/val_view.h"
#include "plugins/d_view.h"
#include "plugins/vi_view.h"

template<typename T>
using vari_return_t = std::conditional_t<std::is_const<std::remove_reference_t<T>>::value,
const decltype(T::vi_)&,
decltype(T::vi_)&>;

template<typename T>
using forward_return_t = std::conditional_t<std::is_const<std::remove_reference_t<T>>::value,
const decltype(T::val_)&,
decltype(T::val_)&>;

/**
* Structure to return a view to the values in a var, vari*, and fvar<T>.
* To identify the correct member to call for a given input, templates
* check a combination of whether the input is a pointer (i.e. vari*)
* and/or whether the input has member ".d_" (i.e. fvar).
*
* There are two methods for returning doubles unchanged. One which takes a reference
* to a double and returns the same reference, used when 'chaining' methods
* (i.e. A.adj().val()). The other for passing and returning by value, used directly
* with matrices of doubles (i.e. A.val(), where A is of type MatrixXd).
*
* For definitions of EIGEN_EMPTY_STRUCT_CTOR, EIGEN_DEVICE_FUNC, and
* EIGEN_STRONG_INLINE; see: https://eigen.tuxfamily.org/dox/XprHelper_8h_source.html
*/
struct val_Op{
EIGEN_EMPTY_STRUCT_CTOR(val_Op);

//Returns value from a vari*
template<typename T = Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
std::enable_if_t<std::is_pointer<T>::value, const double&>
operator()(T &v) const { return v->val_; }

//Returns value from a var
template<typename T = Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
std::enable_if_t<(!std::is_pointer<T>::value && !is_fvar<T>::value
&& !std::is_arithmetic<T>::value), const double&>
operator()(T &v) const { return v.vi_->val_; }

//Returns value from an fvar
template<typename T = Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
std::enable_if_t<is_fvar<T>::value, forward_return_t<T>>
operator()(T &v) const { return v.val_; }

//Returns double unchanged from input (by value)
template<typename T = Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
std::enable_if_t<std::is_arithmetic<T>::value, double_return_t<T>>
operator()(T v) const { return v; }

//Returns double unchanged from input (by reference)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const double& operator()(const double& v) const { return v; }

//Returns double unchanged from input (by reference)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
double& operator()(double& v) const { return v; }
};

/**
* Coefficient-wise function applying val_Op struct to a matrix of const var
* or vari* and returning a view to the const matrix of doubles containing
* the values
*/
inline const CwiseUnaryOp<val_Op, const Derived>
val() const { return CwiseUnaryOp<val_Op, const Derived>(derived());
}

/**
* Coefficient-wise function applying val_Op struct to a matrix of var
* or vari* and returning a view to the matrix of doubles containing
* the values
*/
inline CwiseUnaryOp<val_Op, Derived>
val_op() { return CwiseUnaryOp<val_Op, Derived>(derived());
}

/**
* Coefficient-wise function applying val_Op struct to a matrix of var
* or vari* and returning a view to the values
*/
inline CwiseUnaryView<val_Op, Derived>
val() { return CwiseUnaryView<val_Op, Derived>(derived());
}

/**
* Structure to return tangent from an fvar.
*/
struct d_Op {
EIGEN_EMPTY_STRUCT_CTOR(d_Op);

//Returns tangent from an fvar
template<typename T = Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
forward_return_t<T> operator()(T &v) const { return v.d_; }
};

/**
* Coefficient-wise function applying d_Op struct to a matrix of const fvar<T>
* and returning a const matrix of type T containing the tangents
*/
inline const CwiseUnaryOp<d_Op, const Derived>
d() const { return CwiseUnaryOp<d_Op, const Derived>(derived());
}

/**
* Coefficient-wise function applying d_Op struct to a matrix of fvar<T>
* and returning a view to a matrix of type T of the tangents that can
* be modified
*/
inline CwiseUnaryView<d_Op, Derived>
d() { return CwiseUnaryView<d_Op, Derived>(derived());
}

/**
* Structure to return adjoints from var and vari*. Deduces whether the variables
* are pointers (i.e. vari*) to determine whether to return the adjoint or
* first point to the underlying vari* (in the case of var).
*/
struct adj_Op {
EIGEN_EMPTY_STRUCT_CTOR(adj_Op);

//Returns adjoint from a vari*
template<typename T = Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
std::enable_if_t<std::is_pointer<T>::value, reverse_return_t<T>>
operator()(T &v) const { return v->adj_; }

//Returns adjoint from a var
template<typename T = Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
std::enable_if_t<!std::is_pointer<T>::value, reverse_return_t<T>>
operator()(T &v) const { return v.vi_->adj_; }

};

/**
* Coefficient-wise function applying adj_Op struct to a matrix of const var
* and returning a const matrix of type T containing the values
*/
inline const CwiseUnaryOp<adj_Op, const Derived>
adj() const { return CwiseUnaryOp<adj_Op, const Derived>(derived());
}

/**
* Coefficient-wise function applying adj_Op struct to a matrix of var
* and returning a view to a matrix of doubles of the adjoints that can
* be modified. This is meant to be used on the rhs of expressions.
*/
inline CwiseUnaryOp<adj_Op, Derived> adj_op() {
return CwiseUnaryOp<adj_Op, Derived>(derived());
}

/**
* Coefficient-wise function applying adj_Op struct to a matrix of var
* and returning a view to a matrix of doubles of the adjoints that can
* be modified
*/
inline CwiseUnaryView<adj_Op, Derived>
adj() { return CwiseUnaryView<adj_Op, Derived>(derived());
}
/**
* Structure to return vari* from a var.
*/
struct vi_Op {
EIGEN_EMPTY_STRUCT_CTOR(vi_Op);

//Returns vari* from a var
template<typename T = Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
vari_return_t<T> operator()(T &v) const { return v.vi_; }
};

/**
* Coefficient-wise function applying vi_Op struct to a matrix of const var
* and returning a const matrix of vari*
*/
inline const CwiseUnaryOp<vi_Op, const Derived>
vi() const { return CwiseUnaryOp<vi_Op, const Derived>(derived());
}

/**
* Coefficient-wise function applying vi_Op struct to a matrix of var
* and returning a view to a matrix of vari* that can be modified
*/
inline CwiseUnaryView<vi_Op, Derived>
vi() { return CwiseUnaryView<vi_Op, Derived>(derived());
}

#define EIGEN_STAN_MATRIXBASE_PLUGIN
#define EIGEN_STAN_ARRAYBASE_PLUGIN
#define EIGEN_STAN_DENSEBASE_PLUGIN
#endif
21 changes: 7 additions & 14 deletions stan/math/prim/fun/Eigen.hpp
Original file line number Diff line number Diff line change
@@ -1,32 +1,25 @@
#ifndef STAN_MATH_PRIM_FUN_EIGEN_HPP
#define STAN_MATH_PRIM_FUN_EIGEN_HPP

#ifdef EIGEN_MATRIXBASE_PLUGIN
#ifndef EIGEN_STAN_MATRIXBASE_PLUGIN
#error "Stan uses Eigen's EIGEN_MATRIXBASE_PLUGIN macro. To use your own "
#ifdef EIGEN_DENSEBASE_PLUGIN
#ifndef EIGEN_STAN_DENSEBASE_PLUGIN
#error "Stan uses Eigen's EIGEN_DENSEBASE_PLUGIN macro. To use your own "
"plugin add the eigen_plugin.h file to your plugin."
#endif
#else
#define EIGEN_MATRIXBASE_PLUGIN "stan/math/prim/eigen_plugins.h"
#endif

#ifdef EIGEN_ARRAYBASE_PLUGIN
#ifndef EIGEN_STAN_ARRAYBASE_PLUGIN
#error "Stan uses Eigen's EIGEN_ARRAYBASE_PLUGIN macro. To use your own "
"plugin add the eigen_plugin.h file to your plugin."
#endif
#else
#define EIGEN_ARRAYBASE_PLUGIN "stan/math/prim/eigen_plugins.h"
// By using the DenseBase plugin, we do not need to specify both
// MatrixBase and ArrayBase inclusions
#define EIGEN_DENSEBASE_PLUGIN "stan/math/prim/eigen_plugins.h"
#endif

#include <stan/math/prim/plugins/Core.h>
#include <Eigen/Dense>
#include <Eigen/Sparse>
#include <Eigen/QR>
#include <Eigen/src/Core/NumTraits.h>
#include <Eigen/SVD>

namespace Eigen {

/**
* Traits specialization for Eigen binary operations for `int`
* and `double` arguments.
Expand Down
Loading