From db58855e0788f576b1df1919b7703de31e8905e6 Mon Sep 17 00:00:00 2001 From: Anthony Scopatz Date: Tue, 14 Aug 2012 00:25:41 -0500 Subject: [PATCH] Refactored _padeXX() functs to use sparse or dense matrices. --- scipy/linalg/matfuncs.py | 80 ++++++++++++++++++-------- scipy/linalg/misc.py | 5 ++ scipy/sparse/base.py | 63 +------------------- scipy/sparse/linalg/dsolve/linsolve.py | 2 +- 4 files changed, 63 insertions(+), 87 deletions(-) diff --git a/scipy/linalg/matfuncs.py b/scipy/linalg/matfuncs.py index 2d3c75085f2f..cd6fd8308f0f 100644 --- a/scipy/linalg/matfuncs.py +++ b/scipy/linalg/matfuncs.py @@ -12,6 +12,9 @@ from numpy import matrix as mat import numpy as np +from scipy.sparse.base import isspmatrix +from scipy.sparse.construct import eye as speye + # Local imports from misc import norm from basic import solve, inv @@ -87,43 +90,67 @@ def expm(A, q=False): return R # implementation of Pade approximations of various degree using the algorithm presented in [Higham 2005] - +# These should apply to both dense and sparse matricies. def _pade3(A): b = (120., 60., 12., 1.) - ident = eye(A.shape[0], A.shape[1], dtype=A.dtype) - A2 = dot(A,A) - U = dot(A , (b[3]*A2 + b[1]*ident)) + if isspmatrix(A): + ident = speye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format) + else: + ident = eye(A.shape[0], A.shape[1], dtype=A.dtype) + A2 = A.dot(A) + U = A.dot(b[3]*A2 + b[1]*ident) V = b[2]*A2 + b[0]*ident return U,V def _pade5(A): b = (30240., 15120., 3360., 420., 30., 1.) - ident = eye(A.shape[0], A.shape[1], dtype=A.dtype) - A2 = dot(A,A) - A4 = dot(A2,A2) - U = dot(A, b[5]*A4 + b[3]*A2 + b[1]*ident) + if isspmatrix(A): + ident = speye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format) + else: + ident = eye(A.shape[0], A.shape[1], dtype=A.dtype) + A2 = A.dot(A) + A4 = A2.dot(A2) + U = A.dot(b[5]*A4 + b[3]*A2 + b[1]*ident) + V = b[4]*A4 + b[2]*A2 + b[0]*ident + return U,V + +def _pade5(A): + b = (30240., 15120., 3360., 420., 30., 1.) + if isspmatrix(A): + ident = speye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format) + else: + ident = eye(A.shape[0], A.shape[1], dtype=A.dtype) + A2 = A.dot(A) + A4 = A2.dot(A2) + U = A.dot(b[5]*A4 + b[3]*A2 + b[1]*ident) V = b[4]*A4 + b[2]*A2 + b[0]*ident return U,V def _pade7(A): b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.) - ident = eye(A.shape[0], A.shape[1], dtype=A.dtype) - A2 = dot(A,A) - A4 = dot(A2,A2) - A6 = dot(A4,A2) - U = dot(A, b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident) + if isspmatrix(A): + ident = speye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format) + else: + ident = eye(A.shape[0], A.shape[1], dtype=A.dtype) + A2 = A.dot(A) + A4 = A2.dot(A2) + A6 = A4.dot(A2) + U = A.dot(b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident) V = b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident return U,V def _pade9(A): b = (17643225600., 8821612800., 2075673600., 302702400., 30270240., 2162160., 110880., 3960., 90., 1.) - ident = eye(A.shape[0], A.shape[1], dtype=A.dtype) - A2 = dot(A,A) - A4 = dot(A2,A2) - A6 = dot(A4,A2) - A8 = dot(A6,A2) - U = dot(A, b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident) + if isspmatrix(A): + ident = speye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format) + else: + ident = eye(A.shape[0], A.shape[1], dtype=A.dtype) + A2 = A.dot(A) + A4 = A2.dot(A2) + A6 = A4.dot(A2) + A8 = A6.dot(A2) + U = A.dot(b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident) V = b[8]*A8 + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident return U,V @@ -131,12 +158,15 @@ def _pade13(A): b = (64764752532480000., 32382376266240000., 7771770303897600., 1187353796428800., 129060195264000., 10559470521600., 670442572800., 33522128640., 1323241920., 40840800., 960960., 16380., 182., 1.) - ident = eye(A.shape[0], A.shape[1], dtype=A.dtype) - A2 = dot(A,A) - A4 = dot(A2,A2) - A6 = dot(A4,A2) - U = dot(A,dot(A6, b[13]*A6 + b[11]*A4 + b[9]*A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident) - V = dot(A6, b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident + if isspmatrix(A): + ident = speye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format) + else: + ident = eye(A.shape[0], A.shape[1], dtype=A.dtype) + A2 = A.dot(A) + A4 = A2.dot(A2) + A6 = A4.dot(A2) + U = A.dot(A6.dot(b[13]*A6 + b[11]*A4 + b[9]*A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident) + V = A6.dot(b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident return U,V def expm2(A): diff --git a/scipy/linalg/misc.py b/scipy/linalg/misc.py index 179135660ed7..c20de7b8b26c 100644 --- a/scipy/linalg/misc.py +++ b/scipy/linalg/misc.py @@ -1,7 +1,12 @@ import numpy as np +from numpy import eye from numpy.linalg import LinAlgError import fblas +from scipy.sparse.base import isspmatrix +from scipy.sparse.construct import eye as speye + + __all__ = ['LinAlgError', 'norm'] _nrm2_prefix = {'f' : 's', 'F': 'sc', 'D': 'dz'} diff --git a/scipy/sparse/base.py b/scipy/sparse/base.py index 63abd5d19585..2d1ce8ef113a 100644 --- a/scipy/sparse/base.py +++ b/scipy/sparse/base.py @@ -639,6 +639,8 @@ def expm(self): -------- scipy.linalg.expm """ + from scipy.linalg.matfuncs import _pade3, _pade5, _pade7, _pade9, _pade13 + A_L1 = max(abs(self).sum(axis=0).flat) n_squarings = 0 @@ -691,67 +693,6 @@ def inv(self): return selfinv - -def _pade3(A): - from construct import eye - b = (120., 60., 12., 1.) - ident = eye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format) - A2 = A.dot(A) - U = A.dot(b[3]*A2 + b[1]*ident) - V = b[2]*A2 + b[0]*ident - return U,V - -def _pade5(A): - from construct import eye - b = (30240., 15120., 3360., 420., 30., 1.) - ident = eye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format) - A2 = A.dot(A) - A4 = A2.dot(A2) - U = A.dot(b[5]*A4 + b[3]*A2 + b[1]*ident) - V = b[4]*A4 + b[2]*A2 + b[0]*ident - return U,V - -def _pade7(A): - from construct import eye - b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.) - ident = eye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format) - A2 = A.dot(A) - A4 = A2.dot(A2) - A6 = A4.dot(A2) - U = A.dot(b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident) - V = b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident - return U,V - -def _pade9(A): - from construct import eye - b = (17643225600., 8821612800., 2075673600., 302702400., 30270240., - 2162160., 110880., 3960., 90., 1.) - ident = eye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format) - A2 = A.dot(A) - A4 = A2.dot(A2) - A6 = A4.dot(A2) - A8 = A6.dot(A2) - U = A.dot(b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident) - V = b[8]*A8 + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident - return U,V - -def _pade13(A): - from construct import eye - b = (64764752532480000., 32382376266240000., 7771770303897600., - 1187353796428800., 129060195264000., 10559470521600., 670442572800., - 33522128640., 1323241920., 40840800., 960960., 16380., 182., 1.) - ident = eye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format) - A2 = A.dot(A) - A4 = A2.dot(A2) - A6 = A4.dot(A2) - U = A.dot(A6.dot(b[13]*A6 + b[11]*A4 + b[9]*A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident) - V = A6.dot(b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident - return U,V - - - - - def isspmatrix(x): return isinstance(x, spmatrix) diff --git a/scipy/sparse/linalg/dsolve/linsolve.py b/scipy/sparse/linalg/dsolve/linsolve.py index 987e3a77d2c2..a5a87533746a 100644 --- a/scipy/sparse/linalg/dsolve/linsolve.py +++ b/scipy/sparse/linalg/dsolve/linsolve.py @@ -62,7 +62,7 @@ def spsolve(A, b, permc_spec=None, use_umfpack=True): msg += "or matrix whose shape matches lhs (%s)" % (A.shape,) raise ValueError(msg) elif isspmatrix(b) and not (isspmatrix_csc(b) or isspmatrix_csr(b)): - B = csc_matrix(b) + b = csc_matrix(b) warn('solve requires b be CSC or CSR matrix format', SparseEfficiencyWarning) A.sort_indices()