In [None]:
from transformers import BertTokenizer, BertModel
import torch
import torch.nn as nn
import os

# metrics

In [None]:
import re
import collections
import string
 
def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
 
    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)
 
    def white_space_fix(text):
        return " ".join(text.split())
 
    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)
 
    def lower(text):
        return text.lower()
 
    return white_space_fix(remove_articles(remove_punc(lower(s))))
 
def get_tokens(s):
    if not s:
        return []
    return normalize_answer(s).split()
 
def compute_exact(a_gold, a_pred):
    return int(normalize_answer(a_gold) == normalize_answer(a_pred))
 
def compute_f1(a_gold, a_pred):
    gold_toks = get_tokens(a_gold)
    pred_toks = get_tokens(a_pred)
    common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
    num_same = sum(common.values())
    if len(gold_toks) == 0 or len(pred_toks) == 0:
        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        return int(gold_toks == pred_toks)
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

# utilities

In [None]:
def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # 四捨五入到最近的秒
    elapsed_rounded = int(round((elapsed)))
    
    # 格式化為 hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

def exactly_tokens(tokens):
    # tokens:['a','b','##ccd']
    temp = ''
    for token in tokens:
        if token[:2]=='##':
            temp += token[2:]
        else:
            temp += ' '+token
    return temp[1:]

def get_predAns(input_vector,start,end):
    predAns = []
    for i,vector in enumerate(input_vector):
        pred_id = vector[start[i]:end[i]+1]
        tokens = tokenizer.convert_ids_to_tokens(pred_id)    # ['a','b','##ccd']
        predAns.append(exactly_tokens(tokens))
    # return predAns[ans1,ans2,...,ansn]
    return predAns

def get_metrics(predAns,ans):
    f1_count = 0
    EM_count = 0
    assert len(predAns)==len(ans)
    for i,pred in enumerate(predAns):
        f1_count += compute_f1(ans[i],pred)
        EM_count += compute_exact(ans[i],pred)
    return f1_count, EM_count

# process raw data function

In [None]:
def get_paragraph_QA(data):
    paragraph = []
    QA = []
    for line in data:
        snippets = []
        while True:
            try:
                idx = line.index('</s>')
            except:
                break
            snippet = line[:idx+4]
            line = line[len(snippet):]
            snippets.append(snippet.strip()[3:-4].strip())
        # get QA
        idx_1 = line.index('|||')    # first '|||' index
        idx_2 = line.index('|||',idx_1+3)    # second '|||' index
        Q = line[idx_1+3:idx_2].strip()
        A = line[idx_2+3:].strip()
        QA.append([Q,A])
        paragraph.append(snippets)
        assert len(paragraph)==len(QA)
    return paragraph,QA

# get model input data function

In [None]:
def start_end_pos(a,b):
    flag=False
    for i in range(len(a)):
        for j in range(len(b)):
            if(b[j] == a[i+j]):
                flag=True
                start=i
            else:
                flag=False
                start=0
                end=0
                break
        if flag==True:
            end=i+j
            break
    return start,end

In [None]:
# 只取前面的384 token
bert_input_len = 384
def get_data(mode,QA,paragraph,tokenizer):
    input_vector = []
    att_mask = []
    segment_ids = []
    start_pos=[]
    end_pos=[]
    ans = []
    for i,doc in enumerate(paragraph):
        count=0
        while count<len(doc):
            if count==0:
                # return a list: [101,...,102](101為cls,102為sep)，若return_tensors = 'pt'
                # 則return tensor([[101,...,102]])
                ls1 = tokenizer.encode(QA[i][0], add_special_tokens=True)
                ls2 = tokenizer.encode(doc[count]+'.', add_special_tokens=False) 
                input_id = ls1+ls2
            else:
                ls1 = tokenizer.encode(doc[count]+'.',add_special_tokens=False)
                if len(input_id+ls1)>(bert_input_len-1):break
                else:input_id += ls1
            count+=1
        input_id += [102]
        num_pad = bert_input_len-len(input_id)
        input_id += [0]*num_pad
        input_vector.append(torch.Tensor(input_id).type(torch.LongTensor))
        
        # attention mask
        mask = [1]*(bert_input_len-num_pad)+[0]*num_pad
        att_mask.append(torch.Tensor(mask).type(torch.FloatTensor))
        
        # segment_ids
        sep_index = input_id.index(tokenizer.sep_token_id)
        num_seg_a = sep_index + 1
        num_seg_b = len(input_id) - num_seg_a
        temp = [0]*num_seg_a + [1]*num_seg_b
        segment_ids.append(torch.Tensor(temp).type(torch.LongTensor))
        
        # start end position
        if mode !='test':
            # get text_ans
            ans_token = tokenizer.tokenize(QA[i][1])
            a = input_id
            b = tokenizer.convert_tokens_to_ids(ans_token)
            start,end = start_end_pos(a,b)
            if start==0 and end==0:ans.append(exactly_tokens('[CLS]'))  
            else:ans.append(exactly_tokens(ans_token))
            start_pos.append(start)
            end_pos.append(end)
    start_pos = torch.Tensor(start_pos).type(torch.LongTensor)
    end_pos = torch.Tensor(end_pos).type(torch.LongTensor)
    
    print('done!!!')
    if mode != 'test':
        assert len(input_vector) == len(segment_ids) == len(att_mask) == len(start_pos) == len(end_pos)
        return input_vector,att_mask,segment_ids,start_pos,end_pos,ans
    else:
        assert len(input_vector) == len(segment_ids) == len(att_mask)
        return input_vector,att_mask,segment_ids
# input_vector:[tensor([1,2,3,...,55,6]),...,tensor([1,26,33,...,54,22])]
# att_mask:[tensor([1,1,1,...,0,0],...,tensor([1,1,1,...,1,0])]
# start_pos:tensor([33,54,87,1,2,...,333])
# end_pos:tensor([33,54,87,1,2,...,333])
# segment_ids:[tensor([0,0,0,...,1,1,1,1],...,tensor([0,0,0,...,1,1,1,1])]
# ans: [ans1,ans2,...,ansn]

# pretrain model,output layer

In [None]:
bert = BertModel.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

class QAModelOutput:
    def __init__(self,loss,start_logits,end_logits):
        self.loss=loss
        self.start_logits=start_logits
        self.end_logits=end_logits

class CustomeModel(nn.Module):
    def __init__(self,bert_model):
        super().__init__()
        self.bert = bert_model
        self.qa_outputs = nn.Linear(1024, 2)

    def forward(
        self,
        input_ids=None,
        token_type_ids=None,
        attention_mask=None,
        output_hidden_states=None,
        start_positions=None,
        end_positions=None,
    ):
        
        outputs = self.bert(input_ids=input_ids,token_type_ids=token_type_ids,attention_mask=attention_mask,output_hidden_states=output_hidden_states)
        seq_out = outputs.last_hidden_state
        logits = self.qa_outputs(seq_out)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        total_loss = None
        if start_positions is not None and end_positions is not None:
            loss_fct = nn.CrossEntropyLoss()
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        return QAModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits
        )

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = CustomeModel(bert).to(device)

# Dataset class

In [None]:
# 目前只有train, val
from torch.utils.data import Dataset
class QA_Dataset(Dataset):
    '''Dataset for loading and preprocessing'''
    def __init__(self, mode,x_vectors,att_mask,segment_ids,start=None, end=None,ans=None):
        self.mode = mode
        self.x_vectors = x_vectors
        self.att_mask = att_mask
        self.segment_ids = segment_ids
        self.start = start
        self.end = end
        self.ans = ans
    def __len__(self):
        return len(self.x_vectors)
    def __getitem__(self, idx):
        if self.mode=='train' or self.mode=='val':
            return self.x_vectors[idx], self.att_mask[idx], self.segment_ids[idx],self.start[idx], self.end[idx], self.ans[idx]
        else:
            return self.x_vectors[idx], self.att_mask[idx], self.segment_ids[idx]


# get train,val data

In [None]:
with open('./train.txt','r') as f1:
    data1 = f1.readlines()
with open('./val.txt','r') as f2:
    data2 = f2.readlines()

In [None]:
from torch.utils.data import DataLoader
# train data
train_paragraph,train_QA = get_paragraph_QA(data1)
train_vector,train_att_mask,train_segment_ids,train_start_pos,train_end_pos,train_ans = get_data('train',train_QA,train_paragraph,tokenizer)

batch_size = 3
train_dataset = QA_Dataset(mode='train', x_vectors=train_vector, att_mask=train_att_mask,segment_ids=train_segment_ids,start=train_start_pos, end=train_end_pos,ans=train_ans)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)


In [None]:
# val data
val_paragraph,val_QA = get_paragraph_QA(data2)
val_vector,val_att_mask,val_segment_ids,val_start_pos,val_end_pos,val_ans = get_data('val',val_QA,val_paragraph,tokenizer)

val_dataset = QA_Dataset(mode='val', x_vectors=val_vector, att_mask=val_att_mask,segment_ids=val_segment_ids,start=val_start_pos, end=val_end_pos,ans=val_ans)
val_loader = DataLoader(dataset=val_dataset, batch_size=1, shuffle=False, drop_last=False)

# Training

In [None]:
from transformers import AdamW
import os
optimizer = AdamW(model.parameters(),
                  lr = 2e-5, # args.learning_rate - default is 5e-5
                  eps = 1e-8 # args.adam_epsilon  - default is 1e-8
                )
from transformers import get_linear_schedule_with_warmup

# 訓練 epochs。 BERT 作者建議在 2 和 4 之間，設大了容易過擬合 
epochs = 5

# 總的訓練樣本數
total_steps = len(train_loader) * epochs

# 建立學習率排程器
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = 0, num_training_steps = total_steps)

In [None]:
import random
import numpy as np
import time
import datetime

# 設定隨機種子值，以確保輸出是確定的
'''
seed_val = 42

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
'''
# 儲存訓練和評估的 loss、準確率、訓練時長等統計指標, to txt file
f = open('./training_state.txt','w')

# 統計整個訓練時長
total_t0 = time.time()


for epoch_i in range(0, epochs):
    # 將模型設定為訓練模式。這裡並不是呼叫訓練介面的意思
    # dropout、batchnorm 層在訓練和測試模式下的表現是不同的 (source: https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch)
    model.train()
   
    # ========================================
    #               Training
    # ========================================
    

    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')

    # 統計單次 epoch 的訓練時間
    t0 = time.time()

    # 重置每次 epoch 的訓練總 loss
    total_train_loss = 0
    f1_count = 0
    EM_count = 0
    
    # 訓練集小批量迭代
    for step,batch in enumerate(train_loader):
        
        # input_vector: tensor([[...],...,[...]]),shape:B*bert input len
        # att_mask: tensor([[...],...,[...]]),shape:B*bert input len
        # segment_ids: tensor([[...],...,[...]]),shape:B*bert input len
        # start,end: tensor([...]),shape:B,
        input_vector, att_mask, segment_ids, start, end, ans = batch
        
        # 每經過1000次迭代，就輸出進度資訊
        if step % 3000 == 0 and not step == 0:
            elapsed = format_time(time.time() - t0)
            print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_loader), elapsed))

        # 準備輸入資料，並將其拷貝到 gpu 中
        input_vector = input_vector.to(device)
        att_mask = att_mask.to(device)
        segment_ids = segment_ids.to(device)
        start = start.to(device)
        end = end.to(device)
        
        # 每次計算梯度前，都需要將梯度清 0，因為 pytorch 的梯度是累加的
        model.zero_grad()        

        # forward
        output = model(input_ids=input_vector,token_type_ids=segment_ids,attention_mask=att_mask,output_hidden_states=True,start_positions=start,end_positions=end)

        # Choose the most probable start position / end position 
        start_index = torch.argmax(output.start_logits, dim=1)
        end_index = torch.argmax(output.end_logits, dim=1)
        
        # get prediction index to vector,and caculate metrics, return predAns:batch
        predAns = get_predAns(input_vector,start_index,end_index)
        f1_score, EM = get_metrics(predAns,ans)
        f1_count += f1_score
        EM_count += EM
        if step % 3000 == 0 and not step == 0:
            print('train pred ans: ',predAns)
            print('train exact ans: ',ans)
        
        # 累加 loss
        total_train_loss += output.loss.item()

        # backward
        output.loss.backward()

        # 梯度裁剪，避免出現梯度爆炸情況
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # update parameters
        optimizer.step()

        # update learning
        scheduler.step()

    # caculate average training loss,f1,em
    train_avg_loss = total_train_loss/len(train_loader)  
    train_avg_f1 = f1_count/len(train_dataset)
    train_avg_EM = EM_count/len(train_dataset)
    
    # one epoch 的訓練時長
    training_time = format_time(time.time() - t0)
    
    # print training loss,training (f1,em),training time
    print("")
    print("  Average training loss: {0:.4f}".format(train_avg_loss))
    print("  Average training f1: {0:.4f},EM: {0:.4f}".format(train_avg_f1,train_avg_EM))
    print("  Training epcoh took: {:}".format(training_time))
    
    # ========================================
    #               Validation
    # ========================================
    # 完成一次 epoch 訓練後，就對該模型的效能進行驗證

    print("")
    print("Running Validation...")

    t0 = time.time()

    # 設定模型為評估模式
    model.eval()

    # Tracking variables 
    total_eval_loss = 0
    max_f1 = 0
    max_EM = 0
    f1_count = 0
    EM_count = 0
    
    # Evaluate data for one epoch
    for step,batch in enumerate(val_loader):
        
        # input_vector: tensor([[...],...,[...]]),shape:B*bert input len
        # att_mask: tensor([[...],...,[...]]),shape:B*bert input len
        # segment_ids: tensor([[...],...,[...]]),shape:B*bert input len
        # start,end: tensor([...]),shape:B,
        input_vector, att_mask, segment_ids, start, end, ans = batch
        
        # 將輸入資料載入到 gpu 中
        input_vector = input_vector.to(device)
        att_mask = att_mask.to(device)
        segment_ids = segment_ids.to(device)
        start = start.to(device)
        end = end.to(device)
        
        # 評估的時候不需要更新引數、計算梯度
        with torch.no_grad():        
            output = model(input_ids=input_vector,token_type_ids=segment_ids, attention_mask=att_mask,output_hidden_states=True,start_positions=start,end_positions=end)
            
        # Choose the most probable start position / end position
        start_index = torch.argmax(output.start_logits, dim=1)
        end_index = torch.argmax(output.end_logits, dim=1)
        
        # get prediction index to vector,and caculate metrics, return predAns:batch,len
        predAns = get_predAns(input_vector,start_index,end_index)
        f1_score, EM = get_metrics(predAns,ans)
        f1_count += f1_score
        EM_count += EM
        if step % 1000 == 0 and not step == 0:
            print('val pred ans: ',predAns)
            print('val exact ans: ',ans)
        
        # 累加 loss
        total_eval_loss += output.loss.item()

    # caculate average val loss,f1,em
    val_avg_loss = total_eval_loss / len(val_loader)
    val_avg_f1 = f1_count/len(val_dataset)
    val_avg_EM = EM_count/len(val_dataset)
    
    # 統計本次評估的時長
    validation_time = format_time(time.time() - t0)
    
    # save best val model
    if val_avg_f1>max_f1:
        max_f1=val_avg_f1
        torch.save({'Bert_model':model.state_dict()},'./val_best_model.pth')
        print('save new model with higher f1: {}\n'.format(val_avg_f1))
        
    print("  Validation Loss: {0:.4f}".format(val_avg_loss))
    print("  Validation f1: {0:.4f},EM: {0:.4f}".format(val_avg_f1,val_avg_EM))
    print("  Validation took: {:}".format(validation_time))

    # 記錄本次 epoch 的所有統計資訊
    training_state = {
            'epoch': epoch_i + 1,
            'Train Loss': train_avg_loss,
            'Train f1': train_avg_f1,
            'Train EM': train_avg_EM,
            'Val Loss': val_avg_loss,
            'Val f1': val_avg_f1,
            'Val EM': val_avg_EM,
            'Training Time': training_time,
            'Validation Time': validation_time
        }
    f.write(str(training_state)+'\n')
f.close()
print("")
print("Training complete!")
print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))

# Inference test.txt

In [None]:
with open('./test.txt','r') as f3:
    data3 = f3.readlines()

# test data
test_paragraph,test_QA = get_paragraph_QA(data3)
test_input_vector,test_att_mask,test_segment_ids = get_data('test',test_QA,test_paragraph,tokenizer)

test_dataset = QA_Dataset(mode='test', x_vectors=test_input_vector, att_mask=test_att_mask, segment_ids=test_segment_ids)
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, drop_last=False)

In [None]:
checkpoint = torch.load('./val_best_model.pth')
model.load_state_dict(checkpoint['Bert_model'])
model = model.to(device)
f = open('./result.txt','w')
lines = []
for i,batch in enumerate(test_loader):
    print(i)
    input_vector,att_mask,segment_ids = batch
    # to device
    input_vector = input_vector.to(device)
    att_mask = att_mask.to(device)
    segment_ids = segment_ids.to(device)
    
    # 評估的時候不需要更新引數、計算梯度
    with torch.no_grad():        
        output = model(input_ids=input_vector,token_type_ids=segment_ids,attention_mask=att_mask)
    
    # Choose the most probable start position / end position
    start_index = torch.argmax(output.start_logits, dim=1)
    end_index = torch.argmax(output.end_logits, dim=1)
    predAns = get_predAns(input_vector,start_index,end_index)   
    
    lines.append(test_QA[i][0]+' ||| '+predAns[0]+'\n')
f.writelines(lines)
f.close()