In [4]:
# SO to begin with I have a sentence
sentence = 'SO to begin with I have a sentence'.split(' ')
sentence

['SO', 'to', 'begin', 'with', 'I', 'have', 'a', 'sentence']

In [7]:
# we need to convert sentence to numbers before converting to tensor
word2idx = {}
count = 0
for word in sentence:
    if word in word2idx:
        continue
    else:
        word2idx[word] = count
        count += 1
    
word2idx

{'SO': 0,
 'to': 1,
 'begin': 2,
 'with': 3,
 'I': 4,
 'have': 5,
 'a': 6,
 'sentence': 7}

In [9]:
# then there was some kind of embedding
import torch
import torch.nn as nn

x = torch.tensor([[word2idx[word] for word in sentence]]) # batch 1 for sentence
print(x)
embed = nn.Embedding(20, 5) #vocab 20 and vec size 5
x = embed(x)
print(x)
# so each word -> number -> vector, sure, that works

tensor([[0, 1, 2, 3, 4, 5, 6, 7]])
tensor([[[ 0.6287,  1.6665, -0.3331, -0.4593, -1.6786],
         [ 0.2869, -0.0948, -0.3191, -0.9615,  2.1615],
         [ 1.3007,  0.2417, -1.2106, -2.0156,  1.5267],
         [ 1.5645, -0.7544,  0.0069, -1.7901, -1.2235],
         [ 1.9531, -0.7038,  0.3065, -1.0574, -0.7996],
         [ 1.6464,  0.4403,  0.4414, -0.1570, -0.3825],
         [-0.3843, -0.6210, -0.1699,  0.5268, -1.1570],
         [-1.3945, -0.3285, -0.6545, -0.3870, -0.5344]]],
       grad_fn=<EmbeddingBackward>)


In [11]:
# now there was this positional encoding thing, but let's skip it since it's added, ie assume 0
# then we got multiheaded attention
# So attention can be seen as this heuristic:
# IN ESSENCE -- add more changeable parameters/degree of freedom to the most BASIC approach
# 1. We want seq -> seq, and somehow the output seq captures the correlations with input seq
# 2. A natural way to do this is just weighted sum
# 3. A natural way for weights is just dot products
# 4. To make things easier, we normalize, with softmax
# 5. The raw vectors may not be in the right latent space, let's add an extra linear transf.
# 6. AND, let each 'raw vec term' in the weighted sum have its own customizable projection
# 7. Voila! These transformed raw vectors are key, quary and values! 
#    And we get the output seq by good old weighted sum (weights as dot product)

In [15]:
emb_dim = 5
latent_dim = 7
#init the matrices to get K, Q, V vecotors
M_K, M_Q, M_V = [torch.rand(emb_dim, latent_dim) for _ in range(3)]
# 

In [16]:
M_K

tensor([[0.3447, 0.8092, 0.3961, 0.4156, 0.9257, 0.4706, 0.7407],
        [0.5121, 0.5361, 0.9129, 0.5644, 0.2540, 0.2610, 0.1731],
        [0.6374, 0.6320, 0.2554, 0.5370, 0.6227, 0.6127, 0.1856],
        [0.2040, 0.9485, 0.3082, 0.9250, 0.6534, 0.0748, 0.3550],
        [0.2369, 0.1740, 0.7922, 0.8342, 0.0204, 0.1259, 0.6711]])

In [17]:
x@M_K #this is how one transforms the whole batch of words!
# the trick to applying matrix to a batch is really just put x first 
# it can be hard to make that mental switch from the math background
# since in math we are so used to put the matrix before x, but nothing special is going on here

tensor([[[ 0.3665,  0.4639,  0.2141, -0.8021,  0.4635,  0.2811, -0.5971],
         [ 0.1629, -0.5562,  1.3616,  0.8082, -0.5413,  0.1150,  1.2462],
         [-0.2491, -1.2292,  1.0150, -0.5639, -0.7743, -0.0252,  1.0896],
         [-0.4977, -1.0448, -1.5880, -2.4483,  0.0663,  0.2555, -0.4271],
         [ 0.1031,  0.2549, -0.7498, -1.0659,  1.1129,  0.7434,  0.4696],
         [ 0.9517,  1.6318,  0.8155,  0.7055,  1.8004,  1.1003,  1.0651],
         [-0.7254, -0.4530, -1.5167, -1.0794, -0.2987, -0.5533, -1.0132],
         [-1.2717, -2.1783, -1.5621, -1.9204, -2.0457, -1.2393, -1.7073]]],
       grad_fn=<UnsafeViewBackward>)

In [34]:
import torch.nn.functional as F
K, Q, V = x@M_K, x@M_Q, x@M_V 
# to get all pairs of dot products between Q_i and K_j, we multiply the "batch matrices" like so
# here matrices is truly a bookkeeping device
# and the matrix multiplication here is just a neat notation to get all pairs
# almost like list comprehension is a neat notation for for loop, the essence remains basic
# W_ij = Q_i . K_j
W_raw = Q@(K.transpose(1,2))
# when we do the final sum, we are summing over (or contracting the index of;) j -- the keys
# and we softmax these weights, i.e softmax all the rows of W
W = F.softmax(W_raw, dim=1)

In [35]:
# sanity check softmax does indeed makes probabilities
torch.sum(W, dim = 1)

tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],
       grad_fn=<SumBackward1>)

In [38]:
# putting it together: y_i = \sum_j{ W_i_j * v_j}
# in matrix form -- one could check by hand, it is
Y = W@V
# so yeah, batching can use matrix operations to bookkeep things and results elegantly!
Y
# num_words x latent dim -- shape checks out!

tensor([[[ 0.4231, -0.6297, -0.1127, -0.1433, -0.3924,  0.1836, -0.0788],
         [ 0.2189, -0.1688,  0.0991,  0.1640, -0.1528,  0.0889, -0.0864],
         [ 0.6795, -0.8617,  0.1346,  0.2751, -0.7951,  0.1563, -0.5765],
         [ 0.4446, -0.6852, -0.0218, -0.0228, -0.5302,  0.0906, -0.3964],
         [ 0.4071, -0.5548,  0.0055, -0.0293, -0.3627,  0.1148, -0.2745],
         [ 2.1575, -0.5151,  1.7189,  1.3461,  0.4325,  1.2230,  0.4569],
         [ 0.0795, -0.1965, -0.0770, -0.0895, -0.1347,  0.0139, -0.0796],
         [-2.4305, -3.1821, -4.7890, -5.0264, -2.9587, -2.2738, -3.3613]]],
       grad_fn=<UnsafeViewBackward>)