Skip to content

Commit

Permalink
Migrate to pybin11 2.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
SylvainCorlay committed Jan 2, 2017
1 parent 7cb176d commit f77a71f
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 205 deletions.
2 changes: 1 addition & 1 deletion .appveyor.yml
Expand Up @@ -24,7 +24,7 @@ install:
- conda info -a
- conda install pytest -c conda-forge
- cd test
- conda install xtensor==0.2.2 pytest numpy pybind11==1.8.1 -c conda-forge
- conda install xtensor==0.2.2 pytest numpy pybind11==2.0.0 -c conda-forge
- xcopy /S %APPVEYOR_BUILD_FOLDER%\include %MINICONDA%\include

build_script:
Expand Down
2 changes: 1 addition & 1 deletion .travis.yml
Expand Up @@ -59,7 +59,7 @@ install:
# Useful for debugging any issues with conda
- conda info -a
- cd test
- conda install xtensor==0.2.2 pytest numpy pybind11==1.8.1 -c conda-forge
- conda install xtensor==0.2.2 pytest numpy pybind11==2.0.0 -c conda-forge
- cp -r $TRAVIS_BUILD_DIR/include/* $HOME/miniconda/include/

script:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -28,7 +28,7 @@ conda install -c conda-forge xtensor-python

| `xtensor-python` | `xtensor` | `pybind11` |
|-------------------|------------|-------------|
| master | master | ^1.8.1 |
| master | master | ^2.0.0 |
| 0.2.0 | ^0.2.1 | ^1.8.1 |
| 0.1.0 | ^0.1.1 | ^1.8.1 |

Expand Down
107 changes: 82 additions & 25 deletions include/xtensor-python/pyarray.hpp
Expand Up @@ -14,17 +14,46 @@
#include <vector>

#include "pybind11/numpy.h"
#include "pybind11_backport.hpp"

#include "xtensor/xexpression.hpp"
#include "xtensor/xsemantic.hpp"
#include "xtensor/xiterator.hpp"

namespace xt
{
template <class T, int ExtraFlags>
class pyarray;
}

using pybind_array = pybind11::backport::array;
using buffer_info = pybind11::buffer_info;
namespace pybind11
{
namespace detail
{
template <typename T, int ExtraFlags>
struct pyobject_caster<xt::pyarray<T, ExtraFlags>>
{
using type = xt::pyarray<T, ExtraFlags>;

bool load(handle src, bool)
{
value = type::ensure(src);
return static_cast<bool>(value);
}

static handle cast(const handle &src, return_value_policy, handle)
{
return src.inc_ref();
}

PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
};
}
}

namespace xt
{

using pybind_array = pybind11::array;

/***********************
* pyarray declaration *
Expand Down Expand Up @@ -95,11 +124,11 @@ namespace xt

using closure_type = const self_type&;

PYBIND11_OBJECT_CVT(pyarray, pybind_array, is_non_null, m_ptr = ensure_(m_ptr));

pyarray();

explicit pyarray(const buffer_info& info);
pyarray(pybind11::handle h, borrowed_t);
pyarray(pybind11::handle h, stolen_t);
pyarray(const pybind11::object &o);

pyarray(const shape_type& shape,
const strides_type& strides,
Expand Down Expand Up @@ -188,6 +217,9 @@ namespace xt
template <class E>
pyarray& operator=(const xexpression<E>& e);

static pyarray ensure(pybind11::handle h);
static bool _check(pybind11::handle h);

private:

template<typename... Args>
Expand All @@ -199,11 +231,10 @@ namespace xt

static bool is_non_null(PyObject* ptr);

static PyObject *ensure_(PyObject* ptr);

mutable shape_type m_shape;
mutable strides_type m_strides;

static PyObject* raw_array_t(PyObject* ptr);
};

/**************************************
Expand All @@ -230,16 +261,29 @@ namespace xt

template <class T, int ExtraFlags>
inline pyarray<T, ExtraFlags>::pyarray()
: pybind_array()
: pybind_array(0, static_cast<const_pointer>(nullptr))
{
}

template <class T, int ExtraFlags>
inline pyarray<T, ExtraFlags>::pyarray(pybind11::handle h, borrowed_t) : pybind_array(h, borrowed)
{
}

template <class T, int ExtraFlags>
inline pyarray<T, ExtraFlags>::pyarray(const buffer_info& info)
: pybind_array(info)
inline pyarray<T, ExtraFlags>::pyarray(pybind11::handle h, stolen_t) : pybind_array(h, stolen)
{
}

template <class T, int ExtraFlags>
inline pyarray<T, ExtraFlags>::pyarray(const pybind11::object &o) : pybind_array(raw_array_t(o.ptr()), stolen)
{
if (!m_ptr)
{
throw pybind11::error_already_set();
}
}

template <class T, int ExtraFlags>
inline pyarray<T, ExtraFlags>::pyarray(const shape_type& shape,
const strides_type& strides,
Expand Down Expand Up @@ -512,7 +556,7 @@ namespace xt
template <class T, int ExtraFlags>
inline auto pyarray<T, ExtraFlags>::storage_begin() -> storage_iterator
{
return reinterpret_cast<storage_iterator>(pybind11::backport::array_proxy(m_ptr)->data);
return reinterpret_cast<storage_iterator>(pybind11::detail::array_proxy(m_ptr)->data);
}

template <class T, int ExtraFlags>
Expand All @@ -524,7 +568,7 @@ namespace xt
template <class T, int ExtraFlags>
inline auto pyarray<T, ExtraFlags>::storage_begin() const -> const_storage_iterator
{
return reinterpret_cast<const_storage_iterator>(pybind11::backport::array_proxy(m_ptr)->data);
return reinterpret_cast<const_storage_iterator>(pybind11::detail::array_proxy(m_ptr)->data);
}

template <class T, int ExtraFlags>
Expand All @@ -536,7 +580,7 @@ namespace xt
template <class T, int ExtraFlags>
inline auto pyarray<T, ExtraFlags>::storage_cbegin() const -> const_storage_iterator
{
return reinterpret_cast<const_storage_iterator>(pybind11::backport::array_proxy(m_ptr)->data);
return reinterpret_cast<const_storage_iterator>(pybind11::detail::array_proxy(m_ptr)->data);
}

template <class T, int ExtraFlags>
Expand All @@ -560,6 +604,25 @@ namespace xt
return semantic_base::operator=(e);
}

template <class T, int ExtraFlags>
inline pyarray<T, ExtraFlags> pyarray<T, ExtraFlags>::ensure(pybind11::handle h)
{
auto result = pybind11::reinterpret_steal<pyarray>(raw_array_t(h.ptr()));
if (!pybind11::handle(result))
{
PyErr_Clear();
}
return result;
}

template <class T, int ExtraFlags>
inline bool pyarray<T, ExtraFlags>::_check(pybind11::handle h)
{
const auto &api = pybind11::detail::npy_api::get();
return api.PyArray_Check_(h.ptr())
&& api.PyArray_EquivTypes_(pybind11::detail::array_proxy(h.ptr())->descr, pybind11::dtype::of<T>().ptr());
}

// Private methods

template <class T, int ExtraFlags>
Expand Down Expand Up @@ -591,23 +654,17 @@ namespace xt
}

template <class T, int ExtraFlags>
inline PyObject* pyarray<T, ExtraFlags>::ensure_(PyObject* ptr)
inline PyObject* pyarray<T, ExtraFlags>::raw_array_t(PyObject* ptr)
{
if (ptr == nullptr)
{
return nullptr;
}
API& api = lookup_api();
PyObject* descr = api.PyArray_DescrFromType_(pybind11::detail::npy_format_descriptor<T>::value);
PyObject* result = api.PyArray_FromAny_(ptr, descr, 0, 0, API::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
if (!result)
{
PyErr_Clear();
}
Py_DECREF(ptr);
return result;
return pybind11::detail::npy_api::get().PyArray_FromAny_(
ptr, pybind11::dtype::of<T>().release().ptr(), 0, 0,
pybind11::detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr
);
}

}

#endif
Expand Down

0 comments on commit f77a71f

Please sign in to comment.