<a href="https://colab.research.google.com/github/paulxuereb/ML/blob/master/Seq2SeqAllenNLP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

http://www.realworldnlpbook.com/blog/building-seq2seq-machine-translation-models-using-allennlp.html

In [1]:
!pip install AllenNLP



In [2]:
    
import itertools

import torch
import torch.optim as optim
from allennlp.data.dataset_readers.seq2seq import Seq2SeqDatasetReader
from allennlp.data.iterators import BucketIterator
from allennlp.data.token_indexers import SingleIdTokenIndexer
from allennlp.data.tokenizers.character_tokenizer import CharacterTokenizer
from allennlp.data.tokenizers.word_tokenizer import WordTokenizer
from allennlp.data.vocabulary import Vocabulary
from allennlp.nn.activations import Activation
from allennlp.models.encoder_decoders.simple_seq2seq import SimpleSeq2Seq
from allennlp.modules.attention import LinearAttention, BilinearAttention, DotProductAttention
from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper, StackedSelfAttentionEncoder
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.predictors import SimpleSeq2SeqPredictor
from allennlp.training.trainer import Trainer

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


Setup data files

In [0]:
!mkdir data

In [0]:
%cd data

In [4]:
!wget http://downloads.tatoeba.org/exports/sentences.tar.bz2

--2019-04-24 11:51:23--  http://downloads.tatoeba.org/exports/sentences.tar.bz2
Resolving downloads.tatoeba.org (downloads.tatoeba.org)... 94.130.77.194
Connecting to downloads.tatoeba.org (downloads.tatoeba.org)|94.130.77.194|:80... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://downloads.tatoeba.org/exports/sentences.tar.bz2 [following]
--2019-04-24 11:51:23--  https://downloads.tatoeba.org/exports/sentences.tar.bz2
Connecting to downloads.tatoeba.org (downloads.tatoeba.org)|94.130.77.194|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 118227705 (113M) [application/octet-stream]
Saving to: ‘sentences.tar.bz2’


2019-04-24 11:51:29 (22.2 MB/s) - ‘sentences.tar.bz2’ saved [118227705/118227705]



In [5]:
!wget http://downloads.tatoeba.org/exports/links.tar.bz2

URL transformed to HTTPS due to an HSTS policy
--2019-04-24 11:51:31--  https://downloads.tatoeba.org/exports/links.tar.bz2
Resolving downloads.tatoeba.org (downloads.tatoeba.org)... 94.130.77.194
Connecting to downloads.tatoeba.org (downloads.tatoeba.org)|94.130.77.194|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 80800484 (77M) [application/octet-stream]
Saving to: ‘links.tar.bz2’


2019-04-24 11:51:35 (20.3 MB/s) - ‘links.tar.bz2’ saved [80800484/80800484]



In [14]:
!ls

create_bitext.py  links.csv  links.tar.bz2  sentences.csv  sentences.tar.bz2


In [6]:
!tar -xvf sentences.tar.bz2

sentences.csv


In [7]:
!tar -xvf links.tar.bz2

links.csv


In [8]:
!wget https://raw.githubusercontent.com/mhagiwara/realworldnlp/master/examples/mt/create_bitext.py

--2019-04-24 11:52:11--  https://raw.githubusercontent.com/mhagiwara/realworldnlp/master/examples/mt/create_bitext.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4870 (4.8K) [text/plain]
Saving to: ‘create_bitext.py’


2019-04-24 11:52:11 (73.2 MB/s) - ‘create_bitext.py’ saved [4870/4870]



In [0]:
!python create_bitext.py eng_cmn sentences.csv links.csv \
    | cut -f3,6 > tatoeba.eng_cmn.tsv

In [0]:
!cat tatoeba.eng_cmn.tsv | awk 'NR%10==1' > tatoeba.eng_cmn.test.tsv
!cat tatoeba.eng_cmn.tsv | awk 'NR%10==2' > tatoeba.eng_cmn.dev.tsv
!cat tatoeba.eng_cmn.tsv | awk 'NR%10!=1&&NR%10!=2' > tatoeba.eng_cmn.train.tsv

In [11]:
!ls

create_bitext.py  sample_data		   tatoeba.eng_cmn.test.tsv
data		  sentences.csv		   tatoeba.eng_cmn.train.tsv
links.csv	  sentences.tar.bz2	   tatoeba.eng_cmn.tsv
links.tar.bz2	  tatoeba.eng_cmn.dev.tsv


Build the model

In [12]:
EN_EMBEDDING_DIM = 256
ZH_EMBEDDING_DIM = 256
HIDDEN_DIM = 256
CUDA_DEVICE = 0


reader = Seq2SeqDatasetReader(
        source_tokenizer=WordTokenizer(),
        target_tokenizer=CharacterTokenizer(),
        source_token_indexers={'tokens': SingleIdTokenIndexer()},
        target_token_indexers={'tokens': SingleIdTokenIndexer(namespace='target_tokens')})
train_dataset = reader.read('tatoeba.eng_cmn.train.tsv')
validation_dataset = reader.read('tatoeba.eng_cmn.dev.tsv')

vocab = Vocabulary.from_instances(train_dataset + validation_dataset,
                                  min_count={'tokens': 3, 'target_tokens': 3})

en_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                         embedding_dim=EN_EMBEDDING_DIM)
# encoder = PytorchSeq2SeqWrapper(
#     torch.nn.LSTM(EN_EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))
encoder = StackedSelfAttentionEncoder(input_dim=EN_EMBEDDING_DIM, hidden_dim=HIDDEN_DIM, projection_dim=128, feedforward_hidden_dim=128, num_layers=1, num_attention_heads=8)

source_embedder = BasicTextFieldEmbedder({"tokens": en_embedding})



# attention = LinearAttention(HIDDEN_DIM, HIDDEN_DIM, activation=Activation.by_name('tanh')())
# attention = BilinearAttention(HIDDEN_DIM, HIDDEN_DIM)
attention = DotProductAttention()

max_decoding_steps = 20   # TODO: make this variable
model = SimpleSeq2Seq(vocab, source_embedder, encoder, max_decoding_steps,
                      target_embedding_dim=ZH_EMBEDDING_DIM,
                      target_namespace='target_tokens',
                      attention=attention,
                      beam_size=8,
                      use_bleu=True)
optimizer = optim.Adam(model.parameters())
iterator = BucketIterator(batch_size=32, sorting_keys=[("source_tokens", "num_tokens")])

iterator.index_with(vocab)


36204it [00:06, 5825.75it/s]
4526it [00:00, 8821.02it/s]
100%|██████████| 40730/40730 [00:00<00:00, 40903.98it/s]


In [0]:
CUDA_DEVICE = 0

if torch.cuda.is_available():
    cuda_device = 0
    model = model.cuda(cuda_device)
else:
    cuda_device = -1
    

trainer = Trainer(model=model,
                      optimizer=optimizer,
                      iterator=iterator,
                      train_dataset=train_dataset,
                      validation_dataset=validation_dataset,
                      num_epochs=1,
                      cuda_device=CUDA_DEVICE)

for i in range(50):
    print('Epoch: {}'.format(i))
    trainer.train()

    predictor = SimpleSeq2SeqPredictor(model, reader)

    #do predictions on last epoch
    if i >= 49:
      for instance in itertools.islice(validation_dataset, 10):
          print('SOURCE:', instance.fields['source_tokens'].tokens)
          print('GOLD:', instance.fields['target_tokens'].tokens)
          print('PRED:', predictor.predict_instance(instance)['predicted_tokens'])

You provided a validation dataset but patience was set to None, meaning that early stopping is disabled


  0%|          | 0/1132 [00:00<?, ?it/s][A[A

Epoch: 0




loss: 1.2275 ||:   0%|          | 1/1132 [00:00<17:21,  1.09it/s][A[A

loss: 1.3891 ||:   0%|          | 5/1132 [00:01<12:16,  1.53it/s][A[A

loss: 1.4390 ||:   1%|          | 9/1132 [00:01<08:43,  2.14it/s][A[A

loss: 1.5832 ||:   1%|          | 12/1132 [00:01<06:19,  2.95it/s][A[A

loss: 1.5184 ||:   1%|▏         | 16/1132 [00:01<04:34,  4.07it/s][A[A

loss: 1.4644 ||:   2%|▏         | 21/1132 [00:01<03:19,  5.57it/s][A[A

loss: 1.4193 ||:   2%|▏         | 25/1132 [00:01<02:27,  7.49it/s][A[A

loss: 1.4388 ||:   3%|▎         | 29/1132 [00:01<01:52,  9.82it/s][A[A

loss: 1.4949 ||:   3%|▎         | 33/1132 [00:01<01:28, 12.46it/s][A[A

loss: 1.4880 ||:   3%|▎         | 37/1132 [00:01<01:11, 15.33it/s][A[A

loss: 1.4892 ||:   4%|▎         | 42/1132 [00:02<00:58, 18.58it/s][A[A

loss: 1.5011 ||:   4%|▍         | 46/1132 [00:02<00:51, 21.28it/s][A[A

loss: 1.5502 ||:   4%|▍         | 50/1132 [00:02<00:49, 21.93it/s][A[A

loss: 1.5513 ||:   5%|▍         | 54/11

Epoch: 1



loss: 1.3305 ||:   0%|          | 1/1132 [00:01<30:44,  1.63s/it][A
loss: 1.6823 ||:   0%|          | 4/1132 [00:01<21:40,  1.15s/it][A
loss: 1.6099 ||:   1%|          | 8/1132 [00:01<15:16,  1.23it/s][A
loss: 1.5091 ||:   1%|          | 12/1132 [00:01<10:48,  1.73it/s][A
loss: 1.4723 ||:   1%|▏         | 16/1132 [00:02<07:42,  2.41it/s][A
loss: 1.4859 ||:   2%|▏         | 20/1132 [00:02<05:31,  3.35it/s][A
loss: 1.4655 ||:   2%|▏         | 24/1132 [00:02<04:00,  4.60it/s][A
loss: 1.4321 ||:   2%|▏         | 28/1132 [00:02<02:56,  6.26it/s][A
loss: 1.3787 ||:   3%|▎         | 33/1132 [00:02<02:10,  8.40it/s][A
loss: 1.3746 ||:   3%|▎         | 37/1132 [00:02<01:40, 10.91it/s][A
loss: 1.4276 ||:   4%|▎         | 41/1132 [00:02<01:21, 13.33it/s][A
loss: 1.4016 ||:   4%|▍         | 45/1132 [00:02<01:05, 16.59it/s][A
loss: 1.4195 ||:   4%|▍         | 49/1132 [00:03<00:56, 19.27it/s][A
loss: 1.4057 ||:   5%|▍         | 53/1132 [00:03<00:47, 22.66it/s][A
loss: 1.4236 ||:   5%|

Epoch: 2



loss: 1.6327 ||:   0%|          | 1/1132 [00:00<16:53,  1.12it/s][A
loss: 1.6793 ||:   0%|          | 5/1132 [00:01<11:58,  1.57it/s][A
loss: 1.7309 ||:   1%|          | 8/1132 [00:01<08:34,  2.18it/s][A
loss: 1.7514 ||:   1%|          | 11/1132 [00:01<06:11,  3.02it/s][A
loss: 1.5137 ||:   1%|▏         | 16/1132 [00:01<04:26,  4.18it/s][A
loss: 1.4881 ||:   2%|▏         | 20/1132 [00:01<03:16,  5.66it/s][A
loss: 1.4285 ||:   2%|▏         | 24/1132 [00:01<02:26,  7.58it/s][A
loss: 1.3958 ||:   2%|▏         | 28/1132 [00:01<01:50, 10.01it/s][A
loss: 1.3925 ||:   3%|▎         | 32/1132 [00:01<01:25, 12.80it/s][A
loss: 1.3745 ||:   3%|▎         | 36/1132 [00:01<01:09, 15.79it/s][A
loss: 1.3613 ||:   4%|▎         | 40/1132 [00:02<00:57, 19.14it/s][A
loss: 1.4166 ||:   4%|▍         | 44/1132 [00:02<00:54, 20.07it/s][A
loss: 1.3936 ||:   4%|▍         | 48/1132 [00:02<00:46, 23.43it/s][A
loss: 1.3875 ||:   5%|▍         | 52/1132 [00:02<00:42, 25.67it/s][A
loss: 1.3628 ||:   5%|

In [18]:
for instance in itertools.islice(validation_dataset, 10):
          print('SOURCE:', instance.fields['source_tokens'].tokens)
          print('GOLD:', instance.fields['target_tokens'].tokens)
          print('PRED:', predictor.predict_instance(instance)['predicted_tokens'])

SOURCE: [@start@, I, have, to, go, to, sleep, ., @end@]
GOLD: [@start@, 我, 该, 去, 睡, 觉, 了, 。, @end@]
PRED: ['我', '睡', '觉', '了', '。']
SOURCE: [@start@, I, just, do, n't, know, what, to, say, ., @end@]
GOLD: [@start@, 我, 就, 是, 不, 知, 道, 說, 些, 什, 麼, 。, @end@]
PRED: ['我', '不', '懂', '这', '句', '话', '。']
SOURCE: [@start@, I, may, give, up, soon, and, just, nap, instead, ., @end@]
GOLD: [@start@, 也, 许, 我, 会, 马, 上, 放, 弃, 然, 后, 去, 睡, 一, 觉, 。, @end@]
PRED: ['我', '累', '死', '了', '，', '再', '也', '走', '了', '。']
SOURCE: [@start@, I, 'm, going, to, go, ., @end@]
GOLD: [@start@, 我, 要, 走, 了, 。, @end@]
PRED: ['我', '要', '去', '。']
SOURCE: [@start@, That, 's, MY, line, !, @end@]
GOLD: [@start@, 那, 是, 我, 的, 台, 词, ！, @end@]
PRED: ['那', '是', '安', '靜', '！']
SOURCE: [@start@, It, does, n't, surprise, me, ., @end@]
GOLD: [@start@, 这, 并, 不, 让, 我, 惊, 讶, 。, @end@]
PRED: ['不', '要', '問', '我', '。']
SOURCE: [@start@, I, 'm, not, a, real, fish, ,, I, 'm, just, a, mere, plushy, ., @end@]
GOLD: [@start@, 我, 不, 是, 一, 条, 真, 的, 鱼