Skip to content

Commit

Permalink
BUG: sparse: Propagate dtype through CSR/CSC constructors (scipy#13403)
Browse files Browse the repository at this point in the history
Propagate the dtype parameter through the intermediate COO format constructor,
so the given dtype is preserved throughout the intermediate stages of creating a
sparse matrix.

Closes scipygh-13329.
  • Loading branch information
perimosocordiae authored and tylerjereddy committed Feb 11, 2021
1 parent 65199f3 commit 494b25c
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
3 changes: 2 additions & 1 deletion scipy/sparse/compressed.py
Expand Up @@ -51,7 +51,8 @@ def __init__(self, arg1, shape=None, dtype=None, copy=False):
if len(arg1) == 2:
# (data, ij) format
from .coo import coo_matrix
other = self.__class__(coo_matrix(arg1, shape=shape))
other = self.__class__(coo_matrix(arg1, shape=shape,
dtype=dtype))
self._set_self(other)
elif len(arg1) == 3:
# (data, indices, indptr) format
Expand Down
7 changes: 4 additions & 3 deletions scipy/sparse/coo.py
Expand Up @@ -131,9 +131,10 @@ def __init__(self, arg1, shape=None, dtype=None, copy=False):
M, N = arg1
self._shape = check_shape((M, N))
idx_dtype = get_index_dtype(maxval=max(M, N))
data_dtype = getdtype(dtype, default=float)
self.row = np.array([], dtype=idx_dtype)
self.col = np.array([], dtype=idx_dtype)
self.data = np.array([], getdtype(dtype, default=float))
self.data = np.array([], dtype=data_dtype)
self.has_canonical_format = True
else:
try:
Expand All @@ -154,11 +155,11 @@ def __init__(self, arg1, shape=None, dtype=None, copy=False):
self._shape = check_shape((M, N))

idx_dtype = get_index_dtype(maxval=max(self.shape))
data_dtype = getdtype(dtype, obj, default=float)
self.row = np.array(row, copy=copy, dtype=idx_dtype)
self.col = np.array(col, copy=copy, dtype=idx_dtype)
self.data = np.array(obj, copy=copy)
self.data = np.array(obj, copy=copy, dtype=data_dtype)
self.has_canonical_format = False

else:
if isspmatrix(arg1):
if isspmatrix_coo(arg1) and copy:
Expand Down
14 changes: 11 additions & 3 deletions scipy/sparse/tests/test_base.py
Expand Up @@ -3523,6 +3523,11 @@ def test_constructor4(self):
ij = vstack((row,col))
csr = csr_matrix((data,ij),(4,3))
assert_array_equal(arange(12).reshape(4,3),csr.todense())

# using Python lists and a specified dtype
csr = csr_matrix(([2**63 + 1, 1], ([0, 1], [0, 1])), dtype=np.uint64)
dense = array([[2**63 + 1, 0], [0, 1]], dtype=np.uint64)
assert_array_equal(dense, csr.toarray())

def test_constructor5(self):
# infer dimensions from arrays
Expand Down Expand Up @@ -4086,13 +4091,16 @@ def test_constructor1(self):
# unsorted triplet format
row = array([2, 3, 1, 3, 0, 1, 3, 0, 2, 1, 2])
col = array([0, 1, 0, 0, 1, 1, 2, 2, 2, 2, 1])
data = array([6., 10., 3., 9., 1., 4.,
11., 2., 8., 5., 7.])
data = array([6., 10., 3., 9., 1., 4., 11., 2., 8., 5., 7.])

coo = coo_matrix((data,(row,col)),(4,3))

assert_array_equal(arange(12).reshape(4,3),coo.todense())

# using Python lists and a specified dtype
coo = coo_matrix(([2**63 + 1, 1], ([0, 1], [0, 1])), dtype=np.uint64)
dense = array([[2**63 + 1, 0], [0, 1]], dtype=np.uint64)
assert_array_equal(dense, coo.toarray())

def test_constructor2(self):
# unsorted triplet format with duplicates (which are summed)
row = array([0,1,2,2,2,2,0,0,2,2])
Expand Down

0 comments on commit 494b25c

Please sign in to comment.