Ref: https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html

In [1]:
from importlib.metadata import version

print("torch version:", version("torch"))

import torch

torch version: 2.5.1


# Sentence Embedding

In [2]:
sentence = 'My shoes are small, my feet are big.'

dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)

{'My': 0, 'are': 2, 'big.': 3, 'feet': 4, 'my': 5, 'shoes': 6, 'small': 7}


- Assign index to each word 

In [3]:
sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
print(sentence_int)

tensor([0, 6, 2, 7, 5, 4, 2, 3])


- Now, using the integer-vector representation of the input sentence, we can use an embedding layer to encode the inputs into a real-vector embedding. Here, we will use a 2-dimensional embedding such that each input word is represented by a 2-dimensional vector. Since the sentence consists of 8 words, this will result in a 8 X 2 dimensional embedding:

In [4]:
torch.manual_seed(123) #for reproducibility
embed = torch.nn.Embedding(8, 2) #8 tokens, 2 dim vector
embedded_sentence = embed(sentence_int).detach()

print(embedded_sentence)
print(embedded_sentence.shape)

tensor([[ 0.3374, -0.1778],
        [ 0.1794,  1.8951],
        [ 0.3486,  0.6603],
        [ 0.4954,  0.2692],
        [ 0.6984, -1.4097],
        [ 0.7671, -1.1925],
        [ 0.3486,  0.6603],
        [-0.2196, -0.3792]])
torch.Size([8, 2])


# Weight Matrices

Now, let’s discuss the widely utilized self-attention mechanism known as the scaled dot-product attention, which is integrated into the transformer architecture.

Self-attention utilizes three weight matrices, referred to as $W_q$, $W_k$ and $W_v$ which are adjusted as model parameters during training. These matrices serve to project the inputs into query, key, and value components of the sequence, respectively.


Since we are computing the dot-product between the query and key vectors, these two vectors have to contain the same number of elements, However, the number of elements in the value vector $v^{(i)}$, which determines the size of the resulting context vector, is arbitrary.

We will be extending the dimensions for query and keys to 3 and values to 4. 

In [5]:
torch.manual_seed(123)

print(embedded_sentence.shape) # [8, 2]
d = embedded_sentence.shape[1] # [2]

d_q, d_k, d_v = 3, 3, 4

W_query = torch.nn.Parameter(torch.rand(d_q, d))  # Shape: [3, 2]
W_key = torch.nn.Parameter(torch.rand(d_k, d))  # Shape: [3, 2]
W_value = torch.nn.Parameter(torch.rand(d_v, d)) # Shape: [4, 2]   

torch.Size([8, 2])


In [6]:
embedded_sentence.shape

torch.Size([8, 2])

In [7]:
print(W_query)
print(W_query.shape)

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]], requires_grad=True)
torch.Size([3, 2])


In [8]:
print(W_key)
print(W_key.shape)

Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]], requires_grad=True)
torch.Size([3, 2])


In [9]:
print(W_value)
print(W_value.shape)

Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274],
        [0.3821, 0.6605]], requires_grad=True)
torch.Size([4, 2])


# Calculate Attention Weights

Now, let’s suppose we are interested in computing the attention-vector for the second input element – the second input element acts as the query here:

In [10]:
x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)

print(f"x_2: {x_2} \n x_2.shape: {x_2.shape}")
print(f"W_query: {W_query} \n ... query_2: {query_2} \n")
print(f"W_key: {W_key} \n ... key_2: {key_2} \n")
print(f"W_value: {W_value} \n ... value_2: {value_2} \n")
print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

x_2: tensor([0.1794, 1.8951]) 
 x_2.shape: torch.Size([2])
W_query: Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]], requires_grad=True) 
 ... query_2: tensor([1.0321, 1.3501, 1.6555], grad_fn=<MvBackward0>) 

W_key: Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]], requires_grad=True) 
 ... key_2: tensor([0.2187, 1.4097, 1.3587], grad_fn=<MvBackward0>) 

W_value: Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274],
        [0.3821, 0.6605]], requires_grad=True) 
 ... value_2: tensor([0.3862, 0.8181, 1.5893, 1.3203], grad_fn=<MvBackward0>) 

torch.Size([3])
torch.Size([3])
torch.Size([4])


In [45]:
#checking the matmul 
(0.1794*0.2961) + (1.8951*0.5166)

1.0321289999999999

- These three matrices are used to project the embedded input tokens, $x^{(i)}$, into query, key, and value vectors via matrix multiplication:

  - Query vector: $q^{(i)} = W_q \,x^{(i)}$
  - Key vector: $k^{(i)} = W_k \,x^{(i)}$
  - Value vector: $v^{(i)} = W_v \,x^{(i)}$

In [46]:
# Compute Query, Key, and Value vectors
queries = W_query.matmul(embedded_sentence.T).T
keys = W_key.matmul(embedded_sentence.T).T
values = W_value.matmul(embedded_sentence.T).T

print("queries.shape", queries.shape) # [8, 3]
print("keys.shape:", keys.shape) # [8, 3]
print("values.shape:", values.shape) # [8, 4]

queries.shape torch.Size([8, 3])
keys.shape: torch.Size([8, 3])
values.shape: torch.Size([8, 4])


In [47]:
queries

tensor([[ 0.0081, -0.0375, -0.1291],
        [ 1.0321,  1.3501,  1.6555],
        [ 0.4443,  0.5424,  0.5980],
        [ 0.2858,  0.3100,  0.2699],
        [-0.5214, -0.7949, -1.1699],
        [-0.3889, -0.6280, -0.9766],
        [ 0.4443,  0.5424,  0.5980],
        [-0.2609, -0.3164, -0.3448]], grad_fn=<PermuteBackward0>)

In [48]:
keys

tensor([[ 0.0279, -0.0671, -0.0158],
        [ 0.2187,  1.4097,  1.3587],
        [ 0.1153,  0.5439,  0.5636],
        [ 0.0953,  0.2867,  0.3412],
        [-0.0491, -0.8956, -0.7485],
        [-0.0174, -0.7251, -0.5775],
        [ 0.1153,  0.5439,  0.5636],
        [-0.0689, -0.3159, -0.3298]], grad_fn=<PermuteBackward0>)

We can then generalize this to compute th remaining key, and value elements for all inputs as well, since we will need them in the next step when we compute the unnormalized attention weights

In [49]:
values

tensor([[-0.0094,  0.0353, -0.1071,  0.0115],
        [ 0.3862,  0.8181,  1.5893,  1.3203],
        [ 0.1562,  0.3756,  0.5877,  0.5693],
        [ 0.0904,  0.2649,  0.2815,  0.3671],
        [-0.2244, -0.3454, -1.0836, -0.6643],
        [-0.1765, -0.2364, -0.8957, -0.4945],
        [ 0.1562,  0.3756,  0.5877,  0.5693],
        [-0.0912, -0.2218, -0.3398, -0.3344]], grad_fn=<PermuteBackward0>)

In [50]:
embedded_sentence

tensor([[ 0.3374, -0.1778],
        [ 0.1794,  1.8951],
        [ 0.3486,  0.6603],
        [ 0.4954,  0.2692],
        [ 0.6984, -1.4097],
        [ 0.7671, -1.1925],
        [ 0.3486,  0.6603],
        [-0.2196, -0.3792]])

In [51]:
embedded_sentence.T

tensor([[ 0.3374,  0.1794,  0.3486,  0.4954,  0.6984,  0.7671,  0.3486, -0.2196],
        [-0.1778,  1.8951,  0.6603,  0.2692, -1.4097, -1.1925,  0.6603, -0.3792]])

Let's compute the unnormalized attention weights  $\omega$

we compute $\omega_{ij}$ as the dot product between the query and key sequences,
$\omega_{ij}$ = q^{(i)}k^{(j)}$

For example, we can compute the unnormalized attention weight for the query and 5th input element (corresponding to index position 4) as follows:

In [52]:
query_2

tensor([1.0321, 1.3501, 1.6555], grad_fn=<MvBackward0>)

In [53]:
keys

tensor([[ 0.0279, -0.0671, -0.0158],
        [ 0.2187,  1.4097,  1.3587],
        [ 0.1153,  0.5439,  0.5636],
        [ 0.0953,  0.2867,  0.3412],
        [-0.0491, -0.8956, -0.7485],
        [-0.0174, -0.7251, -0.5775],
        [ 0.1153,  0.5439,  0.5636],
        [-0.0689, -0.3159, -0.3298]], grad_fn=<PermuteBackward0>)

In [54]:
# query_2.dot(keys[4])
test_omega_24 = -0.0491 * 1.0321 + -0.8956 * 1.3501 +  -0.7485 * 1.6555
print(test_omega_24)

-2.49896742


In [55]:
#omaega_20 = [0.0279,-0.0671,-0.0158] * [1.0321, 1.3501, 1.6555]

omega_20 = 0.0279 * 1.0321 + -0.0671 * 1.3501 + -0.0158 * 1.6555
print(omega_20)

-0.08795302000000002


In [56]:
omega_24 = query_2.dot(keys[4])
print(omega_24)

tensor(-2.4988, grad_fn=<DotBackward0>)


In [57]:
omega_2 = query_2.matmul(keys.T)  # Matmul Query2[1, 3] *  Keys.T[3, 8] = [1, 8]
print(omega_2) # Result matrix for query2, i.e. shoes  = [ 1 token, 8 dimensions]

tensor([-0.0879,  4.3783,  1.7863,  1.0502, -2.4988, -1.9530,  1.7863, -1.0434],
       grad_fn=<SqueezeBackward4>)


# Computing the Attention Scores

In [58]:
import torch.nn.functional as F

# Compute attention scores for a specific token (e.g., "shoes")
query_2 = queries[1].unsqueeze(0)  # Query for "shoes" (2nd token)
omega_2 = query_2.matmul(keys.T)   # Dot product between query and keys
print("Raw Attention Scores (Omega):\n", omega_2)

# Scale and apply softmax
attention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=1)
print("Attention Weights (Softmax):\n", attention_weights_2)

Raw Attention Scores (Omega):
 tensor([[-0.0879,  4.3783,  1.7863,  1.0502, -2.4988, -1.9530,  1.7863, -1.0434]],
       grad_fn=<MmBackward0>)
Attention Weights (Softmax):
 tensor([[0.0432, 0.5687, 0.1273, 0.0832, 0.0107, 0.0147, 0.1273, 0.0249]],
       grad_fn=<SoftmaxBackward0>)


In [59]:
values

tensor([[-0.0094,  0.0353, -0.1071,  0.0115],
        [ 0.3862,  0.8181,  1.5893,  1.3203],
        [ 0.1562,  0.3756,  0.5877,  0.5693],
        [ 0.0904,  0.2649,  0.2815,  0.3671],
        [-0.2244, -0.3454, -1.0836, -0.6643],
        [-0.1765, -0.2364, -0.8957, -0.4945],
        [ 0.1562,  0.3756,  0.5877,  0.5693],
        [-0.0912, -0.2218, -0.3398, -0.3344]], grad_fn=<PermuteBackward0>)

In [60]:
print(attention_weights_2.shape)
print(values.shape)

torch.Size([1, 8])
torch.Size([8, 4])


In [61]:
# Compute context vector for "shoes"
context_vector_2 = attention_weights_2.matmul(values) # # matmul = attention_weights_2[1, 8] * values[8, 4]
print("Context Vector Shape:", context_vector_2.shape)  # [1, 4]
print("Context Vector:\n", context_vector_2)

Context Vector Shape: torch.Size([1, 4])
Context Vector:
 tensor([[0.2593, 0.5718, 1.0390, 0.9041]], grad_fn=<MmBackward0>)


In [62]:
# [0.0432, 0.5687, 0.1273, 0.0832, 0.0107, 0.0147, 0.1273, 0.0249] * tensor([[-0.0094,  0.0353, -0.1071,  0.0115],
#                                                                            [ 0.3862,  0.8181,  1.5893,  1.3203],
#                                                                            [ 0.1562,  0.3756,  0.5877,  0.5693],
#                                                                            [ 0.0904,  0.2649,  0.2815,  0.3671],
#                                                                            [-0.2244, -0.3454, -1.0836, -0.6643],
#                                                                            [-0.1765, -0.2364, -0.8957, -0.4945],
#                                                                            [ 0.1562,  0.3756,  0.5877,  0.5693],
#                                                                            [-0.0912, -0.2218, -0.3398, -0.3344]]

context_vector_query2_1 = (0.0432 * -0.0094) + (0.5687 *  0.3862) +  (0.1273 * 0.1562) + (0.0832 * 0.0904) + (0.0107 * -0.2244) + (0.0147 * -0.1765) + (0.1273 * 0.1562) + (0.0249 * -0.0912)
context_vector_query2_2 = (0.0432 * 0.0353) + (0.5687 *  0.8181) +  (0.1273 * 0.3756) + (0.0832 * 0.2649) + (0.0107 * -0.3454) + (0.0147 * -0.2364) + (0.1273 * 0.3756) + (0.0249 * -0.2218)

In [63]:
print(f"{context_vector_query2_1}, {context_vector_query2_2}")

0.25924915, 0.57175219


# Multi-Head Attention

In [64]:
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model=2, num_heads=2):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        # Linear layers for Q, K, V
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        
        # Final output projection (after concatenating heads)
        self.out_proj = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        """
        x: Tensor of shape (batch_size, seq_len, d_model)
        Returns: Tensor of shape (batch_size, seq_len, d_model)
        """
        batch_size, seq_len, d_model = x.shape
        
        # 1) Compute Q, K, V => each shape: (batch_size, seq_len, d_model)
        Q = self.W_Q(x)  # [1, 8, 2]
        K = self.W_K(x)  # [1, 8, 2]
        V = self.W_V(x)  # [1, 8, 2]
        print("Before reshaping:")
        print(f"Q: {Q} \nK: {K} \nV:{V}")
        
        # 2) Reshape for multi-head attention:
        #    (batch_size, seq_len, num_heads, head_dim) => then transpose
        #    to (batch_size, num_heads, seq_len, head_dim)
        Q = Q.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        # Now Q, K, V = (1, 2, 8, 1)
        print("After reshaping:")
        print(f"Q: {Q} \nK: {K} \nV:{V}")
        
        # 3) Compute attention scores: Q x K^T, scaled by sqrt(head_dim)
        #    (batch_size, num_heads, seq_len, head_dim) x
        #    (batch_size, num_heads, head_dim, seq_len) => (batch_size, num_heads, seq_len, seq_len)
        scores = torch.matmul(Q, K.transpose(-1, -2)) / (self.head_dim ** 0.5)  # => [1, 2, 8, 8]
        print("scores:", scores)
        
        # 4) Softmax along the "seq_len" of the keys
        attn_weights = F.softmax(scores, dim=-1)  # => [1, 2, 8, 8]
        print("attn_weights:", attn_weights)
        
        # 5) Compute context: multiply attn_weights by V
        context = torch.matmul(attn_weights, V)  # => [1, 2, 8, 1]
        print("context (matmul by V):", context)
        
        # 6) Transpose back + reshape to (batch_size, seq_len, d_model)
        context = context.transpose(1, 2).reshape(batch_size, seq_len, d_model)  # => [1, 8, 2]
        print("context (after reshape):", context)
        
        # 7) Final projection to d_model
        out = self.out_proj(context)  # => [1, 8, 2]
        
        return out


In [65]:
print(embedded_sentence)
# Make it batch-size=1 => shape (1, 8, 2)
x_in = embedded_sentence.unsqueeze(0)  # => (1,8,2)
# adding another dimension
print(x_in)
x_in.shape

tensor([[ 0.3374, -0.1778],
        [ 0.1794,  1.8951],
        [ 0.3486,  0.6603],
        [ 0.4954,  0.2692],
        [ 0.6984, -1.4097],
        [ 0.7671, -1.1925],
        [ 0.3486,  0.6603],
        [-0.2196, -0.3792]])
tensor([[[ 0.3374, -0.1778],
         [ 0.1794,  1.8951],
         [ 0.3486,  0.6603],
         [ 0.4954,  0.2692],
         [ 0.6984, -1.4097],
         [ 0.7671, -1.1925],
         [ 0.3486,  0.6603],
         [-0.2196, -0.3792]]])


torch.Size([1, 8, 2])

In [66]:
# Make it batch-size=1 => shape (1, 8, 2)
x_in = embedded_sentence.unsqueeze(0)  # => (1,8,2)

mha = MultiHeadSelfAttention(d_model=2, num_heads=2)
mha_out = mha(x_in)  # => (1,8,2)

Before reshaping:
Q: tensor([[[-0.1736,  0.1679],
         [ 0.0204,  1.5522],
         [-0.0576,  0.7421],
         [-0.0357,  0.5035],
         [-0.1554, -0.6032],
         [-0.0924, -0.4416],
         [-0.0576,  0.7421],
         [-0.4787, -0.0773]]], grad_fn=<ViewBackward0>) 
K: tensor([[[ 0.4076,  0.6441],
         [ 1.5049, -0.7966],
         [ 0.8276,  0.1040],
         [ 0.5838,  0.4420],
         [-0.3285,  1.6511],
         [-0.2403,  1.5483],
         [ 0.8276,  0.1040],
         [ 0.4811,  0.4608]]], grad_fn=<ViewBackward0>) 
V:tensor([[[-0.5458,  0.0120],
         [ 0.0089, -0.2744],
         [-0.2984, -0.0596],
         [-0.3670,  0.0632],
         [-0.7932,  0.3392],
         [-0.7088,  0.3594],
         [-0.2984, -0.0596],
         [-0.7761, -0.2966]]], grad_fn=<ViewBackward0>)
After reshaping:
Q: tensor([[[[-0.1736],
          [ 0.0204],
          [-0.0576],
          [-0.0357],
          [-0.1554],
          [-0.0924],
          [-0.0576],
          [-0.4787]],

     

In [67]:
# Residual + Norm
residual_1 = x_in + mha_out
layer_norm_1 = nn.LayerNorm(normalized_shape=2)
normed_1 = layer_norm_1(residual_1)

print("\n-- After Multi-Head Self-Attention --")
print("MHA output shape:", mha_out.shape)
print("MHA output:", mha_out)
print("Add & Norm shape:", normed_1.shape)
print("Add & Norm:", normed_1)


-- After Multi-Head Self-Attention --
MHA output shape: torch.Size([1, 8, 2])
MHA output: tensor([[[0.1839, 0.7218],
         [0.1803, 0.6445],
         [0.1805, 0.6862],
         [0.1774, 0.6997],
         [0.1777, 0.7618],
         [0.1746, 0.7531],
         [0.1805, 0.6862],
         [0.2011, 0.7406]]], grad_fn=<AddBackward0>)
Add & Norm shape: torch.Size([1, 8, 2])
Add & Norm: tensor([[[-0.9634,  0.9634],
         [-1.0000,  1.0000],
         [-1.0000,  1.0000],
         [-0.9998,  0.9998],
         [ 1.0000, -1.0000],
         [ 1.0000, -1.0000],
         [-1.0000,  1.0000],
         [-0.9999,  0.9999]]], grad_fn=<NativeLayerNormBackward0>)


In [68]:
class FeedForward(nn.Module):
    def __init__(self, d_model=2, hidden_dim=4):
        super().__init__()
        self.linear1 = nn.Linear(d_model, hidden_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_dim, d_model)
        
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

ffn = FeedForward(d_model=2, hidden_dim=4)
ffn_out = ffn(normed_1)  # => (1,8,2)
print("ffn", ffn)
print("ffn_out", ffn_out)

ffn FeedForward(
  (linear1): Linear(in_features=2, out_features=4, bias=True)
  (relu): ReLU()
  (linear2): Linear(in_features=4, out_features=2, bias=True)
)
ffn_out tensor([[[-0.0663,  0.2402],
         [-0.0638,  0.2437],
         [-0.0638,  0.2437],
         [-0.0638,  0.2437],
         [-0.2012, -0.0018],
         [-0.2012, -0.0018],
         [-0.0638,  0.2437],
         [-0.0638,  0.2437]]], grad_fn=<ViewBackward0>)


In [69]:
residual_2 = normed_1 + ffn_out
layer_norm_2 = nn.LayerNorm(normalized_shape=2)
encoder_out = layer_norm_2(residual_2)

print("\n-- After Feed-Forward --")
print("FFN output shape:", ffn_out.shape)
print("Final encoder output shape:", encoder_out.shape)
print("Final encoder output:", encoder_out)


-- After Feed-Forward --
FFN output shape: torch.Size([1, 8, 2])
Final encoder output shape: torch.Size([1, 8, 2])
Final encoder output: tensor([[[-1.0000,  1.0000],
         [-1.0000,  1.0000],
         [-1.0000,  1.0000],
         [-1.0000,  1.0000],
         [ 1.0000, -1.0000],
         [ 1.0000, -1.0000],
         [-1.0000,  1.0000],
         [-1.0000,  1.0000]]], grad_fn=<NativeLayerNormBackward0>)


Let's go with 3 heads now. 
- d_q, d_k, d_v = 3, 3, 4
- d = 2

In [16]:
# Number of heads
num_heads = 3

# Split embeddings into multiple heads
head_dim = d_v // num_heads  # Dimension per head
heads = [values[:, i*head_dim:(i+1)*head_dim] for i in range(num_heads)]

# Compute context vectors for each head
context_vectors = []
for head in heads:
    # Compute attention scores (simplified for illustration)
    attention_scores = F.softmax(query_2.matmul(keys.T) / d_k**0.5, dim=1)
    context_vector = attention_scores.matmul(head)
    context_vectors.append(context_vector)

print("context_vectors: ", context_vectors)
# Concatenate context vectors from all heads
combined_context = torch.cat(context_vectors, dim=1)
print("Combined Context Vector Shape:", combined_context.shape)  # [1, 4]
print("Combined Context Vector:\n", combined_context)

context_vectors:  [tensor([[0.2593]], grad_fn=<MmBackward0>), tensor([[0.5718]], grad_fn=<MmBackward0>), tensor([[1.0390]], grad_fn=<MmBackward0>)]
Combined Context Vector Shape: torch.Size([1, 3])
Combined Context Vector:
 tensor([[0.2593, 0.5718, 1.0390]], grad_fn=<CatBackward0>)


In [18]:
heads

[tensor([[-0.0094],
         [ 0.3862],
         [ 0.1562],
         [ 0.0904],
         [-0.2244],
         [-0.1765],
         [ 0.1562],
         [-0.0912]], grad_fn=<SliceBackward0>),
 tensor([[ 0.0353],
         [ 0.8181],
         [ 0.3756],
         [ 0.2649],
         [-0.3454],
         [-0.2364],
         [ 0.3756],
         [-0.2218]], grad_fn=<SliceBackward0>),
 tensor([[-0.1071],
         [ 1.5893],
         [ 0.5877],
         [ 0.2815],
         [-1.0836],
         [-0.8957],
         [ 0.5877],
         [-0.3398]], grad_fn=<SliceBackward0>)]

In [30]:
h = 3
multihead_W_query = torch.nn.Parameter(torch.rand(h, d_q, d))
multihead_W_key = torch.nn.Parameter(torch.rand(h, d_k, d))
multihead_W_value = torch.nn.Parameter(torch.rand(h, d_v, d))

In [32]:
print(multihead_W_query)
print(multihead_W_query.shape)

Parameter containing:
tensor([[[0.8536, 0.5932],
         [0.6367, 0.9826],
         [0.2745, 0.6584]],

        [[0.2775, 0.8573],
         [0.8993, 0.0390],
         [0.9268, 0.7388]],

        [[0.7179, 0.7058],
         [0.9156, 0.4340],
         [0.0772, 0.3565]]], requires_grad=True)
torch.Size([3, 3, 2])


(here, let’s keep the focus on the 3rd element corresponding to index position 2)

In [35]:
x_2

tensor([0.1794, 1.8951])

In [34]:
multihead_query_2 = multihead_W_query.matmul(x_2)
print(multihead_query_2)
print(multihead_query_2.shape)

tensor([[1.2772, 1.9764, 1.2970],
        [1.6745, 0.2353, 1.5663],
        [1.4664, 0.9867, 0.6895]], grad_fn=<UnsafeViewBackward0>)
torch.Size([3, 3])


In [36]:
multihead_key_2 = multihead_W_key.matmul(x_2)
multihead_value_2 = multihead_W_value.matmul(x_2)
print("Keys:")
print(multihead_key_2)
print(multihead_key_2.shape)
print("Values:")
print(multihead_value_2)
print(multihead_value_2.shape)

Keys:
tensor([[1.0367, 0.5123, 1.9268],
        [1.0603, 1.1722, 1.6966],
        [1.2750, 1.4245, 0.1450]], grad_fn=<UnsafeViewBackward0>)
torch.Size([3, 3])
Values:
tensor([[1.7906, 0.2750, 1.8344, 0.3146],
        [1.0396, 0.9122, 1.3936, 1.6952],
        [0.2099, 1.0958, 0.1481, 1.7002]], grad_fn=<UnsafeViewBackward0>)
torch.Size([3, 4])


Now, these key and value elements are specific to the query element. But, similar to earlier, we will also need the value and keys for the other sequence elements in order to compute the attention scores for the query. We can do this is by expanding the input sequence embeddings to size 3, i.e., the number of attention heads:

In [42]:
print(embedded_sentence)
print(embedded_sentence.shape)
print("\nTranspose...")
print(embedded_sentence.T)
print(embedded_sentence.T.shape)

tensor([[ 0.3374, -0.1778],
        [ 0.1794,  1.8951],
        [ 0.3486,  0.6603],
        [ 0.4954,  0.2692],
        [ 0.6984, -1.4097],
        [ 0.7671, -1.1925],
        [ 0.3486,  0.6603],
        [-0.2196, -0.3792]])
torch.Size([8, 2])

Transpose...
tensor([[ 0.3374,  0.1794,  0.3486,  0.4954,  0.6984,  0.7671,  0.3486, -0.2196],
        [-0.1778,  1.8951,  0.6603,  0.2692, -1.4097, -1.1925,  0.6603, -0.3792]])
torch.Size([2, 8])


Since we have 3 attention heads, we will duplicate the input embeddings to size 3.  

In [43]:
stacked_inputs = embedded_sentence.T.repeat(3, 1, 1)
print(stacked_inputs)
print(stacked_inputs.shape)

tensor([[[ 0.3374,  0.1794,  0.3486,  0.4954,  0.6984,  0.7671,  0.3486,
          -0.2196],
         [-0.1778,  1.8951,  0.6603,  0.2692, -1.4097, -1.1925,  0.6603,
          -0.3792]],

        [[ 0.3374,  0.1794,  0.3486,  0.4954,  0.6984,  0.7671,  0.3486,
          -0.2196],
         [-0.1778,  1.8951,  0.6603,  0.2692, -1.4097, -1.1925,  0.6603,
          -0.3792]],

        [[ 0.3374,  0.1794,  0.3486,  0.4954,  0.6984,  0.7671,  0.3486,
          -0.2196],
         [-0.1778,  1.8951,  0.6603,  0.2692, -1.4097, -1.1925,  0.6603,
          -0.3792]]])
torch.Size([3, 2, 8])


Now, we can compute all the keys and values using via `torch.bmm()` (batch matrix multiplication):

In [44]:
multihead_keys = torch.bmm(multihead_W_key, stacked_inputs)
multihead_values = torch.bmm(multihead_W_value, stacked_inputs)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)

multihead_keys.shape: torch.Size([3, 3, 8])
multihead_values.shape: torch.Size([3, 4, 8])


We now have tensors that represent the three attention heads in their first dimension. 
- The third dimension refer to the number of words, and 
- The second dimension refers to the embedding size, respectively. 

To make the values and keys more intuitive to interpret, we will swap the second and third dimensions, resulting in tensors with the same dimensional structure as the original input sequence, `embedded_sentence`:

In [45]:
multihead_keys = multihead_keys.permute(0, 2, 1)
multihead_values = multihead_values.permute(0, 2, 1)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)

multihead_keys.shape: torch.Size([3, 8, 3])
multihead_values.shape: torch.Size([3, 8, 4])


Now, 
- the first dimension represents the number of heads
- the second dimension represents the numebr of tokens
- the thrid dimension represents the number of embedding size for keys and values

In [46]:
print("multihead_keys:")
print(multihead_keys)

print("multihead_values:")
print(multihead_values)

multihead_keys:
tensor([[[-0.0449,  0.0960, -0.0198],
         [ 1.0367,  0.5123,  1.9268],
         [ 0.4035,  0.2948,  0.8014],
         [ 0.2168,  0.2639,  0.4873],
         [-0.6482, -0.0428, -1.0552],
         [-0.5222,  0.0355, -0.8125],
         [ 0.4035,  0.2948,  0.8014],
         [-0.2346, -0.1772, -0.4690]],

        [[ 0.0637,  0.0395,  0.1757],
         [ 1.0603,  1.1722,  1.6966],
         [ 0.5012,  0.5292,  0.8617],
         [ 0.3671,  0.3648,  0.6854],
         [-0.4056, -0.5210, -0.4756],
         [-0.2619, -0.3663, -0.2355],
         [ 0.5012,  0.5292,  0.8617],
         [-0.2968, -0.3121, -0.5132]],

        [[ 0.1204,  0.0852,  0.1406],
         [ 1.2750,  1.4245,  0.1450],
         [ 0.6381,  0.6731,  0.1751],
         [ 0.4995,  0.4927,  0.2252],
         [-0.3849, -0.5457,  0.2543],
         [-0.2061, -0.3526,  0.2919],
         [ 0.6381,  0.6731,  0.1751],
         [-0.3796, -0.3985, -0.1090]]], grad_fn=<PermuteBackward0>)
multihead_values:
tensor([[[-0.1004,  