In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import numpy as np
from langdetect import detect
import re
import os
import matplotlib.ticker as ticker
import matplotlib.pyplot as plt


from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Dense, GRU, Embedding, Layer
from tensorflow.keras.losses import SparseCategoricalCrossentropy, CategoricalCrossentropy
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import plot_model
import pydot

from src.helpme import clean_text,start_end_tagger,max_length,tokenize,preprocess,preprocess_sentence

In [3]:
df_test = pd.read_csv('data/kor.txt', sep='\t', names=['eng','kor','drop_me'])
df_test = df_test.drop(columns='drop_me')

In [4]:
df_test

Unnamed: 0,eng,kor
0,Go.,가.
1,Hi.,안녕.
2,Run!,뛰어!
3,Run.,뛰어.
4,Who?,누구?
...,...,...
3313,Tom always cried when his sister took away his...,"톰은 누나가 자기 장난감을 빼앗아 갔을 때마다 울음을 터뜨렸고, 누나는 바로 그런 ..."
3314,Science fiction has undoubtedly been the inspi...,공상 과학 소설은 의심의 여지 없이 오늘날 존재하는 많은 기술에 영감을 주었어.
3315,I started a new blog. I'll do my best not to b...,난 블로그를 시작했어. 블로그를 초반에만 반짝 많이 하다가 관두는 사람처럼은 되지 ...
3316,I think it's a shame that some foreign languag...,몇몇 외국어 선생님이 한 번도 원어민과 공부해본 적도 없으면서 대학을 나올 수 있었...


In [5]:
eng = preprocess(df_test['eng'])
kor = preprocess(df_test['kor'])

input_tensor, input_lang_tokenizer = tokenize(eng)
target_tensor, target_lang_tokenizer = tokenize(kor)

In [6]:
eng_vocab_size = len(input_lang_tokenizer.word_index)
kor_vocab_size = len(target_lang_tokenizer.word_index)

In [7]:
print(f'English vocab size: {eng_vocab_size}')
print(f'Korean vocab size: {kor_vocab_size}')

English vocab size: 2354
Korean vocab size: 5103


In [8]:
eng_max_length = len(input_tensor[0])
kor_max_length = len(target_tensor[0])

In [9]:
print(f'Longest English Sentence: {eng_max_length}')
print(f'Longest Korean Sentence: {kor_max_length}')

Longest English Sentence: 103
Longest Korean Sentence: 91


In [10]:
encoder_input = Input(shape=(None,),name='Encoder_input')
embedding_dim=50
embedded_input = Embedding(input_dim=eng_vocab_size,
                           output_dim=embedding_dim,
                           name='Embedding_layer')(encoder_input)
encoder_lstm = LSTM(units=100,
                   activation='relu',
                   return_sequences=False,
                   return_state=True,
                   name='Encoder_lstm')
_, last_h_state, last_c_state = encoder_lstm(embedded_input)

decoder_input = Input(shape=(None,1), name='Decoder_input')
decoder_lstm = LSTM(units=100,
                   activation='relu',
                   return_sequences=True,
                   return_state=True,
                   name='Decoder_lstm')
all_h_decoder, _, _ = decoder_lstm(decoder_input,initial_state=[last_h_state,last_c_state])

final_dense = Dense(kor_vocab_size,activation='softmax',name='Final_dense_layer')
logits = final_dense(all_h_decoder)

model = Model([encoder_input,decoder_input],logits)

model.compile(loss=SparseCategoricalCrossentropy(),optimizer=Adam(),metrics=['acc'])

In [11]:
decoder_kor_input = target_tensor.reshape((-1,kor_max_length,1))[:,:-1,:]

In [12]:
decoder_kor_target = target_tensor.reshape((-1,kor_max_length,1))[:,1:,:]

In [13]:
model.fit([input_tensor,decoder_kor_input],decoder_kor_target,
         epochs=5,
         batch_size=10)

Train on 3318 samples
Epoch 1/5
  10/3318 [..............................] - ETA: 18:39

InternalError: 2 root error(s) found.
  (0) Internal:  Blas GEMM launch failed : a.shape=(10, 1), b.shape=(1, 400), m=10, n=400, k=1
	 [[{{node model/Decoder_lstm/while/body/_142/MatMul}}]]
	 [[Reshape_18/_78]]
  (1) Internal:  Blas GEMM launch failed : a.shape=(10, 1), b.shape=(1, 400), m=10, n=400, k=1
	 [[{{node model/Decoder_lstm/while/body/_142/MatMul}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_distributed_function_3943]

Function call stack:
distributed_function -> distributed_function
