Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP

Loading…

Qr with pivoting #44

Closed
wants to merge 6 commits into from

4 participants

@collinstocks

I've added functionality to linalg.qr(). It now has an optional boolean keyword argument, ``pivoting'' which signals that QR decomposition should be computed with pivoting. See http://en.wikipedia.org/wiki/Qr_decomposition#Column_pivoting if you need more details on what that means.

This is implemented by wrapping geqp3 from the lapack library.

This feature is necessary in order to replicate the behavior of MatLab's stepwisefit() function, an implementation of which I plan to submit to scikits.statsmodels, since they (via their mailing list) have shown interest in such an implementation.

I have included fairly comprehensive tests of the new functionality, as you can verify.

This patch addresses ticket #1473.

Thanks for considering this patch.

-- Collin

@rgommers

You should add the "pivoting" keyword at the end, like this you break backwards compatibility because keyword args can also be given as positional args.

Done, thanks for the recommendation.

@collinstocks

Also, I forgot to mention above: I have included documentation of the new functionality of linalg.qr().

@scopatz
Collaborator

I have confirmed that this builds and tests. @rgommers, any further reason not to merge?

@rgommers
Owner

Looks good to me as far as I can judge (which doesn't include the f2py part).

@collinstocks

According to the documentation I could find, the f2py part is correct. However, someone more experienced with that should definitely take a look at that, especially given that my first attempt at that code resulted in a segmentation fault under certain circumstances because the real and complex versions of geqp3 have different prototypes.

@scopatz
Collaborator

Unfortunately, I am not all that familiar with f2py myself....

@collinstocks

I've sent out an email on scipy-dev and scipy-user to see if anyone happens to be familiar with f2py. Is there anyone in particular either of you have in mind who might be able to help?

@pv
Owner
pv commented

Thanks for the thorough work, this branch is merged in 6b30572

@pv pv closed this
@collinstocks

Glad to be of assistance. Looking forward to the next release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Commits on Jul 20, 2011
  1. @collinstocks

    Feature Added - qr decomposition with pivoting (rank revealing qr fac…

    collinstocks authored
    …torization), along with tests
  2. @collinstocks

    Add more tests for qr with pivoting; change keyword from pivoted to p…

    collinstocks authored
    …ivoting based on feedback from scipy-dev; add more documentation
  3. @collinstocks
Commits on Jul 21, 2011
  1. @collinstocks

    BF - had previously assumed that <s,d>geqp3 and <c,z>geqp3 would have…

    collinstocks authored
    … the same prototype; include complex tests; fix typo in documentation
  2. @collinstocks
Commits on Jul 22, 2011
  1. @collinstocks
This page is out of date. Refresh to see the latest.
View
72 scipy/linalg/decomp_qr.py
@@ -1,3 +1,5 @@
+# Additions by Collin RM Stocks, July 2011
+
"""QR decomposition functions."""
import numpy
@@ -13,7 +15,7 @@
__all__ = ['qr', 'rq', 'qr_old']
-def qr(a, overwrite_a=False, lwork=None, mode='full'):
+def qr(a, overwrite_a=False, lwork=None, mode='full', pivoting=False):
"""Compute QR decomposition of a matrix.
Calculate the decomposition :lm:`A = Q R` where Q is unitary/orthogonal
@@ -32,6 +34,11 @@ def qr(a, overwrite_a=False, lwork=None, mode='full'):
Determines what information is to be returned: either both Q and R
('full', default), only R ('r') or both Q and R but computed in
economy-size ('economic', see Notes).
+ pivoting : bool, optional
+ Whether or not factorization should include pivoting for rank-revealing
+ qr decomposition. If pivoting, compute the decomposition
+ :lm:`A P = Q R` as above, but where P is chosen such that the diagonal
+ of R is non-increasing.
Returns
-------
@@ -40,33 +47,51 @@ def qr(a, overwrite_a=False, lwork=None, mode='full'):
``mode='r'``.
R : double or complex ndarray
Of shape (M, N), or (K, N) for ``mode='economic'``. ``K = min(M, N)``.
+ P : double or complex ndarray
+ Of shape (N, 1) for ``pivoting=True``. Not returned if
Raises LinAlgError if decomposition fails
Notes
-----
This is an interface to the LAPACK routines dgeqrf, zgeqrf,
- dorgqr, and zungqr.
+ dorgqr, zungqr, dgeqp3, and zgeqp3.
If ``mode=economic``, the shapes of Q and R are (M, K) and (K, N) instead
of (M,M) and (M,N), with ``K=min(M,N)``.
Examples
--------
- >>> from scipy import random, linalg, dot, allclose
+ >>> from scipy import random, linalg, dot, diag, all, allclose
>>> a = random.randn(9, 6)
+
>>> q, r = linalg.qr(a)
>>> allclose(a, dot(q, r))
True
>>> q.shape, r.shape
((9, 9), (9, 6))
+
>>> r2 = linalg.qr(a, mode='r')
>>> allclose(r, r2)
True
+
>>> q3, r3 = linalg.qr(a, mode='economic')
>>> q3.shape, r3.shape
((9, 6), (6, 6))
+ >>> q4, r4, p4 = linalg.qr(a, pivoting=True)
+ >>> d = abs(diag(r4))
+ >>> all(d[1:] <= d[:-1])
+ True
+ >>> allclose(a[:, p4], dot(q4, r4))
+ True
+ >>> q4.shape, r4.shape, p4.shape
+ ((9, 9), (9, 6), (6,))
+
+ >>> q5, r5, p5 = linalg.qr(a, mode='economic', pivoting=True)
+ >>> q5.shape, r5.shape, p5.shape
+ ((9, 6), (6, 6), (6,))
+
"""
if mode == 'qr':
# 'qr' was the old default, equivalent to 'full'. Neither 'full' nor
@@ -82,23 +107,40 @@ def qr(a, overwrite_a=False, lwork=None, mode='full'):
M, N = a1.shape
overwrite_a = overwrite_a or (_datacopied(a1, a))
- geqrf, = get_lapack_funcs(('geqrf',), (a1,))
- if lwork is None or lwork == -1:
- # get optimal work array
- qr, tau, work, info = geqrf(a1, lwork=-1, overwrite_a=1)
- lwork = work[0].real.astype(numpy.int)
-
- qr, tau, work, info = geqrf(a1, lwork=lwork, overwrite_a=overwrite_a)
- if info < 0:
- raise ValueError("illegal value in %d-th argument of internal geqrf"
- % -info)
+ if pivoting:
+ geqp3, = get_lapack_funcs(('geqp3',), (a1,))
+ if lwork is None or lwork == -1:
+ # get optimal work array
+ qr, jpvt, tau, work, info = geqp3(a1, lwork=-1, overwrite_a=1)
+ lwork = work[0].real.astype(numpy.int)
+
+ qr, jpvt, tau, work, info = geqp3(a1, lwork=lwork,
+ overwrite_a=overwrite_a)
+ jpvt -= 1 # geqp3 returns a 1-based index array, so subtract 1
+ if info < 0:
+ raise ValueError("illegal value in %d-th argument of internal geqp3"
+ % -info)
+ else:
+ geqrf, = get_lapack_funcs(('geqrf',), (a1,))
+ if lwork is None or lwork == -1:
+ # get optimal work array
+ qr, tau, work, info = geqrf(a1, lwork=-1, overwrite_a=1)
+ lwork = work[0].real.astype(numpy.int)
+
+ qr, tau, work, info = geqrf(a1, lwork=lwork, overwrite_a=overwrite_a)
+ if info < 0:
+ raise ValueError("illegal value in %d-th argument of internal geqrf"
+ % -info)
if not mode == 'economic' or M < N:
R = special_matrices.triu(qr)
else:
R = special_matrices.triu(qr[0:N, 0:N])
if mode == 'r':
- return R
+ if pivoting:
+ return R, jpvt
+ else:
+ return R
if find_best_lapack_type((a1,))[0] in ('s', 'd'):
gor_un_gqr, = get_lapack_funcs(('orgqr',), (qr,))
@@ -127,6 +169,8 @@ def qr(a, overwrite_a=False, lwork=None, mode='full'):
if info < 0:
raise ValueError("illegal value in %d-th argument of internal gorgqr"
% -info)
+ if pivoting:
+ return Q, R, jpvt
return Q, R
View
42 scipy/linalg/generic_flapack.pyf
@@ -7,6 +7,7 @@
!
! Additions by Travis Oliphant
! Additions by Tiziano Zito
+! Additions by Collin RM Stocks
! Usage:
! f2py -c generic_flapack.pyf -L/usr/local/lib/atlas -llapack -lf77blas -lcblas -latlas -lg2c
@@ -494,6 +495,47 @@ interface
end subroutine <tchar=c,z>gelss
+ subroutine <tchar=s,d>geqp3(m,n,a,jpvt,tau,work,lwork,info)
+
+ ! qr_a,jpvt,tau,work,info = geqp3(a,lwork=3*(n+1),overwrite_a=0)
+ ! Compute a QR factorization of a real M-by-N matrix A with column pivoting:
+ ! A * P = Q * R.
+
+ callstatement (*f2py_func)(&m,&n,a,&m,jpvt,tau,work,&lwork,&info)
+ callprotoargument int*,int*,<type_in_c>*,int*,int*,<type_in_c>*,<type_in_c>*,int*,int*
+
+ integer intent(hide),depend(a):: m = shape(a,0)
+ integer intent(hide),depend(a):: n = shape(a,1)
+ <type_in> dimension(m,n),intent(in,out,copy,out=qr,aligned8) :: a
+ integer dimension(n),intent(out) :: jpvt
+ <type_in> dimension(MIN(m,n)),intent(out) :: tau
+
+ integer optional,intent(in),depend(n),check(lwork>=n||lwork==-1) :: lwork=3*(n+1)
+ <type_in> dimension(MAX(lwork,1)),intent(out),depend(lwork) :: work
+ integer intent(out) :: info
+ end subroutine <tchar=s,d>geqp3
+
+ subroutine <tchar=c,z>geqp3(m,n,a,jpvt,tau,work,lwork,rwork,info)
+
+ ! qr_a,jpvt,tau,work,info = geqp3(a,lwork,overwrite_a=0)
+ ! Compute a QR factorization of a complex M-by-N matrix A with column pivoting:
+ ! A * P = Q * R.
+
+ callstatement (*f2py_func)(&m,&n,a,&m,jpvt,tau,work,&lwork,rwork,&info)
+ callprotoargument int*,int*,<type_in_c>*,int*,int*,<type_in_c>*,<type_in_c>*,int*,<type_in_c>*,int*
+
+ integer intent(hide),depend(a):: m = shape(a,0)
+ integer intent(hide),depend(a):: n = shape(a,1)
+ <type_in> dimension(m,n),intent(in,out,copy,out=qr,aligned8) :: a
+ integer dimension(n),intent(out) :: jpvt
+ <type_in> dimension(MIN(m,n)),intent(out) :: tau
+
+ integer optional,intent(in),depend(n),check(lwork>=n||lwork==-1) :: lwork=3*(n+1)
+ <type_in> dimension(MAX(lwork,1)),intent(out),depend(lwork) :: work
+ <type_in> dimension(2*n),intent(hide),depend(n) :: rwork
+ integer intent(out) :: info
+ end subroutine <tchar=c,z>geqp3
+
subroutine <tchar=s,d,c,z>geqrf(m,n,a,tau,work,lwork,info)
! qr_a,tau,work,info = geqrf(a,lwork=3*n,overwrite_a=0)
View
158 scipy/linalg/tests/test_decomp.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python
#
# Created by: Pearu Peterson, March 2002
+# Additions by Collin RM Stocks, July 2011
#
""" Test functions for linalg.decomp module
@@ -861,12 +862,34 @@ def test_simple(self):
assert_array_almost_equal(dot(transpose(q),q),identity(3))
assert_array_almost_equal(dot(q,r),a)
+ def test_simple_pivoting(self):
+ a = np.asarray([[8,2,3],[2,9,3],[5,3,6]])
+ q,r,p = qr(a, pivoting=True)
+ d = abs(diag(r))
+ assert_(all(d[1:] <= d[:-1]))
+ assert_array_almost_equal(dot(transpose(q),q),identity(3))
+ assert_array_almost_equal(dot(q,r),a[:,p])
+ q2,r2 = qr(a[:,p])
+ assert_array_almost_equal(q,q2)
+ assert_array_almost_equal(r,r2)
+
def test_simple_trap(self):
a = [[8,2,3],[2,9,3]]
q,r = qr(a)
assert_array_almost_equal(dot(transpose(q),q),identity(2))
assert_array_almost_equal(dot(q,r),a)
+ def test_simple_trap_pivoting(self):
+ a = np.asarray([[8,2,3],[2,9,3]])
+ q,r,p = qr(a, pivoting=True)
+ d = abs(diag(r))
+ assert_(all(d[1:] <= d[:-1]))
+ assert_array_almost_equal(dot(transpose(q),q),identity(2))
+ assert_array_almost_equal(dot(q,r),a[:,p])
+ q2,r2 = qr(a[:,p])
+ assert_array_almost_equal(q,q2)
+ assert_array_almost_equal(r,r2)
+
def test_simple_tall(self):
# full version
a = [[8,2],[2,9],[5,3]]
@@ -874,6 +897,18 @@ def test_simple_tall(self):
assert_array_almost_equal(dot(transpose(q),q),identity(3))
assert_array_almost_equal(dot(q,r),a)
+ def test_simple_tall_pivoting(self):
+ # full version pivoting
+ a = np.asarray([[8,2],[2,9],[5,3]])
+ q,r,p = qr(a, pivoting=True)
+ d = abs(diag(r))
+ assert_(all(d[1:] <= d[:-1]))
+ assert_array_almost_equal(dot(transpose(q),q),identity(3))
+ assert_array_almost_equal(dot(q,r),a[:,p])
+ q2,r2 = qr(a[:,p])
+ assert_array_almost_equal(q,q2)
+ assert_array_almost_equal(r,r2)
+
def test_simple_tall_e(self):
# economy version
a = [[8,2],[2,9],[5,3]]
@@ -883,6 +918,18 @@ def test_simple_tall_e(self):
assert_equal(q.shape, (3,2))
assert_equal(r.shape, (2,2))
+ def test_simple_tall_e_pivoting(self):
+ # economy version pivoting
+ a = np.asarray([[8,2],[2,9],[5,3]])
+ q,r,p = qr(a, pivoting=True, mode='economic')
+ d = abs(diag(r))
+ assert_(all(d[1:] <= d[:-1]))
+ assert_array_almost_equal(dot(transpose(q),q),identity(2))
+ assert_array_almost_equal(dot(q,r),a[:,p])
+ q2,r2 = qr(a[:,p], mode='economic')
+ assert_array_almost_equal(q,q2)
+ assert_array_almost_equal(r,r2)
+
def test_simple_fat(self):
# full version
a = [[8,2,5],[2,9,3]]
@@ -892,6 +939,20 @@ def test_simple_fat(self):
assert_equal(q.shape, (2,2))
assert_equal(r.shape, (2,3))
+ def test_simple_fat_pivoting(self):
+ # full version pivoting
+ a = np.asarray([[8,2,5],[2,9,3]])
+ q,r,p = qr(a, pivoting=True)
+ d = abs(diag(r))
+ assert_(all(d[1:] <= d[:-1]))
+ assert_array_almost_equal(dot(transpose(q),q),identity(2))
+ assert_array_almost_equal(dot(q,r),a[:,p])
+ assert_equal(q.shape, (2,2))
+ assert_equal(r.shape, (2,3))
+ q2,r2 = qr(a[:,p])
+ assert_array_almost_equal(q,q2)
+ assert_array_almost_equal(r,r2)
+
def test_simple_fat_e(self):
# economy version
a = [[8,2,3],[2,9,5]]
@@ -901,12 +962,37 @@ def test_simple_fat_e(self):
assert_equal(q.shape, (2,2))
assert_equal(r.shape, (2,3))
+ def test_simple_fat_e_pivoting(self):
+ # economy version pivoting
+ a = np.asarray([[8,2,3],[2,9,5]])
+ q,r,p = qr(a, pivoting=True, mode='economic')
+ d = abs(diag(r))
+ assert_(all(d[1:] <= d[:-1]))
+ assert_array_almost_equal(dot(transpose(q),q),identity(2))
+ assert_array_almost_equal(dot(q,r),a[:,p])
+ assert_equal(q.shape, (2,2))
+ assert_equal(r.shape, (2,3))
+ q2,r2 = qr(a[:,p], mode='economic')
+ assert_array_almost_equal(q,q2)
+ assert_array_almost_equal(r,r2)
+
def test_simple_complex(self):
a = [[3,3+4j,5],[5,2,2+7j],[3,2,7]]
q,r = qr(a)
assert_array_almost_equal(dot(conj(transpose(q)),q),identity(3))
assert_array_almost_equal(dot(q,r),a)
+ def test_simple_complex_pivoting(self):
+ a = np.asarray([[3,3+4j,5],[5,2,2+7j],[3,2,7]])
+ q,r,p = qr(a, pivoting=True)
+ d = abs(diag(r))
+ assert_(all(d[1:] <= d[:-1]))
+ assert_array_almost_equal(dot(conj(transpose(q)),q),identity(3))
+ assert_array_almost_equal(dot(q,r),a[:,p])
+ q2,r2 = qr(a[:,p])
+ assert_array_almost_equal(q,q2)
+ assert_array_almost_equal(r,r2)
+
def test_random(self):
n = 20
for k in range(2):
@@ -915,6 +1001,19 @@ def test_random(self):
assert_array_almost_equal(dot(transpose(q),q),identity(n))
assert_array_almost_equal(dot(q,r),a)
+ def test_random_pivoting(self):
+ n = 20
+ for k in range(2):
+ a = random([n,n])
+ q,r,p = qr(a, pivoting=True)
+ d = abs(diag(r))
+ assert_(all(d[1:] <= d[:-1]))
+ assert_array_almost_equal(dot(transpose(q),q),identity(n))
+ assert_array_almost_equal(dot(q,r),a[:,p])
+ q2,r2 = qr(a[:,p])
+ assert_array_almost_equal(q,q2)
+ assert_array_almost_equal(r,r2)
+
def test_random_tall(self):
# full version
m = 200
@@ -925,6 +1024,21 @@ def test_random_tall(self):
assert_array_almost_equal(dot(transpose(q),q),identity(m))
assert_array_almost_equal(dot(q,r),a)
+ def test_random_tall(self):
+ # full version pivoting
+ m = 200
+ n = 100
+ for k in range(2):
+ a = random([m,n])
+ q,r,p = qr(a, pivoting=True)
+ d = abs(diag(r))
+ assert_(all(d[1:] <= d[:-1]))
+ assert_array_almost_equal(dot(transpose(q),q),identity(m))
+ assert_array_almost_equal(dot(q,r),a[:,p])
+ q2,r2 = qr(a[:,p])
+ assert_array_almost_equal(q,q2)
+ assert_array_almost_equal(r,r2)
+
def test_random_tall_e(self):
# economy version
m = 200
@@ -937,6 +1051,23 @@ def test_random_tall_e(self):
assert_equal(q.shape, (m,n))
assert_equal(r.shape, (n,n))
+ def test_random_tall_e(self):
+ # economy version pivoting
+ m = 200
+ n = 100
+ for k in range(2):
+ a = random([m,n])
+ q,r,p = qr(a, pivoting=True, mode='economic')
+ d = abs(diag(r))
+ assert_(all(d[1:] <= d[:-1]))
+ assert_array_almost_equal(dot(transpose(q),q),identity(n))
+ assert_array_almost_equal(dot(q,r),a[:,p])
+ assert_equal(q.shape, (m,n))
+ assert_equal(r.shape, (n,n))
+ q2,r2 = qr(a[:,p], mode='economic')
+ assert_array_almost_equal(q,q2)
+ assert_array_almost_equal(r,r2)
+
def test_random_trap(self):
m = 100
n = 200
@@ -946,6 +1077,20 @@ def test_random_trap(self):
assert_array_almost_equal(dot(transpose(q),q),identity(m))
assert_array_almost_equal(dot(q,r),a)
+ def test_random_trap_pivoting(self):
+ m = 100
+ n = 200
+ for k in range(2):
+ a = random([m,n])
+ q,r,p = qr(a, pivoting=True)
+ d = abs(diag(r))
+ assert_(all(d[1:] <= d[:-1]))
+ assert_array_almost_equal(dot(transpose(q),q),identity(m))
+ assert_array_almost_equal(dot(q,r),a[:,p])
+ q2,r2 = qr(a[:,p])
+ assert_array_almost_equal(q,q2)
+ assert_array_almost_equal(r,r2)
+
def test_random_complex(self):
n = 20
for k in range(2):
@@ -954,6 +1099,19 @@ def test_random_complex(self):
assert_array_almost_equal(dot(conj(transpose(q)),q),identity(n))
assert_array_almost_equal(dot(q,r),a)
+ def test_random_complex_pivoting(self):
+ n = 20
+ for k in range(2):
+ a = random([n,n])+1j*random([n,n])
+ q,r,p = qr(a, pivoting=True)
+ d = abs(diag(r))
+ assert_(all(d[1:] <= d[:-1]))
+ assert_array_almost_equal(dot(conj(transpose(q)),q),identity(n))
+ assert_array_almost_equal(dot(q,r),a[:,p])
+ q2,r2 = qr(a[:,p])
+ assert_array_almost_equal(q,q2)
+ assert_array_almost_equal(r,r2)
+
class TestRQ(TestCase):
def setUp(self):
Something went wrong with that request. Please try again.