-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.lobpcg always breaks for autograd #38948
Comments
@pearu, could you please take a look at this? |
The LOBPCG algorithm is iterative, and indeed, the |
For reference: #32531 - gradient of |
The paper https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf @nikitaved what do you think, is the problem of backward solvable with the restricted set of eigenpairs? |
The |
It is possible to derive the gradient for k-rank lobpcg, but it is not going to be unique because it will be a solution to an overdetermined system of linear equations. I will put my derivations a bit later so that you could tell whether you agree with them... |
Ah, sorry, the system is underdetermined, here it goes...
Let
we can split it into
As you can see, we have
The non-uniqueness comes from the system being underdetermined... |
But I guess we might run into an issue when both input matrices for lobpcg require gradients and they are not some function of each other... |
But in this case, if you consider the rank-k problem, |
Take any k eigenvectors of A as U and fill diagonal of D with the corresponding eigenvalues, there you go... |
@albanD I am not sure I am following you. We have
so that
holds. Replacing |
Yeah, exactly, lobpcg is a function A -> U, D, such that AU = UD, in case of the basic eignevalue problem. In the backward we receive dU, dD and propagate them to dA through the equality. |
This is not actually true as lobpcg works only when |
This method does seem to produce UPDATED: |
I am not sure I understand... In X only a submatrix of size of size (k, n) is non-zero, the rest is zero. |
@pearu My concern is how can the k top eigenvalues be exact if A has rank > k:
|
@albanD , take k eigenvectors/eigenvalues of A, so that we have Au_i = lambda_i u_i. Now, Let U = (u_1, ... u_k), D = diag(lambda_1,..., lambda_k). Now you have AU = UD. |
Ho right this system has only k columns! Thanks ! |
The input to
before calling |
In that case you get an error in the order of 1e-3 (with the same parameters, n=10, k=3). Which is a worst case so ok I think. |
@nikitaved, it seems to be ok (I had a little bug before). Here is a quick example of your method:
```
import torch as T
class LOBPCG2(T.autograd.Function):
T.random.manual_seed(123) A = T.randn(6, 6) e, v = LOBPCG2.apply(S) S_hat = T.einsum('ij,j,kj->ik', v, e, v) # v * diag(e) * v^T
|
@pearu, so, does it work? :) |
Yes, the |
You can compare it for the case |
But it is analytic. I am sure there must be some linear equation solver somewhere in PyTorch.. |
We do have a
Interestingly, the Jacobian wrt the eigenvalues given by the finite difference don't match because it generates a "dense" gradient while your method give gradients only for the first two columns of A :/ |
If I understand correctly the above, the backward computation is algorithm agnostic, i.e. comes from perturbation theory of eigenvalues and eigenvectors that are assumed to be computed exactly. LOBPCG only computes approximately, so the results of the backward computation are going to match only quite approximately. E.g., in an extreme case where LOBPCG performs just one iteration, the results of forward and backward computations will unlickly match each other well. This is very different from what is typically meant. Please confirm or correct me. |
yes
yes, that is all correct, however, I don't see problems with the above in the sense that if a user asks for less accuracy from lobpcg, s/he cannot expect forward/backward to be exact either. |
Thanks for the confirmation! I hope that this distinction is made clear to avoid confusion. There are simple iterative methods where forward/backward pair can be implemented quite accurately, for a fixed initial approximation.
Sent from my T-Mobile 4G LTE Device
Get Outlook for Android<https://aka.ms/ghei36>
…________________________________
From: Pearu Peterson <notifications@github.com>
Sent: Wednesday, August 12, 2020 11:59:33 AM
To: pytorch/pytorch <pytorch@noreply.github.com>
Cc: Knyazev, Andrew <Andrew.Knyazev@ucdenver.edu>; Comment <comment@noreply.github.com>
Subject: Re: [pytorch/pytorch] torch.lobpcg always breaks for autograd (#38948)
If I understand correctly the above, the backward computation is algorithm agnostic
yes
LOBPCG only computes approximately, so the results of the backward computation are going to match only quite approximately. E.g., in an extreme case where LOBPCG performs just one iteration, the results of forward and backward computations will unlickly match each other well. This is very different from what is typically meant. Please confirm or correct me.
yes, that is all correct, however, I don't see problems with the above in the sense that if a user asks for less accuracy from lobpcg, s/he cannot expect forward/backward to be exact either.
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub<#38948 (comment)>, or unsubscribe<https://github.com/notifications/unsubscribe-auth/AKFMTPOSBMKOOUYKNYHUUVTSAK36LANCNFSM4NIJFXFA>.
|
@albanD , the |
You can do anything you want :D |
@nikitaved Since the proposed backward algorithm is generic and thus unrelated to lobpcg, its interface should not mimic that of lobpcg. |
Nothing I can do about it, the forward replicates the interface of lobpcg, and the backward requires grads for each input of forward. I could probably use some old Maybe it actually makes sense to create a method with a name something like geigk (generalized eigenvalue problem of rank k exactly the |
Nope, most of the code has been removed. This don't work anymore.
Does such generalized solver exist in other python lib like scipy? If so, we can try to replicate their API. |
Yes, the scipy's |
Well, the formula above for the symmetric case, as does autograd show, is incorrect. Upon closer look I realized I cannot justify the substitute There is an alternative derivation below in which I am stuck at one point.
Here goes the critical part
All problems dissapear, of course, once we know the whole eigenspace of A, because then the Sylvester equation can then be written as B = dX o C for some matrices B and C. |
The Sylvester equation from above can be written as
The This paper
states that the equation from above can be solved explicitly as:
Since |
…ward [only k-rank SYMEIG case] (#43002) Summary: As per title. Fixes [#{38948}](#38948). Therein you can find some blueprints for the algorithm being used in this PR. Pull Request resolved: #43002 Reviewed By: zou3519 Differential Revision: D23931326 Pulled By: albanD fbshipit-source-id: e6994af70d94145f974ef87aa5cea166d6deff1e
A PR that implements support for the symmetric case with
There could be some numerical issues popping up once |
Hi @nikitaved, I've implemented the backward for symmetric A & B in xitorch for your reference. I have to admit it was not an easy task. I can write down the math derivation if you want. |
@mfkasim1 , hi! The symmetric case is already in PyTorch. The case for B requires the matrix square root, there is an issue I am assigned too. Maybe it is possible to avoid it. I am about to make a write-up. We could compare the math or become co-authors, up to you! You can check the code in torch/_lobpcg.py. Looks like you propagate through the whole eigenspace, right? Here the issue is that you only have a k-subspace, where k is generally much less than the dimension n. |
@nikitaved Ah, I thought you haven't implemented for non-identity
The write-up sounds interesting. Where do you plan to write it? |
I see, thank you! Then you are right! B is only identity as of now. Would be cool to see alternatives which involve different methods. Well, at least arXiv would be great for starters... |
The math trick avoiding the square root of B is to notice that inv(B)*A is symmetric in the B-based scalar product, i.e. x'Bx. So one can apply whatever standard theory or algorithms for generalized eigenvalue problems, as soon as B is SPD (our case). After the substitution, in many places one gets inv(B)B so it just goes away. But in a few key spots one still needs implementing vector multiplication by inv(B)*A - that is where a linear solve for B is needed. |
Upon further inspection, I can see how to extend the already implemented method to the generalized problem. And, yes, it does not require the matrix square root, although it does become more computationally heavy as I no longer can exploit simultaneous diagonalizations in the explicit solution to the Sylvester equation (check needed). I will try and compare this solution to the solution with the matrix square root. Not sure which one is going to be faster. |
@nikitaved for your reference, I have written down the symeig derivative in xitorch: https://xitorch.readthedocs.io/en/latest/notes/deriv_symeig.html |
Closing as the example in the description now works (thanks to #43002) |
🐛 Bug
It seems that
torch.lobpcg
(https://pytorch.org/docs/stable/torch.html?highlight=lobpcg#torch.lobpcg) just always breaks when trying to take gradients viabackward
.To Reproduce
Here's a minimalist example showing
lobpcg
breaking.Running that code produces the following error.
I have a feeling that the problem is that
torch.lobpcg
's implementation is using an in-place operation when it shouldn't be.This happened when running
torch.__version__ == '1.5.0+cpu'
installed with pip on Windows 10 WSL (Windows Subsystem for Linux) on Python 3.5.2.Can this be fixed, or is
torch.lobpcg
not meant to support autograd?cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @vincentqb @vishwakftw @jianyuh @mruberry @ssnl
The text was updated successfully, but these errors were encountered: