In [2]:
import torch
from transformers import BertForMaskedLM, BertJapaneseTokenizer

In [3]:
pretrain_path = "../../corpus/"
PRETRAINED_MODEL = 'cl-tohoku/bert-base-japanese-whole-word-masking'

In [26]:
tokenizer = BertJapaneseTokenizer.from_pretrained(pretrain_path+PRETRAINED_MODEL)
model = BertForMaskedLM.from_pretrained(pretrain_path+PRETRAINED_MODEL)
model.to('cpu')
model.eval()

Some weights of the model checkpoint at ../../corpus/cl-tohoku/bert-base-japanese-whole-word-masking were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(32000, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

In [7]:
from pyknp import Juman
import spacy
import ginza
nlp = spacy.load('ja_ginza')

In [45]:
prev = "あなたの好きな飲み物は？"
# prev = ""
ans_templete = "私は緑茶が好きです"
# ans_templete = ""

In [46]:

def noun2masked(sentence:str, mask_token="[MASK]") -> str:
    doc = nlp(sentence)
    masked = ""
    for token in doc:
        tag_split = token.tag_.split("-")
        # print(token, tag_split)
        if len(tag_split)>2 and tag_split[0]=="名詞" and tag_split[2] == "一般":
            masked += mask_token
        else:
            masked += token.orth_
    return masked


In [47]:
noun2masked(ans_templete)

'私は[MASK]が好きです'

In [48]:
mask_index = tokenizer.convert_tokens_to_ids(['[MASK]'])[0]

def search_mask_index(indexed_tokens):
    mask_idx = []
    for i, id_ in enumerate(indexed_tokens):
        if id_ == mask_index:
            mask_idx.append(i)
    return mask_idx

def predict_base(prev, masked:str):
    sentence =  "[CLS]{0}[SEP]{1}[SEP]".format(prev, masked)
    print(sentence)
    tokenized_text = tokenizer.tokenize(sentence)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    mask_idx = search_mask_index(indexed_tokens)

    with torch.no_grad():
        outputs = model(  torch.tensor( [indexed_tokens] ))
        predictions = outputs[0]
        if len(mask_idx)==1:
            _, predicted_indexes = torch.topk(predictions[0, mask_idx[0]], k=5)
            predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_indexes.tolist())
            print(predicted_tokens)
        else:
            predicted_tokens_list = []
            for i in mask_idx:
                _, predicted_indexes = torch.topk(predictions[0, i], k=5)
                predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_indexes.tolist())
                predicted_tokens_list.append(predicted_tokens)
            print(predicted_tokens_list)

In [49]:
persona = "私はボクサーですが，"

In [50]:
predict_base(prev, persona+noun2masked(ans_templete))

[CLS]あなたの好きな飲み物は？[SEP]私はボクサーですが，私は[MASK]が好きです[SEP]
['ワイン', '酒', 'コーヒー', 'ビール', '野球']
