Skip to content

Commit

Permalink
Merge pull request #4977 from argriffing/block-diag-dtype
Browse files Browse the repository at this point in the history
MAINT: more careful dtype treatment in block diagonal matrix construction
  • Loading branch information
ev-br committed Jun 22, 2015
2 parents 6370384 + 3912cd0 commit a344c3f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
3 changes: 2 additions & 1 deletion scipy/linalg/special_matrices.py
Expand Up @@ -531,7 +531,8 @@ def block_diag(*arrs):
"greater than 2: %s" % bad_args)

shapes = np.array([a.shape if a.size > 0 else [0, 0] for a in arrs])
out = np.zeros(np.sum(shapes, axis=0), dtype=arrs[0].dtype)
out_dtype = np.find_common_type([arr.dtype for arr in arrs], [])
out = np.zeros(np.sum(shapes, axis=0), dtype=out_dtype)

r, c = 0, 0
for i, (rr, cc) in enumerate(shapes):
Expand Down
5 changes: 5 additions & 0 deletions scipy/linalg/tests/test_special_matrices.py
Expand Up @@ -244,6 +244,11 @@ def test_dtype(self):
x = block_diag([[True]])
assert_equal(x.dtype, bool)

def test_mixed_dtypes(self):
actual = block_diag([[1]], [[1j]])
desired = np.array([[1, 0], [0, 1j]])
assert_array_equal(actual, desired)

def test_scalar_and_1d_args(self):
a = block_diag(1)
assert_equal(a.shape, (1,1))
Expand Down

0 comments on commit a344c3f

Please sign in to comment.