In [2]:
import torch
from torch import nn

## Single Head Attention

In [21]:
# single head attention

# Hyperparameters
batch_size = 2
seq_len = 3
embedding_dim = 4
head_size = 5

x = torch.randint(0, 100, (batch_size, seq_len)) # (batch_size, seq_len)

print('Initial shape', x.shape)
print(x)
print()

#embdding
embedding_layer = nn.Embedding(100, embedding_dim) # (vocab_size, input_dim)
x = embedding_layer(x) # (batch_size, seq_len, embedding_dim)
print('Shape after embedding', x.shape)
print(x)
print()

# Linear layers for Q, K, V projections
W_q = nn.Linear(embedding_dim, head_size)
W_k = nn.Linear(embedding_dim, head_size)
W_v = nn.Linear(embedding_dim, head_size)
W_o = nn.Linear(head_size, head_size)

### FORWARD PASS
# Assuming x is of shape (batch_size, seq_len, input_dim)
# Project inputs to Q, K, V
Q = W_q(x) # (batch_size, seq_len, head_size)
K = W_k(x) # (batch_size, seq_len, head_size)
V = W_v(x) # (batch_size, seq_len, head_size)
print('Q shape', Q.shape)
print('K shape', K.shape)
print('V shape', V.shape)
print()

# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / (K.size(-1) ** 0.5) # (batch_size, seq_len, seq_len)
print('Scores shape', scores.shape)
print(scores)
print()

# apply causal mask
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
scores = scores.masked_fill(causal_mask, float('-inf'))
print('Causal scores shape', scores.shape)
print(scores)
print()
weights = torch.softmax(scores, dim=-1) # (batch_size, seq_len, seq_len)
print('Weights shape', weights.shape)
print(weights)
print()
# Apply attention weights to V
output = torch.matmul(weights, V) # (batch_size, seq_len, head_size)
print('weighted values shape', output.shape)
print(output)
print()
# Project back to original dimension
output = W_o(output)
print('Final Output shape', output.shape)
print(output)
print()


Initial shape torch.Size([2, 3])
tensor([[66, 85, 31],
        [14, 74, 53]])

Shape after embedding torch.Size([2, 3, 4])
tensor([[[ 0.3844,  0.0395, -0.7434,  1.5816],
         [-0.3342,  0.8615,  0.4445, -1.3960],
         [ 2.1648,  0.4029, -0.2743,  1.3082]],

        [[-0.2259, -0.1743, -0.4262, -1.4836],
         [-0.9625, -0.6392,  0.5664, -0.5657],
         [-0.4407,  0.7306, -1.9145,  0.2075]]], grad_fn=<EmbeddingBackward0>)

Q shape torch.Size([2, 3, 5])
K shape torch.Size([2, 3, 5])
V shape torch.Size([2, 3, 5])

Scores shape torch.Size([2, 3, 3])
tensor([[[ 0.0290, -0.3966,  0.1607],
         [-0.2134,  0.2727, -0.6381],
         [-0.1263, -0.2367, -0.4838]],

        [[-0.0196,  0.2575, -0.0710],
         [-0.1938, -0.1350, -0.2213],
         [-0.1482, -0.1566,  0.1393]]], grad_fn=<DivBackward0>)

Causal scores shape torch.Size([2, 3, 3])
tensor([[[ 0.0290,    -inf,    -inf],
         [-0.2134,  0.2727,    -inf],
         [-0.1263, -0.2367, -0.4838]],

        [[-0.0196, 

## Multi Head Attention

In [26]:
 # multi head attention

# single head attention
from torch import nn

# Hyperparameters
batch_size = 2
seq_len = 3
embedding_dim = 8
num_heads = 4
head_size = embedding_dim // num_heads

x = torch.randint(0, 100, (batch_size, seq_len)) # (batch_size, seq_len)

print('Initial shape', x.shape)
print()

#embdding
embedding_layer = nn.Embedding(100, embedding_dim) # (vocab_size, input_dim)
x = embedding_layer(x) # (batch_size, seq_len, embedding_dim)
print('Shape after embedding', x.shape)
print()

# Linear layers for Q, K, V projections
W_q = nn.Linear(embedding_dim, embedding_dim)
W_k = nn.Linear(embedding_dim, embedding_dim)
W_v = nn.Linear(embedding_dim, embedding_dim)
W_o = nn.Linear(embedding_dim, embedding_dim)

### FORWARD PASS
# Assuming x is of shape (batch_size, seq_len, input_dim)
# Project inputs to Q, K, V
Q = W_q(x) # (batch_size, seq_len, embedding_dim)
K = W_k(x) # (batch_size, seq_len, embedding_dim)
V = W_v(x) # (batch_size, seq_len, embedding_dim)
print('Q shape', Q.shape)
print('K shape', K.shape)
print('V shape', V.shape)
print()

# split heads
Q = Q.view(batch_size, seq_len, num_heads, head_size).transpose(1, 2) # (batch_size, num_heads, seq_len, head_size)
K = K.view(batch_size, seq_len, num_heads, head_size).transpose(1, 2) # (batch_size, num_heads, seq_len, head_size)      
V = V.view(batch_size, seq_len, num_heads, head_size).transpose(1, 2) # (batch_size, num_heads, seq_len, head_size)

# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / (K.size(-1) ** 0.5) # (batch_size, num_heads, seq_len, seq_len)
print('Scores shape', scores.shape)
print()

# apply causal mask
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
scores = scores.masked_fill(causal_mask, float('-inf'))
print('Causal scores shape', scores.shape)
print()

weights = torch.softmax(scores, dim=-1) # (batch_size, num_heads, seq_len, seq_len)
print('Weights shape', weights.shape)
print(weights)
print()

# Apply attention weights to V
output = torch.matmul(weights, V) # batch_size, num_heads, seq_len, head_size)
print('weighted values shape', output.shape)
print()

# Project back to original dimension
# we need to use continguous before applying view as the output of transpose is non-continguous in memory
output = W_o(output.transpose(1, 2).contiguous().view(batch_size, seq_len, embedding_dim))
print('Final Output shape', output.shape)


Initial shape torch.Size([2, 3])

Shape after embedding torch.Size([2, 3, 8])

Q shape torch.Size([2, 3, 8])
K shape torch.Size([2, 3, 8])
V shape torch.Size([2, 3, 8])

Scores shape torch.Size([2, 4, 3, 3])

Causal scores shape torch.Size([2, 4, 3, 3])

Weights shape torch.Size([2, 4, 3, 3])
tensor([[[[1.0000, 0.0000, 0.0000],
          [0.4477, 0.5523, 0.0000],
          [0.3448, 0.3598, 0.2954]],

         [[1.0000, 0.0000, 0.0000],
          [0.4888, 0.5112, 0.0000],
          [0.2664, 0.4853, 0.2483]],

         [[1.0000, 0.0000, 0.0000],
          [0.5272, 0.4728, 0.0000],
          [0.2904, 0.3663, 0.3433]],

         [[1.0000, 0.0000, 0.0000],
          [0.3566, 0.6434, 0.0000],
          [0.2164, 0.3818, 0.4018]]],


        [[[1.0000, 0.0000, 0.0000],
          [0.4978, 0.5022, 0.0000],
          [0.3588, 0.2964, 0.3448]],

         [[1.0000, 0.0000, 0.0000],
          [0.5675, 0.4325, 0.0000],
          [0.3245, 0.3506, 0.3249]],

         [[1.0000, 0.0000, 0.0000],
        

Using Pytorch's attention layer

In [9]:
# Hyperparameters
batch_size = 2
seq_len = 3
embedding_dim = 8
num_heads = 4
head_size = embedding_dim // num_heads

x = torch.randint(0, 100, (batch_size, seq_len)) # (batch_size, seq_len)

#embdding
embedding_layer = nn.Embedding(100, embedding_dim) # (vocab_size, input_dim)
x = embedding_layer(x) # (batch_size, seq_len, embedding_dim)

attention = nn.MultiheadAttention(
    embed_dim=embedding_dim,
    num_heads=num_heads,
    batch_first=True  # Use (batch, seq, features) format
)


causal_mask = torch.triu(
    torch.ones(seq_len, seq_len, dtype=bool), 
    diagonal=1
)

attention(x, x, x, is_causal=True, attn_mask=causal_mask) #

(tensor([[[ 0.4867, -0.6963,  0.3553,  1.6682, -0.8163,  0.4333, -0.7774,
           -0.2174],
          [ 0.4151, -0.8389,  0.5260,  1.5996, -0.6560,  0.3786, -0.8686,
           -0.2580],
          [ 0.2302, -0.4075,  0.3258,  1.1119, -0.4806,  0.3360, -0.5904,
           -0.1410]],
 
         [[-0.7617,  0.1197,  0.2304, -1.1195,  0.5100, -1.3840,  0.8277,
           -0.8969],
          [-0.4709,  0.5864, -0.3686, -0.8999,  0.2155, -0.6151,  0.6710,
           -0.2450],
          [-0.0963, -0.0068,  0.0778, -0.2031, -0.1081, -0.1642, -0.1416,
           -0.1137]]], grad_fn=<TransposeBackward0>),
 tensor([[[1.0000, 0.0000, 0.0000],
          [0.4459, 0.5541, 0.0000],
          [0.3307, 0.4116, 0.2577]],
 
         [[1.0000, 0.0000, 0.0000],
          [0.4688, 0.5312, 0.0000],
          [0.3090, 0.3182, 0.3728]]], grad_fn=<MeanBackward1>))

In [11]:
list(attention.parameters())

[Parameter containing:
 tensor([[-0.3527, -0.4131,  0.2531, -0.1708, -0.2861, -0.1642,  0.2186, -0.2150],
         [-0.3680,  0.3517,  0.2795,  0.1738, -0.1350, -0.1377,  0.1143,  0.1266],
         [ 0.4017, -0.1394,  0.1766,  0.1344, -0.1631, -0.0773, -0.4112,  0.2540],
         [ 0.1289,  0.3852, -0.1793, -0.1650, -0.0402, -0.3146,  0.3904, -0.3643],
         [-0.1100,  0.0730,  0.4076,  0.0636, -0.2499,  0.0209, -0.1799,  0.1883],
         [-0.3686, -0.3249, -0.3289,  0.0247, -0.1940, -0.3537, -0.1217, -0.3896],
         [-0.1471,  0.2450, -0.3825,  0.1600, -0.0596,  0.4094, -0.1570, -0.0447],
         [ 0.1567,  0.0915, -0.1380,  0.0324, -0.4106,  0.4170, -0.3286,  0.0297],
         [ 0.1417,  0.2760,  0.0196,  0.0301, -0.3777, -0.2168, -0.1833,  0.1492],
         [ 0.2148,  0.1869, -0.3434,  0.0595,  0.0860,  0.1215,  0.1664,  0.0398],
         [ 0.1115,  0.3742,  0.2065,  0.2267, -0.1438,  0.1401, -0.2448,  0.2265],
         [-0.1933, -0.2717, -0.0476,  0.3238, -0.3507,  0.0037, 