-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Numerical stability during autodifferentiation of eigh #25
Comments
I am investigating whether this google/jax#5461 (comment) could be a simpler yet effective solution. It seems to allow for 2 training iterations before breaking down, (having set up scf iterations = 2). However, it returns erroneous eigenvalues, as it ends up not conserving the charge. |
I have created an example to play around with this problem |
I have realized that https://gist.github.com/jackd/99e012090a56637b8dd8bb037374900e provides a definition of the derivatives by hand which is probably being used https://github.com/XanaduAI/DiffDFT/blob/9bbab42cfe95bbb6031c45613b5ab718ca9e9e5f/grad_dft/external/eigh_impl.py#L84 |
I'm going to experiment a bit with this now, but I think we can just do: import jax.numpy as jnp
def generalized_eigh(A, B):
L = jnp.linalg.cholesky(B)
L_inv = jnp.linalg.inv(L)
A_redo = L_inv.dot(A).dot(L_inv.T)
return jnp.linalg.eigh( A_redo ) |
Ah sorry this is the solution you already posted. Let me see if there is a better way of doing this... |
Ok @PabloAMC I think this code fixes the problem: def generalized_to_standard_eig(A, B):
L = np.linalg.cholesky(B)
L_inv = np.linalg.inv(L)
C = L_inv @ A @ L_inv.T
eigenvalues, eigenvectors_transformed = np.linalg.eigh(C)
eigenvectors_original = L_inv.T @ eigenvectors_transformed
return eigenvalues, eigenvectors_original Basically, the eigenvectors were not transformed back to the original basis. |
You are right, that works, thanks @jackbaker1001. Unfortunately, it still gives errors when backpropagating even when using this version of the generalized eigenvalue problem. It still gives nan errors when I set cycles = 1 in the additional example https://github.com/XanaduAI/DiffDFT/blob/main/examples/basic_examples/example_neural_scf_training.py I included in the basic folder and change the number of cycles to 1, instead of 0, in https://github.com/XanaduAI/DiffDFT/blob/81444ce067a809e3e69bffdf81b4e7c1b5ca4d2b/examples/basic_examples/example_neural_scf_training.py#L138 |
Ok @PabloAMC so because this issue was about a differentiable I will make a new issue regarding differentiating through the SCF iterator (different implementations of DIIS presently) as this is now the problem! |
I think we should reopen this. The problem is numerical stability during the calculation of grads, not a lack of a differentiable implementation. And that does not seem solved yet. |
In #17 we discussed that it would be quite important to implement the training self-consistently. Unfortunately, I was having lots of Nan-related issues in trying to differentiate throughout the self-consistent loop. With a simple example, I have been able to narrow down the problem to an external implementation of the generalized eigenvalue problem, provided in https://github.com/XanaduAI/DiffDFT/blob/main/grad_dft/external/eigh_impl.py and originally reported in https://gist.github.com/jackd/99e012090a56637b8dd8bb037374900e. The author also has some notes in https://jackd.github.io/posts/generalized-eig-jvp/.
The problem is that
jax.scipy.linalg.eigh
only supports the implementation when b = 0, even if it is not reported in the docs https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.eigh.html.Substituting the line https://github.com/XanaduAI/DiffDFT/blob/da7d384fcb1bae547fa9370ae7846bf939a6d580/grad_dft/evaluate.py#L544 by anything else removes the error. For example, even if it nonsense, I can do
which preserves the right matrix shapes, and retains dependency on the Fock matrix (so gradients are not 0). Then, for a very simple LDA model with a single layer etc, the gradients all look good even when the number of scf iterations is 35, though it breaks between 35 to 40 scf iterations.
I think solving this should be a high-priority problem now.
The text was updated successfully, but these errors were encountered: