Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to pybind11 2.0.0 #21

Merged
merged 1 commit into from Jan 2, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .appveyor.yml
Expand Up @@ -24,7 +24,7 @@ install:
- conda info -a
- conda install pytest -c conda-forge
- cd test
- conda install xtensor==0.2.2 pytest numpy pybind11==1.8.1 -c conda-forge
- conda install xtensor==0.2.2 pytest numpy pybind11==2.0.0 -c conda-forge
- xcopy /S %APPVEYOR_BUILD_FOLDER%\include %MINICONDA%\include

build_script:
Expand Down
2 changes: 1 addition & 1 deletion .travis.yml
Expand Up @@ -59,7 +59,7 @@ install:
# Useful for debugging any issues with conda
- conda info -a
- cd test
- conda install xtensor==0.2.2 pytest numpy pybind11==1.8.1 -c conda-forge
- conda install xtensor==0.2.2 pytest numpy pybind11==2.0.0 -c conda-forge
- cp -r $TRAVIS_BUILD_DIR/include/* $HOME/miniconda/include/

script:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -28,7 +28,7 @@ conda install -c conda-forge xtensor-python

| `xtensor-python` | `xtensor` | `pybind11` |
|-------------------|------------|-------------|
| master | master | ^1.8.1 |
| master | master | ^2.0.0 |
| 0.2.0 | ^0.2.1 | ^1.8.1 |
| 0.1.0 | ^0.1.1 | ^1.8.1 |

Expand Down
107 changes: 82 additions & 25 deletions include/xtensor-python/pyarray.hpp
Expand Up @@ -14,17 +14,46 @@
#include <vector>

#include "pybind11/numpy.h"
#include "pybind11_backport.hpp"

#include "xtensor/xexpression.hpp"
#include "xtensor/xsemantic.hpp"
#include "xtensor/xiterator.hpp"

namespace xt
{
template <class T, int ExtraFlags>
class pyarray;
}

using pybind_array = pybind11::backport::array;
using buffer_info = pybind11::buffer_info;
namespace pybind11
{
namespace detail
{
template <typename T, int ExtraFlags>
struct pyobject_caster<xt::pyarray<T, ExtraFlags>>
{
using type = xt::pyarray<T, ExtraFlags>;

bool load(handle src, bool)
{
value = type::ensure(src);
return static_cast<bool>(value);
}

static handle cast(const handle &src, return_value_policy, handle)
{
return src.inc_ref();
}

PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
};
}
}

namespace xt
{

using pybind_array = pybind11::array;

/***********************
* pyarray declaration *
Expand Down Expand Up @@ -95,11 +124,11 @@ namespace xt

using closure_type = const self_type&;

PYBIND11_OBJECT_CVT(pyarray, pybind_array, is_non_null, m_ptr = ensure_(m_ptr));

pyarray();

explicit pyarray(const buffer_info& info);
pyarray(pybind11::handle h, borrowed_t);
pyarray(pybind11::handle h, stolen_t);
pyarray(const pybind11::object &o);

pyarray(const shape_type& shape,
const strides_type& strides,
Expand Down Expand Up @@ -188,6 +217,9 @@ namespace xt
template <class E>
pyarray& operator=(const xexpression<E>& e);

static pyarray ensure(pybind11::handle h);
static bool _check(pybind11::handle h);

private:

template<typename... Args>
Expand All @@ -199,11 +231,10 @@ namespace xt

static bool is_non_null(PyObject* ptr);

static PyObject *ensure_(PyObject* ptr);

mutable shape_type m_shape;
mutable strides_type m_strides;

static PyObject* raw_array_t(PyObject* ptr);
};

/**************************************
Expand All @@ -230,16 +261,29 @@ namespace xt

template <class T, int ExtraFlags>
inline pyarray<T, ExtraFlags>::pyarray()
: pybind_array()
: pybind_array(0, static_cast<const_pointer>(nullptr))
{
}

template <class T, int ExtraFlags>
inline pyarray<T, ExtraFlags>::pyarray(pybind11::handle h, borrowed_t) : pybind_array(h, borrowed)
{
}

template <class T, int ExtraFlags>
inline pyarray<T, ExtraFlags>::pyarray(const buffer_info& info)
: pybind_array(info)
inline pyarray<T, ExtraFlags>::pyarray(pybind11::handle h, stolen_t) : pybind_array(h, stolen)
{
}

template <class T, int ExtraFlags>
inline pyarray<T, ExtraFlags>::pyarray(const pybind11::object &o) : pybind_array(raw_array_t(o.ptr()), stolen)
{
if (!m_ptr)
{
throw pybind11::error_already_set();
}
}

template <class T, int ExtraFlags>
inline pyarray<T, ExtraFlags>::pyarray(const shape_type& shape,
const strides_type& strides,
Expand Down Expand Up @@ -512,7 +556,7 @@ namespace xt
template <class T, int ExtraFlags>
inline auto pyarray<T, ExtraFlags>::storage_begin() -> storage_iterator
{
return reinterpret_cast<storage_iterator>(pybind11::backport::array_proxy(m_ptr)->data);
return reinterpret_cast<storage_iterator>(pybind11::detail::array_proxy(m_ptr)->data);
}

template <class T, int ExtraFlags>
Expand All @@ -524,7 +568,7 @@ namespace xt
template <class T, int ExtraFlags>
inline auto pyarray<T, ExtraFlags>::storage_begin() const -> const_storage_iterator
{
return reinterpret_cast<const_storage_iterator>(pybind11::backport::array_proxy(m_ptr)->data);
return reinterpret_cast<const_storage_iterator>(pybind11::detail::array_proxy(m_ptr)->data);
}

template <class T, int ExtraFlags>
Expand All @@ -536,7 +580,7 @@ namespace xt
template <class T, int ExtraFlags>
inline auto pyarray<T, ExtraFlags>::storage_cbegin() const -> const_storage_iterator
{
return reinterpret_cast<const_storage_iterator>(pybind11::backport::array_proxy(m_ptr)->data);
return reinterpret_cast<const_storage_iterator>(pybind11::detail::array_proxy(m_ptr)->data);
}

template <class T, int ExtraFlags>
Expand All @@ -560,6 +604,25 @@ namespace xt
return semantic_base::operator=(e);
}

template <class T, int ExtraFlags>
inline pyarray<T, ExtraFlags> pyarray<T, ExtraFlags>::ensure(pybind11::handle h)
{
auto result = pybind11::reinterpret_steal<pyarray>(raw_array_t(h.ptr()));
if (!pybind11::handle(result))
{
PyErr_Clear();
}
return result;
}

template <class T, int ExtraFlags>
inline bool pyarray<T, ExtraFlags>::_check(pybind11::handle h)
{
const auto &api = pybind11::detail::npy_api::get();
return api.PyArray_Check_(h.ptr())
&& api.PyArray_EquivTypes_(pybind11::detail::array_proxy(h.ptr())->descr, pybind11::dtype::of<T>().ptr());
}

// Private methods

template <class T, int ExtraFlags>
Expand Down Expand Up @@ -591,23 +654,17 @@ namespace xt
}

template <class T, int ExtraFlags>
inline PyObject* pyarray<T, ExtraFlags>::ensure_(PyObject* ptr)
inline PyObject* pyarray<T, ExtraFlags>::raw_array_t(PyObject* ptr)
{
if (ptr == nullptr)
{
return nullptr;
}
API& api = lookup_api();
PyObject* descr = api.PyArray_DescrFromType_(pybind11::detail::npy_format_descriptor<T>::value);
PyObject* result = api.PyArray_FromAny_(ptr, descr, 0, 0, API::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
if (!result)
{
PyErr_Clear();
}
Py_DECREF(ptr);
return result;
return pybind11::detail::npy_api::get().PyArray_FromAny_(
ptr, pybind11::dtype::of<T>().release().ptr(), 0, 0,
pybind11::detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr
);
}

}

#endif
Expand Down