In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch

from transformers import BertJapaneseTokenizer, BertForMaskedLM

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_name = "cl-tohoku/bert-base-japanese-whole-word-masking"
tokenizer = BertJapaneseTokenizer.from_pretrained(model_name)
bert_mlm = BertForMaskedLM.from_pretrained(model_name)

Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-whole-word-masking were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
text = "今日は[MASK]へ行く。"
tokens = tokenizer.tokenize(text)
print(tokenizer(text))
print(tokens)

{'input_ids': [2, 3246, 9, 4, 118, 3488, 8, 3], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}
['今日', 'は', '[MASK]', 'へ', '行く', '。']


In [5]:
input_ids = tokenizer.encode(text, return_tensors="pt")
print(input_ids)
with torch.no_grad():
    output = bert_mlm(input_ids=input_ids)
    scores = output.logits

# ID列で'[MASK]'（IDは4）の位置を調べる
mask_position = input_ids[0].tolist().index(4)

print(scores.shape)

tensor([[   2, 3246,    9,    4,  118, 3488,    8,    3]])
torch.Size([1, 8, 32000])


In [6]:

id_best = scores[0, mask_position].argmax(-1).item()
token_best = tokenizer.convert_ids_to_tokens(id_best)
token_best = token_best.replace("##", "")

text = text.replace("[MASK]", token_best)

print(text)

今日は東京へ行く。


In [11]:
def predict_mask_topk(text, tokenizer, bert_mlm, num_topk):
    """文書中の最初の[MASK]をスコア上位のトークンに置き換える
    """
    input_ids = tokenizer.encode(text,return_tensors="pt")
    with torch.no_grad():
        output = bert_mlm(input_ids=input_ids)
    scores = output.logits

    mask_position = input_ids[0].tolist().index(4)
    topk = scores[0, mask_position].topk(num_topk)
    ids_topk =  topk.indices # token ID
    tokens_topk = tokenizer.convert_ids_to_tokens(ids_topk)
    scores_topk = topk.values.numpy()

    text_topk = []
    for token in tokens_topk:
        token = token.replace('##', '')
        text_topk.append(text.replace("[MASK]", token, 1))
    return text_topk, scores_topk

In [12]:
text = '今日は[MASK]へ行く。'
text_topk, _ = predict_mask_topk(text,tokenizer, bert_mlm, 10)
print(*text_topk, sep='\n')

今日は東京へ行く。
今日はハワイへ行く。
今日は学校へ行く。
今日はニューヨークへ行く。
今日はどこへ行く。
今日は空港へ行く。
今日はアメリカへ行く。
今日は病院へ行く。
今日はそこへ行く。
今日はロンドンへ行く。
