Skip to content

Commit

Permalink
BUG: linalg: remove getr* from clapack, as they are not compatible wi…
Browse files Browse the repository at this point in the history
…th flapack routines (fixes #1458)

The getr* routines in CLAPACK have U unit diagonal in LU, whereas in
FLAPACK routines, L is unit diagonal.
  • Loading branch information
pv committed Jun 11, 2011
1 parent e23e462 commit 388b3da
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 141 deletions.
68 changes: 0 additions & 68 deletions scipy/lib/lapack/clapack.pyf.src
Expand Up @@ -43,74 +43,6 @@ python module clapack

end function <prefix>gesv

function <prefix>getrf(m,n,a,piv,info,rowmajor)

! lu,piv,info = getrf(a,rowmajor=1,overwrite_a=0)
! Compute an LU factorization of a general M-by-N matrix A.
! A * P = L * U
threadsafe
fortranname clapack_<prefix>getrf
integer intent(c,hide) :: <prefix>getrf
callstatement <prefix>getrf_return_value = info = (*f2py_func)(102-rowmajor,m,n,a,(rowmajor?n:m),piv)
callprotoargument const int,const int,const int,<ctype>*,const int,int*

integer optional,intent(in),check(rowmajor==1||rowmajor==0) :: rowmajor = 1

integer depend(a),intent(hide):: m = shape(a,0)
integer depend(a),intent(hide):: n = shape(a,1)
<ftype> dimension(m,n),intent(c,in,out,copy,out=lu) :: a
integer dimension((m<n?m:n)),depend(m,n),intent(out) :: piv
integer intent(out):: info

end function <prefix>getrf

function <prefix>getrs(n,nrhs,lu,piv,b,info,trans,rowmajor)

! x,info = getrs(lu,piv,b,trans=0,rowmajor=1,overwrite_b=0)
! Solve A * X = B if trans=0
! Solve A^T * X = B if trans=1
! Solve A^H * X = B if trans=2
! A * P = L * U

fortranname clapack_<prefix>getrs
integer intent(c,hide) :: <prefix>getrs
callstatement <prefix>getrs_return_value = info = (*f2py_func)(102-rowmajor,111+trans,n,nrhs,lu,n,piv,b,n)
callprotoargument const int,const int,const int,const int,<ctype>*,const int,int*,<ctype>*,const int

integer optional,intent(in),check(rowmajor==1||rowmajor==0) :: rowmajor = 1
integer optional,intent(in),check(trans>=0 && trans <=2) :: trans = 0

integer depend(lu),intent(hide):: n = shape(lu,0)
integer depend(b),intent(hide):: nrhs = shape(b,1)
<ftype> dimension(n,n),intent(c,in) :: lu
check(shape(lu,0)==shape(lu,1)) :: lu
integer dimension(n),intent(in),depend(n) :: piv
<ftype> dimension(n,nrhs),intent(in,out,copy,out=x),depend(n),check(shape(lu,0)==shape(b,0)) :: b
integer intent(out):: info
end function <prefix>getrs

function <prefix>getri(n,lu,piv,info,rowmajor)

! inv_a,info = getri(lu,piv,rowmajor=1,overwrite_lu=0)
! Find A inverse A^-1.
! A * P = L * U

fortranname clapack_<prefix>getri
integer intent(c,hide) :: <prefix>getri
callstatement <prefix>getri_return_value = info = (*f2py_func)(102-rowmajor,n,lu,n,piv)
callprotoargument const int,const int,<ctype>*,const int,const int*

integer optional,intent(in),check(rowmajor==1||rowmajor==0) :: rowmajor = 1

integer depend(lu),intent(hide):: n = shape(lu,0)
<ftype> dimension(n,n),intent(c,in,out,copy,out=inv_a) :: lu
check(shape(lu,0)==shape(lu,1)) :: lu
integer dimension(n),intent(in),depend(n) :: piv
integer intent(out):: info

end function <prefix>getri


function <prefix>posv(n,nrhs,a,b,info,lower,rowmajor)

! c,x,info = posv(a,b,lower=0,rowmajor=1,overwrite_a=0,overwrite_b=0)
Expand Down
68 changes: 0 additions & 68 deletions scipy/linalg/generic_clapack.pyf
Expand Up @@ -38,74 +38,6 @@ interface

end function <tchar=s,d,c,z>gesv

function <tchar=s,d,c,z>getrf(m,n,a,piv,info,rowmajor)

! lu,piv,info = getrf(a,rowmajor=1,overwrite_a=0)
! Compute an LU factorization of a general M-by-N matrix A.
! A * P = L * U
threadsafe
fortranname clapack_<tchar=s,d,c,z>getrf
integer intent(c,hide) :: <tchar=s,d,c,z>getrf
callstatement <tchar=s,d,c,z>getrf_return_value = info = (*f2py_func)(102-rowmajor,m,n,a,(rowmajor?n:m),piv)
callprotoargument const int,const int,const int,<type_in_c>*,const int,int*

integer optional,intent(in),check(rowmajor==1||rowmajor==0) :: rowmajor = 1

integer depend(a),intent(hide):: m = shape(a,0)
integer depend(a),intent(hide):: n = shape(a,1)
<type_in> dimension(m,n),intent(c,in,out,copy,out=lu) :: a
integer dimension((m<n?m:n)),depend(m,n),intent(out) :: piv
integer intent(out):: info

end function <tchar=s,d,c,z>getrf

function <tchar=s,d,c,z>getrs(n,nrhs,lu,piv,b,info,trans,rowmajor)

! x,info = getrs(lu,piv,b,trans=0,rowmajor=1,overwrite_b=0)
! Solve A * X = B if trans=0
! Solve A^T * X = B if trans=1
! Solve A^H * X = B if trans=2
! A * P = L * U

fortranname clapack_<tchar=s,d,c,z>getrs
integer intent(c,hide) :: <tchar=s,d,c,z>getrs
callstatement <tchar=s,d,c,z>getrs_return_value = info = (*f2py_func)(102-rowmajor,111+trans,n,nrhs,lu,n,piv,b,n)
callprotoargument const int,const int,const int,const int,<type_in_c>*,const int,int*,<type_in_c>*,const int

integer optional,intent(in),check(rowmajor==1||rowmajor==0) :: rowmajor = 1
integer optional,intent(in),check(trans>=0 && trans <=2) :: trans = 0

integer depend(lu),intent(hide):: n = shape(lu,0)
integer depend(b),intent(hide):: nrhs = shape(b,1)
<type_in> dimension(n,n),intent(c,in) :: lu
check(shape(lu,0)==shape(lu,1)) :: lu
integer dimension(n),intent(in),depend(n) :: piv
<type_in> dimension(n,nrhs),intent(in,out,copy,out=x),depend(n),check(shape(lu,0)==shape(b,0)) :: b
integer intent(out):: info
end function <tchar=s,d,c,z>getrs


function <tchar=s,d,c,z>getri(n,lu,piv,info,rowmajor)

! inv_a,info = getri(lu,piv,rowmajor=1,overwrite_lu=0)
! Find A inverse A^-1.
! A * P = L * U

fortranname clapack_<tchar=s,d,c,z>getri
integer intent(c,hide) :: <tchar=s,d,c,z>getri
callstatement <tchar=s,d,c,z>getri_return_value = info = (*f2py_func)(102-rowmajor,n,lu,n,piv)
callprotoargument const int,const int,<type_in_c>*,const int,const int*

integer optional,intent(in),check(rowmajor==1||rowmajor==0) :: rowmajor = 1

integer depend(lu),intent(hide):: n = shape(lu,0)
<type_in> dimension(n,n),intent(c,in,out,copy,out=inv_a) :: lu
check(shape(lu,0)==shape(lu,1)) :: lu
integer dimension(n),intent(in),depend(n) :: piv
integer intent(out):: info

end function <tchar=s,d,c,z>getri

function <tchar=s,d,c,z>posv(n,nrhs,a,b,info,lower,rowmajor)

! c,x,info = posv(a,b,lower=0,rowmajor=1,overwrite_a=0,overwrite_b=0)
Expand Down
22 changes: 17 additions & 5 deletions scipy/linalg/tests/test_decomp.py
Expand Up @@ -685,6 +685,15 @@ def test_medium1_complex(self):
"""Check lu decomposition on medium size, rectangular matrix."""
self._test_common(self.cmed)

def test_simple_known(self):
# Ticket #1458
for order in ['C', 'F']:
A = np.array([[2, 1],[0, 1.]], order=order)
LU, P = lu_factor(A)
assert_array_almost_equal(LU, np.array([[2, 1], [0, 1]]))
assert_array_equal(P, np.array([0, 1]))


class TestLUSingle(TestLU):
"""LU testers for single precision, real and double"""
def __init__(self, *args, **kw):
Expand All @@ -709,15 +718,18 @@ def setUp(self):
seed(1234)

def test_lu(self):
a = random((10,10))
a0 = random((10,10))
b = random((10,))

x1 = solve(a,b)
for order in ['C', 'F']:
a = np.array(a0, order=order)

x1 = solve(a,b)

lu_a = lu_factor(a)
x2 = lu_solve(lu_a,b)
lu_a = lu_factor(a)
x2 = lu_solve(lu_a,b)

assert_array_equal(x1,x2)
assert_array_equal(x1,x2)

class TestSVD(TestCase):
def setUp(self):
Expand Down

0 comments on commit 388b3da

Please sign in to comment.