Ref: https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html

In [1]:
from importlib.metadata import version

print("torch version:", version("torch"))

import torch

torch version: 2.5.1


# Sentence Embedding

In [2]:
sentence = 'My shoes are small, my feet are big.'

dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)

{'My': 0, 'are': 2, 'big.': 3, 'feet': 4, 'my': 5, 'shoes': 6, 'small': 7}


- Assign index to each word 

In [3]:
sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
print(sentence_int)

tensor([0, 6, 2, 7, 5, 4, 2, 3])


- Now, using the integer-vector representation of the input sentence, we can use an embedding layer to encode the inputs into a real-vector embedding. Here, we will use a 2-dimensional embedding such that each input word is represented by a 2-dimensional vector. Since the sentence consists of 8 words, this will result in a 8 X 2 dimensional embedding:

In [4]:
torch.manual_seed(123)
embed = torch.nn.Embedding(8, 2)
embedded_sentence = embed(sentence_int).detach()

print(embedded_sentence)
print(embedded_sentence.shape)

tensor([[ 0.3374, -0.1778],
        [ 0.1794,  1.8951],
        [ 0.3486,  0.6603],
        [ 0.4954,  0.2692],
        [ 0.6984, -1.4097],
        [ 0.7671, -1.1925],
        [ 0.3486,  0.6603],
        [-0.2196, -0.3792]])
torch.Size([8, 2])


# Weight Matrices

Now, let’s discuss the widely utilized self-attention mechanism known as the scaled dot-product attention, which is integrated into the transformer architecture.

Self-attention utilizes three weight matrices, referred to as $W_q$, $W_k$ and $W_v$ which are adjusted as model parameters during training. These matrices serve to project the inputs into query, key, and value components of the sequence, respectively.


Since we are computing the dot-product between the query and key vectors, these two vectors have to contain the same number of elements, However, the number of elements in the value vector $v^{(i)}$, which determines the size of the resulting context vector, is arbitrary.

We will be extending the dimensions for query and keys to 3 and values to 4. 

In [5]:
torch.manual_seed(123)

d = embedded_sentence.shape[1]

d_q, d_k, d_v = 3, 3, 4

W_query = torch.nn.Parameter(torch.rand(d_q, d))
W_key = torch.nn.Parameter(torch.rand(d_k, d))
W_value = torch.nn.Parameter(torch.rand(d_v, d))

In [6]:
embedded_sentence.shape

torch.Size([8, 2])

In [7]:
print(W_query)
print(W_query.shape)

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]], requires_grad=True)
torch.Size([3, 2])


In [8]:
print(W_key)
print(W_key.shape)

Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]], requires_grad=True)
torch.Size([3, 2])


In [9]:
print(W_value)
print(W_value.shape)

Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274],
        [0.3821, 0.6605]], requires_grad=True)
torch.Size([4, 2])


# Calculate Attention Weights

Now, let’s suppose we are interested in computing the attention-vector for the second input element – the second input element acts as the query here:

In [10]:
x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)

print(f"x_2: {x_2} \n x_2.shape: {x_2.shape}")
print(f"W_query: {W_query} \n ... query_2: {query_2} \n")
print(f"W_key: {W_key} \n ... key_2: {key_2} \n")
print(f"W_value: {W_value} \n ... value_2: {value_2} \n")
print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

x_2: tensor([0.1794, 1.8951]) 
 x_2.shape: torch.Size([2])
W_query: Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]], requires_grad=True) 
 ... query_2: tensor([1.0321, 1.3501, 1.6555], grad_fn=<MvBackward0>) 

W_key: Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]], requires_grad=True) 
 ... key_2: tensor([0.2187, 1.4097, 1.3587], grad_fn=<MvBackward0>) 

W_value: Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274],
        [0.3821, 0.6605]], requires_grad=True) 
 ... value_2: tensor([0.3862, 0.8181, 1.5893, 1.3203], grad_fn=<MvBackward0>) 

torch.Size([3])
torch.Size([3])
torch.Size([4])


In [11]:
#checking the matmul 
(0.1794*0.2961) + (1.8951*0.5166)

1.0321289999999999

- These three matrices are used to project the embedded input tokens, $x^{(i)}$, into query, key, and value vectors via matrix multiplication:

  - Query vector: $q^{(i)} = W_q \,x^{(i)}$
  - Key vector: $k^{(i)} = W_k \,x^{(i)}$
  - Value vector: $v^{(i)} = W_v \,x^{(i)}$

In [12]:
keys = W_key.matmul(embedded_sentence.T).T
values = W_value.matmul(embedded_sentence.T).T

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([8, 3])
values.shape: torch.Size([8, 4])


In [13]:
keys

tensor([[ 0.0279, -0.0671, -0.0158],
        [ 0.2187,  1.4097,  1.3587],
        [ 0.1153,  0.5439,  0.5636],
        [ 0.0953,  0.2867,  0.3412],
        [-0.0491, -0.8956, -0.7485],
        [-0.0174, -0.7251, -0.5775],
        [ 0.1153,  0.5439,  0.5636],
        [-0.0689, -0.3159, -0.3298]], grad_fn=<PermuteBackward0>)

We can then generalize this to compute th remaining key, and value elements for all inputs as well, since we will need them in the next step when we compute the unnormalized attention weights

In [14]:
values

tensor([[-0.0094,  0.0353, -0.1071,  0.0115],
        [ 0.3862,  0.8181,  1.5893,  1.3203],
        [ 0.1562,  0.3756,  0.5877,  0.5693],
        [ 0.0904,  0.2649,  0.2815,  0.3671],
        [-0.2244, -0.3454, -1.0836, -0.6643],
        [-0.1765, -0.2364, -0.8957, -0.4945],
        [ 0.1562,  0.3756,  0.5877,  0.5693],
        [-0.0912, -0.2218, -0.3398, -0.3344]], grad_fn=<PermuteBackward0>)

In [15]:
embedded_sentence

tensor([[ 0.3374, -0.1778],
        [ 0.1794,  1.8951],
        [ 0.3486,  0.6603],
        [ 0.4954,  0.2692],
        [ 0.6984, -1.4097],
        [ 0.7671, -1.1925],
        [ 0.3486,  0.6603],
        [-0.2196, -0.3792]])

In [16]:
embedded_sentence.T

tensor([[ 0.3374,  0.1794,  0.3486,  0.4954,  0.6984,  0.7671,  0.3486, -0.2196],
        [-0.1778,  1.8951,  0.6603,  0.2692, -1.4097, -1.1925,  0.6603, -0.3792]])

Let's compute the unnormalized attention weights  $\omega$

we compute $\omega_{ij}$ as the dot product between the query and key sequences,
$\omega_{ij}$ = q^{(i)}k^{(j)}$

For example, we can compute the unnormalized attention weight for the query and 5th input element (corresponding to index position 4) as follows:

In [17]:
query_2

tensor([1.0321, 1.3501, 1.6555], grad_fn=<MvBackward0>)

In [18]:
keys

tensor([[ 0.0279, -0.0671, -0.0158],
        [ 0.2187,  1.4097,  1.3587],
        [ 0.1153,  0.5439,  0.5636],
        [ 0.0953,  0.2867,  0.3412],
        [-0.0491, -0.8956, -0.7485],
        [-0.0174, -0.7251, -0.5775],
        [ 0.1153,  0.5439,  0.5636],
        [-0.0689, -0.3159, -0.3298]], grad_fn=<PermuteBackward0>)

In [19]:
# query_2.dot(keys[4])
test_omega_24 = -0.0491 * 1.0321 + -0.8956 * 1.3501 +  -0.7485 * 1.6555
print(test_omega_24)

-2.49896742


In [20]:
#omaega_20 = [0.0279,-0.0671,-0.0158] * [1.0321, 1.3501, 1.6555]

omega_20 = 0.0279 * 1.0321 + -0.0671 * 1.3501 + -0.0158 * 1.6555
print(omega_20)

-0.08795302000000002


In [21]:
omega_24 = query_2.dot(keys[4])
print(omega_24)

tensor(-2.4988, grad_fn=<DotBackward0>)


In [22]:
omega_2 = query_2.matmul(keys.T)
print(omega_2)

tensor([-0.0879,  4.3783,  1.7863,  1.0502, -2.4988, -1.9530,  1.7863, -1.0434],
       grad_fn=<SqueezeBackward4>)


# Computing the Attention Scores

In [23]:
import torch.nn.functional as F

attention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=0)
print(attention_weights_2)

tensor([0.0432, 0.5687, 0.1273, 0.0832, 0.0107, 0.0147, 0.1273, 0.0249],
       grad_fn=<SoftmaxBackward0>)


In [24]:
values

tensor([[-0.0094,  0.0353, -0.1071,  0.0115],
        [ 0.3862,  0.8181,  1.5893,  1.3203],
        [ 0.1562,  0.3756,  0.5877,  0.5693],
        [ 0.0904,  0.2649,  0.2815,  0.3671],
        [-0.2244, -0.3454, -1.0836, -0.6643],
        [-0.1765, -0.2364, -0.8957, -0.4945],
        [ 0.1562,  0.3756,  0.5877,  0.5693],
        [-0.0912, -0.2218, -0.3398, -0.3344]], grad_fn=<PermuteBackward0>)

In [25]:
print(attention_weights_2.shape)
print(values.shape)

torch.Size([8])
torch.Size([8, 4])


In [26]:
context_vector_2 = attention_weights_2.matmul(values)

print(context_vector_2.shape)
print(context_vector_2)

torch.Size([4])
tensor([0.2593, 0.5718, 1.0390, 0.9041], grad_fn=<SqueezeBackward4>)


In [27]:
print(values.T)
(values.T).shape

tensor([[-0.0094,  0.3862,  0.1562,  0.0904, -0.2244, -0.1765,  0.1562, -0.0912],
        [ 0.0353,  0.8181,  0.3756,  0.2649, -0.3454, -0.2364,  0.3756, -0.2218],
        [-0.1071,  1.5893,  0.5877,  0.2815, -1.0836, -0.8957,  0.5877, -0.3398],
        [ 0.0115,  1.3203,  0.5693,  0.3671, -0.6643, -0.4945,  0.5693, -0.3344]],
       grad_fn=<PermuteBackward0>)


torch.Size([4, 8])

In [28]:
# [0.0432, 0.5687, 0.1273, 0.0832, 0.0107, 0.0147, 0.1273, 0.0249] * tensor([[-0.0094,  0.0353, -0.1071,  0.0115],
#                                                                            [ 0.3862,  0.8181,  1.5893,  1.3203],
#                                                                            [ 0.1562,  0.3756,  0.5877,  0.5693],
#                                                                            [ 0.0904,  0.2649,  0.2815,  0.3671],
#                                                                            [-0.2244, -0.3454, -1.0836, -0.6643],
#                                                                            [-0.1765, -0.2364, -0.8957, -0.4945],
#                                                                            [ 0.1562,  0.3756,  0.5877,  0.5693],
#                                                                            [-0.0912, -0.2218, -0.3398, -0.3344]]

context_vector_query2_1 = (0.0432 * -0.0094) + (0.5687 *  0.3862) +  (0.1273 * 0.1562) + (0.0832 * 0.0904) + (0.0107 * -0.2244) + (0.0147 * -0.1765) + (0.1273 * 0.1562) + (0.0249 * -0.0912)
context_vector_query2_2 = (0.0432 * 0.0353) + (0.5687 *  0.8181) +  (0.1273 * 0.3756) + (0.0832 * 0.2649) + (0.0107 * -0.3454) + (0.0147 * -0.2364) + (0.1273 * 0.3756) + (0.0249 * -0.2218)

In [29]:
print(f"{context_vector_query2_1}, {context_vector_query2_2}")

0.25924915, 0.57175219


# Multi-Head Attention

Let's go with 3 heads now. 
- d_q, d_k, d_v = 3, 3, 4
- d = 2

In [30]:
h = 3
multihead_W_query = torch.nn.Parameter(torch.rand(h, d_q, d))
multihead_W_key = torch.nn.Parameter(torch.rand(h, d_k, d))
multihead_W_value = torch.nn.Parameter(torch.rand(h, d_v, d))

In [32]:
print(multihead_W_query)
print(multihead_W_query.shape)

Parameter containing:
tensor([[[0.8536, 0.5932],
         [0.6367, 0.9826],
         [0.2745, 0.6584]],

        [[0.2775, 0.8573],
         [0.8993, 0.0390],
         [0.9268, 0.7388]],

        [[0.7179, 0.7058],
         [0.9156, 0.4340],
         [0.0772, 0.3565]]], requires_grad=True)
torch.Size([3, 3, 2])


(here, let’s keep the focus on the 3rd element corresponding to index position 2)

In [35]:
x_2

tensor([0.1794, 1.8951])

In [34]:
multihead_query_2 = multihead_W_query.matmul(x_2)
print(multihead_query_2)
print(multihead_query_2.shape)

tensor([[1.2772, 1.9764, 1.2970],
        [1.6745, 0.2353, 1.5663],
        [1.4664, 0.9867, 0.6895]], grad_fn=<UnsafeViewBackward0>)
torch.Size([3, 3])


In [36]:
multihead_key_2 = multihead_W_key.matmul(x_2)
multihead_value_2 = multihead_W_value.matmul(x_2)
print("Keys:")
print(multihead_key_2)
print(multihead_key_2.shape)
print("Values:")
print(multihead_value_2)
print(multihead_value_2.shape)

Keys:
tensor([[1.0367, 0.5123, 1.9268],
        [1.0603, 1.1722, 1.6966],
        [1.2750, 1.4245, 0.1450]], grad_fn=<UnsafeViewBackward0>)
torch.Size([3, 3])
Values:
tensor([[1.7906, 0.2750, 1.8344, 0.3146],
        [1.0396, 0.9122, 1.3936, 1.6952],
        [0.2099, 1.0958, 0.1481, 1.7002]], grad_fn=<UnsafeViewBackward0>)
torch.Size([3, 4])


Now, these key and value elements are specific to the query element. But, similar to earlier, we will also need the value and keys for the other sequence elements in order to compute the attention scores for the query. We can do this is by expanding the input sequence embeddings to size 3, i.e., the number of attention heads:

In [42]:
print(embedded_sentence)
print(embedded_sentence.shape)
print("\nTranspose...")
print(embedded_sentence.T)
print(embedded_sentence.T.shape)

tensor([[ 0.3374, -0.1778],
        [ 0.1794,  1.8951],
        [ 0.3486,  0.6603],
        [ 0.4954,  0.2692],
        [ 0.6984, -1.4097],
        [ 0.7671, -1.1925],
        [ 0.3486,  0.6603],
        [-0.2196, -0.3792]])
torch.Size([8, 2])

Transpose...
tensor([[ 0.3374,  0.1794,  0.3486,  0.4954,  0.6984,  0.7671,  0.3486, -0.2196],
        [-0.1778,  1.8951,  0.6603,  0.2692, -1.4097, -1.1925,  0.6603, -0.3792]])
torch.Size([2, 8])


Since we have 3 attention heads, we will duplicate the input embeddings to size 3.  

In [43]:
stacked_inputs = embedded_sentence.T.repeat(3, 1, 1)
print(stacked_inputs)
print(stacked_inputs.shape)

tensor([[[ 0.3374,  0.1794,  0.3486,  0.4954,  0.6984,  0.7671,  0.3486,
          -0.2196],
         [-0.1778,  1.8951,  0.6603,  0.2692, -1.4097, -1.1925,  0.6603,
          -0.3792]],

        [[ 0.3374,  0.1794,  0.3486,  0.4954,  0.6984,  0.7671,  0.3486,
          -0.2196],
         [-0.1778,  1.8951,  0.6603,  0.2692, -1.4097, -1.1925,  0.6603,
          -0.3792]],

        [[ 0.3374,  0.1794,  0.3486,  0.4954,  0.6984,  0.7671,  0.3486,
          -0.2196],
         [-0.1778,  1.8951,  0.6603,  0.2692, -1.4097, -1.1925,  0.6603,
          -0.3792]]])
torch.Size([3, 2, 8])


Now, we can compute all the keys and values using via `torch.bmm()` (batch matrix multiplication):