Skip to content

Commit

Permalink
Merge pull request #12206 from ilayn/tall_pinv
Browse files Browse the repository at this point in the history
MAINT:lstsq: Switch to tranposed problem if the array is tall
  • Loading branch information
tylerjereddy committed Jun 9, 2020
2 parents c38df7b + d6b24ca commit 99b8660
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
15 changes: 9 additions & 6 deletions scipy/linalg/basic.py
Expand Up @@ -673,8 +673,8 @@ def solve_toeplitz(c_or_cr, b, check_finite=True):
c = _asarray_validated(c_or_cr, check_finite=check_finite).ravel()
r = c.conjugate()

# Form a 1-D array of values to be used in the matrix, containing a reversed
# copy of r[1:], followed by c.
# Form a 1-D array of values to be used in the matrix, containing a
# reversed copy of r[1:], followed by c.
vals = np.concatenate((r[-1:0:-1], c))
if b is None:
raise ValueError('illegal value, `b` is a required argument')
Expand Down Expand Up @@ -1302,20 +1302,23 @@ def pinv(a, cond=None, rcond=None, return_rank=False, check_finite=True):
"""
a = _asarray_validated(a, check_finite=check_finite)
b = np.identity(a.shape[0], dtype=a.dtype)
# If a is sufficiently tall it is cheaper to compute using the transpose
trans = a.shape[0] / a.shape[1] >= 1.1
b = np.eye(a.shape[1] if trans else a.shape[0], dtype=a.dtype)

if rcond is not None:
cond = rcond

if cond is None:
cond = max(a.shape) * np.spacing(a.real.dtype.type(1))

x, resids, rank, s = lstsq(a, b, cond=cond, check_finite=False)
x, resids, rank, s = lstsq(a.T if trans else a, b,
cond=cond, check_finite=False)

if return_rank:
return x, rank
return (x.T if trans else x), rank
else:
return x
return x.T if trans else x


def pinv2(a, cond=None, rcond=None, return_rank=False, check_finite=True):
Expand Down
7 changes: 7 additions & 0 deletions scipy/linalg/tests/test_basic.py
Expand Up @@ -1262,6 +1262,13 @@ def test_native_list_argument(self):
a_pinv2 = pinv2(a)
assert_array_almost_equal(a_pinv, a_pinv2)

def test_tall_transposed(self):
a = random([10, 2])
a_pinv = pinv(a)
# The result will be transposed internally hence will be a C-layout
# instead of the typical LAPACK output with Fortran-layout
assert a_pinv.flags['C_CONTIGUOUS']


class TestPinvSymmetric(object):

Expand Down

0 comments on commit 99b8660

Please sign in to comment.