In [73]:
# https://magazine.sebastianraschka.com/p/understanding-and-coding-self-attention

In [2]:
sentence = 'Life is short, eat dessert first'

mapping = {s: i for i, s in enumerate(sorted(sentence.replace(',', '').split()))}

mapping

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}

In [12]:
sentence_int = torch.tensor([mapping[s] for s in sentence.replace(',','').split()])
sentence_int

tensor([0, 4, 5, 2, 1, 3])

In [11]:
import torch

vocab_size = 50_000

torch.manual_seed(123)
embed = torch.nn.Embedding(vocab_size, 3)
embedded_sentence = embed(sentence_int).detach()

print(embedded_sentence)
print(embedded_sentence.shape)

tensor([[ 0.3374, -0.1778, -0.3035],
        [ 0.1794,  1.8951,  0.4954],
        [ 0.2692, -0.0770, -1.0205],
        [-0.2196, -0.3792,  0.7671],
        [-0.5880,  0.3486,  0.6603],
        [-1.1925,  0.6984, -1.4097]])
torch.Size([6, 3])


In [17]:
torch.manual_seed(124)

d = embedded_sentence.shape[1]
d

3

In [19]:
d_q, d_k, d_v = 2, 2, 4

W_query = torch.nn.Parameter(torch.rand(d_q, d))
W_key = torch.nn.Parameter(torch.rand(d_k, d))
W_value = torch.nn.Parameter(torch.rand(d_v, d))

W_query, W_value

(Parameter containing:
 tensor([[0.7310, 0.4033, 0.5970],
         [0.3808, 0.7053, 0.0211]], requires_grad=True),
 Parameter containing:
 tensor([[0.5950, 0.7909, 0.0994],
         [0.1674, 0.3942, 0.0355],
         [0.5912, 0.3594, 0.2125],
         [0.0139, 0.6765, 0.3368]], requires_grad=True))

In [20]:
x_2 = embedded_sentence[1]
query_2 = W_query @ x_2
key_2 = W_key @ x_2
value_2 = W_value @ x_2

print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

query_2, key_2, value_2

torch.Size([2])
torch.Size([2])
torch.Size([4])


In [26]:
keys = embedded_sentence @ W_key.T
values = embedded_sentence @ W_value.T

keys

tensor([[-0.0300, -0.0029],
        [ 1.5483,  0.7627],
        [-0.5884, -0.0187],
        [ 0.1944, -0.1504],
        [ 0.1739,  0.0260],
        [-1.7673, -0.0923]], grad_fn=<MmBackward0>)

In [31]:
omega_2 = query_2 @ keys.T
omega_2

tensor([-0.0398,  2.9238, -0.7273,  0.0188,  0.2440, -2.2358],
       grad_fn=<SqueezeBackward4>)

In [33]:
!pip install torch



In [72]:
import torch.nn.functional as F

attention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=-1)
attention_weights_2

tensor([0.0818, 0.6652, 0.0503, 0.0853, 0.1000, 0.0173],
       grad_fn=<SoftmaxBackward0>)

In [41]:
context_vector_2 = attention_weights_2 @ values
context_vector_2

tensor([1.0670, 0.5186, 0.5652, 0.9726], grad_fn=<SqueezeBackward4>)

In [69]:
import torch.nn as nn

class SelfAttention(torch.nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))        
        
    def forward(self, x):
        queries = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value
        
        omega = queries @ keys.T
        print(omega / d_out_kq ** 0.5)
        attention_weights = F.softmax(omega / d_out_kq ** 0.5, dim=-1)
        
        return attention_weights @ values

In [70]:
torch.manual_seed(123)

# reduce d_out_v from 4 to 1, because we have 4 heads
d_in, d_out_kq, d_out_v = 3, 2, 4

sa = SelfAttention(d_in, d_out_kq, d_out_v)
print(sa(embedded_sentence))

tensor([[ 0.0434, -0.2469,  0.1020, -0.0309, -0.0921,  0.0761],
        [-0.4246,  2.4542, -1.0623,  0.3529,  0.9124, -0.9456],
        [ 0.1720, -0.9853,  0.4150, -0.1309, -0.3671,  0.3345],
        [-0.0562,  0.3173, -0.1278,  0.0366,  0.1186, -0.0846],
        [-0.1068,  0.6100, -0.2544,  0.0786,  0.2274, -0.1971],
        [ 0.3072, -1.7704,  0.7594, -0.2481, -0.6587,  0.6551]],
       grad_fn=<DivBackward0>)
tensor([[-0.1564,  0.1028, -0.0763, -0.0764],
        [ 0.5313,  1.3607,  0.7891,  1.3110],
        [-0.3542, -0.1234, -0.2627, -0.3706],
        [ 0.0071,  0.3345,  0.0969,  0.1998],
        [ 0.1008,  0.4780,  0.2021,  0.3674],
        [-0.5296, -0.2799, -0.4107, -0.6006]], grad_fn=<MmBackward0>)
