-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
scaled_dot_product_attention
behaves differently between v2.0 and v2.1
#110213
Comments
cc @drisspg |
So the reason for this like you said is because of masking out an entire row, similar to this issue: #103749. There was a change though from 2.0 to 2.1, which is that we can now run fused attention kernels with arbitrary attention masks. This means that your code likely was being run on the math path before and is now running with the fused kernel. The fused attention kernels use the iterative softmax algorithm and for large negative values it will produce all NaNs as output while regular softmax for large negative values will evenly distribute the attention among all entries. It would be very costly to check if an entire row is masked at runtime. I could see this being a check in debug mode though however. As a workaround I would try running your code with mem effiecient attention turned off via the context manager: pytorch/torch/backends/cuda/__init__.py Line 275 in 81da6db
|
Thank you. Turning off mem. effiecient attention looks like to me destorying the purpose of using (as you can see, left padding + causal mask will cause shorter sequences having all places being masked for leading tokens) |
@drisspg I am wondering if you have an other suggestion apart from disabling the memory-efficient kernel path. My current understanding is that efficient batched inference with padding support is not a by-product of #96099, contrary to what I thought before. Is that correct? To me (for inference), it is not a big deal that some A solution could be to override nan with some dummy value to avoid see https://www.diffchecker.com/4UwU6uKK/ (math vs mem-efficient) |
Solution: attend to at least a token even for padding. This does not influence the result given that softmax is computed on the last dimension |
This could fix the issue, and yes we don't care about the output at those padding places. But in terms of testing/debugging, it would be difficult if we don't keep what has been done in torch 2.0 or before. |
What I did in huggingface/transformers#26572 is attend to all tokens equally (basically having |
OK, very nice! Could you check they give the same outputs with this (on SDPA GPU + torch 2.1) and when running without SDPA (or with SDPA on CPU) for a sequence with no attention at all? |
Will do! |
Interestingly this issue happens only in fp32 but not in fp16 - maybe |
Hi @fxmarty, @drisspg, I know I'm kinda late to the party, but using an My case involves a cross-attention layer where the query and keys both have padded elements. I don't want any of the padded elements to be attended to. Setting first column of the I think we should establish a guideline for this case from now so that people like me don't get stuck for a day trying to resolve the issue. Thank you for your time! |
Hi @FarzanT, I am only seeing your message now. You should probably open an other issue with a repro/example. In transformers, we use https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/modeling_attn_mask_utils.py#L189 which unmasks entirely padding rows in the source dimension from (batch_size, source_length, target_length), effectively dealing with padding in the source dimension. |
馃悰 Describe the bug
With torch v2.1,
scaled_dot_product_attention
onGPU
givesnan
when a sequence has all large negative values (e.gtorch.finfo(q.dtype).min
- in order to mean no attention at all places). OnCPU
, it won't givenan
.With torch v2.0, it gives no
nan
on bothCPU
andGPU
and those values are the same as the one given byv2.1 + CPU
.I understand it doesn't really make sense when a sequence has no place to attend attention. However, I am wondering if this
nan
value in torch v2.1 is intentional or unexpected.This causes issues
falcon
implementation intransformers
when left padding is used.Reproduce
(running with torch v2.1)
Outputs
CPU
andGPU
) or torch v2.1 (CPU
)GPU
)Versions
Collecting environment information...
PyTorch version: 2.1.0+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
OS: Microsoft Windows 11 Home
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A
Python version: 3.8.16 (default, Jun 12 2023, 21:00:42) [MSC v.1916 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.22621-SP0
Is CUDA available: True
CUDA runtime version: 11.6.112
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3070 Ti Laptop GPU
Nvidia driver version: 517.00
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture=9
CurrentClockSpeed=2400
DeviceID=CPU0
Family=198
L2CacheSize=11776
L2CacheSpeed=
Manufacturer=GenuineIntel
MaxClockSpeed=2400
Name=12th Gen Intel(R) Core(TM) i7-12800H
ProcessorType=3
Revision=
Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.4
[pip3] torch==2.1.0+cu118
[pip3] torchaudio==2.1.0+cu118
[pip3] torchvision==0.16.0+cu118
[conda] numpy 1.24.4 pypi_0 pypi
[conda] torch 2.1.0+cu118 pypi_0 pypi
[conda] torchaudio 2.1.0+cu118 pypi_0 pypi
[conda] torchvision 0.16.0+cu118 pypi_0 pypi
Tasks
The text was updated successfully, but these errors were encountered: