Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]Nan in gradients of scaled_dot_product_attention operation with mem_efficient backend #125674

Open
walkacross opened this issue May 7, 2024 · 1 comment
Labels
oncall: transformer/mha Issues related to Transformers and MultiheadAttention

Comments

@walkacross
Copy link

walkacross commented May 7, 2024

馃悰 Describe the bug

reproduce the bug

@mstebelev found out that memory efficient attention kernel on float32 cuda tensors gives nan gradients despite inputs and incoming gradient are reasonably limited. Math backend doesn't produce nans with this input.

below for a minimal and simple script to reproduce the issue(from @otto-dev)

$ pip show torch
Name: torch
Version: 2.4.0.dev20240427+cu121

$ python --version
Python 3.11.6
import torch
import torch.nn as nn
import torch.optim as optim

embed_dim = 1024
batch_size = 32
seq_length = 50
num_iterations = 100000
learning_rate = 0.01
device = torch.device("cuda")
torch.autograd.set_detect_anomaly(True)

# Initialize model
model = nn.TransformerEncoderLayer( embed_dim, 16, dropout=0.01, batch_first=True).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
inputs = torch.full((seq_length, batch_size, embed_dim), 1000.0, device=device)

# Training loop
for i in range(num_iterations):
    print(f'Iteration {i + 1}')
    optimizer.zero_grad()
    output = model(inputs)
    loss = output.mean()
    loss.backward()
    optimizer.step()

Turns out it's very easy to reproduce: Train a TransformerEncoderLayer on a tensor that contains nothing but the number 1000 a few times. The higher the number (here, 1000), the quicker the error occurs.

Reproduction script:

$ python repro.py 
Iteration 1
Iteration 2
Iteration 3
Iteration 4
Iteration 5
/otto-dev/venv/lib/python3.11/site-packages/torch/autograd/graph.py:767: UserWarning: Error detected in ScaledDotProductEfficientAttentionBackward0. Traceback of forward call that caused the error:
  File "/otto-dev/repro.py", line 22, in <module>
    output = model(inputs)
  File "/otto-dev/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/otto-dev/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/otto-dev/venv/lib/python3.11/site-packages/torch/nn/modules/transformer.py", line 748, in forward
    x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal))
  File "/otto-dev/venv/lib/python3.11/site-packages/torch/nn/modules/transformer.py", line 756, in _sa_block
    x = self.self_attn(x, x, x,
  File "/otto-dev/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/otto-dev/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/otto-dev/venv/lib/python3.11/site-packages/torch/nn/modules/activation.py", line 1266, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "/otto-dev/venv/lib/python3.11/site-packages/torch/nn/functional.py", line 5527, in multi_head_attention_forward
    attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:111.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/otto-dev/repro.py", line 24, in <module>
    loss.backward()
  File "/otto-dev/venv/lib/python3.11/site-packages/torch/_tensor.py", line 523, in backward
    torch.autograd.backward(
  File "/otto-dev/venv/lib/python3.11/site-packages/torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File "/otto-dev/venv/lib/python3.11/site-packages/torch/autograd/graph.py", line 767, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Function 'ScaledDotProductEfficientAttentionBackward0' returned nan values in its 0th output.

This exception will not be triggered immediately during training, but will be triggered after training for many epochs.

the issue from memory efficient attention kernel and Math backend doesn't produce nans with this input.

code from @SamAdamDay

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.attention import SDPBackend, sdpa_kernel

embed_dim = 1024
batch_size = 32
seq_length = 50
num_iterations = 1000
learning_rate = 0.01
device = torch.device("cuda")
torch.autograd.set_detect_anomaly(True)

# Initialize model
model = nn.TransformerEncoderLayer( embed_dim, 16, dropout=0.01, batch_first=True).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
inputs = torch.full((seq_length, batch_size, embed_dim), 1000.0, device=device)

# Training loop
with sdpa_kernel(SDPBackend.MATH):
    for i in range(num_iterations):
        print(f'Iteration {i + 1}')
        optimizer.zero_grad()
        output = model(inputs)
        loss = output.mean()
        loss.backward()
        optimizer.step()

more evidence and disscussion

It is clear that this is a very serious bug that can affect many researcher's model train process and experimental results. It should be fixed quickly

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @bhosmer @cpuhrsch @erichan1 @drisspg

Versions

2.1.0+cu121 - 2.4.0.dev20240427+cu121

@jbschlosser jbschlosser added the oncall: transformer/mha Issues related to Transformers and MultiheadAttention label May 7, 2024
@drisspg
Copy link
Contributor

drisspg commented May 16, 2024

One Update
I tried locally compiling with string(APPEND CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fmad=false")
And the error still repros

Also while using: NanInfdetect

I was able to determine that this training loops is encouraging very large values for q, k , v

4.9956e+08, -4.9956e+08, -4.9956e+08,  ..., -4.9956e+08,
          -4.9956e+08, -4.9956e+08],
         [ 2.7200e+08,  2.7200e+08,  2.7200e+08,  ...,  2.7200e+08,
           2.7200e+08,  2.7200e+08],
           

If you set norm_first = True I am unable to reproduce the error, the fused kernels are known to be more sensitive to large values, I have seen this directly for FAv2 because of its use of FMA.. I don think thats the case here from the first update though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: transformer/mha Issues related to Transformers and MultiheadAttention
Projects
None yet
Development

No branches or pull requests

3 participants