Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

scaled_dot_product_attention behaves differently between v2.0 and v2.1 #110213

Closed
ydshieh opened this issue Sep 28, 2023 · 12 comments
Closed

scaled_dot_product_attention behaves differently between v2.0 and v2.1 #110213

ydshieh opened this issue Sep 28, 2023 · 12 comments
Labels
module: multi-headed-attention triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ydshieh
Copy link

ydshieh commented Sep 28, 2023

馃悰 Describe the bug

With torch v2.1, scaled_dot_product_attention on GPU gives nan when a sequence has all large negative values (e.g torch.finfo(q.dtype).min - in order to mean no attention at all places). On CPU, it won't give nan.

With torch v2.0, it gives no nan on both CPU and GPU and those values are the same as the one given by v2.1 + CPU.

I understand it doesn't really make sense when a sequence has no place to attend attention. However, I am wondering if this nan value in torch v2.1 is intentional or unexpected.

This causes issues falcon implementation in transformers when left padding is used.

Reproduce

(running with torch v2.1)

import torch
from transformers import FalconModel
from torch.nn import functional as F

torch.manual_seed(0)

a = 3
b = 4

q = torch.randn(size=(1, 1, a, b))
k = torch.randn(size=(1, 1, a, b))
v = torch.randn(size=(1, 1, a, b))

def check(q, k, v, device):

    q = q.to(device)
    k = k.to(device)
    v = v.to(device)

    neg_value = torch.finfo(q.dtype).min
    mask = [[neg_value, neg_value, neg_value], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
    mask = torch.tensor([[mask]]).to(device)

    o = F.scaled_dot_product_attention(q, k, v, mask, 0.0, is_causal=False)
    print(o)

check(q, k, v, "cpu")
check(q, k, v, "cuda")

Outputs

  • with torch v2.0 (both CPU and GPU) or torch v2.1 (CPU)
tensor([[[[ 0.1210,  0.3627, -0.9969, -0.6149],
          [ 0.1295,  0.4572, -1.0491, -0.6166],
          [ 0.1095,  0.3819, -0.7369, -0.8267]]]])
  • torch v2.1 (GPU)
tensor([[[[    nan,     nan,     nan,     nan],
          [ 0.1295,  0.4572, -1.0491, -0.6166],
          [ 0.1095,  0.3819, -0.7369, -0.8267]]]], device='cuda:0')

Versions

Collecting environment information...
PyTorch version: 2.1.0+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 11 Home
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.8.16 (default, Jun 12 2023, 21:00:42) [MSC v.1916 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.22621-SP0
Is CUDA available: True
CUDA runtime version: 11.6.112
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3070 Ti Laptop GPU
Nvidia driver version: 517.00
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture=9
CurrentClockSpeed=2400
DeviceID=CPU0
Family=198
L2CacheSize=11776
L2CacheSpeed=
Manufacturer=GenuineIntel
MaxClockSpeed=2400
Name=12th Gen Intel(R) Core(TM) i7-12800H
ProcessorType=3
Revision=

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.4
[pip3] torch==2.1.0+cu118
[pip3] torchaudio==2.1.0+cu118
[pip3] torchvision==0.16.0+cu118
[conda] numpy 1.24.4 pypi_0 pypi
[conda] torch 2.1.0+cu118 pypi_0 pypi
[conda] torchaudio 2.1.0+cu118 pypi_0 pypi
[conda] torchvision 0.16.0+cu118 pypi_0 pypi

Tasks

No tasks being tracked yet.
@janeyx99
Copy link
Contributor

cc @drisspg

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 28, 2023
@drisspg
Copy link
Contributor

drisspg commented Sep 28, 2023

So the reason for this like you said is because of masking out an entire row, similar to this issue: #103749. There was a change though from 2.0 to 2.1, which is that we can now run fused attention kernels with arbitrary attention masks. This means that your code likely was being run on the math path before and is now running with the fused kernel.

The fused attention kernels use the iterative softmax algorithm and for large negative values it will produce all NaNs as output while regular softmax for large negative values will evenly distribute the attention among all entries. It would be very costly to check if an entire row is masked at runtime.

I could see this being a check in debug mode though however.

As a workaround I would try running your code with mem effiecient attention turned off via the context manager:

enable_mem_efficient: bool = True,

@ydshieh
Copy link
Author

ydshieh commented Sep 29, 2023

Thank you. Turning off mem. effiecient attention looks like to me destorying the purpose of using scaled_dot_product_attention?
I guess we will have to do something manually on masks before calling scaled_dot_product_attention.

(as you can see, left padding + causal mask will cause shorter sequences having all places being masked for leading tokens)

@ydshieh ydshieh closed this as completed Sep 29, 2023
@fxmarty
Copy link

fxmarty commented Oct 3, 2023

@drisspg I am wondering if you have an other suggestion apart from disabling the memory-efficient kernel path.

My current understanding is that efficient batched inference with padding support is not a by-product of #96099, contrary to what I thought before. Is that correct?

To me (for inference), it is not a big deal that some nan appear (they appear in meaningless positions anyway), however it is a big issue that nan propagate to all positions at later SDPA calls.

A solution could be to override nan with some dummy value to avoid nan to propagate, but that is surely inefficient.

see https://www.diffchecker.com/4UwU6uKK/ (math vs mem-efficient)

image
image
image

@fxmarty
Copy link

fxmarty commented Oct 3, 2023

Solution: attend to at least a token even for padding. This does not influence the result given that softmax is computed on the last dimension

@ydshieh
Copy link
Author

ydshieh commented Oct 4, 2023

attend to at least a token even for padding

This could fix the issue, and yes we don't care about the output at those padding places.

But in terms of testing/debugging, it would be difficult if we don't keep what has been done in torch 2.0 or before.
i.e. if previously it evenly distributes the attention among all entries, it's better to do so in our custom manipulation too. Otherwise, when we check (forward) outputs given by torch 2.0 and torch 2.1 and see it is different, users or developers need to remember those differences are at the padding places and are OK.

@fxmarty
Copy link

fxmarty commented Oct 4, 2023

What I did in huggingface/transformers#26572 is attend to all tokens equally (basically having [0, 0, 0, 0, 0] instead of [-inf, -inf, -inf, -inf, -inf] on padding rows), maybe it works with regard to keeping an even distribution of the attention among all entries for padding rows.

@ydshieh
Copy link
Author

ydshieh commented Oct 4, 2023

OK, very nice! Could you check they give the same outputs with this (on SDPA GPU + torch 2.1) and when running without SDPA (or with SDPA on CPU) for a sequence with no attention at all?

@fxmarty
Copy link

fxmarty commented Oct 4, 2023

Will do!

@fxmarty
Copy link

fxmarty commented Oct 12, 2023

Interestingly this issue happens only in fp32 but not in fp16 - maybe -65504 is not a small enough value as softmax may be computed in fp32

@FarzanT
Copy link

FarzanT commented May 9, 2024

Solution: attend to at least a token even for padding. This does not influence the result given that softmax is computed on the last dimension

Hi @fxmarty, @drisspg, I know I'm kinda late to the party, but using an attn_mask that is attempting to ignore certain rows entirely still leads to NaNs that propagates throughout all layers of the neural network. This is the case whether FlashAttention v1 or v2 is used in Pytorch >= 2.1.

My case involves a cross-attention layer where the query and keys both have padded elements. I don't want any of the padded elements to be attended to.

Setting first column of the attn_mask to False using attn_mask[:, :, 0] = False as @fxmarty suggested is one way of preventing the problem. This feels quite hacky, and I'm not sure if it's the right solution, as the model now attends to the first element of the key, even for padded elements in the query (?).

I think we should establish a guideline for this case from now so that people like me don't get stuck for a day trying to resolve the issue.

Thank you for your time!

@fxmarty
Copy link

fxmarty commented Jun 10, 2024

Hi @FarzanT, I am only seeing your message now. You should probably open an other issue with a repro/example.

In transformers, we use https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/modeling_attn_mask_utils.py#L189 which unmasks entirely padding rows in the source dimension from (batch_size, source_length, target_length), effectively dealing with padding in the source dimension.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: multi-headed-attention triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants