In [3]:
import torch
import torch.nn as nn

In [4]:
# SO to begin with I have a sentence
sentence = 'SO to begin with I have a sentence'.split(' ')
sentence

['SO', 'to', 'begin', 'with', 'I', 'have', 'a', 'sentence']

In [5]:
# we need to convert sentence to numbers before converting to tensor
word2idx = {}
count = 0
for word in sentence:
    if word in word2idx:
        continue
    else:
        word2idx[word] = count
        count += 1
    
word2idx

{'SO': 0,
 'to': 1,
 'begin': 2,
 'with': 3,
 'I': 4,
 'have': 5,
 'a': 6,
 'sentence': 7}

In [6]:
# then there was some kind of embedding
import torch
import torch.nn as nn

x = torch.tensor([[word2idx[word] for word in sentence]]) # batch 1 for sentence
print(x)
embed = nn.Embedding(20, 5) #vocab 20 and vec size 5
x = embed(x)
print(x)
# so each word -> number -> vector, sure, that works

tensor([[0, 1, 2, 3, 4, 5, 6, 7]])
tensor([[[-0.7012,  0.9257,  0.3939,  0.7522,  0.5292],
         [-0.2357, -1.1488, -0.0339,  0.9371, -0.8041],
         [-1.1951,  1.1564, -0.5757, -0.5994, -0.8532],
         [-0.2493, -0.8708, -2.2939,  1.1188,  0.9900],
         [-0.2316, -1.3752, -1.1277,  1.1448,  1.4320],
         [ 1.5041, -0.5158,  1.9260, -0.8586, -1.0676],
         [ 0.4384, -0.1313, -1.1147,  0.2073, -0.2337],
         [-0.4124, -0.6262,  0.0381,  0.3523, -0.9138]]],
       grad_fn=<EmbeddingBackward>)


In [7]:
# now there was this positional encoding thing, but let's skip it since it's added, ie assume 0
# then we got multiheaded attention
# So attention can be seen as this heuristic:
# IN ESSENCE -- add more changeable parameters/degree of freedom to the most BASIC approach
# 1. We want seq -> seq, and somehow the output seq captures the correlations with input seq
# 2. A natural way to do this is just weighted sum
# 3. A natural way for weights is just dot products
# 4. To make things easier, we normalize, with softmax
# 5. The raw vectors may not be in the right latent space, let's add an extra linear transf.
# 6. AND, let each 'raw vec term' in the weighted sum have its own customizable projection
# 7. Voila! These transformed raw vectors are key, quary and values! 
#    And we get the output seq by good old weighted sum (weights as dot product)

In [8]:
emb_dim = 5
latent_dim = 7
#init the matrices to get K, Q, V vecotors
M_K, M_Q, M_V = [torch.rand(emb_dim, latent_dim) for _ in range(3)]
# 

In [9]:
M_K

tensor([[0.4851, 0.9903, 0.7989, 0.6851, 0.6765, 0.3137, 0.6337],
        [0.3811, 0.2715, 0.2483, 0.5626, 0.2533, 0.5040, 0.2638],
        [0.3329, 0.1015, 0.6938, 0.3812, 0.1823, 0.2009, 0.5130],
        [0.3601, 0.2882, 0.0179, 0.9687, 0.8276, 0.3333, 0.3122],
        [0.2598, 0.6326, 0.2150, 0.5138, 0.3100, 0.7412, 0.4021]])

In [10]:
x@M_K #this is how one transforms the whole batch of words!
# the trick to applying matrix to a batch is really just put x first 
# it can be hard to make that mental switch from the math background
# since in math we are so used to put the matrix before x, but nothing special is going on here

tensor([[[ 0.5521,  0.1484,  0.0702,  1.1911,  0.6184,  0.9685,  0.4495],
         [-0.4349, -0.7874, -0.6532, -0.3261,  0.0697, -0.9434, -0.5006],
         [-0.7682, -1.6405, -1.2613, -1.4068, -1.3810, -0.7399, -1.2778],
         [-0.5563,  0.2326, -1.7740,  0.0575,  0.4254,  0.1287, -0.8170],
         [-0.2276,  0.5186, -0.9805,  0.4826,  0.6806,  0.4505, -0.1548],
         [ 0.5877,  0.6222,  2.1649,  0.0942,  0.1966, -0.4785,  1.1077],
         [-0.1945,  0.1972, -0.5023, -0.1176,  0.1592, -0.2567, -0.3579],
         [-0.5367, -1.0512, -0.6488, -0.7487, -0.4224, -0.9973, -0.6645]]],
       grad_fn=<UnsafeViewBackward>)

In [11]:
import torch.nn.functional as F
K, Q, V = x@M_K, x@M_Q, x@M_V 
# to get all pairs of dot products between Q_i and K_j, we multiply the "batch matrices" like so
# here matrices is truly a bookkeeping device
# and the matrix multiplication here is just a neat notation to get all pairs
# almost like list comprehension is a neat notation for for loop, the essence remains basic
# W_ij = Q_i . K_j
W_raw = Q@(K.transpose(1,2))
# when we do the final sum, we are summing over (or contracting the index of;) j -- the keys
# and we softmax these weights, i.e softmax all the rows of W
W = F.softmax(W_raw, dim=1)

In [12]:
# sanity check softmax does indeed makes probabilities
torch.sum(W, dim = 1)

tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],
       grad_fn=<SumBackward1>)

In [13]:
# putting it together: y_i = \sum_j{ W_i_j * v_j}
# in matrix form -- one could check by hand, it is
Y = W@V
# so yeah, batching can use matrix operations to bookkeep things and results elegantly!
Y
# num_words x latent dim -- shape checks out!

tensor([[[ 0.9371,  0.8236,  1.2489,  1.8378,  1.4066,  0.6687,  0.4845],
         [-0.0982, -0.1767, -0.2936, -0.8520, -0.4210, -0.2800, -0.3303],
         [-0.3392, -0.8299, -1.0265, -1.5608, -1.9907, -1.6549, -0.9218],
         [-0.2412, -0.2690, -0.4956, -1.7746, -0.5718, -0.1423, -0.5431],
         [-0.0947, -0.0421, -0.0998, -1.0533,  0.0788,  0.4430, -0.2389],
         [ 0.1258,  0.1192,  0.1740,  0.0719,  0.2455,  0.2015,  0.0255],
         [-0.0516, -0.0684, -0.1397, -0.5674, -0.1472, -0.0428, -0.1848],
         [-0.1221, -0.3061, -0.4342, -1.0544, -0.7245, -0.5557, -0.4898]]],
       grad_fn=<UnsafeViewBackward>)

In [14]:
# Multiheaded attention (wide) is just repeating the process a few times to get different Y's
# this way we can have multiple customizable "contexts" in these Y's
# the hope is that backprop will tune it such that each Y represents a disentabgled context
# Well then, let's make attention a function and call it repeatedly
def self_attention(emb_dim, latent_dim):
    M_K, M_Q, M_V = [torch.rand(emb_dim, latent_dim) for _ in range(3)]
    K, Q, V = x@M_K, x@M_Q, x@M_V 
    W_raw = Q@(K.transpose(1,2))
    W = F.softmax(W_raw, dim=1)
    Y = W@V
    return Y
num_heads = 9
Ys = [self_attention(5,7) for _ in range(num_heads)]

In [15]:
Ys
# well, this is definitely not the best way to do it for pytorch since this could cause 
# gradient flowing issues, nonetheless, we do it here for understanding the essence

[tensor([[[ 0.1381,  0.1175,  0.2261,  0.0956,  0.1979,  0.1973,  0.1280],
          [ 0.0314, -0.6409,  0.1912, -0.3636, -0.8466,  0.0484, -1.0450],
          [-1.7420, -2.7032, -1.9559, -2.8386, -5.1617, -2.1685, -4.6580],
          [-0.0170, -0.2379,  0.0393, -0.1323, -0.3201,  0.0066, -0.3718],
          [ 0.0767, -0.0820,  0.1082, -0.0438, -0.0887,  0.0348, -0.1689],
          [ 0.4044,  1.1314,  0.9159,  1.5698,  2.3812,  1.7393,  2.4376],
          [ 0.0523, -0.0546,  0.0709, -0.0306, -0.0591,  0.0208, -0.1136],
          [-0.1072, -0.8400,  0.0406, -0.5487, -1.2029, -0.0816, -1.3557]]],
        grad_fn=<UnsafeViewBackward>),
 tensor([[[ 3.5052e-01,  4.1408e-01,  4.7629e-01,  5.0054e-01,  2.1207e-01,
            3.7885e-01,  6.6828e-01],
          [ 2.6067e-02,  3.5265e-02, -3.4085e-01,  3.1393e-01,  5.8668e-02,
            1.6751e-01, -1.6683e-02],
          [-4.1660e-01, -2.4759e-01, -6.4185e-01, -3.2129e-01, -3.3837e-01,
           -1.3594e-01, -6.0418e-01],
          [-2.437

In [17]:
# Now, two common tricks for neural net training:
# 1. Residual connection:
# if we want to add a residual connection, i.e. add Xs to Ys, we need to make them 
# the same dimension, i.e. latent dim = emb dim
# residual connection helps the gradient flow through deep neural nets, preventing the 
# gradient update from getting stuck 
# 2. Normalization: standardize data, stablize and accelerate training
# in this case we use layer norm on embedding dimension
# We don't want to worry about layer norm here since that's not essnetial to the 
# transfromer architechture and can be changed for different contexts (e.g. vision)
m = nn.LayerNorm(emb_dim)
m(x)

tensor([[[-1.8949,  0.9564,  0.0245,  0.6524,  0.2615],
         [ 0.0298, -1.2436,  0.3113,  1.6654, -0.7629],
         [-0.9579,  1.9236, -0.1989, -0.2279, -0.5390],
         [ 0.0093, -0.4829, -1.6098,  1.0927,  0.9907],
         [-0.1745, -1.1719, -0.9561,  1.0260,  1.2765],
         [ 1.0381, -0.5669,  1.3733, -0.8392, -1.0053],
         [ 1.1394,  0.0668, -1.7847,  0.7044, -0.1260],
         [-0.2198, -0.6898,  0.7704,  1.4611, -1.3219]]],
       grad_fn=<NativeLayerNormBackward>)

In [18]:
x

tensor([[[-0.7012,  0.9257,  0.3939,  0.7522,  0.5292],
         [-0.2357, -1.1488, -0.0339,  0.9371, -0.8041],
         [-1.1951,  1.1564, -0.5757, -0.5994, -0.8532],
         [-0.2493, -0.8708, -2.2939,  1.1188,  0.9900],
         [-0.2316, -1.3752, -1.1277,  1.1448,  1.4320],
         [ 1.5041, -0.5158,  1.9260, -0.8586, -1.0676],
         [ 0.4384, -0.1313, -1.1147,  0.2073, -0.2337],
         [-0.4124, -0.6262,  0.0381,  0.3523, -0.9138]]],
       grad_fn=<EmbeddingBackward>)