Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: sparse: Propagate dtype through CSR/CSC constructors #13403

Merged
merged 4 commits into from
Jan 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion scipy/sparse/compressed.py
Expand Up @@ -51,7 +51,8 @@ def __init__(self, arg1, shape=None, dtype=None, copy=False):
if len(arg1) == 2:
# (data, ij) format
from .coo import coo_matrix
other = self.__class__(coo_matrix(arg1, shape=shape))
other = self.__class__(coo_matrix(arg1, shape=shape,
dtype=dtype))
self._set_self(other)
elif len(arg1) == 3:
# (data, indices, indptr) format
Expand Down
7 changes: 4 additions & 3 deletions scipy/sparse/coo.py
Expand Up @@ -131,9 +131,10 @@ def __init__(self, arg1, shape=None, dtype=None, copy=False):
M, N = arg1
self._shape = check_shape((M, N))
idx_dtype = get_index_dtype(maxval=max(M, N))
data_dtype = getdtype(dtype, default=float)
self.row = np.array([], dtype=idx_dtype)
self.col = np.array([], dtype=idx_dtype)
self.data = np.array([], getdtype(dtype, default=float))
self.data = np.array([], dtype=data_dtype)
self.has_canonical_format = True
else:
try:
Expand All @@ -154,11 +155,11 @@ def __init__(self, arg1, shape=None, dtype=None, copy=False):
self._shape = check_shape((M, N))

idx_dtype = get_index_dtype(maxval=max(self.shape))
data_dtype = getdtype(dtype, obj, default=float)
self.row = np.array(row, copy=copy, dtype=idx_dtype)
self.col = np.array(col, copy=copy, dtype=idx_dtype)
self.data = np.array(obj, copy=copy)
self.data = np.array(obj, copy=copy, dtype=data_dtype)
self.has_canonical_format = False

else:
if isspmatrix(arg1):
if isspmatrix_coo(arg1) and copy:
Expand Down
14 changes: 11 additions & 3 deletions scipy/sparse/tests/test_base.py
Expand Up @@ -3523,6 +3523,11 @@ def test_constructor4(self):
ij = vstack((row,col))
csr = csr_matrix((data,ij),(4,3))
assert_array_equal(arange(12).reshape(4,3),csr.todense())

# using Python lists and a specified dtype
csr = csr_matrix(([2**63 + 1, 1], ([0, 1], [0, 1])), dtype=np.uint64)
dense = array([[2**63 + 1, 0], [0, 1]], dtype=np.uint64)
assert_array_equal(dense, csr.toarray())

def test_constructor5(self):
# infer dimensions from arrays
Expand Down Expand Up @@ -4086,13 +4091,16 @@ def test_constructor1(self):
# unsorted triplet format
row = array([2, 3, 1, 3, 0, 1, 3, 0, 2, 1, 2])
col = array([0, 1, 0, 0, 1, 1, 2, 2, 2, 2, 1])
data = array([6., 10., 3., 9., 1., 4.,
11., 2., 8., 5., 7.])
data = array([6., 10., 3., 9., 1., 4., 11., 2., 8., 5., 7.])

coo = coo_matrix((data,(row,col)),(4,3))

assert_array_equal(arange(12).reshape(4,3),coo.todense())

# using Python lists and a specified dtype
coo = coo_matrix(([2**63 + 1, 1], ([0, 1], [0, 1])), dtype=np.uint64)
dense = array([[2**63 + 1, 0], [0, 1]], dtype=np.uint64)
assert_array_equal(dense, coo.toarray())

def test_constructor2(self):
# unsorted triplet format with duplicates (which are summed)
row = array([0,1,2,2,2,2,0,0,2,2])
Expand Down