Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.lobpcg always breaks for autograd #38948

Closed
chausies opened this issue May 23, 2020 · 68 comments
Closed

torch.lobpcg always breaks for autograd #38948

chausies opened this issue May 23, 2020 · 68 comments
Assignees
Labels
module: autograd Related to torch.autograd, and the autograd engine in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@chausies
Copy link

chausies commented May 23, 2020

🐛 Bug

It seems that torch.lobpcg (https://pytorch.org/docs/stable/torch.html?highlight=lobpcg#torch.lobpcg) just always breaks when trying to take gradients via backward.

To Reproduce

Here's a minimalist example showing lobpcg breaking.

# lob.py
import torch as T
T.autograd.set_detect_anomaly(True)

A = T.randn(10, 10)
A.requires_grad_()
S = A.matmul(A.t())
e, v = T.lobpcg(S, k=3)
S_hat = T.einsum('ij,j,kj->ik', v, e, v) # v * diag(e) * v^T
loss = S_hat.abs().sum()
loss.backward() # breaks here

Running that code produces the following error.

Warning: Error detected in MmBackward. Traceback of forward call that caused the error:
  File "lob.py", line 9, in <module>
    e, v = T.lobpcg(S, k=3)
  File "/usr/local/lib/python3.5/dist-packages/torch/_lobpcg.py", line 261, in lobpcg
    worker.run()
  File "/usr/local/lib/python3.5/dist-packages/torch/_lobpcg.py", line 408, in run
    self.update()
  File "/usr/local/lib/python3.5/dist-packages/torch/_lobpcg.py", line 343, in update
    self._update_ortho()
  File "/usr/local/lib/python3.5/dist-packages/torch/_lobpcg.py", line 498, in _update_ortho
    self.X[:, nc:] = mm(S_, Z[:, :n - nc])
 (print_stack at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:60)
Traceback (most recent call last):
  File "lob.py", line 12, in <module>
    loss.backward() # breaks here
  File "/usr/local/lib/python3.5/dist-packages/torch/tensor.py", line 198, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/usr/local/lib/python3.5/dist-packages/torch/autograd/__init__.py", line 100, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [10, 5]], which is output 0 of SliceBackward, is at version 14; expected version 11 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

I have a feeling that the problem is that torch.lobpcg's implementation is using an in-place operation when it shouldn't be.

This happened when running torch.__version__ == '1.5.0+cpu' installed with pip on Windows 10 WSL (Windows Subsystem for Linux) on Python 3.5.2.

Can this be fixed, or is torch.lobpcg not meant to support autograd?

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @vincentqb @vishwakftw @jianyuh @mruberry @ssnl

@ezyang
Copy link
Contributor

ezyang commented May 26, 2020

@pearu, could you please take a look at this?

@ezyang ezyang added module: autograd Related to torch.autograd, and the autograd engine in general module: operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 26, 2020
@pearu
Copy link
Collaborator

pearu commented May 26, 2020

The LOBPCG algorithm is iterative, and indeed, the torch.lobpcg implementation uses in-place operations which cannot be avoided. However, the current issue can be resolved by implementing autograd support for torch.lobpcg: the same backward algorithm used in torch.eig ought to be a good starting point, it needs to be modified for the k parameter that restricts the number of eigenpairs.

@pearu
Copy link
Collaborator

pearu commented May 26, 2020

For reference: #32531 - gradient of torch.eig.

@pearu
Copy link
Collaborator

pearu commented May 27, 2020

The paper https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
derives the backward formula for the standard eigenvalue problem that involves the inverse of the eigenvectors matrix. Since the lobpcg provides only the first k eigenvectors, the formula is not applicable here.

@nikitaved what do you think, is the problem of backward solvable with the restricted set of eigenpairs?

@albanD
Copy link
Collaborator

albanD commented May 27, 2020

The torch.eig function has a similar limitation that if you set eigenvectors=False, you won't be able to compute the backward pass.

@nikitaved
Copy link
Collaborator

nikitaved commented May 27, 2020

It is possible to derive the gradient for k-rank lobpcg, but it is not going to be unique because it will be a solution to an overdetermined system of linear equations. I will put my derivations a bit later so that you could tell whether you agree with them...

@nikitaved
Copy link
Collaborator

nikitaved commented May 27, 2020

Ah, sorry, the system is underdetermined, here it goes...
For simplicity, let's consider the basic rank-k eigenvalue problem of the form AU = UD, where A is n by n, rank(U) = rank(D) = k. Then applying the diff operator:

dA U + A dU = dU D + U dD,
U^T dA^T = (dU D + U dD - A dU)^T.

Let K := U^T, X:= dA^T, L:= (dU D + U dD - A dU)^T. We want to find X by solving the following matrix equation:

K X = L, where K,L in (k, n), X in (n, n).

we can split it into n systems of linear equations of the form

K X[i] = L[i].

As you can see, we have k < n equalities with n unknowns, so it is underdetermined.
One possible solution could be this:

Let I = {i_i,...i_k} column indices, such that rank(K[I]) = k, then
X[i]_I = K[I]^{-1} L[i] for indices in I, and X[i]_{[n] \ I} = 0.

The non-uniqueness comes from the system being underdetermined...

@nikitaved
Copy link
Collaborator

nikitaved commented May 27, 2020

But I guess we might run into an issue when both input matrices for lobpcg require gradients and they are not some function of each other...

@albanD
Copy link
Collaborator

albanD commented May 27, 2020

But in this case, if you consider the rank-k problem, AU = UD is not true in general. There are no rank-k matrix U and diagonal D that verifies this. So you can't base your gradient derivation on this equality right?

@nikitaved
Copy link
Collaborator

nikitaved commented May 27, 2020

Take any k eigenvectors of A as U and fill diagonal of D with the corresponding eigenvalues, there you go...

@pearu
Copy link
Collaborator

pearu commented May 27, 2020

@albanD I am not sure I am following you. We have

E, V = torch.lobpcg(A, k=k)

so that

torch.mm(A, V) == torch.mm(V, torch.diag(E))

holds.

Replacing k with n will produce the same first k eigenpairs as when using the initial k.

@nikitaved
Copy link
Collaborator

nikitaved commented May 27, 2020

Yeah, exactly, lobpcg is a function A -> U, D, such that AU = UD, in case of the basic eignevalue problem. In the backward we receive dU, dD and propagate them to dA through the equality.

@pearu
Copy link
Collaborator

pearu commented May 27, 2020

Replacing k with n will produce the same first k eigenpairs as when using the initial k.

This is not actually true as lobpcg works only when 3*k<=n. Sorry for the confusion. But lobpcg will return the first k eigenpairs.

@pearu
Copy link
Collaborator

pearu commented May 27, 2020

One possible solution could be this:

Let I = {i_i,...i_k} column indices, such that rank(K[I]) = k, then
X[i]_I = K[I]^{-1} B[i] for indices in I, and zero otherwise.

This method does seem to produce X such that K X=L holds only the k columns but for others not (as the corresponding columns in L are non-trivial).

UPDATED: B -> L

@nikitaved
Copy link
Collaborator

nikitaved commented May 27, 2020

One possible solution could be this:

Let I = {i_i,...i_k} column indices, such that rank(K[I]) = k, then
X[i]_I = K[I]^{-1} B[i] for indices in I, and zero otherwise.

This method does seem to produce X such that K X=B holds only the k columns but for others not (as the corresponding columns in B are non-trivial).

I am not sure I understand... In X only a submatrix of size of size (k, n) is non-zero, the rest is zero.
X[i][I] = K[I]^{-1} L[i], and X[i][[n] \ I] = 0. To be more precise, X has k non-zero rows, the rest is zero.

@albanD
Copy link
Collaborator

albanD commented May 27, 2020

@pearu My concern is how can the k top eigenvalues be exact if A has rank > k:

import torch

A = torch.rand(10, 10) # Most likely rank > 3
k = 3

E, V = torch.lobpcg(A, k=k)

# This will print != 0
print(( torch.mm(A, V) - torch.mm(V, torch.diag(E)) ).abs().max())

@nikitaved
Copy link
Collaborator

nikitaved commented May 27, 2020

@albanD , take k eigenvectors/eigenvalues of A, so that we have Au_i = lambda_i u_i. Now, Let U = (u_1, ... u_k), D = diag(lambda_1,..., lambda_k). Now you have AU = UD.

@albanD
Copy link
Collaborator

albanD commented May 27, 2020

Ho right this system has only k columns! Thanks !

@pearu
Copy link
Collaborator

pearu commented May 27, 2020

My concern is how can the k top eigenvalues be exact if A has rank > k

The input to lobpcg must be a symmetric positive defined matrix. So try

A = torch.mm(A.transpose(-2, -1), A)

before calling lobpcg.

@albanD
Copy link
Collaborator

albanD commented May 27, 2020

In that case you get an error in the order of 1e-3 (with the same parameters, n=10, k=3). Which is a worst case so ok I think.
But my concern is solved by seeing that we only have k vectors for which we verify this equality. So it makes sense that k eigenpairs are enough.

@pearu
Copy link
Collaborator

pearu commented May 27, 2020

I am not sure I understand... In X only a submatrix of size of size (k, n) is non-zero, the rest is zero.
X[i][I] = K[I]^{-1} L[i], and X[i][[n] \ I] = 0. To be more prices, X has k non-zero rows, the rest is zero.

@nikitaved, it seems to be ok (I had a little bug before). Here is a quick example of your method:

``` import torch as T

class LOBPCG2(T.autograd.Function):

@staticmethod
def forward(ctx, A):
    k = 2
    e, v = T.lobpcg(A, k=k)
    res = T.mm(A, v) - T.mm(v, T.diag(e))
    assert (res.abs() < 1e-5).all()
    ctx.save_for_backward(e, v, A)
    return e, v

@staticmethod
def backward(ctx, de, dv):
    """
    solve `dA v + A dv = dv diag(e) + v diag(de)` for `dA`
    """
    e, v, A = ctx.saved_tensors

    vt = v.transpose(-2, -1)
    print('vt=', vt)
    print('de=', de)
    print('dv=', dv)
    rhs = (T.mm(dv, T.diag(e)) + T.mm(v, T.diag(de)) - T.mm(A, dv)).transpose(-2, -1)
    print('rhs=', rhs)

    n, k = v.shape
    K = vt[:, :vt.shape[0]]
    print('K.det=', K.det())  # should be > 0
    iK = K.inverse()

    dAt = T.zeros((n, n))
    dAt[:k] = T.mm(iK, rhs)[:k]
    print('dAt=', dAt)
    dA = dAt.transpose(-2, -1)

    res = T.mm(dA, v) + T.mm(A, dv) - T.mm(dv, T.diag(e)) - T.mm(v, T.diag(de))
    print('res=', res)
    return dA

T.random.manual_seed(123)

A = T.randn(6, 6)
S = A.matmul(A.t())
S.requires_grad_()

e, v = LOBPCG2.apply(S)

S_hat = T.einsum('ij,j,kj->ik', v, e, v) # v * diag(e) * v^T
loss = S_hat.abs().sum()
loss.backward()

</details>
(the printed `res` is close to zero).

@nikitaved
Copy link
Collaborator

@pearu, so, does it work? :)

@pearu
Copy link
Collaborator

pearu commented May 27, 2020

Yes, the backward call works but the method requires verification against torch.symeig case and understanding the consequences of the arbitrariness of dA.

@nikitaved
Copy link
Collaborator

nikitaved commented May 27, 2020

You can compare it for the case k=n, for example. With torch.symeig I mean.

@nikitaved
Copy link
Collaborator

But it is analytic. I am sure there must be some linear equation solver somewhere in PyTorch..

@albanD
Copy link
Collaborator

albanD commented May 27, 2020

We do have a torch.solve.
Also if you want to compare to finite difference, you can use torch.autograd.gradcheck (with double input) with:

S = S.double().detach().requires_grad_(True)

# You need to make sure in your backward that your call to torch.zeros() creates the proper dtype
# by passing `dtype=A.dtype`.
T.autograd.gradcheck(LOBPCG2.apply, S)

Interestingly, the Jacobian wrt the eigenvalues given by the finite difference don't match because it generates a "dense" gradient while your method give gradients only for the first two columns of A :/
Couldn't see the Jacobian wrt to the eigenvectors as the other fails before :/

@lobpcg
Copy link

lobpcg commented Aug 12, 2020

If I understand correctly the above, the backward computation is algorithm agnostic, i.e. comes from perturbation theory of eigenvalues and eigenvectors that are assumed to be computed exactly. LOBPCG only computes approximately, so the results of the backward computation are going to match only quite approximately. E.g., in an extreme case where LOBPCG performs just one iteration, the results of forward and backward computations will unlickly match each other well. This is very different from what is typically meant. Please confirm or correct me.

@pearu
Copy link
Collaborator

pearu commented Aug 12, 2020

If I understand correctly the above, the backward computation is algorithm agnostic

yes

LOBPCG only computes approximately, so the results of the backward computation are going to match only quite approximately. E.g., in an extreme case where LOBPCG performs just one iteration, the results of forward and backward computations will unlickly match each other well. This is very different from what is typically meant. Please confirm or correct me.

yes, that is all correct, however, I don't see problems with the above in the sense that if a user asks for less accuracy from lobpcg, s/he cannot expect forward/backward to be exact either.

@lobpcg
Copy link

lobpcg commented Aug 12, 2020 via email

@nikitaved
Copy link
Collaborator

nikitaved commented Aug 12, 2020

@albanD , the lobpcg implementation accepts 14 parameters. Is it possible to implement forward/backward as a class method, and not as a static method, to make forward/backward accept and return only 2 parameters?

@albanD
Copy link
Collaborator

albanD commented Aug 12, 2020

You can do anything you want :D
You can make the whole thing a custom Function that takes all 14 parameters (and will return None gradients for most of them as they are not differentiable).
Or you can have some pre/post processing that is autodiffed. And only the core that is a simpler custom Function with a limited number of arguments for which you implement the forward/backward.

@lobpcg
Copy link

lobpcg commented Aug 12, 2020

@nikitaved Since the proposed backward algorithm is generic and thus unrelated to lobpcg, its interface should not mimic that of lobpcg.

@nikitaved
Copy link
Collaborator

nikitaved commented Aug 12, 2020

Nothing I can do about it, the forward replicates the interface of lobpcg, and the backward requires grads for each input of forward. I could probably use some old autograd.Function style with a constructor to save the lobpcg relevant data and implement forward/backward as for the perfect Generalized Eigenvalue Solver, but the new style, with static methods, is probably better. I could be wrong, however.

Maybe it actually makes sense to create a method with a name something like geigk (generalized eigenvalue problem of rank k exactly the autograd.Function wrapper), then implement backward for it, and then in the documentation state that its forward is implemented with lobpcg. This method could be using less parameters by ignoring most optional ones from the lobpcg's interface.

@albanD
Copy link
Collaborator

albanD commented Aug 12, 2020

I could probably use some old autograd.Function style

Nope, most of the code has been removed. This don't work anymore.

This method could be using less parameters by ignoring most optional ones from the lobpcg's interface.

Does such generalized solver exist in other python lib like scipy? If so, we can try to replicate their API.
If none exist, we do have a bit more freedom here.

@nikitaved
Copy link
Collaborator

nikitaved commented Aug 12, 2020

Does such generalized solver exist in other python lib like scipy? If so, we can try to replicate their API.
If none exist, we do have a bit more freedom here.

Yes, the scipy's eig does support the generalized interface. What about eigk, to distinguish from the "full-rank" eig?

@lobpcg
Copy link

lobpcg commented Aug 13, 2020

https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.eigh.html
https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.eigsh.html
since this is the symmetric case.

@nikitaved
Copy link
Collaborator

nikitaved commented Aug 19, 2020

Well, the formula above for the symmetric case, as does autograd show, is incorrect. Upon closer look I realized I cannot justify the substitute dU = U dC, as there is no guarantee that span(dU) is a subspace of span(U).

There is an alternative derivation below in which I am stuck at one point.

U^T U = I => dU^T U + U^T dU = 0 => du := U^T dU is skew-symmetric, so diag(du) = 0.
AU = UD => 
dA U + A dU = dU D + U dD [left-multiply by U^T]               (*)
U^T dA U + U^T A dU = du D + dD [AU = UD => U^T A = D U^T]
U^T dA U + D du = du D + dD
U^T dA U = (du D - D du) + dD [because du is skew symmetric, diag(du D) = diag(D du) = 0]

so we get
dD = I o (U^T dA U), and, similar to the derivation from above (**)
du = F o (U^T dA U), where F_ij = (d_j - d_i)^{-1}, F_ii = 0   (***)

Here goes the critical part

Let U_ortho be an orthonormal basis of a subspace orthogonal to the span(U).
Then we can write
dU = U du + U_ortho dX, where dX is (m-k) x k. This is true for some dX because <U, U_ortho> form a basis of R^m.

By Left-multipling (*) by U_ortho^T, and using the decomposition of dU and orthogonality we get:
U_ortho^T dA U + U_ortho^T A U_ortho dX = dX D. (x)

(x) is a Sylvester equation wrt dX and we need to solve it explicitly, so that dA enters this solution as a part of matrix products. Do you know how to solve it? D is diagonal.

All problems dissapear, of course, once we know the whole eigenspace of A, because then the Sylvester equation can then be written as B = dX o C for some matrices B and C.

@nikitaved
Copy link
Collaborator

nikitaved commented Aug 21, 2020

The Sylvester equation from above can be written as

C dX + dX D = E.

The symeig_backward assumes that A has distinct values, so, under this assumption we have spec(C) intersect spec(D) = the empty set, so the system has a unique solution.

This paper

Hu, Q., & Cheng, D. (2006). 
The polynomial solution to the Sylvester matrix equation.
Applied mathematics letters, 19(9), 859-864.

states that the equation from above can be solved explicitly as:

dX = p_D(C)^{-1} \sum_{i=1}^k b_i \sum_{j=1}^{i-1} C^{i-1-j} E D^j,
where p_D is a characteristic polynomial of D with coefficients b_i.

Since D is diagonal, its characteristic polynomial is simple and its coefficients could be found via the Horner's rule.
Since E = U_ortho^T dA U, we can see that the explicit solution for dX has dA entered with power 1 and only as a part of matrix products, which is exactly what we need to be able to gather terms being multiplied with dA in the backward AD.

facebook-github-bot pushed a commit that referenced this issue Sep 28, 2020
…ward [only k-rank SYMEIG case] (#43002)

Summary:
As per title. Fixes [#{38948}](#38948). Therein you can find some blueprints for the algorithm being used in this PR.

Pull Request resolved: #43002

Reviewed By: zou3519

Differential Revision: D23931326

Pulled By: albanD

fbshipit-source-id: e6994af70d94145f974ef87aa5cea166d6deff1e
@nikitaved
Copy link
Collaborator

nikitaved commented Sep 28, 2020

A PR that implements support for the symmetric case with B=I has been merged and is available in master. After a short break I plan to implement the remaining two cases, that is

  1. Grad wrt A for arbitrary B when B.requires_grad == False and
  2. Grad wrt A and B when both require grads.

There could be some numerical issues popping up once k is large enough. Any feedback on such behaviors is much appreciated!

@mruberry mruberry added module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul and removed module: operators (deprecated) labels Oct 7, 2020
@mfkasim1
Copy link
Contributor

Hi @nikitaved, I've implemented the backward for symmetric A & B in xitorch for your reference. I have to admit it was not an easy task. I can write down the math derivation if you want.

@nikitaved
Copy link
Collaborator

nikitaved commented Oct 31, 2020

@mfkasim1 , hi! The symmetric case is already in PyTorch. The case for B requires the matrix square root, there is an issue I am assigned too. Maybe it is possible to avoid it. I am about to make a write-up. We could compare the math or become co-authors, up to you! You can check the code in torch/_lobpcg.py.

Looks like you propagate through the whole eigenspace, right? Here the issue is that you only have a k-subspace, where k is generally much less than the dimension n.

@mfkasim1
Copy link
Contributor

mfkasim1 commented Oct 31, 2020

@nikitaved Ah, I thought you haven't implemented for non-identity B.
I'm using implicit matrices/linear operator for A and B in xitorch (only need to know Av and Bv), so it works for sparse matrices if you're interested in implementing that too. However, it involves iterative solve (e.g. cg, bicgstab, etc).

Looks like you propagate over the whole eigenspace, right? Here the issue is that you only have a k-subspace, where k is generally much less than the dimension n

xitorch.linalg.symeig function actually looks very similar to lobpcg where it only retrieves k eigenpairs for k << n and intended for applications where you can't store the n-by-n matrix.

The write-up sounds interesting. Where do you plan to write it?

@nikitaved
Copy link
Collaborator

I see, thank you! Then you are right! B is only identity as of now. Would be cool to see alternatives which involve different methods.

Well, at least arXiv would be great for starters...

@lobpcg
Copy link

lobpcg commented Oct 31, 2020

The math trick avoiding the square root of B is to notice that inv(B)*A is symmetric in the B-based scalar product, i.e. x'Bx. So one can apply whatever standard theory or algorithms for generalized eigenvalue problems, as soon as B is SPD (our case). After the substitution, in many places one gets inv(B)B so it just goes away. But in a few key spots one still needs implementing vector multiplication by inv(B)*A - that is where a linear solve for B is needed.

@nikitaved
Copy link
Collaborator

nikitaved commented Nov 1, 2020

Upon further inspection, I can see how to extend the already implemented method to the generalized problem. And, yes, it does not require the matrix square root, although it does become more computationally heavy as I no longer can exploit simultaneous diagonalizations in the explicit solution to the Sylvester equation (check needed).

I will try and compare this solution to the solution with the matrix square root. Not sure which one is going to be faster.

@mfkasim1
Copy link
Contributor

mfkasim1 commented Nov 2, 2020

@nikitaved for your reference, I have written down the symeig derivative in xitorch: https://xitorch.readthedocs.io/en/latest/notes/deriv_symeig.html

@pearu
Copy link
Collaborator

pearu commented Dec 14, 2023

Closing as the example in the description now works (thanks to #43002)

@pearu pearu closed this as completed Dec 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: autograd Related to torch.autograd, and the autograd engine in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
8 participants