Skip to content

Commit

Permalink
FIX:linalg:Guard against possible permute_l out of bound behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
ilayn authored and tylerjereddy committed Jun 28, 2023
1 parent 7ec5010 commit d9ac3f3
Showing 1 changed file with 24 additions and 15 deletions.
39 changes: 24 additions & 15 deletions scipy/linalg/_decomp_lu_cython.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ cdef void lu_decompose(cnp.ndarray[lapack_t, ndim=2] a,
int[::1] perm,
bint permute_l) noexcept:
"""LU decomposition and copy operations using ?getrf routines
This function overwrites inputs. For interfacing LAPACK,
it creates a memory buffer and copies into with F-order
then swaps back to C order hence no need for dealing with
Fortran arrays which are inconvenient.
After the LU factorization, to minimize the amount of data
copied, for rectangle arrays, and depending on the size,
the smaller portion is copied out to U and the rest becomes
Expand Down Expand Up @@ -117,7 +117,7 @@ cdef void lu_decompose(cnp.ndarray[lapack_t, ndim=2] a,
for ind1 in range(mn):
a[ind1, ind1] = 1
a[ind1, ind1+1:mn] = 0

else: # square or fat, "a" holds bigger U

lu[0, 0] = 1
Expand All @@ -128,22 +128,31 @@ cdef void lu_decompose(cnp.ndarray[lapack_t, ndim=2] a,
for ind2 in range(mn - 1): # cols
for ind1 in range(ind2+1, m): # rows
a[ind1, ind2] = 0

if permute_l:
# b still exists -> use it as temp array
# we copy everything to b and pick back
# rows from b as dictated by perm
memcpy(bb,
(&a[0, 0] if m > n else &lu[0, 0]),
m*n*sizeof(lapack_t)
)
for ind1 in range(m):
if perm[ind1] == ind1: # Row is not permuted
continue
else:
memcpy((&a[ind1, 0] if m > n else &lu[ind1, 0]),
bb + (perm[ind1]*mn), # stride rows
mn*sizeof(lapack_t)) # copy 1 row of mem

if m > n:
memcpy(bb, &a[0, 0], m*mn*sizeof(lapack_t))
for ind1 in range(m):
if perm[ind1] == ind1:
continue
else:
memcpy(&a[ind1, 0],
bb + (perm[ind1]*mn), # row stride
mn*sizeof(lapack_t)) # copy one row of memory

else: # same but for lu array
memcpy(bb, &lu[0, 0], mn*n*sizeof(lapack_t))
for ind1 in range(mn):
if perm[ind1] == ind1:
continue
else:
memcpy(&lu[ind1, 0],
bb + (perm[ind1]*mn),
mn*sizeof(lapack_t))


@cython.nonecheck(False)
Expand Down

0 comments on commit d9ac3f3

Please sign in to comment.