Skip to content

Commit

Permalink
Add support of numpy 2.0.0b1
Browse files Browse the repository at this point in the history
  • Loading branch information
duburcqa committed Mar 15, 2024
1 parent 37ad9a1 commit 93993ee
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Added
- Allow use of installed JRL-cmakemodule ([#446](https://github.com/stack-of-tasks/eigenpy/pull/446)
- Support of Numpy 2.0.0b1 ([#448](https://github.com/stack-of-tasks/eigenpy/pull/448))

### Fixed
- Fix unit test build in C++11 ([#442](https://github.com/stack-of-tasks/eigenpy/pull/442))
Expand Down
8 changes: 8 additions & 0 deletions include/eigenpy/numpy-allocator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ struct numpy_allocator_impl_matrix<Eigen::Ref<MatType, Options, Stride> > {
outer_stride = reverse_strides ? mat.innerStride()
: mat.outerStride();

#if NPY_ABI_VERSION < 0x02000000
const int elsize = call_PyArray_DescrFromType(Scalar_type_code)->elsize;
#else
const int elsize = PyDataType_ELSIZE(call_PyArray_DescrFromType(Scalar_type_code));
#endif
npy_intp strides[2] = {elsize * inner_stride, elsize * outer_stride};

PyArrayObject *pyArray = (PyArrayObject *)call_PyArray_New(
Expand Down Expand Up @@ -204,7 +208,11 @@ struct numpy_allocator_impl_matrix<
outer_stride = reverse_strides ? mat.innerStride()
: mat.outerStride();

#if NPY_ABI_VERSION < 0x02000000
const int elsize = call_PyArray_DescrFromType(Scalar_type_code)->elsize;
#else
const int elsize = PyDataType_ELSIZE(call_PyArray_DescrFromType(Scalar_type_code));
#endif
npy_intp strides[2] = {elsize * inner_stride, elsize * outer_stride};

PyArrayObject *pyArray = (PyArrayObject *)call_PyArray_New(
Expand Down
27 changes: 26 additions & 1 deletion include/eigenpy/numpy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,34 @@
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#endif

/* Allow compiling against NumPy 1.x and 2.x
see: https://github.com/numpy/numpy/blob/afea8fd66f6bdbde855f5aff0b4e73eb0213c646/doc/source/reference/c-api/array.rst#L1224
*/
#if NPY_ABI_VERSION < 0x02000000
#define PyArray_DescrProto PyArray_Descr
#endif

#include <numpy/ndarrayobject.h>
#include <numpy/ufuncobject.h>

#if NPY_ABI_VERSION < 0x02000000
static inline PyArray_ArrFuncs *
PyDataType_GetArrFuncs(PyArray_Descr *descr)
{
return descr->f;
}
#endif

/* PEP 674 disallow using macros as l-values
see : https://peps.python.org/pep-0674/
*/
#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_TYPE)
static inline void _Py_SET_TYPE(PyObject *o, PyTypeObject *type) {
Py_TYPE(o) = type;
}
#define Py_SET_TYPE(o, type) _Py_SET_TYPE((PyObject*)(o), type)
#endif

#if defined _WIN32 || defined __CYGWIN__
#define EIGENPY_GET_PY_ARRAY_TYPE(array) \
call_PyArray_MinScalarType(array)->type_num
Expand Down Expand Up @@ -170,7 +195,7 @@ inline void call_PyArray_InitArrFuncs(PyArray_ArrFuncs* funcs) {
PyArray_InitArrFuncs(funcs);
}

inline int call_PyArray_RegisterDataType(PyArray_Descr* dtype) {
inline int call_PyArray_RegisterDataType(PyArray_DescrProto* dtype) {
return PyArray_RegisterDataType(dtype);
}

Expand Down
4 changes: 2 additions & 2 deletions include/eigenpy/user-type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ struct SpecialMethods<T, NPY_USERDEF> {
char* srcptr = static_cast<char*>(src);

PyArrayObject* py_array = static_cast<PyArrayObject*>(array);
PyArray_CopySwapFunc* copyswap = PyArray_DESCR(py_array)->f->copyswap;
PyArray_CopySwapFunc* copyswap = PyDataType_GetArrFuncs(PyArray_DESCR(py_array))->copyswap;

for (npy_intp i = 0; i < n; i++) {
copyswap(dstptr, srcptr, swap, array);
Expand All @@ -189,7 +189,7 @@ struct SpecialMethods<T, NPY_USERDEF> {
return (npy_bool)(value != ZeroValue);
} else {
T tmp_value;
PyArray_DESCR(py_array)->f->copyswap(
PyDataType_GetArrFuncs(PyArray_DESCR(py_array))->copyswap(
&tmp_value, ip, PyArray_ISBYTESWAPPED(py_array), array);
return (npy_bool)(tmp_value != ZeroValue);
}
Expand Down
6 changes: 5 additions & 1 deletion src/numpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ void import_numpy() {
}

int PyArray_TypeNum(PyTypeObject* type) {
return PyArray_TypeNumFromName(const_cast<char*>(type->tp_name));
PyArray_Descr * descr = PyArray_DescrFromTypeObject(reinterpret_cast<PyObject*>(type));
if (descr == NULL) {
return NPY_NOTYPE;
}
return descr->type_num;
}

#if defined _WIN32 || defined __CYGWIN__
Expand Down
8 changes: 4 additions & 4 deletions src/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ int Register::registerNewType(
throw std::invalid_argument("PyType_Ready fails to initialize input type.");
}

PyArray_Descr* descr_ptr =
new PyArray_Descr(*call_PyArray_DescrFromType(NPY_OBJECT));
PyArray_Descr& descr = *descr_ptr;
PyArray_DescrProto* descr_ptr = new PyArray_DescrProto();
Py_SET_TYPE(descr_ptr, &PyArrayDescr_Type);
PyArray_DescrProto& descr = *descr_ptr;
descr.typeobj = py_type_ptr;
descr.kind = 'V';
descr.byteorder = '=';
Expand Down Expand Up @@ -98,7 +98,7 @@ int Register::registerNewType(
PyArray_Descr* new_descr = call_PyArray_DescrFromType(code);

if (PyDict_SetItemString(py_type_ptr->tp_dict, "dtype",
(PyObject*)descr_ptr) < 0) {
(PyObject*)new_descr) < 0) {
throw std::invalid_argument("PyDict_SetItemString fails.");
}

Expand Down

0 comments on commit 93993ee

Please sign in to comment.