Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical stability during autodifferentiation of eigh #25

Closed
PabloAMC opened this issue Aug 19, 2023 · 9 comments · Fixed by #26
Closed

Numerical stability during autodifferentiation of eigh #25

PabloAMC opened this issue Aug 19, 2023 · 9 comments · Fixed by #26
Labels
bug Something isn't working

Comments

@PabloAMC
Copy link
Collaborator

PabloAMC commented Aug 19, 2023

In #17 we discussed that it would be quite important to implement the training self-consistently. Unfortunately, I was having lots of Nan-related issues in trying to differentiate throughout the self-consistent loop. With a simple example, I have been able to narrow down the problem to an external implementation of the generalized eigenvalue problem, provided in https://github.com/XanaduAI/DiffDFT/blob/main/grad_dft/external/eigh_impl.py and originally reported in https://gist.github.com/jackd/99e012090a56637b8dd8bb037374900e. The author also has some notes in https://jackd.github.io/posts/generalized-eig-jvp/.
The problem is that jax.scipy.linalg.eigh only supports the implementation when b = 0, even if it is not reported in the docs https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.eigh.html.
Substituting the line https://github.com/XanaduAI/DiffDFT/blob/da7d384fcb1bae547fa9370ae7846bf939a6d580/grad_dft/evaluate.py#L544 by anything else removes the error. For example, even if it nonsense, I can do

mo_energy = fock[:,:,0]
mo_coeff = fock + fock**2

which preserves the right matrix shapes, and retains dependency on the Fock matrix (so gradients are not 0). Then, for a very simple LDA model with a single layer etc, the gradients all look good even when the number of scf iterations is 35, though it breaks between 35 to 40 scf iterations.
I think solving this should be a high-priority problem now.

@PabloAMC PabloAMC added the bug Something isn't working label Aug 19, 2023
@PabloAMC
Copy link
Collaborator Author

PabloAMC commented Aug 19, 2023

I am investigating whether this google/jax#5461 (comment) could be a simpler yet effective solution. It seems to allow for 2 training iterations before breaking down, (having set up scf iterations = 2). However, it returns erroneous eigenvalues, as it ends up not conserving the charge.

@PabloAMC
Copy link
Collaborator Author

PabloAMC commented Aug 19, 2023

I have created an example to play around with this problem
https://github.com/XanaduAI/DiffDFT/blob/main/examples/basic_examples/example_neural_scf_training.py
It should be deleted if we can't solve it on time. Changing the number of diis iterations in https://github.com/XanaduAI/DiffDFT/blob/9bbab42cfe95bbb6031c45613b5ab718ca9e9e5f/grad_dft/evaluate.py#L524
also has some minor effect.

@PabloAMC
Copy link
Collaborator Author

@jackbaker1001
Copy link
Collaborator

I'm going to experiment a bit with this now, but I think we can just do:

import jax.numpy as jnp
def generalized_eigh(A, B):
    L = jnp.linalg.cholesky(B)
    L_inv = jnp.linalg.inv(L)
    A_redo = L_inv.dot(A).dot(L_inv.T)
    return jnp.linalg.eigh( A_redo )

@jackbaker1001
Copy link
Collaborator

I'm going to experiment a bit with this now, but I think we can just do:

import jax.numpy as jnp
def generalized_eigh(A, B):
    L = jnp.linalg.cholesky(B)
    L_inv = jnp.linalg.inv(L)
    A_redo = L_inv.dot(A).dot(L_inv.T)
    return jnp.linalg.eigh( A_redo )

Ah sorry this is the solution you already posted. Let me see if there is a better way of doing this...

@jackbaker1001
Copy link
Collaborator

Ok @PabloAMC I think this code fixes the problem:

def generalized_to_standard_eig(A, B):
    L = np.linalg.cholesky(B)
    L_inv = np.linalg.inv(L)
    C = L_inv @ A @ L_inv.T
    eigenvalues, eigenvectors_transformed = np.linalg.eigh(C)
    eigenvectors_original = L_inv.T @ eigenvectors_transformed
    return eigenvalues, eigenvectors_original

Basically, the eigenvectors were not transformed back to the original basis.

@PabloAMC
Copy link
Collaborator Author

PabloAMC commented Aug 21, 2023

You are right, that works, thanks @jackbaker1001. Unfortunately, it still gives errors when backpropagating even when using this version of the generalized eigenvalue problem. It still gives nan errors when I set cycles = 1 in the additional example https://github.com/XanaduAI/DiffDFT/blob/main/examples/basic_examples/example_neural_scf_training.py I included in the basic folder and change the number of cycles to 1, instead of 0, in https://github.com/XanaduAI/DiffDFT/blob/81444ce067a809e3e69bffdf81b4e7c1b5ca4d2b/examples/basic_examples/example_neural_scf_training.py#L138

@jackbaker1001
Copy link
Collaborator

Ok @PabloAMC so because this issue was about a differentiable eigh implementation, which we now have, I will put in a PR for this and merge. I will keep the custom eigh implementation from before in the code just in case in comes in handy later. It may be removed later also.

I will make a new issue regarding differentiating through the SCF iterator (different implementations of DIIS presently) as this is now the problem!

@PabloAMC
Copy link
Collaborator Author

I think we should reopen this. The problem is numerical stability during the calculation of grads, not a lack of a differentiable implementation. And that does not seem solved yet.

@PabloAMC PabloAMC reopened this Aug 22, 2023
@PabloAMC PabloAMC changed the title Substitute the eigh implementation by something fully differentiable Numerical stability during autodifferentiation of eigh Aug 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants