<a href="https://colab.research.google.com/github/rahiakela/deep-learning-research-and-practice/blob/main/deep-learning-fundamentals/unit08-NLP/01-bag-of-words/02_self_attention_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Self-attention works from scratch

**Reference**

[Understanding and Coding the Self-Attention Mechanism of Large Language Models From Scratch](https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html)


##Setup

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F

from sklearn.feature_extraction.text import CountVectorizer
from torch.utils.data import Dataset, DataLoader

## 1) Embedding Input Sentence

In [2]:
sentence = "Life is short, eat dessert first"

# create vocab
vocab = {w: i for i, w in enumerate(sorted(sentence.replace(",", "").split()))}
vocab

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}

In [3]:
# let's make integer-vector representation
sentence_vec = torch.tensor([vocab[w] for w in sentence.replace(",", "").split()])
sentence_vec

tensor([0, 4, 5, 2, 1, 3])

In [4]:
torch.manual_seed(123)

# let's embed this integer-vector representation
embedding = torch.nn.Embedding(6, 16)
embedded_sentence = embedding(sentence_vec).detach()

print(embedded_sentence.shape)
print(embedded_sentence)

torch.Size([6, 16])
tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,
          0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465],
        [ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,
         -1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0.6230, -1.3729],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,
          0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],
        [ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
          2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293]])


## 2)  Weight Matrices

In [5]:
torch.manual_seed(123)

d = embedded_sentence.shape[1]

# For computing dot-product between the query and key vectors, it contain the same number of elements
d_q, d_k, d_v = 24, 24, 28

W_Q = torch.nn.Parameter(torch.rand(d_q, d))
W_K = torch.nn.Parameter(torch.rand(d_k, d))
W_V = torch.nn.Parameter(torch.rand(d_v, d))

In [6]:
W_Q.shape, W_K.shape, W_V.shape

(torch.Size([24, 16]), torch.Size([24, 16]), torch.Size([28, 16]))

##3) Computing Attention Weights

In [7]:
# let's make the second input element acts as the query for computing the attention-vector
x_2 = embedded_sentence[1]

query_2 = W_Q.matmul(x_2)
key_2 = W_K.matmul(x_2)
value_2 = W_V.matmul(x_2)

print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

torch.Size([24])
torch.Size([24])
torch.Size([28])


In [8]:
# We can then generalize this to compute th remaining key
keys = W_K.matmul(embedded_sentence.T).T
values = W_V.matmul(embedded_sentence.T).T

print(f"keys shape: {keys.shape}")
print(f"values shape: {values.shape}")

keys shape: torch.Size([6, 24])
values shape: torch.Size([6, 28])


In [9]:
keys

tensor([[ 1.3274,  0.9452,  0.5531,  0.1000, -1.5909,  0.8779,  0.8645,  1.1643,
          2.2148,  1.3088,  0.8315, -1.5550, -0.2360,  2.0408, -0.3657,  1.7449,
          0.1180,  2.2120, -0.2037, -0.6591, -0.1771, -0.1563,  0.7547,  1.0151],
        [-0.0809, -1.2746, -2.3948, -0.3425,  1.5967,  0.5399,  0.9113,  0.0962,
          0.7300, -1.0553,  1.2533, -0.2113,  1.0208, -0.7470,  1.5171,  0.2773,
         -0.3173,  0.2698,  1.5237, -1.0970,  1.3849,  0.4400, -2.4926,  0.3594],
        [-2.1744, -2.7340, -1.0410, -1.8867, -4.0902, -0.3303, -3.1343, -2.4864,
         -1.1285, -3.5427, -4.7195, -6.1590, -0.9058, -3.2316, -1.3660, -3.5371,
         -2.7504, -2.3356, -1.2755, -3.0702, -2.3168,  0.4209, -2.5422, -3.8700],
        [-0.4005,  0.6766, -1.7351,  1.0082, -1.1248, -3.2161,  0.5959, -2.3485,
         -1.7592, -1.2618, -1.8644, -0.5954, -1.4934, -3.0239,  0.3801, -1.9108,
         -0.6866, -2.1498, -1.9909, -0.8837, -1.6336, -2.1131, -0.4434, -0.1749],
        [ 1.1230,  1.301

In [10]:
# let's compute the unnormalized attention weight for the query and 5th input element
omega_24 = query_2.dot(keys[4])
print(omega_24)

tensor(11.1466, grad_fn=<DotBackward0>)


In [11]:
# let’s compute the ω values for all input
omega_2 = query_2.matmul(keys.T)
print(omega_2)

tensor([ 8.5808, -7.6597,  3.2558,  1.0395, 11.1466, -0.4800],
       grad_fn=<SqueezeBackward4>)


##4) Computing Attention Scores

In [12]:
attention_weights_2 = F.softmax(omega_2 / d_k ** 0.5, dim=0)
print(attention_weights_2)

tensor([0.2912, 0.0106, 0.0982, 0.0625, 0.4917, 0.0458],
       grad_fn=<SoftmaxBackward0>)


##5) Computing Context Vector

In [13]:
context_vector_2 = attention_weights_2.matmul(values)

print(context_vector_2.shape)
print(context_vector_2)

torch.Size([28])
tensor([-1.5993,  0.0156,  1.2670,  0.0032, -0.6460, -1.1407, -0.4908, -1.4632,
         0.4747,  1.1926,  0.4506, -0.7110,  0.0602,  0.7125, -0.1628, -2.0184,
         0.3838, -2.1188, -0.8136, -1.5694,  0.7934, -0.2911, -1.3640, -0.2366,
        -0.9564, -0.5265,  0.0624,  1.7084], grad_fn=<SqueezeBackward4>)


In [17]:
# (1x6)(6x28) = (1x28)
attention_weights_2.shape, values.shape, context_vector_2.shape

(torch.Size([6]), torch.Size([6, 28]), torch.Size([28]))

##6) Multi-Head Attention