From 37741c4b80e573791d14ba1260265a9eee5e0e31 Mon Sep 17 00:00:00 2001 From: Pauli Virtanen Date: Wed, 26 Feb 2014 01:32:26 +0200 Subject: [PATCH] BUG: sparse/_csparsetools: work around dtype equality bugs in earlier numpy versions --- scipy/sparse/_csparsetools.pyx | 54 +++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/scipy/sparse/_csparsetools.pyx b/scipy/sparse/_csparsetools.pyx index cf8af81c854a..92dd96e52ba5 100644 --- a/scipy/sparse/_csparsetools.pyx +++ b/scipy/sparse/_csparsetools.pyx @@ -33,24 +33,34 @@ ctypedef fused value_t: long double complex +# Use .char to work around dtype comparison bugs in earlier Numpy +# versions + DTYPE_NAME_MAP = { - np.dtype(np.bool_): "npy_bool", - np.dtype(np.int8): "npy_int8", - np.dtype(np.uint8): "npy_uint8", - np.dtype(np.int16): "npy_int16", - np.dtype(np.uint16): "npy_uint16", - np.dtype(np.int32): "npy_int32", - np.dtype(np.uint32): "npy_uint32", - np.dtype(np.int64): "npy_int64", - np.dtype(np.uint64): "npy_uint64", - np.dtype(np.float32): "npy_float32", - np.dtype(np.float64): "npy_float64", - np.dtype(np.longdouble): "long double", - np.dtype(np.complex64): "float complex", - np.dtype(np.complex128): "double complex", - np.dtype(np.clongdouble): "long double complex" + np.dtype(np.bool_).char: "npy_bool", + np.dtype(np.int8).char: "npy_int8", + np.dtype(np.uint8).char: "npy_uint8", + np.dtype(np.int16).char: "npy_int16", + np.dtype(np.uint16).char: "npy_uint16", + np.dtype(np.int32).char: "npy_int32", + np.dtype(np.uint32).char: "npy_uint32", + np.dtype(np.int64).char: "npy_int64", + np.dtype(np.uint64).char: "npy_uint64", + np.dtype(np.float32).char: "npy_float32", + np.dtype(np.float64).char: "npy_float64", + np.dtype(np.longdouble).char: "long double", + np.dtype(np.complex64).char: "float complex", + np.dtype(np.complex128).char: "double complex", + np.dtype(np.clongdouble).char: "long double complex" } +if np.dtype('q').itemsize == 4: + DTYPE_NAME_MAP['q'] = "npy_int32" + DTYPE_NAME_MAP['Q'] = "npy_uint32" +elif np.dtype('q').itemsize == 8: + DTYPE_NAME_MAP['q'] = "npy_int64" + DTYPE_NAME_MAP['Q'] = "npy_uint64" + def prepare_index_for_memoryview(cnp.ndarray i, cnp.ndarray j, cnp.ndarray x=None): """ @@ -135,10 +145,11 @@ def lil_insert(cnp.npy_intp M, cnp.npy_intp N, object[:] rows, object[:] datas, """ Work around broken Cython fused type dispatch """ + dtype = np.dtype(dtype) try: - key = DTYPE_NAME_MAP[dtype] + key = DTYPE_NAME_MAP[dtype.char] except KeyError: - raise ValueError("Unsupported data type: %r" % (dtype,)) + raise ValueError("Unsupported data type: %r" % (dtype.char,)) _lil_insert[key](M, N, rows, datas, i, j, x) @@ -239,10 +250,13 @@ def lil_fancy_set(cnp.npy_intp M, cnp.npy_intp N, Work around broken Cython fused type dispatch """ try: - key = DTYPE_NAME_MAP[values.dtype] - ikey = DTYPE_NAME_MAP[i_idx.dtype] + key = DTYPE_NAME_MAP[values.dtype.char] + except KeyError: + raise ValueError("Unsupported data type: %r" % (values.dtype.char,)) + try: + ikey = DTYPE_NAME_MAP[i_idx.dtype.char] except KeyError: - raise ValueError("Unsupported data type") + raise ValueError("Unsupported data type: %r" % (i_idx.dtype.char,)) if key != "npy_bool": _lil_fancy_set[ikey, key](M, N, rows, data, i_idx, j_idx, values)