## Self attention

In [26]:
import torch
import torch.nn.functional as F
import math

In [27]:
# random input with 1 batch , 4 tokens and every token as 6d vector.
x = torch.randn(1, 4, 6)
print(x)

tensor([[[ 1.6424, -0.9907,  0.6301,  0.8595, -0.6456, -0.5459],
         [-0.4458, -0.8357,  1.9928, -1.6449,  0.6589,  0.8060],
         [-0.8718, -0.2471, -2.5186, -1.3519, -1.0562,  0.4714],
         [-0.9784,  0.6653, -1.4547,  0.6962,  0.6511, -0.7816]]])


In [28]:
### Weight Matrix (random right now but learnable during training)
W_q = torch.randn(6,6)
W_k = torch.randn(6,6)
W_v = torch.randn(6,6)

In [29]:
### computing Query , Key and Value tensors (heart of the self attention)
Q = x @ W_q
K = x @ W_k
V = x @ W_v
Q.shape, K.shape, V.shape

(torch.Size([1, 4, 6]), torch.Size([1, 4, 6]), torch.Size([1, 4, 6]))

In [30]:
### computing attention score using Query and Key tensors

scores = Q @ K.transpose(-2,-1)
scores.shape # (1,4,6) X (1,6,4) denotes similarity matrix between every tokens

torch.Size([1, 4, 4])

In [31]:
## Scaling - used in original paper for stable training and avoiding gradient exploding problem
d_k = K.size(-1)
scores = scores / math.sqrt(d_k)

In [32]:
## softmax
attn = F.softmax(scores, dim=-1)

In [33]:
output = attn @ V
output.shape # (1,4,4) X (1,4,6) => (1,4,6)

torch.Size([1, 4, 6])

In [34]:
print(f"attention weights: \n{attn}")

attention weights: 
tensor([[[4.2019e-06, 1.7077e-04, 9.0447e-01, 9.5351e-02],
         [8.2040e-01, 1.7792e-01, 1.6562e-03, 2.1478e-05],
         [9.9986e-01, 1.3975e-04, 9.8872e-08, 8.1741e-07],
         [7.9510e-01, 1.9682e-01, 8.5944e-04, 7.2197e-03]]])


In [35]:
print(f"input :\n{x}")
print(f"input :\n{output}")

input :
tensor([[[ 1.6424, -0.9907,  0.6301,  0.8595, -0.6456, -0.5459],
         [-0.4458, -0.8357,  1.9928, -1.6449,  0.6589,  0.8060],
         [-0.8718, -0.2471, -2.5186, -1.3519, -1.0562,  0.4714],
         [-0.9784,  0.6653, -1.4547,  0.6962,  0.6511, -0.7816]]])
input :
tensor([[[-5.2841, -3.4050,  0.0996, -6.2014, -1.1419,  4.8094],
         [ 0.8883,  1.5031,  0.7012,  2.9978,  1.9457, -0.3712],
         [ 1.5253,  1.8126,  1.3404,  4.2435,  1.3543,  0.2843],
         [ 0.8129,  1.4514,  0.6299,  2.8269,  1.9622, -0.4338]]])
