New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NaN when fitting with derivative #77
Comments
From your prints it looks like the derivative loss is just exploding after a couple of batches and then becomes NaN. Even though torchmd-net/torchmdnet/module.py Line 125 in b9785d2
We might want to change this so a loss term is only computed if its weighting factor is not equal to 0, however, it is also not default behavior if one of the terms becomes NaN. I'm guessing the fitting works with The loss generally also seems quite high, are you standardizing your data in some way? E.g. atomrefs or just setting |
Here's a larger number of steps before the point where it becomes nan:
As you can see, the derivative loss doesn't explode. It fluctuates but doesn't show any obvious trend. On the step just before things become nan, it's actually much smaller than on some earlier steps.
The derivative loss only becomes nan because all of the outputs have become nan. The regular loss also becomes nan at the same time. Until that happens, the derivative loss is finite and well behaved. Multiplying it by 0 should therefore mean that the derivatives have no effect at all on the fitting process.
Correct, at least that it doesn't become nan. I've been unable to get the model to actually fit the data very well, though.
I do a least squares fit to find a reference energy for each atom type and subtract the per-atom energies from each molecule. |
But the regular loss only becomes nan because the nan derivative loss is added to it, even though multiplied by 0. So using the current code, the derivative loss does influence training if it is nan, even if We could change this behavior by also checking for torchmd-net/torchmdnet/module.py Line 80 in b9785d2
However, force_weight: 0 and derivative: true is a nonsensical combination of arguments as it does nothing (except the derivative becomes nan but then something else is broken) except slowing down training.
I wouldn't call losses as high as 1.3091e+08 well behaved. Since you're saying that the model's derivative output becomes nan and it's not the mean squared error that produces it, it seems like some operation in the model has an nan gradient with respect to atom coordinates. Could you try running the builtin torch anomaly detection to see which operation introduces the nan? You can enable it by calling But since the derivative losses are already very large and chaotic before they become nan, I expect the problem to be in the data. This however is also a bit confusing to me as this occurs only a few batches into training where I would expect the model to have smoother gradients with respect to coordinates. What are your learning rate warmup and learning rate set to? |
Yet the derivatives (and derivative loss) are not nan. That's the point here. They remain finite until after the model has already somehow gotten messed up and started producing nan for all outputs. The nan derivative loss is an effect, not a cause. The only explanation I can think of is that the second derivative of the output (that is, the gradient of the derivative loss) is becoming nan even though the derivative loss itself is not. What could cause that?
It's not that I want to have |
Here's what I get from the anomaly detector.
|
That is possible but I would still argue that this comes from something in the data as this only occurs after a couple of batches, not immediately, and training using the derivatives has worked without problems in the past. The anomaly detection tells us that the call to torchmd-net/torchmdnet/models/utils.py Line 281 in b9785d2
I can look into it if you can provide example code to reproduce this, otherwise I suggest looking at the tensor that gets passed to I came across import torch
c = torch.zeros(1, requires_grad=True)
x = torch.randn(1, requires_grad=True)
y = torch.norm(c) * x
deriv = torch.autograd.grad(y, c, create_graph=True)[0]
deriv.backward()
# x.grad == tensor([nan]) The important part here is that In your case |
I've been tracing through the code to figure out where the nan first appears. It's coming from the embedding at this line: torchmd-net/torchmdnet/models/torchmd_et.py Line 158 in b9785d2
It starts producing all nans for one particular atom type (13). I added this line just before it: print(self.embedding(torch.tensor([13], device=z.device))) For the first 33 batches it prints reasonable values:
And then it abruptly changes to
The error in |
Did you check the gradient My guess is that the backward call of |
I'm working on tracking it further. I set the batch size to 1 so I can identify the exact sample where it happens. Here is the batch immediately before it becomes nan. Nothing looks obviously wrong with it.
|
Continuing to track it down. Here is the call to torchmd-net/torchmdnet/models/utils.py Lines 280 to 281 in b9785d2
On the step that creates problems, here is the value of
The final row of it consists entirely of zero. That tensor gets generated as torchmd-net/torchmdnet/models/torchmd_et.py Lines 172 to 177 in b9785d2
It gets initialized to all zeros, and then the attention layers add increments to it. But none of them adds anything to the final row, leaving it all zero. |
I'm guessing that last atom is outside the cutoff radius of all other atoms. The closest it is to another atom is about 7A. Is this outside your upper cutoff? If yes it will never interact with another atom inside the model, leaving all the features the way they were initialized, which is 0 for the vector features. |
I think this should throw a debuggable error or even just skip the gradients of vector features that are 0. We could do something similar to this torchmd-net/torchmdnet/models/utils.py Lines 223 to 225 in b9785d2
Here we basically detach zero-entries in the distance calculation before calling torch.norm so they are not included in the backward pass, which would also cause nan gradients.
@giadefa @peastman what do you think, should this just throw an error or continue without considering these gradients and maybe print a warning? |
Having it continue without considering the gradients sounds like the right solution. It looks like this particular sample is a DES370K dimer. It consists of a water and K+ ion, which are just over 7A apart. Printing a warning does make sense, so people will know there are samples their model can't possibly fit without increasing the cutoff. |
I also can confirm that increasing the cutoff to 8A makes the error go away. |
I'm trying to fit an equivariant transformer model. If I specify
derivative: true
in the configuration file to use derivatives in fitting, then after only a few training steps the model output becomes nan. This happens even if I also specifyforce_weight: 0.0
. The derivatives shouldn't affect the loss at all in that case, yet it still causes fitting to fail. The obvious explanation would be if I had a nan in the training data somewhere, since that would cause the loss to also be nan even after multiplying by 0. But I verified that's not the case. Immediately after it computes the losstorchmd-net/torchmdnet/module.py
Line 89 in b9785d2
I added
Here's the relevant output from the log.
batch.dy
never contains a non-finite value.Any idea what could be causing this?
The text was updated successfully, but these errors were encountered: