# BertForQuestionAnswering

## 加載

In [1]:
import torch
import logging

In [2]:
# close transformers logging
logging.getLogger("transformers.file_utils").setLevel(logging.WARNING)
logging.getLogger("transformers.tokenization_utils").setLevel(logging.WARNING)
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARNING)
logging.getLogger("transformers.configuration_utils").setLevel(logging.WARNING)

In [3]:
def _check_has_skip_token(check_tokens,skip_tokens):
    for check_token in check_tokens:
        for skip_token in skip_tokens:
            if check_token == skip_token:
                return True
    return False

def _check_segment_type_is_a(start_index,end_index,segment_embeddings):
    tag_segment_embeddings = segment_embeddings[start_index]
    if 0 in tag_segment_embeddings:
        return True
    return False

def _get_best_indexes(logits, n_best_size):
    """Get the n-best logits from a list."""
    index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)

    best_indexes = []
    for i in range(len(index_and_score)):
        if i >= n_best_size:
            break
        best_indexes.append(index_and_score[i][0])
    return best_indexes

In [4]:
def use_model():
    from transformers import BertTokenizer, AlbertForQuestionAnswering
    tokenizer = BertTokenizer.from_pretrained("clue/albert_chinese_small")
    albert = AlbertForQuestionAnswering.from_pretrained("clue/albert_chinese_small")
    return albert, tokenizer

# ALBERT
model, tokenizer = use_model()

## 輸入(context、question)

In [5]:
context = '王大明是校長'
question = '王大明擔任什麼'
input_encode = tokenizer.encode_plus(question,context,add_special_tokens=True,return_tensors='pt')
print(input_encode)

{'input_ids': tensor([[ 101, 4374, 1920, 3209, 3085,  818,  784, 7938,  102, 4374, 1920, 3209,
         3221, 3413, 7269,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


In [6]:
segment_a_ids = (input_encode['token_type_ids'].squeeze(0) == 0).nonzero().transpose(0, 1).squeeze(0) # 找question長度
len(segment_a_ids)

9

In [7]:
answer_padding = len(segment_a_ids) # 計算 [CLS]SEGMENT_A[SEP] 偏移量
answer_start_position = 4 + answer_padding
answer_end_position = 5 + answer_padding
print(answer_start_position,answer_end_position)

13 14


In [8]:
answer_start_id = input_encode['input_ids'][0][answer_start_position].item()
answer_end_id = input_encode['input_ids'][0][answer_end_position].item()
print(tokenizer.decode(answer_start_id),tokenizer.decode(answer_end_id))

校 長


In [9]:
start_position_labels = torch.tensor([answer_start_position])
end_position_labels = torch.tensor([answer_end_position])
print(start_position_labels)
print(end_position_labels)

tensor([13])
tensor([14])


In [10]:
loss, start_scores, end_scores = model(input_encode['input_ids'],token_type_ids=input_encode['token_type_ids'],\
                                start_positions= start_position_labels, end_positions= end_position_labels )
print(start_scores)
print(end_scores)

tensor([[ 0.1783, -0.2392,  0.0243,  0.1509, -0.5761, -0.3853, -0.1867, -0.3440,
         -0.3300, -0.0576,  0.1205,  0.2807,  0.3168,  0.1942,  0.0434,  0.0638]],
       grad_fn=<SqueezeBackward1>)
tensor([[ 0.0542,  0.0188, -0.1365,  0.3662,  0.0576, -0.1260, -0.1121, -0.1194,
          0.1059, -0.0329, -0.2258,  0.4193, -0.2055,  0.1314,  0.2550,  0.2596]],
       grad_fn=<SqueezeBackward1>)


## 輸出、挑選答案

In [11]:
print(start_scores.shape)
print(end_scores.shape)

torch.Size([1, 16])
torch.Size([1, 16])


In [12]:
predict_start_positions = _get_best_indexes(start_scores.squeeze(0),10)
predict_end_positions = _get_best_indexes(end_scores.squeeze(0),10)
print(predict_start_positions)
print(predict_end_positions)

[12, 11, 13, 0, 3, 10, 15, 14, 2, 9]
[11, 3, 15, 14, 13, 8, 4, 0, 1, 9]


In [13]:
start_scores = start_scores.squeeze(0)
end_scores = end_scores.squeeze(0)
answer_results=[]

In [14]:
for start_index in predict_start_positions:
    for end_index in predict_end_positions:
        answer_ids = input_encode['input_ids'].squeeze(0)[start_index:end_index+1]
        answer_token = tokenizer.convert_ids_to_tokens(answer_ids)

        if(len(answer_token) == 0 or len(answer_token)>30):
            continue
        elif(_check_has_skip_token(check_tokens = answer_token, skip_tokens = ['[CLS]','[SEP]','[PAD]'])):
            continue
        elif(_check_segment_type_is_a(start_index,end_index,input_encode['token_type_ids'].squeeze(0))):
            continue
        answer = "".join(answer_token)
        print(answer)
        answer_result = (start_index,start_scores[start_index].item(),end_index,end_scores[end_index].item(),answer)
        answer_results.append(answer_result)

是校長
是校
明
明是校長
明是校
校長
校
大明
大明是校長
大明是校
長
王大明
王大明是校長
王大明是校
王
