# The StatQuest Illustrated Guide to Neural Networks and AI
## Chapter 10 - Seq2Seq and Encoder-Decoder Models with LSTMs

Copyright 2024, Joshua Starmer

This tutorial is from the book, **[The StatQuest Illustrated Guide to Neural Networks and AI](https://www.amazon.com/dp/B0DRS71QVQ)**.

In this notebook, we will build and train a Seq2Seq or Encoder-Deocder model with 2 layers of LSTMs, each layer with 2 stacks of LSTMs as seen in the picture below.

<img src="https://github.com/StatQuest/signa/blob/main/chapter_10/images/full_model.png?raw=1" alt="an encoder-decoder model with 2 layers of LSTMs, each layer with 2 stacks of LSTMs" style="width: 800px;">

In this tutorial, you will...

#### NOTE:
This tutorial assumes that you have read through the chapter on **Seq2Seq and Encoder-Decoder Models** in **The StatQuest Illustrated Guide to Neural Networks and AI**.

----

# Import the modules that will do all the work

The very first thing we need to do is load a bunch of Python modules. Python itself is just a basic programming language. These modules give us extra functionality to create and train a Neural Network.

In [None]:
%%capture
# %%capture prevents this cell from printing a ton of STDERR stuff to the screen

## First, check to see if lightning is installed, if not, install it.
##
## NOTE: If you **do** need to install something, just know that you may need to
##       restart your session for python to find the new module(s).
##
##       To restart your session:
##       - In Google Colab, click on the "Runtime" menu and select
##         "Restart Session" from the pulldown menu
##       - In a local jupyter notebook, click on the "Kernel" menu and select
##         "Restart Kernel" from the pulldown menu
# import pip
# try:
#   __import__("lightning")
# except ImportError:
#   pip.main(['install', "lightning"])

In [None]:
import torch ## torch let's us create tensors and also provides helper functions
import torch.nn as nn ## torch.nn gives us nn.Module(), nn.Embedding() and nn.Linear()
import torch.nn.functional as F # This gives us the softmax() and argmax()
from torch.optim import Adam ## We will use the Adam optimizer, which is, essentially,
                             ## a slightly less stochastic version of stochastic gradient descent.
from torch.utils.data import TensorDataset, DataLoader ## We'll store our data in DataLoaders

import lightning as L ## Lightning makes it easier to write, optimize and scale our code

## NOTE: If you get an error running this block of code, it is probably
##       because you installed a new package earlier and forgot to
##       restart your session for python to find the new module(s).
##
##       To restart your session:
##       - In Google Colab, click on the "Runtime" menu and select
##         "Restart Session" from the pulldown menu
##       - In a local jupyter notebook, click on the "Kernel" menu and select
##         "Restart Kernel" from the pulldown menu

----

# Create the datasets that we will use for training Encoder-Decoder model

To make the model at least a little bit interesting, we will translate two english phrases, **Let's go** and **to go** into spanish. **Let's go** should translate to **vamos \<EOS\>** and **to go** should translate to **ir \<EOS\>**.

In [None]:
## first, we create a dictionary that maps vocabulary tokens to id numbers...
english_token_to_id = {'lets': 0,
                       'to': 1,
                       'go': 2,
                       '<EOS>': 3 ## <EOS> = end of sequence
                      }
## ...then we create a dictionary that maps the ids to tokens. This will help us interpret the output.
## We use the "map()" function to apply the "reversed()" function to each tuple (i.e. ('lets', 0)) stored
## in the token_to_id dictionary. We then use dict() to make a new dictionary from the
## reversed tuples.
english_id_to_token = dict(map(reversed, english_token_to_id.items()))

spanish_token_to_id = {'ir': 0,
                       'vamos': 1,
                       'y': 2,
                       '<EOS>': 3}
spanish_id_to_token = dict(map(reversed, spanish_token_to_id.items()))

inputs = torch.tensor([[english_token_to_id["lets"],
                        english_token_to_id["go"]],

                       [english_token_to_id["to"],
                        english_token_to_id["go"]]])

labels = torch.tensor([[spanish_token_to_id["vamos"],
                        spanish_token_to_id["<EOS>"]],

                       [spanish_token_to_id["ir"],
                        spanish_token_to_id["<EOS>"]]])

Now that we have created the data that we want to train the embeddings with, we'll store it in a `DataLoader`. Since our dataset is so small, using a `DataLoader` is a little bit of an overkill, but it it's easy to do, and it will allow us to easily scale up to a much larger vocabulary when the time comes.

In [None]:
## Now let's package everything up into a DataLoader...
dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)

----

# Build and Train a Seq2Seq/Encoder-Decoder Model from Scratch

In [None]:
class seq2seq(L.LightningModule):

    def __init__(self, max_len=2):

        super().__init__()

        self.max_output_length = max_len

        L.seed_everything(seed=420)

        #################################
        ##
        ## ENCODING
        ##
        #################################
        self.encoder_we = nn.Embedding(num_embeddings=4, # num_embeddings = # of words in input vocabulary
                                       embedding_dim=2)  # embedding_dim = 2 numbers per embedding

        self.encoder_lstm = nn.LSTM(input_size=2, # input_size = number of inputs (2 numbers per word)
                                    hidden_size=2,# hidden_size = number of outputs (2 per word per layer)
                                    num_layers=2) # num_layers = how many lstm's to stack
                                                  #          If there are 2 layers, then the short term memory from the
                                                  #          first layer is used as input to the second layer

        #################################
        ##
        ## DECODING
        ##
        #################################
        self.decoder_we = nn.Embedding(num_embeddings=4,
                                       embedding_dim=2)

        self.decoder_lstm = nn.LSTM(input_size=2,
                                    hidden_size=2,
                                    num_layers=2)

        self.output_fc = nn.Linear(in_features=2,  # in_features = # of outputs per LSTM
                                   out_features=4) # out_features = # of words in the output vocabulary

        #################################
        ##
        ## Training
        ##
        #################################
        self.loss = nn.CrossEntropyLoss()


    def forward(self, input, output=None):

        #################################
        ##
        ## ENCODING
        ##
        #################################
        ## first, use the encoder stage to create an intermediate encoding of the input text
        encoder_embeddings = self.encoder_we(input)
        encoder_lstm_output, (encoder_lstm_hidden, encoder_lstm_cell) = self.encoder_lstm(encoder_embeddings)

        #################################
        ##
        ## DECODING
        ##
        #################################
        ## We start by initializing the decoder with the <EOS> token...
        decoder_token_id = torch.tensor([spanish_token_to_id["<EOS>"]])
        decoder_embeddings = self.decoder_we(decoder_token_id)

        decoder_lstm_output, (decoder_lstm_hidden, decoder_lstm_cell) = self.decoder_lstm(decoder_embeddings,
                                                                                          (encoder_lstm_hidden,
                                                                                           encoder_lstm_cell))

        output_values = self.output_fc(decoder_lstm_output)
        outputs = output_values

        predicted_id = torch.tensor([torch.argmax(output_values)])
        predicted_ids = predicted_id

        for i in range(1, self.max_output_length):

            if (output == None): # using the model...
                if (predicted_id == spanish_token_to_id["<EOS>"]): # if the prediction is <EOS>, then we are done
                    break
                decoder_embeddings = self.decoder_we(predicted_id)
            else:
                ## run this when training the model
                decoder_embeddings = self.decoder_we(torch.tensor([output[i-1]]))

            decoder_lstm_output, (decoder_lstm_hidden, decoder_lstm_cell) = self.decoder_lstm(decoder_embeddings,
                                                                                              (decoder_lstm_hidden,
                                                                                               decoder_lstm_cell))

            output_values = self.output_fc(decoder_lstm_output)
            outputs = torch.cat((outputs, output_values), 0)
            predicted_id = torch.tensor([torch.argmax(output_values)])
            predicted_ids = torch.cat((predicted_ids, predicted_id))

        return(outputs)


    def configure_optimizers(self): # this configures the optimizer we want to use for backpropagation.
        return Adam(self.parameters(), lr=0.1) ## NOTE: Setting the learning rate to 0.1 trains way faster than
                                               ## using the default learning rate, lr=0.001


    def training_step(self, batch, batch_idx): # take a step during gradient descent.
        input_tokens, labels = batch # collect input
        output = self.forward(input_tokens[0], labels[0]) # run input through the neural network
        loss = self.loss(output, labels[0]) ## self.loss = cross entropy
        ###################
        ##
        ## Logging the loss
        ##
        ###################
        # self.log("train_loss", loss)

        return loss

Now that we have created the `seq2seq()` class, let's just run the phrase **Let's go** through it to see what it gets translated into.

In [None]:
model = seq2seq()
outputs = model.forward(input=torch.tensor([english_token_to_id["lets"],
                                            english_token_to_id["go"]]), ## translate "lets go", we should get "vamos <EOS>"
                        output=None)

print("Translated text:")
predicted_ids = torch.argmax(outputs, dim=1)
for id in predicted_ids:
    print("\t", spanish_id_to_token[id.item()])

And we see that **Let's go** was translated to **\<EOS\>** instead of what we wanted, which was **vamos \<EOS\>**. So let's train the model!

In [None]:
trainer = L.Trainer(max_epochs=40, accelerator="cpu")
trainer.fit(model, train_dataloaders=dataloader)

Now let's see if the model correctly translates **Let's go** into **vamos \<EOS\>**...

In [None]:
outputs = model.forward(input=torch.tensor([english_token_to_id["lets"],
                                            english_token_to_id["go"]]), ## translate "lets go", we should get "vamos <EOS>"
                        output=None)

print("Translated text:")
predicted_ids = torch.argmax(outputs, dim=1)
for id in predicted_ids:
    print("\t", spanish_id_to_token[id.item()])

...and it does!

### BAM!

Now let's see if the model correctly translates **to go** to **ir \<EOS\>**...

In [None]:
outputs = model.forward(input=torch.tensor([english_token_to_id["to"],
                                            english_token_to_id["go"]]), ## translate "lets go", we should get "vamos <EOS>"
                        output=None)

print("Translated text:")
predicted_ids = torch.argmax(outputs, dim=1)
for id in predicted_ids:
    print("\t", spanish_id_to_token[id.item()])

...and it does!

## DOUBLE BAM!!

Now that we have model that works, let's just see count the number of parameters we had to train by counting the number of parameters where `requires_grad` was set to `True`.

In [None]:
## count the number of parameters...
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Total number of trainable parameters:", total_trainable_params)

So this model has 220 parameters. Compared to modern models used in practical situations, that's not very many. However, we could still spare us the agony of having to train this model whenever we wanted to use it by saving the trained parameters to a file and then loading them when needed. So let's do that.

----

# Saving and loading the trained model weights...

In [None]:
## First, save the weights...
trainer.save_checkpoint("seq2seq_en2es_220_trained.ckpt") ## NOTE: You can specify a path as part of the filename

In [None]:
## Now let's create a new model and load in the saved weights...
new_model = seq2seq.load_from_checkpoint("seq2seq_en2es_220_trained.ckpt")

outputs = new_model.forward(input=torch.tensor([english_token_to_id["lets"],
                                                english_token_to_id["go"]]),
                            output=None)

print("Translated text:")
predicted_ids = torch.argmax(outputs, dim=1)
for id in predicted_ids:
    print("\t", spanish_id_to_token[id.item()])

# TRIPLE BAM!!!