In [185]:
import utils.constant as config
from tqdm import tqdm
import torch
import numpy as np

import glob,re
from konlpy.tag import Mecab, Okt, Kkma
from models.cnn_lstm import CNNBiLSTM
import utils.constant as config


device = config.device
okt = Okt()
word_vocab_dict = torch.load('./data/word_vocab.pt')
char_vocab_dict = torch.load('./data/char_vocab.pt')
pos_vocab_dict = torch.load('./data/pos_vocab.pt')
entitiy_to_index = torch.load('./data/processed_data/entity_to_index.pt')
num_class = len(entitiy_to_index)

In [186]:
class Vocabulary():
    def __init__(self):
        self.vocab = set()
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0
        # self.add_special()
        
    def add_special(self):
        special_tokens = ['<pad>', '<sos>', '<eos>', '<unk>']
        for word in special_tokens:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.vocab.add(word)
            self.idx += 1
        
    def add_word(self, tokenized_text, char=False, pos_tag=False):
        for word in tokenized_text:
            
            if not char and not pos_tag:
                if word not in self.vocab:
                    self.word2idx[word] = self.idx
                    self.idx2word[self.idx] = word
                    self.vocab.add(word)
                    self.idx += 1
                    
            elif char and not pos_tag:
                for c in word:
                    if c not in self.vocab:
                        self.word2idx[c] = self.idx
                        self.idx2word[self.idx] = c
                        self.vocab.add(c)
                        self.idx += 1
                        
            elif not char and pos_tag:
                for pos in word[1]:
                    if pos not in self.vocab:
                        self.word2idx[pos] = self.idx
                        self.idx2word[self.idx] = pos
                        self.vocab.add(pos)
                        self.idx += 1
                            
    def convert_tokens_to_idx(self, list_of_tokens, add_special=False):
        list_of_idx = []
        for w in list_of_tokens:
            try:
                idx = self.word2idx[w]
            except:
                idx = self.word2idx['<unk>']
            list_of_idx.append(idx)
            
        return list_of_idx
 
    def convert_chars_to_idx(self, list_of_tokens, add_special=False):
        list_of_idx = []
        for w in list_of_tokens:
            char= []
            for c in w:
                try:
                    idx = self.word2idx[c]
                except:
                    idx = self.word2idx['<unk>']
                char.append(idx)
            list_of_idx.append(char)
            
        return list_of_idx
    
    def convert_pos_to_idx(self, raw_text, add_special=False):
        list_of_idx = []
        pos_tag = okt.pos(raw_text)

        for p in pos_tag:
            pos = p[1]
            try:
                idx = self.word2idx[pos]
            except:
                idx = self.word2idx['<unk>']
            list_of_idx.append(idx)
            
        return list_of_idx
                    
    def __len__(self):
        return len(self.word2idx)

In [187]:
def tokenizer(raw_text):
    return okt.morphs(raw_text)


def convert_to_label(pred_label):
    def get_idx_to_entity():
        idx2ent = {}
        for k,v in entitiy_to_index.items():
            idx2ent[v]=k
        return idx2ent
    idx_to_entity = get_idx_to_entity()
    tag = [idx_to_entity[i] for i in pred_label]

    return tag


def load_model(model, path):
    model_files = glob.glob(model_save_path+'/*.pt')
    best_model = model_files[0]
    ckpt = torch.load(best_model)
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval().to(device)
    print('Loading checkpoint from {}'.format(best_model))
    return model


def load_vocab(vocab_name):
    vocab = Vocabulary()
    
    if vocab_name =='word' or vocab_name=='token':
        vocab.word2idx = word_vocab_dict
    elif vocab_name =='char':
        vocab.word2idx = char_vocab_dict
    elif vocab_name == 'pos':
        vocab.word2idx = pos_vocab_dict
    else:
        raise Exception('Unknwon vocab type')
    return vocab


def pad_char_idx(char_idx, token_len):
    maxlen = max(token_len, max(len(i) for i in char_idx))
    padded = []
    for char in char_idx:
        char += [0]*(maxlen-len(char))
        padded.append(char)
    return padded


def transform_to_model_input(text):
    
    token_vocab, char_vocab, pos_vocab = load_vocab('token'), load_vocab('char'), load_vocab('pos')
    
    def _to_tensor(x):
        return torch.tensor(x).long().unsqueeze(0).to(device)

    tokenized_text = tokenizer(text)
    token_idx = token_vocab.convert_tokens_to_idx(tokenized_text)
    char_idx = char_vocab.convert_chars_to_idx(tokenized_text)
    pos_idx = pos_vocab.convert_pos_to_idx(text)
    
    # Pad & Tensor
    token_tensor = _to_tensor(token_idx)
    char_tensor = _to_tensor(pad_char_idx(char_idx, len(token_idx)))
    pos_tensor = _to_tensor(pos_idx)
    
    return token_tensor, char_tensor, pos_tensor


def get_entity(token, pred):
    answer = []
    
    token = tokenizer(token)
    assert len(token)==len(pred)

    answer = []
    for i in range(len(token)):
        if pred[i][0]=='B':  
            pref, suf = pred[i].split('-')[0], pred[i].split('-')[1]
            value=token[i]

            try:
                for j in range(i+1, len(pred)):
                    if pred[j]=='I-'+suf:
                        value += token[i+1]
            except:
                break
            answer.append((value, suf))     
            
    return answer

In [188]:
model = CNNBiLSTM(config, num_class, len(word_vocab_dict), len(char_vocab_dict), len(pos_vocab_dict))
model_save_path='./result/epoch_60_batch_64_ch_in_1_ch_out_32'

model = load_model(model, model_save_path)

Loading checkpoint from ./result/epoch_60_batch_64_ch_in_1_ch_out_32/epoch_55_step_19745_tr_acc_0.431_eval_acc_0.891.pt


In [194]:
text = '다음주 수요일 인천국제공항에서 5시30분에 8인승으로 예약해주세요, 이름은 김진수입니다'

print(tokenizer(text))
token_tensor, char_tensor, pos_tensor = transform_to_model_input(text)
logit = model(token_tensor, char_tensor, pos_tensor)

# To Labels
pred_label = logit.argmax(-1)
pred_label_list = pred_label.tolist()[0]
result = convert_to_label(pred_label_list)
print(result)
print(' ')

# Answer
ans = get_entity(text, result)
print(ans)

['다음주', '수요일', '인천', '국제공항', '에서', '5시', '30분', '에', '8', '인승', '으로', '예약', '해주세요', ',', '이름', '은', '김진수', '입니다']
['O', 'B-DAT', 'B-LOC', 'I-LOC', 'O', 'B-NOH', 'I-NOH', 'O', 'B-NOH', 'I-NOH', 'O', 'O', 'O', 'O', 'O', 'O', 'B-PER', 'O']
 
[('수요일', 'DAT'), ('인천국제공항', 'LOC'), ('5시30분30분', 'NOH'), ('8인승', 'NOH'), ('김진수', 'PER')]
