# BertForQuestionAnswering

## 加載

In [1]:
import torch

In [2]:
def _check_has_skip_token(check_tokens,skip_tokens):
    for check_token in check_tokens:
        for skip_token in skip_tokens:
            if check_token == skip_token:
                return True
    return False

def _check_segment_type_is_a(start_index,end_index,segment_embeddings):
    tag_segment_embeddings = segment_embeddings[start_index]
    if 0 in tag_segment_embeddings:
        return True
    return False

def _get_best_indexes(logits, n_best_size):
    """Get the n-best logits from a list."""
    index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)

    best_indexes = []
    for i in range(len(index_and_score)):
        if i >= n_best_size:
            break
        best_indexes.append(index_and_score[i][0])
    return best_indexes

In [3]:
def use_model():
    from transformers import BertTokenizer, AlbertForQuestionAnswering
    tokenizer = BertTokenizer.from_pretrained("clue/albert_chinese_small")
    albert = AlbertForQuestionAnswering.from_pretrained("clue/albert_chinese_small")
    return albert, tokenizer

# ALBERT
model, tokenizer = use_model()

I0430 03:48:07.135764 140367158015808 file_utils.py:41] PyTorch version 1.3.0+cu100 available.
I0430 03:48:07.924898 140367158015808 tokenization_utils.py:420] Model name 'clue/albert_chinese_small' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, bert-base-finnish-cased-v1, bert-base-finnish-uncased-v1, bert-base-dutch-cased). Assuming 'clue/albert_chinese_small' is a path, a model identifier, or url to a directory containing tokenizer files.
I0430 03:48:11.769205 140367158015808 tokenization_utils.py:504] loading file https://s3.amazonaws.com/mo

## 輸入(context、question)

In [4]:
context = '王大明是校長'
question = '王大明擔任什麼'
input_encode = tokenizer.encode_plus(question,context,add_special_tokens=True,return_tensors='pt')
print(input_encode)

{'input_ids': tensor([[ 101, 4374, 1920, 3209, 3085,  818,  784, 7938,  102, 4374, 1920, 3209,
         3221, 3413, 7269,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


In [5]:
segment_a_ids = (input_encode['token_type_ids'].squeeze(0) == 0).nonzero().transpose(0, 1).squeeze(0) # 找question長度
len(segment_a_ids)

9

In [6]:
answer_padding = len(segment_a_ids) # 計算 [CLS]SEGMENT_A[SEP] 偏移量
answer_start_position = 4 + answer_padding
answer_end_position = 5 + answer_padding
print(answer_start_position,answer_end_position)

13 14


In [7]:
answer_start_id = input_encode['input_ids'][0][answer_start_position].item()
answer_end_id = input_encode['input_ids'][0][answer_end_position].item()
print(tokenizer.decode(answer_start_id),tokenizer.decode(answer_end_id))

校 長


In [8]:
start_position_labels = torch.tensor([answer_start_position])
end_position_labels = torch.tensor([answer_end_position])
print(start_position_labels)
print(end_position_labels)

tensor([13])
tensor([14])


In [9]:
loss, start_scores, end_scores = model(input_encode['input_ids'],token_type_ids=input_encode['token_type_ids'],\
                                start_positions= start_position_labels, end_positions= end_position_labels )
print(start_scores)
print(end_scores)

tensor([[ 0.0334, -0.3516, -0.0872, -0.1833,  0.0492,  0.0119,  0.1604,  0.2262,
          0.1590, -0.4066, -0.2251, -0.1369, -0.3766,  0.0882, -0.1281, -0.2077]],
       grad_fn=<SqueezeBackward1>)
tensor([[ 0.4998,  0.2118,  0.3676,  0.5608,  0.3600, -0.2253, -0.1838,  0.4214,
         -0.0127,  0.0307,  0.5352,  0.7715,  0.3497,  0.3239,  0.4949, -0.0497]],
       grad_fn=<SqueezeBackward1>)


## 輸出、挑選答案

In [10]:
print(start_scores.shape)
print(end_scores.shape)

torch.Size([1, 16])
torch.Size([1, 16])


In [11]:
predict_start_positions = _get_best_indexes(start_scores.squeeze(0),10)
predict_end_positions = _get_best_indexes(end_scores.squeeze(0),10)
print(predict_start_positions)
print(predict_end_positions)

[7, 6, 8, 13, 4, 0, 5, 2, 14, 11]
[11, 3, 10, 0, 14, 7, 2, 4, 12, 13]


In [12]:
start_scores = start_scores.squeeze(0)
end_scores = end_scores.squeeze(0)
answer_results=[]

In [13]:
for start_index in predict_start_positions:
    for end_index in predict_end_positions:
        answer_ids = input_encode['input_ids'].squeeze(0)[start_index:end_index+1]
        answer_token = tokenizer.convert_ids_to_tokens(answer_ids)

        if(len(answer_token) == 0 or len(answer_token)>30):
            continue
        elif(_check_has_skip_token(check_tokens = answer_token, skip_tokens = ['[CLS]','[SEP]','[PAD]'])):
            continue
        elif(_check_segment_type_is_a(start_index,end_index,input_encode['token_type_ids'].squeeze(0))):
            continue
        answer = "".join(answer_token)
        print(answer)
        answer_result = (start_index,start_scores[start_index].item(),end_index,end_scores[end_index].item(),answer)
        answer_results.append(answer_result)

校長
校
長
明
明是校長
明是
明是校
