In [1]:
import logging

import torch
from transformers import BertForMaskedLM, BertJapaneseTokenizer

logger = logging.getLogger(__name__)


class BertProofreader:
    def __init__(self, pretrained_model: str, cache_dir: str = None):

        # Load pre-trained model tokenizer (vocabulary)
        self.tokenizer = BertJapaneseTokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)

        # Load pre-trained model (weights)
        self.model = BertForMaskedLM.from_pretrained(pretrained_model, cache_dir=cache_dir)
        self.model.to('cuda')

        self.model.eval()

    def mask_prediction(self, sentence: str) -> torch.Tensor:
        # 特殊Tokenの追加
        sentence = f'[CLS]{sentence}[SEP]'

        tokenized_text = self.tokenizer.tokenize(sentence)

        indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
        tokens_tensor = torch.tensor([indexed_tokens], device='cuda')

        # [MASK]に対応するindexを取得
        mask_index = self.tokenizer.convert_tokens_to_ids(['[MASK]'])[0]

        # 1単語ずつ[MASK]に置き換えたTensorを作る
        repeat_num = tokens_tensor.shape[1] - 2
        tokens_tensor = tokens_tensor.repeat(repeat_num, 1)
        for i in range(repeat_num):
            tokens_tensor[i, i + 1] = mask_index

        # Predict all tokens
        with torch.no_grad():
            outputs = self.model(tokens_tensor, token_type_ids=None)
            predictions = outputs[0]

        return tokenized_text, predictions

    def check_topk(self, sentence: str, topk: int = 10):
        """
        [MASK]に対して予測された単語のTop Kに元の単語が含まれていればTrueと判定
        """

        tokens, predictions = self.mask_prediction(sentence)

        pred_sort = torch.argsort(predictions, dim=2, descending=True)
        pred_top_k = pred_sort[:, :, :topk]  # 上位Xのindex取得

        judges = []
        for i in range(len(tokens) - 2):
            pred_top_k_word = self.tokenizer.convert_ids_to_tokens(pred_top_k[i][i + 1])
            judges.append(tokens[i + 1] in pred_top_k_word)
            logger.info(f'{tokens[i + 1]}: {judges[-1]}')
            logger.debug(f'top k word={pred_top_k_word}')

        return all(judges)

    def check_threshold(self, sentence: str, threshold: float = 0.01):
        """
        [MASK]に対して予測された単語のスコアが閾値以上の単語群に、元の単語が含まれていればTrueと判定
        """
        tokens, predictions = self.mask_prediction(sentence)

        predictions = predictions.softmax(dim=2)

        judges = []
        for i in range(len(tokens) - 2):
            indices = (predictions[i][i + 1] >= threshold).nonzero()
            pred_top_word = self.tokenizer.convert_ids_to_tokens(indices)
            judges.append(tokens[i + 1] in pred_top_word)
            logger.info(f'{tokens[i + 1]}: {judges[-1]}')

        return all(judges)



In [5]:
import logging
# from models.bert_proofreader import BertProofreader

logging.basicConfig(level=logging.INFO)

PRETRAINED_MODEL = 'cl-tohoku/bert-base-japanese-whole-word-masking'
proofreader = BertProofreader(PRETRAINED_MODEL)

Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-whole-word-masking were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [8]:
proofreader.check_topk('今彼女がいないんで、私も是非上げたいです。', topk=100)

INFO:__main__:今: True
INFO:__main__:彼女: True
INFO:__main__:が: True
INFO:__main__:い: True
INFO:__main__:ない: True
INFO:__main__:ん: True
INFO:__main__:##で: True
INFO:__main__:、: True
INFO:__main__:私: True
INFO:__main__:も: True
INFO:__main__:是非: False
INFO:__main__:上げ: False
INFO:__main__:たい: True
INFO:__main__:です: True
INFO:__main__:。: True


False

In [20]:
proofreader.check_topk('元気は良いんですか？？元気ですかは元気か．', topk=100)

INFO:__main__:元気: True
INFO:__main__:は: True
INFO:__main__:良い: True
INFO:__main__:ん: True
INFO:__main__:です: True
INFO:__main__:か: True
INFO:__main__:?: True
INFO:__main__:##?: True
INFO:__main__:元気: True
INFO:__main__:です: True
INFO:__main__:か: True
INFO:__main__:は: True
INFO:__main__:元気: True
INFO:__main__:か: True
INFO:__main__:.: True


True