Skip to content

Commit

Permalink
Adding ndarray type checks in python c-extensions.
Browse files Browse the repository at this point in the history
  • Loading branch information
shinmorino committed Apr 30, 2018
1 parent 434af02 commit 4e3bf09
Showing 1 changed file with 32 additions and 2 deletions.
34 changes: 32 additions & 2 deletions sqaodc/pyglue/pyglue.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,36 @@

namespace sq = sqaod;


template<class T> inline
void throwErrorForInvalidArray(PyObject *obj) {
throwError("Unsupported type.");
}

template<> inline
void throwErrorForInvalidArray<float>(PyObject *obj) {
bool ok = PyArray_Check(obj) && (PyArray_TYPE((PyArrayObject*)obj) == NPY_FLOAT);
throwErrorIf(!ok, "Invalid array type.");
}

template<> inline
void throwErrorForInvalidArray<double>(PyObject *obj) {
bool ok = PyArray_Check(obj) && (PyArray_TYPE((PyArrayObject*)obj) == NPY_DOUBLE);
throwErrorIf(!ok, "Invalid array type.");
}

template<> inline
void throwErrorForInvalidArray<char>(PyObject *obj) {
bool ok = PyArray_Check(obj) && (PyArray_TYPE((PyArrayObject*)obj) == NPY_INT8);
throwErrorIf(!ok, "Invalid array type.");
}


template<class real>
struct NpMatrixType {
typedef sqaod::MatrixType<real> Matrix;
NpMatrixType(PyObject *pyObj) {
throwErrorForInvalidArray<real>(pyObj);
PyArrayObject *arr = (PyArrayObject*)pyObj;
real *data = (real*)PyArray_DATA(arr);
assert(PyArray_NDIM(arr) == 2);
Expand Down Expand Up @@ -86,6 +112,7 @@ struct NpVectorType {


NpVectorType(PyObject *pyObj) {
throwErrorForInvalidArray<real>(pyObj);
obj = pyObj;
PyArrayObject *arr = (PyArrayObject*)pyObj;
real *data = (real*)PyArray_DATA(arr);
Expand Down Expand Up @@ -122,6 +149,7 @@ struct NpScalarRefType {
typedef sqaod::VectorType<real> Vector;

NpScalarRefType(PyObject *pyObj) {
throwErrorForInvalidArray<real>(pyObj);
obj = pyObj;
PyArrayObject *arr = (PyArrayObject*)pyObj;
throwErrorIf(3 <= PyArray_NDIM(arr), "not a scalar.");
Expand Down Expand Up @@ -206,11 +234,13 @@ PyObject *newScalarObj(float v) {
/* Helpers for dtypes */
inline
bool isFloat64(PyObject *dtype) {
return dtype == (PyObject*)&PyFloat64ArrType_Type;
/* Since PyFloat64ArrType_Type may be defined as another type, using PyDoubleArrType_Type. */
return dtype == (PyObject*)&PyDoubleArrType_Type;
}
inline
bool isFloat32(PyObject *dtype) {
return dtype == (PyObject*)&PyFloat32ArrType_Type;
/* Since PyFloat32ArrType_Type may be defined as another type, using PyFloatArrType_Type. */
return dtype == (PyObject*)&PyFloatArrType_Type;
}

#define ASSERT_DTYPE(dtype) if (!isFloat32(dtype) && !isFloat64(dtype)) \
Expand Down

0 comments on commit 4e3bf09

Please sign in to comment.