# Transformer (NMT) Translation
- Based on : https://pytorch.org/hub/pytorch_fairseq_translation/
    *Author: Facebook AI (fairseq Team)*

## Translate CNN Daily Mail
We will translate the English Dataset CNN Daily Mail to German to try German text summarization

### Model Description

The Transformer, introduced in the paper [Attention Is All You Need][1], is a
powerful sequence-to-sequence modeling architecture capable of producing
state-of-the-art neural machine translation (NMT) systems.

Recently, the fairseq team has explored large-scale semi-supervised training of
Transformers using back-translated data, further improving translation quality
over the original model. More details can be found in [this blog post][2].


### Requirements

We require a few additional Python dependencies for preprocessing:
-  pip install fastBPE regex requests sacremoses subword_nmt Cython

In [1]:
! nvidia-smi

Thu Jul  9 11:37:15 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 435.21       Driver Version: 435.21       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce RTX 208...  Off  | 00000000:08:00.0  On |                  N/A |
| 35%   33C    P8    21W / 260W |    444MiB / 11016MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage    

### English-to-German Translation

Semi-supervised training with back-translation is an effective way of improving
translation systems. In the paper [Understanding Back-Translation at Scale][4],
we back-translate over 200 million German sentences to use as additional
training data. An ensemble of five of these models was the winning submission to
the [WMT'18 English-German news translation competition][5].

We can further improved this approach through [noisy-channel reranking][6]. More
details can be found in [this blog post][7]. An ensemble of models trained with
this technique was the winning submission to the [WMT'19 English-German news
translation competition][8].

To translate from English to German using one of the models from the winning submission:

### References

- [Attention Is All You Need][1]
- [Scaling Neural Machine Translation][3]
- [Understanding Back-Translation at Scale][4]
- [Facebook FAIR's WMT19 News Translation Task Submission][6]


[1]: https://arxiv.org/abs/1706.03762
[2]: https://code.fb.com/ai-research/scaling-neural-machine-translation-to-bigger-data-sets-with-faster-training-and-inference/
[3]: https://arxiv.org/abs/1806.00187
[4]: https://arxiv.org/abs/1808.09381
[5]: http://www.statmt.org/wmt18/translation-task.html
[6]: https://arxiv.org/abs/1907.06616
[7]: https://ai.facebook.com/blog/facebook-leads-wmt-translation-competition/
[8]: http://www.statmt.org/wmt19/translation-task.html

In [2]:
import torch

from pathlib import Path
from segtok.segmenter import split_single
import time

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
first_run = False

ds_name = "train"

### Write tfds Dataset to files

In [6]:
if first_run:
    import tensorflow as tf
    import tensorflow_datasets as tfds

    cnn_dailymail = tfds.load(name="cnn_dailymail")

    train_tfds = cnn_dailymail['train']
    test_tfds = cnn_dailymail['test']
    val_tfds = cnn_dailymail['validation']

    train_ds_iter = tfds.as_numpy(train_tfds)
    val_ds_iter = tfds.as_numpy(val_tfds)
    test_ds_iter = tfds.as_numpy(test_tfds)


    def write_data(iter_dataset, name, path="../data/"):

        articles_file = Path(path + name + "/article").open("w")
        highlights_file = Path(path + name + "/highlights").open("w")

        for item in iter_dataset:
            articles_file.write(item["article"].decode("utf-8") + "\n")
            articles_file.flush()
            highlights_file.write(item["highlights"].decode("utf-8").replace("\n", " ") + "\n")
            highlights_file.flush()

    write_data(train_ds_iter, "train")
    write_data(test_ds_iter, "test")
    write_data(val_ds_iter, "val")

## Translation Model

In [7]:
# Load an En-De Transformer model trained on WMT'19 data:
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model', tokenizer='moses', bpe='fastbpe')

# Access the underlying TransformerModel
assert isinstance(en2de.models[0], torch.nn.Module)

# Translate from En-De
de = en2de.translate('PyTorch Hub is a pre-trained model repository designed to facilitate research reproducibility.')
assert de == 'PyTorch Hub ist ein vorgefertigtes Modell-Repository, das die Reproduzierbarkeit der Forschung erleichtern soll.'

# to gpu
en2de = en2de.to(device)

Using cache found in /home/yannik/.cache/torch/hub/pytorch_fairseq_master


## Load Dataset

In [8]:
def read_files(name):
    article_path = "../../data/%s/article" % name
    highlights_path = "../../data/%s/highlights" % name
    
    articles = [x.rstrip() for x in open(article_path).readlines()]
    highlights = [x.rstrip() for x in open(highlights_path).readlines()]
    
    assert len(articles) == len(highlights)
    return articles, highlights

In [9]:
def split_in_sentences(text):
    return split_single(text)

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, articles):
        self.x = articles
        
    def __getitem__(self, index):
        sentences = split_in_sentences(self.x[index]) 
        ret_x = []
        for i, sent in enumerate(sentences): 
            
            ret_x.append(sent[:1024])
            
        
        return ret_x
    
    @staticmethod
    def transfrom(x):
        x = x.lower()
        x = re.sub("'(.*)'", r"\1", x)
        return x
    
    def __len__(self):
        return len(self.x)

In [10]:
articles, highlights = read_files(ds_name)

In [11]:
articles_ds = MyDataset(articles)

highlights_ds = MyDataset(highlights)

In [12]:
class FileWriter:
    def __init__(self, ds_name, name, path="../../data/"):
        self.path = path + ds_name + "/"+ name + "_german"
        self.file = Path(self.path).open("a")
        
    def write_translated(self, i, list_str):    
        result = str(i) + "; "
        for item_str in list_str:
            result += item_str + " "
        self.file.write(result.replace("\n", " ") + "\n")
        self.file.flush()
    
    def get_last_index(self):
        with open(self.path) as fileObj:
            ret_list = list(fileObj)
            if len(ret_list) > 0: 
                return int(ret_list[-1].split(";")[0])
            else: 
                return 0
        
    
file_writer = FileWriter(ds_name, "articles")  
file_writer.get_last_index()

287112

In [13]:
file_writer = FileWriter(ds_name, "highlights")  
file_writer.get_last_index()

287112

In [14]:
def translate(ds, ds_name, name, log_interval=1000):
    len_ds = len(ds)

    file_writer = FileWriter(ds_name, name)
    first_index = file_writer.get_last_index()
    start_time = time.time()
    for i in range(first_index, len_ds):
        predictions = en2de.translate(ds[i])
        file_writer.write_translated(i, predictions)
        elapsed = time.time() - start_time  
        if ((i+1) % log_interval) == 0:
            elapsed = time.time() - start_time  
            print("| [{:5d}/{:5d}] | ms/ds_point {:5.2f} |".format(i, len(articles_ds), (elapsed * 1000 / log_interval)))
            start_time = time.time()


In [15]:
translate(articles_ds, ds_name, "articles")

In [27]:
translate(highlights_ds, ds_name, "highlights")

| [  999/13368] | ms/ds_point 522.94 |
| [ 1999/13368] | ms/ds_point 504.64 |
| [ 2999/13368] | ms/ds_point 511.83 |
| [ 3999/13368] | ms/ds_point 520.32 |
| [ 4999/13368] | ms/ds_point 538.62 |
| [ 5999/13368] | ms/ds_point 526.41 |
| [ 6999/13368] | ms/ds_point 533.31 |
| [ 7999/13368] | ms/ds_point 526.18 |
| [ 8999/13368] | ms/ds_point 537.70 |
| [ 9999/13368] | ms/ds_point 535.05 |
| [10999/13368] | ms/ds_point 535.51 |
| [11999/13368] | ms/ds_point 525.89 |
| [12999/13368] | ms/ds_point 526.42 |
