Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN when fitting with derivative #77

Closed
peastman opened this issue May 2, 2022 · 15 comments
Closed

NaN when fitting with derivative #77

peastman opened this issue May 2, 2022 · 15 comments

Comments

@peastman
Copy link
Collaborator

peastman commented May 2, 2022

I'm trying to fit an equivariant transformer model. If I specify derivative: true in the configuration file to use derivatives in fitting, then after only a few training steps the model output becomes nan. This happens even if I also specify force_weight: 0.0. The derivatives shouldn't affect the loss at all in that case, yet it still causes fitting to fail. The obvious explanation would be if I had a nan in the training data somewhere, since that would cause the loss to also be nan even after multiplying by 0. But I verified that's not the case. Immediately after it computes the loss

loss_dy = loss_fn(deriv, batch.dy)

I added

print(loss_dy, torch.all(torch.isfinite(deriv)), torch.all(torch.isfinite(batch.dy)))

Here's the relevant output from the log.

Epoch 0:   1%|          | 31/5483 [00:06<19:36,  4.63it/s, loss=1.28e+07, v_num=_]tensor(11670.3730, device='cuda:0', grad_fn=<MseLossBackward0>) tensor(True, device='cuda:0') tensor(True, device='cuda:0')
Epoch 0:   1%|          | 32/5483 [00:06<19:32,  4.65it/s, loss=1.25e+07, v_num=_]tensor(273794.6562, device='cuda:0', grad_fn=<MseLossBackward0>) tensor(True, device='cuda:0') tensor(True, device='cuda:0')
Epoch 0:   1%|          | 33/5483 [00:07<19:28,  4.67it/s, loss=1.25e+07, v_num=_]tensor(nan, device='cuda:0', grad_fn=<MseLossBackward0>) tensor(False, device='cuda:0') tensor(True, device='cuda:0')
Epoch 0:   1%|          | 34/5483 [00:07<19:25,  4.68it/s, loss=nan, v_num=_]     tensor(nan, device='cuda:0', grad_fn=<MseLossBackward0>) tensor(False, device='cuda:0') tensor(True, device='cuda:0')

batch.dy never contains a non-finite value.

Any idea what could be causing this?

@PhilippThoelke
Copy link
Collaborator

From your prints it looks like the derivative loss is just exploding after a couple of batches and then becomes NaN. Even though force_weight is 0, the NaN derivative loss multiplied by 0 still renders the global loss NaN:

loss = loss_y * self.hparams.energy_weight + loss_dy * self.hparams.force_weight

We might want to change this so a loss term is only computed if its weighting factor is not equal to 0, however, it is also not default behavior if one of the terms becomes NaN.

I'm guessing the fitting works with derivative: false? How much learning rate warmup are you using? In my experience it helps a lot with chaotic gradients during early training, 10k steps have worked well for me in the past. Have you tried decreasing the learning rate?

The loss generally also seems quite high, are you standardizing your data in some way? E.g. atomrefs or just setting standardize: true to compute the training set's mean and standard deviation before training and scale-shifting the model's prediction accordingly.

@peastman
Copy link
Collaborator Author

peastman commented May 2, 2022

From your prints it looks like the derivative loss is just exploding after a couple of batches and then becomes NaN.

Here's a larger number of steps before the point where it becomes nan:

Epoch 0:   0%|          | 24/5483 [00:05<20:32,  4.43it/s, loss=1.69e+07, v_num=_]tensor(8333.0967, device='cuda:0', grad_fn=<MseLossBackward0>) tensor(True, device='cuda:0') tensor(True, device='cuda:0')
Epoch 0:   0%|          | 25/5483 [00:05<20:19,  4.48it/s, loss=1.68e+07, v_num=_]tensor(9625.0684, device='cuda:0', grad_fn=<MseLossBackward0>) tensor(True, device='cuda:0') tensor(True, device='cuda:0')
Epoch 0:   0%|          | 26/5483 [00:05<20:08,  4.52it/s, loss=1.68e+07, v_num=_]tensor(3447841.5000, device='cuda:0', grad_fn=<MseLossBackward0>) tensor(True, device='cuda:0') tensor(True, device='cuda:0')
Epoch 0:   0%|          | 27/5483 [00:05<20:06,  4.52it/s, loss=1.65e+07, v_num=_]tensor(2137326.7500, device='cuda:0', grad_fn=<MseLossBackward0>) tensor(True, device='cuda:0') tensor(True, device='cuda:0')
Epoch 0:   1%|          | 28/5483 [00:06<20:01,  4.54it/s, loss=1.68e+07, v_num=_]tensor(1.3091e+08, device='cuda:0', grad_fn=<MseLossBackward0>) tensor(True, device='cuda:0') tensor(True, device='cuda:0')
Epoch 0:   1%|          | 29/5483 [00:06<19:55,  4.56it/s, loss=1.85e+07, v_num=_]tensor(669516.5000, device='cuda:0', grad_fn=<MseLossBackward0>) tensor(True, device='cuda:0') tensor(True, device='cuda:0')
Epoch 0:   1%|          | 30/5483 [00:06<19:52,  4.57it/s, loss=1.63e+07, v_num=_]tensor(29315306., device='cuda:0', grad_fn=<MseLossBackward0>) tensor(True, device='cuda:0') tensor(True, device='cuda:0')
Epoch 0:   1%|          | 31/5483 [00:06<19:46,  4.60it/s, loss=1.28e+07, v_num=_]tensor(11670.3730, device='cuda:0', grad_fn=<MseLossBackward0>) tensor(True, device='cuda:0') tensor(True, device='cuda:0')
Epoch 0:   1%|          | 32/5483 [00:06<19:41,  4.61it/s, loss=1.25e+07, v_num=_]tensor(273794.6562, device='cuda:0', grad_fn=<MseLossBackward0>) tensor(True, device='cuda:0') tensor(True, device='cuda:0')
Epoch 0:   1%|          | 33/5483 [00:07<19:36,  4.63it/s, loss=1.25e+07, v_num=_]tensor(nan, device='cuda:0', grad_fn=<MseLossBackward0>) tensor(False, device='cuda:0') tensor(True, device='cuda:0')

As you can see, the derivative loss doesn't explode. It fluctuates but doesn't show any obvious trend. On the step just before things become nan, it's actually much smaller than on some earlier steps.

Even though force_weight is 0, the NaN derivative loss multiplied by 0 still renders the global loss NaN:

The derivative loss only becomes nan because all of the outputs have become nan. The regular loss also becomes nan at the same time. Until that happens, the derivative loss is finite and well behaved. Multiplying it by 0 should therefore mean that the derivatives have no effect at all on the fitting process.

I'm guessing the fitting works with derivative: false?

Correct, at least that it doesn't become nan. I've been unable to get the model to actually fit the data very well, though.

The loss generally also seems quite high, are you standardizing your data in some way?

I do a least squares fit to find a reference energy for each atom type and subtract the per-atom energies from each molecule.

@PhilippThoelke
Copy link
Collaborator

The derivative loss only becomes nan because all of the outputs have become nan. The regular loss also becomes nan at the same time.

But the regular loss only becomes nan because the nan derivative loss is added to it, even though multiplied by 0. So using the current code, the derivative loss does influence training if it is nan, even if force_weight == 0 because (energy_loss + 0 * nan) == nan. This results in gradients becoming nan, therefore also weights and the model breaks.

We could change this behavior by also checking for force_weight != 0 here

if self.hparams.derivative:

However, force_weight: 0 and derivative: true is a nonsensical combination of arguments as it does nothing (except the derivative becomes nan but then something else is broken) except slowing down training.

Until that happens, the derivative loss is finite and well behaved.

I wouldn't call losses as high as 1.3091e+08 well behaved. Since you're saying that the model's derivative output becomes nan and it's not the mean squared error that produces it, it seems like some operation in the model has an nan gradient with respect to atom coordinates. Could you try running the builtin torch anomaly detection to see which operation introduces the nan? You can enable it by calling torch.autograd.set_detect_anomaly(True) at some point before training begins. It will then raise an exception when the autograd engine encounters nan somewhere and print the traceback that lead to it.

But since the derivative losses are already very large and chaotic before they become nan, I expect the problem to be in the data. This however is also a bit confusing to me as this occurs only a few batches into training where I would expect the model to have smoother gradients with respect to coordinates.

What are your learning rate warmup and learning rate set to?

@peastman
Copy link
Collaborator Author

peastman commented May 2, 2022

But the regular loss only becomes nan because the nan derivative loss is added to it, even though multiplied by 0.

Yet the derivatives (and derivative loss) are not nan. That's the point here. They remain finite until after the model has already somehow gotten messed up and started producing nan for all outputs. The nan derivative loss is an effect, not a cause. The only explanation I can think of is that the second derivative of the output (that is, the gradient of the derivative loss) is becoming nan even though the derivative loss itself is not. What could cause that?

We could change this behavior by also checking for force_weight != 0 here

It's not that I want to have force_weight be 0 in practice. I just tested that as part of debugging. I want to use the derivatives for fitting. Initially I assumed the nans must mean I had too large a force weight, or too high a learning rate, or something wrong with my training data. But the problem persists even when force_weight is 0, which proves the problem isn't any of those things.

@peastman
Copy link
Collaborator Author

peastman commented May 2, 2022

Here's what I get from the anomaly detector.

[W python_anomaly_mode.cpp:104] Warning: Error detected in DivBackward0. Traceback of forward call that caused the error:
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/traceback.py", line 197, in format_stack
    return format_list(extract_stack(f, limit=limit))
 (function _print_stack)
[W python_anomaly_mode.cpp:109] Warning: 

Previous calculation was induced by NormBackward1. Traceback of forward call that induced the previous calculation:
  File "/home/peastman/workspace/torchmd-net/scripts/train.py", line 174, in <module>
    main()
  File "/home/peastman/workspace/torchmd-net/scripts/train.py", line 167, in main
    trainer.fit(model, data)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 460, in fit
    self._run(model)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 758, in _run
    self.dispatch()
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 799, in dispatch
    self.accelerator.start_training(self)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training
    self._results = trainer.run_stage()
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 809, in run_stage
    return self.run_train()
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 871, in run_train
    self.train_loop.run_training_epoch()
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 499, in run_training_epoch
    batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 738, in run_training_batch
    self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 442, in optimizer_step
    using_lbfgs=is_lbfgs,
  File "/home/peastman/workspace/torchmd-net/torchmdnet/module.py", line 143, in optimizer_step
    super().optimizer_step(*args, **kwargs)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/core/lightning.py", line 1403, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py", line 214, in step
    self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py", line 134, in __optimizer_step
    trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 329, in optimizer_step
    self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 336, in run_optimizer_step
    self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 193, in optimizer_step
    optimizer.step(closure=lambda_closure, **kwargs)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch/optim/optimizer.py", line 88, in wrapper
    return func(*args, **kwargs)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch/optim/adamw.py", line 92, in step
    loss = closure()
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 733, in train_step_and_backward_closure
    split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 823, in training_step_and_backward
    result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 290, in training_step
    training_step_output = self.trainer.accelerator.training_step(args)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 204, in training_step
    return self.training_type_plugin.training_step(*args)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 337, in training_step
    return self.model(*args, **kwargs)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 886, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/pytorch_lightning/overrides/base.py", line 46, in forward
    output = self.module.training_step(*inputs, **kwargs)
  File "/home/peastman/workspace/torchmd-net/torchmdnet/module.py", line 60, in training_step
    return self.step(batch, mse_loss, "train")
  File "/home/peastman/workspace/torchmd-net/torchmdnet/module.py", line 78, in step
    s=batch.s if self.hparams.spin else None)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/peastman/workspace/torchmd-net/torchmdnet/module.py", line 57, in forward
    return self.model(z, pos, batch=batch, q=q, s=s)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/peastman/workspace/torchmd-net/torchmdnet/models/model.py", line 173, in forward
    x = self.output_model.pre_reduce(x, v, z, pos, batch)
  File "/home/peastman/workspace/torchmd-net/torchmdnet/models/output_modules.py", line 74, in pre_reduce
    x, v = layer(x, v)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/peastman/workspace/torchmd-net/torchmdnet/models/utils.py", line 281, in forward
    vec1 = torch.norm(self.vec1_proj(v), dim=-2)
  File "/home/peastman/miniconda3/envs/torchmd/lib/python3.7/site-packages/torch/functional.py", line 1442, in norm
    return _VF.frobenius_norm(input, _dim, keepdim=keepdim)
 (function _print_stack)

@PhilippThoelke
Copy link
Collaborator

The only explanation I can think of is that the second derivative of the output (that is, the gradient of the derivative loss) is becoming nan even though the derivative loss itself is not. What could cause that?

That is possible but I would still argue that this comes from something in the data as this only occurs after a couple of batches, not immediately, and training using the derivatives has worked without problems in the past. The anomaly detection tells us that the call to torch.norm in the gated equivariant block produced the nan during the backward call of a division.

vec1 = torch.norm(self.vec1_proj(v), dim=-2)

I can look into it if you can provide example code to reproduce this, otherwise I suggest looking at the tensor that gets passed to torch.norm and look for what could cause nan to appear during the backward pass.

I came across torch.norm producing nan gradients before. Here is a small example of what happened for me in the past:

import torch
c = torch.zeros(1, requires_grad=True)
x = torch.randn(1, requires_grad=True)
y = torch.norm(c) * x
deriv = torch.autograd.grad(y, c, create_graph=True)[0]
deriv.backward()
# x.grad == tensor([nan])

The important part here is that c is zero and that you don't compute the second derivative of y with respect to c, but to some other tensor. The second derivative of torch.norm(0) is well defined in pytorch.

In your case v or the weights of self.vec1_proj would have to be zero for this to occur though, which shouldn't really happen.

@peastman
Copy link
Collaborator Author

peastman commented May 3, 2022

I've been tracing through the code to figure out where the nan first appears. It's coming from the embedding at this line:

x = self.embedding(z)

It starts producing all nans for one particular atom type (13). I added this line just before it:

print(self.embedding(torch.tensor([13], device=z.device)))

For the first 33 batches it prints reasonable values:

tensor([[ 1.2304e+00, -5.6380e-01, -7.2504e-01, -7.5315e-01,  1.4812e-01,
         -1.0052e-01,  4.1293e-01,  2.0958e-01,  1.6800e-03,  1.1809e+00,
         -2.6126e-01,  3.2459e-01, -1.1151e+00,  1.4460e+00,  7.9803e-01,
         -1.4435e+00, -4.5231e-01,  2.1724e-01, -4.9895e-01, -4.3095e-01,
         -3.2661e-01,  1.6516e+00,  1.6782e+00, -7.1074e-01,  3.9565e-01,
          9.0223e-01, -5.0934e-01,  7.7105e-01,  5.6627e-01, -2.1207e+00,
         -8.3474e-01,  2.0902e-01,  2.0115e+00,  1.0476e+00,  5.2538e-01,
         -1.2637e-01,  1.2200e+00, -6.5442e-01, -3.5076e+00,  1.4348e-01,
         -2.7620e-01, -9.8142e-01,  1.8136e-01, -5.2072e-01,  4.7254e-01,
          1.2096e-01,  1.8359e-01,  6.5353e-01, -7.0402e-01,  8.7412e-01,
         -8.2406e-01, -1.4989e+00, -1.6871e+00, -4.3589e-01,  2.4797e-01,
         -1.1724e+00,  1.4642e+00, -1.7117e+00, -9.5990e-01, -1.5022e-01,
          4.9057e-01, -4.8097e-01,  2.5388e-01,  2.2166e-01]], device='cuda:0',
       grad_fn=<EmbeddingBackward0>)

And then it abruptly changes to

tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]],
       device='cuda:0', grad_fn=<EmbeddingBackward0>)

The error in torch.norm() is a consequence of this. It happens because the input is all nan.

@PhilippThoelke
Copy link
Collaborator

Did you check the gradient torch.norm produces in the training step before the embedding layer produces nan? I would assume that the torch.norm backward call produces nans somewhere before the embedding generates nans in a forward pass, otherwise the anomaly detection would have reported some earlier operation in the network (I believe the GatedEquivariantBlocks only occur in the output network).

My guess is that the backward call of torch.norm produces nan due to some input being zero (probably some atom with atom type 13), which then produces nan gradients, making the embedding weights become nan. Then, in the next forward pass you notice the embedding layer generating nan because its weights for atom type 13 have been updated to nan in the last backward pass.

@peastman
Copy link
Collaborator Author

peastman commented May 3, 2022

I'm working on tracking it further. I set the batch size to 1 so I can identify the exact sample where it happens. Here is the batch immediately before it becomes nan. Nothing looks obviously wrong with it.

y tensor([[-13.7331]], device='cuda:0')
z tensor([21, 10, 10, 13], device='cuda:0')
pos tensor([[-0.0897,  2.9169, -0.9472],
        [ 0.4741,  3.7232, -0.9294],
        [-0.4512,  2.8872, -1.8618],
        [ 0.1198, -3.7892,  1.2756]], device='cuda:0')
dy tensor([[ -24.1020,  -75.6053,   99.4888],
        [  68.4237,   91.6197,    6.4975],
        [ -44.4332,  -13.0008, -106.2225],
        [   0.1146,   -3.0116,    0.2329]], device='cuda:0')

@peastman
Copy link
Collaborator Author

peastman commented May 3, 2022

Continuing to track it down. Here is the call to torch.norm().

def forward(self, x, v):
vec1 = torch.norm(self.vec1_proj(v), dim=-2)

On the step that creates problems, here is the value of v:

tensor([[[-1.1570e-04,  1.2356e-02,  1.3049e-02, -1.6728e-02, -9.5599e-04,
          -1.4275e-02,  3.2444e-02, -3.5882e-02, -2.5314e-02,  1.5275e-02,
           1.8673e-02, -8.8951e-03, -9.1384e-03, -5.2834e-03,  8.9776e-03,
          -1.0675e-02, -1.0605e-02,  2.3503e-02, -2.3497e-02,  7.2124e-02,
          -1.4842e-02, -1.5459e-02,  1.5129e-02,  1.8291e-03,  1.9483e-02,
          -4.3428e-03,  1.5156e-02,  7.4844e-03, -5.1067e-03,  1.4550e-02,
           2.5732e-02, -1.5664e-02, -6.0905e-03, -9.8123e-03, -7.0684e-03,
           1.9341e-02,  9.7827e-04,  1.1466e-02, -2.4368e-02,  1.2992e-02,
          -1.6020e-02, -8.7237e-03,  7.7276e-03, -9.4767e-03,  9.0692e-03,
          -1.0772e-02, -1.2357e-02, -2.9550e-02,  7.9949e-03, -1.2781e-02,
          -1.4876e-02,  2.5431e-02,  9.9952e-03,  1.5916e-02,  7.6407e-03,
          -1.9517e-03, -4.7064e-04, -1.4775e-02,  2.2727e-02, -1.0818e-02,
          -1.6621e-02, -9.0240e-03, -3.0447e-02, -1.9935e-02],
         [-4.4596e-04,  4.7462e-02,  5.0119e-02, -6.4254e-02, -3.6721e-03,
          -5.4823e-02,  1.2461e-01, -1.3783e-01, -9.7244e-02,  5.8682e-02,
           7.1718e-02, -3.4169e-02, -3.5104e-02, -2.0298e-02,  3.4491e-02,
          -4.1000e-02, -4.0737e-02,  9.0276e-02, -9.0252e-02,  2.7704e-01,
          -5.7007e-02, -5.9376e-02,  5.8112e-02,  7.0303e-03,  7.4845e-02,
          -1.6679e-02,  5.8208e-02,  2.8735e-02, -1.9619e-02,  5.5882e-02,
           9.8840e-02, -6.0169e-02, -2.3403e-02, -3.7691e-02, -2.7145e-02,
           7.4286e-02,  3.7611e-03,  4.4048e-02, -9.3600e-02,  4.9906e-02,
          -6.1530e-02, -3.3509e-02,  2.9696e-02, -3.6405e-02,  3.4836e-02,
          -4.1379e-02, -4.7466e-02, -1.1351e-01,  3.0711e-02, -4.9108e-02,
          -5.7132e-02,  9.7681e-02,  3.8387e-02,  6.1134e-02,  2.9361e-02,
          -7.4980e-03, -1.8058e-03, -5.6749e-02,  8.7294e-02, -4.1547e-02,
          -6.3853e-02, -3.4636e-02, -1.1696e-01, -7.6556e-02],
         [ 5.1616e-04, -5.4819e-02, -5.7884e-02,  7.4210e-02,  4.2412e-03,
           6.3312e-02, -1.4392e-01,  1.5920e-01,  1.1232e-01, -6.7783e-02,
          -8.2829e-02,  3.9466e-02,  4.0547e-02,  2.3447e-02, -3.9840e-02,
           4.7352e-02,  4.7052e-02, -1.0426e-01,  1.0424e-01, -3.1998e-01,
           6.5841e-02,  6.8575e-02, -6.7119e-02, -8.1228e-03, -8.6449e-02,
           1.9263e-02, -6.7223e-02, -3.3178e-02,  2.2661e-02, -6.4537e-02,
          -1.1416e-01,  6.9497e-02,  2.7036e-02,  4.3532e-02,  3.1348e-02,
          -8.5796e-02, -4.3464e-03, -5.0877e-02,  1.0811e-01, -5.7641e-02,
           7.1062e-02,  3.8703e-02, -3.4307e-02,  4.2049e-02, -4.0235e-02,
           4.7794e-02,  5.4824e-02,  1.3110e-01, -3.5471e-02,  5.6730e-02,
           6.5982e-02, -1.1282e-01, -4.4331e-02, -7.0606e-02, -3.3919e-02,
           8.6609e-03,  2.0842e-03,  6.5542e-02, -1.0082e-01,  4.7980e-02,
           7.3756e-02,  3.9984e-02,  1.3509e-01,  8.8409e-02]],

        [[-5.3662e-02, -1.3002e-03,  6.7549e-02, -1.8016e-01, -2.8846e-02,
          -1.0311e-01, -1.6201e-01, -1.0164e-01, -2.8890e-02, -1.8539e-02,
           6.5367e-02,  8.0637e-02,  3.7315e-02,  8.9029e-02, -6.7195e-02,
          -6.6469e-02, -4.5529e-03,  6.7728e-02, -7.8329e-02,  1.5945e-01,
          -2.3219e-02, -5.6128e-02,  1.5219e-01, -8.7525e-02, -1.1971e-01,
          -1.1049e-01,  9.3875e-02,  4.3002e-02,  1.0266e-01, -8.6096e-02,
          -2.3253e-01,  6.9009e-02,  6.8919e-02,  3.1887e-02, -1.3905e-01,
          -7.8211e-02, -8.9267e-02,  6.6138e-03,  6.6659e-02, -8.2718e-02,
          -1.8479e-01,  7.0869e-02, -4.4919e-03,  2.0764e-02, -6.2990e-02,
           7.4010e-02, -3.9534e-03, -1.0274e-01, -2.5200e-02, -6.2025e-02,
          -8.8812e-02, -5.3786e-02,  1.0888e-01, -1.9360e-01, -5.9768e-02,
           1.0663e-01,  5.6589e-02, -2.2376e-01,  1.7173e-01,  4.9487e-02,
          -8.9506e-02,  2.3867e-01, -9.1645e-02,  7.5434e-02],
         [-8.3304e-02, -2.2644e-02,  6.9337e-02, -2.2757e-01, -4.4984e-02,
          -9.5871e-02, -1.9985e-01, -1.3295e-01, -2.0204e-02, -1.8339e-03,
           8.6053e-02,  3.6800e-02,  7.6971e-02,  8.9832e-02, -9.4344e-02,
          -1.0008e-01,  1.4580e-02,  8.8303e-02, -1.5200e-01,  1.8036e-01,
          -1.5019e-02, -8.1709e-02,  1.6719e-01, -1.5630e-01, -9.9280e-02,
          -1.3909e-01,  1.3272e-01,  5.9805e-02,  8.4932e-02, -7.7291e-02,
          -2.5447e-01,  1.1055e-01,  8.2119e-02,  5.0159e-02, -1.3833e-01,
          -7.7866e-02, -8.1785e-02,  8.0914e-04,  5.1539e-02, -1.4258e-01,
          -1.9879e-01,  1.1336e-01,  2.1038e-02,  4.2901e-02, -6.0271e-02,
           8.7875e-02, -1.6966e-02, -9.9128e-02, -3.2679e-02, -3.3978e-02,
          -1.0476e-01, -7.4457e-02,  1.4345e-01, -2.6732e-01, -8.5677e-02,
           1.4048e-01,  7.7028e-02, -2.8253e-01,  2.2968e-01,  6.5828e-02,
          -1.1895e-01,  2.9679e-01, -9.9255e-02,  3.5689e-02],
         [ 1.0446e-02,  3.8479e-02,  5.2687e-02, -6.1507e-02,  5.9944e-03,
          -9.8899e-02, -6.4179e-02, -2.6240e-02, -4.0050e-02, -4.6329e-02,
           1.5854e-02,  1.4810e-01, -4.2558e-02,  7.2323e-02, -5.3927e-03,
           7.1819e-03, -3.9233e-02,  1.8020e-02,  7.1589e-02,  9.3443e-02,
          -3.4447e-02,  8.7476e-04,  9.8378e-02,  5.4898e-02, -1.3711e-01,
          -3.8591e-02,  5.8264e-03,  4.5084e-03,  1.1796e-01, -8.7695e-02,
          -1.5212e-01, -1.9778e-02,  3.2672e-02, -7.4267e-03, -1.1662e-01,
          -6.5481e-02, -8.7874e-02,  1.6241e-02,  8.3286e-02,  4.2366e-02,
          -1.2726e-01, -1.9998e-02, -5.1038e-02, -2.3814e-02, -5.7262e-02,
           3.5662e-02,  2.0839e-02, -9.1860e-02, -7.0310e-03, -1.0340e-01,
          -4.4073e-02, -6.2840e-03,  2.6198e-02, -2.3868e-02, -1.5335e-03,
           2.5658e-02,  9.0364e-03, -7.6584e-02,  3.4969e-02,  1.0744e-02,
          -1.9642e-02,  9.0142e-02, -6.1875e-02,  1.3620e-01]],

        [[ 2.9959e-02, -1.3315e-02, -6.1897e-02,  1.3606e-01,  1.5966e-02,
           1.0127e-01,  1.2561e-01,  7.3653e-02,  3.2908e-02,  2.8695e-02,
          -4.6991e-02, -1.0519e-01, -7.8683e-03, -8.2637e-02,  4.4302e-02,
           3.9223e-02,  1.7279e-02, -4.9274e-02,  2.3031e-02, -1.3475e-01,
           2.7277e-02,  3.5026e-02, -1.3199e-01,  3.4953e-02,  1.2576e-01,
           8.3764e-02, -6.1263e-02, -2.8732e-02, -1.0798e-01,  8.6438e-02,
           2.0233e-01, -3.6198e-02, -5.5407e-02, -1.7355e-02,  1.3042e-01,
           7.3313e-02,  8.8501e-02, -1.0134e-02, -7.2575e-02,  3.6532e-02,
           1.6313e-01, -3.7287e-02,  2.1578e-02, -4.3274e-03,  6.0709e-02,
          -5.9712e-02, -5.1663e-03,  9.8448e-02,  1.8454e-02,  7.7048e-02,
           7.2126e-02,  3.6180e-02, -7.8196e-02,  1.3069e-01,  3.8203e-02,
          -7.6573e-02, -3.8959e-02,  1.6905e-01, -1.2100e-01, -3.5113e-02,
           6.3588e-02, -1.8343e-01,  8.0450e-02, -9.7542e-02],
         [-7.7452e-03, -3.3503e-02, -4.7628e-02,  5.8166e-02, -4.4944e-03,
           8.8804e-02,  6.0035e-02,  2.5443e-02,  3.5637e-02,  4.0847e-02,
          -1.5473e-02, -1.3111e-01,  3.6146e-02, -6.5284e-02,  6.4016e-03,
          -4.5761e-03,  3.4307e-02, -1.7419e-02, -6.0408e-02, -8.5475e-02,
           3.0609e-02,  6.5815e-04, -8.9592e-02, -4.5629e-02,  1.2253e-01,
           3.6431e-02, -7.4545e-03, -5.0170e-03, -1.0540e-01,  7.8608e-02,
           1.3846e-01,  1.5490e-02, -3.0219e-02,  5.6653e-03,  1.0516e-01,
           5.9048e-02,  7.8844e-02, -1.4323e-02, -7.4272e-02, -3.4828e-02,
           1.1559e-01,  1.5634e-02,  4.4594e-02,  2.0229e-02,  5.1500e-02,
          -3.2954e-02, -1.8062e-02,  8.2660e-02,  6.7658e-03,  9.1686e-02,
           4.0660e-02,  6.8380e-03, -2.5589e-02,  2.5703e-02,  2.8494e-03,
          -2.5060e-02, -9.3077e-03,  7.2410e-02, -3.4824e-02, -1.0616e-02,
           1.9385e-02, -8.4604e-02,  5.6246e-02, -1.2061e-01],
         [ 9.4715e-02,  2.6375e-02, -7.7761e-02,  2.5717e-01,  5.1156e-02,
           1.0706e-01,  2.2567e-01,  1.5041e-01,  2.2224e-02,  1.2684e-03,
          -9.7370e-02, -3.9167e-02, -8.8096e-02, -1.0067e-01,  1.0697e-01,
           1.1369e-01, -1.7234e-02, -9.9887e-02,  1.7373e-01, -2.0303e-01,
           1.6439e-02,  9.2733e-02, -1.8800e-01,  1.7832e-01,  1.1026e-01,
           1.5717e-01, -1.5051e-01, -6.7772e-02, -9.4314e-02,  8.6174e-02,
           2.8611e-01, -1.2580e-01, -9.2617e-02, -5.7051e-02,  1.5493e-01,
           8.7219e-02,  9.1268e-02, -6.3852e-04, -5.7025e-02,  1.6254e-01,
           2.2335e-01, -1.2899e-01, -2.4765e-02, -4.9098e-02,  6.7397e-02,
          -9.9095e-02,  1.9616e-02,  1.1088e-01,  3.6962e-02,  3.6750e-02,
           1.1811e-01,  8.4375e-02, -1.6233e-01,  3.0294e-01,  9.7195e-02,
          -1.5895e-01, -8.7253e-02,  3.1927e-01, -2.6002e-01, -7.4512e-02,
           1.3464e-01, -3.3522e-01,  1.1155e-01, -3.8115e-02]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]],
       device='cuda:0', grad_fn=<AddBackward0>)

The final row of it consists entirely of zero. That tensor gets generated as vec in TorchMD_ET.forward():

vec = torch.zeros(x.size(0), 3, x.size(1), device=x.device)
for attn in self.attention_layers:
dx, dvec = attn(x, vec, edge_index, edge_weight, edge_attr, edge_vec)
x = x + dx
vec = vec + dvec

It gets initialized to all zeros, and then the attention layers add increments to it. But none of them adds anything to the final row, leaving it all zero.

@PhilippThoelke
Copy link
Collaborator

I'm guessing that last atom is outside the cutoff radius of all other atoms. The closest it is to another atom is about 7A. Is this outside your upper cutoff? If yes it will never interact with another atom inside the model, leaving all the features the way they were initialized, which is 0 for the vector features.

@PhilippThoelke
Copy link
Collaborator

PhilippThoelke commented May 3, 2022

I think this should throw a debuggable error or even just skip the gradients of vector features that are 0. We could do something similar to this

mask = edge_index[0] != edge_index[1]
edge_weight = torch.zeros(edge_vec.size(0), device=edge_vec.device)
edge_weight[mask] = torch.norm(edge_vec[mask], dim=-1)

Here we basically detach zero-entries in the distance calculation before calling torch.norm so they are not included in the backward pass, which would also cause nan gradients.

@giadefa @peastman what do you think, should this just throw an error or continue without considering these gradients and maybe print a warning?

@peastman
Copy link
Collaborator Author

peastman commented May 3, 2022

Having it continue without considering the gradients sounds like the right solution.

It looks like this particular sample is a DES370K dimer. It consists of a water and K+ ion, which are just over 7A apart. Printing a warning does make sense, so people will know there are samples their model can't possibly fit without increasing the cutoff.

@peastman
Copy link
Collaborator Author

peastman commented May 3, 2022

I also can confirm that increasing the cutoff to 8A makes the error go away.

@PhilippThoelke
Copy link
Collaborator

@peastman I have created a pull request (#79) for what we discussed. Could you make sure it fixes your issue by training using this branch?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants