Skip to content

Commit

Permalink
TST: fix ComplexWarnings by replacing .astype(float) by .real.astype(…
Browse files Browse the repository at this point in the history
…float).

Also adds tests that different dtypes are handled correctly in linalg/expm2.

(backport of r7037, r7043)
  • Loading branch information
rgommers committed Jan 16, 2011
1 parent 8682bd9 commit 1d06f67
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
8 changes: 4 additions & 4 deletions scipy/fftpack/tests/test_basic.py
Expand Up @@ -417,17 +417,17 @@ def test_size_accuracy(self):
for size in SMALL_COMPOSITE_SIZES + SMALL_PRIME_SIZES:
np.random.seed(1234)
x = np.random.rand(size, size) + 1j*np.random.rand(size, size)
y1 = fftn(x.astype(np.float32))
y2 = fftn(x.astype(np.float64)).astype(np.complex64)
y1 = fftn(x.real.astype(np.float32))
y2 = fftn(x.real.astype(np.float64)).astype(np.complex64)

self.failUnless(y1.dtype == np.complex64)
assert_array_almost_equal_nulp(y1, y2, 2000)

for size in LARGE_COMPOSITE_SIZES + LARGE_PRIME_SIZES:
np.random.seed(1234)
x = np.random.rand(size, 3) + 1j*np.random.rand(size, 3)
y1 = fftn(x.astype(np.float32))
y2 = fftn(x.astype(np.float64)).astype(np.complex64)
y1 = fftn(x.real.astype(np.float32))
y2 = fftn(x.real.astype(np.float64)).astype(np.complex64)

self.failUnless(y1.dtype == np.complex64)
assert_array_almost_equal_nulp(y1, y2, 2000)
Expand Down
6 changes: 5 additions & 1 deletion scipy/linalg/matfuncs.py
Expand Up @@ -91,7 +91,11 @@ def expm2(A):
t = 'd'
s,vr = eig(A)
vri = inv(vr)
return dot(dot(vr,diag(exp(s))),vri).astype(t)
r = dot(dot(vr,diag(exp(s))),vri)
if t in ['f', 'd']:
return r.real.astype(t)
else:
return r.astype(t)

def expm3(A, q=20):
"""Compute the matrix exponential using Taylor series.
Expand Down
9 changes: 9 additions & 0 deletions scipy/linalg/tests/test_matfuncs.py
Expand Up @@ -93,5 +93,14 @@ def test_zero(self):
assert_array_almost_equal(expm2(a),[[1,0],[0,1]])
assert_array_almost_equal(expm3(a),[[1,0],[0,1]])

def test_consistency(self):
a = array([[0.,1],[-1,0]])
assert_array_almost_equal(expm(a), expm2(a))
assert_array_almost_equal(expm(a), expm3(a))

a = array([[1j,1],[-1,-2j]])
assert_array_almost_equal(expm(a), expm2(a))
assert_array_almost_equal(expm(a), expm3(a))

if __name__ == "__main__":
run_module_suite()

0 comments on commit 1d06f67

Please sign in to comment.