# **Neural Machine Translation**
This notebook allows to train and interactively use a Neural Network model for the sequence to sequence translation of German sentences to English. The IWSLT14 German-English dataset is utilised to train and evaluate the performance of the model.

The BLEU Score of the trained model will also evaluate its translation quality. The method presented here is a state-of-the-art model when it comes to German-to-English translation.

The same provided code can be used to train any other language pair. However, some language pairs might be more challenging to train, as they require a more complex network architecture and a larger training dataset. In such cases, one GPU is not enough and usually around 8 GPUs are for example required to train an English-to-German model.

> ***Note***: There will be no mathematical expressions included in this notebook to make it easy to comprehend and follow.



##**1. Evolution of Machine Translation**
The [timeline below](https://towardsdatascience.com/evolution-of-machine-translation-5524f1c88b25) shows a concise history of machine translation. However, since only Neural Machine Translation NMT is within the scope of this notebook, we will not go into detail about the different machine translation phases before the adoption of neural networks. 

A significant achievement in Statistical Machine Translation SMT was the introduction of neural networks. NMT has shown better results than SMT and is considered as the future of machine translation.
SMT utilizes statistical methods to represent the translation patterns, whereas Neural MT, as the name suggests, requires training a Deep Neural Network to learn a statistical model for the translation.

<center><img src=https://miro.medium.com/max/1400/1*XuR_iuPOuY-8i5A3cGmcBw.png width="700"/></center>


There exist [various types of neural architectures](http://www.cse.iitd.ac.in/~mausam/courses/col772/spring2018/lectures/21-seq2seq.pdf) that can handle different combinations of inputs and outputs involving sequences, e.g., CNNs or BiLSTM acceptors to map a sequence to a single decision, Siamese networks for mapping between two sequences and a single decision and BiLSTM transducers that are networks for same length sequence to sequence mapping. However, for the task of machine translation, the model is required to handle an output sequence with a different length than the input sequence.

The following [timeline](https://medium.com/@bgg/seq2seq-pay-attention-to-self-attention-part-1-d332e85e9aad) highlights the most important milestones in NMT, which we will briefly review and explain.

<center><img src=https://miro.medium.com/max/1400/1*C_ERqLXFW0MZ8IUuPbR19A.png width="700"/></center>

###**[1.1. Recurrent Neural Networks](https://leonardoaraujosantos.gitbook.io/artificial-inteligence/machine_learning/supervised_learning/recurrent_neural_networks)**
Some of the first neural methods that have been widely used but are now considered to be less capable are for example multi-layered Recurrent Neural Networks (RNNs). This structure can process sequential input and output data, as it has a **notion of time or sequence**. RNNs are unrolled as a feedforward network with shared weights. By having the same parameters for each time-step, they can learn a more general model for all steps. The output of step $i$ is input to step $i+1$, which makes each output dependent on all previous elements in the sequence. 

*   **Recurrent Neural Networks**  
Because the weight are multiplied over and over again, vanilla RNNs suffer from the vanishing and exploding gradients problem when learning long-term dependencies. The later is easier to solve using gradient clipping. In addition, RNNs are slow learner because they cannot be parallelized since they are processed sequentially over time.

*   **Long Short-Term Memory**  
The vanishing gradient problem of vanilla RNN was alleviated thanks to Long Short-Term Memory (LSTM) modules which are similar to a RestNet. LSTM can remember the information from older past cells and thus helps maintain long-term dependencies. The additional cell state is responsible for the transport of the information through the unit, as it provides a highway for the gradient to flow.
LSTMs are therefore more powerful learners of sequential data compared to simple RNNs.  
A tutorial of a simple LSTM that is very similar to this one, as it makes use of the fairseq toolkit can be found [here](https://fairseq.readthedocs.io/en/latest/tutorial_simple_lstm.html#).

*   **Gated Recurrent Unit**  
Gated Recurrent Unit (GRU) cells are a more computationally efficient variant of the LSTM module.

The cells of the different types of Recurrent Neural Networks can be seen [here](http://dprogrammer.org/rnn-lstm-gru):
<center><img src=http://dprogrammer.org/wp-content/uploads/2019/04/RNN-vs-LSTM-vs-GRU.png width="1000"/></center>

The [illustration below](https://www.analyticsvidhya.com/blog/2019/06/understanding-transformers-nlp-state-of-the-art-models/) shows an example of an unrolled sequence to sequence RNN used for translation. This simple encoder-decoder model obviously will not perform well with sentence longer than 3-4 words. A more in depth explanation of the model as well as various ways to improve it can be found [here](https://leonardoaraujosantos.gitbook.io/artificial-inteligence/machine_learning/supervised_learning/recurrent_neural_networks/machine-translation-using-rnn).

<center><img src=https://cdn.analyticsvidhya.com/wp-content/uploads/2019/06/seq2seq.gif width="600"/></center>

###**[1.2. Attention Models](https://medium.com/@gautam.karmakar/attention-for-neural-connectionist-machine-translation-b833d1e085a3)**
One big disadvantage of the RNN is that all the information perceived by the encoder is stored in a single intermediate (encoder) vector which then serves as a first hidden state of the decoder. This results in RNNs' incapibility to return good translations of long sentences.  
That being said, attention enables neural machine translation to memorize long source sentences. To achieve this, we no longer only build one encoder vector from the encoder’s last hidden state. Instead, we create as much context vectors as there are words in the source sentence. As can be seen [below](https://medium.com/@umerfarooq_26378/neural-machine-translation-with-code-68c425044bbd), at each decoding stage, attention passes an attention score (the context vector) to the decoder, which equals a learned weighted sum of all hidden states.
By enabling the model to weight elements of the source sentence differently and incorporate information from older inputs, we can capture much more information from the encoder side.    
Today, all state-of-the-art architectures use some sort of attention, which has been found to be a very effective method for sequence-to-sequence tasks. Another advantage of attention is the better interpretability of models.

<center><img src=https://miro.medium.com/max/2100/1*75Jb0q3sX1GDYmJSfl-gOw.gif width="600"/></center>

###**[1.3. Convolutional Neural Networks](https://engineering.fb.com/ml-applications/a-novel-approach-to-neural-machine-translation/)**

Convolutions are the application of a filter to a function, where the smaller one is called the filter kernel. Convolutional layers were mostly implemented with datasets containing images, e.g., image classification, object detection and semantic segmentation.
However, Convolutional Neural Networks (CNNs) were later adopted instead of RNNs, by interpreting a sentence similar to a 1D image. CNN-based models not only improve the translation performance  thanks to the hierarchical processing of information, but they are also parallelizable, which makes them computationally more efficient. 

The [following illustration](https://engineering.fb.com/ml-applications/a-novel-approach-to-neural-machine-translation/) shows the architecture designed by the Facebook AI Research team that has the multi-hop attention mechanism to thank for its improved results. Here the encoder CNN processes all the words from the source sentence simultaneously. Afterwards, the decoder sequentially outputs the translated words using CNNs and attention.
<center><img src=https://engineering.fb.com/wp-content/uploads/2017/05/translation_illustration.gif?w=640 width="700"/></center>


##**2. Self-Attention and Transformers** 
To understand how the model architecture we are using works, we must first understand what Transformers are and how they work. But let's start with self-attention, since Transformers rely entirely on this mechanism.

###**2.1. From Attention to Self-Attention**

We already encountered the encoder-decoder attention. The disadvantage of such attention mechanism is that it cannot be parallelized, and it ignores the attention information inside the source sentence as well as the target sentence. So, how is self-attention different?

The first important thing to know is that the encoding and decoding blocks are usually a stack of the same number of encoders and decoders respectively ([see illustration](http://jalammar.github.io/illustrated-transformer/)).
<center><img src=http://jalammar.github.io/images/t/The_transformer_encoder_decoder_stack.png width="500"/></center> 

Let's take a look at [this example](http://jalammar.github.io/illustrated-transformer/) for a second. We, as humans, can easily see that the word "it" in this sentence stands for "animal" and not "street". However, an algorithm would need a little help, and that's where self-attention comes. Using the self-attention mechanism helps better encode each input words by enriching it with context words from the other positions in the input sequence.

<center><img src=http://jalammar.github.io/images/t/transformer_self-attention_visualization.png width="400"/></center>

The idea behind self-attention is to use a special case of attention between the different encoder layers and decoder layers respectively. On the encoder side, self-attention is computed between an input token and other input tokens, whereas for the decoder, the attention score is calculated using an output word and the previously produced output words.

###**2.2. From Self-Attention to Transformers**

The encoder–decoder architecture that is currently dominating in many NLP tasks is the [Transformer](https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf) model [displayed below](https://medium.com/inside-machine-learning/what-is-a-transformer-d07dd1fbec04).

<center><img src=https://miro.medium.com/max/1400/1*BHzGVskWGS_3jEcYYi6miQ.png width="400"/></center>

As we mentionned before, the encoder block consists of serveral encoders of identical structure. Each encoder is composed of a self-attention layer followed by a feed-forward neural network. There are also residual connections and normalization layers in the encoder. An important note is that the individual encoders do not share weights.  
The decoder is also built similarly. We can also see the self-attention layer, the feed-forward neural network, as well as the residual connection and the normalization layer. The only difference is that the self-attention layer is only allowed to attend to previously generated target tokens, by masking future positions.  
The encoder-decoder attention mechanism is also utilized between the last layer of the encoder and each layer in the decoder.  
In the architecture figure we saw before, the Self-Attention block was calles Multi-head Attention: this simply means that Self-Attention is computed multiple times in parallel and independently between one encoder and the next.  
After the last decoder layer, we must convert the float vectors into output words. We first use a linear layer to get a larger logits vector. This vector contains scores for each unique word in our output vocabulary. Afterwards, these scores are tranformed into probabilities using the Softmax layer and the output word for that time step is the word corresponding to the highest probability.


##**3. Joint Source-Target Self Attention with Locality Constraints** 
This Notebook is based on the work done by José A. R. Fonollosa, Noe Casas, and Marta R. Costa-jussá in ["Joint Source-Target Self Attention with Locality Constraints"](http://arxiv.org/abs/1905.06596) and uses their code provided in https://github.com/jarfo/joint.

> ***Note***: This takes around 5 hours to train on Google Colab using the provided GPU.



###**3.1. Model Architecture**###
The following [model architecture](https://arxiv.org/abs/1905.06596) and the name of the paper kind of seems a little complicated at first sight. This model architecture combines the ideas of two papers: ["*Layer-Wise Coordination between Encoder and Decoder for Neural Machine Translation*"](https://pdfs.semanticscholar.org/005c/d149aa86b1be8a22706c8d29095bbf46d192.pdf?_ga=2.75612565.1449333594.1596232318-210385577.1596232318) and ["*Pay Less Attention with Lightweight and Dynamic Convolutions*"](https://arxiv.org/pdf/1901.10430.pdf). Let’s now go through this architecture step by step. 

<center><img src=https://d3i71xaburhd42.cloudfront.net/b0f0a5a21619d70748a4dc007983cc111f1b301e/2-Figure1-1.png width="700"/></center>

As input to the network, we provide both the source and the target sequences concatenated together. The idea behind such a concept is to train the model as a language model. You might ask: but what is a language model?  
We all encounter the most famous language model everyday but might not know that it actually is one. I'm talking about smartphone keyboards, which suggest different next words based on what has been typed so far. What language models do is to calculate the probabilities of the next words depending on previous ones, and the word with the highest probability is chosen, as can be [seen below](http://jalammar.github.io/illustrated-word2vec/).  
The advantage of such models is that it allows to learn a joint source-target representations from the early layers.

<center><img src=http://jalammar.github.io/images/word2vec/language_model_blackbox_output_vector.png width="600"/></center>

The next question you might ask yourself is: How do we ensure this language model-like training?  
The answer to that is by combining the encoder and decoder into a single block and that way we no longer use an independent encoder. 
A layer-wise coordination of the Transformer results in two modification. The individual decoder layers now attend to their corresponding layer in the encoder. This is different from the concept of the Transformer where each decoder layer attends to the last layer of the encoder. The second alteration aims at ensuring that the outputs of the encoder and the decoder from the same level are also in the same semantic level. This is realised by sharing the parameters of the attention and the feed-forward layer are shared between the encoder and decoder.

As a consequence of the merging of the encoder and decoder, we no longer need an encoder-decoder attention mechanism to extract information from the source sequence. Instead, we use a so-called mix-attention on the decoder side, which is a combination of both the encoder-decoder attention and self-attention. This mechanism allows target tokens from the decoder to attend to all source tokens as well as the preceding target tokens ([see figure](https://pdfs.semanticscholar.org/005c/d149aa86b1be8a22706c8d29095bbf46d192.pdf?_ga=2.75612565.1449333594.1596232318-210385577.1596232318)).

<center><img src=https://d3i71xaburhd42.cloudfront.net/b12ccd118974839db290f15c989649b2b5188636/4-Figure1-1.png width="600"/></center>

Because Tranformers are not based on recurrent operations like RNNs and because of the joint representation and the layer-wise coordinated Transformer that results in weight sharing, new problems arise. 
How will the network be able to remember the word ordering in the sentence and differentiate the source language (German tokens) from the target ones (English tokens)?  
The first problem is solved by giving explicit information about the position of the different words in the sequence. This is known as position embedding and is achieved with the help of a sinusoidal function. This way, a resettable position embedding ensures that the source tokens start with zero and is then resetted to zero when it encounters the first token of the target sequence.  
However, the position embedding alone is not capable to identify to which language the tokens belong to. For this reason, we make use of a second embedding:  Differnet from the positional embedding, this source/target embedding is learned end-to-end during training.

If you look closer at the actual figure of the model architecture (first figure in this section), you can see that the self-attention and the mixed-attention are different from the figure above. This is because we also impose locality constraints to the receptive field of the self-attention layers. Such a locality constraints helps enhance the ability of capturing useful local context by attending only to a reduced number of tokens in the vicinity of the actual token we're currently looking at. On the source side, the reduction of the receptive field is centered on the encoder at each token, whereas on the target side, we only attend to previous tokens. This approach is known as masking, and it is similar to the one used for the decoder of the Transformer, where we also do not consider future target tokens. The band with the specified receptive field size as width grows after each attention layer.

During inference, the model generates the translated sequence by outputting one token at a time, from left to right. This works similar to a Model Language that interprets the source sequence as starting point.



###**3.2.   Requirements**###

For model building, training and translation, the fairseq toolkit is used. Fairseq is a sequence-to-sequence modeling toolkit implemented in PyTorch by Facebook AI Research. It enables to build and train custom models for translation purposes, among others. 

  1.   Check the Python version. Required **Python >= 3.6.**
  2.   Check the Pytorch version. Required **Pytorch >= 1.4.0.**
  3.   Install **fairseq >= 0.6.2**.
  4.   Fairseq requires a **GPU** for faster training. Turn on the GPU on Google Colab: `Edit > Notebook settings > Hardware accelerator: GPU`. For better and more efficient results, you can also use **NCCL**.

In [1]:
!python --version

import torch
print(torch.__version__)

# Install fairseq from source in the same location as the notebook
%cd "/content/"
!git clone https://github.com/pytorch/fairseq
%cd fairseq
!pip install --editable .

Python 3.6.9
1.6.0+cu101
/content
Cloning into 'fairseq'...
remote: Enumerating objects: 31, done.[K
remote: Counting objects: 100% (31/31), done.[K
remote: Compressing objects: 100% (21/21), done.[K
remote: Total 16937 (delta 14), reused 22 (delta 10), pack-reused 16906[K
Receiving objects: 100% (16937/16937), 7.82 MiB | 26.97 MiB/s, done.
Resolving deltas: 100% (12482/12482), done.
/content/fairseq
Obtaining file:///content/fairseq
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting sacrebleu
[?25l  Downloading https://files.pythonhosted.org/packages/23/d3/be980ad7cda7c4bbfa97ee3de062fb3014fc1a34d6dd5b82d7b92f8d6522/sacrebleu-1.4.13-py3-none-any.whl (43kB)
[K     |████████████████████████████████| 51kB 2.9MB/s 
Collecting portalocker
  Downloading https://files.pythonhosted.org/packages/89/a6/3814b7107e07

###**3.3. Dataset Download and Preparation**

As mentioned above, we will train our model on the standard benchmark dataset IWSLT14 German-English, which contains $160$K training sentence pairs. This choice of the data set was restricted by the available processing power. 

An important task to perform before training with any possibly unstructured data, e.g. text data, is the pre-proceesing step. A pre-processing script `prepare-iwslt14.sh` for the IWSLT14 German-English corpus is already provided by fairseq. 
That file was modified by the authors to extend the vocabulary to $31$K joint source target BPE tokens.

This pre-processing file includes:

*   **Text Cleaning**  
This step includes converting the text to lower case, removing empty lines and redundant space characters and only accepting sentences that are within a limited sentence length. It is realized with the [`lowercase.perl`](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/lowercase.perl) and [`clean-corpus-n.perl`](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/training/clean-corpus-n.perl) scripts from the [Moses](https://github.com/moses-smt/mosesdecoder) toolkit.

*   **Tokenization**  
The Moses [`tokenizer.perl`](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl) separates punctuation from words, with the exception of special instances such as URL or dates. A text string is thus segmented into a list of tokens (mostly words), for example, tokenizing the sentence `'Hello World!'` returns `['Hello', 'World', '!']`.

*   **Byte Pair Encoding (BPE)**  
The second type of tokenization that is applied afterwards is BPE. It is a special type of tokenization and is one of various subword tokenization methods.  
It is self-evident that in any language some words occur more often than others. For this reason, we want the words that are less frequent in the dataset not to have their own identifier. This is mainly because a larger vocabulary will slow down the model. Therefore, the idea of all subword tokenizers is to decompose less frequent words into subword units while preserving their meaning. An example of a subword tokenization is to partition the word "loving" into "lov" and "ing" and the word "loving" into "lov" and "ed". As a result, the model has a reasonably sized vocabulary and can nevertheless be generalized to new words.  
For more information about the main tokenization algorithms, you can check out this [post](https://mlexplained.com/2019/11/06/a-deep-dive-into-the-wonderful-world-of-preprocessing-in-nlp/). The repository utilized for this step can also be found [here](https://github.com/rsennrich/subword-nmt.git).

We first download the IWSLT14 German-English pre-processing script from the Github repository of the authors of the paper, and then run it. To better understand the individual steps of the pre-processing task, you may want to take a closer look at the `prepare-iwslt14-31K.sh` file.

In [2]:
%cd /content/

# Download pre-processing script
!wget https://raw.githubusercontent.com/jarfo/joint/master/examples/prepare-iwslt14-31K.sh

# Dataset download and preparation 
!bash /content/prepare-iwslt14-31K.sh

/content
--2020-08-03 20:52:24--  https://raw.githubusercontent.com/jarfo/joint/master/examples/prepare-iwslt14-31K.sh
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2970 (2.9K) [text/plain]
Saving to: ‘prepare-iwslt14-31K.sh’


2020-08-03 20:52:24 (52.1 MB/s) - ‘prepare-iwslt14-31K.sh’ saved [2970/2970]

Cloning Moses github repository (for tokenization scripts)...
Cloning into 'mosesdecoder'...
remote: Enumerating objects: 58, done.[K
remote: Counting objects: 100% (58/58), done.[K
remote: Compressing objects: 100% (48/48), done.[K
remote: Total 147572 (delta 29), reused 21 (delta 9), pack-reused 147514[K
Receiving objects: 100% (147572/147572), 129.76 MiB | 22.13 MiB/s, done.
Resolving deltas: 100% (114014/114014), done.
Cloning Subword NMT repository (fo

###**3.4.   Dataset Binarization**

In this next step, the `fairseq-preprocess` command-line tool is executed to build the vocabularies and convert the previously pre-processed training data to a binary format. During this step, a word level processing will generate a vocabulary of all words and subwords that appear in the dataset. The sentences are then converted into a sequence of integers by giving each word in the vocabulary a unique number. Naturally, the sentences are of different lengths. However, since most neural networks require inputs of identical size, we ensure that all sequences are of same length by padding the rest of the sequence with zeros.

The [following figure](https://towardsdatascience.com/nlp-preparing-text-for-deep-learning-model-using-tensorflow2-461428138657) shows the different steps a text sequence goes through during pre-processing.
<center><img src=https://miro.medium.com/max/1218/1*zsIXWoN0_CE9PXzmY3tIjQ.png width="500"/></center>


The binarized data that is later used for model training will be written to `data-bin/iwslt14.joined-dictionary.31K.de-en`

In [3]:
# Dataset binarization
TEXT='iwslt14.tokenized.31K.de-en'
!fairseq-preprocess --joined-dictionary --source-lang de --target-lang en \
  --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
  --destdir data-bin/iwslt14.joined-dictionary.31K.de-en

2020-08-03 20:54:40 | INFO | fairseq_cli.preprocess | Namespace(align_suffix=None, alignfile=None, all_gather_list_size=16384, bf16=False, bpe=None, checkpoint_suffix='', cpu=False, criterion='cross_entropy', dataset_impl='mmap', destdir='data-bin/iwslt14.joined-dictionary.31K.de-en', empty_cache_freq=0, fp16=False, fp16_init_scale=128, fp16_no_flatten_grads=False, fp16_scale_tolerance=0.0, fp16_scale_window=None, joined_dictionary=True, log_format=None, log_interval=100, lr_scheduler='fixed', memory_efficient_bf16=False, memory_efficient_fp16=False, min_loss_scale=0.0001, model_parallel_size=1, no_progress_bar=False, nwordssrc=-1, nwordstgt=-1, only_source=False, optimizer='nag', padding_factor=8, profile=False, quantization_config_path=None, seed=None, source_lang='de', srcdict=None, target_lang='en', task='translation', tensorboard_logdir='', testpref='iwslt14.tokenized.31K.de-en/test', tgtdict=None, threshold_loss_scale=None, thresholdsrc=0, thresholdtgt=0, tokenizer=None, tpu=Fals

###**3.5.   Model Building**

As mentioned at the beginning, this notebook exploits the functions and models already implemented by the fairseq toolkit to build our neural machine translation model. For this reason, most of the code provided here either modifies some already implemented moduless to achieve the desired results or wraps some of the existing classes to adjust them to the new model architecture.

In [4]:
%cd "/content/fairseq"

/content/fairseq


####**3.5.1. Build the Multihead Attention Model**####
The Fairseq toolkit already has an implementation of the multi-head attention mechanism under `fairseq.modules.multihead_attention`. However, since we want to add locality constrains to the receptive field of the self-attention layers, we must adapt the implementation of multi-head attention appropriately. Therefore, the self-attention layer has to be able to mask the tokens that are not contained within a specific band and to mask the padding symbols we used to make the sequences of same length as well. This is achieved by setting the corresponding positions to $-\infty$ in the self-attention calculation.  
However, when combinated with padding masking, such a local attention masking can lead to all $-\infty$ attention rows. This adapted version detects and corrects this situation.

For a better understanding of the math behind multi-head attention, I recommend reading through the ["*Attention Is All You Need*"](https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf) paper.

In [5]:
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Parameter
from fairseq.incremental_decoding_utils import with_incremental_state

from fairseq import utils


@with_incremental_state
class ProtectedMultiheadAttention(nn.Module):

    def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.scaling = self.head_dim ** -0.5

        self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
        if bias:
            self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
        else:
            self.register_parameter('in_proj_bias', None)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

        if add_bias_kv:
            self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
            self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn

        self.reset_parameters()

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.in_proj_weight)
        nn.init.xavier_uniform_(self.out_proj.weight)
        if self.in_proj_bias is not None:
            nn.init.constant_(self.in_proj_bias, 0.)
            nn.init.constant_(self.out_proj.bias, 0.)
        if self.bias_k is not None:
            nn.init.xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            nn.init.xavier_normal_(self.bias_v)

    def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
                need_weights=True, static_kv=False, attn_mask=None):
        """Input shape: Time x Batch x Channel

        Self-attention can be implemented by passing in the same arguments for
        query, key and value. Timesteps can be masked by supplying a T x T mask in the
        `attn_mask` argument. Padding elements can be excluded from
        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
        batch x src_len, where padding elements are indicated by 1s.
        """

        qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
        kv_same = key.data_ptr() == value.data_ptr()

        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]
        assert key.size() == value.size()

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if 'prev_key' in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert kv_same and not qkv_same
                    key = value = None
        else:
            saved_state = None

        if qkv_same:
            # self-attention
            q, k, v = self.in_proj_qkv(query)
        elif kv_same:
            # encoder-decoder attention
            q = self.in_proj_q(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k, v = self.in_proj_kv(key)
        else:
            q = self.in_proj_q(query)
            k = self.in_proj_k(key)
            v = self.in_proj_v(value)
        q *= self.scaling

        if self.bias_k is not None:
            assert self.bias_v is not None
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)

        q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        if k is not None:
            k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        if v is not None:
            v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)

        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if 'prev_key' in saved_state:
                prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    k = torch.cat((prev_key, k), dim=1)
            if 'prev_value' in saved_state:
                prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    v = torch.cat((prev_value, v), dim=1)
            saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
            saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)

            self._set_input_buffer(incremental_state, saved_state)

        src_len = k.size(1)

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.add_zero_attn:
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)

        attn_weights = torch.bmm(q, k.transpose(1, 2))
        assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            if self.onnx_trace:
                attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
            attn_weights += attn_mask

        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            if self.onnx_trace:
                attn_weights = torch.where(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    torch.Tensor([float("-Inf")]),
                    attn_weights.float()
                ).type_as(attn_weights)
            else:
                attn_weights = attn_weights.float().masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    float('-inf'),
                ).type_as(attn_weights)  # FP16 support: cast to float and back
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
            all_inf = torch.isinf(attn_weights).all(dim=-1)
            if all_inf.any():
                attn_weights = attn_weights.float().masked_fill(
                    all_inf.unsqueeze(-1),
                    0,
                ).type_as(attn_weights)  # FP16 support: cast to float and back


        attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
        attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)

        attn = torch.bmm(attn_weights, v)
        assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        if (self.onnx_trace and attn.size(1) == 1):
            # when ONNX tracing a single decoder step (sequence length == 1)
            # the transpose is a no-op copy before view, thus unnecessary
            attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
        else:
            attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn)

        if need_weights:
            # average attention weights over heads
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights.sum(dim=1) / self.num_heads
        else:
            attn_weights = None

        return attn, attn_weights

    def in_proj_qkv(self, query):
        return self._in_proj(query).chunk(3, dim=-1)

    def in_proj_kv(self, key):
        return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)

    def in_proj_q(self, query):
        return self._in_proj(query, end=self.embed_dim)

    def in_proj_k(self, key):
        return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)

    def in_proj_v(self, value):
        return self._in_proj(value, start=2 * self.embed_dim)

    def _in_proj(self, input, start=0, end=None):
        weight = self.in_proj_weight
        bias = self.in_proj_bias
        weight = weight[start:end, :]
        if bias is not None:
            bias = bias[start:end]
        return F.linear(input, weight, bias)

    def reorder_incremental_state(self, incremental_state, new_order):
        """Reorder buffered internal state (for incremental generation)."""
        input_buffer = self._get_input_buffer(incremental_state)
        if input_buffer is not None:
            for k in input_buffer.keys():
                input_buffer[k] = input_buffer[k].index_select(0, new_order)
            self._set_input_buffer(incremental_state, input_buffer)

    def _get_input_buffer(self, incremental_state):
        return self.get_incremental_state(
            incremental_state,
            'attn_state',
        ) or {}

    def _set_input_buffer(self, incremental_state, buffer):
        return self.set_incremental_state(
            incremental_state,
            'attn_state',
            buffer,
        )


#### **3.5.2. Adapt some Functions** ####

We use some of the precomputed Pytorch modules to get the functions for word embedding (mapping the different words from the vocabulary to a vector of real numbers), language embedding, normalization layer and the the linear layer.



In [6]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from fairseq import options
from fairseq import utils


def Embedding(num_embeddings, embedding_dim, padding_idx):
    m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
    nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
    nn.init.constant_(m.weight[padding_idx], 0)
    return m


def LanguageEmbedding(embedding_dim):
    m = nn.Parameter(torch.Tensor(embedding_dim))
    nn.init.normal_(m, mean=0, std=embedding_dim ** -0.5)
    return m


def LayerNorm(embedding_dim):
    m = nn.LayerNorm(embedding_dim)
    return m


def Linear(in_features, out_features, bias=True):
    m = nn.Linear(in_features, out_features, bias)
    nn.init.xavier_uniform_(m.weight)
    if bias:
        nn.init.constant_(m.bias, 0.)
    return m

#### **3.5.3. Build the Encoder Block**

As we discussed in the model architecture section, the 
joint attention model does not make use of an independant encoder. The encoder is only needed for the source embeddings computation, including pasdding the sequences to get network inputs of same-length . 


In [7]:
from fairseq.models import FairseqEncoder
from fairseq.modules import PositionalEmbedding


class JointAttentionEncoder(FairseqEncoder):
    """
    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
        left_pad (bool): whether the input is left-padded
    """

    def __init__(self, args, dictionary, embed_tokens, left_pad):
        super().__init__(dictionary)
        self.dropout = args.dropout

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = args.max_source_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions, embed_dim, self.padding_idx,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None
        self.embed_language = LanguageEmbedding(embed_dim) if args.language_embeddings else None

        self.register_buffer('version', torch.Tensor([2]))

    def forward(self, src_tokens, src_lengths):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`

        Returns:
            dict:
                - **encoder_out** (Tensor): embedding output of shape
                  `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
        """
        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(src_tokens)
        if self.embed_positions is not None:
            x += self.embed_positions(src_tokens)
        # language embedding
        if self.embed_language is not None:
            lang_emb = self.embed_scale * self.embed_language.view(1, 1, -1)
            x += lang_emb
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        if not encoder_padding_mask.any():
            encoder_padding_mask = None

        return {
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        if encoder_out['encoder_out'] is not None:
            encoder_out['encoder_out'] = \
                encoder_out['encoder_out'].index_select(1, new_order)
        if encoder_out['encoder_padding_mask'] is not None:
            encoder_out['encoder_padding_mask'] = \
                encoder_out['encoder_padding_mask'].index_select(0, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions, self.embed_positions.max_positions)


#### **3.5.4. Adapt the Decoder** ####

We use the preimplemented decoder from `fairseq.models.transformer.py` and adapt it to use `ProtectedMultiheadAttention` we implemented before. The decoder is postprocessed with dropout, a residual layer and a normalization layer.

In [8]:
class ProtectedTransformerDecoderLayer(nn.Module):
    """Decoder layer block.
    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.
    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """
    
    def __init__(self, args, no_encoder_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.self_attn = ProtectedMultiheadAttention(
            self.embed_dim, args.decoder_attention_heads,
            dropout=args.attention_dropout,
        )
        self.dropout = args.dropout
        self.relu_dropout = args.relu_dropout
        self.normalize_before = args.decoder_normalize_before

        self.self_attn_layer_norm = LayerNorm(self.embed_dim)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = ProtectedMultiheadAttention(
                self.embed_dim, args.decoder_attention_heads,
                dropout=args.attention_dropout,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(self, x, encoder_out, encoder_padding_mask, incremental_state,
                prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None,
                self_attn_padding_mask=None):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.

        Returns:
            encoded output of shape `(batch, src_len, embed_dim)`
        """
        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        x, _ = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)

        attn = None
        if self.encoder_attn is not None:
            residual = x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                self.encoder_attn._set_input_buffer(incremental_state, saved_state)
            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=self.relu_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
        if self.onnx_trace:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
            return x, attn, self_attn_state
        return x, attn

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn


#### **3.5.5. Build the Decoder Block**
We wil now build the decoder block consisting of `args.decoder_layers` decoders from the class `ProtectedTransformerDecoderLayer`.
Here, we implement an incremental decoding with the `FairseqIncrementalDecoder` interface. The incremental decoder interface is preferred over the `FairseqDecoder` interface because it results in faster generations during inference time. This is achieved by inputting only the immediate previous state and caching all previous states in an `incremental_state` argument.

For a detailed explanation of this of this choice of decoder, refer to this [tutorial](https://fairseq.readthedocs.io/en/latest/tutorial_simple_lstm.html#making-generation-faster).

In [9]:
from fairseq.models import FairseqIncrementalDecoder

class JointAttentionDecoder(FairseqIncrementalDecoder):
    """
    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): decoding dictionary
        embed_tokens (torch.nn.Embedding): output embedding
        left_pad (bool, optional): whether the input is left-padded. Default:
            ``False``
    """

    def __init__(self, args, dictionary, embed_tokens, left_pad=False, final_norm=True):
        super().__init__(dictionary)
        self.dropout = args.dropout
        self.share_input_output_embed = args.share_decoder_input_output_embed
        self.kernel_size_list = args.kernel_size_list

        input_embed_dim = embed_tokens.embedding_dim
        embed_dim = args.decoder_embed_dim
        output_embed_dim = args.decoder_output_dim

        padding_idx = embed_tokens.padding_idx
        self.max_target_positions = args.max_target_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)

        self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None

        self.embed_positions = PositionalEmbedding(
            args.max_target_positions, embed_dim, padding_idx,
            learned=args.decoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.embed_language = LanguageEmbedding(embed_dim) if args.language_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            ProtectedTransformerDecoderLayer(args, no_encoder_attn=True)
            for _ in range(args.decoder_layers)
        ])

        self.project_out_dim = Linear(embed_dim, output_embed_dim, bias=False) \
            if embed_dim != output_embed_dim and not args.tie_adaptive_weights else None

        if not self.share_input_output_embed:
            self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), output_embed_dim))
            nn.init.normal_(self.embed_out, mean=0, std=output_embed_dim ** -0.5)
        self.register_buffer('version', torch.Tensor([2]))
        self.normalize = args.decoder_normalize_before and final_norm
        if self.normalize:
            self.layer_norm = LayerNorm(embed_dim)

    def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
        """
        Args:
            input (dict): with
                prev_output_tokens (LongTensor): previous decoder outputs of shape
                    `(batch, tgt_len)`, for input feeding/teacher forcing
            encoder_out (Tensor, optional): output from the encoder, used for
                encoder-side attention
            incremental_state (dict): dictionary used for storing state during
                :ref:`Incremental decoding`

        Returns:
            tuple:
                - the last decoder layer's output of shape `(batch, tgt_len,
                  vocab)`
                - the last decoder layer's attention weights of shape `(batch,
                  tgt_len, src_len)`
        """
        tgt_len = prev_output_tokens.size(1)

        # embed positions
        positions = self.embed_positions(
            prev_output_tokens,
            incremental_state=incremental_state,
        ) if self.embed_positions is not None else None

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
            if positions is not None:
                positions = positions[:, -1:]

        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(prev_output_tokens)

        if self.project_in_dim is not None:
            x = self.project_in_dim(x)

        if positions is not None:
            x += positions

        # language embedding
        if self.embed_language is not None:
            lang_emb = self.embed_scale * self.embed_language.view(1, 1, -1)
            x += lang_emb

        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)
        attn = None
        inner_states = [x]
        source = encoder_out['encoder_out']
        process_source = incremental_state is None or len(incremental_state) == 0

        # extended padding mask
        source_padding_mask = encoder_out['encoder_padding_mask']
        if source_padding_mask is not None:
            target_padding_mask = source_padding_mask.new_zeros((source_padding_mask.size(0), tgt_len))
            self_attn_padding_mask = torch.cat((source_padding_mask, target_padding_mask), dim=1)
        else:
            self_attn_padding_mask = None

        # transformer layers
        for i, layer in enumerate(self.layers):

            if self.kernel_size_list is not None:
                target_mask = self.local_mask(x, self.kernel_size_list[i], causal=True, tgt_len=tgt_len)
            elif incremental_state is None:
                target_mask = self.buffered_future_mask(x)
            else:
                target_mask = None

            if target_mask is not None:
                zero_mask = target_mask.new_zeros((target_mask.size(0), source.size(0)))
                self_attn_mask = torch.cat((zero_mask, target_mask), dim=1)
            else:
                self_attn_mask = None

            state = incremental_state
            if process_source:
                if state is None:
                    state = {}
                if self.kernel_size_list is not None:
                    source_mask = self.local_mask(source, self.kernel_size_list[i], causal=False)
                else:
                    source_mask = None
                source, attn = layer(
                    source,
                    None,
                    None,
                    state,
                    self_attn_mask=source_mask,
                    self_attn_padding_mask=source_padding_mask
                )
                inner_states.append(source)

            x, attn = layer(
                x,
                None,
                None,
                state,
                self_attn_mask=self_attn_mask,
                self_attn_padding_mask=self_attn_padding_mask
            )
            inner_states.append(x)

        if self.normalize:
            x = self.layer_norm(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if self.project_out_dim is not None:
            x = self.project_out_dim(x)

        # project back to size of vocabulary
        if self.share_input_output_embed:
            x = F.linear(x, self.embed_tokens.weight)
        else:
            x = F.linear(x, self.embed_out)

        pred = x
        info = {'attn': attn, 'inner_states': inner_states}

        return pred, info

    def max_positions(self):
        """Maximum output length supported by the decoder."""
        if self.embed_positions is None:
            return self.max_target_positions
        return min(self.max_target_positions, self.embed_positions.max_positions)

    def buffered_future_mask(self, tensor):
        """Cached future mask."""
        dim = tensor.size(0)
        #pylint: disable=access-member-before-definition, attribute-defined-outside-init
        if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device:
            self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
        if self._future_mask.size(0) < dim:
            self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1)
        return self._future_mask[:dim, :dim]

    def local_mask(self, tensor, kernel_size, causal, tgt_len=None):
        """Locality constraint mask."""
        rows = tensor.size(0)
        cols = tensor.size(0) if tgt_len is None else tgt_len
        if causal:
            if rows == 1:
                mask = utils.fill_with_neg_inf(tensor.new(1, cols))
                mask[0, -kernel_size:] = 0
                return mask
            else:
                diag_u, diag_l = 1, kernel_size
        else:
            diag_u, diag_l = ((kernel_size + 1) // 2, (kernel_size + 1) // 2) if kernel_size % 2 == 1 \
                else (kernel_size // 2, kernel_size // 2 + 1)
        mask1 = torch.triu(utils.fill_with_neg_inf(tensor.new(rows, cols)), diag_u)
        mask2 = torch.tril(utils.fill_with_neg_inf(tensor.new(rows, cols)), -diag_l)

        return mask1 + mask2


#### **3.5.6. Register the Encoder-Decoder Model**

After we have finished defining the encoder and decoder, we can extend the fairseq library with the neural network model [plug-in](https://fairseq.readthedocs.io/en/latest/overview.html) by registering it with the help of the function decorater `register_model()` provided by fairseq. This is mandatory in order to use the fairseq command-line tools to train the model and later evaluate its performance by calculating its BLEU score.

The class [`BaseFairseqModel`](https://fairseq.readthedocs.io/en/latest/models.html#fairseq.models.BaseFairseqModel) serves as a base class for all fairseq models. For this step we must therefore implement a wrapper around the [`FairseqEncoderDecoderModel`](https://fairseq.readthedocs.io/en/latest/models.html#fairseq.models.FairseqEncoderDecoderModel) interface and extend it with the two functions `add_args()` and `build_model()`, as we are dealing with sequence-to-sequence models.  
On the one hand, the first function allows to expand the comman-line with new model-specific arguments. In this case, arguments like dropout, kernel size and the dimensionality of the embeddings are added.
On the other hand, `build_model()` initializes the encoder and decoder models and returns a `JointAttentionModel` model instance. 


In [10]:
from fairseq.models import (FairseqEncoderDecoderModel, register_model)

@register_model('joint_attention')
class JointAttentionModel(FairseqEncoderDecoderModel):
    """
    Args:
        encoder (JointAttentionEncoder): the encoder
        decoder (JointAttentionDecoder): the decoder

    The joint source-target model provides the following named architectures and
    command-line arguments:

    .. argparse::
        :ref: fairseq.models.joint_attention_parser
        :prog:
    """

    def __init__(self, encoder, decoder):
        super().__init__(encoder, decoder)

    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
                            help='path to pre-trained encoder embedding')
        parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
                            help='encoder embedding dimension')
        parser.add_argument('--encoder-learned-pos', action='store_true',
                            help='use learned positional embeddings in the encoder')
        parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
                            help='path to pre-trained decoder embedding')
        parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
                            help='decoder embedding dimension')
        parser.add_argument('--decoder-learned-pos', action='store_true',
                            help='use learned positional embeddings in the decoder')
        parser.add_argument('--decoder-normalize-before', action='store_true',
                            help='apply layernorm before each decoder block')
        parser.add_argument('--share-decoder-input-output-embed', action='store_true',
                            help='share decoder input and output embeddings')
        parser.add_argument('--share-all-embeddings', action='store_true',
                            help='share encoder, decoder and output embeddings'
                                 ' (requires shared dictionary and embed dim)')
        parser.add_argument('--dropout', type=float, metavar='D',
                            help='dropout probability')
        parser.add_argument('--attention-dropout', type=float, metavar='D',
                            help='dropout probability for attention weights')
        parser.add_argument('--relu-dropout', type=float, metavar='D',
                            help='dropout probability after ReLU in FFN')
        parser.add_argument('--decoder-layers', type=int, metavar='N',
                            help='num layers')
        parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
                            help='embedding dimension for FFN')
        parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
                            help='num attention heads')
        parser.add_argument('--kernel-size-list', type=lambda x: options.eval_str_list(x, int),
                            help='list of kernel size (default: None)')
        parser.add_argument('--language-embeddings', action='store_true',
                            help='use language embeddings')

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_architecture(args)

        if not hasattr(args, 'max_source_positions'):
            args.max_source_positions = 1024
        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = 1024

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

        def build_embedding(dictionary, embed_dim, path=None):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            emb = Embedding(num_embeddings, embed_dim, padding_idx)
            # if provided, load from preloaded dictionaries
            if path:
                embed_dict = utils.parse_embedding(path)
                utils.load_embedding(embed_dict, dictionary, emb)
            return emb

        if args.share_all_embeddings:
            if src_dict != tgt_dict:
                raise ValueError('--share-all-embeddings requires a joined dictionary')
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise ValueError(
                    '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
            if args.decoder_embed_path and (
                    args.decoder_embed_path != args.encoder_embed_path):
                raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path')
            encoder_embed_tokens = build_embedding(
                src_dict, args.encoder_embed_dim, args.encoder_embed_path
            )
            decoder_embed_tokens = encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise ValueError(
                    'The joint_attention model requires --encoder-embed-dim to match --decoder-embed-dim')
            encoder_embed_tokens = build_embedding(
                src_dict, args.encoder_embed_dim, args.encoder_embed_path
            )
            decoder_embed_tokens = build_embedding(
                tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
            )

        encoder = JointAttentionEncoder(args, src_dict, encoder_embed_tokens, left_pad=args.left_pad_source)
        decoder = JointAttentionDecoder(args, tgt_dict, decoder_embed_tokens, left_pad=args.left_pad_target)
        return JointAttentionModel(encoder, decoder)


#### **3.5.7. Register the Model Architecture**

After having registered the new plug-in in the previous step, we can now define the desired architecture and register it as well. In the `register_model_architecture()` function decorator, the first argument should be the name of the above registered model, i.e., `'joint_attention'`, whereas the second argument corresponds to the name of the model architecture. The function we register has one argument *args*: it modifies the model in-place with the user-defined architecture parameters.

As we explained before, the local joint attention model we are building here is an extension of the [joint attention model](http://papers.nips.cc/paper/8019-layer-wise-coordination-between-encoder-and-decoder-for-neural-machine-translation.pdf). For this reason we register three seperate architectures: the `'local_joint_attention_iwslt_de_en'` which is upgrades `'joint_attention_iwslt_de_en'` which is on its turn based on the base architecture `'joint_attention'` we registered before. By defining these two different model architectures in seperate functions, we can also train the other architecture (`'joint_attention_iwslt_de_en'`) and compare its performance. 

> **Model Parameter**  
To later compare the translation quality of this architecture with the one of the well-known Big Transformer, we ensure that both have a similar number of trainable parameters by chosing the number of layers of our joint self-attention model accordingly.
 * Number of layers: 14
 * Embedding size: $256$
 * Feedforward expansion size: $1024$
 * Attention heads: $4$
 * Attention window sizes from input layers to output layers: $3, 5, 7, 9, 11, 13, 15, 17, 21, 25, 29, 33, 37, 41$

Having registered the model and the new architecture allows us to make use of the command-line tools from fairseq to train and evaluate the new model by directly selecting the stored model architecture by specifying `--arch local_joint_attention_iwslt_de_en`.

In [11]:
from fairseq.models import register_model_architecture

@register_model_architecture('joint_attention', 'joint_attention')
def base_architecture(args):
    args.encoder_embed_path = getattr(args, 'encoder_embed_path', None)
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
    args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', False)

    args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
    args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim)
    args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
    args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
    args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)

    args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 2048)
    args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
    args.decoder_layers = getattr(args, 'decoder_layers', 14)

    args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False)
    args.attention_dropout = getattr(args, 'attention_dropout', 0.)
    args.relu_dropout = getattr(args, 'relu_dropout', 0.)
    args.dropout = getattr(args, 'dropout', 0.1)
    args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', True)
    args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
    args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
    args.kernel_size_list = getattr(args, 'kernel_size_list', None)
    assert args.kernel_size_list is None or len(args.kernel_size_list) == args.decoder_layers, "kernel_size_list doesn't match decoder_layers"
    args.language_embeddings = getattr(args, 'language_embeddings', True)


@register_model_architecture('joint_attention', 'joint_attention_iwslt_de_en')
def joint_attention_iwslt_de_en(args):
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
    args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
    args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024)
    args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
    args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
    args.dropout = getattr(args, 'dropout', 0.3)
    base_architecture(args)


@register_model_architecture('joint_attention', 'local_joint_attention_iwslt_de_en')
def local_joint_attention_iwslt_de_en(args):
    args.kernel_size_list = getattr(args, 'kernel_size_list', [3, 5, 7, 9, 11, 13, 15, 17, 21, 25, 29, 33, 37, 41])
    joint_attention_iwslt_de_en(args)

###**3.6.   Train the Built Model**####

Run the `fairseq-train` command-line tool from the fairseq library to train a new model from scratch. We must specify the registered model architecture in the command-line argument `--arch local_joint_attention_iwslt_de_en`. `fairseq-train` will use all GPUs if available.

> **Hyperparameters**, defined based on [Wu et al., 2019](https://arxiv.org/abs/1901.10430):
The following hyperparameter will be specified to the `fairseq-train` comman-line arguments:
 * Optimization policy: adam optimizer with $\epsilon = 10^{-9}$, $\beta_1 = 0.9$ and $\beta_2 = 0.98$
 * Batch size: $4$K source tokens. The batch size is specified in terms of the maximum number of tokens per batch `--max-tokens`
 * Training steps: $85$K
 * Learning rate: linearly warmed up for the first $4$K steps from $10^{−7}$ up to a maximum learning rate of $10^{−3}$, followed by an inverse square root scheduler with a weight decay of $10^{−4}$ until it reaches a minimum learning rate of $10^{−9}$.
 * Loss function: [cross entropy with label smoothing](https://medium.com/@nainaakash012/when-does-label-smoothing-help-89654ec75326) with a label smoothing of $0.1$.
 * Gradient Clipping: clip threshold of gradients to $0$.
 * Logging interval: every $100$ batches.
 * Checkpointing parameter: keep $10$ checkpoint files.

> ***Note***: Unfortunately, I did not find a way to train the model directly using the code we just executed in this Notebook. For this reason, we have to put all the code we just saw in a python file to be able to run the `fairseq-train` command-line. Since the code is already provided by the authors of the paper in Github, we just download the folder that has the entire code already written in python files, and then we train using those.

In [12]:
%cd /content/
!apt install subversion
!svn checkout https://github.com/jarfo/joint/trunk/models

/content
Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following package was automatically installed and is no longer required:
  libnvidia-common-440
Use 'apt autoremove' to remove it.
The following additional packages will be installed:
  libapr1 libaprutil1 libserf-1-1 libsvn1
Suggested packages:
  db5.3-util libapache2-mod-svn subversion-tools
The following NEW packages will be installed:
  libapr1 libaprutil1 libserf-1-1 libsvn1 subversion
0 upgraded, 5 newly installed, 0 to remove and 35 not upgraded.
Need to get 2,237 kB of archives.
After this operation, 9,910 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic/main amd64 libapr1 amd64 1.6.3-2 [90.9 kB]
Get:2 http://archive.ubuntu.com/ubuntu bionic/main amd64 libaprutil1 amd64 1.6.1-2 [84.4 kB]
Get:3 http://archive.ubuntu.com/ubuntu bionic/universe amd64 libserf-1-1 amd64 1.3.9-6 [44.4 kB]
Get:4 http://archive.ubuntu.com/ubuntu bionic/univ

In [13]:
# Create a Folder to save the Checkpoints
SAVE="/content/checkpoints/local_joint_attention_iwslt_de_en"
!mkdir -p $SAVE

# Use the train function from the fairseq library to train the new model on the IWSLT 2014 dataset
!fairseq-train data-bin/iwslt14.joined-dictionary.31K.de-en \
    --user-dir /content/models/ \
    --arch local_joint_attention_iwslt_de_en \
    --clip-norm 0 --optimizer adam --lr 0.001 --dropout 0.3\
    --source-lang de --target-lang en --max-tokens 4000 --no-progress-bar \
    --log-interval 100 --min-lr '1e-09' --weight-decay 0.0001 \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --lr-scheduler inverse_sqrt \
    --ddp-backend=no_c10d \
    --max-update 85000 --warmup-updates 4000 --warmup-init-lr '1e-07' \
    --adam-betas '(0.9, 0.98)' --adam-eps '1e-09' --keep-last-epochs 10 \
    --arch local_joint_attention_iwslt_de_en --share-all-embeddings \
    --save-dir $SAVE

2020-08-03 20:56:08 | INFO | fairseq_cli.train | Namespace(adam_betas='(0.9, 0.98)', adam_eps=1e-09, all_gather_list_size=16384, arch='local_joint_attention_iwslt_de_en', attention_dropout=0.1, best_checkpoint_metric='loss', bf16=False, bpe=None, broadcast_buffers=False, bucket_cap_mb=25, checkpoint_suffix='', clip_norm=0.0, cpu=False, criterion='label_smoothed_cross_entropy', curriculum=0, data='data-bin/iwslt14.joined-dictionary.31K.de-en', data_buffer_size=10, dataset_impl=None, ddp_backend='no_c10d', decoder_attention_heads=4, decoder_embed_dim=256, decoder_embed_path=None, decoder_ffn_embed_dim=1024, decoder_input_dim=256, decoder_layers=14, decoder_learned_pos=False, decoder_normalize_before=False, decoder_output_dim=256, device_id=0, disable_validation=False, distributed_backend='nccl', distributed_init_method=None, distributed_no_spawn=False, distributed_port=-1, distributed_rank=0, distributed_world_size=1, distributed_wrapper='DDP', dropout=0.3, empty_cache_freq=0, encoder_em

###**3.7.   Average the 10 Last Trained Models**

Because Stochastic Gradient Descent is used for model optimization, namely applying a mini-batch instead of the entire training set to update the model at each step, the last trained model might overfit to the last used mini-batch.

Checkpoint averaging is a method used in machine translation to improve the translation performance of the model by achieving more robust parameters.

The `average_checkpoints.py` file from fairseq allows to average the trainable parameters of a user-defined number of pretrained models, in this case the last 10 saved checkpoints.

In [14]:
!python /content/fairseq/scripts/average_checkpoints.py --inputs $SAVE \
    --num-epoch-checkpoints 10 --output {SAVE}"/checkpoint_last10_avg.pt"

Namespace(checkpoint_upper_bound=None, inputs=['/content/checkpoints/local_joint_attention_iwslt_de_en'], num_epoch_checkpoints=10, num_update_checkpoints=None, output='/content/checkpoints/local_joint_attention_iwslt_de_en/checkpoint_last10_avg.pt')
averaging checkpoints:  ['/content/checkpoints/local_joint_attention_iwslt_de_en/checkpoint83.pt', '/content/checkpoints/local_joint_attention_iwslt_de_en/checkpoint82.pt', '/content/checkpoints/local_joint_attention_iwslt_de_en/checkpoint81.pt', '/content/checkpoints/local_joint_attention_iwslt_de_en/checkpoint80.pt', '/content/checkpoints/local_joint_attention_iwslt_de_en/checkpoint79.pt', '/content/checkpoints/local_joint_attention_iwslt_de_en/checkpoint78.pt', '/content/checkpoints/local_joint_attention_iwslt_de_en/checkpoint77.pt', '/content/checkpoints/local_joint_attention_iwslt_de_en/checkpoint76.pt', '/content/checkpoints/local_joint_attention_iwslt_de_en/checkpoint75.pt', '/content/checkpoints/local_joint_attention_iwslt_de_en/ch

###**3.8.   Evaluation of the Trained Model on the [BLEU Benchmark](https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213)** 


>BLEU stands for **Bilingual Evaluation Understudy**.
It is the most commonly used metric amongst the various existing ones for evaluating models that solve sequence-to-sequence problems, i.e., the output is not merely a classification but a sequence of words that may be of different length than the input sequence.  
The resulting score is a value between **0 and 100**, where the value 100 means the translation is identical with the human reference translation, while a score of 0 indicates that the machine translation has no matches with the human one.  
>>The score is calculated based on **an average of unigram, bigram, trigram and 4-gram precision**, where n-grams are a sequence of n words that occur next to each other in a given text. In the computation of the score, the length of the sentence translated by the model is also penalized if it is shorter compared with the length of the reference translation, and this is known as the **brevity penalty**.  

Included in the fairseq toolkit is also the `fairseq-generate` script for evaluating the BLEU score of the trained model on the test set. 

The above trained model should result in a BLEU score similar to the **35.7** BLEU score reported in the paper, as can be seen on the right side of the [table below](https://arxiv.org/abs/1905.06596).

Even though these scores may sound low, one must not forget that even humans cannot always come up with the perfect translation. The scores of **state-of-the-art models** are usually **between 20 and 40**, which makes the results of this model very impressive and they are one of the highest scores achieved in German-to-English translation.

**Checkpoint averaging** should as well show slightly better results in the BLEU score compared with the results of the best-performing model (on the validation set).

<center><img src=https://d3i71xaburhd42.cloudfront.net/b0f0a5a21619d70748a4dc007983cc111f1b301e/4-Table1-1.png width="450"></center>




In [15]:
!echo 'Evaluating the best-performing trained model'
!fairseq-generate /content/data-bin/iwslt14.joined-dictionary.31K.de-en --user-dir /content/models/ \
    --path {SAVE}"/checkpoint_best.pt" \
    --batch-size 32 --beam 5 --remove-bpe --lenpen 1.7 --gen-subset test --quiet

!echo 'Evaluating the averaged model'
!fairseq-generate /content/data-bin/iwslt14.joined-dictionary.31K.de-en --user-dir /content/models \
    --path {SAVE}"/checkpoint_last10_avg.pt" \
    --batch-size 32 --beam 5 --remove-bpe --lenpen 1.7 --gen-subset test --quiet

Evaluating the best-performing trained model
2020-08-04 04:20:06 | INFO | fairseq_cli.generate | Namespace(all_gather_list_size=16384, beam=5, bf16=False, bpe=None, broadcast_buffers=False, bucket_cap_mb=25, checkpoint_suffix='', cpu=False, criterion='cross_entropy', data='/content/data-bin/iwslt14.joined-dictionary.31K.de-en', data_buffer_size=10, dataset_impl=None, ddp_backend='c10d', decoding_format=None, device_id=0, distributed_backend='nccl', distributed_init_method=None, distributed_no_spawn=False, distributed_port=-1, distributed_rank=0, distributed_world_size=1, distributed_wrapper='DDP', diverse_beam_groups=-1, diverse_beam_strength=0.5, diversity_rate=-1.0, empty_cache_freq=0, eval_bleu=False, eval_bleu_args=None, eval_bleu_detok='space', eval_bleu_detok_args=None, eval_bleu_print_samples=False, eval_bleu_remove_bpe=None, eval_tokenized_bleu=False, fast_stat_sync=False, find_unused_parameters=False, fix_batches_to_gpus=False, force_anneal=None, fp16=False, fp16_init_scale=12

###**3.9.   Interactive Translation**

Using the checkpoints file of the averaged model, it is now possible to input a German sentence in the trained model and generate its English translation in an interactive way.

In [16]:
# first install some libraries
!pip install sacremoses
!pip install subword-nmt

Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |▍                               | 10kB 28.3MB/s eta 0:00:01[K     |▊                               | 20kB 3.4MB/s eta 0:00:01[K     |█▏                              | 30kB 4.4MB/s eta 0:00:01[K     |█▌                              | 40kB 4.8MB/s eta 0:00:01[K     |█▉                              | 51kB 3.9MB/s eta 0:00:01[K     |██▎                             | 61kB 4.3MB/s eta 0:00:01[K     |██▋                             | 71kB 4.8MB/s eta 0:00:01[K     |███                             | 81kB 5.2MB/s eta 0:00:01[K     |███▍                            | 92kB 5.4MB/s eta 0:00:01[K     |███▊                            | 102kB 5.2MB/s eta 0:00:01[K     |████                            | 112kB 5.2MB/s eta 0:00:01[K     |████▌                           | 122kB 5.2MB/s eta 0:00:

In [17]:
# We need to use the bpe tokens code file in 'bpe_codes'
# Bpe_codes takes a file that must be in the same directory as specified in the first argument
# The first directory indicates where the checkpoints were saved
%cd $SAVE
!wget https://raw.githubusercontent.com/wejdene14/NMT-Joint/master/code

# Define the model and restore the averaged checkpoint
de2en = JointAttentionModel.from_pretrained(
  SAVE,
  checkpoint_file='checkpoint_last10_avg.pt',
  data_name_or_path='/content/data-bin/iwslt14.joined-dictionary.31K.de-en/',
  tokenizer='moses',
  source_lang='de',
  target_lang='en',
  bpe='subword_nmt',
  bpe_codes='code'
)

# If dropout was used during training, disable dropout before testing
de2en.eval()

/content/checkpoints/local_joint_attention_iwslt_de_en
--2020-08-04 04:25:08--  https://raw.githubusercontent.com/wejdene14/NMT-Joint/master/code
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 353021 (345K) [text/plain]
Saving to: ‘code’


2020-08-04 04:25:08 (9.84 MB/s) - ‘code’ saved [353021/353021]



GeneratorHubInterface(
  (models): ModuleList(
    (0): JointAttentionModel(
      (encoder): JointAttentionEncoder(
        (embed_tokens): Embedding(30760, 256, padding_idx=1)
        (embed_positions): SinusoidalPositionalEmbedding()
      )
      (decoder): JointAttentionDecoder(
        (embed_tokens): Embedding(30760, 256, padding_idx=1)
        (embed_positions): SinusoidalPositionalEmbedding()
        (layers): ModuleList(
          (0): ProtectedTransformerDecoderLayer(
            (self_attn): ProtectedMultiheadAttention(
              (out_proj): Linear(in_features=256, out_features=256, bias=True)
            )
            (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
            (fc1): Linear(in_features=256, out_features=1024, bias=True)
            (fc2): Linear(in_features=1024, out_features=256, bias=True)
            (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          )
          (1): ProtectedTransformer

In [18]:
# Translate a German sentence
de = 'Hallo Welt!'
en = de2en.translate(de.lower())
print(en)

hello world!


##**4. Next Step**

This new architecture produces state-of-the-art results in the German-English translation and outperforms other models by at least 0.5 in the BLEU score.

With this provided Notebook, it is possible to experiment by training on a different and/or larger dataset, such as the [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) dataset, as well as playing around with the hyperparameters of the model.