-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Results of MultiheadAttention depend on the query length #33841
Comments
cc: @zhangguanheng66 would you mind taking a look here? |
It seems some floating point difference. Not sure it's because of nn.MultiheadAttention module or torch.cat func. May worth some investigations. |
It is possible to rewrite the example script without np.random.seed(12)
in1 = torch.FloatTensor(np.random.rand(256**2, 1, 64))
in2 = torch.FloatTensor(np.random.rand(128, 1, 64))
model = torch.nn.MultiheadAttention(embed_dim=64, num_heads=1)
model = model.eval()
with torch.no_grad():
out1 = model(in2, in1, in1)[0]
print(torch.equal(out1[:64], model(in2[:64], in1, in1)[0])) # prints False
print(torch.equal(out1[:128], model(in2[:128], in1, in1)[0])) # prints True Also, I've found a bug in the original script. It does not work correctly for chunk size 64 either. |
The input shape for the attention module is If you want to chunk at the batch dimension, you should go for the second dimension. |
@tstumm I agree that chunking along the batch dimension is correct. But it looks like chunking along the sequence dimension should also be possible because queries do not attend to themselves, and each result element depends just on its respective query, all keys, and all values. |
I thought about if we should expect the behaviors in this issue for MultiheadAttention. Since the difference is very small, I tend to say "yes". If we take a close look at the MHA module, the heads are split at embedding dimension. But I may miss something within the MHA. |
a = F.linear(in2, model.in_proj_weight, model.in_proj_bias)[:64]
b = F.linear(in2[:64], model.in_proj_weight, model.in_proj_bias)
print(a.dist(b)) Prints |
I tried several different values for import torch
import numpy as np
def apply_chunked(model, q, k, v, chunk_size):
parts = []
for start_idx in torch.arange(0, q.shape[0], chunk_size):
end_idx = start_idx + chunk_size
chunk_q = q[start_idx:end_idx]
parts.append(model(chunk_q, k, v)[0])
return torch.cat(parts)
torch.manual_seed(12)
np.random.seed(12)
for embedding in [64, 96, 128, 144]:
print(f'embedding: {embedding}')
in1 = torch.FloatTensor(np.random.rand(256**2, 1, embedding))
in2 = torch.FloatTensor(np.random.rand(256, 1, embedding))
for num_heads in [1,4,8,16]:
for chunk_size in [16, 32, 64, 96, 128, 140]:
model = torch.nn.MultiheadAttention(embed_dim=embedding, num_heads=num_heads)
model = model.eval()
with torch.no_grad():
out_true = model(in2, in1, in1)[0]
print(f'\tnum_heads={num_heads}\tchunk_size={chunk_size}\t'
f'verdict_joined={torch.equal(out_true, apply_chunked(model, in2, in1, in1, chunk_size))}\t'
f'verdict_prefix={torch.equal(out_true[:chunk_size], model(in2[:chunk_size], in1, in1)[0])}')
print() Output:
|
馃悰 Bug
MultiheadAttention should yield the same result if I split the key into several chunks and then concatenate chunk results back together. Now it does not work for some chunk sizes.
To Reproduce
Steps to reproduce the behavior:
Expected behavior
The example script should print True three times.
Environment
Additional context
The text was updated successfully, but these errors were encountered: