# The StatQuest Illustrated Guide to Neural Networks and AI
## Chapter 12 - Transformers!!!

Copyright 2024, Joshua Starmer

---- 

This tutorial is from the book, **[The StatQuest Illustrated Guide to Neural Networks and AI](https://www.amazon.com/dp/B0DRS71QVQ)**.

In this tutorial, we will use **[PyTorch](https://pytorch.org/) + [Lightning](https://www.lightning.ai/)** to create and optimize an encoder-decoder transformer, like the one shown in the picture below.

<img src="./images/enc_dec_transformer.png" alt="an encoder-decoder neural network" style="width: 800px;">

In this tutorial, you will...

- **[Code a Position Encoder Class From Scratch!!!](#position)** The position encoder class will give the encoder and the decoder a way to keep track of the order of the input tokens in the encoder and the decoder.

- **[Code an Attention Class From Scratch!!!](#attention)** The attention class will allow us to keep track of how words in the input and output are related to each other

- **[Code an Encoder Class From Scratch!!!](#encoder)** The encoder will process the input.

- **[Code a Decoder Class From Scratch!!!](#decoder)** The decoder will generate the output.

- **[Code a Transformer Class From Scratch!!!](#transformer)** The transformer class will connect all the pieces, the position encoder, attention, the encoder and the decoder.

- **[Train the Transformer!!!](#train)** We'll train the transformer to translate simple English phrases into Spanish.

- **[Use the Trained Transformer!!!](#use)** Finally we'll use the transformer to translate simple English phrases into Spanish.

#### NOTE:
This tutorial assumes that you already know the basics of coding in **Python** and are familiar with the theory behind **[Encoder-Decoder Transformers CORRECT THIS LINK]()** and **[Backpropagation](https://youtu.be/IN2XmBhILt4)**. If not, check out the **StatQuests** by clicking on the links for each topic.

#### ALSO NOTE:
I strongly encourage you to play around with the code. Playing with the code is the best way to learn from it.

----

# Import the modules that will do all the work

The very first thing we need to do is load a bunch of Python modules. Python itself is just a basic programming language. These modules give us extra functionality to create and train a Neural Network.

In [None]:
%%capture 
# %%capture prevents this cell from printing a ton of STDERR stuff to the screen

## First, check to see if lightning is installed, if not, install it.
##
## NOTE: If you **do** need to install something, just know that you may need to
##       restart your session for python to find the new module(s).
##
##       To restart your session:
##       - In Google Colab, click on the "Runtime" menu and select
##         "Restart Session" from the pulldown menu
##       - In a local jupyter notebook, click on the "Kernel" menu and select
##         "Restart Kernel" from the pulldown menu
# import pip
# try:
#   __import__("lightning")
# except ImportError:
#   pip.main(['install', "lightning"])

In [None]:
import torch ## torch let's us create tensors and also provides helper functions
import torch.nn as nn ## torch.nn gives us nn.Module, nn.Embedding() and nn.Linear()
import torch.nn.functional as F # This gives us the softmax() and argmax()
from torch.optim import Adam # This is the optimizer we will use

import lightning as L # Lightning makes it easier to write, optimize and scale our code
from torch.utils.data import TensorDataset, DataLoader # We'll store our data in DataLoaders

## NOTE: If you get an error running this block of code, it is probably
##       because you installed a new package earlier and forgot to
##       restart your session for python to find the new module(s).
##
##       To restart your session:
##       - In Google Colab, click on the "Runtime" menu and select
##         "Restart Session" from the pulldown menu
##       - In a local jupyter notebook, click on the "Kernel" menu and select
##         "Restart Kernel" from the pulldown menu

----

# The input and output vocabularies and data

In this tutorial we will build a simple encoder-decoder transformer that can translate simple Engilish phrases into Spanish. Specifically, we'll be able to translate **Let's go** to **vamos** and **to go** to **ir**. In order to keep track of things,
we'll create dictionaries for the input and output vocabularies, and then create a **Dataloader** that contains the English phrases mapped to their Spanish translations. Ultimately we'll use the **Dataloader** to train the transformer.

In [None]:
## first, a dictionary for the input vocabulary
input_vocab = {'<SOS>': 0, ## <SOS> = start of sequence.
               'lets': 1,
               'to': 2,
               'go': 3}

## Now a dictionary for the output vocabulary
output_vocab = {'<SOS>': 0,
                'ir': 1,
                'vamos': 2,
                'y': 3,
                '<EOS>': 4}

## Here are the english phrases, encoded using the
## input vocabulary
## NOTE: our transformer will prepend the <SOS> token to these inputs
inputs = torch.tensor([[1, 3],
                       [2, 3]])

## Here are the spanish translations encoded using
## the output vocabulary.
## NOTE: our transformer will prepend the <SOS> token to these outputs
labels = torch.tensor([[2],
                      [1]])

dataset = TensorDataset(inputs, labels) 
dataloader = DataLoader(dataset)

Now that we have created the input and output datasets and the **Dataloader** to train the model, let's start building it.

----

<a id="position"></a>
# Position Encoding

Position Encoding helps the transformer keep track of the order of the words in the input and the output. For example, in the picture below, we see that the two phrases **Squatch eats pizza** and **Pizza eats Squatch** both have the exact same words, but, due to differences in the word order, have very different meanings. Thus, keeping track of word order is very important.

<img src="./images/squatch_eats_pizza.png" alt="Squatch eats pizza is very different from Pizza eats Squatch" style="width: 800px;">

There are a bunch of ways for a transformer to keep track of word order, but one popular method is to use a series of alternating sine and cosine curves (seen below). The number of sine and cosine squiggles depends on how many numbers, or word embedding values, we use to represent each token. In the context of Transformers, the number of numbers, or word embedding values, we use to to represent each token is the **dimension** of the transformer. So, if the transformer's dimension is 2, meaning that it uses 2 numbers to represent each token, then we only need one sine and one cosine squiggle. 

<img src="./images/pos_encoding_1.png" alt="Sine and cosine squiggles for position encoding" style="width: 800px;">

In contrast, as we see in the illustration below, if the transformer's dimension is 4, then we'll need 2 sine squiggles alternating with 2 cosine squiggles, for a total of 4 squiggles.

<img src="./images/pos_encoding_2.png" alt="More sine and cosine squiggles for position encoding" style="width: 800px;">

As we see in the illustration above, the additional pair of sine and cosine squiggles have a wider period (they repeat less frequently) than the first pair. Increasing the period for each additional pair of squiggles ensures that each position is represented by a unique combination of values.

**NOTE:** The reason why we are bothering to create a class to do positional encoding, instead of just adding this code directly to the transformer, is that both the Encoder and Decoder need to use it. So, by creating a class that does positional encoding, we can code it once, and then just create as many instances of it as we need (which, in this case, is two times).

**ALSO NOTE:** Since the position encoding values never change, meaning that the first token always uses the same position encoding values regardless of what that token is, we can precompute them and save them in a lookup table. This makes adding position encoding values super fast.

Now that we understand the ideas that we want to implement in the Position Encoding class, let's code it!

In [None]:
class PositionEncoding(nn.Module):
    def __init__(self, d_model=2, max_len=3):
        ## d_model = The dimension of the transformer, which is also the number of embedding values per token.
        ##           In the transformer I used in the StatQuest: Transformer Neural Networks Clearly Explained!!!
        ##           d_model=2, so that's what we'll use as a default for now.
        ##           However, in "Attention Is All You Need" d_model=512
        ## max_len = maximum number of tokens we allow as input.
        ##           Since we are precomputing the position encoding values and storing them in a lookup table
        ##           we can use d_model and max_len to determine the number of rows and columns in that
        ##           lookup table.
        ##
        ##           In this simple example, we are only using 2 word phrases + <SOS>, so we are using
        ##           max_len=3 as the default setting.
        ##           However, in The Annotated Transformer, they set the default value for max_len to 5000
        
        super().__init__()
        ## We call the super's init because by creating our own init method, we overwrite the one
        ## we inherited from nn.Module. So we have to explicity call nn.Module's __init__(), otherwise it
        ## won't get initialized. NOTE: If we didn't write our own __init__(), then we would not have
        ## to call super().__init__(). Alternatively, if we didn't want to access any of nn.Module's methods, 
        ## we wouldn't have to call it then either.

        ## Now we create a lookup table, pe, of position encoding values and initialize all of them to 0.
        ## To do this, we will make a matrix of 0s that has max_len rows and d_model columns.
        ## for example...
        ## torch.zeros(2, 3)
        ## ...returns a matrix of 0s with 2 rows and 3 columns...
        ## tensor([[0., 0., 0.],
        ##         [0., 0., 0.]])
        pe = torch.zeros(max_len, d_model)

        ## Now we create a sequence of numbers for each position that a token can have in the input (or output).
        ## For example, if the input tokens where "I'm happy today!", then "I'm" would get the first
        ## position, 0, "happy" would get the second position, 1, and "today!" would get the third position, 2.
        ## NOTE: Since we are going to be doing math with these position indices to create the 
        ## positional encoding for each one, we need them to be floats rather than ints.
        ## 
        ## NOTE: Two ways to create floats are...
        ##
        ## torch.arange(start=0, end=3, step=1, dtype=torch.float)
        ##
        ## ...and...
        ##
        ## torch.arange(start=0, end=3, step=1).float()
        ##
        ## ...but the latter is just as clear and requires less typing.
        ##
        ## Lastly, .unsqueeze(1) converts the single list of numbers that torch.arange creates into a matrix with
        ## one row for each index, and all of the indices in a single column. So if "max_len" = 3, then we
        ## would create a matrix with 3 rows and 1 column like this...
        ##
        ## torch.arange(start=0, end=3, step=1, dtype=torch.float).unsqueeze(1)
        ##
        ## ...returns...
        ##
        ## tensor([[0.],
        ##         [1.],
        ##         [2.]])        
        position = torch.arange(start=0, end=max_len, step=1).float().unsqueeze(1)


        ## Here is where we start doing the math to determine the y-axis coordinates on the
        ## sine and cosine curves.
        ##
        ## The positional encoding equations used in "Attention is all you need" are...
        ##
        ## PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
        ## PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
        ##
        ## ...and we see, within the sin() and cos() functions, we divide "pos" by some number that depends
        ## on the index and number of PE values we want per token (d_model). So, pretty much everyone
        ## calculates the term we use to divide "pos" by first, and they do it with code that looks like this...
        ##
        # div_term = torch.exp(torch.arange(start=0, end=d_model, step=2).float() * -(math.log(10000.0) / d_model))
        ##
        ## NOTE: The fact that div_term = 1/(10000^(2i/d_model)) is not immediately clear for a few reasons: 
        ##
        ##    1) div_term wrapps everything in a call to torch.exp() 
        ##    2) It uses log()
        ##    2) The order of the terms is different 
        ##
        ## The reason for these differences is, presumably, trying to prevent underflow (getting too close to 0).
        ## So, to show that div_term = 1/(10000^(2i/d_model))...
        ##
        ## 1) Swap out math.log() for torch.log() (doing this requires converting 10000.0 to a tensor, which is my
        ##    guess for why they used math.log() instead of torch.log())...
        ## torch.exp(torch.arange(start=0, end=d_model, step=2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        ##
        ## 2) Rearrange the terms...
        ## torch.exp(-1 * (torch.log(torch.tensor(10000.0)) * torch.arange(start=0, end=d_model, step=2).float() / d_model))
        ##
        ## 3) Pull out the -1 by with exp(-1 * x) = 1/exp(x)
        ## 1/torch.exp(torch.log(torch.tensor(10000.0)) * torch.arange(start=0, end=d_model, step=2).float() / d_model)
        ##
        ## 4) Use exp(a * b) = exp(a)^b to pull out the 2i/d_model term...
        ## 1/torch.exp(torch.log(torch.tensor(10000.0)))^(torch.arange(start=0, end=d_model, step=2).float() / d_model)
        ##
        ## 5) Use exp(log(x)) = x to get the original form of the denominator...
        ## 1/torch.tensor(10000.0)^(torch.arange(start=0, end=d_model, step=2).float() / d_model)
        ##
        ## 6) Bam.
        ## 
        ## So, that being said, I don't think underflow is actually that big an issue. In fact, some coder at Hugging Face
        ## also doesn't think so, and their code for positional encoding in DistilBERT (a streamlined version of BERT, which
        ## is a transformer model)
        ## calculates the values directly - using the form of the equation found in original Attention is all you need
        ## manuscript. See...
        ## https://github.com/huggingface/transformers/blob/455c6390938a5c737fa63e78396cedae41e4e87e/src/transformers/modeling_distilbert.py#L53
        ## So I think we can simplify the code, but also have some notes that show how it is equivalent to what
        ## you'll see in the wild...
        div_term = 1/torch.tensor(10000.0)**(torch.arange(start=0, end=d_model, step=2).float() / d_model)
        
        ## Now we calculate the actual positional encoding values. Remember 'pe' was initialized as a matrix of 0s
        ## with max_len (max number of input tokens) rows and d_model (number of embedding values per token) columns.
        pe[:, 0::2] = torch.sin(position * div_term) ## every other column, starting with the 1st, has sin() values
        pe[:, 1::2] = torch.cos(position * div_term) ## every other column, starting with the 2nd, has cos() values
        ## NOTE: If the notation for indexing 'pe[]' looks cryptic to you, read on...
        ##
        ## First, let's look at the general indexing notation:
        ##
        ## i:j:k = select elements between i and j with stepsize = k.
        ##
        ## i defaults to 0
        ## j defaults to the number of elements in the row, column or whatever.
        ## k defaults to 1
        ##
        ## Now that we have looked at the general notation, let's look at specific
        ## examples so that we can understand it.
        ##
        ## We'll start with: pe[:, 0::2]
        ##
        ## The stuff that comes before the comma refers to the rows we want to select.
        ## In this case, we have ':' before the comma, and that means "select all rows".
        ## This is because we are using the default values for i, j and k.
        ##
        ## The stuff after the comma refers to the columns we want to select.
        ## In this case, we have '0::2', and that means we start with
        ## the first column (column =  0) and go to the end (using the default value for j)
        ## and we set the stepsize to 2, which means we skip every other column.
        ##
        ## Now to understand pe[:, 1::2]
        ##
        ## Again, the stuff before the comma refers to the rows, and, just like before
        ## we use default values for i,j and k, so we select all rows.
        ##
        ## The stuff that comes after the comma refers to the columns.
        ## In this case, we start with the 2nd column (column = 1), and go to the end
        ## (using the default value for 'j') and we set the stepsize to 2, which
        ## means we skip every other column.
        ##
        ## NOTE: using this ':' based notation is called "indexing" and also called "slicing"
        
        ## Now we "register 'pe'.
        self.register_buffer('pe', pe) ## using "register_buffer()" prevents "pe" from getting
                                       ## passed to the optimizer. Thus, ultimately, these tensors will
                                       ## not get optimized, which is what we want.
                                       ##
                                       ## NOTE: If, instead, we had set "requires_grad=False", then 
                                       ## "pe" would still get passed to the optimizer (when we pass it
                                       ## model.parameters()) and the optimizer would have to skip over them
                                       ## so doing this with "register_buffer()" is a little cleaner.

    ## The forward() method is what is called by default when we use a PositionEncoding() object.
    ## In other words, after we create a PositionEncoding() object, pe = PositionEncoding(), we
    ## can add position encoding values to the word embeddings with pe(word_embeddings).
    def forward(self, x):
        # x = word embedding values
        # x.size(0) returns the number of embedding values
        return x + self.pe[:x.size(0), :] ## NOTE: That second ':' is optional and we could re-write it these ways
                                          ## self.pe[:x.size(0)] = self.pe[:x.size(0), :]
                                          ## The first is the least amount of typing, the last is the most explicit
                                          ## which is what I prefer.

----

<a id="attention"></a>
# Attention
We're going to code an `Attention` class to do all of the types of attention that a transformer might need: **Self-Attention**, **Masked Self-Attention** (which is used by the Decoder during training) and **Encoder-Decoder Attention**.

**Self-Attention** is a type of attention used in Encoder-Decoder transformers and Encoder-Only transformers. It allows every word in a phrase to define a relationship with any other word in the phrase, regardless of the order of the words. In other words, if the the phrase is **The pizza came out of the oven and it tasted good!**, then the word **it** can define it's relationship with every word in that phrase, including words that came after it, like **tasted** and **good**, as illustrated by the blue arrows in the figure below.

<img src="./images/self_attention_1.png" alt="An illustration of self-attention" style="width: 800px;">

**Masked Self-Attention** is used by all types of transformers (Encoder-Only, Encoder-Decoder, and Decoder-Only) and it allows each word in a phrase to define a relationship with itself and the words that came before it. In other words, **Masked Self-Attention** prevents the transformer from "looking ahead". This is illustrated below where the word **it** can define relationships with itself and everything that came before. In Encoder-Only and Encoder-Decoder transformers, **Masked Self-Attention** is used during training, when we know what the output should be, but we still force the decoder to generate it one token at a time, and thus, limiting attention to only output words that came earlier. In contrast, Decoder-Only transformers use **Masked Self-Attention** all the time, on the input and the output, during training and during inference. Thus, even though the Decoder-Only transformer doesn't have to generate the input, and thus, can see all of it during training and during inference, it still only allows the attention values for each word to depend on words that came before it.

<img src="./images/masked_attention_1.png" alt="An illustration of Masked Self-Attention" style="width: 800px;">

**Encoder-Decoder Attention** is only used in Encoder-Decoder transformers, where there is a distinct seperation of the part of the transformer that processes in the input (the encoder) from the part that generates the output (the decoder). **Encoder-Decoder Attention** lets each word in the output (in the decoder) define relationships with all the words in the input (in the encoder), as illustrated in the figure below.

<img src="./images/enc_dec_attention_1.png" alt="An illustration of Encoder-Decoder Attention" style="width: 800px;">

Now that we have a general sense of three types of attention used in transformers, we can talk about how it's calculated. 

First, the general equations for the different types of attention are almost the identical as seen in the figure below. In the equations, **Q** is for the **Query** matrix, **K** is for the **Key** matrix and **V** is for the **Value** matrix. On the left, we have the equation for Self-Attention and Encoder-Decoder Attention. As we see the differences in these types of attention are not from the equation we use, but how the **Q**, **K**, and **V** matrices are computed. On the right, we see the equation for Masked Self-Attention and the only difference it has from the equation on the left is the addition of a **Mask** matrix, **M**, that prevents words that come after a specific **Query** from being included in the final attention scores. 

<img src="./images/attention_equations.png" alt="Equations for computing attention" style="width: 800px;">

**NOTE:** Since both equations are very similar, we'll go through one example and point out the key differences when we get to them.

First, given word embedding values for each word/token in the input phrase **\<SOS> let's go** in matrix form, we multiply them by matrices of weights to create **Queries**, **Keys** and **Values** 

<img src="./images/attention_compute_q.png" alt="Computing the Q matrix" style="width: 800px;">

<img src="./images/attention_compute_k.png" alt="Computing the K matrix" style="width: 800px;">

<img src="./images/attention_compute_k.png" alt="Computing the V matrix" style="width: 800px;">

We then multiply the **Queries** by the transpose of the **Keys** so that the query for each word calcualtes a similarity value with the keys for all of the words. **NOTE:** As seen in the illustration below, Masked Self-Attention calculates the values for all **Query/Key** pairs, but, ultimately, ignores values for when a token's **Query** comes before other token's **Keys**. For example, if the **Query** is for the first token **\<SOS>**, then Masked Self-Attention will ignore the values calculated with **Keys** for **Let's** and **go**, because those tokens come after **\<SOS>**.

<img src="./images/attention_q_times_kt.png" alt="Calculating the similarities between Queries and Keys" style="width: 800px;">

The next step is to scale the similarity scores by the square root of the number of columns in the **Key** matrix, which represents the number of values used to represent each token. In this case, we scale by the square root of 2.

<img src="./images/attention_scaling_scores.png" alt="Scaling the similarities" style="width: 800px;">

Now, if we were doing Masked Self-Attention, we would mask out the values we want to ignore by adding -infinity to them, as seen below. This step is the only difference between Self-Attention and Masked Self-Attention. 

<img src="./images/attention_masking.png" alt="Masking out scaled similarities for Masked Self-Attention" style="width: 800px;">

The next step is to applyt the **SoftMax()** function to each row in the scaled similarities. We'll do this first for the Self-Attention without a mask (below)...

<img src="./images/attention_softmax.png" alt="Applying the SoftMax() function to each row in the scaled similarity matrix" style="width: 800px;">

...and we'll also do it for Masked Self-Attention (below).

<img src="./images/attention_softmax_masked.png" alt="Applying the SoftMax() function to each row in the masked scaled similarity matrix" style="width: 800px;">

The SoftMax() function gives us percentages that the **Values** for each token should contribute to the attention score for a specific token. Thus, we can get the final attention scores by multiplying the percentages with the **Values** in matrix **V**. First, we'll do this with the un-masked percentages...

<img src="./images/attention_final_scores.png" alt="Calculating the final attention scores" style="width: 800px;">

...and then we'll calculate the final Masked Self-Attention scores.

<img src="./images/attention_final_scores_masked.png" alt="Calculating the final attention scores" style="width: 800px;">

# BAM!

Now that we know how to calculate the differen types of attention, let's code the `Attention()` class.

In [None]:
class Attention(nn.Module): ## NOTE: We only need to inherit from L.LightningModule 
                            ##       in the class that puts all the pieces together.
    def __init__(self, d_model=2):
        ## d_model = the number of embedding values per token.
        ##           In the transformer I used in the StatQuest: Transformer Neural Networks Clearly Explained!!!
        ##           d_model=2, so that's what we'll use as a default for now.
        ##           However, in "Attention Is All You Need" d_model=512

        
        super().__init__()
        
        ## Initialize the Weights (W) that we'll use to create the
        ## query (q), key (k) and value (v) for each token
        ## NOTE: Most implementations that I looked at include the bias terms
        ##       but I didn't use them in my video (since they are not in the 
        ##       original Attention is All You Need paper).
        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        
        ## We'll keep track of which dimension specifies the rows in a matrix 
        ## and which specifies the columns. We're doing this because usually
        ## people train with batches of data, and the first dimension, dimension 0
        ## is used for the batch, dimension 1 is for rows and dimension 2 is for columns.
        ## However, in this example, we are not using batches. However, by using
        ## variables, we can easily change things if we wanted to use batches.
        self.row_dim = 0
        self.col_dim = 1

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
        ## Create the Query, Key and Values using the encodings
        ## associated with each token
        ## For normal Self-Attention and Masked Self-Attention...
        ##
        ## encodings_for_q == encodings_for_k == encodings_for_v
        ##
        ## ...however, for Encoder-Decoder Attention, encodings_for_q comes from the Decoder
        ## and encodings_for_k and encodings_for_v are different and come from the Encoder.
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)
        
        ## Compute attention scores
        ## the equation is (q * k^T)/sqrt(d_model)
        ## NOTE: It seems most people use "reverse indexing" for the dimensions when transposing k
        ##       k.transpose(dim0, dim1) will transpose k by swapping dim0 and dim1
        ##       In standard matrix notation, we would want to swap rows (dim=0) with columns (dim=1)
        ##       If we have 3 dimensions, because of batching, and the batch was the first dimension
        ##       And thus dims are defined batch = 0, rows = 1, columns = 2
        ##       then dim0=-2 = 3 - 2 = 1. dim1=-1 = 3 - 1 = 2.
        ##       Alternatively, we could put the batches in dim 3, and thus, dim 0 would still be rows
        ##       and dim 1 would still be columns. I'm not sure why batches are put in dim 0...
        ##
        ##       Likewise, the q.size(-1) uses negative indexing to reverse to the number of columns in the query
        ##       which tells us d_model. Alternatively, we could ust q.size(2) if we have batches in the first
        ##       dimension or q.size(1) if we have batches in the 3rd dimension.
        ##
        ##       Since there are a bunch of ways to index things, I think the best thing to do is use
        ##       variables "row_dim" and "col_dim" instead of numbers...
        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))
        scaled_sims = sims / torch.tensor(q.size(self.col_dim)**0.5)
        
        if mask is not None:
            ## Here we are masking out things we don't want to pay attention to
            ## (like the <PAD> (which is used when we have a batch of inputs sequences
            ## and they are not all the exact same length... Because the batch is passed
            ## in as a matrix, each input sequence has to have the same length, so we
            ## add <PAD> to the shorter sequences so that they are all as long ast the
            ## longest sequence.))
            ##
            ## We replace <PAD> and other things we wanted masked out
            ## with a very large negative number (to approximate -infinity) so that the SoftMax() function
            ## will give all masked elements an output value (or "probability") of 0.
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9) # I've also seen -1e20 and -9e15 used in masking
        
        ## Apply softmax to determine what percent of each token's value to
        ## use in the final attention values.
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)
        
        ## Scale the values by their associated percentages and add them up.
        attention_scores = torch.matmul(attention_percents, v)
        
        return attention_scores

----

<a id="encoder"></a>
# The Encoder Class

Now that we have coded up the `PositionEncoding()` and `Attention()` classes, we're ready to put all the pieces together to build an `Encoder()` class, as seen in the figure below. 

<img src="./images/encoder_diagram.png" alt="A diagram of an encoder" style="width: 800px;">


A basic Encoder simply brings together...

- Word Embedding
- Position Encoding
- Self-Attention
- Residual Connections

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, num_tokens=4, d_model=2, max_len=3):
        
        super().__init__()
        
        ## We set the seed so we start with the same random numbers each time.
        ## This means, in theory, you should get the exact same results as me.
        L.seed_everything(seed=42)

        ## NOTE: In this simple example, we are just using a "single layer" encoder.
        ##       If we wanted to have multiple layers of encoders, then we would
        ##       take the output of one encoder module and use it as input to
        ##       the next module.
        self.we = nn.Embedding(num_embeddings=num_tokens,
                               embedding_dim=d_model)     
        
        self.pe = PositionEncoding(d_model=d_model, 
                                   max_len=max_len)

        self.self_attention = Attention(d_model=d_model)
        ## NOTE: In this simple example, we are just doing plain old vanilla attention
        ## If we wanted to do multi-head attention, we could
        ## initailize more Attention objects like this...
        ##
        ## self.self_attention_2 = Attention(d_model=d_model)
        ## self.self_attention_3 = Attention(d_model=d_model)
        ##
        ## If d_model=2, then using 3 self_attention objects would 
        ## result in d_model*3 = 6 self-attention values per token, 
        ## so we would need to initialize
        ## a matrix of weights to reduce the dimension of the 
        ## self attention values back down to d_model...
        ## 
        ## self.reduce_attention_dim = nn.Linear(in_features=(num_attention_heads*d_model), out_features=d_model)

        
    def forward(self, token_ids):

        ## Get word embeddings
        word_embeddings = self.we(token_ids)
        
        ## Add positional encoding to the word embeddings
        position_encoded = self.pe(word_embeddings)
        
        ## Calculate the self attention values
        self_attention_values = self.self_attention(position_encoded, 
                                                    position_encoded, 
                                                    position_encoded)
        ## NOTE: If we were doing multi-head attention, we would
        ## calculate the self-attention values with the other attention objects
        ## like this...
        ##
        ## self_attention_values_2 = self.self_attention_2(...)
        ## self_attention_values 3 = self.self_attention_3(...)
        ## 
        ## ...then we would concatenate all the self attention values...
        ##
        ## all_self_attention_values = torch.cat(self_attention_values_1, ...)
        ##
        ## ...and then run them through reduce_dim to get back to d_model values per token
        ##
        ## final_self_attention_values = self.reduce_attention_dim(all_self_attention_values)
        
        ## Add the position encoded values to the self attention values
        ## To get the output values.
        output_values = position_encoded + self_attention_values

        return output_values

# BAM!

Now let's code the `Decoder()` class.

----

<a id="decoder"></a>
# The Decoder Class

As we see in the figure below, the decoder class is almost the same as the encoder, execpt that it also includes Encoder-Decoder Attention, where the **Keys** and **Values** are created from the outputs from the Encoder, a fully connected layer so that we can have 5 outputs, one per word in the vocabulary, and a SoftMax() function to select the output token.

<img src="./images/decoder_diagram.png" alt="A diagram of an decoder" style="width: 800px;">

A basic Decoder simply brings together...

- Word Embedding
- Position Encoding
- Self-Attention
- Residual Connections
- Encoder-Decoder Attention
- A fully connected layer
- SoftMax - However, the loss function we are using `nn.CrossEntropyLoss()`, applies the SoftMax for us, so we will not include it here.

In [None]:
class Decoder(nn.Module):
    def __init__(self, num_tokens=4, d_model=2, max_len=3):
        
        super().__init__()
        
        ## Just like in the encoder, we are setting the seed
        ## so that you can get the same results as me.
        ## NOTE: This time we are using seed=43, which
        ## is different from what we did in the encoder.
        ## We're using a different seed number so that we
        ## will start out with different embedding values
        L.seed_everything(seed=43)
        
        ## NOTE: Just like for the encoder, we are just using a "single layer" decoder.
        self.we = nn.Embedding(num_embeddings=num_tokens, 
                               embedding_dim=d_model)     
        
        self.pe = PositionEncoding(d_model=d_model, 
                                   max_len=max_len)

        ## NOTE: Just like for the encoder, we are only using a single head for
        ##       self-attention and encoder-decoder attention
        self.self_attention = Attention(d_model=d_model)

        self.enc_dec_attention = Attention(d_model=d_model)

        self.fc_layer = nn.Linear(in_features=d_model, out_features=num_tokens)
        
        self.row_dim = 0
        self.col_dim = 1
        
        
    def forward(self, token_ids, encoder_values):
        
        word_embeddings = self.we(token_ids)
        position_encoded = self.pe(word_embeddings)

        ## For the decoder, we need to use "masked self-attention" so that 
        ## when we are training, the decoder can't cheat and look ahead at
        ## what words come after the current word it is working on.
        ## To create the mask we are creating a matrix where the lower triangle
        ## is filled with 0, and everything above the diagonal is filled with 0s.
        mask = torch.tril(torch.ones((token_ids.size(dim=self.row_dim), token_ids.size(dim=self.row_dim))))
        ## We then replace the 0s above the diagonal, which represent the words
        ## we want to be masked out, with "True", and replace the 1s in the lower
        ## triangle, which represent the words we want to include when we calcualte
        ## self-attention for a specific word in the output, with "False".
        mask = mask == 0
        
        self_attention_values = self.self_attention(position_encoded, 
                                                    position_encoded, 
                                                    position_encoded, 
                                                    mask=mask)  
        
        residual_connection_values = position_encoded + self_attention_values
        
        enc_dec_attention_values = self.enc_dec_attention(residual_connection_values,
                                                          encoder_values,
                                                          encoder_values)
        
        residual_connection_values = enc_dec_attention_values + residual_connection_values
        
        
        fc_layer_output = self.fc_layer(residual_connection_values)

        ## NOTE: We are not passing the fc_layer_output to a SoftMax like in the illustration because
        ## the loss function we're using, nn.CrossEntropyLoss(), will apply it for us.
        
        return fc_layer_output

# BAM!

Now that we have coded up the `Encoder()` and `Decoder()` classes, all that's left is to code up a `Transformer()` that connects the two.

----

<a id="transformer"></a>
# The Transformer Class

The `Transformer()` class simply connects the outputs from the Encoder to the Decoder, as seen in the figure below.

<img src="./images/enc_dec_transformer.png" alt="an encoder-decoder neural network" style="width: 800px;">


In [None]:
class Transformer(L.LightningModule):

    def __init__(self, input_size, output_size, d_model=2, max_len=3):
        
        super().__init__()
        
        self.encoder = Encoder(num_tokens=len(input_vocab), d_model=d_model, max_len=max_len)
        self.decoder = Decoder(num_tokens=len(output_vocab), d_model=d_model, max_len=max_len)
        
        self.loss = nn.CrossEntropyLoss()
        
        
    def forward(self, inputs, labels): 
        
        encoder_values = self.encoder(inputs)
        output_presoftmax = self.decoder(labels, encoder_values)
        
        return(output_presoftmax)


    def configure_optimizers(self): 
        
        return Adam(self.parameters(), lr=0.1)
    
    
    def training_step(self, batch, batch_idx): 
        
        input_i, label_i = batch # collect input
        
        ## First, let's append the <SOS> token to tokens used as input to the Encoder...
        input_tokens = torch.cat((torch.tensor([0]), input_i[0]))
        
        ## ...and to the tokens used as input to the decoder.
        teacher_forcing = torch.cat((torch.tensor([0]), label_i[0]))
        
        ## Now let's add the <EOS> token to the end of the known output
        expected_output = torch.cat((label_i[0], torch.tensor([4])))
                
        output_i = self.forward(input_tokens, teacher_forcing)
        loss = self.loss(output_i, expected_output)
                    
        return loss

# BAM!

Now that we've built the `Transformer()` class, let's see if it works correctly without training. To use the transformer, we encode an input phrase, either **\<SOS> let's go** or **\<SOS> to go**, with the Encoder, and then pass the outputs to a Decoder. The Decoder itself is in a loop that will continue to create output until it creates the **\<EOS>** (End of Sequence) token.

In [None]:
## First, a reminder of our input and output vocabularies...
# input_vocab = {'<SOS>': 0, # Start
#                'lets': 1,
#                'to': 2,
#                'go': 3}

# output_vocab = {'<SOS>': 0, # Start
#                 'ir': 1,
#                 'vamos': 2,
#                 'y': 3,
#                 '<EOS>': 4} # End
max_length = 3

## Create a tranformer object...
transformer = Transformer(len(input_vocab), len(output_vocab), d_model=2, max_len=max_length)

## Encode the user input...
encoder_values = transformer.encoder(torch.tensor([0, 1, 3])) # <SOS> let's go # Expecting: 0, 2, 4 = <SOS> vamos <EOS>
# encoder_values = transformer.encoder(torch.tensor([0, 2, 3])) # <SOS> to go  # Expecting: 0, 1, 4 = <SOS> ir <EOS>
    
## Since we initialize the decoder with the <SOS> token, we
## can consider that <SOS> to be the first predicted token
predicted_ids = torch.tensor([0]) # set the first predicted token to <SOS> to initialize the decoder
for i in range(max_length):
    ## given the current predicted tokens and the encoded input, 
    ## predict the next token with the decoder
    ## NOTE: "prediction" is the output from the fully connected layer,
    ##      not a softmax() function. We could, if we wanted to,
    ##      Run "prediction" through a softmax() function, but 
    ##      since we're going to select the item with the largest value
    ##      we can just use argmax instead...
    prediction = transformer.decoder(predicted_ids, encoder_values)

    ## Now use argmax() to select the id of the predicted token
    ## NOTE: The first time we call decoder(), with just the <SOS> token 
    ##       to initialize things, prediction
    ##       will be a matrix with a single row of values, 
    ##       the output from the fully connected layer, like this...
    ##
    ##       tensor([[ 0.0417,  0.0945,  0.2714, -0.0105,  0.0902]]
    ##       
    ##       We then take the index for the element with the largest
    ##       value (the 3rd element in this case, which is "vamos") and then
    ##       loop around and call decoder() a second time, this time with two tokens: <SOS> vamos
    ##       This will return a prediction for <SOS> and a prediction for vamos like this...
    ##
    ##       [[ 0.0417,  0.0945,  0.2714, -0.0105,  0.0902],
    ##        [ 0.8693, -1.0257, -0.6939,  0.1791, -0.6546]]
    ##
    ##       NOTE: The first row is the same as the row that was returned earlier.
    ##       Since we already figured out that the first row predicts vamos, we 
    ##       Only need to get the prediction from the second row. So, to make sure
    ##       we always apply argmax() to the final row in the matrix, we index
    ##       the final row with -1.
    ##
    ##       ALSO NOTE: If you're wondering why we need to make a prediction for
    ##       <SOS> every time we call the decoder...We do this because
    ##       the decoder has self-attention, so the prediciton made
    ##       from "vamos" requires keys and values from <SOS>
    ##       Could we optimize this and just store things in a table? Probably.
    ##       But that's another project...
    predicted_id = torch.tensor([torch.argmax(prediction[-1,:])])
    ## add the predicted token id to the list of predicted ids.
    predicted_ids = torch.cat((predicted_ids, predicted_id))
        
    if (predicted_id == 4): ## if the prediction is <EOS> then we are done.
        break
        
print("\npredicted_ids:", predicted_ids)

And, without training, the transformer predicts **\<SOS> vamos \<SOS> ir**, but we wanted it to predict **\<SOS> vamos \<EOS>** So, since the transformer didn't correctly translate the English phrases into Spanish, we'll have to train it.

----

<a id="train"></a>
# Train the Transformer!!!

To train a transformer, we simply create an object from the `Transformer()` class...

In [None]:
transformer = Transformer(len(input_vocab), len(output_vocab), d_model=2, max_len=3)

...and then create a Lightning `Trainer()` and train the transformer with the `dataloader` that we created earlier.

In [None]:
trainer = L.Trainer(max_epochs=30)
trainer.fit(transformer, train_dataloaders=dataloader)

# Double BAM!!!

Now that we've trained the transformer, let's use it!

----

<a id="use"></a>
# Use the Trained Transformer!!!

To use the transformer that we just trained, we encode an input phrase, either **\<SOS> let's go** or **\<SOS> to go**, with the Encoder, and then pass the outputs to a Decoder. The Decoder itself is in a loop that will continue to create output until it creates the **\<EOS>** (End of Sequence) token.

In [None]:
## First, a reminder of our input and output vocabularies...
# input_vocab = {'<SOS>': 0, # Start
#                'lets': 1,
#                'to': 2,
#                'go': 3}

# output_vocab = {'<SOS>': 0, # Start
#                 'ir': 1,
#                 'vamos': 2,
#                 'y': 3,
#                 '<EOS>': 4} # End

max_length = 3
row_dim = 0
col_dim = 1

## Encode the user input...
encoder_values = transformer.encoder(torch.tensor([0, 1, 3])) # <SOS> let's go # Expecting: 0, 2, 4 = <SOS> vamos <EOS>
# encoder_values = transformer.encoder(torch.tensor([0, 2, 3])) # <SOS> to go  # Expecting: 0, 1, 4 = <SOS> ir <EOS>
    
## Since we initialize the decoder with the <SOS> token, we
## can consider that <SOS> to be the first predicted token
predicted_ids = torch.tensor([0]) # set the first predicted token to <SOS> to initialize the decoder
for i in range(max_length):
    ## given the current predicted tokens and the encoded input, 
    ## predict the next token with the decoder
    ## NOTE: "prediction" is the output from the fully connected layer,
    ##      not a softmax() function. We could, if we wanted to,
    ##      Run "prediction" through a softmax() function, but 
    ##      since we're going to select the item with the largest value
    ##      we can just use argmax instead...
    prediction = transformer.decoder(predicted_ids, encoder_values)

    ## Use argmax() to select the id of the predicted token
    predicted_id = torch.tensor([torch.argmax(prediction[-1,:])])
    ## add the predicted token id to the list of predicted ids.
    predicted_ids = torch.cat((predicted_ids, predicted_id))
        
    if (predicted_id == 4): # if the prediction is <EOS>, then we are done
        break
        
print("\npredicted_ids:", predicted_ids)

And the output is **\<SOS> vamos \<EOS>**, which is exactly what we want.

# TRIPLE BAM!!!