Skip to content

Commit

Permalink
BUG: sparse: ensure index dtype is large enough to pass all parameter…
Browse files Browse the repository at this point in the history
…s to sparsetools

The index dtype size selection for CSR, CSC, BSR did not take into
account that the matrix size is also passed on to sparsetools, as
integers of the index dtype size.

Fix this by changing the dtype selection to account for also integer
parameters passed on to sparsetools.

Because the integer type is never downgraded, and the array shape
parameters do not change (outside .reshape which is unimplemented for
these types), it is sufficient to ensure this only in __init__
  • Loading branch information
pv committed Feb 22, 2016
1 parent bc11753 commit 65244b9
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 7 deletions.
15 changes: 13 additions & 2 deletions scipy/sparse/bsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def __init__(self, arg1, shape=None, dtype=None, copy=False, blocksize=None):
if (M % R) != 0 or (N % C) != 0:
raise ValueError('shape must be multiple of blocksize')

idx_dtype = get_index_dtype(maxval=N//C)
# Select index dtype large enough to pass array and
# scalar parameters to sparsetools
idx_dtype = get_index_dtype(maxval=max(M//R, N//C, R, C))
self.indices = np.zeros(0, dtype=idx_dtype)
self.indptr = np.zeros(M//R + 1, dtype=idx_dtype)

Expand All @@ -157,7 +159,16 @@ def __init__(self, arg1, shape=None, dtype=None, copy=False, blocksize=None):
elif len(arg1) == 3:
# (data,indices,indptr) format
(data, indices, indptr) = arg1
idx_dtype = get_index_dtype((indices, indptr), check_contents=True)

# Select index dtype large enough to pass array and
# scalar parameters to sparsetools
maxval = None
if shape is not None:
maxval = max(shape)
if blocksize is not None:
maxval = max(maxval, max(blocksize))
idx_dtype = get_index_dtype((indices, indptr), maxval=maxval, check_contents=True)

self.indices = np.array(indices, copy=copy, dtype=idx_dtype)
self.indptr = np.array(indptr, copy=copy, dtype=idx_dtype)
self.data = np.array(data, copy=copy, dtype=getdtype(dtype, data))
Expand Down
13 changes: 11 additions & 2 deletions scipy/sparse/compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def __init__(self, arg1, shape=None, dtype=None, copy=False):
# create empty matrix
self.shape = arg1 # spmatrix checks for errors here
M, N = self.shape
idx_dtype = get_index_dtype(maxval=self._swap((M,N))[1])
# Select index dtype large enough to pass array and
# scalar parameters to sparsetools
idx_dtype = get_index_dtype(maxval=max(M,N))
self.data = np.zeros(0, getdtype(dtype, default=float))
self.indices = np.zeros(0, idx_dtype)
self.indptr = np.zeros(self._swap((M,N))[0] + 1, dtype=idx_dtype)
Expand All @@ -50,7 +52,14 @@ def __init__(self, arg1, shape=None, dtype=None, copy=False):
elif len(arg1) == 3:
# (data, indices, indptr) format
(data, indices, indptr) = arg1
idx_dtype = get_index_dtype((indices, indptr), check_contents=True)

# Select index dtype large enough to pass array and
# scalar parameters to sparsetools
maxval = None
if shape is not None:
maxval = max(shape)
idx_dtype = get_index_dtype((indices, indptr), maxval=maxval, check_contents=True)

self.indices = np.array(indices, copy=copy, dtype=idx_dtype)
self.indptr = np.array(indptr, copy=copy, dtype=idx_dtype)
self.data = np.array(data, copy=copy, dtype=dtype)
Expand Down
13 changes: 11 additions & 2 deletions scipy/sparse/sparsetools/sparsetools.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,18 @@ call_thunk(char ret_spec, const char *spec, thunk_t *thunk, PyObject *args)
}
else if (*p == 'i') {
/* Integer scalars */
Py_ssize_t value;
PY_LONG_LONG value;

value = PyInt_AsSsize_t(arg_arrays[j]);
#if PY_VERSION_HEX >= 0x03000000
value = PyLong_AsLongLong(arg_arrays[j]);
#else
if (PyInt_Check(arg_arrays[j])) {
value = PyInt_AsLong(arg_arrays[j]);
}
else {
value = PyLong_AsLongLong(arg_arrays[j]);
}
#endif
if (PyErr_Occurred()) {
goto fail;
}
Expand Down
83 changes: 82 additions & 1 deletion scipy/sparse/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3271,6 +3271,30 @@ def test_has_canonical_format(self):
M.sum_duplicates()
assert_equal(2, len(M.indices)) # unaffected content

def test_scalar_idx_dtype(self):
# Check that index dtype takes into account all parameters
# passed to sparsetools, including the scalar ones
indptr = np.zeros(2, dtype=np.int32)
indices = np.zeros(0, dtype=np.int32)
vals = np.zeros(0)
a = csr_matrix((vals, indices, indptr), shape=(1, 2**31-1))
b = csr_matrix((vals, indices, indptr), shape=(1, 2**31))
ij = np.zeros((2, 0), dtype=np.int32)
c = csr_matrix((vals, ij), shape=(1, 2**31-1))
d = csr_matrix((vals, ij), shape=(1, 2**31))
e = csr_matrix((1, 2**31-1))
f = csr_matrix((1, 2**31))
assert_equal(a.indptr.dtype, np.int32)
assert_equal(b.indptr.dtype, np.int64)
assert_equal(c.indptr.dtype, np.int32)
assert_equal(d.indptr.dtype, np.int64)
assert_equal(e.indptr.dtype, np.int32)
assert_equal(f.indptr.dtype, np.int64)

# These shouldn't fail
for x in [a, b, c, d, e, f]:
x + x


class TestCSC(sparse_test_class()):
spmatrix = csc_matrix
Expand Down Expand Up @@ -3387,6 +3411,30 @@ def test_fancy_indexing_broadcast(self):
SIJ = SIJ.todense()
assert_equal(SIJ, D[I,J])

def test_scalar_idx_dtype(self):
# Check that index dtype takes into account all parameters
# passed to sparsetools, including the scalar ones
indptr = np.zeros(2, dtype=np.int32)
indices = np.zeros(0, dtype=np.int32)
vals = np.zeros(0)
a = csc_matrix((vals, indices, indptr), shape=(2**31-1, 1))
b = csc_matrix((vals, indices, indptr), shape=(2**31, 1))
ij = np.zeros((2, 0), dtype=np.int32)
c = csc_matrix((vals, ij), shape=(2**31-1, 1))
d = csc_matrix((vals, ij), shape=(2**31, 1))
e = csr_matrix((1, 2**31-1))
f = csr_matrix((1, 2**31))
assert_equal(a.indptr.dtype, np.int32)
assert_equal(b.indptr.dtype, np.int64)
assert_equal(c.indptr.dtype, np.int32)
assert_equal(d.indptr.dtype, np.int64)
assert_equal(e.indptr.dtype, np.int32)
assert_equal(f.indptr.dtype, np.int64)

# These shouldn't fail
for x in [a, b, c, d, e, f]:
x + x


class TestDOK(sparse_test_class(minmax=False, nnz_axis=False)):
spmatrix = dok_matrix
Expand Down Expand Up @@ -3824,6 +3872,37 @@ def test_iterator(self):
def test_setdiag(self):
pass

def test_scalar_idx_dtype(self):
# Check that index dtype takes into account all parameters
# passed to sparsetools, including the scalar ones
indptr = np.zeros(2, dtype=np.int32)
indices = np.zeros(0, dtype=np.int32)
vals = np.zeros((0, 1, 1))
a = bsr_matrix((vals, indices, indptr), shape=(1, 2**31-1))
b = bsr_matrix((vals, indices, indptr), shape=(1, 2**31))
c = bsr_matrix((1, 2**31-1))
d = bsr_matrix((1, 2**31))
assert_equal(a.indptr.dtype, np.int32)
assert_equal(b.indptr.dtype, np.int64)
assert_equal(c.indptr.dtype, np.int32)
assert_equal(d.indptr.dtype, np.int64)

try:
vals2 = np.zeros((0, 1, 2**31-1))
vals3 = np.zeros((0, 1, 2**31))
e = bsr_matrix((vals2, indices, indptr), shape=(1, 2**31-1))
f = bsr_matrix((vals3, indices, indptr), shape=(1, 2**31))
assert_equal(e.indptr.dtype, np.int32)
assert_equal(f.indptr.dtype, np.int64)
except (MemoryError, ValueError):
# May fail on 32-bit Python
e = 0
f = 0

# These shouldn't fail
for x in [a, b, c, d, e, f]:
x + x


#------------------------------------------------------------------------------
# Tests for non-canonical representations (with duplicates, unsorted indices)
Expand Down Expand Up @@ -3989,7 +4068,8 @@ class Test64Bit(object):
# The following features are missing, so skip the tests:
SKIP_TESTS = {
'test_expm': 'expm for 64-bit indices not available',
'test_solve': 'linsolve for 64-bit indices not available'
'test_solve': 'linsolve for 64-bit indices not available',
'test_scalar_idx_dtype': 'test implemented in base class',
}

def _create_some_matrix(self, mat_cls, m, n):
Expand Down Expand Up @@ -4127,5 +4207,6 @@ def check_unlimited():
check_limited()
check_unlimited()


if __name__ == "__main__":
run_module_suite()

0 comments on commit 65244b9

Please sign in to comment.