Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiheadAttention is_causal=True is ignored if need_weights=True #99282

Closed
lendle opened this issue Apr 17, 2023 · 3 comments
Closed

MultiheadAttention is_causal=True is ignored if need_weights=True #99282

lendle opened this issue Apr 17, 2023 · 3 comments
Labels
oncall: transformer/mha Issues related to Transformers and MultiheadAttention

Comments

@lendle
Copy link

lendle commented Apr 17, 2023

馃悰 Describe the bug

When need_weight=True, is_causal is ignored in MultiheadAttention.forward and the result without causal masking is returned.

import torch 
import torch.nn as nn

batch_size = 4
seq_len = 3
embedding_dim = 8
num_heads = 2

mha=nn.MultiheadAttention(num_heads=num_heads, embed_dim=embedding_dim, batch_first=True)
x = torch.randn(batch_size, seq_len, embedding_dim)

mask = nn.Transformer.generate_square_subsequent_mask(seq_len)

with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False, enable_flash=False):
    no_mask = mha(x,x,x, need_weights=False)[0]
    with_attn_mask = mha(x,x,x, need_weights=True, attn_mask=mask)[0]
    with_is_causal_need_weights = mha(x,x,x, need_weights=True, is_causal=True)[0]
    with_is_causal_no_need_weights = mha(x,x,x, need_weights=False, is_causal=True)[0]
    
#succeeds    
assert with_attn_mask.allclose(with_is_causal_no_need_weights)

#both should succeed but fail
assert with_attn_mask.allclose(with_is_causal_need_weights), "is_causal should match regardless of 'need_weights'"
assert not no_mask.allclose(with_is_causal_need_weights),  "no mask should NOT match is_causal=True, need_weights=True"
    

Versions

PyTorch version: 2.0.0
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 10 (buster) (x86_64)
GCC version: (Debian 8.3.0-6) 8.3.0
Clang version: Could not collect
CMake version: version 3.13.4
Libc version: glibc-2.28

Python version: 3.10.10 | packaged by conda-forge | (main, Mar 24 2023, 20:08:06) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-4.19.0-23-cloud-amd64-x86_64-with-glibc2.28
Is CUDA available: True
CUDA runtime version: 11.3.109
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: Tesla T4
GPU 1: Tesla T4

Nvidia driver version: 510.47.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:        x86_64
CPU op-mode(s):      32-bit, 64-bit
Byte Order:          Little Endian
Address sizes:       46 bits physical, 48 bits virtual
CPU(s):              4
On-line CPU(s) list: 0-3
Thread(s) per core:  2
Core(s) per socket:  2
Socket(s):           1
NUMA node(s):        1
Vendor ID:           GenuineIntel
CPU family:          6
Model:               63
Model name:          Intel(R) Xeon(R) CPU @ 2.30GHz
Stepping:            0
CPU MHz:             2299.998
BogoMIPS:            4599.99
Hypervisor vendor:   KVM
Virtualization type: full
L1d cache:           32K
L1i cache:           32K
L2 cache:            256K
L3 cache:            46080K
NUMA node0 CPU(s):   0-3
Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.5
[pip3] pytorch-lightning==2.0.1
[pip3] torch==2.0.0
[pip3] torcharrow==0.1.0
[pip3] torchdata==0.6.0+a6c4904
[pip3] torchmetrics==0.11.4
[conda] dlenv-pytorch-1-13-gpu    1.0.20230310     py37h003b471_0    file:///tmp/conda-pkgs
[conda] numpy                     1.21.6                   pypi_0    pypi
[conda] torch                     1.13.1                   pypi_0    pypi
[conda] torch-xla                 1.13                     pypi_0    pypi
[conda] torchvision               0.14.1                   pypi_0    pypi

cc @jbschlosser @bhosmer @cpuhrsch @erichan1

@ngimel ngimel added the oncall: transformer/mha Issues related to Transformers and MultiheadAttention label Apr 17, 2023
@mikekgfb
Copy link
Contributor

mikekgfb commented Apr 20, 2023

#97214 clarifies that is_causal is a hint, and has no semantic power alone. If is_caiusal is supplied, it can optionally be used by nn.MHA and F.MHA in lieu of attn_mask.

Please verify with documentation in nightlies and in the upcoming 2.0.1 bug fix release. Starting with 2.0.1, this should be flagged as error as follows:

    if is_causal and attn_mask is None:
        raise RuntimeError(
            "Need attn_mask if specifying the is_causal hint. "
            "You may use the Transformer module method "
            "`generate_square_subsequent_mask` to create this mask."
        )

@zmurez
Copy link

zmurez commented May 17, 2023

This seems strange given the lines right below:

    if is_causal and key_padding_mask is None and not need_weights:
        # when we have a kpm or need weights, we need attn_mask
        # Otherwise, we use the is_causal hint go as is_causal
        # indicator to SDPA.
        attn_mask = None

We require an attn_mask to be given but then pass along None instead. In this case there is no reason to create the attn_mask. To save time on unnecessary memory allocation we can use attn_mask=torch.empty(1,1) which passes all the asserts but then is thrown away here. Why not allow attn_mask==None and is_causal==True? Using the trick attn_mask=torch.empty(1,1) seems prone to create unintended issues. Do you recommend just allocating the correct triu matrix and assume the added time/memory is negligible?

@mikekgfb
Copy link
Contributor

mikekgfb commented May 17, 2023

This seems strange given the lines right below:

    if is_causal and key_padding_mask is None and not need_weights:
        # when we have a kpm or need weights, we need attn_mask
        # Otherwise, we use the is_causal hint go as is_causal
        # indicator to SDPA.
        attn_mask = None

We require an attn_mask to be given but then pass along None instead. In this case there is no reason to create the attn_mask. To save time on unnecessary memory allocation we can use attn_mask=torch.empty(1,1) which passes all the asserts but then is thrown away here. Why not allow attn_mask==None and is_causal==True? Using the trick attn_mask=torch.empty(1,1) seems prone to create unintended issues. Do you recommend just allocating the correct triu matrix and assume the added time/memory is negligible?

To be clear, this code is in the general MHA implementation that has to support many use cases. For MHA, is_causal is a hint, and the attention mask is necessary because there are many legacy code paths that require an attention mask. So, in the nn.MHA and nn.Transformer* module API, is_causal was added as a hint to the existing attention mechanism, and the instantiated attention mask is required.

However, we recommend that many users can use the new SDPA operator directly. Because it requires fewer execution scenarios that must be supported, SDPA implements is_causal as an alternative to the attention mask. For an example of this, please refer to https://pytorch.org/blog/accelerating-large-language-models/ to see how we enabled nanoGPT to use sdpa directly and without mask.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: transformer/mha Issues related to Transformers and MultiheadAttention
Projects
None yet
Development

No branches or pull requests

4 participants