# BertForMaskedLM

## 加載

In [1]:
import torch

In [2]:
def use_model(model_name, config_file_path, model_file_path, vocab_file_path, num_labels):
    # 選擇模型並加載設定
    if(model_name == 'bert'):
        from transformers import BertConfig, BertForMaskedLM, BertTokenizer
        model_config, model_class, model_tokenizer = (BertConfig, BertForMaskedLM, BertTokenizer)
        config = model_config.from_pretrained(config_file_path,num_labels = num_labels)
        model = model_class.from_pretrained(model_file_path, from_tf=bool('.ckpt' in 'bert-base-chinese'), config=config)
        tokenizer = model_tokenizer.from_pretrained(vocab_file_path)
        return model, tokenizer

In [3]:
# BERT
model_setting = {
    "model_name":"bert", 
    "config_file_path":"bert-base-chinese", 
    "model_file_path":"bert-base-chinese", 
    "vocab_file_path":"bert-base-chinese",
    "num_labels":149  # 分幾類 
}
model, tokenizer = use_model(**model_setting)

I0430 03:30:56.475142 140628562782016 file_utils.py:41] PyTorch version 1.3.0+cu100 available.
I0430 03:30:58.205695 140628562782016 configuration_utils.py:283] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json from cache at /root/.cache/torch/transformers/8a3b1cfe5da58286e12a0f5d7d182b8d6eca88c08e26c332ee3817548cf7e60a.f12a4f986e43d8b328f5b067a641064d67b91597567a06c7b122d1ca7dfd9741
I0430 03:30:58.208685 140628562782016 configuration_utils.py:319] Model config BertConfig {
  "_num_labels": 149,
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bad_words_ids": null,
  "bos_token_id": null,
  "decoder_start_token_id": null,
  "directionality": "bidi",
  "do_sample": false,
  "early_stopping": false,
  "eos_token_id": null,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "

I0430 03:30:59.160732 140628562782016 modeling_utils.py:507] loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin from cache at /root/.cache/torch/transformers/b1b5e295889f2d0979ede9a78ad9cb5dc6a0e25ab7f9417b315f0a2c22f4683d.929717ca66a3ba9eb9ec2f85973c6398c54c38a4faa464636a491d7a705f7eb6
I0430 03:31:00.850485 140628562782016 modeling_utils.py:601] Weights of BertForMaskedLM not initialized from pretrained model: ['cls.predictions.decoder.bias']
I0430 03:31:00.851286 140628562782016 modeling_utils.py:607] Weights from pretrained model not used in BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
I0430 03:31:01.801495 140628562782016 tokenization_utils.py:504] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt from cache at /root/.cache/torch/transformers/8a0c070123c1f794c42a29c6904beb7c1b8715741e235bee04aca2c7636fc83f.9b42061518a39ca00b8b52059fd2bede8daa613f8a86

## 處理輸入

In [4]:
context = '今天天[MASK]如何?'
masked_label = '氣'
masked_label_id = tokenizer.encode('氣',add_special_tokens=False) # 3706
print(masked_label_id)

context_encode = tokenizer.encode_plus(context,add_special_tokens=True,return_tensors='pt')
special_taken_padding = 1
masked_index = 3 + special_taken_padding

print(context_encode['input_ids'])

[3706]
tensor([[ 101,  791, 1921, 1921,  103, 1963,  862,  136,  102]])


> [CLS]=101, [SEP]=102, [MASK] = 103

In [5]:
masked_lm_labels = torch.LongTensor([-100]*len(context_encode['input_ids'][0]))
masked_lm_labels[masked_index] = torch.LongTensor(masked_label_id)
masked_lm_labels = masked_lm_labels.unsqueeze(0) # make batch
print(masked_lm_labels)

tensor([[-100, -100, -100, -100, 3706, -100, -100, -100, -100]])


> 設置為-100忽略loss計算，只提供需要預測文字的正確vocab_id

In [6]:
outputs = model(context_encode['input_ids'], masked_lm_labels=masked_lm_labels)
loss, prediction_scores = outputs[:2]
print(loss) # 只計算masked_label的loss

tensor(0.8557, grad_fn=<NllLossBackward>)


## 處理輸出、解碼

In [7]:
print(prediction_scores.shape)

torch.Size([1, 9, 21128])


In [8]:
prediction_scores = prediction_scores.squeeze(0)
print(prediction_scores.shape)

torch.Size([9, 21128])


In [9]:
prediction_score = prediction_scores[masked_index]
print(prediction_score.shape)

torch.Size([21128])


In [10]:
vocab_id = torch.argmax(prediction_score).item()
print(vocab_id)

3698


In [11]:
tokenizer.decode(vocab_id)

'气'