Skip to content

Commit

Permalink
mv sparse solve() into spsolve() and uses factorized()
Browse files Browse the repository at this point in the history
  • Loading branch information
scopatz committed Aug 14, 2012
1 parent 347f799 commit c929e84
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 66 deletions.
8 changes: 4 additions & 4 deletions scipy/sparse/base.py
Expand Up @@ -669,11 +669,11 @@ def expm(self):
else:
raise ValueError("invalid type: "+str(self.dtype))

from linalg import solve
from linalg import spsolve

P = U + V # p_m(A) : numerator
Q = -U + V # q_m(A) : denominator
R = solve(Q, P)
R = spsolve(Q, P)

# squaring step to undo scaling
for i in range(n_squarings):
Expand All @@ -685,9 +685,9 @@ def inv(self):
"""Computes the inverse of the sparse matrix.
"""
from construct import eye
from linalg import solve
from linalg import spsolve
I = eye(self.shape[0], self.shape[1], dtype=self.dtype, format=self.format)
selfinv = solve(self, I)
selfinv = spsolve(self, I)
return selfinv


Expand Down
86 changes: 41 additions & 45 deletions scipy/sparse/linalg/dsolve/linsolve.py
@@ -1,6 +1,6 @@
from warnings import warn

from numpy import asarray, empty, where
from numpy import asarray, empty, where, squeeze
from scipy.sparse import isspmatrix_csc, isspmatrix_csr, isspmatrix, \
SparseEfficiencyWarning, csc_matrix

Expand All @@ -18,7 +18,7 @@
useUmfpack = True


__all__ = [ 'use_solver', 'spsolve', 'solve', 'splu', 'spilu', 'factorized' ]
__all__ = [ 'use_solver', 'spsolve', 'splu', 'spilu', 'factorized' ]

def use_solver( **kwargs ):
"""
Expand All @@ -44,33 +44,40 @@ def use_solver( **kwargs ):


def spsolve(A, b, permc_spec=None, use_umfpack=True):
"""Solve the sparse linear system Ax=b """
if isspmatrix( b ):
b = b.toarray()

if b.ndim > 1:
if max( b.shape ) == b.size:
b = b.squeeze()
else:
raise ValueError("rhs must be a vector (has shape %s)" % (b.shape,))

"""Solve the sparse linear system Ax=b, where b may be a vector or a matrix."""
if not (isspmatrix_csc(A) or isspmatrix_csr(A)):
A = csc_matrix(A)
warn('spsolve requires CSC or CSR matrix format', SparseEfficiencyWarning)
warn('spsolve requires A be CSC or CSR matrix format', SparseEfficiencyWarning)

bisvector = (A.shape != b.shape)

if bisvector and isspmatrix( b ):
b = b.toarray()

if b.ndim > 1:
if max( b.shape ) == b.size:
b = b.squeeze()
else:
msg = "rhs must be a vector (has shape %s) " % (b.shape,)
msg += "or matrix whose shape matches lhs (%s)" % (A.shape,)
raise ValueError(msg)
elif isspmatrix(b) and not (isspmatrix_csc(b) or isspmatrix_csr(b)):
B = csc_matrix(b)
warn('solve requires b be CSC or CSR matrix format', SparseEfficiencyWarning)

A.sort_indices()
A = A.asfptype() #upcast to a floating point format
A = A.asfptype() # upcast to a floating point format

M, N = A.shape
if (M != N):
raise ValueError("matrix must be square (has shape %s)" % ((M, N),))
if M != b.size:
if bisvector and M != b.size:
raise ValueError("matrix - rhs size mismatch (%s - %s)"
% (A.shape, b.size))

use_umfpack = use_umfpack and useUmfpack

if isUmfpack and use_umfpack:
if bisvector and isUmfpack and use_umfpack:
if noScikit:
warn( 'scipy.sparse.linalg.dsolve.umfpack will be removed,'
' install scikits.umfpack instead', DeprecationWarning )
Expand All @@ -85,7 +92,7 @@ def spsolve(A, b, permc_spec=None, use_umfpack=True):
return umf.linsolve( umfpack.UMFPACK_A, A, b,
autoTranspose = True )

else:
elif bisvector:
if isspmatrix_csc(A):
flag = 1 # CSC format
elif isspmatrix_csr(A):
Expand All @@ -98,30 +105,18 @@ def spsolve(A, b, permc_spec=None, use_umfpack=True):
options = dict(ColPerm=permc_spec)
return _superlu.gssv(N, A.nnz, A.data, A.indices, A.indptr, b, flag,
options=options)[0]

def solve(A, B, permc_spec=None, use_umfpack=True):
"""Solve the sparse linear system Ax=B, where B may be a vector or a matrix."""
if not isspmatrix(B) or A.shape != B.shape:
return spsolve(A, B, permc_spec=permc_spec, use_umfpack=use_umfpack)

if not (isspmatrix_csc(A) or isspmatrix_csr(A)):
A = csc_matrix(A)
warn('solve requires CSC or CSR matrix format', SparseEfficiencyWarning)

if not (isspmatrix_csc(B) or isspmatrix_csr(B)):
B = csc_matrix(B)
warn('solve requires CSC or CSR matrix format', SparseEfficiencyWarning)

shape0 = A.shape[0]
tempj = empty(shape0, dtype=int)
X = A.__class__(A.shape)
for j in range(shape0):
Xj = spsolve(A, B[:,j])
w = where(Xj != 0.0)[0]
tempj.fill(j)
X = X + A.__class__((Xj[w],(w,tempj[:len(w)])),
shape=A.shape, dtype=A.dtype)
return X
else:
# Cover the case where b is also a matrix
Afactsolve = factorized(A)
tempj = empty(M, dtype=int)
x = A.__class__(A.shape)
for j in range(M):
xj = Afactsolve(squeeze(b[:,j].toarray()))
w = where(xj != 0.0)[0]
tempj.fill(j)
x = x + A.__class__((xj[w],(w,tempj[:len(w)])),
shape=A.shape, dtype=A.dtype)
return x


def splu(A, permc_spec=None, diag_pivot_thresh=None,
Expand Down Expand Up @@ -268,10 +263,11 @@ def factorized( A ):
"""
Return a fuction for solving a sparse linear system, with A pre-factorized.
Example:
solve = factorized( A ) # Makes LU decomposition.
x1 = solve( rhs1 ) # Uses the LU factors.
x2 = solve( rhs2 ) # Uses again the LU factors.
Examples
--------
solve = factorized( A ) # Makes LU decomposition.
x1 = solve( rhs1 ) # Uses the LU factors.
x2 = solve( rhs2 ) # Uses again the LU factors.
"""
if isUmfpack and useUmfpack:
if noScikit:
Expand Down
32 changes: 15 additions & 17 deletions scipy/sparse/linalg/dsolve/tests/test_linsolve.py
Expand Up @@ -8,7 +8,7 @@
import scipy.linalg
from scipy.linalg import norm, inv
from scipy.sparse import spdiags, SparseEfficiencyWarning, csc_matrix, csr_matrix
from scipy.sparse.linalg.dsolve import spsolve, solve, use_solver, splu, spilu
from scipy.sparse.linalg.dsolve import spsolve, use_solver, splu, spilu

warnings.simplefilter('ignore',SparseEfficiencyWarning)

Expand Down Expand Up @@ -40,7 +40,7 @@ def test_twodiags(self):

assert_( norm(b - Asp*x) < 10 * cond_A * eps )

def test_smoketest(self):
def test_bvector_smoketest(self):
Adense = matrix([[ 0., 1., 1.],
[ 1., 0., 1.],
[ 0., 0., 1.]])
Expand All @@ -52,19 +52,7 @@ def test_smoketest(self):

assert_array_almost_equal(x, x2)

def test_non_square(self):
# A is not square.
A = ones((3, 4))
b = ones((4, 1))
assert_raises(ValueError, spsolve, A, b)
# A2 and b2 have incompatible shapes.
A2 = csc_matrix(eye(3))
b2 = array([1.0, 2.0])
assert_raises(ValueError, spsolve, A2, b2)


class TestSolve(TestCase):
def test_solve_smoketest(self):
def test_bmatrix_smoketest(self):
Adense = matrix([[ 0., 1., 1.],
[ 1., 0., 1.],
[ 0., 0., 1.]])
Expand All @@ -73,9 +61,19 @@ def test_solve_smoketest(self):
x = random.randn(3, 3)
Bdense = As.dot(x)
Bs = csc_matrix(Bdense)
x2 = solve(As, Bs)
x2 = spsolve(As, Bs)
assert_array_almost_equal(x, x2.todense())

def test_non_square(self):
# A is not square.
A = ones((3, 4))
b = ones((4, 1))
assert_raises(ValueError, spsolve, A, b)
# A2 and b2 have incompatible shapes.
A2 = csc_matrix(eye(3))
b2 = array([1.0, 2.0])
assert_raises(ValueError, spsolve, A2, b2)

def test_example_comparison(self):
row = array([0,0,1,2,2,2])
col = array([0,2,2,0,1,2])
Expand All @@ -89,7 +87,7 @@ def test_example_comparison(self):
sN = csr_matrix((data, (row,col)), shape=(3,3), dtype=float)
N = sN.todense()

sX = solve(sM, sN)
sX = spsolve(sM, sN)
X = scipy.linalg.solve(M, N)

assert_array_almost_equal(X, sX.todense())
Expand Down

0 comments on commit c929e84

Please sign in to comment.