Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Results of MultiheadAttention depend on the query length #33841

Closed
amorgun opened this issue Feb 26, 2020 · 9 comments
Closed

Results of MultiheadAttention depend on the query length #33841

amorgun opened this issue Feb 26, 2020 · 9 comments
Assignees
Labels
module: nn Related to torch.nn oncall: transformer/mha Issues related to Transformers and MultiheadAttention triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@amorgun
Copy link

amorgun commented Feb 26, 2020

馃悰 Bug

MultiheadAttention should yield the same result if I split the key into several chunks and then concatenate chunk results back together. Now it does not work for some chunk sizes.

To Reproduce

Steps to reproduce the behavior:

np.random.seed(12)
in1 = torch.FloatTensor(np.random.rand(256**2, 1, 64))
in2 = torch.FloatTensor(np.random.rand(128, 1, 64))
model = torch.nn.MultiheadAttention(embed_dim=64, num_heads=1)
model = model.eval()
with torch.no_grad():
    out1 = model(in2, in1, in1)[0]
    out2 = model(in2, in1, in1)[0]
    print(torch.equal(out1, out2))  # prints True, just a sanity check
    out3 = torch.cat([model(in2[:64], in1, in1)[0], model(in2[64:], in1, in1)[0]])
    print(torch.equal(out1, out3))  # prints False
    out4 = torch.cat([model(in2[:100], in1, in1)[0], model(in2[100:], in1, in1)[0]])
    print(torch.equal(out1, out4))  # prints False 

Expected behavior

The example script should print True three times.

Environment

PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609
CMake version: Could not collect

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce GTX 1080 Ti
Nvidia driver version: 440.33.01
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] efficientnet-pytorch==0.6.1
[pip3] numpy==1.17.3
[pip3] pytorch-pretrained-bert==0.6.2
[pip3] segmentation-models-pytorch==0.1.0
[pip3] torch==1.4.0
[pip3] torchfile==0.1.0
[pip3] torchvision==0.5.0
[conda] Could not collect

Additional context

@ailzhang
Copy link
Contributor

cc: @zhangguanheng66 would you mind taking a look here?

@ailzhang ailzhang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 27, 2020
@zhangguanheng66 zhangguanheng66 added the module: nn Related to torch.nn label Feb 29, 2020
@zhangguanheng66
Copy link
Contributor

It seems some floating point difference. Not sure it's because of nn.MultiheadAttention module or torch.cat func. May worth some investigations.

@amorgun
Copy link
Author

amorgun commented Mar 2, 2020

It is possible to rewrite the example script without tocrh.cat:

np.random.seed(12)
in1 = torch.FloatTensor(np.random.rand(256**2, 1, 64))
in2 = torch.FloatTensor(np.random.rand(128, 1, 64))
model = torch.nn.MultiheadAttention(embed_dim=64, num_heads=1)
model = model.eval()
with torch.no_grad():
    out1 = model(in2, in1, in1)[0]
    print(torch.equal(out1[:64], model(in2[:64], in1, in1)[0]))  # prints False 
    print(torch.equal(out1[:128], model(in2[:128], in1, in1)[0]))  # prints True 

Also, I've found a bug in the original script. It does not work correctly for chunk size 64 either.
Do I understand correctly that MultiheadAttention(a[:i], b, c)[0] should always be equal to MultiheadAttention(a, b, c)[0][:i]?

@tstumm
Copy link

tstumm commented Mar 2, 2020

The input shape for the attention module is (SeqLen, BatchSize, EmbeddingDim). You're slicing the sequence dimension, which does indeed influence the calculation. That's actually the correct behavior, as you are attending each token in every sequence to every other token in the same sequence. If you're slicing the sequence, then you won't be attending to the whole sequence, yielding the minor discrepancies you've been observing.

If you want to chunk at the batch dimension, you should go for the second dimension.

@amorgun
Copy link
Author

amorgun commented Mar 2, 2020

@tstumm I agree that chunking along the batch dimension is correct. But it looks like chunking along the sequence dimension should also be possible because queries do not attend to themselves, and each result element depends just on its respective query, all keys, and all values.

@amorgun amorgun changed the title Results of MultiheadAttention depend on the key length Results of MultiheadAttention depend on the query length Mar 2, 2020
@zhangguanheng66
Copy link
Contributor

I thought about if we should expect the behaviors in this issue for MultiheadAttention. Since the difference is very small, I tend to say "yes". If we take a close look at the MHA module, the heads are split at embedding dimension. But I may miss something within the MHA.
Another thing interesting: the hypothesis is wrong for the 64 case but correct for the 128 case.
If someone are interested in this problem, some analysis with a smaller numbers for seq, batch, embedding may be helpful to gain some insight.

@zhangguanheng66 zhangguanheng66 self-assigned this Mar 2, 2020
@tstumm
Copy link

tstumm commented Mar 3, 2020

a = F.linear(in2, model.in_proj_weight, model.in_proj_bias)[:64]
b = F.linear(in2[:64], model.in_proj_weight, model.in_proj_bias)
print(a.dist(b))

Prints tensor(4.2136e-06), so this may be related to #34060

@amorgun
Copy link
Author

amorgun commented Mar 3, 2020

I tried several different values for embedding, chunk_size and num_heads.
My assumption works for some combinations, but not all of them.
My script:

import torch
import numpy as np

def apply_chunked(model, q, k, v, chunk_size):
    parts = []
    for start_idx in torch.arange(0, q.shape[0], chunk_size):
        end_idx = start_idx + chunk_size
        chunk_q = q[start_idx:end_idx]
        parts.append(model(chunk_q, k, v)[0])
    return torch.cat(parts)
    

torch.manual_seed(12)
np.random.seed(12)
for embedding in [64, 96, 128, 144]:
    print(f'embedding: {embedding}')
    in1 = torch.FloatTensor(np.random.rand(256**2, 1, embedding))
    in2 = torch.FloatTensor(np.random.rand(256, 1, embedding))
    for num_heads in [1,4,8,16]:
        for chunk_size in [16, 32, 64, 96, 128, 140]:
            model = torch.nn.MultiheadAttention(embed_dim=embedding, num_heads=num_heads)
            model = model.eval()
            with torch.no_grad():
                out_true = model(in2, in1, in1)[0]
                print(f'\tnum_heads={num_heads}\tchunk_size={chunk_size}\t'
                      f'verdict_joined={torch.equal(out_true, apply_chunked(model, in2, in1, in1, chunk_size))}\t'
                      f'verdict_prefix={torch.equal(out_true[:chunk_size], model(in2[:chunk_size], in1, in1)[0])}')
        print()

Output:

embedding: 64
	num_heads=1	chunk_size=16	verdict_joined=True	verdict_prefix=True
	num_heads=1	chunk_size=32	verdict_joined=False	verdict_prefix=False
	num_heads=1	chunk_size=64	verdict_joined=False	verdict_prefix=False
	num_heads=1	chunk_size=96	verdict_joined=False	verdict_prefix=True
	num_heads=1	chunk_size=128	verdict_joined=False	verdict_prefix=False
	num_heads=1	chunk_size=140	verdict_joined=True	verdict_prefix=True

	num_heads=4	chunk_size=16	verdict_joined=True	verdict_prefix=True
	num_heads=4	chunk_size=32	verdict_joined=False	verdict_prefix=False
	num_heads=4	chunk_size=64	verdict_joined=True	verdict_prefix=True
	num_heads=4	chunk_size=96	verdict_joined=True	verdict_prefix=True
	num_heads=4	chunk_size=128	verdict_joined=False	verdict_prefix=False
	num_heads=4	chunk_size=140	verdict_joined=False	verdict_prefix=True

	num_heads=8	chunk_size=16	verdict_joined=True	verdict_prefix=True
	num_heads=8	chunk_size=32	verdict_joined=False	verdict_prefix=False
	num_heads=8	chunk_size=64	verdict_joined=True	verdict_prefix=True
	num_heads=8	chunk_size=96	verdict_joined=True	verdict_prefix=True
	num_heads=8	chunk_size=128	verdict_joined=False	verdict_prefix=False
	num_heads=8	chunk_size=140	verdict_joined=False	verdict_prefix=True

	num_heads=16	chunk_size=16	verdict_joined=True	verdict_prefix=True
	num_heads=16	chunk_size=32	verdict_joined=False	verdict_prefix=False
	num_heads=16	chunk_size=64	verdict_joined=True	verdict_prefix=True
	num_heads=16	chunk_size=96	verdict_joined=True	verdict_prefix=True
	num_heads=16	chunk_size=128	verdict_joined=False	verdict_prefix=False
	num_heads=16	chunk_size=140	verdict_joined=False	verdict_prefix=True

embedding: 96
	num_heads=1	chunk_size=16	verdict_joined=True	verdict_prefix=True
	num_heads=1	chunk_size=32	verdict_joined=False	verdict_prefix=False
	num_heads=1	chunk_size=64	verdict_joined=False	verdict_prefix=False
	num_heads=1	chunk_size=96	verdict_joined=False	verdict_prefix=True
	num_heads=1	chunk_size=128	verdict_joined=True	verdict_prefix=True
	num_heads=1	chunk_size=140	verdict_joined=True	verdict_prefix=True

	num_heads=4	chunk_size=16	verdict_joined=True	verdict_prefix=True
	num_heads=4	chunk_size=32	verdict_joined=False	verdict_prefix=False
	num_heads=4	chunk_size=64	verdict_joined=False	verdict_prefix=False
	num_heads=4	chunk_size=96	verdict_joined=False	verdict_prefix=True
	num_heads=4	chunk_size=128	verdict_joined=True	verdict_prefix=True
	num_heads=4	chunk_size=140	verdict_joined=False	verdict_prefix=True

	num_heads=8	chunk_size=16	verdict_joined=True	verdict_prefix=True
	num_heads=8	chunk_size=32	verdict_joined=False	verdict_prefix=False
	num_heads=8	chunk_size=64	verdict_joined=False	verdict_prefix=False
	num_heads=8	chunk_size=96	verdict_joined=False	verdict_prefix=True
	num_heads=8	chunk_size=128	verdict_joined=True	verdict_prefix=True
	num_heads=8	chunk_size=140	verdict_joined=False	verdict_prefix=True

	num_heads=16	chunk_size=16	verdict_joined=True	verdict_prefix=True
	num_heads=16	chunk_size=32	verdict_joined=False	verdict_prefix=False
	num_heads=16	chunk_size=64	verdict_joined=False	verdict_prefix=False
	num_heads=16	chunk_size=96	verdict_joined=False	verdict_prefix=True
	num_heads=16	chunk_size=128	verdict_joined=True	verdict_prefix=True
	num_heads=16	chunk_size=140	verdict_joined=False	verdict_prefix=True

embedding: 128
	num_heads=1	chunk_size=16	verdict_joined=False	verdict_prefix=False
	num_heads=1	chunk_size=32	verdict_joined=False	verdict_prefix=False
	num_heads=1	chunk_size=64	verdict_joined=False	verdict_prefix=False
	num_heads=1	chunk_size=96	verdict_joined=False	verdict_prefix=False
	num_heads=1	chunk_size=128	verdict_joined=False	verdict_prefix=False
	num_heads=1	chunk_size=140	verdict_joined=False	verdict_prefix=False

	num_heads=4	chunk_size=16	verdict_joined=False	verdict_prefix=False
	num_heads=4	chunk_size=32	verdict_joined=False	verdict_prefix=False
	num_heads=4	chunk_size=64	verdict_joined=False	verdict_prefix=False
	num_heads=4	chunk_size=96	verdict_joined=False	verdict_prefix=False
	num_heads=4	chunk_size=128	verdict_joined=False	verdict_prefix=False
	num_heads=4	chunk_size=140	verdict_joined=False	verdict_prefix=False

	num_heads=8	chunk_size=16	verdict_joined=False	verdict_prefix=False
	num_heads=8	chunk_size=32	verdict_joined=False	verdict_prefix=False
	num_heads=8	chunk_size=64	verdict_joined=False	verdict_prefix=False
	num_heads=8	chunk_size=96	verdict_joined=False	verdict_prefix=False
	num_heads=8	chunk_size=128	verdict_joined=False	verdict_prefix=False
	num_heads=8	chunk_size=140	verdict_joined=False	verdict_prefix=False

	num_heads=16	chunk_size=16	verdict_joined=False	verdict_prefix=False
	num_heads=16	chunk_size=32	verdict_joined=False	verdict_prefix=False
	num_heads=16	chunk_size=64	verdict_joined=False	verdict_prefix=False
	num_heads=16	chunk_size=96	verdict_joined=False	verdict_prefix=False
	num_heads=16	chunk_size=128	verdict_joined=False	verdict_prefix=False
	num_heads=16	chunk_size=140	verdict_joined=False	verdict_prefix=False

embedding: 144
	num_heads=1	chunk_size=16	verdict_joined=True	verdict_prefix=True
	num_heads=1	chunk_size=32	verdict_joined=False	verdict_prefix=False
	num_heads=1	chunk_size=64	verdict_joined=False	verdict_prefix=False
	num_heads=1	chunk_size=96	verdict_joined=False	verdict_prefix=False
	num_heads=1	chunk_size=128	verdict_joined=False	verdict_prefix=False
	num_heads=1	chunk_size=140	verdict_joined=False	verdict_prefix=False

	num_heads=4	chunk_size=16	verdict_joined=True	verdict_prefix=True
	num_heads=4	chunk_size=32	verdict_joined=False	verdict_prefix=False
	num_heads=4	chunk_size=64	verdict_joined=False	verdict_prefix=False
	num_heads=4	chunk_size=96	verdict_joined=False	verdict_prefix=False
	num_heads=4	chunk_size=128	verdict_joined=False	verdict_prefix=False
	num_heads=4	chunk_size=140	verdict_joined=False	verdict_prefix=False

	num_heads=8	chunk_size=16	verdict_joined=True	verdict_prefix=True
	num_heads=8	chunk_size=32	verdict_joined=False	verdict_prefix=False
	num_heads=8	chunk_size=64	verdict_joined=False	verdict_prefix=False
	num_heads=8	chunk_size=96	verdict_joined=False	verdict_prefix=False
	num_heads=8	chunk_size=128	verdict_joined=False	verdict_prefix=False
	num_heads=8	chunk_size=140	verdict_joined=False	verdict_prefix=False

	num_heads=16	chunk_size=16	verdict_joined=True	verdict_prefix=True
	num_heads=16	chunk_size=32	verdict_joined=False	verdict_prefix=False
	num_heads=16	chunk_size=64	verdict_joined=False	verdict_prefix=False
	num_heads=16	chunk_size=96	verdict_joined=False	verdict_prefix=False
	num_heads=16	chunk_size=128	verdict_joined=False	verdict_prefix=False
	num_heads=16	chunk_size=140	verdict_joined=False	verdict_prefix=False

@zhangguanheng66
Copy link
Contributor

Investigated by @tstumm in #34060. IMO, this is expected and not a pytorch specific thing. A float number will only give you 6/7 digits precision.

Feel free to re-open the issue if you have any follow-up questions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nn Related to torch.nn oncall: transformer/mha Issues related to Transformers and MultiheadAttention triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants