# Neural Translation with Self-Attention

지난 실습에서는 seq2seq 및 attention을 사용하여 날짜언어->날짜포맷 으로의 번역을 시도해 보았습니다.  
이번 실습에서는 동일한 task를 self-attention 모델을 이용해 다시 처리해 보겠습니다.

(참고)  

In [1]:
import os
import json
import pandas as pd
import numpy as np
import random
import unicodedata
import re
import time
import shutil
from collections import Counter

# Start by importing all the things we'll need.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Layer, Dot, Concatenate, Input, Activation, LSTM, Dense, Embedding, CuDNNLSTM, Flatten, TimeDistributed, Dropout, LSTMCell, RNN, Bidirectional
from keras.layers.recurrent import Recurrent
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.python.keras.utils import tf_utils
from tensorflow.keras import backend as K
from keras.optimizers import *

# This enables the Jupyter backend on some matplotlib installations.
%matplotlib notebook
import matplotlib.pyplot as plt

Using TensorFlow backend.


In [2]:
random.seed(1984)

## Load Dataset

In [16]:
# csv 포맷의 데이터 경로를 지정합니다.
train_dataset_filepath = 'datasets/nmt_date/nmt_date_train.csv'
test_dataset_filepath = 'datasets/nmt_date/nmt_date_test.csv'
df = pd.read_csv(train_dataset_filepath, header=None, names=['X', 'Y'])
x_corpus = df.iloc[:,0] 
y_corpus = df.iloc[:,1] 
x_corpus_list = x_corpus.values.tolist()
y_corpus_list = y_corpus.values.tolist()
x_char_list = np.concatenate([list(tuple(x)) for x in x_corpus_list], axis=0)
y_char_list = np.concatenate([list(tuple(y)) for y in y_corpus_list], axis=0)

## Build Vocabulary

In [17]:
counter_x = Counter(x_char_list)
counter_y = Counter(y_char_list)
x_vocab = ['<PAD>', '<UNK>', '<S>', '</S>']
x_vocab = x_vocab + list(Counter(dict(counter_x.most_common())))
y_vocab = ['<PAD>', '<UNK>', '<S>', '</S>']
y_vocab = y_vocab + list(Counter(dict(counter_y.most_common())))
idx2char_x = dict(enumerate(x_vocab))
char2idx_x = {char:index for index, char in enumerate(x_vocab)}
idx2char_y = dict(enumerate(y_vocab))
char2idx_y = {char:index for index, char in enumerate(y_vocab)}

## Prepare Train Dataset

In [19]:
def convert_sentence_to_indexed_corpus(corpus, char2idx):
    indexed_corpus = [char2idx['<S>']]
    indexed_corpus = indexed_corpus + [char2idx[char] if char in char2idx else char2idx['UNK'] for char in tuple(corpus)]
    indexed_corpus = indexed_corpus + [char2idx_x['</S>']]
    return indexed_corpus

In [20]:
indexed_x_corpus_list = []
for doc in x_corpus_list:
    indexed_x_corpus_list.append(convert_sentence_to_indexed_corpus(doc, char2idx_x))

In [21]:
indexed_y_corpus_list = []
for doc in y_corpus_list:
    indexed_y_corpus_list.append(convert_sentence_to_indexed_corpus(doc, char2idx_y))

In [22]:
max_x_corpus_length = max([len(doc) for doc in indexed_x_corpus_list])
max_y_corpus_length = max([len(doc) for doc in indexed_y_corpus_list])

In [23]:
input_data = tf.keras.preprocessing.sequence.pad_sequences(indexed_x_corpus_list, maxlen=max_x_corpus_length, padding="post")
output_data = tf.keras.preprocessing.sequence.pad_sequences(indexed_y_corpus_list, maxlen=max_y_corpus_length, padding="post")
teacher_data = output_data

target_data = [[teacher_data[n][i+1] for i in range(len(teacher_data[n])-1)] for n in range(len(teacher_data))]
target_data = tf.keras.preprocessing.sequence.pad_sequences(target_data, maxlen=max_y_corpus_length, padding="post")
target_data = target_data.reshape((target_data.shape[0], target_data.shape[1], 1))

print(input_data.shape)
print(teacher_data.shape)
print(target_data.shape)

(500000, 68)
(500000, 12)
(500000, 12, 1)


In [10]:
BUFFER_SIZE = len(x_corpus_list)
BATCH_SIZE = 32
embedding_dim = 16
units = 32
x_vocab_size = len(idx2char_y)
y_vocab_size = len(idx2char_y)
len_input = max_x_corpus_length
len_target = max_y_corpus_length

## Build Transformer model

In [11]:
from transformer import Transformer, LRSchedulerPerStep

In [12]:
d_model = 32   # 256?

In [13]:
class TokenList:
	def __init__(self, token_list):
		self.id2t = ['<PAD>', '<UNK>', '<S>', '</S>'] + token_list
		self.t2id = {v:k for k,v in enumerate(self.id2t)}
	def id(self, x):	return self.t2id.get(x, 1)
	def token(self, x):	return self.id2t[x]
	def num(self):		return len(self.id2t)
	def startid(self):  return 2
	def endid(self):    return 3

In [14]:
itokens = TokenList(list(Counter(dict(counter_x.most_common()))))
otokens = TokenList(list(Counter(dict(counter_y.most_common()))))

In [15]:
s2s = Transformer(itokens, otokens, len_limit=70, d_model=d_model, d_inner_hid=512, n_head=8, layers=2, dropout=0.1)
lr_scheduler = LRSchedulerPerStep(d_model, 4000) 
# model_saver = ModelCheckpoint(mfile, save_best_only=True, save_weights_only=True)

s2s.compile(Adam(0.001, 0.9, 0.98, epsilon=1e-9))
s2s.model.summary()

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            (None, None)         0                                            
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, None)         0           input_2[0][0]                    
__________________________________________________________________________________________________
input_1 (InputLayer)            (None, None)         0                                            
__________________________________________________________________________________________________
lambda_5 (Lambda)    

                                                                 layer_normalization_2[0][0]      
__________________________________________________________________________________________________
layer_normalization_1 (LayerNor (None, None, 32)     64          add_3[0][0]                      
__________________________________________________________________________________________________
dense_5 (Dense)                 (None, None, 32)     1024        layer_normalization_1[0][0]      
__________________________________________________________________________________________________
dense_6 (Dense)                 (None, None, 32)     1024        layer_normalization_1[0][0]      
__________________________________________________________________________________________________
lambda_15 (Lambda)              (None, None, None)   0           dense_5[0][0]                    
__________________________________________________________________________________________________
lambda_16 

conv1d_3 (Conv1D)               (None, None, 512)    16896       layer_normalization_4[0][0]      
__________________________________________________________________________________________________
time_distributed_3 (TimeDistrib (None, None, 32)     1056        lambda_34[0][0]                  
__________________________________________________________________________________________________
conv1d_4 (Conv1D)               (None, None, 32)     16416       conv1d_3[0][0]                   
__________________________________________________________________________________________________
dropout_14 (Dropout)            (None, None, 32)     0           time_distributed_3[0][0]         
__________________________________________________________________________________________________
dropout_5 (Dropout)             (None, None, 32)     0           conv1d_4[0][0]                   
__________________________________________________________________________________________________
add_8 (Add

lambda_46 (Lambda)              (None, None, None)   0           lambda_25[0][0]                  
__________________________________________________________________________________________________
lambda_47 (Lambda)              (None, None, None)   0           lambda_43[0][0]                  
                                                                 lambda_44[0][0]                  
__________________________________________________________________________________________________
lambda_48 (Lambda)              (None, None, None)   0           lambda_46[0][0]                  
__________________________________________________________________________________________________
add_12 (Add)                    (None, None, None)   0           lambda_47[0][0]                  
                                                                 lambda_48[0][0]                  
__________________________________________________________________________________________________
activation

__________________________________________________________________________________________________
time_distributed_7 (TimeDistrib (None, None, 15)     480         layer_normalization_8[0][0]      
Total params: 206,688
Trainable params: 204,448
Non-trainable params: 2,240
__________________________________________________________________________________________________


In [26]:
#batch_size=64, epochs=30, \
s2s.model.fit([input_data, output_data], None, batch_size=8, epochs=1, validation_split=0.2)

Instructions for updating:
Use tf.cast instead.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
Train on 400000 samples, validate on 100000 samples
Epoch 1/1


  2280/400000 [..............................] - ETA: 80:41:57 - loss: 3.0538 - ppl: 21.1952 - accu: 0.113 - ETA: 27:12:28 - loss: 2.8629 - ppl: 17.7311 - accu: 0.162 - ETA: 16:29:45 - loss: 2.7420 - ppl: 15.8637 - accu: 0.152 - ETA: 11:54:46 - loss: 2.6944 - ppl: 15.0922 - accu: 0.162 - ETA: 9:22:11 - loss: 2.6545 - ppl: 14.4867 - accu: 0.164 - ETA: 7:45:59 - loss: 2.5931 - ppl: 13.6979 - accu: 0.16 - ETA: 6:38:19 - loss: 2.5513 - ppl: 13.1578 - accu: 0.17 - ETA: 5:48:57 - loss: 2.5259 - ppl: 12.8173 - accu: 0.17 - ETA: 5:11:16 - loss: 2.4980 - ppl: 12.4700 - accu: 0.17 - ETA: 4:41:30 - loss: 2.4772 - ppl: 12.2088 - accu: 0.18 - ETA: 4:17:22 - loss: 2.4503 - ppl: 11.9023 - accu: 0.18 - ETA: 3:57:17 - loss: 2.4219 - ppl: 11.5950 - accu: 0.18 - ETA: 3:40:31 - loss: 2.4023 - ppl: 11.3737 - accu: 0.19 - ETA: 3:26:09 - loss: 2.3854 - ppl: 11.1829 - accu: 0.19 - ETA: 3:13:45 - loss: 2.3656 - ppl: 10.9752 - accu: 0.19 - ETA: 3:02:56 - loss: 2.3487 - ppl: 10.7960 - accu: 0.20 - ETA: 2:53:33 -

  4616/400000 [..............................] - ETA: 42:05 - loss: 1.3083 - ppl: 4.2301 - accu: 0.509 - ETA: 41:58 - loss: 1.3062 - ppl: 4.2197 - accu: 0.509 - ETA: 41:50 - loss: 1.3037 - ppl: 4.2082 - accu: 0.510 - ETA: 41:44 - loss: 1.3013 - ppl: 4.1972 - accu: 0.511 - ETA: 41:37 - loss: 1.2990 - ppl: 4.1866 - accu: 0.511 - ETA: 41:32 - loss: 1.2966 - ppl: 4.1758 - accu: 0.512 - ETA: 41:27 - loss: 1.2943 - ppl: 4.1650 - accu: 0.513 - ETA: 41:21 - loss: 1.2921 - ppl: 4.1547 - accu: 0.513 - ETA: 41:14 - loss: 1.2902 - ppl: 4.1453 - accu: 0.514 - ETA: 41:08 - loss: 1.2879 - ppl: 4.1352 - accu: 0.514 - ETA: 41:01 - loss: 1.2859 - ppl: 4.1254 - accu: 0.514 - ETA: 40:55 - loss: 1.2837 - ppl: 4.1155 - accu: 0.515 - ETA: 40:48 - loss: 1.2815 - ppl: 4.1055 - accu: 0.515 - ETA: 40:41 - loss: 1.2796 - ppl: 4.0962 - accu: 0.516 - ETA: 40:34 - loss: 1.2777 - ppl: 4.0871 - accu: 0.516 - ETA: 40:29 - loss: 1.2755 - ppl: 4.0775 - accu: 0.517 - ETA: 40:24 - loss: 1.2736 - ppl: 4.0685 - accu: 0.517 -

  6952/400000 [..............................] - ETA: 32:52 - loss: 1.1255 - ppl: 3.3960 - accu: 0.548 - ETA: 32:50 - loss: 1.1248 - ppl: 3.3931 - accu: 0.548 - ETA: 32:48 - loss: 1.1242 - ppl: 3.3901 - accu: 0.548 - ETA: 32:45 - loss: 1.1236 - ppl: 3.3874 - accu: 0.549 - ETA: 32:43 - loss: 1.1230 - ppl: 3.3846 - accu: 0.549 - ETA: 32:41 - loss: 1.1223 - ppl: 3.3816 - accu: 0.549 - ETA: 32:39 - loss: 1.1216 - ppl: 3.3788 - accu: 0.549 - ETA: 32:36 - loss: 1.1210 - ppl: 3.3760 - accu: 0.549 - ETA: 32:35 - loss: 1.1203 - ppl: 3.3731 - accu: 0.549 - ETA: 32:33 - loss: 1.1196 - ppl: 3.3701 - accu: 0.549 - ETA: 32:31 - loss: 1.1190 - ppl: 3.3674 - accu: 0.549 - ETA: 32:29 - loss: 1.1184 - ppl: 3.3647 - accu: 0.550 - ETA: 32:27 - loss: 1.1178 - ppl: 3.3619 - accu: 0.550 - ETA: 32:25 - loss: 1.1172 - ppl: 3.3592 - accu: 0.550 - ETA: 32:23 - loss: 1.1165 - ppl: 3.3563 - accu: 0.550 - ETA: 32:22 - loss: 1.1159 - ppl: 3.3537 - accu: 0.550 - ETA: 32:20 - loss: 1.1154 - ppl: 3.3513 - accu: 0.550 -

  9288/400000 [..............................] - ETA: 29:30 - loss: 1.0555 - ppl: 3.0963 - accu: 0.565 - ETA: 29:29 - loss: 1.0551 - ppl: 3.0948 - accu: 0.565 - ETA: 29:28 - loss: 1.0548 - ppl: 3.0934 - accu: 0.565 - ETA: 29:28 - loss: 1.0544 - ppl: 3.0919 - accu: 0.565 - ETA: 29:26 - loss: 1.0540 - ppl: 3.0903 - accu: 0.565 - ETA: 29:26 - loss: 1.0537 - ppl: 3.0889 - accu: 0.565 - ETA: 29:25 - loss: 1.0533 - ppl: 3.0875 - accu: 0.566 - ETA: 29:25 - loss: 1.0529 - ppl: 3.0859 - accu: 0.566 - ETA: 29:24 - loss: 1.0525 - ppl: 3.0844 - accu: 0.566 - ETA: 29:23 - loss: 1.0521 - ppl: 3.0828 - accu: 0.566 - ETA: 29:22 - loss: 1.0516 - ppl: 3.0812 - accu: 0.566 - ETA: 29:21 - loss: 1.0513 - ppl: 3.0798 - accu: 0.566 - ETA: 29:20 - loss: 1.0509 - ppl: 3.0783 - accu: 0.566 - ETA: 29:19 - loss: 1.0505 - ppl: 3.0768 - accu: 0.567 - ETA: 29:18 - loss: 1.0502 - ppl: 3.0754 - accu: 0.567 - ETA: 29:17 - loss: 1.0499 - ppl: 3.0741 - accu: 0.567 - ETA: 29:17 - loss: 1.0495 - ppl: 3.0727 - accu: 0.567 -

 11624/400000 [..............................] - ETA: 27:52 - loss: 1.0072 - ppl: 2.9146 - accu: 0.583 - ETA: 27:51 - loss: 1.0070 - ppl: 2.9136 - accu: 0.583 - ETA: 27:51 - loss: 1.0067 - ppl: 2.9127 - accu: 0.583 - ETA: 27:50 - loss: 1.0065 - ppl: 2.9117 - accu: 0.583 - ETA: 27:50 - loss: 1.0062 - ppl: 2.9107 - accu: 0.583 - ETA: 27:49 - loss: 1.0059 - ppl: 2.9096 - accu: 0.583 - ETA: 27:49 - loss: 1.0056 - ppl: 2.9086 - accu: 0.583 - ETA: 27:49 - loss: 1.0053 - ppl: 2.9077 - accu: 0.583 - ETA: 27:49 - loss: 1.0051 - ppl: 2.9067 - accu: 0.583 - ETA: 27:48 - loss: 1.0048 - ppl: 2.9058 - accu: 0.584 - ETA: 27:48 - loss: 1.0045 - ppl: 2.9048 - accu: 0.584 - ETA: 27:48 - loss: 1.0043 - ppl: 2.9039 - accu: 0.584 - ETA: 27:47 - loss: 1.0041 - ppl: 2.9030 - accu: 0.584 - ETA: 27:47 - loss: 1.0038 - ppl: 2.9021 - accu: 0.584 - ETA: 27:46 - loss: 1.0036 - ppl: 2.9012 - accu: 0.584 - ETA: 27:46 - loss: 1.0033 - ppl: 2.9003 - accu: 0.584 - ETA: 27:45 - loss: 1.0030 - ppl: 2.8992 - accu: 0.584 -

 13960/400000 [>.............................] - ETA: 26:49 - loss: 0.9713 - ppl: 2.7895 - accu: 0.597 - ETA: 26:48 - loss: 0.9711 - ppl: 2.7887 - accu: 0.598 - ETA: 26:48 - loss: 0.9709 - ppl: 2.7880 - accu: 0.598 - ETA: 26:47 - loss: 0.9706 - ppl: 2.7872 - accu: 0.598 - ETA: 26:47 - loss: 0.9704 - ppl: 2.7864 - accu: 0.598 - ETA: 26:47 - loss: 0.9702 - ppl: 2.7858 - accu: 0.598 - ETA: 26:46 - loss: 0.9700 - ppl: 2.7850 - accu: 0.598 - ETA: 26:46 - loss: 0.9698 - ppl: 2.7843 - accu: 0.598 - ETA: 26:46 - loss: 0.9695 - ppl: 2.7836 - accu: 0.598 - ETA: 26:45 - loss: 0.9694 - ppl: 2.7829 - accu: 0.598 - ETA: 26:45 - loss: 0.9692 - ppl: 2.7823 - accu: 0.598 - ETA: 26:44 - loss: 0.9689 - ppl: 2.7815 - accu: 0.598 - ETA: 26:44 - loss: 0.9687 - ppl: 2.7807 - accu: 0.598 - ETA: 26:43 - loss: 0.9684 - ppl: 2.7799 - accu: 0.598 - ETA: 26:43 - loss: 0.9682 - ppl: 2.7791 - accu: 0.599 - ETA: 26:43 - loss: 0.9680 - ppl: 2.7785 - accu: 0.599 - ETA: 26:42 - loss: 0.9677 - ppl: 2.7776 - accu: 0.599 -

 16152/400000 [>.............................] - ETA: 26:00 - loss: 0.9411 - ppl: 2.6920 - accu: 0.611 - ETA: 26:00 - loss: 0.9409 - ppl: 2.6913 - accu: 0.611 - ETA: 26:00 - loss: 0.9407 - ppl: 2.6908 - accu: 0.611 - ETA: 26:00 - loss: 0.9406 - ppl: 2.6904 - accu: 0.611 - ETA: 26:00 - loss: 0.9404 - ppl: 2.6898 - accu: 0.611 - ETA: 25:59 - loss: 0.9401 - ppl: 2.6891 - accu: 0.611 - ETA: 25:59 - loss: 0.9400 - ppl: 2.6885 - accu: 0.611 - ETA: 25:59 - loss: 0.9397 - ppl: 2.6878 - accu: 0.611 - ETA: 25:59 - loss: 0.9395 - ppl: 2.6871 - accu: 0.611 - ETA: 25:59 - loss: 0.9393 - ppl: 2.6865 - accu: 0.612 - ETA: 25:58 - loss: 0.9391 - ppl: 2.6859 - accu: 0.612 - ETA: 25:58 - loss: 0.9389 - ppl: 2.6853 - accu: 0.612 - ETA: 25:58 - loss: 0.9387 - ppl: 2.6846 - accu: 0.612 - ETA: 25:58 - loss: 0.9385 - ppl: 2.6841 - accu: 0.612 - ETA: 25:58 - loss: 0.9383 - ppl: 2.6834 - accu: 0.612 - ETA: 25:58 - loss: 0.9381 - ppl: 2.6828 - accu: 0.612 - ETA: 25:57 - loss: 0.9378 - ppl: 2.6821 - accu: 0.612 -

KeyboardInterrupt: 