# Introduction to transformers

The transformer {cite}`vaswani2017attention` is a deep learning architecture which has powered many of the recent advances across a range of machine learning applications, including text modelling, image modelling {cite}`dosovitskiy2021image`, and many others.
This is an overview of the transformer architecture, including a self-contained mathematical description of the architectural details, and a concise implementation.
All of this exposition is based off an excellent introduction paper on transformers by Rich Turner {cite}`turner2023introduction`.

## Modelling with tokens
__One architecture, many applications.__
The purpose of the transformer architecture was, originally, to model sequence data such as text.
The approach for achieving this was to first convert individual words, or characters, into one-dimensional arrays called _tokens_, and then operate on these tokens with a neural network.
This approach however extends beyond word modelling.
For example, the transformer can be applied to tasks as diverse as modelling of images and video, proteins, or weather.
In all these applications, the data are first converted into sets of tokens.
After this step, the transformer can be applied in roughly the same way, irrespective of the original representation of the data.
This versatility, together with their empirical performance, are some of the main appealing features of the transformer.

__Inputs as tokens.__
In particular, for the moment, we will assume that the input data have already been converted into tokens and defer the details of this tokenisation for later.
More concretely, let us assume that each data example, e.g. a sentence, image, or protein,  has been conerted into a set of tokens $\{x_n\}_{n=1}^N,$ where each $x_n$ is a $D$ dimensional array $x_n \in \mathbb{R}^D.$
We can collect these tokens into a single $D \times N$ array $X^{(0)} \in \mathbb{R}^{D \times N},$ forming a single data input for the transformer.

## Transformer block
Much like in other deep architectures, the transformer maintains a representation of the input data, and progressively refines it using a sequence of so-called _transformer blocks_.
In particular, given an initial representation $X^{(0)}$ the archtecture comprises of $M$ transformer blocks, i.e. for each $m = 1, \dots, M,$ it computes

$$X^{(m)} = \texttt{TransformerBlock}(X^{(m-1)}).$$

Each of these blocks consists of two main operations, namely a self-attention operation and a token-wise multi-layer perceptron (MLP) operation.
The self-attention operation has the role of combining the representations of different tokens in a sequence, in order to model dependencies between the tokens.
It is applied collectively to all tokens within the transformer block.
The MLP operation has the role of refining the representation of each token.
It is applied separately to each token and is shared across all tokens within a transformer block.
Let's look at these two operations in detail.

### Self-attention

__Attention.__
The role of the first operation in a transformer block is to combine the representations of different tokens in order to model dependencies between them.
Given a $D \times N$ input array $X^{(m)} = (x_1, \dots, x_N^{(m)})$ the output of the self-attention layer is another $D \times N$ array $Y^{(m)} = (y_1, \dots, y_N^{(m)}),$ where each column is simply a weighted average of the input features, that is

$$y^{(m)}_n = \sum_{n' = 1}^N x^{(m - 1)}_{n'} A_{n', n}^{(m)}.$$

The weighting array $A_{n', n}^{(m)}$ is of size $N \times N$ and has the property that its columns normalise to one, that is $\sum_{n'=1}^N A_{n', n}^{(m)} = 1.$
It is referred to the attention matrix because it weighs the extent to which the feature $y^{(m)}_n$ should depend on each $x^{(m)}_{n'},$ i.e. it determines the extent to which each $y^{(m)}_n$ should attend to each $x^{(m)}_{n'}.$
For compactness, we can collect these equations to a single linear operation, that is

$$Y^{(m)} = X^{(m - 1)} A^{(m)}.$$

But what about the attention weights themselves?
We have not specified how these are computed and, their precise definition is going to be one important factor that differentiates transformers from other architectures.
In fact, many other operations forming the core of other archictectures, such as convolution layers in convolutional neural networks (CNNs), can be written as similar weighted sums.
Let's next look at the specifics of the transformer attention weights.

__Self-attention.__
One of the innovations within the transformer architecture is that the attention weights are adaptive, meaning that they are computed based on the input itself.
This is in contrast with other deep learning architectures such CNNs, where weighted sums are also used, but these weights are fixed and shared across all inputs.
One straightforward way to compute attention weights would be to compare them by a simple similarity metric, such as an inner product.
For example, given two tokens $x_i$ and $x_j,$ we can compute a dot-product between them, which acts as a similarity metric, exponetiate the result to make it positive and then normalise the result to ensure that each column sums to one, that is

$$A^{(m)}_{n, n'} = \frac{\exp(x_n^\top x_{n'})}{\sum_{n'' = 1}^N \exp(x_{n''}^\top x_{n'})}.$$

An alternative, slightly more flexible approach is to transform each token in the sequence by a linear map, say by applying a matrix $U \in \mathbb{R}^{K \times D}$ to each token first, that is

$$A^{(m)}_{n, n'} = \frac{\exp(x_n^\top U^\top U x_{n'})}{\sum_{n'' = 1}^N \exp(x_{n''}^\top U^\top U x_{n'})}.$$

This allows the tokens to be compared in a different space.
For example, if $K < D$ this approach automatically projects out some of the components of the tokens, comparing them in a lower-dimensional space.
However, this approach still has an important limitation, namely symmetry.
Specifically, the attention matrix above would be symmetric, which means that any two tokens would attend to each other with equal strengths.
This might be undesirable because, for example, we could imagine that one token might be important for informing the representation of another token, but not the other way around.
To address this, we can apply different linear operations, say $U_k$ and $U_q$ to each of the tokens being compared, and instead compute

$$A^{(m)}_{n, n'} = \frac{\exp(x_n^\top U_k^\top U_q x_{n'})}{\sum_{n'' = 1}^N \exp(x_{n''}^\top U_k^\top U_q x_{n'})}.$$

In this way, the resulting attention matrix that is not necessarily symmetric and an overall more expressive architecture.
Tokens no longer have to attend to each other with the same strength.
This weighting is known as self-attention, since each token in the sequence attends to every other token of the same sequence.
It is also possible to generalise this to attention between different sequences, which might be useful for some applications such as, for example joint modelling of text and images.
This generalisation is called cross-attention, and we defer its discussion for later.

__Multi-head self-attention.__
In order to increase the capacity of the self-attention layer, the transformer block includes $H$ separate self-attention operations with different parameters, in parallel.
The results of these operations are then projected down to a single $D \times N$ array again, which is required for further processing.
In particular, we have

```{margin}
As a recap to the notation in these equations: the $m$ superscript runs from $1$ to $M$ and is the index of the transformer block, the $n, n'$ and $n''$ superscripts run from $1$ to $N$ and index the tokens in the sequence within the current block, the $h$ subscript runs from $1$ to $H$ and denotes a particular self-attention head in the block.
Finally the $k$ and $q$ subscripts are not indices, but symbols distinguishing the two different kinds of matrices $U_k$ and $U_q.$
```

$$\begin{align}
Y^{(m)} = \texttt{MHSA}(X^{(m - 1)}) &= \sum^H_{h = 1} V_h^{(m)} X^{(m - 1)} A_h^{(m)}, \text{ where } \\
\left[A^{(m)}_h\right]_{n, n'} &= \frac{\exp\left(k_{h, n}^{(m)\top} q_{h, n'}^{(m)}\right)}{\sum_{n'' = 1}^N \exp\left(k_{h, n''}^{(m)\top} q_{h, n'}^{(m)}\right)} 
\end{align}$$

where $q_{h, n}^{(m)} = U^{(m)}_{q, h} x_n^{(m-1)}$ and $k_{h, n}^{(m)} = U^{(m)}_{k, h} x_n^{(m-1)}.$
At this point we should note that, due to the nonlinearity of $A^{(m)},$ together with the multiplication by $V^{(m)}_h$ and summation across $h,$ multi-head cross attention performs not just inter-feature but also intra-feature processing, i.e. each token interacts with and changes its own representation.
However, the capacity of this intra-feature processing is limited, and it is the job of the second stage, the MLP, to address this.
Let's next look at the MLP stage.

### Multi-layer perceptron
The self-attention layer has the role of aggregating information across tokens in a sequence to model joint dependencies.
In order to refine the representations themselves, a simple MLP is applied to each token in isolation, in a relatively simple step

$$x^{(m)_n} = \texttt{MLP}(y^{(m)}_n).$$

Note that this MLP is shared across all input locations, i.e. all tokens, within a given layer.

### Residuals and normalisation
Before putting together the $\texttt{MHSA}$ and $\texttt{MLP}$ operations, we will add two ubiquitous deep learning operations to improve the stability and ease of training of the model, namely residual connections and normalisation.

__Residual connections.__
Residual connections {cite}`he2015deep` are widely used across deep learning architectures, because they simplify model initialisation, stabilise learning and provide a useful inductive bias toward simpler functions.{cite}`szegedy2017inception`
Instead of specifying a mapping of the form $x^{(m)} = f(x^{(m)}),$ a residual connection amounts to specifying a function involving an identity function plus a residual term

$$x^{(m)} = x^{(m-1)} + g(x^{(m)}).$$

This can be equivalently thought of as learning to model differences between the representations at different blocks, that is $x^{(m)} - x^{(m-1)} = g(x^{(m)}).$
If we do not use residual connections and compose multiple blocks together, the activations in each can become more extreme as we go deeper in the network, resulting in either zero or extremely large gradients, which can be problematic during training.
One motivation for using residual connections is that, if we initialise the parameters of $g$ such that its outputs are close to zero, then $x^{(m)}$ will be approximately constant across $m = 1, \dots, M.$
This can improve training ease and stability because all blocks in the network, even the deeper ones, receive an input close to $x^{(0)},$ and the gradients will tend to receive less extreme gradients.
Residual connections are used both in the $\texttt{MHSA}$ and $\texttt{MLP}$ layers of the transformer.

__Token normalisation.__
Another ubiquitous and extremely useful operation in deep learning is normalisation.
There are various different kinds of normalisation, including LayerNorm {cite}`ba2016layer`, BatchNorm {cite}`ioffe2015batch`, GroupNorm {cite}`wu2018group` and InstanceNorm {cite}`ulyanov2016instance`.
Normalisation has been widely found to improve learning stability and overall model performance.
One reason for this is that normalisation typically prevents the inputs to a layer from becoming extremely large, which can result into extreme or staturated outputs, which in turn mean that the gradients with respect to the network parameters can be close to zero or extremely large.
The transformer architecture uses LayerNorm which, when applied to the tokens, amounts to per-token normalisation.
Specifically, when applied to an array $X$ of input tokens, LayerNorm amounts to

$$\texttt{LayerNorm}(X)_{d, n} = \bar{x}_{d, n} = \frac{x_{d, n} - \mu(x_n)}{\sigma(x_n)} \gamma_d + \beta_d,$$

where $\mu$ and $\sigma$ denote operations that compute the mean and the standard deviation respectively, and $\gamma_d$ and $\beta_d$ are a learnt scale and a learnt shift.
In other words, within a transformer, LayerNorm separately normalises each token within each sequence within each batch.


### Putting it together
In summary, we can collect these operations into the following equations

$$\begin{align}
\bar{X}^{(m-1)} &= \texttt{LayerNorm}\left(X^{(m-1)}\right) \\
Y^{(m)} &= \bar{X}^{(m-1)} + \texttt{MHSA}\left(\bar{X}^{(m-1)}\right) \\
\bar{Y}^{(m)} &= \texttt{LayerNorm}\left(Y^{(m)}\right) \\
X^{(m)} &= Y^{(m)} + \texttt{MLP}(\bar{Y}^{(m)})
\end{align}$$

These make up the entirety of the transformer block, which is repeated $M$ times to compute the output of the transformer.
An important detail we have not discussed thus far is how to build the tokens themselves.


### Tokens and embeddings

__Tokenisation.__
Tokenisation is an application-specific detail but, generally, there are two main approaches, depending on whether the inputs are continuous or discrete.
As a reminder, in both cases, we want convert each input element in our sequence, say $s_n,$ to a $D$-dimensional array $x^{(0)}_n.$
We will specify a map $\texttt{tokenise}$ that performs the operation $s_n = \texttt{tokenise}(x^{(0)}_n)$ separately for the case where the inputs $s_n$ are discrete or continuous.

__Discrete or continuous inputs.__
In text modelling the raw inputs are integers representing unique words or characters.
In such applications, i.e. whenever we have discrete inputs, we can use a look-up table containing learnable vectors.
That is, if $s_n \in \{1, \dots, K\},$ we can define $K$ arrays, each of length $D$, say $z_0, \dots, z_K \in \mathbb{R}^D,$ and let

$$x^{(0)}_n = \texttt{tokenise}(s_n) = z_{s_n}.$$

This allows us to map each word into a continuous space and operate on the resulting arrays with the transformer architecture.
In other applications, such as vision, the inputs are typically treated as continuous, that is $s_n \in \mathbb{R}^{D_s}.$
In such cases, we can simply apply a simple operation such as a linear transformation, to map each $s_n$ into a $D$-dimensional array.
For example, letting $W \in \mathbb{R}^{D\times D_s},$ we can define

$$x^{(0)}_n = W s_n,$$

giving a $D$-dimensional token which is ready for use in the transformer.
We have now covered almost all parts of the transformer, except one final, but very important point concerning the embeddings.
Thus far, we have glossed over the fact that the transformer block has no notion of position, which is a very important issue that we look into next.

__Positional embeddings.__
Specifically, the $\texttt{MHSA}$ operation, the token-wise $\texttt{MLP}$ operation, as well as $\texttt{LayerNorm}$ and residual additions are all examples of permutation equivariant: permuting the tokens and applying any one of these operations gives the same result as first applying the operation and then permuting the resulting tokens.
Composing these operations retains permutation equivariance, meaning that permuting the elements of the original sequence and applying the transformer will yield exactly the same result as first applying the transformer and then permuting the resulting features.
This is undesirable because, for example in text modelling, the phrases "Arsenal bets Chelsea" and "Chelsea beats Arsenal" are composed of identical words but have opposite meanings, and we would like the resulting features produced by the transformer to reflect this.
One way to get around this issue is augmenting the tokens with information about the position of an input feature within the sequence.
For example, we could set up an additional embedding which directly maps each position to a learnable array and concatentate the result with the tokenised feature, that is

$$x^{(0)}_n = \texttt{tokenise}_1(s_n) \odot \texttt{tokenise}_2(n),$$

where $\odot$ denotes concatenation, and we have used different tokenisation functions for the sequence features and the positions.
This approach is often used in vision transformers.
Another approach is applying, for example, sinusoidal functions with different frequencies on the input, for example

$$\texttt{tokenise}(n) = [\sin(\omega_1 n), \dots, \sin{\omega_D n}],$$

which are then concatentated to the tokenised features as described above.
Other approaches bake in positional information directly into the $\texttt{MHSA}$ layer, for example by making the attention weights depend on the position difference of pairs of tokens.


## Implementation

Now that we've covered all the details, let's implement a small transformer!

In [6]:
from typing import List

import jax
import equinox as eqx
from jaxtyping import Float, Array


In [15]:
class SelfAttention(eqx.Module):

    Uk: Float[Array, "K D"]
    Uq: Float[Array, "K D"]

    def __init__(
        self,
        key: jax.random.PRNGKey,
        input_dim: int,
        projection_dim: int,
    ):

        # Set up keys for projection matrices
        key1, key2 = jax.random.split(key)

        # Initialize projection matrices Uk and Uq
        self.Uk = eqx.initializers.xavier_normal(
            key1,
            (projection_dim, input_dim),
        )
        self.Uq = eqx.initializers.xavier_normal(
            key2,
            (projection_dim, input_dim),
        )

    def self_attention_weights(
        self,
        x: Float[Array, "D N"],
    ) -> Float[Array, "N N"]:
        """
        Compute self-attention weights for tokens in a sequence

        Args:
            x: input sequence of tokens, shape (D, N)
        
        Returns:
            attention weights, shape (N, N)
        """

        # Compute the keys and queries
        k = jax.numpy.matmul(self.Uk, x)
        q = jax.numpy.matmul(self.Uq, x)

        # Compute inner product of keys and queries
        kq = jax.numpy.matmul(k.T, q)

        # Attention weights are the softmax of the inner products
        return jax.nn.softmax(kq, axis=0)
    
    def __call__(self, x: Float[Array, "D N"]) -> Float[Array, "D N"]:
        """
        Apply self-attention to a sequence of tokens

        Args:
            x: input sequence of tokens, shape (D, N)

        Returns:
            output sequence of tokens, shape (D, N)
        """

        a = self.self_attention_weights(x)
        return jax.numpy.matmul(x, a)


class MultiHeadSelfAttention(eqx.Module):
    
    self_attention_layers: List[SelfAttention]
    linear_layers: List[eqx.nn.Linear]

    def __init__(
        self,
        key: jax.random.PRNGKey,
        input_dim: int,
        projection_dim: int,
        num_heads: int,
    ):

        keys = jax.random.split(key, 2*num_heads)
        self.self_attention_layers = [
            SelfAttention(
                key=key,
                input_dim=input_dim,
                projection_dim=projection_dim,
            ) for key in keys[::2]
        ]

        self.linear = [
            eqx.Linear(key, input_dim, input_dim) for key in keys[1::2]
        ]

    def __call__(self, x: Float[Array, "D N"]) -> Float[Array, "D N"]:
        """
        Apply multi-head self-attention to a sequence of tokens

        Args:
            x: input sequence of tokens, shape (D, N)

        Returns:
            output sequence of tokens, shape (D, N)
        """
            
        # Compute tokens for each head
        heads = [layer(x) for layer in self.self_attention_layers]

        # Apply linear transformation to each head
        heads = [linear(h) for h, linear in zip(heads, self.linear)]

        # Stack and sum across heads
        heads = jax.numpy.stack(heads, axis=0)
        heads = jax.numpy.sum(heads, axis=0)

        return heads

In [16]:
class MLP(eqx.Module):

    layers: List[eqx.nn.Linear]

    def __init__(
        self,
        key: jax.random.PRNGKey,
        num_hidden: int,
        num_layers: int,
        num_features: int,
    ):
        # Set up input and output dimensions of linear layers
        in_feats = [num_features] + [num_hidden] * num_layers
        out_feats = [num_hidden] * num_layers + [num_features]

        # Split the random key into sub-keys for each layer
        keys = jax.random.split(key, num_layers)
        
        # Create linear layers with different random keys
        self.layers = [
            eqx.nn.Linear(
                key=key,
                in_features=in_feat,
                out_features=out_feat,
            )
            for key, in_feat, out_feat in zip(keys, in_feats, out_feats)
        ]

    def __call__(self, x: Float[Array, "D"]) -> Float[Array, "D"]:
        """
        Compute forward pass through the MLP.

        Args:
            x: input tensor of shape (in_features,)
        
        Returns:
            output tensor of shape (out_features,)
        """
        for layer in self.layers[:-1]:
            x = layer(x)
            x = jax.nn.relu(x)
        return self.layers[-1](x)

## Extensions

## Relations to other architectures

## References

```{bibliography}
:filter: docname in docnames
```