In [1]:
# Import modules
import torch
import torch.nn as nn
from math import sqrt

In [2]:
# Example of a transformer model from scratch using PyTorch: https://github.com/Khaliladib11/Transformer-from-scratch/
# Illustration: https://jalammar.github.io/illustrated-transformer/
# Mathematical Visualization: https://people.tamu.edu/~sji/classes/attn-slides.pdf
# Another good set of diagrams to understand attention: https://towardsdatascience.com/transformers-explained-visually-part-2-how-it-works-step-by-step-b49fa4a64f34
#   and https://towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853

In [3]:
# Tokens to vectors: Token embeddings using word2vec like MLP
# [READ] https://discuss.huggingface.co/t/the-inputs-into-bert-are-token-ids-how-do-we-get-the-corresponding-input-token-vectors/11273
# Above tokenizer outputs token IDs, which map (internally) to embedding vectors, which could be accessed using https://discuss.huggingface.co/t/generate-raw-word-embeddings-using-transformer-models-like-bert-for-downstream-process/2958
# The word to token to embedding is part of the bigger transfomer and is trained alongwith the complete model, instead of using an 'off the shelf' word2vec trained MLP

In [4]:
# Only to understand, we define tensors for different aspects

D = 6   # Embedding Vector length => # of dimensions in word2vec terms
L = 8    # Length of sequence. If tokenized input is shorter, pad it so that resulting length is L

# Word embedding tensor - 1 column for each token
WE = torch.rand((D, L))

In [5]:
# Positional Encoding
# Since transformer doesn't use any recurrence, to provide importance to nearby tokens, positional embedding is added to token embedding.
# Refer to attention paper or https://machinelearningmastery.com/a-gentle-introduction-to-positional-encoding-in-transformer-models-part-1/
# TODO: Add code to generate position embedding lookup and add to word embedding?
PE = torch.rand((D, L))

In [6]:
# The input is sum of word and position embeddings
I = WE + PE
I.size()

torch.Size([6, 8])

In [7]:
# The input is linearly transformed (https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) to create query and key matrices
# Essentially, we're transforming each higher dimension (D) embedding into a smaller dimension (M = key-query space dimension) vector, e.g Q = WQ dot I
# https://www.youtube.com/watch?v=eMlx5fFNoYc&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&index=6&t=269s
M = 10

# TODO: Use
WQ = torch.rand((M, D))
WK = torch.rand((M, D))

WQ.size()

torch.Size([10, 6])

In [8]:
# Pytorch way of calculating dot-product
Q = torch.einsum('ij, jk -> ik', WQ, I)
K = torch.einsum('ij, jk -> ik', WK, I)

Q.size()
K.size()
Q

tensor([[3.1693, 2.1927, 2.8809, 2.5130, 2.9900, 2.2685, 3.7322, 2.8889],
        [2.6887, 2.1459, 2.6003, 2.4136, 2.7068, 2.6250, 3.6825, 2.7982],
        [2.9437, 2.6521, 2.5546, 1.9819, 2.7460, 2.2680, 3.2884, 2.6828],
        [3.3613, 2.7055, 2.9411, 2.6842, 3.8369, 3.1979, 4.5251, 3.4753],
        [2.6514, 2.1662, 2.4622, 2.2088, 2.9289, 2.4920, 3.5213, 2.7258],
        [4.5704, 3.4810, 4.4151, 3.7938, 4.2986, 3.6176, 5.5293, 4.3299],
        [3.0719, 2.4255, 2.6300, 2.2707, 3.3615, 2.6705, 3.8966, 3.0434],
        [1.4318, 1.2466, 1.4742, 1.0371, 1.2894, 1.3181, 1.7367, 1.4644],
        [3.7080, 2.8286, 3.8334, 2.8846, 2.8095, 2.5994, 4.0383, 3.3360],
        [5.4759, 4.5118, 4.9074, 3.6270, 4.7713, 3.9199, 5.9387, 4.9061]])

In [9]:
# Attention: Each column in 'Q' is trying to ask a question, and when we calculate K' dot Q, the closely matching Q and K vectors result in large "attention" outputs
# Each column in attention matrix A tells how that other tokens in the input attend to this token in the input
A = torch.einsum('ij, jk -> ik', torch.transpose(K, 0, 1), Q)

# TODO: To avoid later tokens to influence earlier tokens (acausal behaviour), we mask the lower diagonal values in A, so that every key with index higher than query index is set to -INF, so that softmax calculation is not impacted
A

tensor([[122.9445,  98.1300, 114.3399,  95.5700, 119.6243, 102.5016, 150.3902,
         118.9190],
        [ 95.6930,  76.3270,  88.7811,  74.1526,  93.1516,  79.5312, 116.8576,
          92.4227],
        [112.2869,  89.4369, 104.3041,  87.1116, 109.1998,  93.3173, 137.1531,
         108.4709],
        [ 94.6003,  75.2705,  87.7044,  73.1604,  91.9579,  78.3324, 115.3166,
          91.2332],
        [122.1398,  97.1727, 113.4539,  94.3436, 118.0311, 100.6807, 148.3893,
         117.5325],
        [103.3264,  82.2776,  95.5361,  79.3966,  99.9396,  84.8448, 125.2358,
          99.2093],
        [152.2400, 121.3286, 141.1500, 117.5810, 147.5880, 125.8220, 185.2449,
         146.6374],
        [119.7703,  95.4603, 111.1152,  92.5885, 116.1174,  99.0762, 145.8140,
         115.4140]])

In [10]:
# Now we perform softmax on A to convert each column of A, to find the most likelihood
A = nn.functional.softmax(A / sqrt(M), dim=0)
A.size()

torch.Size([8, 8])

In [11]:
# Now we adjust each input embedding using the self-attention calculated above, to add context information: https://www.youtube.com/watch?v=eMlx5fFNoYc&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&index=6&t=790s
# Conceptually, we use a D X D matrix to generate linearly transformed input embedding, and then dot product with attention A to find required changes in input embedding to add context
# Bit it can be more efficiently done by using a product of two weight matrix:

WVu = torch.rand((D, M))
WVd = torch.rand((M, D))

# First, we down-size input embedding from D to M
Vd = torch.einsum('ij, jk -> ik', WVd, I)
Vd.size()

# Then, we up-size Vd back to D space
V = torch.einsum('ij, jk -> ik', WVu, Vd)
V.size()

torch.Size([6, 8])

In [12]:
# Finally, we perform V x A to get change in embeddings: essentially, A gives the weight for each vector, and we average weighted value vectors to get change needed
dI = torch.einsum('ij, jk -> ik', V, A)
dI.size()

torch.Size([6, 8])

In [13]:
# And add dI to input embedding to get transformed output embedding
O = I + dI
O.size()

torch.Size([6, 8])