In [1]:
import os
import csv
import numpy as np
import tensorflow as tf

from model.seq2seq.Seq2Seq import *
from model.seq2seq_attn.Seq2Seq_Attn import *
from utils.utils import *

Loading JIT Compiled ChatSpace Model


In [2]:
def evaluate():
    
    # Load tokenizer
    enc_tokenizer = load_tokenizer('enc-tokenizer')
    dec_tokenizer = load_tokenizer('dec-tokenizer')
    enc_vocab_size = enc_tokenizer.vocab_size + 1
    dec_vocab_size = dec_tokenizer.vocab_size + 2
    
    # Define the optimizer and the loss function
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
    loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
    
    # Define seq2seq model
    config = {'batch_size': 1,
              'enc_max_len': enc_max_len+1,
              'dec_max_len': dec_max_len+2,
              'enc_unit': enc_unit,
              'dec_unit': dec_unit,
              'embed_dim': embed_dim,
              'dropout_rate': dropout_rate,
              'enc_vocab_size': enc_vocab_size,
              'dec_vocab_size': dec_vocab_size,
              'dec_sos_token': dec_tokenizer.vocab_size}

    #model = seq2seq(config)
    model = seq2seq_attn(config)
    
    # checkpoint
    checkpoint_dir = 'checkpoint/daily-korean/seq2seq_attn'
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
    
    while(True):
        
        input_text = input(">: ")
        
        if input_text == 'q':
            break
        
        enc_input = tf.keras.preprocessing.sequence.pad_sequences([enc_encode(input_text, enc_tokenizer)], 
                                                             maxlen=enc_max_len+1, padding='post')

        model.load_weights(filepath=tf.train.latest_checkpoint(checkpoint_dir))
        
        enc_tokens = idx2word(enc_input[0], enc_tokenizer)
        print(enc_tokens)
        preds, attn_weights = model(enc_input, training=False)

        pred_str, pred_tokens = decoding_from_result(preds, dec_tokenizer)
        print("<: ", pred_str)
        
        print(attn_weights.shape, len(enc_tokens), len(pred_tokens))
        print(attn_weights)
        # plotting the attention weights
        plot_attention(attn_weights, enc_tokens, pred_tokens)

In [3]:
if __name__ == '__main__':
    evaluate()

>:  안녕하세요


<:  안녕하세요 .
(1, 26, 26) 2 2
tf.Tensor(
[[[0.00789661 0.01597753 0.97612584 0.         0.         0.
   0.         0.         0.         0.         0.         0.
   0.         0.         0.         0.         0.         0.
   0.         0.         0.         0.         0.         0.
   0.         0.        ]
  [0.09487355 0.12389449 0.781232   0.         0.         0.
   0.         0.         0.         0.         0.         0.
   0.         0.         0.         0.         0.         0.
   0.         0.         0.         0.         0.         0.
   0.         0.        ]
  [0.2252377  0.21163902 0.5631233  0.         0.         0.
   0.         0.         0.         0.         0.         0.
   0.         0.         0.         0.         0.         0.
   0.         0.         0.         0.         0.         0.
   0.         0.        ]
  [0.01750437 0.02704174 0.9554539  0.         0.         0.
   0.         0.         0.         0.         0.         0.
   0.         0.         0.  

NameError: name 'dec_len' is not defined