# Build your own Transformer

Created by [Gerard I. Gállego](https://www.linkedin.com/in/gerard-gallego/) for the [Postgraduate Course in Artificial Intelligence with Deep Learning](https://www.talent.upc.edu/ing/estudis/formacio/curs/310400/postgrau-artificial-intelligence-deep-learning/) ([UPC School](https://www.talent.upc.edu/ing/), 2021).

In this lab we will learn about the [Transformer](https://papers.nips.cc/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html), a popular architecture that revolutionized Deep Learning a few years ago.

This architecture was firstly designed for Machine Translation. In this field, Recurrent Neural Networks (e.g. LSTM) had been the state-of-the-art since [the introduction of the Attention mechanism](https://arxiv.org/abs/1409.0473) in 2015. The Transformer surpassed them by introducing a key idea: getting rid of any recurrence and mainly using the **(Self-)Attention** mechanism, as the title of the paper states: [Attention is All you Need](https://papers.nips.cc/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html).

Actually, Transformer-based architectures are used in many fields beyond Machine Translation. First, many models arised for other text-related tasks (e.g. [BERT](https://arxiv.org/abs/1810.04805) or [GPT-3](https://arxiv.org/abs/2005.14165)), but now they're also used in other fields, like [Speech processing](https://proceedings.neurips.cc/paper/2020/hash/92d1e1eb1cd6f9fba3227870bb6d7f07-Abstract.html) and [Computer Vision](https://arxiv.org/abs/2010.11929).

Throughout this notebook, we will build our own Transformer, and you'll understand module by module how this architecture works. Once it's finished, we will train it with a dummy dataset and we will try to interpret the results.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import math
import numpy as np
import seaborn as sns
import matplotlib.pylab as plt

## Transformer architecture

Here we go!

In the left part of the figure below, you can see the Transformer architecture. Take a look at it carefully to get used to all this new terminology.

You can also see a breakdown of the most important module in the Transfomer: the Multi-Head Attention, which is based on the Scaled-Dot Product Attention.

<p align="center">
<img src="https://lilianweng.github.io/lil-log/assets/images/transformer.png" width="1000px" alt="Zoom in to the Transformer"/>
</p>

Don't panic! ;) We will start from the most simple structure and we will build upon it, little by little. Let's start with the Scaled Dot-Product Attention!

### Scaled Dot-Product Attention

The first key idea to understand how the Transformer works is the Scaled Dot Product Attention (we will call it SDPA from now on).

<p align="center">
<img src="https://paperswithcode.com/media/methods/SCALDE.png" height="300px" alt="Scaled Dot-Product Attention"/>
</p>

This is the equation:

$$Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V$$

You are probably trying to figure out what $Q$, $K$ and $V$ are, right? They are the queries, the keys and the values. A good example to understand these concepts is the one given by [@dontloo](https://stats.stackexchange.com/users/95569/dontloo) in StackExchange

>[...]
>
>The key/value/query concepts come from retrieval systems. For example, when you type a query to search for some video on Youtube, the search engine will map your query against a set of keys (video title, description, etc.) associated with candidate videos in the database, then present you the best matched videos (values).
>
>The attention operation turns out can be thought of as a retrieval process as well, so the key/value/query concepts also apply here.
>
>[...]
>
>[[Read the full answer]](https://stats.stackexchange.com/a/424127)

This keys/values/query can represent tokens in a sentence. Furthermore, the attention function can be computed on a set of queries simultaneously. Hence, they are packed together into a matrix $Q$, like the keys ($K$) and the values ($V$). Take into account that you need the same amount of keys and values, but the number of queries can differ.

Also, why is the scaling needed? Well, it turns out that for high-dimensional keys (large $d_k$) the dot-product grow large in magnitude, hurting the gradients.

Finally, note that in the figure there is a module called "Module (opt.)" which doesn't appear in the equation. This allows us to control which values can the queries "attend" to. This will be useful, for example, to avoid attention to padding tokens, which we use to batch sentences of different length.

In [None]:
def scaled_dot_product(q, k, v, mask=None):
    """ Computes the Scaled Dot-Product Attention

    Args:
        q (torch.FloatTensor):  Query Tensor   (... x T_q x d_q)
        k (torch.FloatTensor):  Key Tensor     (... x T_k x d_k)
        v (torch.FloatTensor):  Value Tensor   (... x T_v x d_v)
        mask (torch.BoolTensor): Attention mask (... x T_q x T_k)

    Returns:
        torch.FloatTensor: Result of the SDPA  (... x T_q x d_v)
        torch.FloatTensor: Attention map       (... x T_q x T_k)

    """
    assert q.size(-1) == k.size(-1), "Query and Key dimensions must coincide"

    # TODO: Matrix multiplication of the queries and the keys (use torch.matmul)
    attn_logits = 

    # TODO: Scale attn_logits (see the SDPA formula, d_k is the last dim of k)
    attn_logits = 

    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask, -float("inf"))

    # TODO: Compute the attention weights (see the SDPA formula, use dim=-1)
    attention =

    output = torch.matmul(attention, v)

    return output, attention

In [None]:
def plot_attention(attention, queries, keys, xtitle="Keys", ytitle="Queries"):
    """ Plots the attention map
    
    Args:
        att (torch.FloatTensor): Attention map (T_q x T_k)
        queries (List[str]): Query Tensor
        keys (List[str]): Key Tensor
    """

    sns.set(rc={'figure.figsize':(12, 8)})
    ax = sns.heatmap(
        attention.detach().cpu(),
        linewidth=0.5,
        xticklabels=keys,
        yticklabels=queries,
        cmap="coolwarm")

    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
    ax.set_xlabel(xtitle) 
    ax.set_ylabel(ytitle)

    plt.show()

Let's create some random queries, keys and values, with the following dimensions:
- $T_Q=5$
- $T_K=T_V=8$
- $d_Q=d_K=d_V=4$

We will use them to test our SDPA function.

In [None]:
q = torch.randn(5, 4)
k = torch.randn(8, 4)
v = torch.randn(8, 4)

In [None]:
output, attention = scaled_dot_product(q, k, v)

print(f"Output:\n{output}\n{output.shape}\n")
print(f"Attention weights:\n{attention}\n{attention.shape}\n")

plot_attention(
    attention,
    [str([round(float(q__), 1) for q__ in q_]) for q_ in q],
    [str([round(float(k__), 1) for k__ in k_]) for k_ in k],
)

After computing the SDPA, we get:
- The output, of dimensions $T_Q x d_V$
- The attention weights, which relate the queries and the keys, of dimensions $T_Q x T_K$

But you've already heard about Self-Attention again, right?

Basically, Self-Attention consists of using the same set of vectors as queries, keys and values. Let's try it:

In [None]:
x = torch.randn(5, 4)
output, attention = scaled_dot_product(q=x, k=x, v=x)

print(f"Output:\n{output}\n{output.shape}\n")
print(f"Attention weights:\n{attention}\n{attention.shape}\n")

plot_attention(
    attention,
    [str([round(float(q__), 1) for q__ in q_]) for q_ in q],
    [str([round(float(q__), 1) for q__ in q_]) for q_ in q],
)

But then... that's all, "Attention is all you need"?

Well, no, **it's not enough with just attention**... We need learnable parameters somewhere!

Actually, the inputs of the SDPA need to be projected with a Linear layer. This way we can get different representations from the same input vectors. We do something like this:

<p align="center">
<img src="https://miro.medium.com/max/1578/1*_92bnsMJy8Bl539G4v93yg.gif" height="600px" alt="Self-Attention"/>
</p>

Here you can find an implementation of this "learnable" SDPA. Be aware that this class will not used by the model we will implement, it's just an intermediate step we use now for didactic purposes.

In [None]:
class LearnableScaledDotProductAttention(nn.Module):
    def __init__(self, embed_dim):
        super(LearnableScaledDotProductAttention, self).__init__()
        self.proj_q = nn.Linear(embed_dim, embed_dim)
        self.proj_k = nn.Linear(embed_dim, embed_dim)
        self.proj_v = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, q, k, v, mask=None):
        q = self.proj_q(q)
        k = self.proj_k(k)
        v = self.proj_v(v)
        output, _ = scaled_dot_product(q, k, v, mask)
        return output

Let's test that it can learn now, by trying to reconstruct the input tensor with self-attention:

In [None]:
sdpa = LearnableScaledDotProductAttention(embed_dim=4)
optimizer = optim.Adam(sdpa.parameters())

losses_sdpa = []
n_epochs = 10000
for i in range(n_epochs):
    optimizer.zero_grad()
    output = sdpa(q=x, k=x, v=x)    # Self-attention
    loss = F.mse_loss(output, x)    # Reconstruct the input
    loss.backward()
    optimizer.step()
    losses_sdpa.append(loss.item())
    if (i + 1) % 1000 == 0:
        print(f"Loss ({i+1}/{n_epochs}): {loss.item()}")


print(f"\nOutput:\n{output}\n")
print(f"Query:\n{x}\n")

Ok, looks good, we can train it. It's starting to make sense now, right?

### Multi-Head Attention

To further exploit this attention mechanism, the original paper introduced the Multi-Head Attention mechanism (MHA).

Instead of performing the attention mechanism just once, they found it benefitial to project the input multiple times into different "attention heads". This way, multiple attentions can be learned at the same time. 

<p align="center">
<img src="https://paperswithcode.com/media/methods/multi-head-attention_l1A3G7a.png" height="400px" alt="Scaled Dot-Product Attention"/>
</p>

To combine the outputs of each head, they are concatenated and projected by a Linear transformation $W^O$, as defined by the following equation:

$$
MultiHead(Q, K, V ) = Concat(head_1, ..., head_h)W^O
$$
$$
where\ head_i = Attention(QW^Q_i, KW^K_i, V W^V_i)
$$

Note that $W^Q_i$, $W^K_i$, $W^V_i$ are the Linear projections we've seen in the previous section.

To better understand this, see the following illustration created by [Jay Alammar](https://jalammar.github.io/) for his famous blog post [The Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/):

<p align="center">
<img src="https://jalammar.github.io/images/t/transformer_multi-headed_self-attention-recap.png" height="600px" alt="Self-Attention"/>
</p>

In [None]:
class MultiheadAttention(nn.Module):

    def __init__(self, embed_dim, num_heads):
        super(MultiheadAttention, self).__init__()
        assert embed_dim % num_heads == 0, \
            "Embedding dimension must be multiple of the number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.proj_q = nn.Linear(embed_dim, embed_dim)
        self.proj_k = nn.Linear(embed_dim, embed_dim)
        self.proj_v = nn.Linear(embed_dim, embed_dim)
        self.proj_o = nn.Linear(embed_dim, embed_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        # Original Transformer initialization
        nn.init.xavier_uniform_(self.proj_q.weight)
        nn.init.xavier_uniform_(self.proj_k.weight)
        nn.init.xavier_uniform_(self.proj_v.weight)
        nn.init.xavier_uniform_(self.proj_o.weight)
        self.proj_q.bias.data.fill_(0)
        self.proj_k.bias.data.fill_(0)
        self.proj_v.bias.data.fill_(0)
        self.proj_o.bias.data.fill_(0)

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(1)

        q = self.proj_q(q)
        k = self.proj_k(k)
        v = self.proj_v(v)

        # TODO: Split the tensors into multiple heads
        #  T x B x embed_dim -> T x B x num_heads x head_dim
        q = q.reshape(...)
        k = k.reshape(...)
        v = v.reshape(...)

        # The last two dimensions must be sequence length and the head dimension,
        # to make it work with the scaled dot-product function. 
        # TODO: Rearrange the dimensions
        # T x B x num_heads x head_dim -> B x num_heads x T x head_dim
        q = q.permute(...)
        k = k.permute(...)
        v = v.permute(...)

        # Apply the same mask to all the heads
        if mask is not None:
            mask = mask.unsqueeze(1)
 
        # TODO: Call the scaled dot-product function (remember to pass the mask!)
        output_heads, attn_w = ...

        # B x num_heads x T x head_dim -> T x B x num_heads x head_dim
        output_heads = output_heads.permute(2, 0, 1, 3)

        # T x B x num_heads x head_dim -> T x B x embed_dim
        output_cat = output_heads.reshape(-1, batch_size, self.embed_dim)
        output = self.proj_o(output_cat)

        return output, attn_w

Let's test the same dummy example than before, trying to reconstruct the input tensor with self-attention, this time with a MHA module:

In [None]:
mha = MultiheadAttention(embed_dim=4, num_heads=2)
optimizer = optim.Adam(mha.parameters())

losses_mha = []
n_epochs = 10000
for i in range(n_epochs):
    optimizer.zero_grad()
    output = mha(                # Self-attention
        q=x.unsqueeze(1),
        k=x.unsqueeze(1),
        v=x.unsqueeze(1)
    )[0].squeeze(1)
    loss = F.mse_loss(output, x) # Reconstruct input
    loss.backward()
    optimizer.step()
    losses_mha.append(loss.item())
    if (i + 1) % 1000 == 0:
        print(f"Loss ({i+1}/{n_epochs}): {loss.item()}")

print(f"\nOutput:\n{output}\n")
print(f"Query:\n{x}\n")

Ok, seems fine, it learns!

At this point, you already know all you need about attention in the Transformer :D Now, we need to know where to apply it.

But before looking again at the Transformer architecture, we need to make a small stop at the Positional Encoding.

### Positional Encoding

As you may know, RNNs are designed in a way that they know the ordering of the input tokens. This is very important when processing sequences such as text sentences.

However, when using Self-attention, there is no positional information used between the queries, keys and values. For this reason, the Transformer authors needed to give positional information to the model explicitly.

They simply decided to create an embedding table, with the following equation, which is summed to the embeddings of the inputs.

$$
PE_{(pos,2i) = sin(pos\ /\ 10000^{2i\ /\ d\_model})}
$$
$$
PE_{(pos,2i+1) = cos(pos\ /\ 10000^{2i\ /\ d\_model})}
$$
where $pos$ is the position and $i$ is the dimension.

The resulting positional embedding table is like this:
<p align="center">
<img src="https://d33wubrfki0l68.cloudfront.net/ef81ee3018af6ab6f23769031f8961afcdd67c68/3358f/img/transformer_architecture_positional_encoding/positional_encoding.png" height="350px" alt="Positional Encoding""/>
</p>

For more information about the positional encoding, check [this post](https://kazemnejad.com/blog/transformer_architecture_positional_encoding/).

In [None]:
class PositionalEncoding(nn.Module):

    def __init__(self, embed_dim, max_len=5000):
        """
        Args:
            embed_dim (int): Embedding dimensionality
            max_len (int): Maximum length of a sequence to expect
        """
        super(PositionalEncoding, self).__init__()

        # Create matrix of (T x embed_dim) representing the positional encoding
        # for max_len inputs
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)

        self.register_buffer('pe', pe, persistent=False)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return x

### Encoder

Now the time has come to put everything together and build the Transformer Encoder!

This structure mainly consists of a stack of layers defined as below:

<p align="center">
<img src="https://jalammar.github.io/images/t/transformer_resideual_layer_norm_2.png" height="600px" alt="Transformer Encoder"/>
</p>

The [Layer Normalization](https://arxiv.org/abs/1607.06450) differs from Batch Normalization in that it works element-wise, as seen in the following figure:

<p align="center">
<img src="https://paperswithcode.com/media/methods/Screen_Shot_2020-05-19_at_4.24.42_PM.png" height="250px" alt="LayerNormalization"/>
</p>

The Feed Forward layers consist of a projection to a higher dimension (`ffn_dim`), a ReLU activation and another projection to the original dimension (`embed_dim`). They are defined by the following equation:

$$
FFN(x) = max(0, xW_1 + b_1)W_2 + b_2
$$

In [None]:
class TransformerEncoderLayer(nn.Module):

    def __init__(self, embed_dim, ffn_dim, num_heads, dropout=0.0):
        """
        Args:
            embed_dim (int): Embedding dimensionality (input, output & self-attention)
            ffn_dim (int): Inner dimensionality in the FFN
            num_heads (int): Number of heads of the multi-head attention block
            dropout (float): Dropout probability
        """
        super(TransformerEncoderLayer, self).__init__()

        self.self_attn = MultiheadAttention(embed_dim, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ffn_dim),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(ffn_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None, return_att=False):
        src_len, batch_size, _ = x.shape
        if mask is None:
            mask = torch.zeros(x.shape[1], x.shape[0]).bool().to(x.device)

        selfattn_mask = mask.unsqueeze(-2)

        # TODO: Self-Attention block
        selfattn_out, selfattn_w = ...
        selfattn_out = self.dropout(selfattn_out)

        # TODO: Add + normalize block (1)
        x = ...

        # TODO: FFN block
        ffn_out = ...
        ffn_out = self.dropout(ffn_out)

        # TODO: Add + normalize block (2)
        x = ...

        if return_att:
            return x, selfattn_w
        else:
            return x

In [None]:
class TransformerEncoder(nn.Module):

    def __init__(self, num_layers, embed_dim, ffn_dim, num_heads, vocab_size, dropout=0.0):
        super(TransformerEncoder, self).__init__()

        # Create an embedding table (T x B -> T x B x embed_dim)
        self.embedding = nn.Embedding(vocab_size, embed_dim)

        # Create the positional encoding with the class defined before
        self.pos_enc = PositionalEncoding(embed_dim)

        self.layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, ffn_dim, num_heads, dropout)
            for _ in range(num_layers)
        ])

    def forward(self, x, mask=None, return_att=False):
        x = self.embedding(x)
        x = self.pos_enc(x)

        selfattn_ws = []
        for l in self.layers:
            if return_att:
                x, selfattn_w = l(x, mask=mask, return_att=True)
                selfattn_ws.append(selfattn_w)
            else:
                x = l(x, mask=mask, return_att=False)

        if return_att:
            selfattn_ws = torch.stack(selfattn_ws, dim=1)
            return x, selfattn_ws
        else:
            return x

We have our Transformer Encoder implemented now! Let's try to do a forward pass:

In [None]:
transformer_encoder_cfg = {
    "num_layers": 6,
    "embed_dim": 512,
    "ffn_dim": 2048,
    "num_heads": 8,
    "vocab_size": 8000,
    "dropout": 0.1,
}

transformer_encoder = TransformerEncoder(**transformer_encoder_cfg)

src_batch_example = torch.randint(transformer_encoder_cfg['vocab_size'], (20, 4))

encoder_out, attn_ws = transformer_encoder(src_batch_example, return_att=True)

print(f"Encoder output: {encoder_out.shape}")
print(f"Self-Attention weights: {attn_ws.shape}")

We have built a random batch ($T\ x\ B$) containing $B=4$ sentences of length $T=20$.

The output we get is ($T\ x\ B\ x\ embed\_dim$) and the self-attention weights are ($B\ x\ num\_layers\ x\ num\_heads\ x \ T\ x\ T$)

### Decoder

The Decoder has a similar structure than the Encoder but with two main differences.

First, in addition to self-attention, it needs to attend to the encoder outputs. With this purpose, it includes an Encoder-Decoder attention block between the Self-Attention and the FFN. This new module, also based on MHA, uses the encoder outputs (also known as `memory`) as the keys and values.

Secondly, the self-attention of the decoder cannot attend to "future" samples, because at inference time it works autorregresively. For this reason, we use a triangular mask in the self-attention.

<p align="center">
<img src="https://jalammar.github.io/images/t/transformer_resideual_layer_norm_3.png" height="600px" alt="Transformer Encoder & Decoder"/>
</p>

In [None]:
class TransformerDecoderLayer(nn.Module):

    def __init__(self, embed_dim, ffn_dim, num_heads, dropout=0.0):
        """
        Args:
            embed_dim (int): Embedding dimensionality (input, output & self-attention)
            ffn_dim (int): Inner dimensionality in the FFN
            num_heads (int): Number of heads of the multi-head attention block
            dropout (float): Dropout probability
        """
        super(TransformerDecoderLayer, self).__init__()

        self.self_attn = MultiheadAttention(embed_dim, num_heads)
        self.encdec_attn = MultiheadAttention(embed_dim, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ffn_dim),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(ffn_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, memory, mask=None, memory_mask=None, return_att=False):
        tgt_len, batch_size, _ = x.shape
        src_len, _, _ = memory.shape
        if mask is None:
            mask = torch.zeros(x.shape[1], x.shape[0])
            mask = mask.bool().to(x.device)
        if memory_mask is None:
            memory_mask = torch.zeros(memory.shape[1], memory.shape[0])
            memory_mask = memory_mask.bool().to(memory.device)


        subsequent_mask = torch.triu(torch.ones(batch_size, tgt_len, tgt_len), 1)
        subsequent_mask = subsequent_mask.bool().to(mask.device)
        selfattn_mask = subsequent_mask + mask.unsqueeze(-2)
        
        attn_mask = memory_mask.unsqueeze(-2)

        # TODO: Self-Attention block
        selfattn_out, selfattn_w = ...
        selfattn_out = self.dropout(selfattn_out)

        # TODO: Add + normalize block (1)
        x = ...

        # TODO: Encoder-Decoder Attention block
        attn_out, attn_w = ...
        attn_out = self.dropout(attn_out)

        # TODO: Add + normalize block (2)
        x = ...

        # TODO: FFN block
        ffn_out = ...
        ffn_out = self.dropout(ffn_out)

        # TODO: Add + normalize block (3)
        x = ...

        if return_att:
            return x, selfattn_w, attn_w
        else:
            return x

In [None]:
class TransformerDecoder(nn.Module):

    def __init__(self, num_layers, embed_dim, ffn_dim, num_heads, vocab_size, dropout=0.0):
        super(TransformerDecoder, self).__init__()

        # Create an embedding table (T x B -> T x B x embed_dim)
        self.embedding = nn.Embedding(vocab_size, embed_dim)

        # Create the positional encoding with the class defined before
        self.pos_enc = PositionalEncoding(embed_dim)

        self.layers = nn.ModuleList([
            TransformerDecoderLayer(embed_dim, ffn_dim, num_heads, dropout)
            for _ in range(num_layers)
        ])

        # Add a projection layer (T x B x embed_dim -> T x B x vocab_size)
        self.proj = nn.Linear(embed_dim, vocab_size)

    def forward(self, x, memory, mask=None, memory_mask=None, return_att=False):
        x = self.embedding(x)
        x = self.pos_enc(x)

        selfattn_ws = []
        attn_ws = []
        for l in self.layers:
            if return_att:
                x, selfattn_w, attn_w = l(
                    x, memory, mask=mask, memory_mask=memory_mask, return_att=True
                )
                selfattn_ws.append(selfattn_w)
                attn_ws.append(attn_w)
            else:
                x = l(
                    x, memory, mask=mask, memory_mask=memory_mask, return_att=False
                )

        x = self.proj(x)
        x = F.log_softmax(x, dim=-1)

        if return_att:
            selfattn_ws = torch.stack(selfattn_ws, dim=1)
            attn_ws = torch.stack(attn_ws, dim=1)
            return x, selfattn_ws, attn_ws
        else:
            return x

And now we also have our Transformer Decoder implemented! Let's try to do a forward pass:

In [None]:
transformer_decoder_cfg = {
    "num_layers": 6,
    "embed_dim": 512,
    "ffn_dim": 2048,
    "num_heads": 8,
    "vocab_size": 8000,
    "dropout": 0.1,
}

transformer_decoder = TransformerDecoder(**transformer_decoder_cfg)

tgt_batch_example = torch.randint(transformer_decoder_cfg['vocab_size'], (15, 4))

decoder_out, selfattn_ws, attn_ws  = transformer_decoder(
    tgt_batch_example,
    memory=encoder_out,
    return_att=True
)

print(f"Decoder output: {decoder_out.shape}")
print(f"Self-Attention weights: {selfattn_ws.shape}")
print(f"Enc-Dec Attention weights: {attn_ws.shape}")

We have built a random target batch ($T_{tgt}\ x\ B$) containing $B=4$ sentences of length $T_{tgt}=15$, and we already had the output of the encoder fo size ($T_{src}\ x\ B\ x\ embed\_dim$).

The output we get from the decoder is ($T_{tgt}\ x\ B\ x\ vocab\_size$), the self-attention weights are ($B\ x\ num\_layers\ x\ num\_heads\ x \ T_{tgt}\ x\ T_{tgt}$), and the enc-dec attention weights are ($B\ x\ num\_layers\ x\ num\_heads\ x \ T_{tgt}\ x\ T_{src}$)

### Transformer

We already have all the components of the Transformer, it's time to put them all together!

Note that we will implement two methods to generate results, one to be used during training (whole sequence in parallel) and another to be used during inference (autorregresive generation). The first one is depicted in the original Transformer figure, while the second can be seen in the animation below:



<p align="center">
<img src="https://paperswithcode.com/media/methods/new_ModalNet-21.jpg" height="600px" alt="Transformer"/>
</p>

<p align="center">
<img src="https://jalammar.github.io/images/t/transformer_decoding_2.gif" height="600px" alt="Autorregressive decoding"/>
</p>

In [None]:
class Transformer(nn.Module):
    def __init__(self, encoder_config, decoder_config):
        super(Transformer, self).__init__()
        self.encoder = TransformerEncoder(**encoder_config)
        self.decoder = TransformerDecoder(**decoder_config)
        
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        """ Forward method
        
        Method used at training time, when the target is known. The target tensor
        passed to the decoder is shifted to the right (starting with BOS
        symbol). Then, the output of the decoder starts directly with the first
        token of the sentence.
        """

        # TODO: Compute the encoder output
        encoder_out = ...

        # TODO: Compute the decoder output
        decoder_out = self.decoder(
            x=...
            memory=...
            mask=...
            memory_mask=...
        )
        
        return decoder_out

    def generate(self, src, src_mask=None, bos_idx=0, max_len=50):
        """ Generate method
        
        Method used at inference time, when the target is unknown. It
        iteratively passes to the decoder the sequence generated so far
        and appends the new token to the input again. It uses a Greedy
        decoding (argmax).
        """

        # TODO: Compute the encoder output
        encoder_out = ...

        output = torch.LongTensor([bos_idx])\
                    .expand(1, encoder_out.size(1)).to(src.device)
        for i in range(max_len):
            # TODO: Get the new token
            new_token = self.decoder(
                x=...,
                memory=...
                memory_mask=...
            )[-1].argmax(-1)

            output = torch.cat([output, new_token.unsqueeze(0)], dim=0)

        return output

In [None]:
transformer = Transformer(transformer_encoder_cfg, transformer_decoder_cfg)

In [None]:
transformer(src_batch_example, tgt_batch_example).shape

You got it! You have built your own Transformer!

**NOTE: Most of the modules we've implemented are available in `torch.nn`, you don't need to copy all this code the next time you want to use a Transformer ;D**

## Training your new Transformer

We will train our Transformer on a simple task, consisting of translating from numbers to their English written form.

In [None]:
!pip install -q git+https://github.com/gegallego/seq2seq-numbers-dataset.git

In [None]:
from seq2seq_numbers_dataset import generate_dataset_pytorch, Seq2SeqNumbersCollater

numbers_dataset = generate_dataset_pytorch()

# Downsample the dataset to reduce training time (remove for better performance)
numbers_dataset['train'].src_sents = numbers_dataset['train'].src_sents[:25000]
numbers_dataset['train'].tgt_sents = numbers_dataset['train'].tgt_sents[:25000]

collater = Seq2SeqNumbersCollater(
    numbers_dataset['train'].src_dict,
    numbers_dataset['train'].tgt_dict,
)

### Training

In [None]:
lr = 5e-4
batch_size = 32
log_interval = 50
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

numbers_loader_train = DataLoader(
    numbers_dataset['train'],
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collater,
)

src_dict = numbers_dataset['train'].src_dict
tgt_dict = numbers_dataset['train'].tgt_dict

transformer_encoder_cfg = {
    "num_layers": 3,
    "embed_dim": 256,
    "ffn_dim": 1024,
    "num_heads": 4,
    "vocab_size": len(src_dict),
    "dropout": 0.1,
}
transformer_decoder_cfg = {
    "num_layers": 3,
    "embed_dim": 256,
    "ffn_dim": 1024,
    "num_heads": 4,
    "vocab_size": len(tgt_dict),
    "dropout": 0.1,
}
model = Transformer(transformer_encoder_cfg, transformer_decoder_cfg)
model.to(device)
model.train()

optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = F.nll_loss

print("Training model...")

loss_avg = 0
for i, (src, tgt) in enumerate(numbers_loader_train):
    src = {k: v.to(device) for k, v in src.items()}
    tgt = {k: v.to(device) for k, v in tgt.items()}

    optimizer.zero_grad()

    output = model(
        src['ids'],
        tgt['ids'][:-1],
        src['padding_mask'],
        tgt['padding_mask'][:, :-1],
    )

    loss = criterion(
        output.reshape(-1, output.size(-1)),
        tgt['ids'][1:].flatten()
    )
    loss.backward()
    optimizer.step()

    loss_avg += loss.item()
    if (i+1) % log_interval == 0:
        loss_avg /= log_interval
        print(f"{i+1}/{len(numbers_loader_train)}\tLoss: {loss_avg}")

### Testing

In [None]:
batch_size_test = 128
log_interval_test = 50

numbers_loader_test = DataLoader(
    numbers_dataset['test'],
    batch_size=batch_size_test,
    shuffle=False,
    collate_fn=collater,
)

model.eval()

print("\nTesting model...")

n_correct = 0
n_total = 0
for i, (src, tgt) in enumerate(numbers_loader_test):
    src = {k: v.to(device) for k, v in src.items()}
    tgt = {k: v.to(device) for k, v in tgt.items()}

    output = model.generate(
        src['ids'],
        src_mask=src['padding_mask'],
        bos_idx=numbers_dataset['test'].tgt_dict.bos_idx(),
    )
    output = output[:tgt['ids'].size(0)]

    n_correct += torch.eq(tgt['ids'], output).sum()
    n_total += tgt['ids'].numel()
    if (i+1) % log_interval_test == 0:
        print(f"{i+1}/{len(numbers_loader_test)}")

print(f"Test Accuracy: {100 * n_correct / n_total}%")

### Inference

Check how the model works by selecting any number.

In [None]:
#@title  { run: "auto", vertical-output: true }
#@title  { run: "auto" }
#@markdown Select a number to pass to the model:
input_num = 29284.3 #@param {type:"slider", min:-100000, max:100000, step:0.1}

src_dict = numbers_dataset['train'].src_dict
tgt_dict = numbers_dataset['train'].tgt_dict

input_num_str = "{:,.2f}".format(input_num)
input_num_enc = torch.LongTensor(
    src_dict.encode(input_num_str)
).unsqueeze(-1).to(device)

output_word_enc = model.generate(
    input_num_enc,
    bos_idx=tgt_dict.bos_idx()
)

output_word = tgt_dict.decode(
    output_word_enc.flatten().tolist()
)

print(f"Input: {input_num_str}")
print(f"Output: {output_word}")

### Attention visualization

Analyze the attention weights with the following tool. Are they how you expected?

In [None]:
#@title  { run: "auto", vertical-output: true }
#@title  { run: "auto" }
#@title  { run: "auto", vertical-output: true }
#@markdown Select a sample from the dataset:
dataset_index =   0#@param {type:"integer"}
#@markdown Select the attention to visualize:
attention = "encoder-decoder attention" #@param ["encoder-decoder attention", "encoder self-attention", "decoder self-attention"]
#@markdown Select the layer:
layer = "2" #@param ["3", "2", "1"]
#@markdown Select a head (or average them):
head = "avg" #@param ["avg", "1", "2", "3", "4"]

src, tgt = collater(
    [numbers_dataset["train"][dataset_index]]
)
src = {k: v.to(device) for k, v in src.items()}
tgt = {k: v.to(device) for k, v in tgt.items()}

src_dict = numbers_dataset['train'].src_dict
tgt_dict = numbers_dataset['train'].tgt_dict

enc_output, enc_selfattn = model.encoder(
    src['ids'],
    src['padding_mask'],
    return_att=True,
)

dec_output, dec_selfattn, encdec_attn = model.decoder(
    tgt['ids'][:-1],
    enc_output,
    tgt['padding_mask'][:, :-1],
    src['padding_mask'],
    return_att=True,
)

if attention=="encoder-decoder attention":
    attention_w = encdec_attn
    queries = [tgt_dict[i] for i in dec_output.argmax(-1)[:,0].tolist()]
    keys = [src_dict[i] for i in src['ids'][:, 0].tolist()]
    ytitle = "Output tokens"


elif attention=="encoder self-attention":
    attention_w = enc_selfattn
    queries = [src_dict[i] for i in src['ids'][:, 0].tolist()]
    keys = [src_dict[i] for i in src['ids'][:, 0].tolist()]

elif attention=="decoder self-attention":
    attention_w = dec_selfattn
    queries = [tgt_dict[i] for i in dec_output.argmax(-1)[:,0].tolist()]
    keys = [tgt_dict[i] for i in tgt['ids'][:-1, 0].tolist()]
    ytitle = "Output tokens"

if head == "avg":
    attention_w = attention_w[0][int(layer)-1].mean(0)
else:
    attention_w = attention_w[0][int(layer)-1][int(head)-1]

plot_attention(
    attention_w,
    queries,
    keys,
    ytitle=ytitle,
)


## References

The images are from:
- https://jalammar.github.io/illustrated-transformer/
- https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html
- https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a
- https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
- https://paperswithcode.com/method/layer-normalization
- https://paperswithcode.com/media/methods/multi-head-attention_l1A3G7a.png
- https://paperswithcode.com/method/scaled

The code is partially inspired by:
- https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html
- https://nlp.seas.harvard.edu/2018/04/03/attention.html