In [32]:
import os
import pickle
from tqdm import tqdm

import numpy as np

import torch
from transformers import BertJapaneseTokenizer, BertForMaskedLM
import gensim.models
from gensim.test.utils import datapath

import MeCab
import ipadic

In [17]:
b_tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
b_model = BertForMaskedLM.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
m_tagger = MeCab.Tagger(ipadic.MECAB_ARGS)

cm = gensim.models.KeyedVectors.load_word2vec_format("assets/cc.ja.300.vec.gz")

Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-whole-word-masking were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [21]:
m_tokenizer = MeCab.Tagger('-O wakati ' + ipadic.MECAB_ARGS)

In [81]:
def masking_text(text, mecab_tokenizer, bert_tokenizer):

    """文を受け取りMeCabとBERTトークナイザの結果からマスク位置を決定，マスクされた文のリストを出力"""
    
    mask_labels = [] # ['r:そのまま', 'm:マスクする', 'h:まとめてマスクする', 's:まとめられる語']
    ogl_mecab_tokens = []
    masked_texts = []
    target_tokens = []
    mecab_tokens_list = []
    zure = 0
    
    node = mecab_tokenizer.parseToNode(text)
    tokenized_text = bert_tokenizer.tokenize(text)
    
    node = node.next #文頭，文末形態素を無視
    while node.next:
        
        #print(node.feature)
        
        if node.surface not in  tokenized_text: #サブワード化された語を変換対象から外すため
            mask_labels.append('r')
        elif node.feature.split(',')[0] in ['助詞', '助動詞', '連体詞', '形容詞', '記号']:
            mask_labels.append('r')
        elif node.feature.split(',')[1] == 'サ変接続':
            mask_labels.append('h')
        elif node.feature.split(',')[4] == 'サ変・スル' and mask_labels[-1] == 'h':
            mask_labels.append('s')
        else:
            mask_labels.append('m')
        ogl_mecab_tokens.append(node.surface)
        node = node.next
        
    #print(mask_labels)
    
    memory_masked = ogl_mecab_tokens[:]
    
    for target_idx, mask_label in enumerate(mask_labels):
        mecab_tokens = ogl_mecab_tokens[:]
        
        target_token = mecab_tokens[target_idx]
        
        if mask_label == 'm':
            mecab_tokens[target_idx] = '[MASK]'
            memory_masked[target_idx-zure] = '[MASK]'
        elif mask_label == 'h':
            mecab_tokens[target_idx] = '[MASK]'
            memory_masked[target_idx-zure] = '[MASK]'
            for m_label in mask_labels[target_idx+1:]:
                if m_label == "s":
                    mecab_tokens.pop(target_idx+1)
                    memory_masked.pop(target_idx+1)
                    zure = zure + 1
                else:
                    break
        elif mask_label in ['r','s']:
            continue
        
        mecab_tokens_list.append(mecab_tokens[:])
        
        masked_text = "".join(mecab_tokens)
        
        masked_texts.append(masked_text)
        target_tokens.append(target_token)
        
        mecab_tokens[target_idx] = target_token
        
    
    return masked_texts, mecab_tokens_list, target_tokens, memory_masked

In [4]:
def get_candidates(text, masked_text, bert_tokenizer, bert_model):
    
    """マスクされた文と原文を受け取って候補を10語出力"""
    
    encoded_dict = bert_tokenizer(text, masked_text)

    masked_idx = encoded_dict["input_ids"].index(4)
    tokens_tensor = torch.tensor([encoded_dict["input_ids"]])
    segments_tensors = torch.tensor([encoded_dict["token_type_ids"]])

    bert_model.eval()
    if torch.cuda.is_available():
        tokens_tensor = tokens_tensor.to('cuda')
        segments_tensors = segments_tensors.to('cuda')
        bert_model.to('cuda')
    with torch.no_grad():
        outputs = bert_model(tokens_tensor, token_type_ids=segments_tensors)
        predictions = outputs[0]

    topk_score, topk_index = torch.topk(predictions[0, masked_idx], 10)
    topk_tokens = bert_tokenizer.convert_ids_to_tokens(topk_index.tolist())
    
    bert_rank = np.array([i for i in range(len(topk_tokens))])
    
    topk_data = {"tokens": topk_tokens, "score": topk_score, "index": topk_index, "rank": bert_rank}
    
    return topk_data

In [8]:
def frequency_ranking(candidates):
    
    """頻度表を検索，高い順からランキングします"""
    
    if os.path.isfile("assets/bccwj_frequency.pickle") == 'True':
        bccwj_frequency = {}
        with open("assets/BCCWJ_goihyo_utf8.txt") as BCCWJ:
            for BCCWJ_line in BCCWJ:
                BCCWJ_line_list = BCCWJ_line.split('\t')
                if BCCWJ_line_list[0] != 'ID_BCCWJ':
                    bccwj_frequency[BCCWJ_line_list[3]] = sum(int(i) for i in BCCWJ_line_list[9:15])
        
        with open("assets/bccwj_frequency.pickle",'wb') as f:
            pickle.dump(bccwj_frequency, f)
        
    with open("assets/bccwj_frequency.pickle", 'rb') as f:
        bccwj_frequency = pickle.load(f)
    
    candidates_frequency = [bccwj_frequency.get(cand, 0) for cand in candidates]
    frequency_rank = np.array([np.argsort([-int(freq) for freq in candidates_frequency]).tolist().index(idx) for idx in range(10)]) #ランク付けの呪文
    
    frequency_data = {'frequency': candidates_frequency, 'rank': frequency_rank}
    
    return frequency_data

In [97]:
def frequency_ranking_twc(candidates, tkz_mkd_text, mecab_tagger):
    cand_genkeis = []
    
    if os.path.isfile("./assets/nlt_frequency.pickle"):
        with open("assets/nlt_frequency.pickle", 'rb') as f:
            nlt_frequency = pickle.load(f)
    else:
        nlt_frequency = {}
        with open("assets/NLT_freq_list_split.txt") as NLT:
            for NLT_line in NLT.read().splitlines():
                NLT_line_list = NLT_line.split(',')
                num = NLT_line_list[1]
                nlt_frequency[NLT_line_list[0]] = num
        
        with open("assets/nlt_frequency.pickle", 'wb') as f:
            pickle.dump(nlt_frequency, f)
    
    masked_idx = tkz_mkd_text.index("[MASK]")
    
    for cand in candidates:
        wak_idx = 0
        
        tkz_mkd_text[masked_idx] = cand
        
        node = mecab_tagger.parseToNode("".join(tkz_mkd_text))
        node = node.next
        while node.next:
            if wak_idx == masked_idx:
                cand_genkeis.append(node.feature.split(',')[6])
            wak_idx = wak_idx + 1
            node = node.next    
    
    candidates_frequency = [nlt_frequency.get(cand, 0) for cand in cand_genkeis]
    frequency_rank = np.array([np.argsort([-int(freq) for freq in candidates_frequency]).tolist().index(idx) for idx in range(10)])
    
    frequency_data = {'frequency': candidates_frequency, 'rank': frequency_rank}
    
    return frequency_data

In [6]:
def fasttext_similarity_ranking(target_token, candidates, fasttext_vec):
    
    """Cos類似度を求めてランク付けする．gensimで読み込んでるんだったら他も使えるかも"""
    
    cosine_sim = [cm.similarity(target_token, cand) if target_token in cm and cand in cm else 0 for cand in candidates]
    
    sim_rank = np.array([np.argsort([-sim for sim in cosine_sim]).tolist().index(idx) for idx in range(10)]) #ランク付けの呪文
    
    sim_data = {"cos_sim": cosine_sim, "rank": sim_rank}
    
    return sim_data

In [100]:
def heiika(text):
    simple_words = []
    
    masked_list, mkb_tok_list, target_tokens, memory_masked = masking_text(text, m_tagger, b_tokenizer)
    
    for text_idx, masked_text in enumerate(masked_list):
        topk_dict = get_candidates(text, masked_text, b_tokenizer, b_model)

        #freq_data = frequency_ranking(topk_dict["tokens"])
        freq_data = frequency_ranking_twc(topk_dict["tokens"], mkb_tok_list[text_idx], m_tagger)

        sim_dict = fasttext_similarity_ranking(target_tokens[text_idx], topk_dict["tokens"], cm)

        avg_idxs = np.argsort(topk_dict["rank"] + freq_data["rank"] + sim_dict["rank"]) #合算したランクが同じになったらBERTスコアが優先されます
        
        sorted_candidates = [topk_dict["tokens"][a_idx] for a_idx in avg_idxs]

        simple_words.append(sorted_candidates[0])
        
        print(topk_dict["tokens"])
        print(topk_dict["score"])
        print(freq_data["frequency"])
        print(sim_dict["cos_sim"])
    
    kantannabun = []
    sw_idx = 0
    for tok in memory_masked:
        if tok == "[MASK]":
            kantannabun.append(simple_words[sw_idx])
            sw_idx = sw_idx + 1
        else:
            kantannabun.append(tok)

    print("".join(kantannabun))

In [101]:
heiika("若者が未来を担う")

['若者', '大人', '子供', '子ども', 'アスリート', '青年', '企業', '女性', 'スポーツ', '老人']
tensor([12.6665, 10.3220,  9.9363,  9.8606,  9.3604,  9.3049,  9.1614,  8.9949,
         8.9760,  8.8435])
['44767', '85155', '676310', 0, '1', '34469', '463441', '341491', '68111', '35918']
[1.0, 0.49588507, 0.5431058, 0.5366304, 0.37994212, 0.54304886, 0.4340958, 0.5138575, 0.3189348, 0.5010498]
['未来', '過去', '次世代', '現代', '時代', '明日', '現在', '自然', '日常', '人生']
tensor([13.4072, 11.7603, 10.8013, 10.4733, 10.3210,  9.6666,  9.6420,  9.5709,
         9.1504,  9.0598])
['69384', '132759', '13938', '90122', '360825', '51071', '452559', '301544', '97181', '155473']
[1.0, 0.45295668, 0.5192402, 0.41967815, 0.33084735, 0.3713258, 0.21572639, 0.32597736, 0.2618374, 0.39183265]
['担う', '支える', '担い', '伝える', '引き継ぐ', '育てる', '語る', '考える', '演じる', '生きる']
tensor([17.1996, 12.2899, 11.6165, 11.5818, 11.2846, 11.1931, 11.1063, 10.7688,
        10.6419, 10.6276])
['45554', '70429', '45554', '177123', '17360', '79384', '109779', '1195718', '2409