In [1]:
import torch
import torch.nn as nn

In [2]:
x = torch.tensor([[[-0.1, 0.1,  0.3]]])
layer = nn.MultiheadAttention(embed_dim=3, num_heads=1, bias=False, batch_first=True)

layer

MultiheadAttention(
  (out_proj): NonDynamicallyQuantizableLinear(in_features=3, out_features=3, bias=False)
)

In [4]:
output_tensor, attn_output_weights = layer(x, x, x)

# Print the shape of the output tensor
print(output_tensor)

tensor([[[-0.0255, -0.0356,  0.0286]]], grad_fn=<TransposeBackward0>)


In [9]:
import torch.nn.functional as F
x = torch.tensor([[[-0.1, 0.1, 0.3]]])

q = torch.tensor(  [[-0.3561,  0.3674, -0.5108],
                    [ 0.5146, -0.4764, -0.1490],
                    [ 0.5072, -0.2932, -0.5633]]).float()
k = torch.tensor(  [[-0.4932, -0.4468,  0.0736],
                    [-0.6879, -0.4689, -0.1026],
                    [ 0.1847,  0.1858,  0.4469]]).float()
v = torch.tensor(  [[-0.4110, -0.4083, -0.5549],
                    [ 0.3921, -0.0746, -0.1336],
                    [-0.6555, -0.3418, -0.2980]]).float()
o = torch.tensor([[-0.3601,  0.2771, -0.0573],
                  [-0.0896,  0.0567, -0.2882],
                  [ 0.3200,  0.1517,  0.0580]]).float()
query = x@q.T
key = x@k.T
value = x@v.T

# Define the model parameters
embed_dim = 3
num_heads = 1
head_dim = embed_dim // num_heads

# Reshape query, key, and value to have shape (batch_size, num_heads, seq_len, head_dim)
query = query.view(num_heads, -1, head_dim)
key = key.view(num_heads, -1, head_dim)
value = value.view(num_heads, -1, head_dim)

query

# Step 3: Compute scaled dot-product attention
attention_scores = torch.matmul(query, key.transpose(-2, -1)) / (head_dim ** 0.5)
attention_weights = F.softmax(attention_scores, dim=-1)
context = torch.matmul(attention_weights, value)


print('attention_scores', attention_scores)
print('attention_weights', attention_weights)
print('context', context)


# Step 4: Concatenate and project back
context = context.view(-1, embed_dim)
output = context@o.T

# Print the shape of the output tensor
print(output)

attention_scores tensor([[[-0.0198]]])
attention_weights tensor([[[1.]]])
context tensor([[[-0.1662, -0.0868, -0.0580]]])
tensor([[ 0.0391,  0.0267, -0.0697]])
