In [1]:
import tensorflow as tf

from gensim.models.keyedvectors import KeyedVectors
from konlpy.tag import Mecab

# from googletrans import Translator
from models.transformer import * 

import time
import numpy as np

import os
import sys
import urllib.request
import requests
import datetime
import pickle
import json

In [2]:
with open('./data/ko_noun_dict.pkl', "rb") as f:
    ko_dict = pickle.load(f)
    
with open('./data/en_noun_dict.pkl', "rb") as f:
    en_dict = pickle.load(f)

In [27]:
def get_noun_data(ko_dict, en_dict):
    for i in range(len(en_dict)):
        ko_vector = list(ko_dict.values())[i]
        en_vector = list(en_dict.values())[i]
        
#         yield (ko_word, ko_vector, en_word, en_vector)
        yield (ko_vector, en_vector)
    
def get_noun_data_2(ko_vec, en_vec):
    for i in range(len(en_dict)):
        ko_vector = ko_vec[i]
        en_vector = en_vec[i]
        
        yield ko_vector, en_vector

In [38]:
dataset = tf.data.Dataset.from_generator(get_noun_data_2, 
                              (tf.float64, tf.float64),
                              (tf.TensorShape([300]), tf.TensorShape([300])),
                               args=(list(ko_dict.values()), list(en_dict.values())))

dataset = dataset.batch(128, drop_remainder=True)

In [85]:
encoder_88 = Encoder(num_layers=1, d_model=8, num_heads=8, dff=512, input_vocab_size=0, maximum_position_encoding=0)

In [86]:
EPOCHS = 200

num_layers = 1
d_model = 8
dff = 512
num_head = 8
dropout_rate = 0.1


In [87]:
loss_object = tf.keras.losses.Huber()

def loss_function(real, pred):
    loss = loss_object(real, pred)
    
    return loss

In [88]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

In [89]:
checkpoint_path = "./checkpoints/train_88"

ckpt = tf.train.Checkpoint(encoder=encoder_100,
                           optimizer=optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print('Latest checkpoint restored!!')

Latest checkpoint restored!!


In [92]:
@tf.function()
def train_step(inp, real):
    
    with tf.GradientTape() as tape:
        inp = inp * 100 
        real = real * 100
        output = encoder(inp, training=True, mask=None)
#         output = output / 100
        loss = loss_function(real, output)

        
    gradients = tape.gradient(loss, encoder.trainable_variables)
    optimizer.apply_gradients(zip(gradients, encoder.trainable_variables))
    
    train_loss(loss)

In [93]:
for epoch in range(EPOCHS):
    tic = time.time()
    
    train_loss.reset_states()
    
    for (batch, (inp, real)) in enumerate(dataset):
        train_step(inp, real)
        
        if batch % 50 == 0:
            print(f'Epoch {epoch + 1} Batch {batch} Loss {train_loss.result() : .4f}')
                  
    if (epoch + 1) % 5 == 0:
        ckpt_save_path = ckpt_manager.save()
        print(f'Saving checkpoint for epoch {epoch + 1} at {ckpt_save_path}')

    print(f'Time taken for 1 epoch: {time.time() - tic:.2f} secs\n')

Model Input shape (128, 300, 1)
Scaled_attention Shape :  (128, 1, 300, 1)
Scaled_attention Shape :  (128, 300, 1, 1)
Concat attention Shape : (128, 300, 1)
Epoch 1 Batch 0 Loss  21.3269
Time taken for 1 epoch: 0.63 secs

Epoch 2 Batch 0 Loss  21.3269
Time taken for 1 epoch: 0.08 secs

Epoch 3 Batch 0 Loss  21.3269
Time taken for 1 epoch: 0.08 secs

Epoch 4 Batch 0 Loss  21.3269
Time taken for 1 epoch: 0.08 secs

Epoch 5 Batch 0 Loss  21.3268
Saving checkpoint for epoch 5 at ./checkpoints/train_88/ckpt-49
Time taken for 1 epoch: 0.09 secs

Epoch 6 Batch 0 Loss  21.3268
Time taken for 1 epoch: 0.08 secs

Epoch 7 Batch 0 Loss  21.3268
Time taken for 1 epoch: 0.08 secs

Epoch 8 Batch 0 Loss  21.3268
Time taken for 1 epoch: 0.08 secs

Epoch 9 Batch 0 Loss  21.3267
Time taken for 1 epoch: 0.08 secs

Epoch 10 Batch 0 Loss  21.3267
Saving checkpoint for epoch 10 at ./checkpoints/train_88/ckpt-50
Time taken for 1 epoch: 0.08 secs

Epoch 11 Batch 0 Loss  21.3267
Time taken for 1 epoch: 0.08 sec

KeyboardInterrupt: 

In [None]:
encoder()