# BertForMaskedLM

## 加載

In [1]:
import torch
import logging

In [2]:
# close transformers logging
logging.getLogger("transformers.file_utils").setLevel(logging.WARNING)
logging.getLogger("transformers.tokenization_utils").setLevel(logging.WARNING)
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARNING)
logging.getLogger("transformers.configuration_utils").setLevel(logging.WARNING)

In [3]:
def use_model(model_name, config_file_path, model_file_path, vocab_file_path, num_labels):
    # 選擇模型並加載設定
    if(model_name == 'bert'):
        from transformers import BertConfig, BertForMaskedLM, BertTokenizer
        model_config, model_class, model_tokenizer = (BertConfig, BertForMaskedLM, BertTokenizer)
        config = model_config.from_pretrained(config_file_path,num_labels = num_labels)
        model = model_class.from_pretrained(model_file_path, from_tf=bool('.ckpt' in 'bert-base-chinese'), config=config)
        tokenizer = model_tokenizer.from_pretrained(vocab_file_path)
        return model, tokenizer

In [4]:
# BERT
model_setting = {
    "model_name":"bert", 
    "config_file_path":"bert-base-chinese", 
    "model_file_path":"bert-base-chinese", 
    "vocab_file_path":"bert-base-chinese",
    "num_labels":149  # 分幾類 
}
model, tokenizer = use_model(**model_setting)

## 處理輸入

In [5]:
context = '今天天[MASK]如何?'
masked_label = '氣'
masked_label_id = tokenizer.encode('氣',add_special_tokens=False) # 3706
print(masked_label_id)

context_encode = tokenizer.encode_plus(context,add_special_tokens=True,return_tensors='pt')
special_taken_padding = 1
masked_index = 3 + special_taken_padding

print(context_encode['input_ids'])

[3706]
tensor([[ 101,  791, 1921, 1921,  103, 1963,  862,  136,  102]])


> [CLS]=101, [SEP]=102, [MASK] = 103

In [6]:
masked_lm_labels = torch.LongTensor([-100]*len(context_encode['input_ids'][0]))
masked_lm_labels[masked_index] = torch.LongTensor(masked_label_id)
masked_lm_labels = masked_lm_labels.unsqueeze(0) # make batch
print(masked_lm_labels)

tensor([[-100, -100, -100, -100, 3706, -100, -100, -100, -100]])


> 設置為-100忽略loss計算，只提供需要預測文字的正確vocab_id

In [7]:
outputs = model(context_encode['input_ids'], masked_lm_labels=masked_lm_labels)
loss, prediction_scores = outputs[:2]
print(loss) # 只計算masked_label的loss

tensor(0.8557, grad_fn=<NllLossBackward>)


## 處理輸出、解碼

In [8]:
print(prediction_scores.shape)

torch.Size([1, 9, 21128])


In [9]:
prediction_scores = prediction_scores.squeeze(0)
print(prediction_scores.shape)

torch.Size([9, 21128])


In [10]:
prediction_score = prediction_scores[masked_index]
print(prediction_score.shape)

torch.Size([21128])


In [11]:
vocab_id = torch.argmax(prediction_score).item()
print(vocab_id)

3698


In [12]:
tokenizer.decode(vocab_id)

'气'