<a href="https://colab.research.google.com/github/straxFromIbr/NLP_with_BERT/blob/main/Section5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U pip 2>&1 >/dev/null
!pip install transformers==4.5.0 fugashi==1.1.0 ipadic==1.0.0 2>&1 >/dev/null 


In [None]:
import numpy as np
import torch
from transformers import BertJapaneseTokenizer, BertForMaskedLM

MODEL_NAME = "cl-tohoku/bert-base-japanese-whole-word-masking"
tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME)
bert_mlm = BertForMaskedLM.from_pretrained(MODEL_NAME)
bert_mlm = bert_mlm.cuda()

In [None]:
text = '今日は[MASK]へ行く。'
tokens = tokenizer.tokenize(text)
tokens

In [None]:
input_ids = tokenizer.encode(text, return_tensors='pt').cuda()

with torch.no_grad():
    output = bert_mlm(input_ids=input_ids)
    scores = output.logits

In [None]:
mask_position = input_ids[0].tolist().index(4)
id_best = scores[0, mask_position].argmax(-1).tolist()
token_best = tokenizer.convert_ids_to_tokens(id_best)
token_best = token_best.replace('##', '')

text = text.replace('[MASK]', token_best)
text


In [None]:
def predict_mask_topk(text, tokenizer, bert_mlm, num_topk):
    """
    入力テキスト中の1つのMASKをスコアが最も高い単語で埋める
    """
    input_ids = tokenizer.encode(text, return_tensors='pt')
    input_ids = input_ids.cuda()
    with torch.no_grad():
        output = bert_mlm(input_ids=input_ids)
    scores = output.logits

    mask_position = input_ids[0].tolist().index(4) # `4`は'[MASK]'のID
    topk = scores[0, mask_position].topk(num_topk)
    scores_topk = topk.values.cpu().numpy()

    ids_topk = topk.indices
    tokens_topk = tokenizer.convert_ids_to_tokens(ids_topk)

    text_topk = []
    for token in tokens_topk:
        token = token.replace('##', '')
        text_topk.append(text.replace('[MASK]', token, 1))
    
    return text_topk, scores_topk



In [None]:
text = '今日は[MASK]へ行く。'
text_topk, _ = predict_mask_topk(text, tokenizer, bert_mlm, 20)
print(*text_topk, sep='\n')

In [None]:
def greedy_prediction(text, tokenizer, bert_mlm):
    """
    貪欲法による複数MASKの穴埋め。
    先頭のMASKからスコアが高いものでうめてく
    """
    for _ in range(text.count('[MASK]')):
        text = predict_mask_topk(text, tokenizer, bert_mlm, 1)[0][0]
    return text


In [None]:
text = '明日は[MASK]が[MASK]かな。'
print(predict_mask_topk(text, tokenizer, bert_mlm, 1)[0][0])
print(greedy_prediction(text, tokenizer, bert_mlm))


In [None]:
def beam_search(text, tokenizer, bert_mlm, num_topk):
    """
    ビームサーチでMASKを埋める
    """
    num_mask = text.count('[MASK]')
    text_topk = [text]
    scores_topk = np.array([0])
    for _ in range(num_mask):
        text_candidates = []
        score_candidates = []
        for text_mask, score in zip(text_topk, scores_topk):
            text_topk_inner, scores_topk_inner = predict_mask_topk(
                text_mask, tokenizer, bert_mlm, num_topk
            )
            text_candidates.extend(text_topk_inner)
            score_candidates.append(score + scores_topk_inner)
        score_candidates = np.hstack(score_candidates)
        idx_list = score_candidates.argsort()[::-1][:num_topk]
        text_topk = [text_candidates[idx] for idx in idx_list]
        scores_topk = score_candidates[idx_list]
    return text_topk



In [None]:
text = '今日は[MASK][MASK]へ行く。'
print('# with beam search')
print(*beam_search(text, tokenizer, bert_mlm, 10), sep='\n')
print('# with greedy method')
print(greedy_prediction(text, tokenizer, bert_mlm))
