Skip to content

Commit

Permalink
BUG: sparse/_csparsetools: work around dtype equality bugs in earlier…
Browse files Browse the repository at this point in the history
… numpy versions
  • Loading branch information
pv committed Feb 25, 2014
1 parent 84cb181 commit 37741c4
Showing 1 changed file with 34 additions and 20 deletions.
54 changes: 34 additions & 20 deletions scipy/sparse/_csparsetools.pyx
Expand Up @@ -33,24 +33,34 @@ ctypedef fused value_t:
long double complex


# Use .char to work around dtype comparison bugs in earlier Numpy
# versions

DTYPE_NAME_MAP = {
np.dtype(np.bool_): "npy_bool",
np.dtype(np.int8): "npy_int8",
np.dtype(np.uint8): "npy_uint8",
np.dtype(np.int16): "npy_int16",
np.dtype(np.uint16): "npy_uint16",
np.dtype(np.int32): "npy_int32",
np.dtype(np.uint32): "npy_uint32",
np.dtype(np.int64): "npy_int64",
np.dtype(np.uint64): "npy_uint64",
np.dtype(np.float32): "npy_float32",
np.dtype(np.float64): "npy_float64",
np.dtype(np.longdouble): "long double",
np.dtype(np.complex64): "float complex",
np.dtype(np.complex128): "double complex",
np.dtype(np.clongdouble): "long double complex"
np.dtype(np.bool_).char: "npy_bool",
np.dtype(np.int8).char: "npy_int8",
np.dtype(np.uint8).char: "npy_uint8",
np.dtype(np.int16).char: "npy_int16",
np.dtype(np.uint16).char: "npy_uint16",
np.dtype(np.int32).char: "npy_int32",
np.dtype(np.uint32).char: "npy_uint32",
np.dtype(np.int64).char: "npy_int64",
np.dtype(np.uint64).char: "npy_uint64",
np.dtype(np.float32).char: "npy_float32",
np.dtype(np.float64).char: "npy_float64",
np.dtype(np.longdouble).char: "long double",
np.dtype(np.complex64).char: "float complex",
np.dtype(np.complex128).char: "double complex",
np.dtype(np.clongdouble).char: "long double complex"
}

if np.dtype('q').itemsize == 4:
DTYPE_NAME_MAP['q'] = "npy_int32"
DTYPE_NAME_MAP['Q'] = "npy_uint32"
elif np.dtype('q').itemsize == 8:
DTYPE_NAME_MAP['q'] = "npy_int64"
DTYPE_NAME_MAP['Q'] = "npy_uint64"


def prepare_index_for_memoryview(cnp.ndarray i, cnp.ndarray j, cnp.ndarray x=None):
"""
Expand Down Expand Up @@ -135,10 +145,11 @@ def lil_insert(cnp.npy_intp M, cnp.npy_intp N, object[:] rows, object[:] datas,
"""
Work around broken Cython fused type dispatch
"""
dtype = np.dtype(dtype)
try:
key = DTYPE_NAME_MAP[dtype]
key = DTYPE_NAME_MAP[dtype.char]
except KeyError:
raise ValueError("Unsupported data type: %r" % (dtype,))
raise ValueError("Unsupported data type: %r" % (dtype.char,))

_lil_insert[key](M, N, rows, datas, i, j, x)

Expand Down Expand Up @@ -239,10 +250,13 @@ def lil_fancy_set(cnp.npy_intp M, cnp.npy_intp N,
Work around broken Cython fused type dispatch
"""
try:
key = DTYPE_NAME_MAP[values.dtype]
ikey = DTYPE_NAME_MAP[i_idx.dtype]
key = DTYPE_NAME_MAP[values.dtype.char]
except KeyError:
raise ValueError("Unsupported data type: %r" % (values.dtype.char,))
try:
ikey = DTYPE_NAME_MAP[i_idx.dtype.char]
except KeyError:
raise ValueError("Unsupported data type")
raise ValueError("Unsupported data type: %r" % (i_idx.dtype.char,))

if key != "npy_bool":
_lil_fancy_set[ikey, key](M, N, rows, data, i_idx, j_idx, values)
Expand Down

0 comments on commit 37741c4

Please sign in to comment.