Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PCA on sparse, noncentered data #12794

Open
pavlin-policar opened this issue Dec 15, 2018 · 68 comments · May be fixed by #24415
Open

PCA on sparse, noncentered data #12794

pavlin-policar opened this issue Dec 15, 2018 · 68 comments · May be fixed by #24415

Comments

@pavlin-policar
Copy link

I suppose this is more of a feature request than anything else. There are several implementations of PCA that can compute the decomposition on noncentered, sparse data, while the implementation here does not support sparse matrices at all.

A matlab implementation can be found here and a Python implementation here. So far, I've been using the Python implementation, but it's missing some things and will eventually be deprecated (facebookarchive/fbpca#9).

I haven't looked at the code or math too much, but as far as I'm aware, it's just a matter of adding a term to randomized_range_finder to account for centering.

Is this something that you guys are aware of, and is anyone working on this? This would an awesome feature to have.

@rth
Copy link
Member

rth commented Dec 15, 2018

Did you see the TruncatedSVD estimator? In the current implementation, it should be equivalent to PCA on centered data, except that it does not necessarily center the data and also supports sparse input.

@pavlin-policar
Copy link
Author

pavlin-policar commented Dec 16, 2018

Yes, I am aware of TruncatedSVD, but that's different. It's true that SVD on centered data is PCA, but SVD on noncentered data is not equivalent to PCA.

It's easy to see the difference with a simple example:

>>> x = datasets.load_boston().data
>>> x.mean(axis=0)  # Not centered!
array([  3.61352356,  11.36363636,  11.13677866,   0.06916996,
         0.55469506,   6.28463439,  68.57490119,   3.79504269,
         9.54940711, 408.23715415,  18.4555336 , 356.67403162,
        12.65306324])

>>> sksvd = decomposition.TruncatedSVD(n_components=6)
>>> sksvd.fit(x)
>>> sksvd.singular_values_
array([12585.18158999,  3445.97405949,   645.75710949,   402.05046109,
         158.96461248,   121.50293599])

>>> skpca = decomposition.PCA(n_components=6)
>>> skpca.fit(x)
>>> skpca.singular_values_
array([3949.60823205, 1776.63071036,  642.86374839,  366.98207453,
        158.634553  ,  118.64982369])

Now, I'm dealing with huge sparse matrices. If I densify them, they would not fit into memory, so centering the matrix is not an option (we lose all sparsity with centering).

Yes, I could perform incremental PCA, loading chunks of the matrix into memory at a time, but that's really slow. Yes, I could perform SVD on these matrices, which is fine, but not the same as PCA. Now, like I said in my first comment, this can be done i.e. we can compute PCA on noncentered data directly, without ever having to directly perform centering. This is what I'm missing. I hope I've it's clear enough this time. PCA is possible and is very fast on sparse matrices, and I feel it really should be included in scikit-learn.

@jnothman
Copy link
Member

jnothman commented Dec 18, 2018 via email

@glemaitre
Copy link
Member

glemaitre commented Dec 19, 2018

Yes, I am aware of TruncatedSVD, but that's different. It's true that SVD on centered data is PCA, but SVD on noncentered data is not equivalent to PCA.

I am confused. It seems that you want PCA with non centered data which is as you mentioned, different from PCA. So it seems that you want the truncated SVD, isn't it?

@pavlin-policar
Copy link
Author

Perhaps I've been unclear. I'll try to be more clear this time :)

What we know
We know that SVD is not PCA. It is different. But, we usually compute PCA by taking a matrix X, centering it, then applying SVD. This is equivalent to taking X, centering it, computing its covariance matrix C, then computing the eigenvalues and eigenvectors of C.

So this is nothing new. We typically compute PCA via the SVD. And when X is centered, SVD is the same as PCA. Otherwise they are different.

The problem
Sometimes we can't center our matrix. If we have a huge sparse matrix, centering it will eliminate the sparseness and make the matrix dense. This is often impossible if X is really, really big. It can't fit into memory. So we can't do PCA. There are incremental ways to do this (I believe incremental PCA is used for this), but this requires a lot of disk reads, so it's slow.

We can still do SVD, but since X is not centered, this is not the same as PCA. PCA is often nicer because it's easier to interpret things. It's pretty hard to interpret an SVD (or I just might not be aware of it).

The solution
In the implementations I referenced (I don't know which paper the formulation comes from - I haven't gone through the maths), they implement randomized SVD and randomized PCA. But, PCA is computed with a randomized method that never needs to center X. To emphasise this, it is possible to compute actual PCA without ever having to center the original, potentially huge X matrix.

To achieve this, we take the already implemented randomized SVD algorithm, and add a couple of negative terms, which account for centering. The changes to the algorithm are very minor.

The randomized SVD implemented here is great and could be extended with just a couple of terms, and we could have it compute both the SVD and PCA. Setting a single flag e.g. apply_centering to randomized_svd could switch between the two.

Conclusion
I hope I've made this very clear this time. I also hope it's clear that having this in scikit-learn would be very beneficial. Again, this is a more efficient way to compute PCA on sparse matrices, one that doesn't require us to make them dense. PCA is already implemented in scikit-learn, so adding an implementation that supports sparse data seems like the natural next step. This is not the same as SVD. PCA is nicer than SVD because it has a clearer interpretation.

I could work on this if needed, but I am not at all familliar with the codebase, and have fairly limited time.

@glemaitre
Copy link
Member

glemaitre commented Dec 19, 2018 via email

@amueller
Copy link
Member

Yeah I think it's clear and I think we probably want it but we need to research the math and check speed tradeoffs and complexity of adding the method etc.

@pavlin-policar
Copy link
Author

pavlin-policar commented Dec 19, 2018

Actually, I've just briefly looked at the implementations and it seems to be astonishingly simple.

Randomized methods use power iteration to increase the spectral gap, which is just a fancy way of saying multiplication by the original matrix X multiple times. SVD already does this in the current implementation.
Since PCA is just SVD(X - means), and SVD is computed with power iteration, now the power iteration looks like this: (X - means) * Q where Q is the basis. And if we use the distributive property of the dot product we can instead do (X * means) - (means * Q), meaning X never has to be centered.

I don't think I've explained this properly (clearly I need to work on my explanations), but it's just a clever application of the distributive property of the dot product.

@glemaitre
Copy link
Member

If I am not wrong (I never really read the code before) it looks like the solver randomized_svd:
Finding structure with randomness: Stochastic algorithms for constructing approximate matrix decompositions, Halko, et al., 2009 https://arxiv.org/abs/0909.4061

which is the one used by default in the TruncatedSVD.

@pavlin-policar
Copy link
Author

pavlin-policar commented Dec 20, 2018

So, I've played around with this a little bit and I think these are all the changes needed int the randomized SVD solver.

The only real change is that whenever we apply a non-centered matrix A to any other matrix, we distribute the mean subtraction like so: (A - m) X = AX - mX, avoiding the need to ever center A.

diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py
index 19df5b161..0ab75721e 100644
--- a/sklearn/utils/extmath.py
+++ b/sklearn/utils/extmath.py
@@ -147,6 +147,7 @@ def safe_sparse_dot(a, b, dense_output=False):
 
 def randomized_range_finder(A, size, n_iter,
                             power_iteration_normalizer='auto',
+                            subtract_mean=False,
                             random_state=None):
     """Computes an orthonormal matrix whose range approximates the range of A.
 
@@ -211,28 +212,39 @@ def randomized_range_finder(A, size, n_iter,
         else:
             power_iteration_normalizer = 'LU'
 
+    if subtract_mean:
+        c = np.mean(A, axis=0).reshape((1, -1))
+        applyA = lambda X: safe_sparse_dot(A, X) - safe_sparse_dot(c, X)
+        applyAT = lambda X: safe_sparse_dot(A.T, X) - \
+                            safe_sparse_dot(c.T, Q.sum(axis=0).reshape((1, -1)))
+    else:
+        applyA = lambda X: safe_sparse_dot(A, X)
+        applyAT = lambda X: safe_sparse_dot(A.T, X)
+
+    Q = applyA(Q)
+
     # Perform power iterations with Q to further 'imprint' the top
     # singular vectors of A in Q
     for i in range(n_iter):
         if power_iteration_normalizer == 'none':
-            Q = safe_sparse_dot(A, Q)
-            Q = safe_sparse_dot(A.T, Q)
+            Q = applyAT(Q)
+            Q = applyA(Q)
         elif power_iteration_normalizer == 'LU':
-            Q, _ = linalg.lu(safe_sparse_dot(A, Q), permute_l=True)
-            Q, _ = linalg.lu(safe_sparse_dot(A.T, Q), permute_l=True)
+            Q, _ = linalg.lu(applyAT(Q), permute_l=True)
+            Q, _ = linalg.lu(applyA(Q), permute_l=True)
         elif power_iteration_normalizer == 'QR':
-            Q, _ = linalg.qr(safe_sparse_dot(A, Q), mode='economic')
-            Q, _ = linalg.qr(safe_sparse_dot(A.T, Q), mode='economic')
+            Q, _ = linalg.qr(applyAT(Q), mode='economic')
+            Q, _ = linalg.qr(applyA(Q), mode='economic')
 
     # Sample the range of A using by linear projection of Q
     # Extract an orthonormal basis
-    Q, _ = linalg.qr(safe_sparse_dot(A, Q), mode='economic')
+    Q, _ = linalg.qr(Q, mode='economic')
     return Q
 
 
 def randomized_svd(M, n_components, n_oversamples=10, n_iter='auto',
                    power_iteration_normalizer='auto', transpose='auto',
-                   flip_sign=True, random_state=0):
+                   flip_sign=True, subtract_mean=False, random_state=0):
     """Computes a truncated randomized SVD
 
     Parameters
@@ -333,11 +345,20 @@ def randomized_svd(M, n_components, n_oversamples=10, n_iter='auto',
         # this implementation is a bit faster with smaller shape[1]
         M = M.T
 
-    Q = randomized_range_finder(M, n_random, n_iter,
-                                power_iteration_normalizer, random_state)
+    Q = randomized_range_finder(
+        M,
+        size=n_random,
+        n_iter=n_iter,
+        power_iteration_normalizer=power_iteration_normalizer,
+        subtract_mean=subtract_mean,
+        random_state=random_state,
+    )
 
     # project M to the (k + p) dimensional space using the basis vectors
     B = safe_sparse_dot(Q.T, M)
+    if subtract_mean:
+        c = np.mean(M, axis=0).reshape((1, -1))
+        B -= np.dot(c.T, Q.sum(axis=0).reshape((1, -1)))
 
     # compute the SVD on the thin matrix: (k + p) wide
     Uhat, s, V = linalg.svd(B, full_matrices=False)

I've run all the tests for TruncatedSVD, and they pass. In the PCA learner, there's one problematic line:

total_var = np.var(X, ddof=1, axis=0)

which fails on sparse matrices. I'd expect something like sparse variance to be already implemented, but I'm not familliar with the codebase and don't know where to look. If not, it should be relatively straightforward to implement this as well.

@lesshaste
Copy link

For the variance I guess you can just use the formula:

E[X^2] - (E[X])^2

@jeremiedbb
Copy link
Member

I'd expect something like sparse variance to be already implemented, but I'm not familliar with the codebase and don't know where to look. If not, it should be relatively straightforward to implement this as well.

in sklearn you have mean_variance_axis in utils/sparsefuncs_fast.py, so that sould not be an issue

@glemaitre
Copy link
Member

I looked into the paper and they actually only speak of the SVD which explain why our implementation does not remove the mean.

Then, your proposal looks good to me. Since that we already have solver='auto' in PCA, we could default to always 'randomized' for sparse matrices with little changes within the code.

@pavlin-policar thanks a lot for the clarifications and the patience ;)

@adrinjalali
Copy link
Member

@pavlin-policar would you be keen on opening a PR with the changes?

@pavlin-policar
Copy link
Author

Yeah, sure. I'll give it a shot.

@jeremiedbb Thank's that's going to make things easier.

@pavlin-policar
Copy link
Author

I've opened #12841 which implements the changes I described above.

@lobpcg
Copy link
Contributor

lobpcg commented Jul 2, 2019

#12319 would be a (faster) alternative to 'randomized' for sparse matrices that also allows running matrix-free SVD, i.e. on a matrix, given by a function. In particular, performing implicit data centering, without changing the data. For that matter, the ARPACK SVD solver is also matrix-free by design.

Thus, one should really make the implicit data centering to be solver agnostic, allowing any matrix-free SVD solver to be used, rather than concentrating just on 'randomized'

@atarashansky
Copy link

atarashansky commented Dec 21, 2019

@lobpcg -- I'm really interested in allowing implicit data centering using your solver. I checked out the lobpcg_svd branch of your forked repo ( #12319 ). How should I go about implementing implicit data-centering using your solver? I'm not really that great with the math so I'd appreciate any pointers!

@lobpcg
Copy link
Contributor

lobpcg commented Dec 21, 2019

@lobpcg -- I'm really interested in allowing implicit data centering using your solver. I checked out the lobpcg_svd branch of your forked repo ( #12319 ). How should I go about implementing implicit data-centering using your solver? I'm not really that great with the math so I'd appreciate any pointers!

@atarashansky
#12841 (unfortunately also stalled) already implements implicit data centering, only using different solvers. If you just need to run implicit data centering - that would probably be the easiest way.

If you want to code specifically LOBPCG for PCA with implicit data centering and submit as a PR to scikit, you probably want to pick up #12841, update it to resolve the conflicts, and add the support for lobpcg_svd #12319 . But both #12841 and #12319 are currently stuck apparently with no clear path to merging :-(

If you want to run a code specifically with LOBPCG for PCA with implicit data centering for yourself, it is now easy, since PCA is just SVD(X - means), and LOBPCG for SVD is already merged into scipy in scipy/scipy#10830 You just need to code the LinearOperator A multiplying (X - means) * Q as X*Q - means * Q and similar for the transpose to pass to svds in https://github.com/scipy/scipy/blob/0db79f6a56ba8187380288372993800331e9f9ba/scipy/sparse/linalg/eigen/arpack/arpack.py since now:

def svds(A, k=6, ncv=None, tol=0, which='LM', v0=None,
         maxiter=None, return_singular_vectors=True,
         solver='arpack'):
    """Compute the largest or smallest k singular values/vectors for a sparse matrix. The order of the singular values is not guaranteed.
    Parameters
    ----------
    A : {sparse matrix, LinearOperator}
        Array to compute the SVD on, of shape (M, N)
    k : int, optional
        Number of singular values and vectors to compute.
        Must be 1 <= k < min(A.shape).
    ncv : int, optional
        The number of Lanczos vectors generated
        ncv must be greater than k+1 and smaller than n;
        it is recommended that ncv > 2*k
        Default: ``min(n, max(2*k + 1, 20))``
    tol : float, optional
        Tolerance for singular values. Zero (default) means machine precision.
    which : str, ['LM' | 'SM'], optional
        Which `k` singular values to find:
            - 'LM' : largest singular values
            - 'SM' : smallest singular values
        .. versionadded:: 0.12.0
    v0 : ndarray, optional
        Starting vector for iteration, of length min(A.shape). Should be an
        (approximate) left singular vector if N > M and a right singular
        vector otherwise.
        Default: random
        .. versionadded:: 0.12.0
    maxiter : int, optional
        Maximum number of iterations.
        .. versionadded:: 0.12.0
    return_singular_vectors : bool or str, optional
        - True: return singular vectors (True) in addition to singular values.
        .. versionadded:: 0.12.0
        - "u": only return the u matrix, without computing vh (if N > M).
        - "vh": only return the vh matrix, without computing u (if N <= M).
        .. versionadded:: 0.16.0
    solver : str, optional
            Eigenvalue solver to use. Should be 'arpack' or 'lobpcg'.
            Default: 'arpack'

@atarashansky
Copy link

atarashansky commented Dec 21, 2019

@lobpcg -- I have a pretty naive question:

The first dimension of Q will be the minimum dimension of the input data, right? If I have more features (m) than samples (n), then Q will be of shape (n x k). But 'means' will be shape (1 x m).

So, in " X*Q - means * Q", is 'means' the mean of either columns or rows depending on the shape of Q? And if its the mean of the rows, is that still equivalent to PCA (SVD on feature-centered data)?

EDIT: Ah, nevermind, I just got confused by the notation. I got it working! Thanks so much!

@atarashansky
Copy link

atarashansky commented Dec 21, 2019

For others who are looking for a solution while waiting for the aforementioned PRs to go through, here is what worked for me:

def sparse_pca(X,npcs,mu = None):
    # X -- scipy sparse data matrix
    # npcs -- number of principal components
    # mu -- precomputed feature means. if None, calculates them from X.

    # compute mean of data features
    if mu is None: 
        mu = X.mean(0).A.flatten()[None,:]

    # dot product operator for the means
    mmat = mdot = mu.dot 
    # dot product operator for the transposed means
    mhmat = mhdot = mu.T.dot 
    # dot product operator for the data
    Xmat = Xdot = X.dot 
    # dot product operator for the transposed data
    XHmat = XHdot = X.T.conj().dot 
    # dot product operator for a vector of ones
    ones = np.ones(X.shape[0])[None,:].dot 

    # modify the matrix/vector dot products to subtract the means
    def matvec(x): 
        return Xdot(x) - mdot(x)
    def matmat(x): 
        return Xmat(x) - mmat(x)
    def rmatvec(x): 
        return XHdot(x) - mhdot(ones(x))
    def rmatmat(x): 
        return XHmat(x) - mhmat(ones(x))
    
    # construct the LinearOperator
    XL = sp.sparse.linalg.LinearOperator(matvec = matvec, dtype = X.dtype,
                                         matmat = matmat,
                                         shape = X.shape,
                                        rmatvec = rmatvec, rmatmat = rmatmat)
     
    # I chose to use 'arpack' in this case as it is 2x faster for my specific application
    u,s,v = sp.sparse.linalg.svds(XL,solver='arpack',k=npcs)
    
    # i like my eigenvalues sorted in decreasing order
    idx = np.argsort(-s)
    S = np.diag(s[idx])
    # principal components
    pcs = u[:,idx].dot(S) 
    # equivalent to PCA.components_ in sklearn 
    components_ = v[idx,:] 
    return pcs,components_

Anecdotally, on my dataset (120,000 x 10,000), lobpcg took ~290s, arpack took ~110s when computing the top 200 components. 10,000 / 200 = 50 which, I'm assuming, should be in the regime that lobpcg outperforms other methods, so I'd be curious to know what's going on here... Granted, that's a separate issue.

@lobpcg
Copy link
Contributor

lobpcg commented Dec 22, 2019

@atarashansky
To evaluate lobpcg vs, arpack comparison deeper, you may want to play with k=npcs. arpack should get just more expensive proportional to increase of k, while lobpcg may suddenly converge faster, since it iterates the whole k block together. You may also want to play with the tolerance and with choosing the initial approximations if you have some better guess than random. And of course make sure you link Python to fast BLAS3 libraries to take advantage of lobpcg blocking advantage over arpack.

@atarashansky
Copy link

atarashansky commented Dec 22, 2019

Does order of operations mitigate this issue? ones(x) forms a scalar and mhdot(ones(x)) is just a column vector multiplied by a scalar.

I was able to do what you suggest for the matvec and matmat, as seen above, but I ran into dimension mismatch errors when I tried doing the same for rmatvec and rmatmat (i.e. without explicitly including the vector of ones).

If X is (n x m) and mu is (1 x m), then mean centering is X * x - 1 * mu * x where 1 * mu is an outer product between a column of 1s of size n and the means. In the code, however, we don't need the column of ones because mu * x will be broadcasted and subtracted from every row in X * x. I ran into confusion when I tried doing the same for the transpose, (X^T - mu^T * 1^T) * x = X^T * x - mu^T * 1^T * x, as now the ones are in the middle of the product..

@lobpcg
Copy link
Contributor

lobpcg commented Dec 22, 2019

@atarashansky I actually misread your code yesterday and made technically wrong comments (now deleted), but you still were able to see my concern.

May be try efficiently implementing mu^T ( 1^T * x) instead of (mu^T * 1^T ) x, also keeping in mind that x is a matrix in rmatmat. May be something like ( x^T *1)^T could work?

You can try using a similar trick to code the X-part of rmatvec and rmatmat without explicitly forming X.T.conj() (which may double the memory use, also storing X^T ?) using instead something like (x^T*X)^T.

In https://www.mathworks.com/matlabcentral/fileexchange/48-locally-optimal-block-preconditioned-conjugate-gradient, I feed LOBPCG the function funA = @(v)((X*v)'*X)') for A = X'*X to perform svd(X)

Please let me know if you find the "optimal" way to code it, i.e., without doing anything to X and creating no new matrices of the size larger than size(x).

Another fun option to try is to actually generate (some of or all) 4 matrices resulting from the product (X^T - mu^T * 1^T) *(X - 1 * mu). Depending on the size and sparsity of X, as well as on the convergence speed of LOBPCG, that might be a competing option. I have found testing svd(X) with LOBPCG, that explicitly computing X'X once might be beneficial speed-wise compared to the no-extra-memory way feeding funA = @(v)((Xv)'*X)') if size(X) is moderate.

@atarashansky
Copy link

I tried running the function with and without the mean centering on a sparse 30,000 x 30,000 dataset, and the memory usage and timings are identical according to memit and timeit (using 'arpack').

I'm sure the way I coded it is indeed not optimal, but it seems that, practically, it makes no difference, at least not in the way svds is using the LinearOperator.

@lobpcg
Copy link
Contributor

lobpcg commented Dec 24, 2019

Thanks for checking!
I do not remember how the LinearOperator is implemented, e.g., if XHmat = XHdot = X.T.conj().dot actually creates this matrix explicitly or not.

'arpack' is pre-compiled - I do not know if memit accurately captures the memory usage in such a case.

@dkobak
Copy link
Contributor

dkobak commented Dec 26, 2019

@atarashansky This is an amazing code snippet! I haven't tried it out myself but assuming it works as advertised, this is really cool. I wasn't aware that one can feed LinearOperator directly into svds (how does the interfacing with ARPACK work in this case?).

@pavlin-policar You have that old PR that implements PCA for sparse matrix with the randomized solver. This looks like a great way to implement the same using arpack solver. Does any of you want to add it to that PR?

@atarashansky
Copy link

@dkobak -- To answer your question about how the interfacing with ARPACK works in this case, from @lobpcg above:

#12319 would be a (faster) alternative to 'randomized' for sparse matrices that also allows running matrix-free SVD, i.e. on a matrix, given by a function. In particular, performing implicit data centering, without changing the data. For that matter, the ARPACK SVD solver is also matrix-free by design. Thus, one should really make the implicit data centering to be solver agnostic, allowing any matrix-free SVD solver to be used, rather than concentrating just on 'randomized'

Apparently, the ARPACK solver is also matrix-free by design. In svds, the customized dot products wrapped by the input LinearOperator is used to create the hermetian XH_X which is then input into eigsh:

    def matvec_XH_X(x):
        return XH_dot(X_dot(x))

    def matmat_XH_X(x):
        return XH_mat(X_matmat(x))

    XH_X = LinearOperator(matvec=matvec_XH_X, dtype=A.dtype,
                          matmat=matmat_XH_X,
                          shape=(min(A.shape), min(A.shape)))

...

    elif solver == 'arpack' or solver is None:
        eigvals, eigvec = eigsh(XH_X, k=k, tol=tol ** 2, maxiter=maxiter,
                                ncv=ncv, which=which, v0=v0)

And then in eigsh:

    params = _SymmetricArpackParams(n, k, A.dtype.char, matvec, mode,
                                    M_matvec, Minv_matvec, sigma,
                                    ncv, v0, maxiter, which, tol)
    with _ARPACK_LOCK:
        while not params.converged:
            params.iterate()

        return params.extract(return_eigenvectors)

The ARPACK class just accepts the dot product functions as input, meaning the Hermetian and input data matrix never need to be explicitly generated.

@lobpcg
Copy link
Contributor

lobpcg commented Sep 10, 2022

I see. I'll have to check but the fear is that a sparse algorithm converts dense inputs to a sparse representation, which would be... inefficient.

It doesn't. All svds sparse algorithms in SciPy are matrix-free, i.e., they don't modify the data matrix M in any way, so the matrix access is read only to perform products by M.

Moreover, the matrix can be accessed only via a function that performs the product, e.g., the matrix entries could even be computed on the fly as needed.

@andportnoy
Copy link
Contributor

Got it, thank you.

@andportnoy
Copy link
Contributor

I have randomized SVD now working along with ARPACK, though the code is still at the proof of concept stage.

Apparently __matmul__ is not supported when a LinearOperator is on the RHS. Meanwhile A @ B, where A is a NumPy array and B is a LinearOperator, is used in the randomized SVD implementation (specifically the part where X is projected onto the small random space with basis given by orthonormal Q)

As a workaround I used $AB = (B^TA^T)^T$ to force B to the LHS.

I tried to debug using a debug build of Python but I've only figured out that the object corresponding to the LinearOperator has its nd (number of dimensions) struct field equal to 0 when the object is interpreted as a NumPy array, and that fails a check in the NumPy __matmul__.

The code is in #24415.

@lobpcg
Copy link
Contributor

lobpcg commented Oct 8, 2022

@andportnoy Yes, I am sure that LinearOperator must be on the LHS. Your workaround looks reasonable.

If you start using LOBPCG in addition to ARPACK, make sure to specify also matmat, not just matvec, if the definition of the LinearOperator

Please file a bug report if you believe that this is a bug:
"LinearOperator has its nd (number of dimensions) struct field equal to 0 when the object is interpreted as a NumPy array, and that fails a check in the NumPy matmul."

@andportnoy
Copy link
Contributor

Got it, thanks. Going to work on @lobpcg next (and PROPACK as well).

I'll spend some more time trying to root cause the error (why isn't this struct field populated? is that expected? can it be fixed?). Will file a bug report in SciPy.

@andportnoy
Copy link
Contributor

andportnoy commented Oct 22, 2022

I've looked at the error, it's not a bug I don't know if it should be considered a bug. At a high level, when a ufunc is called NumPy runs convert_ufunc_arguments, which in turn uses a bunch of helpers to examine the inputs and convert them to PyArrayObject's. The helpers end up concluding the LinearOperator has zero dimensions which breaks matmul's expectations with respect to the shapes of its inputs.

The easiest way to reproduce is np.array([]) @ None where None similarly ends up as a PyArrayObject with .nd=0.

NumPy very aggressively converts all kinds of objects to arrays, which can be useful but in this case it's slightly confusing. I would have preferred an error message along the lines of "cannot coerce a LinearOperator to a NumPy array".

Relevant stack traces:

max_ndim set to 0
#0  update_shape (curr_ndim=0, max_ndim=0x7fffffffc3b0, out_shape=0x7fffffffc580, new_ndim=0, 
    new_shape=0x0, sequence=0 '\000', flags=0x7fffffffc494)
    at numpy/core/src/multiarray/array_coercion.c:572
#1  0x00007fffe98ec220 in handle_scalar (
    obj=<_CustomLinearOperator(dtype=<numpy.dtype[float64] at remote 0x7fffe9d29900>, shape=(3, 4), args=(), _CustomLinearOperator__matvec_impl=<function at remote 0x7fffe5ea6af0>, _CustomLinearOperator__rmatvec_impl=<function at remote 0x7fffe5d322b0>, _CustomLinearOperator__rmatmat_impl=<function at remote 0x7fff953a52e0>, _CustomLinearOperator__matmat_impl=<function at remote 0x7fff953a5230>) at remote 0x7fffea083b60>, curr_dims=0, max_dims=0x7fffffffc3b0, out_descr=0x7fffffffc530, 
    out_shape=0x7fffffffc580, fixed_DType=0x0, flags=0x7fffffffc494, DType=0x0)
    at numpy/core/src/multiarray/array_coercion.c:738
#2  0x00007fffe98ed073 in PyArray_DiscoverDTypeAndShape_Recursive (
    obj=<_CustomLinearOperator(dtype=<numpy.dtype[float64] at remote 0x7fffe9d29900>, shape=(3, 4), args=(), _CustomLinearOperator__matvec_impl=<function at remote 0x7fffe5ea6af0>, _CustomLinearOperator__rmatvec_impl=<function at remote 0x7fffe5d322b0>, _CustomLinearOperator__rmatmat_impl=<function at remote 0x7fff953a52e0>, _CustomLinearOperator__matmat_impl=<function at remote 0x7fff953a5230>) at remote 0x7fffea083b60>, curr_dims=0, max_dims=32, out_descr=0x7fffffffc530, 
    out_shape=0x7fffffffc580, coercion_cache_tail_ptr=0x7fffffffc470, fixed_DType=0x0, 
    flags=0x7fffffffc494, never_copy=0) at numpy/core/src/multiarray/array_coercion.c:1133
#3  0x00007fffe98ed57e in PyArray_DiscoverDTypeAndShape (
    obj=<_CustomLinearOperator(dtype=<numpy.dtype[float64] at remote 0x7fffe9d29900>, shape=(3, 4), args=(), _CustomLinearOperator__matvec_impl=<function at remote 0x7fffe5ea6af0>, _CustomLinearOperator__rmatvec_impl=<function at remote 0x7fffe5d322b0>, _CustomLinearOperator__rmatmat_impl=<function at remote 0x7fff953a52e0>, _CustomLinearOperator__matmat_impl=<function at remote 0x7fff953a5230>) at remote 0x7fffea083b60>, max_dims=32, out_shape=0x7fffffffc580, coercion_cache=0x7fffffffc538, 
    fixed_DType=0x0, requested_descr=0x0, out_descr=0x7fffffffc530, never_copy=0)
    at numpy/core/src/multiarray/array_coercion.c:1289
#4  0x00007fffe9913346 in PyArray_FromAny (
    op=<_CustomLinearOperator(dtype=<numpy.dtype[float64] at remote 0x7fffe9d29900>, shape=(3, 4), args=(), _CustomLinearOperator__matvec_impl=<function at remote 0x7fffe5ea6af0>, _CustomLinearOperator__rmatvec_impl=<function at remote 0x7fffe5d322b0>, _CustomLinearOperator__rmatmat_impl=<function at remote 0x7fff953a52e0>, _CustomLinearOperator__matmat_impl=<function at remote 0x7fff953a5230>) at remote 0x7fffea083b60>, newtype=0x0, min_depth=0, max_depth=0, flags=0, context=0x0)
    at numpy/core/src/multiarray/ctors.c:1629
#5  0x00007fffe9ae0a92 in convert_ufunc_arguments (ufunc=0x7fffe6c58c00, full_args=..., 
    out_op=0x7fffffffc9b0, out_op_DTypes=0x7fffffffcab0, force_legacy_promotion=0x7fffffffc7cb "", 
    allow_legacy_promotion=0x7fffffffc7cc "\001", promoting_pyscalars=0x7fffffffc7cd "", 
    order_obj=0x0, out_order=0x7fffffffc7d0, casting_obj=0x0, out_casting=0x7fffffffc7d4, 
    subok_obj=0x0, out_subok=0x7fffffffc7ca "\001", where_obj=0x0, out_wheremask=0x7fffffffc7f8, 
    keepdims_obj=0x0, out_keepdims=0x7fffffffc7d8) at numpy/core/src/umath/ufunc_object.c:977
#6  0x00007fffe9aea9ed in ufunc_generic_fastcall (ufunc=0x7fffe6c58c00, args=0x7fffffffcef0, 
    len_args=2, kwnames=0x0, outer=0 '\000') at numpy/core/src/umath/ufunc_object.c:4896
#7  0x00007fffe9aeae8e in ufunc_generic_vectorcall (ufunc=<numpy.ufunc at remote 0x7fffe6c58c00>, 
    args=0x7fffffffcef0, len_args=2, kwnames=0x0) at numpy/core/src/umath/ufunc_object.c:5011
#8  0x00007ffff7b7a1cc in _PyObject_VectorcallTstate (tstate=0x555555577360, 
    callable=<numpy.ufunc at remote 0x7fffe6c58c00>, args=0x7fffffffcef0, nargsf=2, kwnames=0x0)
    at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Include/cpython/abstract.h:114
#9  0x00007ffff7b7c5c8 in object_vacall (tstate=0x555555577360, base=0x0, 
    callable=<numpy.ufunc at remote 0x7fffe6c58c00>, vargs=0x7fffffffcf50)
    at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Objects/call.c:734
#10 0x00007ffff7b7cba4 in PyObject_CallFunctionObjArgs (
    callable=<numpy.ufunc at remote 0x7fffe6c58c00>)
    at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Objects/call.c:841
#11 0x00007fffe9a1c435 in PyArray_GenericBinaryFunction (
    m1=<numpy.ndarray at remote 0x7fff95382ec0>, 
    m2=<_CustomLinearOperator(dtype=<numpy.dtype[float64] at remote 0x7fffe9d29900>, shape=(3, 4), args=(), _CustomLinearOperator__matvec_impl=<function at remote 0x7fffe5ea6af0>, _CustomLinearOperator__rmatvec_impl=<function at remote 0x7fffe5d322b0>, _CustomLinearOperator__rmatmat_impl=<function at remote 0x7fff953a52e0>, _CustomLinearOperator__matmat_impl=<function at remote 0x7fff953a5230>) at remote 0x7fffea083b60>, op=<numpy.ufunc at remote 0x7fffe6c58c00>)
    at numpy/core/src/multiarray/number.c:270
#12 0x00007fffe9a1c934 in array_matrix_multiply (m1=<numpy.ndarray at remote 0x7fff95382ec0>, 
    m2=<_CustomLinearOperator(dtype=<numpy.dtype[float64] at remote 0x7fffe9d29900>, shape=(3, 4), args=(), _CustomLinearOperator__matvec_impl=<function at remote 0x7fffe5ea6af0>, _CustomLinearOperator__rmatvec_impl=<function at remote 0x7fffe5d322b0>, _CustomLinearOperator__rmatmat_impl=<function at remote 0x7fff953a52e0>, _CustomLinearOperator__matmat_impl=<function at remote 0x7fff953a5230>) at remote 0x7fffea083b60>) at numpy/core/src/multiarray/number.c:347
#13 0x00007ffff7b4bb1f in binary_op1 (v=<numpy.ndarray at remote 0x7fff95382ec0>, 
    w=<_CustomLinearOperator(dtype=<numpy.dtype[float64] at remote 0x7fffe9d29900>, shape=(3, 4), args=(), _CustomLinearOperator__matvec_impl=<function at remote 0x7fffe5ea6af0>, _CustomLinearOperator__rmatvec_impl=<function at remote 0x7fffe5d322b0>, _CustomLinearOperator__rmatmat_impl=<function at remote 0x7fff953a52e0>, _CustomLinearOperator__matmat_impl=<function at remote 0x7fff953a5230>) at remote 0x7fffea083b60>, op_slot=272, op_name=0x7ffff7e3877e "@")
    at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Objects/abstract.c:891
#14 0x00007ffff7b4bcd8 in binary_op (v=<numpy.ndarray at remote 0x7fff95382ec0>, 
    w=<_CustomLinearOperator(dtype=<numpy.dtype[float64] at remote 0x7fffe9d29900>, shape=(3, 4), args=(), _CustomLinearOperator__matvec_impl=<function at remote 0x7fffe5ea6af0>, _CustomLinearOperator__rmatvec_impl=<function at remote 0x7fffe5d322b0>, _CustomLinearOperator__rmatmat_impl=<function at remote 0x7fff953a52e0>, _CustomLinearOperator__matmat_impl=<function at remote 0x7fff953a5230>) at remote 0x7fffea083b60>, op_slot=272, op_name=0x7ffff7e3877e "@")
    at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Objects/abstract.c:930
#15 0x00007ffff7b4c65a in PyNumber_MatrixMultiply (v=<numpy.ndarray at remote 0x7fff95382ec0>, 
    w=<_CustomLinearOperator(dtype=<numpy.dtype[float64] at remote 0x7fffe9d29900>, shape=(3, 4), args=(), _CustomLinearOperator__matvec_impl=<function at remote 0x7fffe5ea6af0>, _CustomLinearOperator__rmatvec_impl=<function at remote 0x7fffe5d322b0>, _CustomLinearOperator__rmatmat_impl=<function at remote 0x7fff953a52e0>, _CustomLinearOperator__matmat_impl=<function at remote 0x7fff953a5230>) at remote 0x7fffea083b60>) at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Objects/abstract.c:1128
#16 0x00007ffff7cea99b in _PyEval_EvalFrameDefault (tstate=0x555555577360, 
    f=Frame 0x5555555d8f50, for file /home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py, line 19, in <module> (), throwflag=0)
    at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Python/ceval.c:2015
#17 0x00007ffff7ce559f in _PyEval_EvalFrame (tstate=0x555555577360, 
    f=Frame 0x5555555d8f50, for file /home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py, line 19, in <module> (), throwflag=0)
    at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Include/internal/pycore_ceval.h:46
#18 0x00007ffff7cfdd52 in _PyEval_Vector (tstate=0x555555577360, con=0x7fffffffe010, 
    locals={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py') at remote 0x7fffea026580>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7fffea15c290>, '__file__': '/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py', '__cached__': None, 'np': <module at remote 0x7fffea0987d0>, 'LinearOperator': <type at remote 0x555555a8f710>, 'linear_operator_from_matrix': <function at remote 0x7fffea188940>, 'A': <numpy.ndarray at remote 0x7fff95382ec0>, 'B': <_CustomLinearOperator(dtype=<numpy.dtype[float64] at remote 0x7fffe9d29900>, shape=(3, 4), args=(), _CustomLinearOperator__matvec_impl=<function at remote 0x7fffe5ea6af0>, _CustomLinearOperator__rmatvec_impl=<function at remote 0x7fffe5d322b0>, _CustomLinearOperator__rmatmat_impl=<function at remote 0x7fff953a52e0>, _CustomLinearOperator__matmat_impl=<function at remote 0x7fff953a5230>) at remote 0x7ff...(truncated), args=0x0, argcount=0, kwnames=0x0)
    at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Python/ceval.c:5065
#19 0x00007ffff7ce86ee in PyEval_EvalCode (co=<code at remote 0x7fffea043780>, 
    globals={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py') at remote 0x7fffea026580>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7fffea15c290>, '__file__': '/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py', '__cached__': None, 'np': <module at remote 0x7fffea0987d0>, 'LinearOperator': <type at remote 0x555555a8f710>, 'linear_operator_from_matrix': <function at remote 0x7fffea188940>, 'A': <numpy.ndarray at remote 0x7fff95382ec0>, 'B': <_CustomLinearOperator(dtype=<numpy.dtype[float64] at remote 0x7fffe9d29900>, shape=(3, 4), args=(), _CustomLinearOperator__matvec_impl=<function at remote 0x7fffe5ea6af0>, _CustomLinearOperator__rmatvec_impl=<function at remote 0x7fffe5d322b0>, _CustomLinearOperator__rmatmat_impl=<function at remote 0x7fff953a52e0>, _CustomLinearOperator__matmat_impl=<function at remote 0x7fff953a5230>) at remote 0x7ff...(truncated), 
    locals={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py') at remote 0x7fffea026580>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7fffea15c290>, '__file__': '/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py', '__cached__': None, 'np': <module at remote 0x7fffea0987d0>, 'LinearOperator': <type at remote 0x555555a8f710>, 'linear_operator_from_matrix': <function at remote 0x7fffea188940>, 'A': <numpy.ndarray at remote 0x7fff95382ec0>, 'B': <_CustomLinearOperator(dtype=<numpy.dtype[float64] at remote 0x7fffe9d29900>, shape=(3, 4), args=(), _CustomLinearOperator__matvec_impl=<function at remote 0x7fffe5ea6af0>, _CustomLinearOperator__rmatvec_impl=<function at remote 0x7fffe5d322b0>, _CustomLinearOperator__rmatmat_impl=<function at remote 0x7fff953a52e0>, _CustomLinearOperator__matmat_impl=<function at remote 0x7fff953a5230>) at remote 0x7ff...(truncated))
    at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Python/ceval.c:1134
#20 0x00007ffff7d69719 in run_eval_code_obj (tstate=0x555555577360, co=0x7fffea043780, 
    globals={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py') at remote 0x7fffea026580>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7fffea15c290>, '__file__': '/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py', '__cached__': None, 'np': <module at remote 0x7fffea0987d0>, 'LinearOperator': <type at remote 0x555555a8f710>, 'linear_operator_from_matrix': <function at remote 0x7fffea188940>, 'A': <numpy.ndarray at remote 0x7fff95382ec0>, 'B': <_CustomLinearOperator(dtype=<numpy.dtype[float64] at remote 0x7fffe9d29900>, shape=(3, 4), args=(), _CustomLinearOperator__matvec_impl=<function at remote 0x7fffe5ea6af0>, _CustomLinearOperator__rmatvec_impl=<function at remote 0x7fffe5d322b0>, _CustomLinearOperator__rmatmat_impl=<function at remote 0x7fff953a52e0>, _CustomLinearOperator__matmat_impl=<function at remote 0x7fff953a5230>) at remote 0x7ff...(truncated), 
    locals={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py') at remote 0x7fffea026580>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7fffea15c290>, '__file__': '/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py', '__cached__': None, 'np': <module at remote 0x7fffea0987d0>, 'LinearOperator': <type at remote 0x555555a8f710>, 'linear_operator_from_matrix': <function at remote 0x7fffea188940>, 'A': <numpy.ndarray at remote 0x7fff95382ec0>, 'B': <_CustomLinearOperator(dtype=<numpy.dtype[float64] at remote 0x7fffe9d29900>, shape=(3, 4), args=(), _CustomLinearOperator__matvec_impl=<function at remote 0x7fffe5ea6af0>, _CustomLinearOperator__rmatvec_impl=<function at remote 0x7fffe5d322b0>, _CustomLinearOperator__rmatmat_impl=<function at remote 0x7fff953a52e0>, _CustomLinearOperator__matmat_impl=<function at remote 0x7fff953a5230>) at remote 0x7ff...(truncated))
    at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Python/pythonrun.c:1291
#21 0x00007ffff7d6981b in run_mod (mod=0x5555555fd440, 
    filename='/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py', 
    globals={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py') at remote 0x7fffea026580>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7fffea15c290>, '__file__': '/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py', '__cached__': None, 'np': <module at remote 0x7fffea0987d0>, 'LinearOperator': <type at remote 0x555555a8f710>, 'linear_operator_from_matrix': <function at remote 0x7fffea188940>, 'A': <numpy.ndarray at remote 0x7fff95382ec0>, 'B': <_CustomLinearOperator(dtype=<numpy.dtype[float64] at remote 0x7fffe9d29900>, shape=(3, 4), args=(), _CustomLinearOperator__matvec_impl=<function at remote 0x7fffe5ea6af0>, _CustomLinearOperator__rmatvec_impl=<function at remote 0x7fffe5d322b0>, _CustomLinearOperator__rmatmat_impl=<function at remote 0x7fff953a52e0>, _CustomLinearOperator__matmat_impl=<function at remote 0x7fff953a5230>) at remote 0x7ff...(truncated), 
    locals={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py') at remote 0x7fffea026580>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7fffea15c290>, '__file__': '/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py', '__cached__': None, 'np': <module at remote 0x7fffea0987d0>, 'LinearOperator': <type at remote 0x555555a8f710>, 'linear_operator_from_matrix': <function at remote 0x7fffea188940>, 'A': <numpy.ndarray at remote 0x7fff95382ec0>, 'B': <_CustomLinearOperator(dtype=<numpy.dtype[float64] at remote 0x7fffe9d29900>, shape=(3, 4), args=(), _CustomLinearOperator__matvec_impl=<function at remote 0x7fffe5ea6af0>, _CustomLinearOperator__rmatvec_impl=<function at remote 0x7fffe5d322b0>, _CustomLinearOperator__rmatmat_impl=<function at remote 0x7fff953a52e0>, _CustomLinearOperator__matmat_impl=<function at remote 0x7fff953a5230>) at remote 0x7ff...(truncated), flags=0x7fffffffe268, arena=0x7fffea0726e0)
    at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Python/pythonrun.c:1312
#22 0x00007ffff7d694ad in pyrun_file (fp=0x555555588cf0, 
    filename='/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py', start=257, 
    globals={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py') at remote 0x7fffea026580>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7fffea15c290>, '__file__': '/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py', '__cached__': None, 'np': <module at remote 0x7fffea0987d0>, 'LinearOperator': <type at remote 0x555555a8f710>, 'linear_operator_from_matrix': <function at remote 0x7fffea188940>, 'A': <numpy.ndarray at remote 0x7fff95382ec0>, 'B': <_CustomLinearOperator(dtype=<numpy.dtype[float64] at remote 0x7fffe9d29900>, shape=(3, 4), args=(), _CustomLinearOperator__matvec_impl=<function at remote 0x7fffe5ea6af0>, _CustomLinearOperator__rmatvec_impl=<function at remote 0x7fffe5d322b0>, _CustomLinearOperator__rmatmat_impl=<function at remote 0x7fff953a52e0>, _CustomLinearOperator__matmat_impl=<function at remote 0x7fff953a5230>) at remote 0x7ff...(truncated), 
    locals={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py') at remote 0x7fffea026580>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7fffea15c290>, '__file__': '/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py', '__cached__': None, 'np': <module at remote 0x7fffea0987d0>, 'LinearOperator': <type at remote 0x555555a8f710>, 'linear_operator_from_matrix': <function at remote 0x7fffea188940>, 'A': <numpy.ndarray at remote 0x7fff95382ec0>, 'B': <_CustomLinearOperator(dtype=<numpy.dtype[float64] at remote 0x7fffe9d29900>, shape=(3, 4), args=(), _CustomLinearOperator__matvec_impl=<function at remote 0x7fffe5ea6af0>, _CustomLinearOperator__rmatvec_impl=<function at remote 0x7fffe5d322b0>, _CustomLinearOperator__rmatmat_impl=<function at remote 0x7fff953a52e0>, _CustomLinearOperator__matmat_impl=<function at remote 0x7fff953a5230>) at remote 0x7ff...(truncated), closeit=1, flags=0x7fffffffe268)
    at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Python/pythonrun.c:1208
#23 0x00007ffff7d678da in _PyRun_SimpleFileObject (fp=0x555555588cf0, 
    filename='/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py', closeit=1, 
    flags=0x7fffffffe268) at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Python/pythonrun.c:456
#24 0x00007ffff7d66cee in _PyRun_AnyFileObject (fp=0x555555588cf0, 
    filename='/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py', closeit=1, 
    flags=0x7fffffffe268) at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Python/pythonrun.c:90
#25 0x00007ffff7d98d97 in pymain_run_file_obj (
    program_name='/home/andrey/Projects/scikit-learn/dev-env-debug/bin/python', 
    filename='/home/andrey/Projects/scikit-learn/linear-operator-matmul-fail.py', 
    skip_source_first_line=0) at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Modules/main.c:353
#26 0x00007ffff7d98e71 in pymain_run_file (config=0x55555555b680)
    at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Modules/main.c:372
#27 0x00007ffff7d9958d in pymain_run_python (exitcode=0x7fffffffe3c4)
    at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Modules/main.c:587
#28 0x00007ffff7d996c5 in Py_RunMain ()
    at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Modules/main.c:666
#29 0x00007ffff7d99783 in pymain_main (args=0x7fffffffe440)
    at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Modules/main.c:696
#30 0x00007ffff7d9984b in Py_BytesMain (argc=2, argv=0x7fffffffe5a8)
    at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Modules/main.c:720
#31 0x000055555555517d in main (argc=2, argv=0x7fffffffe5a8)
    at /usr/src/debug/python3.10-3.10.7-1.fc36.x86_64/Programs/python.c:15
PyArrayObject created (partial trace)
#0  PyArray_NewFromDescr_int (subtype=0x7fffe9d277a0 <PyArray_Type>, 
    descr=0x7fffe9d2a260 <OBJECT_Descr>, nd=0, dims=0x7fffffffc580, strides=0x0, data=0x0, 
    flags=0, obj=0x0, base=0x0, zeroed=0, allow_emptystring=0)
    at numpy/core/src/multiarray/ctors.c:693
#1  0x00007fffe9911cec in PyArray_NewFromDescrAndBase (subtype=0x7fffe9d277a0 <PyArray_Type>, 
    descr=0x7fffe9d2a260 <OBJECT_Descr>, nd=0, dims=0x7fffffffc580, strides=0x0, data=0x0, 
    flags=0, obj=0x0, base=0x0) at numpy/core/src/multiarray/ctors.c:958
#2  0x00007fffe9911c92 in PyArray_NewFromDescr (subtype=0x7fffe9d277a0 <PyArray_Type>, 
    descr=0x7fffe9d2a260 <OBJECT_Descr>, nd=0, dims=0x7fffffffc580, strides=0x0, data=0x0, 
    flags=0, obj=0x0) at numpy/core/src/multiarray/ctors.c:943
#3  0x00007fffe99139d9 in PyArray_FromAny (
    op=<_CustomLinearOperator(dtype=<numpy.dtype[float64] at remote 0x7fffe9d29900>, shape=(3, 4), args=(), _CustomLinearOperator__matvec_impl=<function at remote 0x7fffe5ea6af0>, _CustomLinearOperator__rmatvec_impl=<function at remote 0x7fffe5d322b0>, _CustomLinearOperator__rmatmat_impl=<function at remote 0x7fff953a92e0>, _CustomLinearOperator__matmat_impl=<function at remote 0x7fff953a9230>) at remote 0x7fffea083b60>, newtype=0x0, min_depth=0, max_depth=0, flags=0, context=0x0)
    at numpy/core/src/multiarray/ctors.c:1805
#4  0x00007fffe9ae0a92 in convert_ufunc_arguments (ufunc=0x7fffe6c58c00, full_args=..., 
    out_op=0x7fffffffc9b0, out_op_DTypes=0x7fffffffcab0, force_legacy_promotion=0x7fffffffc7cb "", 
    allow_legacy_promotion=0x7fffffffc7cc "\001", promoting_pyscalars=0x7fffffffc7cd "", 
    order_obj=0x0, out_order=0x7fffffffc7d0, casting_obj=0x0, out_casting=0x7fffffffc7d4, 
    subok_obj=0x0, out_subok=0x7fffffffc7ca "\001", where_obj=0x0, out_wheremask=0x7fffffffc7f8, 
    keepdims_obj=0x0, out_keepdims=0x7fffffffc7d8) at numpy/core/src/umath/ufunc_object.c:977
#5  0x00007fffe9aea9ed in ufunc_generic_fastcall (ufunc=0x7fffe6c58c00, args=0x7fffffffcef0, 
    len_args=2, kwnames=0x0, outer=0 '\000') at numpy/core/src/umath/ufunc_object.c:4896
#6  0x00007fffe9aeae8e in ufunc_generic_vectorcall (ufunc=<numpy.ufunc at remote 0x7fffe6c58c00>, 
    args=0x7fffffffcef0, len_args=2, kwnames=0x0) at numpy/core/src/umath/ufunc_object.c:5011

@andportnoy
Copy link
Contributor

Does anyone have any general tips on debugging numerical issues? In my tests the outputs are "mostly" correct, but a few percent of the elements are off. I'm going to study the distribution of these errors a little bit and try to visualize what's happening.

@lobpcg
Copy link
Contributor

lobpcg commented Nov 26, 2022

Does anyone have any general tips on debugging numerical issues? In my tests the outputs are "mostly" correct, but a few percent of the elements are off. I'm going to study the distribution of these errors a little bit and try to visualize what's happening.

It's generally difficult. Debugging commonly requires math knowledge of numerical stability of the values examined and properties of algorithms involved. Start with everything in double precision and test in multiple environments with different algorithms to determine if inaccuracies are environment or algorithm dependent.

@andportnoy
Copy link
Contributor

Does anyone have any general tips on debugging numerical issues? In my tests the outputs are "mostly" correct, but a few percent of the elements are off. I'm going to study the distribution of these errors a little bit and try to visualize what's happening.

It's generally difficult. Debugging commonly requires math knowledge of numerical stability of the values examined and properties of algorithms involved. Start with everything in double precision and test in multiple environments with different algorithms to determine if inaccuracies are environment or algorithm dependent.

Got it. What do you mean by an environment?

@lobpcg
Copy link
Contributor

lobpcg commented Nov 26, 2022

Environment = different numerical compilers and libraries, e.g. provided by sklearn built-in PR testing.

@andportnoy
Copy link
Contributor

I see, thank you. Working on that now.

@andportnoy
Copy link
Contributor

Summary

I added LOBPCG support and tested locally on all 100 global random seeds for a total of 18400 test cases. 14128 are failing, all failures are due to partial numerical mismatches. LOBPCG did pretty well, better than the default full solver. The randomized solver needs the most attention. Haven't added PROPACK yet. All tests were run using 64 bit floating point. Going to continue testing in other environments.

Test matrix
PCA_SOLVERS = ["full", "arpack", "randomized", "auto", "lobpcg"]

SPARSE_M, SPARSE_N = 400, 300  # arbitrary
SPARSE_MAX_COMPONENTS = min(SPARSE_M, SPARSE_N)

@pytest.mark.parametrize("density", [0.01, 0.05, 0.10, 0.30])
@pytest.mark.parametrize("n_components", [1, 2, 3, 10, SPARSE_MAX_COMPONENTS])
@pytest.mark.parametrize("format", ["csr", "csc"])
@pytest.mark.parametrize("svd_solver", PCA_SOLVERS)

Pass rate by test parameter

Fraction of the tests that passed, broken down by test matrix parameter values.

Plot

test-pca-sparse-pass-rate

Elementwise mismatch rate by test parameter

For the tests that failed, the fraction of the elements that didn't pass the default assert_allclose tolerance tests.

Plot

test-pca-sparse-mismatch-rate

@lobpcg
Copy link
Contributor

lobpcg commented Dec 4, 2022

Please have in mind that lobpcg internally calls the dense solver eigh if n_components > 5*min(N, M)

@andportnoy
Copy link
Contributor

Please have in mind that lobpcg internally calls the dense solver eigh if n_components > 5*min(N, M)

Got it, thank you. I believe the test that I'm running shouldn't trigger the dense solver since I always have n_components <= min(N, M) (< for LOBPCG).

@andportnoy
Copy link
Contributor

Given that most solvers have very few mismatched elements, I added rtol (relative tolerance) as a parameter to the test and ran the same test matrix crossed with rtol=[1e-07, 1e-06, ..., 1e-00]. Not sure in general what tolerance is appropriate. The results above used the default rtol=1e-07, atol=0 for 64 bit fp.

I also triggered CI runs across the global random seed range (without the tolerance parameter) and now presumably have data from x86 and ARM Linux runs.

Hope to visualize the data for both of the above next week. Will also need to compare intermediate outputs for the randomized solver since it has ~80% mismatches.

Depending on how the tolerance study turns out, I will try to debug the numerical mismatches for the solvers other than the randomized, which definitely needs attention.

@andportnoy
Copy link
Contributor

andportnoy commented Dec 17, 2022

Relative tolerance study

I varied the relative tolerance used in elementwise comparisons of the components vectors (corresponding to $V^T$ in $A = U\Sigma V^T$) coming from PCA on sparse data with implicit centering to those coming from dense PCA on the same data. The range that I tested is $\{10^0, 10^{-1}, ..., 10^{-7}\}$ (8 positions total), where $10^{-7}$ is the default value. The absolute tolerance was set to 0 (default) and the final comparison in np.assert_allclose is to rtol*desired + atol (see NumPy docs).

Below I'm showing plots of the resulting test pass rates and elementwise mismatch rates (of the failed tests) for each solver.

Rate values are shown up to 2 significant digits. 0 and 1 are true 0 and 1.

Tentative conclusion

For my personal practical purposes, LOBPCG passes the correctness test on this small data set (300x400). Don't know if those tolerances/error rates are acceptable in scikit-learn.

ARPACK seems to always have a small number of elements that are significantly far away from the desired value, that could also be acceptable. Not sure yet how to debug.

Randomized is performing very badly, even when tolerance is relaxed. I'll try to compare intermediate outputs to see where the divergence comes from.

The full solver is not interesting for practical purposes.

Pass rate by solver

Fraction of the tests that passed for each solver, broken down by relative tolerance level.

Plot

passrate-tolerance-study-2022-12-10_12-08-40

Elementwise mismatch rate by solver

For the tests that failed, the fraction of the elements that didn't pass the comparison.

Plot

mismatch-tolerance-study-2022-12-10_12-15-17

@andportnoy
Copy link
Contributor

I'm still trying to get a large number of tests to run across all environments on CI, I'll post an analysis of that later. It will likely only include the default tolerance level since the full range timed out.

@lobpcg
Copy link
Contributor

lobpcg commented Dec 17, 2022

@andportnoy your tests results look plausible, except for the full solver which should be expected to be 100% accurate.

@andportnoy
Copy link
Contributor

andportnoy commented Dec 17, 2022

@lobpcg Agreed, it's a little suspicious. I'll try to root cause.

In each case I'm essentially comparing the SVD of the densified centered data to the SVD of the implicitly centered sparse data, using the same solver. Maybe I should instead use the outputs of a single, highest precision accuracy, dense solver as the only reference. Or carry out the SVD using an arbitrary precision library and convert the results to nearest float64 values.

What do you think?

@lobpcg
Copy link
Contributor

lobpcg commented Dec 17, 2022

You should probably use the full float64 solver as the baseline, after figuring our why it's not 100% accurate.

@andportnoy
Copy link
Contributor

Getting back to this. Short term goals:

  1. figure out why the full solver fails most test cases,
  2. rerun the tests using the full solver outputs as reference.

@andportnoy
Copy link
Contributor

andportnoy commented Feb 24, 2023

I discovered that I wasn't using the density parameter (1-sparsity) at all, so it defaulted to 0.01 everywhere (instead of the [0.01, 0.05, 0.1, 0.3] range that I expected). So the results in #12794 (comment) and #12794 (comment) are incorrect (correct only for 0.01 density).

@andportnoy
Copy link
Contributor

Plots for #12794 (comment) redone actually using the density parameter properly. Note the passrate is much better for higher densities.

Overall passrate

pca-sparse-passrate-600f2f141-20230228-141530

Mismatch rate of failed tests

pca-sparse-mismatch-600f2f141-20230228-141004

The plots can be reproduced by running

bash pca-sparse-debug/scripts/passrate-main.sh
bash pca-sparse-debug/scripts/mismatch-main.sh

on the branch (commit 600f2f1).

@andportnoy
Copy link
Contributor

Plots for #12794 (comment) redone.

Pass rate by solver

pca-sparse-passrate-tolerance-600f2f141-20230228-142509

Elementwise mismatch rate by solver

pca-sparse-mismatch-tolerance-600f2f141-20230228-143349

@andportnoy
Copy link
Contributor

Now need to get back to figuring out why the full solver has failures.

@andportnoy
Copy link
Contributor

Full solver debug

Summary

                          19.03% failure rate
                                  |  use same averaging algorithm
                                  V
                          9.65% failure rate
                                  |  allow atol=1e-12
                                  V
                          2.5% failure rate
     increase dimensions  |               |  increase dimensions
from 400x300 to 400x3000  |               |  from 400x300 to 4000x300
                          V               V
                10% failure rate    0% failure rate
                                    (all tests pass)

Details

As is, 761/4000 (19.03%) full solver tests fail. These tests compare the following two outputs:

A. densify the sparse matrix first, pass the result to the solver
B. pass the sparse matrix to the solver as is.

I've found some issues that are increasing the failure rate:

  1. In both cases A and B the following is run on the input:
    self.mean_ = np.mean(X, axis=0)
    X -= self.mean_
    When X is a sparse matrix, however, np.mean dispatches to a different algorithm than when X is dense. This results in different mean values being computed and contributes to the mismatches.
    After the code is replaced with
    if issparse(X):
        self.mean_ = X.sum(axis=0) / X.shape[0]
    else:
        self.mean_ = np.mean(X, axis=0)
    the failure rate goes down to 386/4000 (9.65%).
  2. atol (absolute tolerance) was set to 0, causing mismatches when the expected element value is 0 and the observed value is not exactly equal to 0, however small. Allowing atol=1e-12 brings the failure rate down to 100/4000 (2.5%).
  3. The size and shape of the problem plays a significant role. Increasing the number of rows tenfold (400x300 -> 4000x300) at this point results in all tests passing. Increasing the number of columns tenfold (400x300 -> 400x3000), however, brings the failure rate up to 400/4000 (10%).

I don't know what role the distribution of our (randomly generated) inputs plays with respect to test outcomes.

@lobpcg
Copy link
Contributor

lobpcg commented Apr 24, 2023

@andportnoy it appears that it might help your debugging if you run the extreme cases - examples with 1 or 2 rows or columns, where the svd part becomes trivial/exact so you would mostly check the operation of the mean.

Your randomly generated input I guess is expected to always have an approximately constant (e.g., zero) mean, depending on a distribution. To change this, you may want to add random shifts to randomly generated columns/rows.

@andportnoy
Copy link
Contributor

One issue with the distribution here is that the entries are i.i.d., so the covariance matrix has a constant repeated on the diagonal and off diagonal entries all zeros… Would this be an "ill-conditioned" input to SVD, especially if the matrix is underdetermined?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.