# Shape study notebook

Here, I am taking a close look at how the shapes of the tensors are transformed in the multi-head attention mechanism step by step based on the formula:

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$



In [1]:
import torch
import math

I am assuming that the q, k, and v tensors have already been split into heads and transposed. For details, you should check the [multihead.py](./multihead.py) file in the same directory as this notebook.

In [2]:
q = torch.randn(1, 8, 4, 64)
k = torch.randn(1, 8, 4, 64)
v = torch.randn(1, 8, 4, 64)
d_k = k.shape[-1]

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([1, 8, 4, 64])
torch.Size([1, 8, 4, 64])
torch.Size([1, 8, 4, 64])


In [3]:
qk_T = q @ k.transpose(-2, -1)
qk_T.shape

torch.Size([1, 8, 4, 4])

In [4]:
attn_scores = qk_T / math.sqrt(d_k)
print(f"attn_scores.shape: {attn_scores.shape}")

attn_weights = torch.softmax(attn_scores, dim=-1)
print(f"attn_weights.shape: {attn_weights.shape}")

context_vec = attn_weights @ v
print(f"context_vec.shape: {context_vec.shape}")

attn_scores.shape: torch.Size([1, 8, 4, 4])
attn_weights.shape: torch.Size([1, 8, 4, 4])
context_vec.shape: torch.Size([1, 8, 4, 64])


In [5]:
"""
and now this context vector is concatenated back together to form the final output.
"""

intermediate = context_vec.transpose(1, 2)
print(f"intermediate.shape: {intermediate.shape}")
final_output = intermediate.contiguous().view(1, -1, 512)
print(f"final_output.shape: {final_output.shape}")

intermediate.shape: torch.Size([1, 4, 8, 64])
final_output.shape: torch.Size([1, 4, 512])


In [6]:
from multihead import *


q_encodings.shape: torch.Size([1, 4, 512])
q_encodings: 1

k_encodings.shape: torch.Size([1, 4, 512])
k_encodings: 1

v_encodings.shape: torch.Size([1, 4, 512])
v_encodings: 1

q's size after q_encodings @ W_q: torch.Size([1, 4, 512])

q after splitting and transposing: torch.Size([1, 8, 4, 64])
k after splitting and transposing: torch.Size([1, 8, 4, 64]) what
value after splitting and transposing: torch.Size([1, 8, 4, 64])

Big H shape: torch.Size([1, 4, 512])

torch.Size([1, 4, 512]) tensor([[[ 0.4343, -0.0738,  0.0735,  ...,  0.0547,  0.3116, -0.2508],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [ 0.2527,  0.0563,  0.2185,  ...,  0.0551,  0.1700, -0.2223]]],
       grad_fn=<ViewBackward0>)


In [7]:
q_encodings.shape, k_encodings.shape

(torch.Size([1, 4, 512]), torch.Size([1, 4, 512]))

In [8]:
mask = create_causal_mask(
    seq_len_q=q_encodings.shape[-2], seq_len_k=k_encodings.shape[-2], device=device
)
mask

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [14]:
attn_scores = (q_encodings @ k_encodings.transpose(-2, -1)) / math.sqrt(d_k)
attn_scores = attn_scores.masked_fill(mask, float("-inf"))
attn_scores

tensor([[[ 96.8332,     -inf,     -inf,     -inf],
         [ 32.7087,  94.9597,     -inf,     -inf],
         [ 32.5465,  30.7693,  99.3592,     -inf],
         [ 26.6260,  34.8210,  28.1737, 100.3393]]],
       grad_fn=<MaskedFillBackward0>)

In [15]:
attn_weights = torch.softmax(attn_scores, dim=-1)
attn_weights

tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [9.2196e-28, 1.0000e+00, 0.0000e+00, 0.0000e+00],
         [9.6297e-30, 1.6285e-30, 1.0000e+00, 0.0000e+00],
         [9.6996e-33, 3.5139e-29, 4.5595e-32, 1.0000e+00]]],
       grad_fn=<SoftmaxBackward0>)