# Shape study notebook

Here, I am taking a close look at how the shapes of the tensors are transformed in the multi-head attention mechanism step by step based on the formula:

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$



In [1]:
import torch
import math

I am assuming that the q, k, and v tensors have already been split into heads and transposed. For details, you should check the [multihead.py](./multihead.py) file in the same directory as this notebook.

In [2]:
q = torch.randn(1, 8, 4, 64)
k = torch.randn(1, 8, 4, 64)
v = torch.randn(1, 8, 4, 64)
d_k = k.shape[-1]

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([1, 8, 4, 64])
torch.Size([1, 8, 4, 64])
torch.Size([1, 8, 4, 64])


In [3]:
qk_T = (q @ k.transpose(-2, -1))
qk_T.shape


torch.Size([1, 8, 4, 4])

In [4]:
attn_scores = qk_T / math.sqrt(d_k)
print(f"attn_scores.shape: {attn_scores.shape}")

attn_weights = torch.softmax(attn_scores, dim=-1)
print(f"attn_weights.shape: {attn_weights.shape}")

context_vec = attn_weights @ v
print(f"context_vec.shape: {context_vec.shape}")


attn_scores.shape: torch.Size([1, 8, 4, 4])
attn_weights.shape: torch.Size([1, 8, 4, 4])
context_vec.shape: torch.Size([1, 8, 4, 64])


In [11]:
"""
and now this context vector is concatenated back together to form the final output.
"""

intermediate = context_vec.transpose(1, 2)
print(f"intermediate.shape: {intermediate.shape}")
final_output = intermediate.contiguous().view(1, -1, 512)
print(f"final_output.shape: {final_output.shape}")

intermediate.shape: torch.Size([1, 4, 8, 64])
final_output.shape: torch.Size([1, 4, 512])
