From 3912cd074c696c1e55a1caf03559bd238d21cf55 Mon Sep 17 00:00:00 2001 From: alex Date: Thu, 18 Jun 2015 14:07:08 -0400 Subject: [PATCH] MAINT: more careful dtype treatment in block diagonal matrix construction --- scipy/linalg/special_matrices.py | 3 ++- scipy/linalg/tests/test_special_matrices.py | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/scipy/linalg/special_matrices.py b/scipy/linalg/special_matrices.py index 14e550e46444..a2fa185181eb 100644 --- a/scipy/linalg/special_matrices.py +++ b/scipy/linalg/special_matrices.py @@ -531,7 +531,8 @@ def block_diag(*arrs): "greater than 2: %s" % bad_args) shapes = np.array([a.shape if a.size > 0 else [0, 0] for a in arrs]) - out = np.zeros(np.sum(shapes, axis=0), dtype=arrs[0].dtype) + out_dtype = np.find_common_type([arr.dtype for arr in arrs], []) + out = np.zeros(np.sum(shapes, axis=0), dtype=out_dtype) r, c = 0, 0 for i, (rr, cc) in enumerate(shapes): diff --git a/scipy/linalg/tests/test_special_matrices.py b/scipy/linalg/tests/test_special_matrices.py index 09ab18670edc..b4e951f6ee3c 100644 --- a/scipy/linalg/tests/test_special_matrices.py +++ b/scipy/linalg/tests/test_special_matrices.py @@ -244,6 +244,11 @@ def test_dtype(self): x = block_diag([[True]]) assert_equal(x.dtype, bool) + def test_mixed_dtypes(self): + actual = block_diag([[1]], [[1j]]) + desired = np.array([[1, 0], [0, 1j]]) + assert_array_equal(actual, desired) + def test_scalar_and_1d_args(self): a = block_diag(1) assert_equal(a.shape, (1,1))