# 使用 BERT 進行中文機器閱讀理解模型訓練
# (BERT  Machine Reading Comprehension using DRCD)

<br>
<br>

<font color='gray'>

本範例展示如何透過 huggingface pytorch 訓練出 transformer based 中文機器閱讀理解模型，

並提供 model inference 和 RESTful API 的實作與介紹。

<br>

閱讀理解又稱為問答系統，其應用為給定一個文章段落和問題，機器需從段落中找尋問題對應的答案，

其本質是尋找問題的答案，位在段落的哪個索引區間(start_index, end_index)

<br>

如下圖所示，若文章段落(paragraph)包含256個token，則模型將從這256個token中，

尋找哪一個token最適合作為答案的開始位置(start_index)，哪一個token最適合作為答案的結束位置(end_index)

假設start_index=172、end_index=178，表示區間為 172-178 共七個token所組成的答案。

<br>

本範例將透過[台達閱讀理解資料集](https://github.com/DRCKnowledgeTeam/DRCD)展示如何訓練一個中文的機器閱讀理解模型。

</font>
<br>
<br>
<div>
<img style="float: left;" src="attachment:image.png" width="500"/>
</div>


# Delta Reading Comprehension Dataset
<br>
<font color='gray'>
台達閱讀理解資料集 Delta Reading Comprehension Dataset (DRCD) 屬於通用領域繁體中文機器閱讀理解資料集。 本資料集期望成為適用於遷移學習之標準中文閱讀理解資料集。 本資料集從2,108篇維基條目中整理出10,014篇段落，並從段落中標註出30,000多個問題
</font>
<br>

## Data format 資料格式




    
    -version :
    -data :
            -title : : 文章標題 
            -id : : 文章編號
            -paragraphs :
                -id : : 文章編號_段落編號     
                -context : : 段落內容
                -qas :
                    -question : : 問題內容
                    -id : : 文章編號_段落編號_問題編號
                    -answers :
                        -answer_start : text在文中位置
                        -id : : "1"表示為人工標註的答案，"2"以上為人工答題的答案
                        -text : : 答案內容
               

## Example
<br>
<div>
<img style="float: center;" src="attachment:drcd__sample.png" width="500"/>
</div>
<br> 

## Reference

[Shao, Chih Chieh, et al. "Drcd: a chinese machine reading comprehension dataset." arXiv preprint
arXiv:1806.00920 (2018).](https://arxiv.org/abs/1806.00920)


In [21]:
# 本範例透過 pytorch 基於 huggingface 所開源的 transformer api 進行實作，
# 請參考 requirements.txt 各 package 的版本
# !pip install -r requirements.txt

# pytorch package
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

# huggingface transformer package
from transformers import * 

# 資料處理
import numpy as np
import pandas as pd

# DRCD資料載入與使用
import json

# 模型落地時間紀錄
from time import strftime, gmtime
import datetime

# 簡繁轉換工具
from zhconv import convert 

# 進度條工具
# import tqdm

# 模型參數
import argparse


# Raw data loading class

- 一個段落、一個問題、一個答案，為一筆訓練資料
- 若答案有複數個，第一筆為人工標註的答案，其餘為人工答題的答案
- 此範例僅取人工標註作為答案
- (相關 DRCD 資料集的 JSON 結構請參考上述介紹)

In [22]:
class DRCDRawData():
    def __init__(self, train_path=None, test_path=None, dev_path=None):
        if train_path != None:
            self.train = self.load_data(train_path)
        if test_path != None:
            self.test = self.load_data(test_path)
        if dev_path != None:
            self.dev = self.load_data(dev_path)
    def load_data(self, path):
        dataset = json.load(open(path, encoding='utf-8'))
        result = []
        for i in range(len(dataset['data'])):
            for j in dataset['data'][i]['paragraphs']:
                context = j['context']
                for qa in j['qas']:
                    question = qa['question']
                    if qa['answers']:
                        result.append([context, question, qa['answers'][0]['text']])                     
        return result

# Customize DRCDataset class

- 自行定義 DRCD 資料集，包含所需要的前處理步驟，和資料型態
- 繼承 torch.utils.data.Dataset 並定義好 "__init__" (constructor) 要做哪些事情 
- Dataset 為一個 iterator object 可覆寫(override) "__getitem__" 和 "__len__" functions
- 前處理包含將 string 轉成 token 再轉成 ids，以及準備答案的 start_index 和 end_index

In [23]:

class DRCD(Dataset):
    '''
    - Initialize BertTokenizer, tokenizer is in charge of preparing the inputs for a model. 
    - Device will be whether 'CPU' or 'GPU'
    '''
    def __init__(self, data, model_type, device, lang):
        self.data = data
        self.tokenizer = BertTokenizer.from_pretrained(model_type)
        self.device = device

        self.lang = lang
    
    '''
    - Tokenizing strings in word, sub-word token strings, converting tokens strings to ids
    - Managing special tokens (like mask, beginning-of-sentence, etc.)
    - Adding new tokens to the vocabulary (BPE, SentencePiece…).
    '''
    def __getitem__(self, idx):
        
        paragraph, question, ans = self.data[idx][0], self.data[idx][1], self.data[idx][2] 
        
        if self.lang == 'zh-cn':
            paragraph, question, ans = convert(paragraph, 'zh-cn'), convert(question, 'zh-cn'), convert(ans, 'zh-cn')
        
        # Tokenize and prepare for the model a sequence or a pair of sequences. 
        # so unlike 'encode' just encode to ids, encode_plus return not only ids 
        # but token_type_ids, attention_mask also
        token_tensor = self.tokenizer.encode_plus(question, paragraph, max_length=512, 
                                                  truncation=True, pad_to_max_length=True)
        
        ans_encode = self.tokenizer.encode(ans)
        s_tensor, e_tensor = self.find_ans_index(token_tensor, ans_encode)

        return {'input_ids': torch.tensor(token_tensor['input_ids']).to(self.device),
                'token_type_ids': torch.tensor(token_tensor['token_type_ids']).to(self.device),
                'attention_mask': torch.tensor(token_tensor['attention_mask']).to(self.device),
                's_tensor': s_tensor.to(self.device),
                'e_tensor': e_tensor.to(self.device)}

    def __len__(self):
        return len(self.data)
    
    def find_ans_index(self, token_tensor, ans_encode):
        start_idx, end_idx = [0]*512, [0]*512

        # answer's token (removing [CLS] and [SEP] by [1:-1])
        ans_tokens = ans_encode[1:-1]
        
        # find the idx of paragraph startwith ans's first token 
        match_ans_prefix_list = [i for i, x in enumerate(token_tensor['input_ids']) if x == ans_tokens[0]]
      
        for start_pos in match_ans_prefix_list:
            end_pos = start_pos + len(ans_tokens)
            if end_pos > 511:
                continue
            if token_tensor['input_ids'][start_pos:end_pos] == ans_tokens:
                start_idx[start_pos] = 1
                end_idx[end_pos-1] = 1
                break
        return  torch.Tensor(start_idx), torch.Tensor(end_idx)
    

# Machine Reading Comprehension Model
- Output start & end index distribution
- Only focus on paragraph, so mask the other part (question, special_token)

In [24]:
class BertForReadingComprehension(nn.Module):
    def __init__(self, model_type):
        super(BertForReadingComprehension, self).__init__()

        config = BertConfig.from_pretrained(model_type, output_hidden_states=True)
        self.bert_model = BertModel.from_pretrained(model_type, config=config)
        
        self.s_decoder = nn.Sequential(nn.Linear(config.hidden_size, 1)) 
        self.e_decoder = nn.Sequential(nn.Linear(config.hidden_size, 1))

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None): 
        hidden = self.bert_model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0]
        s_decode = self.s_decoder(hidden).squeeze()
        e_decode = self.e_decoder(hidden).squeeze()
        
        # mask question part, because ans only appears in the paragraph.
        mask = token_type_ids.clone().float().to(hidden.device).detach()
        mask[mask != 1] = float('-inf')

        return s_decode + mask, e_decode + mask


# Test Function

In [25]:
def test(model, data, args, source):
    if source == 'test':
        dataset = DRCD(data.test, args.model_type, args.device, args.language)
        print(f'Testing...{len(dataset)}')
    elif source == 'dev':   
        dataset = DRCD(data.dev, args.model_type, args.device, args.language)
        print(f'Dev...{len(dataset)}')
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
    
    # gradient will not accumulate while model is in evaluation mode
    model.eval()
    loss, i = 0, 0
    criterion = nn.BCELoss()
    with torch.no_grad():
        for batch in loader:
            i += 1
            inp_ids, tok_id, att_m = batch['input_ids'], batch['token_type_ids'], batch['attention_mask']
            s_label, e_label = batch['s_tensor'], batch['e_tensor']
            
            s_output, e_output = model(input_ids=inp_ids, 
                                       attention_mask=att_m, 
                                       token_type_ids=tok_id)
            
            batch_loss = (criterion(torch.softmax(s_output, dim=-1), s_label) 
                          + criterion(torch.softmax(e_output, dim=-1), e_label)) / 2
            
            loss += batch_loss     
        print(f'Result: LOSS: {loss} AVG: {loss/i} ')
    return loss

# Train Function

In [26]:
def train(data, args):
    dataset = DRCD(data.train, args.model_type, args.device, args.language)
    trainLoader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
    model = BertForReadingComprehension(args.model_type).to(args.device)
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = AdamW(parameters, lr=args.learning_rate, weight_decay=args.weight_decay)
    criterion = nn.BCELoss()
    model.train()
    min_loss = float('inf')
    for ei in range(args.epoch):
        model.train()
        epoch_loss, i, check_loss = 0, 0, 0
        for batch in trainLoader:
            inp_ids, tok_id, att_m = batch['input_ids'], batch['token_type_ids'], batch['attention_mask']
            s_label, e_label = batch['s_tensor'], batch['e_tensor']
            
            s_output, e_output = model(input_ids=inp_ids, 
                         attention_mask=att_m,
                         token_type_ids=tok_id)
            
            batch_loss = (criterion(torch.softmax(s_output, dim=-1), s_label) 
                          + criterion(torch.softmax(e_output, dim=-1), e_label)) / 2
            
            epoch_loss += batch_loss
            check_loss += batch_loss
            batch_loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            i+=1
            if i % 1000==0:
                print(f'1000 batch LOSS {check_loss}')
                check_loss = 0

        dev_loss = test(model, data, args, 'dev') 
        if dev_loss < min_loss:
            model_name = 'BertForReadingComprehension'+'_'+(datetime.now()).strftime("%m%d")+'_'+'.pt' 
            print('dev_loss:', dev_loss, ', min_loss:', min_loss)
            print('save current best model:', model_name)
            
            min_loss = dev_loss
            torch.save(model.state_dict(), f'model/{model_name}')         
            
        print(f'===Epoches: {ei} Loss {epoch_loss}===')  
    print('---Training Finished---')        
    return best_model

# Parameters

In [90]:
parser = argparse.ArgumentParser([])
parser.add_argument('--batch-size', default=8, type=int)
parser.add_argument('--epoch', default=10, type=int)
parser.add_argument('--learning-rate', default=1e-5, type=float)    
parser.add_argument('--weight-decay', default=0.001, type=float)
parser.add_argument('--model-type', default='hfl/chinese-roberta-wwm-ext', type=str)  #model_type = 'hfl/chinese-bert-wwm'  'hfl/chinese-roberta-wwm-ext'  'hfl/chinese-roberta-wwm-ext-large'
parser.add_argument('--device', default=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'), type=int)
parser.add_argument('--language', default='zh-tw', type=str)
args = parser.parse_args([])
args

Namespace(batch_size=8, device=device(type='cuda', index=0), epoch=10, language='zh-tw', learning_rate=1e-05, model_type='hfl/chinese-roberta-wwm-ext', weight_decay=0.001)

# DRCD Raw Data Loading

In [28]:
data = DRCDRawData(train_path='./dataset/DRCD_train.json', test_path='./dataset/DRCD_test.json', dev_path='./dataset/DRCD_dev.json')

# Training/Testing Process

In [42]:
mode = 'train'
if mode == 'train': 
    best_model = train(data, args)
elif mode == 'test' or mode == 'dev':
    model_name = 'bertDRCD_0808_1213.pt'
    test_model = BertForReadingComprehension(args.model_type).to(args.device)
    test_model.load_state_dict(torch.load(f'model/{model_name}'))
    test(test_model, data, args, mode)

# Inference Model

In [77]:
class InferenceModel():
    def __init__(self, model_path, model_type, lang): 
        self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
        config = BertConfig.from_pretrained(model_type,output_hidden_states=True)
        self.tokenizer = BertTokenizer.from_pretrained(model_type)
        self.model = BertForReadingComprehension(model_type).to(self.device)
        self.model.load_state_dict(torch.load(model_path)) 
        self.lang = lang

    def inference(self, content, question): 
        content, question = self.clean_str(content, question)
        with torch.no_grad():   
            token_tensor = self.tokenizer.encode_plus(str(question), str(content), max_length=512, 
                                                      truncation=True, pad_to_max_length=True)
            token = torch.tensor(token_tensor['input_ids']).unsqueeze(0).to(self.device)
            segment = torch.tensor( token_tensor['token_type_ids']).unsqueeze(0).to(self.device)
            mask = torch.tensor( token_tensor['attention_mask'] ).unsqueeze(0).to(self.device)
            answer_start, answer_end = self.model(input_ids=token,attention_mask=mask,token_type_ids=segment) 
            answer_start, answer_end = answer_start.argmax(-1), answer_end.argmax(-1)
            if answer_start > answer_end :
                return 'invalid span'
            tokens = self.tokenizer.convert_ids_to_tokens(token.squeeze())
            answer = ''.join(tokens[answer_start:answer_end+1])
            if self.lang == 'zh-tw':
                answer = convert(answer, 'zh-cn')
        return answer
    
    def clean_str(self, content, question):
        content, question = content.replace(' ',''), question.replace(' ','')
        if self.lang == 'zh-cn':
            content, question = convert(content, 'zh-cn'), convert(question, 'zh-cn')
        return content, question

In [78]:
model_name = 'bertDRCD_1023_1619.pkl'
drcd_model = InferenceModel(f'model/{model_name}','hfl/chinese-roberta-wwm-ext','tw')

In [79]:
paragraph = "中央流行疫情指揮中心今公布國內新增1例境外移入COVID-19病例(案579)，\
為20多歲烏克蘭籍女子，11月3日自烏克蘭經土耳其入境臺灣。中央流行疫情指揮中心醫療應變組副組長羅一鈞表示，\
案579於11月3日來台工作，搭機前3日內檢驗陰性，入境時無上呼吸道症狀，\
入境後至防疫旅館進行居家檢疫，11月6日居家檢疫期間主動通報有鼻塞、嗅味覺異常情形，\
由衛生單位採檢送驗，於今日確診，目前住院隔離中。羅一鈞表示，衛生單位已掌握個案接觸者共54人，\
同行者8人皆持搭機前3日內檢驗陰性報告，目前無不適症狀，列為居家隔離；同班機前後2排座位旅客共10人（扣除3名同行者）列為居家隔離。\
羅一鈞表示，案579在台公司接觸者及旅館工作人員共7人，因僅配戴口罩，列為居家隔離（居家隔離人數共25人）。\
醫護人員17名均著適當防護裝備，列自主健康管理，班機機組人員皆為外籍人士，共12人，均已離境。指揮中心統計，\
截至目前國內累計104,017例新型冠狀病毒肺炎相關通報(含102,449例排除)，其中578例確診，分別為486例境外移入，55例本⼟病例，\
36例敦睦艦隊及1例不明； 另1例(案530)移除為空號。確診個案中7人死亡、526人解除隔離、45人住院隔離中。"
questions = ['案579何時來台工作', '多人少住院隔離中', '累積通報']
for q in questions:
    print(q + ':  '+ drcd_model.inference(paragraph, q).replace('##', ''))

案579何時來台工作:  11月3日
多人少住院隔離中:  45人
累積通報:  104,017例


# RESTful API

In [83]:
from flask import Flask, request, jsonify
from flask_cors import CORS
app = Flask(__name__)
CORS(app)
@app.route("/predict",methods=['POST'])
def predict():
    paragraph = request.json["paragraph"]
    question = request.json["question"]
    try:
        result = {'qa': {}}
        for q in question:
            if q:
                result['qa'][q] = drcd_model.inference(paragraph, q).replace('##', '')
        return jsonify(result)
    except Exception as e:
        print(e)
        return jsonify({"result": "Model Failed"})

if __name__ == "__main__":
    app.run('0.0.0.0',port=8000)

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
   Use a production WSGI server instead.
 * Debug mode: off


 * Running on http://0.0.0.0:8000/ (Press CTRL+C to quit)
10.137.33.108 - - [17/Nov/2020 16:58:14] "POST /predict HTTP/1.1" 200 -


<div>
<img style="float: left;" src="attachment:image.png" width="750"/>
</div>