Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions example/eigen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "example.h"
#include <pybind11/eigen.h>
#include <Eigen/Cholesky>

Eigen::VectorXf double_col(const Eigen::VectorXf& x)
{ return 2.0f * x; }
Expand All @@ -19,6 +20,14 @@ Eigen::RowVectorXf double_row(const Eigen::RowVectorXf& x)
Eigen::MatrixXf double_mat_cm(const Eigen::MatrixXf& x)
{ return 2.0f * x; }

// Different ways of passing via Eigen::Ref; the first and second are the Eigen-recommended
Eigen::MatrixXd cholesky1(Eigen::Ref<Eigen::MatrixXd> &x) { return x.llt().matrixL(); }
Eigen::MatrixXd cholesky2(const Eigen::Ref<const Eigen::MatrixXd> &x) { return x.llt().matrixL(); }
Eigen::MatrixXd cholesky3(const Eigen::Ref<Eigen::MatrixXd> &x) { return x.llt().matrixL(); }
Eigen::MatrixXd cholesky4(Eigen::Ref<const Eigen::MatrixXd> &x) { return x.llt().matrixL(); }
Eigen::MatrixXd cholesky5(Eigen::Ref<Eigen::MatrixXd> x) { return x.llt().matrixL(); }
Eigen::MatrixXd cholesky6(Eigen::Ref<const Eigen::MatrixXd> x) { return x.llt().matrixL(); }

typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> MatrixXfRowMajor;
MatrixXfRowMajor double_mat_rm(const MatrixXfRowMajor& x)
{ return 2.0f * x; }
Expand All @@ -40,6 +49,12 @@ void init_eigen(py::module &m) {
m.def("double_row", &double_row);
m.def("double_mat_cm", &double_mat_cm);
m.def("double_mat_rm", &double_mat_rm);
m.def("cholesky1", &cholesky1);
m.def("cholesky2", &cholesky2);
m.def("cholesky3", &cholesky3);
m.def("cholesky4", &cholesky4);
m.def("cholesky5", &cholesky5);
m.def("cholesky6", &cholesky6);

m.def("fixed_r", [mat]() -> FixedMatrixR {
return FixedMatrixR(mat);
Expand Down
8 changes: 8 additions & 0 deletions example/eigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from example import sparse_passthrough_r, sparse_passthrough_c
from example import double_row, double_col
from example import double_mat_cm, double_mat_rm
from example import cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6
try:
import numpy as np
import scipy
Expand Down Expand Up @@ -70,3 +71,10 @@ def check_got_vs_ref(got_x, ref_x):
for slice_idx, ref_mat in enumerate(slices):
print("double_mat_cm(%d) = %s" % (slice_idx, check_got_vs_ref(double_mat_cm(ref_mat), 2.0 * ref_mat)))
print("double_mat_rm(%d) = %s" % (slice_idx, check_got_vs_ref(double_mat_rm(ref_mat), 2.0 * ref_mat)))

i = 1
for chol in [cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6]:
mymat = chol(np.array([[1,2,4], [2,13,23], [4,23,77]]))
print("cholesky" + str(i) + " " + ("OK" if (mymat == np.array([[1,0,0], [2,3,0], [4,5,6]])).all() else "NOT OKAY"))
i += 1

6 changes: 6 additions & 0 deletions example/eigen.ref
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,9 @@ double_mat_cm(1) = OK
double_mat_rm(1) = OK
double_mat_cm(2) = OK
double_mat_rm(2) = OK
cholesky1 OK
cholesky2 OK
cholesky3 OK
cholesky4 OK
cholesky5 OK
cholesky6 OK
35 changes: 34 additions & 1 deletion include/pybind11/eigen.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ template <typename T> class is_eigen_dense {
static constexpr bool value = decltype(test(std::declval<T>()))::value;
};

// Eigen::Ref<Derived> satisfies is_eigen_dense, but isn't constructible, which means we can't load
// it (since there is no reference!), but we can cast from it.
template <typename T> class is_eigen_ref {
private:
template<typename Derived> static typename std::enable_if<
std::is_same<typename std::remove_const<T>::type, Eigen::Ref<Derived>>::value,
Derived>::type test(const Eigen::Ref<Derived> &);
static void test(...);
public:
typedef decltype(test(std::declval<T>())) Derived;
static constexpr bool value = !std::is_void<Derived>::value;
};

template <typename T> class is_eigen_sparse {
private:
template<typename Derived> static std::true_type test(const Eigen::SparseMatrixBase<Derived> &);
Expand All @@ -49,7 +62,7 @@ template <typename T> class is_eigen_sparse {
};

template<typename Type>
struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value>::type> {
struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value && !is_eigen_ref<Type>::value>::type> {
typedef typename Type::Scalar Scalar;
static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit;
static constexpr bool isVector = Type::IsVectorAtCompileTime;
Expand Down Expand Up @@ -149,6 +162,26 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value>::t
static PYBIND11_DESCR cols() { return _<T::ColsAtCompileTime>(); }
};

template<typename Type>
struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value && is_eigen_ref<Type>::value>::type> {
private:
using Derived = typename std::remove_const<typename is_eigen_ref<Type>::Derived>::type;
using DerivedCaster = type_caster<Derived>;
DerivedCaster derived_caster;
protected:
std::unique_ptr<Type> value;
public:
bool load(handle src, bool convert) { if (derived_caster.load(src, convert)) { value.reset(new Type(derived_caster.operator Derived&())); return true; } return false; }
static handle cast(const Type &src, return_value_policy policy, handle parent) { return DerivedCaster::cast(src, policy, parent); }
static handle cast(const Type *src, return_value_policy policy, handle parent) { return DerivedCaster::cast(*src, policy, parent); }

static PYBIND11_DESCR name() { return DerivedCaster::name(); }

operator Type*() { return value.get(); }
operator Type&() { if (!value) pybind11_fail("Eigen::Ref<...> value not loaded"); return *value; }
template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
};

template<typename Type>
struct type_caster<Type, typename std::enable_if<is_eigen_sparse<Type>::value>::type> {
typedef typename Type::Scalar Scalar;
Expand Down