Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

TST: sparse/isolve: reorganize the tests for iterative routines

Also add complex-valued test cases.
  • Loading branch information...
commit 575262f399419c12d2884712948f97f5a8e3c512 1 parent 9728c09
@pv pv authored
View
254 scipy/sparse/linalg/isolve/tests/test_iterative.py
@@ -2,145 +2,212 @@
""" Test functions for the sparse.linalg.isolve module
"""
-from numpy.testing import TestCase, assert_equal, assert_array_equal, assert_
+import numpy as np
+
+from numpy.testing import TestCase, assert_equal, assert_array_equal, \
+ assert_, assert_allclose
from numpy import zeros, ones, arange, array, abs, max
+from numpy.linalg import cond
from scipy.linalg import norm
from scipy.sparse import spdiags, csr_matrix
-from scipy.sparse.linalg.interface import LinearOperator
+from scipy.sparse.linalg import LinearOperator, aslinearoperator
from scipy.sparse.linalg.isolve import cg, cgs, bicg, bicgstab, gmres, qmr, minres, lgmres
-#def callback(x):
-# global A, b
-# res = b-dot(A,x)
-# #print "||A.x - b|| = " + str(norm(dot(A,x)-b))
-
-
#TODO check that method preserve shape and type
-#TODO test complex matrices
#TODO test both preconditioner methods
-N = 40
-data = ones((3,N))
-data[0,:] = 2
-data[1,:] = -1
-data[2,:] = -1
-Poisson1D = spdiags( data, [0,-1,1], N, N, format='csr')
-
-data = array([[6, -5, 2, 7, -1, 10, 4, -3, -8, 9]],dtype='d')
-RandDiag = spdiags( data, [0], 10, 10, format='csr' )
-
-class TestIterative(TestCase):
- def setUp(self):
+class Case(object):
+ def __init__(self, name, A, skip=None):
+ self.name = name
+ self.A = A
+ if skip is None:
+ self.skip = []
+ else:
+ self.skip = skip
+ def __repr__(self):
+ return "<%s>" % self.name
+
+class IterativeParams(object):
+ def __init__(self):
# list of tuples (solver, symmetric, positive_definite )
- self.solvers = []
- self.solvers.append( (cg, True, True) )
- self.solvers.append( (cgs, False, False) )
- self.solvers.append( (bicg, False, False) )
- self.solvers.append( (bicgstab, False, False) )
- self.solvers.append( (gmres, False, False) )
- self.solvers.append( (qmr, False, False) )
- self.solvers.append( (minres, True, False) )
- self.solvers.append( (lgmres, False, False) )
+ solvers = [cg, cgs, bicg, bicgstab, gmres, qmr, minres, lgmres]
+ sym_solvers = [minres, cg]
+ posdef_solvers = [cg]
+ real_solvers = [minres]
+
+ self.solvers = solvers
# list of tuples (A, symmetric, positive_definite )
self.cases = []
# Symmetric and Positive Definite
- self.cases.append( (Poisson1D,True,True) )
+ N = 40
+ data = ones((3,N))
+ data[0,:] = 2
+ data[1,:] = -1
+ data[2,:] = -1
+ Poisson1D = spdiags(data, [0,-1,1], N, N, format='csr')
+ self.Poisson1D = Case("poisson1d", Poisson1D)
+ self.cases.append(self.Poisson1D)
# Symmetric and Negative Definite
- self.cases.append( (-Poisson1D,True,False) )
+ self.cases.append(Case("neg-poisson1d", -Poisson1D,
+ skip=posdef_solvers))
# Symmetric and Indefinite
- self.cases.append( (RandDiag,True,False) )
+ data = array([[6, -5, 2, 7, -1, 10, 4, -3, -8, 9]],dtype='d')
+ RandDiag = spdiags( data, [0], 10, 10, format='csr' )
+ self.cases.append(Case("rand-diag", RandDiag, skip=posdef_solvers))
+
+ # Random real-valued
+ np.random.seed(1234)
+ data = np.random.rand(4, 4)
+ self.cases.append(Case("rand", data, skip=posdef_solvers+sym_solvers))
+
+ # Random symmetric real-valued
+ np.random.seed(1234)
+ data = np.random.rand(4, 4)
+ data = data + data.T
+ self.cases.append(Case("rand-sym", data, skip=posdef_solvers))
+
+ # Random pos-def symmetric real
+ np.random.seed(1234)
+ data = np.random.rand(9, 9)
+ data = np.dot(data.conj(), data.T)
+ self.cases.append(Case("rand-sym-pd", data))
+
+ # Random complex-valued
+ np.random.seed(1234)
+ data = np.random.rand(4, 4) + 1j*np.random.rand(4, 4)
+ self.cases.append(Case("rand-cmplx", data,
+ skip=posdef_solvers+sym_solvers+real_solvers))
+
+ # Random hermitian complex-valued
+ np.random.seed(1234)
+ data = np.random.rand(4, 4) + 1j*np.random.rand(4, 4)
+ data = data + data.T.conj()
+ self.cases.append(Case("rand-cmplx-herm", data,
+ skip=posdef_solvers+real_solvers))
+
+ # Random pos-def hermitian complex-valued
+ np.random.seed(1234)
+ data = np.random.rand(9, 9) + 1j*np.random.rand(9, 9)
+ data = np.dot(data.conj(), data.T)
+ self.cases.append(Case("rand-cmplx-sym-pd", data, skip=real_solvers))
# Non-symmetric and Positive Definite
- # bicg and cgs fail to converge on this one
- #data = ones((2,10))
- #data[0,:] = 2
- #data[1,:] = -1
- #A = spdiags( data, [0,-1], 10, 10, format='csr')
- #self.cases.append( (A,False,True) )
+ # cgs fails to converge on this one -- algorithmic limitation apparently
+ data = ones((2,10))
+ data[0,:] = 2
+ data[1,:] = -1
+ A = spdiags( data, [0,-1], 10, 10, format='csr')
+ self.cases.append(Case("nonsymposdef", A,
+ skip=sym_solvers+[cgs]))
- def test_maxiter(self):
- """test whether maxiter is respected"""
+def setup_module():
+ global params
+ params = IterativeParams()
- A = Poisson1D
- tol = 1e-12
+def check_maxiter(solver, case):
+ A = case.A
+ tol = 1e-12
- for solver,req_sym,req_pos in self.solvers:
- b = arange(A.shape[0], dtype=float)
- x0 = 0*b
+ b = arange(A.shape[0], dtype=float)
+ x0 = 0*b
- residuals = []
- def callback(x):
- residuals.append( norm(b - A*x) )
+ residuals = []
+ def callback(x):
+ residuals.append(norm(b - case.A*x))
- x, info = solver(A, b, x0=x0, tol=tol, maxiter=3, callback=callback)
+ x, info = solver(A, b, x0=x0, tol=tol, maxiter=3, callback=callback)
- assert_equal(len(residuals), 3)
- assert_equal(info, 3)
+ assert_equal(len(residuals), 3)
+ assert_equal(info, 3)
- def test_convergence(self):
- """test whether all methods converge"""
+def test_maxiter():
+ case = params.Poisson1D
+ for solver in params.solvers:
+ if solver in case.skip: continue
+ yield check_maxiter, solver, case
- tol = 1e-8
+def assert_normclose(a, b, tol=1e-8):
+ residual = norm(a - b)
+ tolerance = tol*norm(b)
+ msg = "residual (%g) not smaller than tolerance %g" % (residual, tolerance)
+ assert_(residual < tolerance, msg=msg)
- for solver,req_sym,req_pos in self.solvers:
- for A,sym,pos in self.cases:
- if req_sym and not sym: continue
- if req_pos and not pos: continue
+def check_convergence(solver, case):
+ tol = 1e-8
- b = arange(A.shape[0], dtype=float)
- x0 = 0*b
+ A = case.A
- x, info = solver(A, b, x0=x0, tol=tol)
+ b = arange(A.shape[0], dtype=float)
+ x0 = 0*b
- assert_array_equal(x0, 0*b) #ensure that x0 is not overwritten
- assert_equal(info,0)
+ x, info = solver(A, b, x0=x0, tol=tol)
- assert_( norm(b - A*x) < tol*norm(b) )
+ assert_array_equal(x0, 0*b) #ensure that x0 is not overwritten
+ assert_equal(info,0)
+ assert_normclose(A.dot(x), b, tol=tol)
- def test_precond(self):
- """test whether all methods accept a trivial preconditioner"""
+def test_convergence():
+ for solver in params.solvers:
+ for case in params.cases:
+ if solver in case.skip: continue
+ yield check_convergence, solver, case
- tol = 1e-8
+def check_precond_dummy(solver, case):
+ tol = 1e-8
- def identity(b,which=None):
- """trivial preconditioner"""
- return b
+ def identity(b,which=None):
+ """trivial preconditioner"""
+ return b
- for solver,req_sym,req_pos in self.solvers:
+ A = case.A
- for A,sym,pos in self.cases:
- if req_sym and not sym: continue
- if req_pos and not pos: continue
+ M,N = A.shape
+ D = spdiags( [1.0/A.diagonal()], [0], M, N)
- M,N = A.shape
- D = spdiags( [1.0/A.diagonal()], [0], M, N)
+ b = arange(A.shape[0], dtype=float)
+ x0 = 0*b
- b = arange(A.shape[0], dtype=float)
- x0 = 0*b
+ precond = LinearOperator(A.shape, identity, rmatvec=identity)
- precond = LinearOperator(A.shape, identity, rmatvec=identity)
+ if solver is qmr:
+ x, info = solver(A, b, M1=precond, M2=precond, x0=x0, tol=tol)
+ else:
+ x, info = solver(A, b, M=precond, x0=x0, tol=tol)
+ assert_equal(info,0)
+ assert_normclose(A.dot(x), b, tol)
- if solver == qmr:
- x, info = solver(A, b, M1=precond, M2=precond, x0=x0, tol=tol)
- else:
- x, info = solver(A, b, M=precond, x0=x0, tol=tol)
- assert_equal(info,0)
- assert_( norm(b - A*x) < tol*norm(b) )
+ A = aslinearoperator(A)
+ A.psolve = identity
+ A.rpsolve = identity
- A = A.copy()
- A.psolve = identity
- A.rpsolve = identity
+ x, info = solver(A, b, x0=x0, tol=tol)
+ assert_equal(info,0)
+ assert_normclose(A*x, b, tol=tol)
- x, info = solver(A, b, x0=x0, tol=tol)
- assert_equal(info,0)
- assert_( norm(b - A*x) < tol*norm(b) )
+def test_precond_dummy():
+ case = params.Poisson1D
+ for solver in params.solvers:
+ if solver in case.skip: continue
+ yield check_precond_dummy, solver, case
+def test_gmres_basic():
+ A = np.vander(np.arange(10) + 1)[:, ::-1]
+ b = np.zeros(10)
+ b[0] = 1
+ x = np.linalg.solve(A, b)
+
+ x_gm, err = gmres(A, b, restart=5, maxiter=1)
+
+ assert_allclose(x_gm[0], 0.359, rtol=1e-2)
+
+
+#------------------------------------------------------------------------------
class TestQMR(TestCase):
def test_leftright_precond(self):
@@ -176,8 +243,7 @@ def UT_solve(b):
x,info = qmr(A, b, tol=1e-8, maxiter=15, M1=M1, M2=M2)
assert_equal(info,0)
- assert_( norm(b - A*x) < 1e-8*norm(b) )
-
+ assert_normclose(A*x, b, tol=1e-8)
class TestGMRES(TestCase):
def test_callback(self):
View
17 scipy/sparse/linalg/tests/test_iterative.py
@@ -1,17 +0,0 @@
-import numpy as np
-from numpy.testing import run_module_suite, assert_almost_equal
-
-import scipy.sparse.linalg as spla
-
-def test_gmres_basic():
- A = np.vander(np.arange(10) + 1)[:, ::-1]
- b = np.zeros(10)
- b[0] = 1
- x = np.linalg.solve(A, b)
-
- x_gm, err = spla.gmres(A, b, restart=5, maxiter=1)
-
- assert_almost_equal(x_gm[0], 0.359, decimal=2)
-
-if __name__ == "__main__":
- run_module_suite()
Please sign in to comment.
Something went wrong with that request. Please try again.