# Transformers

<hr style="border:2px solid gray">

# Index <a id='index'></a>
1. [Introduction](#intro)
1. [Tokenization and transformer data structures](#tokenization)
1. [seq2seq problems and the introduction of attention](#attention)
1. [What are transformers?](#transformers)
1. [Transformers in PyTorch](#pytorch-transformers)


<hr style="border:2px solid gray">

# Introduction [^](#index) <a id='intro'></a>

So far, we have described different types of neural networks in terms of how they define locality:

* Fully connected neural networks are global, as each neuron connects to all neurons in the next layer

* Convolutional neural networks are very local, as pixels are convolved with adjacent pixels

* Graph neural networks use the graph structure to define locality

The figures below illustrate the different levels of locality between these types of neural network.


<div style="display: flex; justify-content: flex-start; gap: 80px; align-items: flex-start;">
<div style="display: flex; flex-direction: column; align-items: flex-start; width: 320px; margin: 0;">
<img src='nn_locality.png' width=320 style="align-self: center;"/>
<div style="margin-top: 10px; text-align: justify; max-width: 320px; font-style: italic; line-height: 1.2;">

<strong>(a)</strong>: in a standard fully-connected neural network, all nodes in one layer are connected to every node in the next layer; in other words, the model only has global connections.
</div>
</div>
<div style="display: flex; flex-direction: column; align-items: flex-start; width: 320px; margin: 0;">
<img src='cnn_locality.png' width=320 style="align-self: center;"/>
<div style="margin-top: 10px; text-align: justify; max-width: 320px; font-style: italic; line-height: 1.2;">

<strong>(b)</strong>: for a convolutional neural network, each convolutional layer aggregates adjacent pixels based on the size of the kernel. The model is very local.
</div>
</div>
<div style="display: flex; flex-direction: column; align-items: flex-start; width: 320px; margin: 0;">
<img src='gnn_locality.png' width=320 style="align-self: center;"/>
<div style="margin-top: 10px; text-align: justify; max-width: 320px; font-style: italic; line-height: 1.2;">

<strong>(c)</strong>: for a graph neural network, each graph layer aggregates the set of neighbouring nodes. Locality is defined when we define the graph structure, i.e. we can define it ourselves.
</div>
</div>
</div>

Now, we will consider something that was briefly mentioned in the GNN exercises: **attention mechanisms**. To describe how the locality of these mechanisms work, we can interpret this as the model *learning its own definition of locality*. We will discuss this in more detail shortly.
<br></br>

The main parts for this notebook will be as follows:

* A brief discussion of data structures relevant for transformers and the process of **tokenization**

* The historical introduction of attention mechanisms for so-called **sequence to sequence** (**seq2seq**) problems and some basic principles as to how they work

* One of the most influential developments in machine learning architectures, which relies on attention mechanisms: the **transformer**. This is the development that underpins many of the modern AI models, including large language models like ChatGPT, image generation models like DALL-E, and AlphaFold, a model that predicts the structure of proteins. 

* An overview of how we can use transformers practically, including building the architecture ourselves using PyTorch



<hr style="border:2px solid gray">

# Tokenization and transformer data structures [^](#intro) <a id='tokenization'></a>

You will likely have heard of modern ML models referring to tokens - this is how we break up a sequence and convert it to a numerical vector so we can learn how to solve sequence problems. There are often many ways that we can tokenize a given datatype. For example, for sentences, we could have the following tokenization schemes:

<!--could add schematics here if there is time-->

* Word-level tokenization: each individual word is a token, and the numerical vector we construct is the index of each word in our dictionary

* Character-level tokenization: each individual *character* is a token, and each character is assigned a numerical value

* Subword-level tokenization: each token is not necessarily a whole word or an individual character, but instead may be a small part of a word. Each subword may have specific meanings, e.g. we could see prefixes such as "pre" or "post" as tokens separate from words they may otherwise be a part of.

There are of course other options for tokenization in language processing problems, but these are a few examples. See the schematic below for an illustration of these tokenization schemes.


<center>
<img src='tokenization-schematic.png' width=800/>
</center>

<div style="text-align:center;">
<div style='width:780px;display:inline-block;vertical-align:top;margin-top:10px;line-height:1.2;'>
<div style="text-align: justify;">

*Three different tokenization schemes for the sentence "The cat isn't black". **(a)**: the original sentence. **(b)**: word-level tokenization. **(c)**: character-level tokenization. **(d)**: subword-level tokenization.*
</div></div></div>


Often after tokenization, a sequence must be projected to the right number of dimensions to match the model. This is often done using linear layers, and is referred to as finding embeddings of the sequence, similar to the node, edge, and graph embeddings we discussed for GNNs.

## Data structures relevant for transformers

For the vast majority of our discussion of transformers we will focus on text sequence data sets, as these were the initial class of problems the architecture was proposed to solve. However, there have been successes in applying the transformer architecture to many different data types, including but not limited to:

* Image datasets, a class of models called [**vision transformers**](https://arxiv.org/abs/2010.11929)

* Graph datasets, often referred to as [**graph transformers**](https://arxiv.org/abs/2407.09777)

* Audio sequence datasets, such as the [convolution-augmented transformer for speech recognition](https://arxiv.org/abs/2005.08100)

The principles behind the transformer architecture have proved very successful generally, resulting in improved performance on many standard benchmark tasks compared to traditional architectures, and have spurred on significant developments in other architectures to incorporate similar ideas.


<hr style="border:2px solid gray">

# seq2seq problems and the introduction of attention [^](#index) <a id='attention'></a>

One of the first notably successful applications of attention mechanisms was for natural
language processing (NLP) - in particular, an approach called **sequence-to-sequence** (**seq2seq**) where
NLP problems are understood as a process from one sequence into another sequence, via some
intermediate representation that contains all the info necessary to reconstruct the output
sequence. We call this intermediate representation the **context vector**.

In fact, we can consider this process as using an encoder to find an embedding of the input
sequence and using a decoder to go from the embedding to the target sequence, where in this
case the embedding is the context vector. We can consider this for an English-to-French
translation problem: 

* The input sequence is a sentence in English, which is represented by a numerical vector
  which may e.g. just be indices of words in a dictionary

* The encoder converts the input sequence into the context vector

* The context vector is then passed to the decoder, which decodes the context vector to a
  different numerical vector where each value corresponds to a word in the target language,
  in this case French

This is illustrated in the schematic below.

<center>
<img src='seq2seq-schematic.png' width=1000></img>
</center>  
<div style="text-align:center;">
<div style='width:950px;display:inline-block;vertical-align:top;margin-top:10px;line-height:1.2;'>
<div style="text-align: justify;">

*A schematic illustrating a seq2seq problem, for English-to-French translation. The input English phrase "the cat is black" is first represented as a numerical vector and then transformed into a context vector by the encoder. The context vector is passed to the decoder, which transforms it to some new numerical vector corresponding to the French translation of the original sentence, "le chat est noir".*
</div></div></div>

<!-- How much detail on RNNs and specifically RNN encoder-decoder is needed? e.g. can talk about how hidden states in encoder just depend on input and previous hidden state, while decoder both updates a hidden state and an output *separately*,  -->

Historically, both the encoder and decoder were so-called **recurrent neural networks** (**RNNs**), which are designed to handle sequential data where the previous point in the sequence is relevant for the next point in the sequence. This can include:

*  NLP: sentences are sequences, as earlier words in the sentence are important for understanding the context of later words in the sentence (and indeed,
   vice-versa); we can also see this as we read sentences as a sequence, one word after another.
   
* Time series data: values at previous times influence values in the future

In general, an individual RNN take both the current input and the previous output as inputs to the model, e.g. to find the output at a time $t = 1$ we give the model both the input for time $t = 1$ and the output for $t = 0$. 

You can read more about the general RNN encoder-decoder structure in [this paper](https://arxiv.org/abs/1406.1078).

However, there is a key problem with this approach, related to how much information can be conveyed by the context vector:

* For any size of context vector, it must be able to capture *all* the information contained in the input sequence

* If we want to handle longer input sequences, the amount of information that has to be captured by the context vector increases

* The context vector length is constant regardless of the length of input sequence, so it is difficult to be able to summarise enough information for long sequences without wasting resources for short sequences

* For longer sequences, RNNs have difficulty having equal weight from words earlier in the sequence than those later in the sequence, so we lose information from the start of the sentence if we have a long sentence

This is referred to as the **bottleneck** problem, as the decoder only sees the context vector. 
<br></br>

To put this in context, assume the average person has a vocabulary of around 20,000 words (as suggested in [this paper](https://pmc.ncbi.nlm.nih.gov/articles/PMC4965448/#sec15)). For a sentence with 10 words, if we assume any word could be in any place in the sentence (i.e disregarding grammar) but not allowing any repeated words, we can find the number of possible sequences using the binomial coefficient:

\begin{equation*}
\text{Number of sequences} = {{20,000}\choose 10} = 2.82 \times 10^{36}
\end{equation*}

Of course, the actual number of possible sentences will be less due to the restrictions from grammar, but also could have repeated wordsm which increases the number of possible sentences. We can see whatever our context vector is, it has to be able to convey any possible sentence of the given length, and capture that relationship to the output language of choice. This is a massive amount of information to convey through a single vector!


In order to overcome this, we need some way to pass more information from the encoder to the decoder. What information can we get? The general RNN encoder operation goes as follows:

* For the 1st step of the input sequence, find some hidden state by passing through the encoder model, like an embedding in GNNs

* For the 2nd step of the input sequence, input both the 2nd step input and the 1st step hidden state to find the 2nd step hidden state

* Repeat this process until the end of the sequence is reached; the final hidden state is the context vector

In fact, what we can do instead of just using the context vector as inputs to the decoder (as well as previous decoder hidden states) is use the information from each step of the encoding, i.e. use all the hidden states rather than just the last one. To do this, we use **attention mechanisms**.

## Attention mechanisms

When we say "attention mechanisms", what we actually mean is some method of determining the relative importance of different parts of the input, and then influence the model to **attend** to important parts of the input and disregard unimportant parts. In the context of machine translation, we can describe this as working out what words in the input sequence are most relevant to each word in the output sequence. 

How does this actually work? In general, we need these parts:

* Some representation of the output we want to predict

* Some representation of our input to our model

* An **alignment model** that scores how well a given single input value matches a single output value
<br></br>

To understand how we will use this, let us consider the RNN encoder-decoder model. We want to calculate some new quantity we can use to improve the performance of the decoder, based on the encoder hidden states and the alignment scores between the encoder hidden states and our decoder outputs. For decoder sequence step $t$, we do the following steps:

* Calculate the alignment score between the decoder hidden state $t - 1$ and all encoder hidden states

* Find the softmax of the alignment scores to get attention weights for decoder sequence step $t$

* Take the weighted sum of the encoder hidden states using the attention weights

The output of the weighted sum is used as an input to the decoder for sequence step $t$, alongside the decoder hidden state and the target (or predicted) sequence entry from step $t - 1$. In this example, the alignment calculated is between encoder hidden states and decoder hidden states, to find what encoder hidden states are most relevant to each decoder hidden state. This is illustrated in the schematic below.

<center>
<img src='attn-schematic.png' width=900></img>
</center>
<div style="text-align:center;">
<div style='width:900px;display:inline-block;vertical-align:top'>
<div style="text-align: justify;">

*The previous machine translation task, now with added attention. For each decoder step $i$, the alignment scores $\alpha_{ij}$
between encoder states $j$ and the previous decoder hidden state $i - 1$ are found, and the attention-weighted sum of encoder hidden 
states is calculated as the attention mechanism output. The decoder output is then determined by its previous value 
and the attention mechanism output.*
</div></div></div>

What function we use to calculate the alignment scores can have a significant effect on the performance of this approach.
A couple of the early examples included:

* [Bahdanau attention](https://arxiv.org/abs/1409.0473): use a single-layer neural network, with independent learnable 
weight matrices for the encoder and decoder hidden states, a tanh activation, and another learnable weight vector to 
project to a single value

* [Luong attention](https://arxiv.org/abs/1508.04025): take a weighted dot product between the encoder and decoder 
hidden states, where the weights are a learnable matrix

Of course, there are many other types of attention we might consider, and we will discuss an important one later.


<div style="background-color: #FFF8C6">

To give another example, we will consider how we might use attention in a GNN; you can read about this in more detail in
the [Graph Attention Networks paper](https://arxiv.org/abs/1710.10903).

You briefly saw this concept last time, when we used the `GATConv` layer to try and improve the performance of our GNN on the Cora dataset.
In essence, this approach includes attention in the neighbourhood aggregation:

* Rather than using a simple aggregation procedure, we can instead take some weighted sum (or other weighted aggregation method) of the neighbours

* Weights are derived according to some attention mechanism between the node of interest and each of its neighbours

* Effectively, we learn which of the neighbours are most important for prediction at the target node

In the case of this model, to find the alignment score between two nodes the two node embeddings are transformed by a 
single weight matrix (as is common in other GNN layers), concatenated, and passed through a single-layer neural network
with a LeakyReLU activation function that maps the concatenated vector to a single value. Attention weights are then 
computed as the softmax of these alignment scores over all nodes in the neighbourhood.

This is illustrated in the schematic below. 



<div style="display: flex; justify-content: center; gap: 80px; align-items: flex-start;">
<div style="display: flex; flex-direction: column; align-items: flex-start; width: 350px; margin: 0;">
<img src='gat-attn-mech.png' width=239 style="align-self: center;"/>
<div style="margin-top: 10px; text-align: justify; max-width: 350px; font-style: italic; line-height: 1.2;">

<strong>Left</strong>: the attention mechanism between a node $i$ and the node $j$, which is a node in the
neighbourhood of $i$. Node features are transformed according to the same weight matrix, aggregated and projected
by a learnable vector $\mathbf{a}$. A LeakyReLU activation is applied, and then the softmax over all
nodes in the neighbourhood is found to get the attention weights.
</div>
</div>
<div style="display: flex; flex-direction: column; align-items: flex-start; width: 504px; margin: 0;">
<img src='gat-node-pred.png' width=480 style="align-self: center;"/>
<div style="margin-top: 10px; text-align: justify; max-width: 504px; font-style: italic; line-height: 1.2;">

<strong>Right</strong>: finding the next layer node embedding using an attention mechanism. Attention weights $\alpha_{ij}$ are calculated between the node of interest $i$ and each of its neighbours $j$ (as well as itself). All the weighted node embeddings are then aggregated to produce the next embedding for the node of interest. 
</div>
</div>
</div>

<center>

*Schematics illustrating a graph attention mechanism.  Adapted from [[source](https://arxiv.org/abs/1710.10903)].*
</center>

In fact, the original paper shows that this change to a GNN architecture leads to a significant improvement in 
performance on both transductive and inductive benchmark tasks relative to previous state-of-the-art GNN models, including:

* State-of-the-art performance or better for the three major publication network benchmark datasets: Cora, Citeseer, and Pubmed

* A 20% improvement in performance for predicting protein-protein interactions based on graph representations of proteins

For more details on this, please see the [corresponding paper](https://arxiv.org/abs/1710.10903).

## Summary

In this section, we have discussed the attention mechanisms and their historical introduction, including:

* seq2seq problems including machine translation, and the difficulty of traditional RNN encoder-decoder methods

* The introduction of attention to machine translation

* Graph attention networks and multi-head attention

In the next section, we will use our newfound understanding of attention mechanisms to discuss one of the most influential developments in machine learning in recent years: **transformers**.



<hr style="border:2px solid gray">

# What are transformers? <a id='transformers'></a>


First proposed in the 2017 paper ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762) by Vaswani et al., 
the transformer architecture (and its components) is a greatly influential model that completely changed the approach 
to sequence-based ML tasks, and indeed has found great success in many different applications. 

As opposed to the complex RNN (or sometimes CNN) models often used for seq2seq tasks, which are often supplemented by 
attention mechanisms, the transformer architecture instead relies solely on attention mechanisms combined with regular linear layers, rather than any recurrence or convolutions. We will first introduce some of the language used in the original paper and then we will discuss the design of the transformer architecture.

## Queries, keys, and values

As a way to describe attention mechanisms, the terms queries, keys, and values were borrowed from database terminology and in fact were popularised to describe attention mechanisms by the original transformer paper. We can break this down as follows:

* Our training dataset consists of key-value pairs, e.g. a list of words with a value assigned to each word

* When we predict, we pass a query to the model to get a value

* For a given query, an attention mechanism returns a weighted combination of values, based on how well their corresponding keys match the query
<br></br>

In other words, we find the **alignment score** between the query and our set of keys, and return a weighted sum of the values weighted by the softmax of the **alignment scores**. This is of course the same type of mechanism we discussed in the previous section; we can think of what each of these are in our previous examples:
<br></br>

* RNN encoder-decoder with attention:
    * Find the alignment between the encoder hidden states and the decoder hidden states

    * Get a weighted sum of our encoder hidden states

    * Both the keys and the values are the encoder hidden states, and the decoder hidden states are the queries

* Graph attention networks: 
    * The weight vector $\mathbf{a}$ and the LeakyReLU function returns the alignment scores between the projected embeddings for node of interest and each of the neighbouring node (and itself)

    * The weighted sum is of the projected node embeddings for each neighbouring node 

    * Both the keys and values are the projected node embeddings for neighbouring (and self) nodes, and the query is the embedding for the node of interest


To illustrate these in a more practical sense, consider searching for a paper in a database:

* We provide some search term that describes what we are looking for, which is the query

* The query is compared to the titles and metadata of the papers in the database, which are the keys

* The papers corresponding to the keys with the best match to the query (the highest attention weights) are returned, ranked by the attention weights; the papers are the values

While in this example the queries, keys, and values are different, they do not have to be.

## The transformer architecture

The figure below shows the transformer architecture as illustrated in the original paper:

<center>
<img src="transformer-architecture.png" width=600></img>
</center>
<div style="text-align:center">
<div style="display: flex; justify-content: center; gap: 80px; align-items: flex-start;">
<div style="margin-top: 10px; text-align: justify; max-width: 580px; font-style: italic; line-height: 1.2;">

*The architecture of the original transformer model, taken from the [original paper]((https://arxiv.org/abs/1706.03762)).*
</div></div></div>

Currently this may seem a bit incomprehensible, but we will break each part down in turn and then see how it all fits together.

While a transformer is built of a encoder and a decoder like previous RNN models, there are three key things that set a transformer apart in compared to earlier seq2seq models (apart from the use of linear layers in place of recurrent or convolutional layers). These include:

* The choice of attention function: **scaled dot-product attention**

* **Multi-head attention**

* **Self attention**

It is the combination of these things, and where they are used in the model, that enabled the jump in performance seen using this model. 



## Scaled dot-product attention

While we briefly mentioned two attention functions earlier (Bahdanau and Luong), the attention mechanism used in transformers is generally established as very computationally efficient and performs similarly to the best attention mechanisms in the literature. This is defined as follows:

* Pack a set of queries into a single matrix $Q$

* Similarly pack the keys and values into matrices $K$ and $V$

* We denote the dimension of the queries and keys as $d_k$, and the dimension of the values as $d_v$

* The attention output is then given according to

$$\text{Attention(}Q, K, V\text{)} = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$

The normalising factor $\frac{1}{\sqrt{d_k}}$ is important as when we have large $d_k$ the dot product can reach very large values, and thus push the softmax output into regions with very small gradients. This can subsequently cause problems with training (similar to the exploding and vanishing gradient problems we have discussed before). This is discussed in a little more detail in the [original paper](https://arxiv.org/abs/1706.03762).

In fact, this attention function is very similar to the Luong attention mechanism, but with the extra normalising factor and without an explicit weight matrix in the dot product. This weight projection is done before inputs into scaled dot-product attention, rather than being included in the attentoin function itself.



## Multi-head attention

So far, we have considered a single attention mechanism in a model. However, what would happen if we had multiple, and how would we go about it, and is this even useful? 

In fact, the original transformers paper proposed the idea of **multi-head attention**, a way of using multiple attention functions in parallel to incorporate information from different representations simultaneously. In the original paper, this goes as follows:

* Typical attention mechanisms would use a single attention function for keys, values and queries with $d_\text{model}$ dimensions

* Instead, do $h$ independent linear projections of the queries, keys, and values, to $d_k$, $d_k$, and $d_v$ dimensions respectively - we refer to each of these sets of projections as an **attention head**

* For each attention head, apply the attention function in parallel to produce a $d_v$-dimensional output for each head

* Concatenate the outputs from all attention heads and finally projected them to the desired number of dimensions $d_\text{model}$

We can then learn the parameters of the linear projections for each attention head, allowing us to use not just one representation of the keys, queries, and values, but as many as we may want to. This allows us to incorporate more information than we might otherwise, including different ways of looking at the information. 

In an NLP task, this could be thought of as considering different possible meanings of a single word and then seeing how important other words are to understanding the sentence for each possible meaning of that word. 

**Note**: while the operation of these linear projections is the same as a linear layer in a neural network, they are explicitly *not* followed by an activation function and instead pass into the attention mechanism. However, the attention mechanism is also a nonlinearity, so we can see this a bit like an activation function here.

<div style="background-color: #FFF8C6">

We have seen the term "heads" before in machine learning, specifically when we talked about graph neural networks; we had different prediction heads for different graph tasks.

In fact, in general a head in machine learning refers to applying a different set of operations to the same base data, to reach different outputs. For example:

* In graph neural networks, if we wanted e.g. to classify nodes an edges simultaneously:

    * We would have a main GNN that finds node embeddings

    * We can then feed the node embeddings from the main GNN into two different sets of transformations to get the node and edge classifications

    * This is having multiple prediction heads
<br></br>
* In multi-head attention:

    * Each head is a different linear projection of the input queries, keys, and values

    * The output of each head learns a different representation of the queries, keys, and values, allowing us to learn different relations in our data simultaneously



The multi-head attention mechanism used in the original transformers paper is illustrated in the figure below.

<center>
<img src='mha-schematic.png' width=600></img>
</center>
<div style="text-align:center;">


<div style="display: flex; justify-content: center; gap: 80px; align-items: flex-start;">
<div style="margin-top: 10px; text-align: justify; max-width: 700px; font-style: italic; line-height: 1.2;">

*The structure of multi-head attention used in "Attention is All You Need". For a number of heads $h$, the input queries, keys, and values are projected $h$ times and the scaled dot-product attention is calculated for each projection. The attention outputs are then concatenated and projected once more to the desired dimensionality, producing the multi-head attention output. Adapted from the [original paper](https://arxiv.org/abs/1706.03762).*
</div></div></div>

While in principle we use different dimensions for the queries/keys and the values, in practice generally we will take $d_k = d_v$ i.e. set them equal. 

<div style="background-color:#FFCCCB">

If we want to write out a mathematical expression for multi-head attention, it goes as follows:

$$\text{MultiHead}(Q,\,K,\,V) = \text{Concat}(\text{head}_1,\,\dots,\,\text{head}_h)\,\mathbf{W}^O,$$
$$\text{where head}_i = \text{Attention}\left(\mathbf{W}^Q_i\,Q, \,\, \mathbf{W}^K_i\,K, \,\, \mathbf{W}^V_i\,V\right).$$

Individual symbols are defined as follows:

* $\mathbf{W}^Q_i$ denotes the queries linear projection parameter matrix for attention head $i$, which is a $d_\text{model} \times d_k$ matrix

* $\mathbf{W}^K_i$ denotes the keys linear projection parameter matrix for attention head $i$, which is a $d_\text{model} \times d_k$ matrix

* $\mathbf{W}^V_i$ denotes the values linear projection parameter matrix for attention head $i$, which is a $d_\text{model} \times d_v$ matrix

* $\text{Attention}$ denotes an arbitrary attention function of queries, keys, and values; in the case of transformers, this is the scaled dot-product attention described before

* $\text{Concat}$ denotes a concatenation into a single vector

* $\mathbf{W}^O$ denotes the linear projection matrix from the concatenated individual attention head outputs to the final model dimension, which is a $h d_v \times d_\text{model}$ matrix

Because each attention head is independent, any calculations across different heads can be parallelised. This can help greatly with the computational cost of training multi-head attention.


## Self attention

So far, when we have discussed attention between the output sequence and the input sequence. However, there is no reason why we couldn't find attention between any two things - in fact, we can consider so-called **self attention**, where we find attention weights between each element in the input sequence and all other elements in the sequence. 

This way, we can find what words in the input sequence are most important to understanding each word in the same sequence. 

This is illustrated in the schematic below.

<center>
<img src='self-attn-schematic.png' width=800>
</center>
<div style="text-align:center">
<div style="display: flex; justify-content: center; gap: 80px; align-items: flex-start;">
<div style="margin-top: 10px; text-align: justify; max-width: 700px; font-style: italic; line-height: 1.2;">

*Schematic illustrating self-attention, for a single query. The input sequence acts as the queries, keys, and values to calculate the output.*
</div></div></div>

In fact, we have seen self-attention already in the Graph Attention Networks example - each node attends to all of the nodes in its neighbourhood, and itself, to learn which nodes are most important for finding an embedding for a given node. 

One important thing to note is that we aren't just passing the same sequence in as the keys, queries, and values, but instead **different** linear projections of the same sequence. The most important keys for the queries are found through the attention function, and the weighted sum of the corresponding values is found. 

While similar ideas had been introduced previously for RNN models, it was the transformers paper that introduced the highly-parallelisable self-attention that has been such a success for modern models.

## Applications of attention in the model

The transformer architecture uses attention in several ways. These are highlighted in the transformer architecture diagram below:

<center>
<img src='transformer-mha-highlight.png' width=800/>
</center>
<div style="text-align:center">
<div style="display: flex; justify-content: center; gap: 80px; align-items: flex-start;">
<div style="margin-top: 10px; text-align: justify; max-width: 700px; font-style: italic; line-height: 1.2;">

*The three applications of attention mechanisms in the transformer architecture. **(a)**: encoder self-attention, **(b)**: encoder-decoder attention (sometimes called **cross-attention**), **(c)**: decoder (masked) self-attention*
</div></div></div>


These are described as follows:

* **Encoder self-attention**: each position in an encoder layer output sequence attends to all positions in the previous encoder output sequence (for the first encoder layer, this is in the input sequence)

* **Encoder-decoder attention**: just like in the RNN attention models, allow the decoder output to attend to all positions in the encoder output sequence; this means that the decoder attends over the whole encoded input sequence. In the context of transformers, this is sometimes referred to as **cross-attention**

* **Decoder self-attention**: same principle as for the encoder, but an additional constraint is needed to ensure an element in the output sequence is only determined by the previous element in the sequence, not later ones. This is done by masking keys that are not allowed for a given query, i.e. masking out elements that are not earlier in the sequence than the query


## Layer normalisation

<center>
<img src='transformer-layer-norm-highlight.png' width=400/>
</center>
<div style="text-align:center">
<div style="display: flex; justify-content: center; gap: 80px; align-items: flex-start;">
<div style="margin-top: 10px; text-align: justify; max-width: 450px; font-style: italic; line-height: 1.2;">

*Layer normalisation in the transformer architecture. After each sub-layer of the encoder or decoder, the outputs are summed with the inputs and passed through a layer normalisation to standardise the distribution of features across the layer.*
</div></div></div>

<!--may need to make more bullet point-y-->

Previously, we have discussed how the weights in one layer of a neural network are strongly dependent on the outputs of the neurons in the previous layer, and how we can handle this dependency with methods such as **batch normalisation**. 

The transformer architecture uses a related but different normalisation approach, called **layer normalisation**. Rather than normalising activations across the batch, this approach normalises activations across the layer, i.e. based on the statistics *within the layer*. To compare:

* Batch normalisation: computes correction factors common to all samples in the branch but different for each hidden neuron

* Layer normalisation: finds corrections common to each hidden neuron but different for each sample
<br></br>

As a result, layer normalisation can be applied regardless of batch size. This has some particular benefits for sequence problems:

* We generally want to consider each collection of tokens separately from other ones, so we can work e.g. one sentence at a time

* Because batch normalisation is over multiple sequences, this normalises over different sequences and separately for each token, so results in issues for test sequences longer than the training sequences

* Layer normalisation normalises over all dimensions apart from the batch dimension, so works the same irrespective of sequence length
<br></br>

<div style="background-color:#FFCCCB">

To express layer normalisation mathematically, start by defining the following:

* $a^l$: the input to the $l$-th hidden layer of a deep feed-forward neural network, referred to as the **activation** of the previous layer
* $H$: the number of hidden units in a layer
* $\gamma^l$ and $\beta^l$: learnable parameters for layer $l$, with the same dimensions as the desired output shape

Now we can define the mean and standard deviation in hidden layer $l$ according to

$$\mu^l = \frac{1}{H}\sum_{i = 1}^H a^l_i,\qquad\quad \sigma^l = \sqrt{\frac{1}{H}\sum_{i = 1}^H(a^l_i - \mu^l)^2},$$

where $a^l_i$ denotes the $i$-th element of $a^l$, and $\mu^l$ and $\sigma^l$ denote the layer mean and standard deviation respectively.

Finally, we can write the layer normalisation output according to

$$\text{LayerNorm}(a^l) = \bar{a}^l = \frac{a^l - \mu^l}{\sigma^l}\cdot\gamma^l + \beta^l$$

Note that all multiplication is element-wise, such that $\bar{a}^l_i = \frac{a^l_i - \mu^l}{\sigma^l} \cdot \gamma^l_i + \beta^l_i$.

## Transformer encoder and decoder

Now we will consider how the encoder and decoder are constructed in the transformer architecture:

<center>
<img src='transformer-enc-dec.png' width=400/>
</center>
<div style="text-align:center">
<div style="display: flex; justify-content: center; gap: 80px; align-items: flex-start;">
<div style="margin-top: 10px; text-align: justify; max-width: 400px; font-style: italic; line-height: 1.2;">

*The transformer encoder and decoder structure, from the original architecture diagram.*
</div></div></div>

Both the transformer encoder and decoder are built of a set of $N = 6$ identical sub-layers, with slightly different sub-layers between the encoder and decoder. These are structured as follows:

**Encoder layer**: 
    
* Contains two sub-layers:

    1. multi-head self-attention

    1. two-layer feed forward (i.e. non-recurrent) neural network with a ReLU activation function, with hidden dimension $d_{ff}$
<br></br>
* Sums output of each sub-layer with the input to that sub-layer i.e. includes a **residual connection** (like skip connections we saw in GNNs)
<br></br>
* Passes sum of input and sub-layer output to layer normalisation, such that the final sub-layer output is given as $\text{LayerNorm}(x + \text{Sublayer(x)})$, where $\text{Sublayer}(x)$ denotes the function of the sub-layer itself

<center>
<img src='enc-layer-schematic.png' height=230>
</center>
<div style="text-align:center">
<div style="display: flex; justify-content: center; gap: 80px; align-items: flex-start;">
<div style="margin-top: 10px; text-align: justify; max-width: 700px; font-style: italic; line-height: 1.2;">

*Schematic of a single transformer encoder layer. This is built of two sub-layers: multi-head self-attention and a feed-forward neural network. Residual connections are included around each sub-layer and layer normalisation is applied after each sub-layer. Adapted from the [original paper](https://arxiv.org/abs/1706.03762).*
</div></div></div>
<!-- encoder layer schematic -->

**Decoder layer**:

* Contains three sub-layers, two similar to the encoder layers:

    1. masked multi-head self-attention, to prevent earlier sequence entries attending to later sequence entries

    1. multi-head attention over encoder output, i.e. the encoder output as the keys and values and decoder self-attention output as the queries

    1. two-layer feed forward neural network with a ReLU activation function, with hidden dimension $d_{ff}$
<br></br>
* Like the encoder layers, uses residual connections around each sub-layer followed by layer normalisation
<br></br>
<center>
<img src='dec-layer-schematic.png' height=250>
</center>
<div style="text-align:center">
<div style="display: flex; justify-content: center; gap: 80px; align-items: flex-start;">
<div style="margin-top: 10px; text-align: justify; max-width: 850px; font-style: italic; line-height: 1.2;">

*Schematic of a single transformer decoder layer. This is built of three sub-layers: masked multi-head self-attention, multi-head attention over the encoder outputs, and a feed-forward neural newtork.  Residual connections are included around each sub-layer and layer normalisation is applied after each sub-layer. Adapted from the [original paper](https://arxiv.org/abs/1706.03762).*
</div></div></div>

For both encoder and decoder layers, the hidden dimension $d_{ff}$ for the feed-forward network is equal to 2048.

Later in the original paper, they also note that they include a **dropout** layer after each sub-layer, with a probability $p = 0.1$, before the residual connection and layer normalisation. This is to regularise the training and reduce overfitting.

## Positional encoding

Because the transformer layers are solely linear, the model will lose information about the order of the sequence unless we do something about it, i.e. manually add some information about relative or absolute positions of tokens in the sequence. This is done with **positional encodings**. 

<center>
<img src='transformer-pos-enc-highlight.png' width=480/>
</center>
<div style="text-align:center">
<div style="display: flex; justify-content: center; gap: 80px; align-items: flex-start;">
<div style="margin-top: 10px; text-align: justify; max-width: 460px; font-style: italic; line-height: 1.2;">

*Positional encoding in the transformer architecture. These are added to input sequence to preserve information about the position of each token in the sequence, which are otherwise lost with no recurrent layers.*
</div></div></div>

Applied to the embeddings inputted to the encoder and decoder, the positional encoding has the same dimension as the embeddings so they can be summed. 

For the transformer model, each dimension of the positional encoding is a $\sin$ or $\cos$ of the token position, where the wavelength of the sinusoid is determined by the index of the dimension and the total number of dimensions. Because this is periodic, the model can learn relative positions of tokens in the sequence.

We can think of this as adding back in explicit locality; while our sequence data has a defined order, because all the actual layers in the transformer architecture are based on linear layers we lose the information about order from the data. Adding positional encoding effectively adds locality back into our data, so we know what tokens are close to what other tokens.

<div style="background-color:#FFCCCB">


Explicitly, the positional encodings from the original paper are given as

\begin{align*}
\text{PE}_{(\text{pos},\,i)} &= \sin\left(\text{pos}/10000^{i/d_{\text{model}}}\right)\qquad &i \text{ even}\\
\text{PE}_{(\text{pos},\,i)} &= \cos\left(\text{pos}/10000^{(i-1)/d_{\text{model}}}\right)\qquad &i \text{ odd}
\end{align*}
where $\text{pos}$ denotes the position of the token, $i$ denotes the dimension, and $d_{\text{model}}$ is the number of dimensions of the embeddings.

The reason for proposing this positional embedding is that $\text{PE}_{\text{pos} + k}$ is a linear function of $\text{PE}_{\text{pos}}$ (for given $i$, $\text{PE}_{\text{pos}+k,\,i}$ is a linear combination of $\text{PE}_{\text{pos},\,i}$ and $\text{PE}_{\text{pos},\,i+1}$, and the same for $\text{PE}_{\text{pos}+k,\,i+1}$). This can allow the model to learn relative positions of tokens in the sequence.


## Putting it all together

Finally, we can look at the architecture diagram from the original paper:

<center>
<img src="transformer-architecture.png" width=600></img>
</center>
<div style="text-align:center">
<div style="display: flex; justify-content: center; gap: 80px; align-items: flex-start;">
<div style="margin-top: 10px; text-align: justify; max-width: 580px; font-style: italic; line-height: 1.2;">

*The architecture of the original transformer model, taken from the [original paper]((https://arxiv.org/abs/1706.03762)).*
</div></div></div>

Let's break down how each part of this model works. Starting with the encoder:

<div style="display: flex; justify-content: flex-start; gap: 80px; align-items: flex-start;">

<div style="display: flex; flex-direction: column; align-items: flex-start; width: 700px; margin-left: 0px;">
<div style="margin-top: 0px; text-align: justify; max-width: 700px; line-height: 1.2;">

* The input sequence is converted to a numerical vector and passed through learned embeddings to project to $d_\text{model}$ dimensions

* Position encodings are added to the input embeddings

* The embeddings are passed through $N = 6$ encoder layers to produce the encoder outputs, including multi-head self-attention, layer normalisation, and a feedforward neural network

Note that unlike in RNNs, the entire input sequence is processed in one go.

</div>
</div>
<div style="display: flex; flex-direction: column; align-items: flex-start; width: 200px; margin-top: -60px;">
<img src="transformer-enc-zoom.png" width=200/>
</div>
</div>

<!-- <center>
<img src='transformer-enc-zoom.png' width = 300/>
</center> -->



<br></br>
For the decoder:

<div style="display: flex; justify-content: flex-center; gap: 80px; align-items: flex-start;">
<div style="display: flex; flex-direction: column; align-items: center; width: 700px; margin-left: 0px;">
<div style="margin-top: 0px; text-align: justify; max-width: 700px; line-height: 1.2;">

* The output sequence (more on this in a moment) is converted to a numerical vector and passed through learned embeddings to project to $d_\text{model}$ dimensions

* Position encodings are added to these embeddings

* Output sequence embeddings are passed through $N = 6$ decoder layers which contain the following operations, all followed by adding the residuals and applying layer normalisation:

    * Masked multi-head self-attention, so each sequence element only attends to earlier sequence elements

    * Multi-head attention over the encoder outputs as keys and values, and decoder embeddings as the queries (i.e. the cross-attention)

    * Feed-forward neural network, with the same weights as the encoder

</div>
</div>
<div style="display: flex; flex-direction: column; align-items: center; width: 200px; margin-top: -60px;">
<img src="transformer-dec-zoom.png" width=200/>
</div>

</div>

After $N$ decoder layers, the output is passed through a final learned linear transformation and a softmax layer to produce weights for the next token, which are often interpreted as probabilities (although they are not truly probabilities). 

<center>
<img src='transformer-output-zoom.png' width = 200/>
</center>

<br></br>

**Note**: the operation of the decoder is slightly different during training and prediction, as follows:

* During training, the output sequence used as the first decoder input is the true output, with the so-called **start-of-sequence token** (sometimes referred to as `<SOS>`) prepended i.e. at the start of the sequence

    * The whole target sequence is passed through the decoder and the loss is computed between the prediction and the true sequence, often with loss functions like cross entropy

* During prediction, the output sequence first inputted is an empty sequence apart from a start-of-sequence token at the start

    * The initial output sequence is passed through the decoder, and the token with the maximum value in the softmax output is assigned to be the next sequence entry $\hat{y}_1$

    * The updated sequence with the start-of-sequence token and our first predicted entry $\hat{y}_1$ is passed into the decoder to predicted the next entry $\hat{y}_2$

    * This procedure is repeated until predictions have been made for every entry in the output sequence
<br></br>

In training, the whole target sequence is used at once so any mistakes made by the model in early sequence entries do not cause errors in later sequence entries. In contrast, during prediction because we don't know what the output sequence should be we must feed the output back in as input to predict later entries in the output sequence.

<!--Could ideally use some schematic here? Maybe illustrating the building of the output sequence in training vs prediction-->

## Summary

In this section, we have covered the transformer architecture, including:

* design of the encoder and decoder, and the applications of attention

* layer normalisation

* positional encodings

* operation of the model during training and prediction

In the next section, we will discuss more practical elements of transformers including how to implement them in PyTorch, and considerations that are necessary during training.

<hr style="border:2px solid gray">

# Transformers in PyTorch <a id='pytorch-transformers'></a>

Because all the individual layers with weights in a transformer are just based on linear operations, we don't need any fancy other library to implement these, and instead can just use `torch`. 

PyTorch does have implementations of the original transformer architecture (and its components) available in `torch.nn`, including but not limited to:

* `Transformer`

* `TransformerEncoderLayer`

* `TransformerDecoderLayer`

* `MultiheadAttention`

However, the package maintainers specifically recommend not using these and instead implementing things yourself, through a new set of more flexible tools. You can read [this article](https://docs.pytorch.org/tutorials/intermediate/transformer_building_blocks.html) for more details on this, but we will **not** use the pre-defined transformer layers/architecture from `torch.nn`.

**Note**: everything we use to build the transformer will be from PyTorch, so all our data structures are Tensors that can store gradients. This is necessary so our model can train and propagate gradients back through, so we will **not** use any `numpy` functions (or anything else that isn't PyTorch) to ensure our model can train.

However, PyTorch does helpfully provide a nice efficient implementation of the attention mechanism from the original paper, as `scaled_dot_product_attention`. Let's explore this function in some detail, and then we'll try implementing some of the components of the transformer architecture ourselves.

We can find `scaled_dot_product_attention` in `torch.nn.functional`. First, let's do our starting imports:

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

Now we can look at the function signature for `scaled_dot_product_attention`:

In [2]:
help(F.scaled_dot_product_attention)

Help on built-in function scaled_dot_product_attention in module torch._C._nn:

scaled_dot_product_attention(...)
    scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
        is_causal=False, scale=None, enable_gqa=False) -> Tensor:
    
    Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed,
    and applying dropout if a probability greater than 0.0 is specified. The optional scale argument can only be
    specified as a keyword argument.
    
    .. code-block:: python
    
        # Efficient implementation equivalent to the following:
        def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
                is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
            L, S = query.size(-2), key.size(-2)
            scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
            attn_bias = torch.zeros(L, S, dtype=query.dtyp

This has a lot of information, so let's pick out the most important parts: the arguments for the function. These are as follows:

* `query`, `key`, and `value` : tensors containing the queries, keys, and values respectively

* `attn_mask` : an optional tensor, which specifies any combination of specific keys or queries that should be masked out for attention calculation, i.e. **not** used to calculate the attention weight

* `dropout_p` : probability of dropout, which is applied to the attention weights but before finding the weighted sum of the values. Note: this **always** applies regardless of if your model is in training or evaluation mode, so you need to make sure you have a check to set `dropout_p` to 0 if you are in evaluation mode.

* `is_causal` : if the attention should be causal or not, i.e. should earlier sequence elements attend to later sequence elements. If this is set to `True` and `attn_mask` is passed, this will cause an error.

* `scale` : scaling factor applied prior to the softmax, with the default value of $\frac{1}{{\sqrt{d_k}}}$ (or $\frac{1}{\sqrt{E}}$, using the PyTorch documentation notation), where $d_k$ (or $E$) refers to the dimensionality of the keys and queries

There is an additional argument called `enable_gqa`, which enables an experimental feature called Grouped Query Attention to reduce memory usage and decrease compute time, but we won't worry too much about this, especially as it is only working for some select methods.

Let's look specifically at the main inputs we will specify: `query`, `key`, and `value`. These are tensors that take the following shapes:

\begin{align*}
\texttt{query}\text{ : shape} &= (N,\,h_q,\,L,\,d_k) \\
\texttt{key}\text{ : shape} &= (N,\,h,\,S,\,d_k) \\
\texttt{value}\text{ : shape} &= (N,\,h,\,S,\,d_v), \\
\end{align*}
where each symbol is defined as follows:

* $N$ : batch size

* $h_q$, $h$ : number of heads for queries and keys/values respectively; generally, these are equivalent (these are only different for Grouped Query Attention, beyond the scope of this notebook)

* $L$ : length of the target sequence

* $S$ : length of the source sequence

* $d_k$ : dimensionality of the embedded queries and keys, referred to as $E$ in the PyTorch documentation

* $d_v$ : dimensionality of the embedded values, referred to as $E_v$ in the PyTorch documentation

Here we have used the notation from the PyTorch documentation. 

Let's now try using this, to see what the output looks like. We'll start by specifying our embedding dimensions, source and target sequence lengths, batch size, and number of heads, and generating some random data for queries, keys, and values:

In [3]:
batch_size = 32 # 32 sequences in a batch
h = h_q = 4 # 4 attention heads
S = L = 10 # source and target sequence are both length 10
d_k = d_v = 16 # same embedding dimensionality for queries/keys and values, 16-dim vectors
torch.manual_seed(0) # fix random seed for value generation

query = torch.rand(batch_size, h_q, L, d_k)
key = torch.rand(batch_size, h, S, d_k)
value = torch.rand(batch_size, h, S, d_v)

Now that we've defined our data and our parameters, let's try calculating the scaled dot product attention and look at the output shape:

In [4]:
sdpa = F.scaled_dot_product_attention(query, key, value)
print(sdpa.shape)

torch.Size([32, 4, 10, 16])


We can see our attention output has shape $(N, h, L, d_v)$, but of course this is straightforward when $S = L$ and $d_k = d_v$. What happens if these are not equivalent? 

Note we could in principle not have $h = h_q$, but that is the subject of the `enable_gqa` keyword that is beyond the scope of this course.

Let's consider a simple text summarisation problem, where we want to summarise a 10 token sequence with 1 output token, i.e. $S = 10$ and $L = 1$. We'll also set the value embedding dimensionality to 8 instead of 16. 

In [5]:
batch_size = 32
h = h_q = 4
S = 10
L = 1
d_k = 16
d_v = 8
torch.manual_seed(0)

query = torch.rand(batch_size, h_q, L, d_k)
key = torch.rand(batch_size, h, S, d_k)
value = torch.rand(batch_size, h, S, d_v)
sdpa = F.scaled_dot_product_attention(query, key, value)
print(f'Queries shape: {query.shape}\n')
print(f'Keys shape: {key.shape}\n')
print(f'Values shape: {value.shape}\n')
print(f'SDPA output shape: {sdpa.shape}\n')

Queries shape: torch.Size([32, 4, 1, 16])

Keys shape: torch.Size([32, 4, 10, 16])

Values shape: torch.Size([32, 4, 10, 8])

SDPA output shape: torch.Size([32, 4, 1, 8])



Now we can properly see the shape of the attention output: $(N, h, L, d_v)$ i.e. batch size and number of heads are the same as all the inputs, it has the sequence length from the queries, and the embedding dimensionality of the values. 

Note: this function is written with mult-head attention in mind, as it can take inputs from each head simultaneously rather than us needing to run it separately for each attention head.

## Implementing the transformer

We will now step through implementing the transformer, piece by piece. You will get to fill in a lot of the code yourself, using what you have learnt on this course so far.

### Multi-head attention in PyTorch

Firstly, let's step through calculating an example multihead attention output for a given set of queries, keys, and values. We start by defining our tensors and the dimensionality of our model; we'll use the same model dimensions as in the original paper, and consider we want to look at 10-length source and target sequences.

Note: in line with existing implementations, we will allow queries, keys, and values to have **different** dimensions when we pass them into the multi-head attention block, and we will project from these dimensions to the same embedding dimension. In other words, we have assumed $d_k = d_v$, which we previously noted was most common in practice. We will denote the input dimensions as $E_q$, $E_k$, and $E_v$ respectively. 

Also, because we are projecting to $h$ different heads, we will project each tensor to $d_\text{model}$ and assume each head has a dimension of $d_\text{head} = d_\text{model}/h$. We then can say that $d_k = d_v = d_\text{head}$.

In [6]:
torch.manual_seed(0)

E_q = E_k = E_v = 512 # input query, key, value embedding dimensions
d_model = 512 # model dimensions
h = 8 # number of heads
S = L = 10 # source and target sequence lengths
d_head = d_model // h # dimension per head, // means we want an integer result

Q = torch.randn(batch_size, L, d_model)
K = torch.randn(batch_size, S, d_model)
V = torch.randn(batch_size, S, d_model)

Our query, key and value inputs have a dimension of $d_{\text{model}}$, which we need to project to the relevant shape. We will also assume that we want the final output of our multi-head attention to have the same dimensionality as the input.

Now, let's define our linear projections to project to each attention head. To do this, we want to project from the $d_{\text{model}}$ input to the relevant dimensions, $h$ times. We can do this as follows:

In [7]:
Q_proj = nn.Linear(E_q, d_model)
K_proj = nn.Linear(E_k, d_model)
V_proj = nn.Linear(E_v, d_model)
q, k, v = Q_proj(Q), K_proj(K), V_proj(V)

Each of these linear layers will project from $d_\text{model}$ to $h$ times the relevant dimension for the queries, keys, and values. In other words, all of our attention heads are effectively in one tensor; this is more efficient than looping $h$ times, but does mean we will need to reshape our tensors in order to apply our attention mechanism. Let's do that:

In [8]:
# reshape queries
q = q.unflatten(-1, [h, d_head]) # (batch_size, L, h, d_head)
q = q.transpose(1, 2) # (batch_size, h, L, d_head)
# reshape keys
k = k.unflatten(-1, [h, d_head]) # (batch_size, S, h, d_head)
k = k.transpose(1, 2) # (batch_size, h, S, d_head)
# reshape values
v = v.unflatten(-1, [h, d_head]) # (batch_size, S, h, d_head)
v = v.transpose(1, 2) # (batch_size, h, S, d_head)

`unflatten(-1, [h, d_head])` is used to reshape the last dimension of each tensor into two dimensions, with shape `[h, d_head]`. We can double-check the shape of each to be sure:

In [9]:
print(f'Queries shape: {q.shape}\n')
print(f'Keys shape: {k.shape}\n')
print(f'Values shape: {v.shape}\n')

Queries shape: torch.Size([32, 8, 10, 64])

Keys shape: torch.Size([32, 8, 10, 64])

Values shape: torch.Size([32, 8, 10, 64])



We can see that each of these have a the right shape, with the batch size first, then the number of heads, then the sequence length, and finally the dimension per head. 

Now we've got these in the right shape, we can find the attention output using `scaled_dot_product_attention`:

In [10]:
attn_output = F.scaled_dot_product_attention(q, k, v)
print(f'Attention output shape: {attn_output.shape}\n')

Attention output shape: torch.Size([32, 8, 10, 64])



Now we have the attention output, we need to concatenate each head together so we can apply the final linear transformation. We want to get the shape to be $(N, L, h\cdot d_\text{head})$, i.e. we need to combine the axes for the heads and for the dimensions. We can do this as follows:

In [11]:
attn_output = attn_output.transpose(1, 2) # (batch_size, L, h, d_head)
attn_output = attn_output.flatten(-2) # (batch_size, L, h * d_head)
print(f'Attention output shape after concatenating heads: {attn_output.shape}\n')

Attention output shape after concatenating heads: torch.Size([32, 10, 512])



Finally, we need to apply the final linear transformation, so the attention output from each head can be aggregated to produce new output. We want the final output to have the same final dimension as the input, so we can easily chain layers together. We will therefore set our output dimension equal to $E_q$. As a result, we do our final projection as follows:

In [12]:
d_out = E_q
out_proj = nn.Linear(h * d_head, d_out)
final_output = out_proj(attn_output)

Finally, let's compare the input shapes and the final output shape:

In [13]:
print(f'Input query shape: {Q.shape}')
print(f'Input key shape: {K.shape}')
print(f'Input value shape: {V.shape}')
print(f'Final output shape: {final_output.shape}')

Input query shape: torch.Size([32, 10, 512])
Input key shape: torch.Size([32, 10, 512])
Input value shape: torch.Size([32, 10, 512])
Final output shape: torch.Size([32, 10, 512])


And so we can see the final output shape is exactly what we would expect, ready to be passed into the next layer in our model. 

<div style="background-color:#C2F5DD">

### Exercise

Now you have seen how we can do multihead attention in PyTorch, fill in the gaps in the class definition below to define a multihead attention layer that we can use in PyTorch models. Remember the key steps of multi-head attention:

* Project input queries, keys, and values to the total embedding dimension size $d_\text{model}$

* Reshape each projected input to separate out each head in the tensor

* Find the scaled dot product attention

* Reshape the attention output

* Apply the output linear transformation to the attention output

Note the `attn_mask` and `is_causal` arguments for the `forward` function. `attn_mask` is necessary if we have a batch of sequences of different lengths, such that some of the sequence elements are just padding and should be ignored for calculating attention, and `is_causal` is necessary if we want to ensure earlier sequence elements cannot attend to later sequence elements (like we will need for decoder self-attention).

In [14]:
class MultiHeadAttention(nn.Module):
    def __init__(self, E_q, E_k, E_v, d_model, num_heads):
        super().__init__()
        # Require embed_dim  = dim_per_head * num_heads i.e. divisible by num_heads
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        # Define projections of queries, keys, and values
        self.q_proj = nn.Linear(E_q, d_model)
        self.k_proj = nn.Linear(E_k, d_model)
        self.v_proj = nn.Linear(E_v, d_model)
        # Define output projection
        out_dim = E_q
        self.out_proj = nn.Linear(d_model, out_dim)
        # Define dimensions per head
        self.d_head = d_model // num_heads
        

    def forward(self, queries, keys, values, attn_mask = None, is_causal = False):
        Q = self.q_proj(queries)
        K = self.k_proj(keys)
        V = self.v_proj(values)
        # Reshape to correct shapes for scaled_dot_product_attention
        Q = Q.unflatten(-1,[self.num_heads, self.d_head]).transpose(1,2) # (N, h, L, d_head)
        K = K.unflatten(-1,[self.num_heads, self.d_head]).transpose(1,2) # (N, h, S, d_head)
        V = V.unflatten(-1,[self.num_heads, self.d_head]).transpose(1,2) # (N, h, S, d_head)
        # Find attention output
        attn_output = F.scaled_dot_product_attention(Q, K, V, attn_mask = attn_mask, is_causal = is_causal)
        attn_output = attn_output.transpose(1,2).flatten(-2)
        # Output projection
        return self.out_proj(attn_output)

How can we make this into multi-head self-attention? That is in fact straightforward: we just pass the same tensor as our queries, keys, and values!

In [17]:
d_input = 64
self_attn = MultiHeadAttention(E_q=d_input, E_k=d_input, E_v=d_input, d_model=512, num_heads=8)
x = torch.rand(32, 10, d_input)  # (batch_size, seq_length, d_input)
output = self_attn(x, x, x)

### Transformer layers

Now that we have a multi-head attention implementation, we can implement the other parts of the transformer. Let's first go through the steps of a single encoder layer, then we can again implement it as an `nn.Module` subclass.

From the architecture diagram above, a single encoder layer has the following steps:

* Input batch of sequences with dimension $d_\text{model} = 512$

* Pass into multi-head self-attention with 8 heads, apply dropout after the attention layer

* Sum the attention input with the attention output and pass through a layer normalisation
 
* Pass layer-normalised attention layer output into a two-layer feed-forward neural network (FFN), with a ReLU activation function and a dropout layer, with hidden dimension $d_{ff} = 2048$, and apply dropout on the output

* Sum the FFN input and output and apply layer normalisation

* Return the second layer normalisation output

All dropout layers use $p = 0.1$.

We proceed by first generating our data:

In [18]:
torch.manual_seed(0)

d_model = 512 # model dimensions
d_ff = 2048 # FFN hidden dimension
dropout_p = 0.1 # dropout probability
num_heads = 8 # number of heads
seq_len = 10 # sequence length

x = torch.randn(batch_size, seq_len, d_model)
print(f'Input shape: {x.shape} : (batch_size, seq_len, d_model)')

Input shape: torch.Size([32, 10, 512]) : (batch_size, seq_len, d_model)


Now we can define the components necessary for the attention part of the encoder layer:

In [19]:
# If you didn't get the MultiheadAttention class working, uncomment the following lines instead:
# self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, 
#                                   dropout=dropout_p, batch_first=True)
self_attn = MultiHeadAttention(E_q=d_model, E_k=d_model, E_v=d_model, d_model=d_model, num_heads=num_heads)
attn_dropout = nn.Dropout(p=dropout_p)
attn_norm = nn.LayerNorm(d_model)

attn_output = self_attn(x, x, x)
attn_output = attn_dropout(attn_output)
x = attn_norm(x + attn_output)
print(f'Output shape after MHA, dropout, and layer norm: {x.shape}')

Output shape after MHA, dropout, and layer norm: torch.Size([32, 10, 512])


Note that we kept `x` unchanged until after finding the attention output, and then we summed `x` and `attn_output`. This is the residual connection we described before.

`nn.LayerNorm` is a layer normalisation object, which requires us to pass a dimensionality to it when we instantiate it. 

Next, we need to pass `x` through a feed-forward neural network:

In [20]:
# Define FFN layers
ffn_layer1 = nn.Linear(d_model, d_ff)
ffn_activation = nn.ReLU()
ffn_dropout1 = nn.Dropout(p=dropout_p)

ffn_layer2 = nn.Linear(d_ff, d_model)
ffn_dropout2 = nn.Dropout(p=dropout_p)

ffn_norm = nn.LayerNorm(d_model)

# Feedforward network
ffn_output = ffn_dropout1(ffn_activation(ffn_layer1(x)))
ffn_output = ffn_dropout2(ffn_layer2(ffn_output))
x = ffn_norm(x + ffn_output) # FFN residual connection
print(f'Output shape after FFN, dropout, and layer norm: {x.shape}')

Output shape after FFN, dropout, and layer norm: torch.Size([32, 10, 512])


And those are the necessary steps to pass through a single transformer encoder layer. We can see the final output is the same shape as our input, ready to feed into another encoder layer.

<div style="background-color:#C2F5DD">

### Exercise

Now, fill the gaps in the class definition below to define a transformer encoder layer for use in a PyTorch model. Remember the key steps:

* Multi-head self attention, including:
    
    * A multi-head attention layer where queries, keys, and values are the same tensor
    
    * Dropout layer after multi-head attention

    * Residual connection after the dropout layer

    * Layer normalisation

* A two-layer feedforward network, including:

    * 1st linear layer projecting $d_\text{model}$ to $d_{ff}$

    * ReLU activation function

    * Dropout layer after activation

    * 2nd linear layer projecting $d_{ff}$ to $d_\text{model}$

    * Dropout layer after 2nd linear layer

    * Residual connection after the dropout layer

    * Layer normalisation

Note the `src_mask` and `is_causal` arguments for the `forward` method, which should be passed for the multi-head attention call.

In [21]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout = 0.1):
        """
        Initialize a Transformer encoder layer.
        Parameters:
        - d_model: dimension of the model
        - num_heads: number of attention heads
        - d_ff: dimension of the feed-forward network
        - dropout: dropout rate
        """
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_ff = d_ff
        self.dropout = dropout

        # Attention layer
        self.self_attn = MultiHeadAttention(d_model, d_model, d_model, d_model, num_heads)
        self.attn_dropout = nn.Dropout(dropout)
        self.attn_norm = nn.LayerNorm(d_model)

        self.linear1 = nn.Linear(d_model, d_ff)
        self.activation = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout2 = nn.Dropout(dropout)
        self.ffn_norm = nn.LayerNorm(d_model)

    def forward(self, src, src_mask = None, is_causal = False):
        """
        Forward pass of the Transformer encoder layer.
        Parameters:
        - src: input tensor of shape (batch_size, seq_length, d_model)
        - src_mask: mask for the tensor input, to mask out keys that should be excluded
                    from attention calculations
        - is_causal: whether to apply causal masking in self-attention
        """
        x = src
        # Self-attention
        attn_output = self.self_attn(x, x, x, attn_mask = src_mask, is_causal = is_causal)
        x = self.attn_norm(x + self.attn_dropout(attn_output))
        # Feed-forward network
        ffn_output = self.linear2(self.dropout1(self.activation(self.linear1(x))))
        x = self.ffn_norm(x + self.dropout2(ffn_output))
        return x


Now we have an encoder layer, we should also prepare a decoder layer. This is mostly the same as the encoder layer, but the self-attention layer is masked, so will need to have `is_causal = True`, and we have an additional attention block (with associated dropout, residual connection, and layer normalisation) that finds the cross attention between the encoder outputs, as keys and values, and the decoder hidden states as the queries.

<div style="background-color:#C2F5DD">

### Exercise

Fill the gaps in the class definition below to define a transformer decoder layer for use in a PyTorch model. Remember the key steps:

* Multi-head masked self attention, including:
    
    * A multi-head attention layer where queries, keys, and values are the same tensor, with `is_causal = True`
    
    * Dropout layer after multi-head attention

    * Residual connection after the dropout layer

    * Layer normalisation
<br></br>
* Multi-head cross-attention, including:    

    * A multi-head attention layer where queries are the input to the decoder layer and the keys and values are the final encoder output
    
    * Dropout layer after multi-head attention

    * Residual connection after the dropout layer

    * Layer normalisation
<br></br>
* A two-layer feedforward network, including:

    * 1st linear layer projecting $d_\text{model}$ to $d_{ff}$

    * ReLU activation function

    * Dropout layer after activation

    * 2nd linear layer projecting $d_{ff}$ to $d_\text{model}$

    * Dropout layer after 2nd linear layer

    * Residual connection after the dropout layer

    * Layer normalisation

Note the extra arguments for the `forward` method: `tgt_mask`, `enc_mask`, `tgt_is_causal`, and `enc_out_is_causal`; these are to indicate any necessary masking for the decoder self-attention and cross-attention, and if the scaled dot product attention should be causaul or not. 

In [22]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout = 0.1):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_ff = d_ff
        self.dropout = dropout

        # Masked self-attention block
        self.self_attn = MultiHeadAttention(d_model, d_model, d_model, d_model, num_heads)
        self.self_attn_dropout = nn.Dropout(p = dropout)
        self.self_attn_norm = nn.LayerNorm(d_model)

        # Cross attention block
        self.cross_attn = MultiHeadAttention(d_model, d_model, d_model, d_model, num_heads)
        self.cross_attn_dropout = nn.Dropout(p = dropout)
        self.cross_attn_norm = nn.LayerNorm(d_model)

        # FFN block
        self.linear1 = nn.Linear(d_model, d_ff)
        self.activation = nn.ReLU()
        self.ffn_dropout1 = nn.Dropout(p = dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.ffn_dropout2 = nn.Dropout(p = dropout)
        self.ffn_norm = nn.LayerNorm(d_model)

    def forward(self, tgt, enc_out, tgt_mask = None, enc_mask = None, tgt_is_causal = True, enc_out_is_causal = False):
        """
        Forward pass of the Transformer decoder layer.
        
        Parameters:
        - tgt: input tensor of shape (batch_size, seq_length, d_model)
        - enc_out: encoder output tensor of shape (batch_size, enc_seq_length, d_model)
        - src_mask: mask for the src tensor input
        - enc_mask: mask for the encoder output tensor input
        """
        x = tgt
        # Masked self-attention
        self_attn_output = self.self_attn(x, x, x, attn_mask = tgt_mask, is_causal = tgt_is_causal)
        x = self.self_attn_norm(x + self.self_attn_dropout(self_attn_output))
        
        # Cross-attention
        cross_attn_output = self.cross_attn(x, enc_out, enc_out, attn_mask = enc_mask, is_causal = enc_out_is_causal)
        x = self.cross_attn_norm(x + self.cross_attn_dropout(cross_attn_output))
        
        # Feed-forward network
        ffn_output = self.linear2(self.ffn_dropout1(self.activation(self.linear1(x))))
        x = self.ffn_norm(x + self.ffn_dropout2(ffn_output))
        return x

Now we have implemented the main components of the transformer architecture!

Of course, we have put everything together into a single architecture. To do this, we also need:

* Layers to embed input and output sequences to pass into the encoder and decoder respectively

* Positional encodings to add to the embedded sequences

* The output linear transformation and softmax layers

We will illustrate each of these, and then you will write an `nn.Module` for a whole transformer model. To start with, we can use `nn.Embedding` from PyTorch to embed our input/output sequences. For demonstration, let's consider an NLP problem:

* Assume we have a limited vocabulary consisting of just 5 words, from the sentence "my new cat is black"; this means our vocab size $d_\text{vocab}$ = 5

* We need to convert this to a numerical vector, so we will assign an integer to each unique word in sentence

* Finally, we can pass the integer sequence through `nn.Embedding` to embed to $d_\text{model}$ dimensions

For this example, we will use $d_\text{model} = 16$:

In [23]:
# Define example sequence
sequence = [word for word in "My new cat is black".split(' ')]

# Encode words
word_mapping = {word : i for i, word in enumerate(sequence)}
mapped_sequence = torch.tensor([word_mapping[word] for word in sequence])

# Find embeddings
embedding_dim = 16
d_vocab = len(word_mapping)
embeddings = nn.Embedding(num_embeddings=d_vocab, embedding_dim=embedding_dim)
embedded_sequence = embeddings(mapped_sequence)
print(embedded_sequence.shape)  # Should be (sequence_length, embedding_dim)

torch.Size([5, 16])


Inside our actual transformer, we would only want to include the `nn.Embedding` part, as this is the actual part of the model; finding a vocab mapping to integers is a task for the data pre-processing. In situations where our sequences are numerical values, we won't need to do the mapping to a set of integers.  

To find positional encodings as described in the paper, we can use the following module:

In [37]:
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout = 0.1, max_len = 5000):
        super().__init__()
        # Dropout after adding positional encoding
        self.dropout = nn.Dropout(p=dropout)
        # Create positional encoding matrix up to max_len to save
        # computation time when calling forward()
        position = torch.arange(max_len).unsqueeze(1)
        # equivalent to e^(i*-ln(10000)/d_model) = e^(ln(10000^(-i/d_model))) = 1/10000^(i/d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        # fill even embedding dims with sine, odd embedding dims with cosine
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        # Need to include this line to ensure the positional encoding is transferred
        # to the same device as the model, in case we use GPU
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        x = x + self.pe[:,:x.size(1)]
        return self.dropout(x)

Let's test it on our embedded sequence:

In [39]:
# Assume our max sequence length is 10
pos_encode = PositionalEncoding(d_model=16, dropout=0.1, max_len=10)
embed_seq_batch = embedded_sequence.unsqueeze(0)  # (1, seq_length, embedding_dim)
pos_encoded_seq = pos_encode(embed_seq_batch)
print(pos_encoded_seq.shape)  # Should be (1, sequence_length, embedding_dim)

torch.Size([1, 5, 16])


We know the final output of the decoder needs to be passed through a linear transformation and then a softmax to retrieve weights for all possible tokens for all elements in the sequence; this means our final transformation has to go from $d_\text{model}$ to $d_\text{vocab}$, i.e. it is essentially an inverse of the embedding transformation. We can therefore easily add a linear layer after our decoder, followed by a softmax, to get the final output. 

<div style="background-color:#C2F5DD">

### Exercise

Now that we have all of the pieces we need to construct a full transformer architecture in PyTorch, fill in the gaps in the class definition below to do so. Remember the key steps:

* Embed the input sequence to $d_\text{model}$ dimensions

* Pass the embedded input sequence through $N$ transformer encoder layers, with $d_\text{model}$ embedding dimensions, $h$ heads, and a FFN hidden dimension of $d_{ff}$

* Embed the output (target) sequence to $d_\text{model}$ dimensions

* Pass the embedded output sequence and the final encoder output through $N$ transformer decoder layers, with $d_\text{model}$ embedding dimensions, $h$ heads for both attention blocks, and a FFN hidden dimension of $d_{ff}$

* Pass the final decoder output through the output linear transformation and softmax

In [40]:
class TransformerModel(nn.Module):
    def __init__(self, d_vocab, d_model, num_heads, d_ff, num_encoder_layers, num_decoder_layers, dropout = 0.1, max_len = 5000):
        super().__init__()
        self.input_embedding = nn.Embedding(d_vocab, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len)
        self.encoder = nn.ModuleList([
            TransformerEncoderLayer(d_model, num_heads, d_ff, dropout) 
            for _ in range(num_encoder_layers)
        ])
        self.decoder = nn.ModuleList([
            TransformerDecoderLayer(d_model, num_heads, d_ff, dropout) 
            for _ in range(num_decoder_layers)
        ])
        self.output_linear = nn.Linear(d_model, d_vocab)

    def forward(self, src, tgt, src_mask = None, tgt_mask = None, src_is_causal = False, tgt_is_causal = True):
        # Embed and add positional encoding to source
        src = self.input_embedding(src)
        src = self.pos_encoder(src)
        # Pass through encoder layers
        for layer in self.encoder:
            src = layer(src, src_mask, src_is_causal)
        enc_out = src
        # Embed and add positional encoding to target
        tgt = self.input_embedding(tgt)
        tgt = self.pos_encoder(tgt)
        # Pass through decoder layers
        for layer in self.decoder:
            tgt = layer(tgt, enc_out, tgt_mask, None, tgt_is_causal, src_is_causal)
        # Output linear transformation & softmax
        output = self.output_linear(tgt)
        output = F.softmax(output, dim=-1)
        return output

Now we've got a full transformer architecture! While we have gone through constructing this, the example problems you can practically solve in these sessions simply don't need a full transformer encoder-decoder architecture. Instead, we will use simpler architectures to solve these problems.


## An aside on learning rate warm-up

A common technique used for training transformers is so-called **learning rate warm-up**. This is when the learning rate starts at 0, and is gradually increased to the desired value over the first few iterations. As a result, the initial learning is with much smaller steps rather than starting with large steps. Without this process, it is common for transformer gradients can be very large early on in training, causing poor results. You can read more about this in the literature, e.g. [this paper](https://arxiv.org/pdf/1908.03265) on how warm-up improves performance with the Adam optimizer.

The code snippet below defines a learning rate scheduler that combines warm-up with a cosine learning rate decay. This implementation is borrowed from the [University of Amsterdam tutorial on transformers](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html). 

The code block that follows is an example of how this might be used in a training loop, but **won't run**.

In [41]:
class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler):

    def __init__(self, optimizer, warmup, max_iters):
        self.warmup = warmup
        self.max_num_iters = max_iters
        super().__init__(optimizer)

    def get_lr(self):
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]

    def get_lr_factor(self, epoch):
        lr_factor = 0.5 * (1 + math.cos(math.pi * epoch / self.max_num_iters))
        if epoch <= self.warmup:
            lr_factor *= epoch * 1.0 / self.warmup
        return lr_factor

In [None]:
# THIS BLOCK WILL NOT RUN, AND IS FOR ILLUSTRATIVE PURPOSES ONLY

# LR schedular usage example
model = some_model
optimizer = optim.Adam(model.parameters(), lr = some_lr)
scheduler = CosineWarmupScheduler(optimizer, warmup = 10, max_iters=100)
model.train()
for epoch in range(n_epochs):
    for batch in data_loader:
        # forward pass
        ...
        # compute loss
        ...
        # backward pass and optimization step
        ...
    scheduler.step()

In other words, our learning rate scheduler connects to our optimizer, and then after we have iterated over all of our training batches in a given epoch we need to call `.step()` on the scheduler. This will update the learning rate as necessary.

Now, we have all the tools we need to train a simple transformer for a relatively straightforward problem: reversing the order of a list. 

This exercise is borrowed from the same [UvA tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html) as the learning rate scehduler.

<div style="background-color:#C2F5DD">

### Exercise

The code cell below defines a dataset consisting of a sequence of random numbers in a given range, where the objective is to learn how to reverse the order of the sequence. We will tackle this with a very simple transformer architecture, just using a single encoder layer with only 1 head. 

Firstly, we generate the data:

In [42]:
import torch.utils.data as data

torch.manual_seed(0)

class ReverseDataset(data.Dataset):
    def __init__(self, num_categories, seq_len, size):
        super().__init__()
        self.num_categories = num_categories
        self.seq_len = seq_len
        self.size = size

        self.data = torch.randint(self.num_categories, size=(self.size, self.seq_len))

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        inp_data = self.data[idx]
        labels = torch.flip(inp_data, dims=(0,))
        return inp_data, labels
    
n_categories = 10
seq_length = 16
dataset = ReverseDataset(num_categories=n_categories, seq_len=seq_length, size=60000)

<div style="background-color:#C2F5DD">

Split this dataset into a training and test dataset, with 50000 training samples and 10000 test samples.

After the dataset split, make DataLoaders for each with a batch size of 128. Make sure that the training DataLoader shuffles data during iteration.

In [43]:
# Your code for training/test split and DataLoaders
torch.manual_seed(0)

n_train = 50000
n_test = 10000
train_dataset, test_dataset = data.random_split(dataset, [n_train, n_test])

train_loader = data.DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last = True)
test_loader = data.DataLoader(test_dataset, batch_size=128)

<div style="background-color:#C2F5DD">

Now we have our data, we need to define a model. Fill out the class template below to define a simple transformer model for this problem, with the following architecture:

* Embedding layer to go from the sequence input dimension `input_dim` to the model dimension `d_model`

* Positional encoding layer

* `n_layer` transformer encoder layers, where the FFN hidden dimension `d_ff` = 2*`d_model`

* An output linear transformation from `d_model` to `n_classes`

After writing the class, make an instance of the model with the following parameter values:

* `input_dim` = `n_categories`

* `d_model` = 32

* `n_heads` = 1

* `n_classes` = `n_categories`

* `n_layer` = 1

* `dropout` = 0.0

In [44]:
class ReversePredictor(nn.Module):
    def __init__(self, input_dim, d_model, n_heads, n_classes, n_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        self.transformer = nn.ModuleList([
            TransformerEncoderLayer(d_model, n_heads, d_model*2, dropout) 
            for _ in range(n_layers)
        ])
        self.fc_out = nn.Linear(d_model, n_classes)

    def forward(self, src):
        x = self.embedding(src)
        x = self.pos_encoder(x)
        for layer in self.transformer:
            x = layer(x)
        out = self.fc_out(x)
        return out

In [45]:
# Instantiate your model

model = ReversePredictor(
    input_dim=n_categories,
    d_model = 32,
    n_heads = 1,
    n_classes = n_categories,
    n_layers = 1,
    dropout = 0.0)

<div style="background-color:#C2F5DD">

Now, train the model for 10 epochs. Use the following training parameters:

* Train for 10 epochs

* The Adam optimizer, with a learning rate of 0.001

* CosineWarmUp learning rate scheduler, with `warmup` = 50 and `max_iters` = `n_epochs` $\times$ `len(train_loader)`

* Use CrossEntropyLoss

Your training loop should be the same as normal, with the exception of stepping the learning rate scheduler after processing all batches in an epoch. 

**Note**: when we get the output from the model it will have shape (`batch_size`, `seq_length`, `n_categories`), and the target will have shape (`batch_size`, `seq_length`). `CrossEntropyLoss` requires the model output have shape (`batch_size`, `n_categories`, `seq_length`), so make sure to transpose the correct axes before passing through the loss function.



In [46]:
# Your training loop here

n_epochs = 10
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = CosineWarmupScheduler(optimizer, warmup=50, max_iters = n_epochs * len(train_loader))

def train_epoch():
    model.train()
    total_loss = 0
    criterion = nn.CrossEntropyLoss()
    n_correct = 0
    for batch_idx, (data, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output.transpose(1,2), targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        n_correct += (output.argmax(dim=-1) == targets).sum().item()
    scheduler.step()
    accuracy = n_correct / (len(train_loader.dataset) * seq_length)
    return total_loss / len(train_loader), accuracy

train_losses = []
for epoch in range(n_epochs):
    train_loss, train_acc = train_epoch()
    train_losses.append(train_loss)
    print(f'Epoch {epoch+1}/{n_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2%}')

Epoch 1/10, Train Loss: 2.4366, Train Accuracy: 9.80%
Epoch 2/10, Train Loss: 2.3912, Train Accuracy: 9.88%
Epoch 3/10, Train Loss: 2.3293, Train Accuracy: 10.85%
Epoch 4/10, Train Loss: 2.2848, Train Accuracy: 14.02%
Epoch 5/10, Train Loss: 2.1950, Train Accuracy: 20.39%
Epoch 6/10, Train Loss: 1.8657, Train Accuracy: 34.12%
Epoch 7/10, Train Loss: 1.2247, Train Accuracy: 54.42%
Epoch 8/10, Train Loss: 0.6454, Train Accuracy: 75.88%
Epoch 9/10, Train Loss: 0.3021, Train Accuracy: 89.62%
Epoch 10/10, Train Loss: 0.1430, Train Accuracy: 94.74%


<div style="background-color:#C2F5DD">

Finally, evaluate the performance of your trained model on the test dataset. To get the predicted sequence from the model outputs, find the `torch.argmax` of the model output along the final axis. Calculate both the test loss and the test acccuracy. How accurate is the model prediction?

In [47]:
# Your evaluation code here

def eval_model():
    model.eval()
    total_loss = 0
    criterion = nn.CrossEntropyLoss()
    n_correct = 0
    with torch.no_grad():
        for batch_idx, (data, targets) in enumerate(test_loader):
            output = model(data)
            loss = criterion(output.transpose(1,2), targets)
            total_loss += loss.item()
            n_correct += (output.argmax(dim=-1) == targets).sum().item()
    accuracy = n_correct / (len(test_loader.dataset) * seq_length)

    return total_loss / len(test_loader), accuracy

test_loss, test_accuracy = eval_model()
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2%}')

Test Loss: 0.0997, Test Accuracy: 96.36%


The next notebook will include more exercises looking at different problems using transformers.