diff --git a/scipy/linalg/_decomp_svd.py b/scipy/linalg/_decomp_svd.py index 24be62752e62..a2452e3dc21b 100644 --- a/scipy/linalg/_decomp_svd.py +++ b/scipy/linalg/_decomp_svd.py @@ -116,6 +116,20 @@ def svd(a, full_matrices=True, compute_uv=True, overwrite_a=False, if lapack_driver not in ('gesdd', 'gesvd'): message = f'lapack_driver must be "gesdd" or "gesvd", not "{lapack_driver}"' raise ValueError(message) + + if lapack_driver == 'gesdd' and compute_uv: + # XXX: revisit int32 when ILP64 lapack becomes a thing + max_mn, min_mn = (m, n) if m > n else (n, m) + if full_matrices: + if max_mn*max_mn > numpy.iinfo(numpy.int32).max: + raise ValueError(f"Indexing a matrix size {max_mn} x {max_mn} " + " would incur integer overflow in LAPACK.") + else: + sz = max(m * min_mn, n * min_mn) + if max(m * min_mn, n * min_mn) > numpy.iinfo(numpy.int32).max: + raise ValueError(f"Indexing a matrix of {sz} elements would " + "incur an in integer overflow in LAPACK.") + funcs = (lapack_driver, lapack_driver + '_lwork') gesXd, gesXd_lwork = get_lapack_funcs(funcs, (a1,), ilp64='preferred') diff --git a/scipy/linalg/tests/test_decomp.py b/scipy/linalg/tests/test_decomp.py index bcd5b6611d80..f9e347050627 100644 --- a/scipy/linalg/tests/test_decomp.py +++ b/scipy/linalg/tests/test_decomp.py @@ -1089,6 +1089,14 @@ class TestSVD_GESVD(TestSVD_GESDD): lapack_driver = 'gesvd' +def test_svd_gesdd_nofegfault(): + # svd(a) with {U,VT}.size > INT_MAX does not segfault + # cf https://github.com/scipy/scipy/issues/14001 + df=np.ones((4799, 53130), dtype=np.float64) + with assert_raises(ValueError): + svd(df) + + class TestSVDVals: def test_empty(self):