<a href="https://colab.research.google.com/github/paulxuereb/ML/blob/master/Seq2SeqAllenNLP_ChitChat.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

http://www.realworldnlpbook.com/blog/building-seq2seq-machine-translation-models-using-allennlp.html

In [1]:
!pip install AllenNLP



In [0]:
    
import itertools

import torch
import torch.optim as optim
from allennlp.data.dataset_readers.seq2seq import Seq2SeqDatasetReader
from allennlp.data.iterators import BucketIterator
from allennlp.data.token_indexers import SingleIdTokenIndexer
from allennlp.data.tokenizers.character_tokenizer import CharacterTokenizer
from allennlp.data.tokenizers.word_tokenizer import WordTokenizer
from allennlp.data.vocabulary import Vocabulary
from allennlp.nn.activations import Activation
from allennlp.models.encoder_decoders.simple_seq2seq import SimpleSeq2Seq
from allennlp.modules.attention import LinearAttention, BilinearAttention, DotProductAttention
from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper, StackedSelfAttentionEncoder
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.predictors import SimpleSeq2SeqPredictor
from allennlp.training.trainer import Trainer
from allennlp.modules.token_embedders.bert_token_embedder import PretrainedBertEmbedder

Setup data files

In [0]:
!mkdir data

In [0]:
%cd data

/content/data


In [0]:
!wget https://raw.githubusercontent.com/mhagiwara/realworldnlp/master/examples/mt/create_bitext.py

--2019-04-28 09:34:34--  https://raw.githubusercontent.com/mhagiwara/realworldnlp/master/examples/mt/create_bitext.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4870 (4.8K) [text/plain]
Saving to: ‘create_bitext.py’


2019-04-28 09:34:34 (97.4 MB/s) - ‘create_bitext.py’ saved [4870/4870]



Grab the Microsoft chat files

In [0]:
!wget https://qnamakerstore.blob.core.windows.net/qnamakerdata/editorial/qna_chitchat_the_professional.tsv

--2019-04-28 09:34:48--  https://qnamakerstore.blob.core.windows.net/qnamakerdata/editorial/qna_chitchat_the_professional.tsv
Resolving qnamakerstore.blob.core.windows.net (qnamakerstore.blob.core.windows.net)... 13.88.144.240
Connecting to qnamakerstore.blob.core.windows.net (qnamakerstore.blob.core.windows.net)|13.88.144.240|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 62473 (61K) [text/tab-separated-values]
Saving to: ‘qna_chitchat_the_professional.tsv’


2019-04-28 09:34:49 (584 KB/s) - ‘qna_chitchat_the_professional.tsv’ saved [62473/62473]



In [5]:
import pandas as pd
df = pd.read_csv("qna_chitchat_the_professional.tsv",sep='\t')
#df.drop(['Source', 'Metadata'])
#df1 = df[['user1','user2']]
#df1

FileNotFoundError: ignored

In [0]:
df1 = df[['Question','Answer']]
df1
df1.to_csv('chat.tsv', sep = '\t', header=False, index=False)

In [0]:
!head chat.tsv

What's your age?	Age doesn't really apply to me.
Are you young?	Age doesn't really apply to me.
When were you born?	Age doesn't really apply to me.
What age are you?	Age doesn't really apply to me.
Are you old?	Age doesn't really apply to me.
How old are you?	Age doesn't really apply to me.
How long ago were you born?	Age doesn't really apply to me.
Ask me anything	I'm better at answering questions.
Ask me a question	I'm better at answering questions.
Can you ask me a question?	I'm better at answering questions.


In [0]:
!cat chat.tsv | awk 'NR%10==1' > chat.eng_cmn.test.tsv
!cat chat.tsv | awk 'NR%10==2' > chat.eng_cmn.dev.tsv
!cat chat.tsv | awk 'NR%10!=1&&NR%10!=2' > chat.eng_cmn.train.tsv

In [0]:
!ls
#!head chat.eng_cmn.dev.tsv

chat.eng_cmn.dev.tsv	chat.tsv
chat.eng_cmn.test.tsv	create_bitext.py
chat.eng_cmn.train.tsv	qna_chitchat_the_professional.tsv


Build the model

Seq2Seq with attention (non BERT) - seems to be working better than when using BERT

In [0]:
#!wget -O config.json 
#!ls
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
# Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

In [0]:
%cd ..

/content


In [0]:
#this is the shareable link of the file temps.csv in my google drive
link = 'https://drive.google.com/open?id=1MkC2EEMqihJPMuw4V4VahLbQO4bdi_B4'
fluff, id = link.split('=')
#print (id) # Verify that you have everything after '='
downloaded = drive.CreateFile({'id':id}) 
downloaded.GetContentFile('config.json')  #name of file is irrelevant

In [0]:
#!head config.json
!mv config.json /content/data


In [0]:
EN_EMBEDDING_DIM = 256
ZH_EMBEDDING_DIM = 256
HIDDEN_DIM = 256
CUDA_DEVICE = 0

serialization_dir = '/content/model/'


reader = Seq2SeqDatasetReader(
        source_tokenizer=WordTokenizer(),
        target_tokenizer=CharacterTokenizer(),
        source_token_indexers={'tokens': SingleIdTokenIndexer()},
        target_token_indexers={'tokens': SingleIdTokenIndexer(namespace='target_tokens')})
train_dataset = reader.read('/content/data/chat.eng_cmn.train.tsv')
validation_dataset = reader.read('/content/data/chat.eng_cmn.dev.tsv')

vocab = Vocabulary.from_instances(train_dataset + validation_dataset,
                                  min_count={'tokens': 3, 'target_tokens': 3})

en_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                         embedding_dim=EN_EMBEDDING_DIM)
# encoder = PytorchSeq2SeqWrapper(
#     torch.nn.LSTM(EN_EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))
encoder = StackedSelfAttentionEncoder(input_dim=EN_EMBEDDING_DIM, hidden_dim=HIDDEN_DIM, projection_dim=128, feedforward_hidden_dim=128, num_layers=1, num_attention_heads=8)

source_embedder = BasicTextFieldEmbedder({"tokens": en_embedding})



# attention = LinearAttention(HIDDEN_DIM, HIDDEN_DIM, activation=Activation.by_name('tanh')())
# attention = BilinearAttention(HIDDEN_DIM, HIDDEN_DIM)
attention = DotProductAttention()

max_decoding_steps = 100   # TODO: make this variable
model = SimpleSeq2Seq(vocab, source_embedder, encoder, max_decoding_steps,
                      target_embedding_dim=ZH_EMBEDDING_DIM,
                      target_namespace='target_tokens',
                      attention=attention,
                      beam_size=8,
                      use_bleu=True)
optimizer = optim.Adam(model.parameters())
iterator = BucketIterator(batch_size=32, sorting_keys=[("source_tokens", "num_tokens")])

iterator.index_with(vocab)

CUDA_DEVICE = 0

if torch.cuda.is_available():
    cuda_device = 0
    model = model.cuda(cuda_device)
else:
    cuda_device = -1
    

trainer = Trainer(model=model,
                      optimizer=optimizer,
                      iterator=iterator,
                      train_dataset=train_dataset,
                      validation_dataset=validation_dataset,
                      num_epochs=1,
                      cuda_device=CUDA_DEVICE)
                      #serialization_dir=serialization_dir,                 
                          #num_serialized_models_to_keep=2,
                          #keep_serialized_model_every_num_seconds=1)



0it [00:00, ?it/s][A[A

525it [00:00, 10903.74it/s][A[A

0it [00:00, ?it/s][A[A

66it [00:00, 6379.61it/s][A[A

  0%|          | 0/591 [00:00<?, ?it/s][A[A

100%|██████████| 591/591 [00:00<00:00, 27521.80it/s][A[AYou provided a validation dataset but patience was set to None, meaning that early stopping is disabled


In [0]:
vocab.get_vocab_size('tokens')

159

In [0]:
for i in range(4):
    print('Epoch: {}'.format(i))
    metrics = trainer.train()

    predictor = SimpleSeq2SeqPredictor(model, reader)

    #do predictions on last epoch
    if i >= 49:
      for instance in itertools.islice(validation_dataset, 10):
          print('SOURCE:', instance.fields['source_tokens'].tokens)
          print('GOLD:', instance.fields['target_tokens'].tokens)
          print('PRED:', predictor.predict_instance(instance)['predicted_tokens'])
          
metrics



  0%|          | 0/17 [00:00<?, ?it/s][A[A

Epoch: 0



loss: 3.9738 ||:   6%|▌         | 1/17 [00:00<00:04,  3.53it/s][A
loss: 3.7976 ||:  18%|█▊        | 3/17 [00:00<00:03,  4.64it/s][A
loss: 3.6487 ||:  29%|██▉       | 5/17 [00:00<00:02,  5.89it/s][A
loss: 3.5319 ||:  41%|████      | 7/17 [00:00<00:01,  7.09it/s][A
loss: 3.4468 ||:  53%|█████▎    | 9/17 [00:00<00:00,  8.27it/s][A
loss: 3.4014 ||:  65%|██████▍   | 11/17 [00:00<00:00,  9.86it/s][A
loss: 3.3575 ||:  76%|███████▋  | 13/17 [00:01<00:00, 10.92it/s][A
loss: 3.3196 ||:  88%|████████▊ | 15/17 [00:01<00:00, 12.19it/s][A
loss: 3.2835 ||: 100%|██████████| 17/17 [00:01<00:00, 13.16it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
BLEU: 0.0000, loss: 3.0496 ||:  33%|███▎      | 1/3 [00:00<00:00,  3.37it/s][A
BLEU: 0.0000, loss: 3.0700 ||:  67%|██████▋   | 2/3 [00:00<00:00,  3.30it/s][A
BLEU: 0.0000, loss: 3.0542 ||: 100%|██████████| 3/3 [00:00<00:00,  4.03it/s][A
[A
  0%|          | 0/17 [00:00<?, ?it/s][A

Epoch: 1



loss: 2.9837 ||:  12%|█▏        | 2/17 [00:00<00:01, 14.55it/s][A
loss: 3.0088 ||:  24%|██▎       | 4/17 [00:00<00:00, 14.74it/s][A
loss: 3.0331 ||:  35%|███▌      | 6/17 [00:00<00:00, 14.58it/s][A
loss: 3.0362 ||:  47%|████▋     | 8/17 [00:00<00:00, 14.66it/s][A
loss: 3.0389 ||:  59%|█████▉    | 10/17 [00:00<00:00, 15.26it/s][A
loss: 3.0336 ||:  71%|███████   | 12/17 [00:00<00:00, 15.63it/s][A
loss: 3.0310 ||:  82%|████████▏ | 14/17 [00:00<00:00, 15.68it/s][A
loss: 3.0249 ||:  94%|█████████▍| 16/17 [00:01<00:00, 15.77it/s][A
loss: 3.0215 ||: 100%|██████████| 17/17 [00:01<00:00, 15.44it/s][A
  0%|          | 0/3 [00:00<?, ?it/s][A
BLEU: 0.0000, loss: 2.9610 ||:  33%|███▎      | 1/3 [00:00<00:00,  3.05it/s][A
BLEU: 0.0000, loss: 2.9760 ||:  67%|██████▋   | 2/3 [00:00<00:00,  2.98it/s][A
BLEU: 0.0000, loss: 2.9837 ||: 100%|██████████| 3/3 [00:00<00:00,  3.71it/s][A
[A
  0%|          | 0/17 [00:00<?, ?it/s][A
loss: 2.9953 ||:  12%|█▏        | 2/17 [00:00<00:01, 14.69it/s]

Epoch: 2



loss: 2.9730 ||:  24%|██▎       | 4/17 [00:00<00:00, 15.25it/s][A
loss: 2.9579 ||:  35%|███▌      | 6/17 [00:00<00:00, 14.78it/s][A
loss: 2.9500 ||:  47%|████▋     | 8/17 [00:00<00:00, 14.56it/s][A
loss: 2.9509 ||:  59%|█████▉    | 10/17 [00:00<00:00, 15.06it/s][A
loss: 2.9511 ||:  71%|███████   | 12/17 [00:00<00:00, 14.88it/s][A
loss: 2.9448 ||:  82%|████████▏ | 14/17 [00:00<00:00, 15.16it/s][A
loss: 2.9342 ||:  94%|█████████▍| 16/17 [00:01<00:00, 15.57it/s][A
loss: 2.9296 ||: 100%|██████████| 17/17 [00:01<00:00, 15.07it/s][A
  0%|          | 0/3 [00:00<?, ?it/s][A
BLEU: 0.0000, loss: 2.8552 ||:  33%|███▎      | 1/3 [00:00<00:00,  3.26it/s][A
BLEU: 0.0000, loss: 2.8664 ||:  67%|██████▋   | 2/3 [00:00<00:00,  3.23it/s][A
BLEU: 0.0000, loss: 2.8909 ||: 100%|██████████| 3/3 [00:00<00:00,  3.92it/s][A
[A
  0%|          | 0/17 [00:00<?, ?it/s][A
loss: 2.8491 ||:  12%|█▏        | 2/17 [00:00<00:00, 15.77it/s][A

Epoch: 3



loss: 2.8516 ||:  24%|██▎       | 4/17 [00:00<00:00, 15.67it/s][A
loss: 2.8474 ||:  35%|███▌      | 6/17 [00:00<00:00, 16.54it/s][A
loss: 2.8226 ||:  47%|████▋     | 8/17 [00:00<00:00, 15.83it/s][A
loss: 2.8155 ||:  59%|█████▉    | 10/17 [00:00<00:00, 16.38it/s][A
loss: 2.8017 ||:  71%|███████   | 12/17 [00:00<00:00, 15.66it/s][A
loss: 2.7854 ||:  82%|████████▏ | 14/17 [00:00<00:00, 15.31it/s][A
loss: 2.7845 ||:  94%|█████████▍| 16/17 [00:01<00:00, 15.02it/s][A
loss: 2.7789 ||: 100%|██████████| 17/17 [00:01<00:00, 15.44it/s][A
  0%|          | 0/3 [00:00<?, ?it/s][A
BLEU: 0.0000, loss: 2.6512 ||:  33%|███▎      | 1/3 [00:00<00:00,  3.12it/s][A
BLEU: 0.0000, loss: 2.6688 ||:  67%|██████▋   | 2/3 [00:00<00:00,  3.13it/s][A
BLEU: 0.0062, loss: 2.7077 ||: 100%|██████████| 3/3 [00:00<00:00,  3.83it/s][A
[A

{'best_epoch': 0,
 'best_validation_BLEU': 0.0062080515814421575,
 'best_validation_loss': 2.707716464996338,
 'epoch': 0,
 'peak_cpu_memory_MB': 3264.216,
 'peak_gpu_0_memory_MB': 825,
 'training_cpu_memory_MB': 3264.216,
 'training_duration': '00:00:01',
 'training_epochs': 0,
 'training_gpu_0_memory_MB': 825,
 'training_loss': 2.778914044885074,
 'training_start_epoch': 0,
 'validation_BLEU': 0.0062080515814421575,
 'validation_loss': 2.707716464996338}

In [12]:
!allennlp train -r -s /content/model4 /content/data/config.json 



2019-04-28 11:10:06,771 - INFO - pytorch_pretrained_bert.modeling - Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .
2019-04-28 11:10:07,817 - INFO - allennlp.common.params - random_seed = 13370
2019-04-28 11:10:07,818 - INFO - allennlp.common.params - numpy_seed = 1337
2019-04-28 11:10:07,818 - INFO - allennlp.common.params - pytorch_seed = 133
2019-04-28 11:10:07,824 - INFO - allennlp.common.checks - Pytorch version: 1.0.1.post2
2019-04-28 11:10:07,825 - INFO - allennlp.training.util - Recovering from prior training at /content/model4.
2019-04-28 11:10:07,839 - INFO - allennlp.common.params - evaluate_on_test = False
2019-04-28 11:10:07,839 - INFO - allennlp.common.from_params - instantiating class <class 'allennlp.data.dataset_readers.dataset_reader.DatasetReader'> from params {'lazy': True, 'source_token_indexers': {'tokens': {'namespace': 'source_tokens', 'type': 'single_id'}}, 'source_tokenizer': {'type': 'word'}, 'target_token_indexers':

In [15]:
from allennlp.models.archival import load_archive
from allennlp.service.predictors import Predictor

path = '/content/model4/model.tar.gz'
archive = load_archive(path)

#archive = load_archive(serialization_dir+'model.tar.gz')
predictor = Predictor.from_archive(archive, 'simple_seq2seq') 
''.join(predictor.predict("What are you doing tomorrow?")['predicted_tokens'])

'I apologize.'

In [16]:
for instance in itertools.islice(train_dataset, 10):
    print('SOURCE:', instance.fields['source_tokens'].tokens)
    print('GOLD:', instance.fields['target_tokens'].tokens)
    print('PRED:', ''.join(predictor.predict_instance(instance)['predicted_tokens']))

NameError: ignored

In [0]:
''.join(predictor.predict_instance(instance)['predicted_tokens'])

'Ok.'

In [0]:
predictor.predict("whats your name?")

{'class_log_probabilities': [-4.412934303283691,
  -13.126253128051758,
  -18.1739444732666,
  -22.792346954345703,
  -23.302249908447266,
  -27.92108154296875,
  -29.777690887451172,
  -125.58854675292969],
 'predicted_tokens': ['O', 'k', '.'],
 'predictions': [[38,
   29,
   9,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11],
  [17,
   2,
   19,
   5,
   7,
   18,


In [0]:
!ls -l

total 16512
-rw-r--r-- 1 root root 4222127 Apr 28 09:49 best.th
-rw-r--r-- 1 root root    1708 Apr 28 09:53 config.json
drwxr-xr-x 4 root root    4096 Apr 28 09:45 log
-rw-r--r-- 1 root root     480 Apr 28 09:49 metrics_epoch_0.json
-rw-r--r-- 1 root root 4222127 Apr 28 09:49 model_state_epoch_0.th
-rw-r--r-- 1 root root      30 Apr 28 09:50 model.tar.gz
-rw-r--r-- 1 root root 8441816 Apr 28 09:49 training_state_epoch_0.th
drwxr-xr-x 2 root root    4096 Apr 28 09:45 vocabulary


In [0]:
import os
vocab.save_to_files(os.path.join(serialization_dir, "vocabulary"))

In [0]:
%cd vocabulary
!ls

/content/model/vocabulary
non_padded_namespaces.txt  target_tokens.txt  tokens.txt


In [0]:
%cd model

/content/model
