In [1]:
# Import modules
import torch
import torch.nn as nn
from math import sqrt

In [2]:
# Example of a transformer model from scratch using PyTorch: https://github.com/Khaliladib11/Transformer-from-scratch/
# Illustration: https://jalammar.github.io/illustrated-transformer/
# Mathematical Visualization: https://people.tamu.edu/~sji/classes/attn-slides.pdf
# Another good set of diagrams to understand attention: https://towardsdatascience.com/transformers-explained-visually-part-2-how-it-works-step-by-step-b49fa4a64f34
#   and https://towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853

In [3]:
# Tokens to vectors: Token embeddings using word2vec like MLP
# [READ] https://discuss.huggingface.co/t/the-inputs-into-bert-are-token-ids-how-do-we-get-the-corresponding-input-token-vectors/11273
# Above tokenizer outputs token IDs, which map (internally) to embedding vectors, which could be accessed using https://discuss.huggingface.co/t/generate-raw-word-embeddings-using-transformer-models-like-bert-for-downstream-process/2958
# The word to token to embedding is part of the bigger transfomer and is trained alongwith the complete model, instead of using an 'off the shelf' word2vec trained MLP

In [4]:
# Only to understand, we define tensors for different aspects

D = 6   # Embedding Vector length => # of dimensions in word2vec terms
L = 8    # Length of sequence. If tokenized input is shorter, pad it so that resulting length is L

# Word embedding tensor - 1 column for each token
WE = torch.rand((D, L))

In [5]:
# Positional Encoding
# Since transformer doesn't use any recurrence, to provide importance to nearby tokens, positional embedding is added to token embedding.
# Refer to attention paper or https://machinelearningmastery.com/a-gentle-introduction-to-positional-encoding-in-transformer-models-part-1/
# TODO: Add code to generate position embedding lookup and add to word embedding?
PE = torch.rand((D, L))

In [6]:
# The input is sum of word and position embeddings
I = WE + PE
I.size()

torch.Size([6, 8])

In [7]:
# The input is linearly transformed (https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) to create query and key matrices
# Essentially, we're transforming each higher dimension (D) embedding into a smaller dimension (M = key-query space dimension) vector, e.g Q = WQ dot I
# https://www.youtube.com/watch?v=eMlx5fFNoYc&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&index=6&t=269s
M = 10

# TODO: Use
WQ = torch.rand((M, D))
WK = torch.rand((M, D))

WQ.size()

torch.Size([10, 6])

In [8]:
# Pytorch way of calculating dot-product
Q = torch.einsum('ij, jk -> ik', WQ, I)
K = torch.einsum('ij, jk -> ik', WK, I)

Q.size()
K.size()
Q

tensor([[2.3897, 1.7479, 2.6692, 1.6097, 1.6942, 1.7014, 1.5254, 1.6739],
        [3.4826, 2.8201, 5.3857, 3.2962, 3.9728, 2.4400, 4.0953, 3.4223],
        [2.3004, 1.5296, 2.9333, 1.5701, 1.9988, 1.6108, 1.8064, 1.7154],
        [2.0118, 1.5871, 3.3924, 1.9076, 2.4824, 1.4944, 2.3505, 2.0305],
        [2.5622, 1.9513, 3.5046, 1.9457, 2.4713, 1.5285, 2.8915, 2.2452],
        [3.0498, 2.0576, 3.6320, 2.3153, 2.4279, 1.9450, 2.5457, 2.4038],
        [2.8447, 2.1060, 4.1158, 2.5028, 2.9352, 2.2065, 2.5113, 2.4870],
        [3.6852, 2.8830, 4.5143, 2.5416, 3.3001, 2.1887, 4.0459, 2.9172],
        [2.4057, 1.9910, 3.1818, 1.7673, 2.3491, 1.7989, 2.2642, 1.8796],
        [3.6407, 3.3560, 4.7156, 2.6356, 3.3087, 2.7745, 3.2263, 2.8537]])

In [9]:
# Attention: Each column in 'Q' is trying to ask a question, and when we calculate K' dot Q, the closely matching Q and K vectors result in large "attention" outputs
# Each column in attention matrix A tells how that other tokens in the input attend to this token in the input
A = torch.einsum('ij, jk -> ik', torch.transpose(K, 0, 1), Q)

# TODO: To avoid later tokens to influence earlier tokens (acausal behaviour), we mask the lower diagonal values in A, so that every key with index higher than query index is set to -INF, so that softmax calculation is not impacted
A

tensor([[ 89.9745,  70.6371, 121.2894,  70.2107,  86.4011,  62.3521,  88.2116,
          75.3157],
        [ 67.1082,  52.2188,  90.4301,  52.3070,  64.3178,  46.4574,  65.4251,
          56.0926],
        [112.9106,  87.9641, 152.2493,  88.2319, 108.3713,  78.2108, 110.2667,
          94.5058],
        [ 64.7450,  50.4106,  87.2677,  50.5651,  62.1189,  44.7724,  63.3186,
          54.1950],
        [ 78.2327,  60.9336, 105.3769,  61.0377,  74.9949,  54.1628,  76.3433,
          65.4064],
        [ 56.3546,  43.6808,  75.5455,  43.6546,  53.6566,  38.8353,  54.7700,
          46.8999],
        [ 87.5789,  68.9833, 118.6289,  68.7970,  84.6466,  60.9697,  86.1163,
          73.6193],
        [ 72.2746,  56.4820,  97.6490,  56.6189,  69.5675,  50.1244,  70.8019,
          60.6220]])

In [10]:
# Now we perform softmax on A to convert each column of A, to find the most likelihood
A = nn.functional.softmax(A / sqrt(M), dim=0)
A.size()

torch.Size([8, 8])

In [11]:
# Now we adjust each input embedding using the self-attention calculated above, to add context information: https://www.youtube.com/watch?v=eMlx5fFNoYc&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&index=6&t=790s
# Conceptually, we use a D X D matrix to generate linearly transformed input embedding, and then dot product with attention A to find required changes in input embedding to add context
# Bit it can be more efficiently done by using a product of two weight matrix:

WVu = torch.rand((D, M))
WVd = torch.rand((M, D))

# First, we down-size input embedding from D to M
Vd = torch.einsum('ij, jk -> ik', WVd, I)
Vd.size()

# Then, we up-size Vd back to D space
V = torch.einsum('ij, jk -> ik', WVu, Vd)
V.size()

torch.Size([6, 8])

In [12]:
# Finally, we perform V x A to get change in embeddings: essentially, A gives the weight for each vector, and we average weighted value vectors to get change needed
dI = torch.einsum('ij, jk -> ik', V, A)
dI.size()

torch.Size([6, 8])

In [13]:
# And add dI to input embedding to get transformed output embedding
O = I + dI
O.size()

torch.Size([6, 8])