You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@mstebelev found out that memory efficient attention kernel on float32 cuda tensors gives nan gradients despite inputs and incoming gradient are reasonably limited. Math backend doesn't produce nans with this input.
below for a minimal and simple script to reproduce the issue(from @otto-dev)
import torch
import torch.nn as nn
import torch.optim as optim
embed_dim = 1024
batch_size = 32
seq_length = 50
num_iterations = 100000
learning_rate = 0.01
device = torch.device("cuda")
torch.autograd.set_detect_anomaly(True)
# Initialize model
model = nn.TransformerEncoderLayer( embed_dim, 16, dropout=0.01, batch_first=True).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
inputs = torch.full((seq_length, batch_size, embed_dim), 1000.0, device=device)
# Training loop
for i in range(num_iterations):
print(f'Iteration {i + 1}')
optimizer.zero_grad()
output = model(inputs)
loss = output.mean()
loss.backward()
optimizer.step()
Turns out it's very easy to reproduce: Train a TransformerEncoderLayer on a tensor that contains nothing but the number 1000 a few times. The higher the number (here, 1000), the quicker the error occurs.
Reproduction script:
$ python repro.py
Iteration 1
Iteration 2
Iteration 3
Iteration 4
Iteration 5
/otto-dev/venv/lib/python3.11/site-packages/torch/autograd/graph.py:767: UserWarning: Error detected in ScaledDotProductEfficientAttentionBackward0. Traceback of forward call that caused the error:
File "/otto-dev/repro.py", line 22, in <module>
output = model(inputs)
File "/otto-dev/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/otto-dev/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/otto-dev/venv/lib/python3.11/site-packages/torch/nn/modules/transformer.py", line 748, in forward
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal))
File "/otto-dev/venv/lib/python3.11/site-packages/torch/nn/modules/transformer.py", line 756, in _sa_block
x = self.self_attn(x, x, x,
File "/otto-dev/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/otto-dev/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/otto-dev/venv/lib/python3.11/site-packages/torch/nn/modules/activation.py", line 1266, in forward
attn_output, attn_output_weights = F.multi_head_attention_forward(
File "/otto-dev/venv/lib/python3.11/site-packages/torch/nn/functional.py", line 5527, in multi_head_attention_forward
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:111.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
File "/otto-dev/repro.py", line 24, in <module>
loss.backward()
File "/otto-dev/venv/lib/python3.11/site-packages/torch/_tensor.py", line 523, in backward
torch.autograd.backward(
File "/otto-dev/venv/lib/python3.11/site-packages/torch/autograd/__init__.py", line 267, in backward
_engine_run_backward(
File "/otto-dev/venv/lib/python3.11/site-packages/torch/autograd/graph.py", line 767, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Function 'ScaledDotProductEfficientAttentionBackward0' returned nan values in its 0th output.
This exception will not be triggered immediately during training, but will be triggered after training for many epochs.
the issue from memory efficient attention kernel and Math backend doesn't produce nans with this input.
If you set norm_first = True I am unable to reproduce the error, the fused kernels are known to be more sensitive to large values, I have seen this directly for FAv2 because of its use of FMA.. I don think thats the case here from the first update though.
馃悰 Describe the bug
reproduce the bug
@mstebelev found out that memory efficient attention kernel on float32 cuda tensors gives nan gradients despite inputs and incoming gradient are reasonably limited. Math backend doesn't produce nans with this input.
below for a minimal and simple script to reproduce the issue(from @otto-dev)
Turns out it's very easy to reproduce: Train a TransformerEncoderLayer on a tensor that contains nothing but the number 1000 a few times. The higher the number (here, 1000), the quicker the error occurs.
Reproduction script:
This exception will not be triggered immediately during training, but will be triggered after training for many epochs.
the issue from memory efficient attention kernel and Math backend doesn't produce nans with this input.
code from @SamAdamDay
more evidence and disscussion
It is clear that this is a very serious bug that can affect many researcher's model train process and experimental results. It should be fixed quickly
cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @bhosmer @cpuhrsch @erichan1 @drisspg
Versions
2.1.0+cu121 - 2.4.0.dev20240427+cu121
The text was updated successfully, but these errors were encountered: