In [1]:
import pymysql, os, copy, json, time, openpyxl
import pandas as pd
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
import re
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from kss import kss
from pororo import Pororo
from konlpy.tag import Mecab
import pickle

In [2]:
from transformers import BertForTokenClassification
from tokenization_kobert import KoBertTokenizer
mecab = Mecab()

In [3]:
def token_generate(sent, tok, MAX_LEN):
    encode_dict = tok.encode_plus(text=sent, 
                                    add_special_tokens=True, max_length=MAX_LEN,
                                    return_token_type_ids=True, padding='max_length', #pad_to_max_length=True,
                                    return_attention_mask=True, truncation=True) # return_tensors='pt',

    input_id = encode_dict['input_ids']
    attention_mask = encode_dict['attention_mask']
    token_type_id = encode_dict['token_type_ids']
    return input_id, attention_mask, token_type_id 

def _read_file(input_file):
    with open(input_file, "r", encoding="utf-8") as f:
        sentences = []
        labels = []
        for line in f:
            split_line = line.strip().split('\t')
            sentences.append(split_line[0])
            labels.append(split_line[1])
        return sentences, labels
    
def eval_input(test, token, args, pad_token_label_id, mask_padding_with_zero = True):

    words = test.split()
    tokens = []
    slot_label_mask = []
    for word in words:
        word = word.strip()
        word_tokens = token.tokenize(word)
        if not word_tokens:
            word_tokens = [unk_token]  # For handling the bad-encoded word
        tokens.extend(word_tokens)
        slot_label_mask.extend([0] + [pad_token_label_id] * (len(word_tokens) - 1))

    # Account for [CLS] and [SEP]
    special_tokens_count = 2
    if len(tokens) > args.max_seq_len - special_tokens_count:
        slot_label_mask = slot_label_mask[:(args.max_seq_len - special_tokens_count)]

    # Add [SEP] token
    slot_label_mask += [pad_token_label_id]
    slot_label_mask = [pad_token_label_id] + slot_label_mask
    padding_length = args.max_seq_len - len(slot_label_mask)
    slot_label_mask = slot_label_mask + ([pad_token_label_id] * padding_length)
    input_id, attention_mask, token_type_id = token_generate(test, token, args.max_seq_len)

    input_ids = torch.tensor(input_id, dtype=torch.long).reshape(1,-1)
    attention_mask = torch.tensor(attention_mask, dtype=torch.long).reshape(1,-1)
    token_type_ids = torch.tensor(token_type_id, dtype=torch.long).reshape(1,-1)
    slot_label_mask = torch.tensor(slot_label_mask, dtype=torch.long).reshape(1,-1)
    
    return input_ids, attention_mask, token_type_ids, slot_label_mask    

def eval_ft(test_sent, label_lst, model_, tok, dev, args):
    preds = None
    pad_token = torch.nn.CrossEntropyLoss().ignore_index    
    with torch.no_grad():
        tmp_input, tmp_attention, tmp_token, tmp_slot = eval_input(test_sent, tok, args, pad_token)    
        inputs = {'input_ids':tmp_input.to(device), 'attention_mask':tmp_attention.to(dev), 
                  'labels' : None,
                  'token_type_ids': tmp_token.to(device)}
        output = model_(**inputs)
        logits = output[0]

        if preds is None:
            preds = logits.detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)

        preds = np.argmax(preds, axis=2)
        slot_label_map = {i : label for i, label in enumerate(label_lst)}   
    preds_list = []
    for j in range(preds.shape[1]):
        if tmp_slot[0,j] != pad_token:
            preds_list.append(slot_label_map[preds[0][j]])            

    line = ""
    for w, p in zip(test_sent.split(), preds_list):
        line = line + " {}".format(p)
        #if p == "O":
        #    line = line + w + " "
        #else :
        #    line = line + "{}[{}] ".format(w, p)

    return line

def get_labels(label_path):
    return [label.strip() for label in open(os.path.join(label_path), 'r', encoding='utf-8')]

def _call_db_info():
    return pymysql.connect(
        host = '183.111.204.69',
        port= 13306,
        user = 'newsbot1',
        password='lgensol2020!',
        db = 'trend',
        charset = 'utf8')
def extract_parenthese(str):
    items_lst = re.findall('\(([^)]+)', str) #extracts string in () 
    newList = [x for x in items_lst if len(x)>=2] # more than 2
    return newList

def extract_quotes(str):
    items_lst = re.findall('"([^"]*)"', str)
    return items_lst

def parentheses_(tmp_input_sent):    
    tmp_input_sent = re.sub(pattern='\(+', repl=' ', string=tmp_input_sent)#tmp_input_sent = re.sub(pattern='\(\(', repl='\(', string=tmp_input_sent)
    tmp_input_sent = re.sub(pattern='\)+', repl=' ', string=tmp_input_sent)#tmp_input_sent = re.sub(pattern='\)\)', repl='\)', string=tmp_input_sent)
    tmp_input_sent = re.sub(pattern=' +', repl=' ', string=tmp_input_sent)
    input_sent = re.sub(pattern='\\\\',   repl='', string=tmp_input_sent)
    return input_sent
    '''
    tmp_sent1, tmp_sent2 = [], []    
    s_re = re.compile('\(')#tmp_sentence[25])#.match('\(')
    e_re = re.compile('\)')
    s_m = [(m.start(0), m.end(0)) for m in s_re.finditer(input_sent)]#tmp_sentence[25])]
    e_m = [(m.start(0), m.end(0)) for m in e_re.finditer(input_sent)]#tmp_sentence[25])]
    m = []
    for i in range(0, len(s_m)):
        if s_m[i][1] < e_m[i][0]:
            m.append((s_m[i], e_m[i]))
        else:
            for j in range(i+1, len(e_m)):
                if s_m[i][1] < e_m[j][0]:
                    m.append((s_m[i], e_m[j]))
                    break;                   

    if len(m) > 0:
        for i in range(0, len(m)):
            if i == 0:
                tmp_sent1.append(input_sent[:m[i][0][0]])
            else :
                tmp_sent1.append(input_sent[m[(i-1)][1][1]:m[i][0][0]])                        
            tmp_sent2.append(input_sent[m[i][0][1]:m[i][1][0]])
        tmp_sent1.append(input_sent[m[-1][1][1]:])    
        return ' '.join(tmp_sent1 + tmp_sent2)
    else:
        return input_sent
    '''

In [4]:
'''
tmp_file_raw = openpyxl.load_workbook('./20210628_레퍼런스2.4.xlsx')
shee_name = tmp_file_raw.sheetnames
ref_dic = pd.DataFrame()
for f in range(0, 2):
    tmp_file_pd = pd.DataFrame(tmp_file_raw[shee_name[f]].values).copy()
    tmp_col_name1 = list(tmp_file_pd.iloc[0,0:])
    tmp_file = tmp_file_pd.iloc[1:,:].copy().reset_index(drop=True)
    tmp_file.columns = tmp_col_name1
    tmp_ref_dic = tmp_file[['대분류','명칭']]
    ref_dic = pd.concat((ref_dic, tmp_ref_dic))
'''
ref_dic = pd.read_csv('ref_dic.csv')#ref_dic.reset_index(drop=True)

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = BertForTokenClassification.from_pretrained('./kobert')
token = KoBertTokenizer.from_pretrained('monologg/kobert')
model.to(device)
model.eval()
label_lst = get_labels('./kobert/label.txt')
args = argparse.Namespace(max_seq_len = 128)
print("Model Loaded")

Model Loaded


In [6]:
conn = _call_db_info()
curs = conn.cursor()
tmp_insert_sql = "select * from word_dic"
curs.execute(tmp_insert_sql)        
tmp_rslt = pd.DataFrame(curs.fetchall())
conn.commit()
conn.close()

In [7]:
conn = _call_db_info()
curs = conn.cursor()
tmp_insert_sql = "select * from content" #where date >=20210714"#"select * from word_dic"
curs.execute(tmp_insert_sql)        
tmp_article = pd.DataFrame(curs.fetchall())
conn.commit()
conn.close()
tmp_article.shape

(42, 5)

In [8]:
step0_ptn= '[\'\‘\’]'
step1_ptn= '[\u00a0\u3000①②③④⑤⑥⑦⑧⑨⑩』◦※→®↑↓‣★▶■△◇◆▲○●\{\}\[\]\/?,+;:‧·ᆞ…》ⓒ|*~`\""“”!^_<>@\#&\\\=\'\n]'     
step2_ptn= '[\.]' 
article_sent = []
article_ner_l1, article_ner_l2 = [], []
for a in range(0, tmp_article.shape[0]):    
    #print('Article %s'%a)
    tmp_a = re.sub(pattern=step0_ptn, repl='\"', string=tmp_article.iloc[a, 4])#.replace('\'‘’', '\"')
    tmp_a = re.sub(pattern=step1_ptn, repl='', string=tmp_a)
    tmp_sentence = kss.split_sentences(tmp_a)
    tmp_sentence = [re.sub(pattern=step2_ptn, repl='', string=s) for s in tmp_sentence]
    tmp_sentence = [parentheses_(s) for s in tmp_sentence] 

    ner_sent, ner_tag_l1, ner_tag_l2 = [], [], []
    pos_set = ref_dic['명칭'].values.tolist()

    for sent in range(0, len(tmp_sentence)):
        #print(sent)
        tmp_ner = []
        tmp_sent = re.sub(pattern=' +', repl=' ', string=tmp_sentence[sent])
        sent_split = tmp_sent.split(' ')
        tmp_sent_n, tmp_sent_a = divmod(len(sent_split), 30)
        if tmp_sent_a == 0:
            sent_n = tmp_sent_n
        else :
            sent_n = tmp_sent_n + 1

        if len(sent_split) > 30 :        
            sent_ = []
            for i in range(0, sent_n):            
                if i < sent_n:
                    sent_.append(sent_split[(i*30):((i+1)*30)])
                else:
                    sent_.append(sent_split[(i*30):])
            for i in range(0, sent_n):            
                tmp_ner.append(eval_ft(' '.join(sent_[i]), label_lst, model, token, device, args))                
            tmp_ner = (' ').join(tmp_ner)
        else:    
            tmp_ner = eval_ft(tmp_sent, label_lst, model, token, device, args).strip() 

        tmp_mecab_pos = [mecab.pos(w) for w in sent_split]
        tmp_ner = tmp_ner = [t for t in tmp_ner.split(' ') if len(t) >0]#tmp_ner.split(' ')
        tmp_ner_l1 = ['O'] * len(tmp_ner)#tmp_ner.copy()
        tmp_ner_l2 = ['O'] * len(tmp_ner)#tmp_ner.copy()
        for k in range(0 ,len(tmp_ner)):    
            m_pos = [i for i, tmp_me in enumerate(tmp_mecab_pos[k]) if tmp_me[1] == 'NNP']
            if len(m_pos) > 0:
                for l in range(0, len(m_pos)):#tmp_mecab_pos[4][m_pos[0]][0]
                    if tmp_mecab_pos[k][m_pos[l]][0] in pos_set:
                        ner_ind = pos_set.index(tmp_mecab_pos[k][m_pos[l]][0])                    
                        tmp_ner_l1[k] = ref_dic['대분류'].loc[ner_ind]
                        tmp_ner_l2[k] = ref_dic['중분류'].loc[ner_ind]
        ner_sent.append((' ').join(sent_split))
        ner_tag_l1.append((' ').join(tmp_ner_l1))
        ner_tag_l2.append((' ').join(tmp_ner_l2))
    article_sent.extend(ner_sent)
    article_ner_l1.extend(ner_tag_l1)
    article_ner_l2.extend(ner_tag_l2)
    if a % 100 == 0:
        print('Article %s'%a)

Article 0


In [9]:
print("Sent : %d"%len(article_sent))
print("Sent tagging : %d"%len(article_ner_l1))

Sent : 569
Sent tagging : 569


In [10]:
with open('LGES_sent.dta', 'wb') as a_sent:
    pickle.dump(article_sent, a_sent)
with open('LGES_sent_ner_level1.dta', 'wb') as a_sent_l1:
    pickle.dump(article_ner_l1, a_sent_l1)
with open('LGES_sent_ner_level2.dta', 'wb') as a_sent_l2:
    pickle.dump(article_ner_l2, a_sent_l2)

In [11]:
label_ner_l1_ = [x for a_ner in article_ner_l1 for x in a_ner.split(' ') if x != 'O' and len(x) > 0 ]
label_ner_l2_ = [x for a_ner in article_ner_l2 for x in a_ner.split(' ') if x != 'O' and len(x) > 0]
tmp_label_ner_l1 = list(set(label_ner_l1_))
tmp_label_ner_l2 = list(set(label_ner_l2_))
tmp_label_ner_l1.sort()
tmp_label_ner_l2.sort()

In [12]:
label_ner_l1 = ['UNK']# + tmp_label_ner_l1
label_ner_l2 = ['UNK']# + tmp_label_ner_l2
for i in tmp_label_ner_l1:
    label_ner_l1.extend([i+'-B', i+'-I'])
for i in tmp_label_ner_l2:
    label_ner_l2.extend([i+'-B', i+'-I'])
with open('LGES_label_l1.dta', 'wb') as a_label1:
    pickle.dump(label_ner_l1, a_label1)
with open('LGES_label_l2.dta', 'wb') as a_label2:
    pickle.dump(label_ner_l2, a_label2)

In [10]:
label_ner_l1

['UNK',
 '기관-B',
 '기관-I',
 '기술-B',
 '기술-I',
 '기업-B',
 '기업-I',
 '동향-B',
 '동향-I',
 '서비스-B',
 '서비스-I',
 '소재-B',
 '소재-I',
 '전지-B',
 '전지-I',
 '제품-B',
 '제품-I',
 '트렌드-B',
 '트렌드-I']