## import package

In [1]:
import torch
import torch.nn as nn
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
import math
import numpy as np
import time
import torch, pandas as pd
import nltk
import re
import pickle
import os
# nltk.download('punkt')

from transformers import set_seed
set_seed(123)

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

  from .autonotebook import tqdm as notebook_tqdm


device(type='cuda')

In [2]:
# Training data file

file="../data/divide_QA_data/data_fix_label_to_sen.pkl"
with open(file, 'rb') as f:
    data = pickle.load(f)

In [5]:
data['sub_q_true'] = [1 if x != None else -1 for x in data["q_label"]]
data['sub_r_true'] = [1 if x != None else -1 for x in data["r_label"]]
data['sub_both'] = data['sub_q_true'] * data['sub_r_true']
data.drop(index= data[data['sub_both'] == -1].index, inplace=True)
data.drop(columns=['sub_q_true', 'sub_r_true', 'sub_both'], inplace=True)
data.reset_index(drop=True, inplace=True)
data

Unnamed: 0,id,q,r,s,q',r',q_label,r_label,q_reidx,r_reidx
0,8,It can go both ways . We all doubt . It is wha...,True .,AGREE,It can go both ways . We all doubt . It is wha...,True .,"(0, 76)","(0, 5)","(0, 76)","(0, 5)"
1,8,It can go both ways . We all doubt . It is wha...,True .,AGREE,can go both ways . We all doubt . It is what y...,True,"(3, 74)","(0, 3)","(0, 76)","(0, 5)"
2,8,It can go both ways . We all doubt . It is wha...,True .,AGREE,It can go both ways . We all doubt . It is wha...,True,"(0, 76)","(0, 3)","(0, 76)","(0, 5)"
3,9,"once again , you seem to support the killing o...",based on the idea that people are dispensible ...,AGREE,seem to support the killing of certain people,based on the idea that people are dispensible ...,"(17, 61)","(0, 92)","(0, 81)","(0, 337)"
4,9,"once again , you seem to support the killing o...",based on the idea that people are dispensible ...,AGREE,you seem to support the killing of certain peo...,based on the idea that people are dispensible,"(13, 81)","(0, 44)","(0, 81)","(0, 337)"
...,...,...,...,...,...,...,...,...,...,...
36867,10001,good thing this argument has never been done !...,"And teen sex does n't , by the very nature of ...",DISAGREE,You are much better off making theft legal and...,"And teen sex does n't , by the very nature of ...","(111, 227)","(0, 57)","(0, 227)","(0, 200)"
36868,10002,"I know one thing , anything that happens , pol...",Was n't sinjin crowing about his plans to take...,DISAGREE,"I know one thing , anything that happens , pol...",Was n't sinjin crowing about his plans to take...,"(0, 108)","(0, 76)","(0, 643)","(0, 260)"
36869,10002,"I know one thing , anything that happens , pol...",Was n't sinjin crowing about his plans to take...,DISAGREE,FBI Arrests Three Men in Terror Plot that Targ...,Was n't sinjin crowing about his plans to take...,"(112, 195)","(0, 56)","(0, 643)","(0, 260)"
36870,10003,I enjoy Botany more than most things and I hav...,"Hi Smallax , welcome to the forum . I did a se...",AGREE,I enjoy Botany more than most things and I hav...,"Hi Smallax , welcome to the forum . I did a se...","(0, 106)","(0, 119)","(0, 442)","(0, 266)"


## Data process

In [14]:
# from sklearn.model_selection import train_test_split

train = data[:int(len(data)*0.9)].copy()
valid = data[int(len(data)*0.9):].copy()
del data
# train, valid = train_test_split(data, test_size=1/9, shuffle=False)
# valid, test = train_test_split(valid, test_size=0.5)
train.reset_index(drop=True, inplace=True)
valid.reset_index(drop=True, inplace=True)

In [15]:
train["s+r"] = train["s"] + ": " + train["r"]
valid["s+r"] = valid["s"] + ": " + valid["r"]

## Tokenizer

In [16]:
from transformers import AutoTokenizer

MODEL_NAME = "bert-base-cased"
# MODEL_NAME = "deepset/roberta-base-squad2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [17]:
train_data_q = train['q'].tolist()
valid_data_q = valid['q'].tolist()
# test_data_q = test['q'].tolist()

train_data_r = train['s+r'].tolist()
valid_data_r = valid['s+r'].tolist()
# train_data_r = train['r'].tolist()
# valid_data_r = valid['r'].tolist()
# test_data_r = test['r'].tolist()

train_s = train['s'].tolist()
valid_s = valid['s'].tolist()

train_q_label = train['q_label'].tolist()
valid_q_label = valid['q_label'].tolist()

train_r_label = train['r_label'].tolist()
valid_r_label = valid['r_label'].tolist()

train_q_reidx = train['q_reidx'].tolist()
valid_q_reidx = valid['q_reidx'].tolist()
# test_q_reidx = test['q_reidx'].tolist()

train_r_reidx = train['r_reidx'].tolist()
valid_r_reidx = valid['r_reidx'].tolist()
# test_r_reidx = test['r_reidx'].tolist()

In [18]:
train_encodings = tokenizer(train_data_q, train_data_r, truncation=True, padding=True, max_length=512)
val_encodings = tokenizer(valid_data_q, valid_data_r, truncation=True, padding=True, max_length=512, return_offsets_mapping=True)

## Dataset

In [19]:
def add_token_positions(encodings, q_label, r_label, q_reidx, r_reidx, s_data=None):
    if s_data is not None:
        for idx, s in enumerate(s_data):
            if s == "AGREE":
                r_label[idx] = (r_label[idx][0] + 7, r_label[idx][1] + 7) if r_label[idx] != None else None
            elif s == "DISAGREE":
                r_label[idx] = (r_label[idx][0] + 10, r_label[idx][1] + 10) if r_label[idx] != None else None
    q_starts, r_starts, q_ends, r_ends = [], [], [], []
    for idx, (q_l, q_r, r_l, r_r) in enumerate(zip(q_label, q_reidx, r_label, r_reidx)):
        # q_start, q_end, r_start, r_end = 0, 0, 0, 0

        if q_l == None or r_l == None:
            q_starts.append(0)
            q_ends.append(0)
            r_starts.append(0)
            r_ends.append(0)
            continue

        q_s = encodings.char_to_token(idx, q_l[0]-q_r[0], 0)
        q_e = encodings.char_to_token(idx, q_l[1]-q_r[0], 0)

        r_s = encodings.char_to_token(idx, r_l[0]-r_r[0], 1)    #2
        r_e = encodings.char_to_token(idx, r_l[1]-r_r[0], 1)

        if q_s == None and q_e == None or r_s == None and r_e == None:
            q_starts.append(0)
            q_ends.append(0)
            r_starts.append(0)
            r_ends.append(0)
            continue

        shift = 1
        while q_s is None:
            q_s = encodings.char_to_token(idx, q_l[0]-q_r[0] + shift, 0)
            shift += 1
        shift = 1
        while r_s is None:
            r_s = encodings.char_to_token(idx, r_l[0]-r_r[0] + shift, 1)    #2
            shift += 1

        shift = 1
        while q_e is None:
            q_e = encodings.char_to_token(idx, q_l[1]-q_r[0] - shift, 0)
            shift += 1
        shift = 1
        while r_e is None:
            r_e = encodings.char_to_token(idx, r_l[1]-r_r[0] - shift, 1)    #2
            shift += 1
            
        if q_s == None or q_e == None or r_s == None or r_e == None:
            print(idx, q_s, q_e, r_s, r_e)
        q_starts.append(q_s)
        q_ends.append(q_e)
        r_starts.append(r_s)
        r_ends.append(r_e)

    encodings.update({'q_start': q_starts, 'q_end': q_ends, 'r_start': r_starts, 'r_end': r_ends})
    return r_label, r_reidx

In [20]:
# Convert char_based_id to token_based_id
# Find the corossponding token id after input being tokenized
train_r_label, train_r_reidx =  add_token_positions(train_encodings, train_q_label, train_r_label, train_q_reidx, train_r_reidx, train_s)
valid_r_label, valid_r_reidx =  add_token_positions(val_encodings, valid_q_label, valid_r_label, valid_q_reidx, valid_r_reidx, valid_s)

In [22]:
class qrDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        # item = {}
        # for key, val in self.encodings.items():
        #     if key != 'offset_mapping':
        #         item[key] = torch.tensor(val[idx])
        # return item
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

In [23]:
val_mappping = val_encodings['offset_mapping']
val_encodings.pop("offset_mapping")
train_dataset = qrDataset(train_encodings)
val_dataset = qrDataset(val_encodings)

In [24]:
train_dataset.encodings.keys(), val_dataset.encodings.keys()

(dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'q_start', 'q_end', 'r_start', 'r_end']),
 dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'q_start', 'q_end', 'r_start', 'r_end']))

## Model

In [28]:
from transformers import BertModel

class myModel(torch.nn.Module):

    def __init__(self):

        super(myModel, self).__init__()

        self.bert = BertModel.from_pretrained(MODEL_NAME)
        self.fc = nn.Linear(768, 4)
        

    def forward(self, input_ids, attention_mask, token_type_ids=None):   
        # output = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        output_logits = self.fc(output[0])
        return output_logits



## Training

In [29]:
device

device(type='cuda')

In [30]:
# # Pack data into dataloader by batch
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [48]:
from transformers import AdamW, get_linear_schedule_with_warmup
from tqdm.auto import tqdm
from torch.nn.utils import clip_grad_norm_

# Set GPU / CPU
# device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# Put model on device
model = myModel().to(device)
training_epoch = 3
loss_fct = CrossEntropyLoss()

# params = list(model.named_parameters())
# no_decay = ['bias,','LayerNorm']
# other = ['fc']
# no_main = no_decay + other

# optimizer_grouped_parameters = [
#     {'params':[p for n,p in params if not any(nd in n for nd in no_main)],'weight_decay':1e-2,'lr':1e-5},
#     {'params':[p for n,p in params if not any(nd in n for nd in other) and any(nd in n for nd in no_decay) ],'weight_decay':0,'lr':1e-5},
#     {'params':[p for n,p in params if any(nd in n for nd in other) and any(nd in n for nd in no_decay) ],'weight_decay':0,'lr':1e-2},
#     {'params':[p for n,p in params if any(nd in n for nd in other) and not any(nd in n for nd in no_decay) ],'weight_decay':1e-2,'lr':1e-2},
# ]
optim = AdamW(model.parameters(), lr=1e-5)

total_steps = len(train_loader) * training_epoch

# optimizer = optim.Adam(model.parameters(), lr=learning_rate0)

# optim = AdamW(optimizer_grouped_parameters, lr=1e-5)
scheduler = get_linear_schedule_with_warmup(
    optim,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Grading

In [32]:
def get_output_post_fn(test, q_sub_output, r_sub_output):
    q_sub, r_sub = [], []
    for i in range(len(test)):

        q_sub_pred = q_sub_output[i].split()
        r_sub_pred = r_sub_output[i].split()

        if q_sub_pred is None:
            q_sub_pred = []
        q_sub_error_index = q_sub_pred.index('[SEP]') if '[SEP]' in q_sub_pred else -1

        if q_sub_error_index != -1:
            q_sub_pred = q_sub_pred[:q_sub_error_index]

        temp = r_sub_pred.copy()
        if r_sub_pred is None:
            r_sub_pred = []
        else:
            for j in range(len(temp)):
                if temp[j] == '[SEP]':
                    r_sub_pred.remove('[SEP]')
                if temp[j] == '[PAD]':
                    r_sub_pred.remove('[PAD]')

        q_sub.append(' '.join(q_sub_pred))
        r_sub.append(' '.join(r_sub_pred))
        if q_sub[-1] == "[CLS]":
            q_sub[-1] = test["q"][len(q_sub)-1]
        if r_sub[-1] == "[CLS]":
            r_sub[-1] = test["r"][len(r_sub)-1]

    return q_sub, r_sub

In [33]:
def nltk_token_string(sentence):
    # print(sentence)
    tokens = nltk.word_tokenize(sentence)
    for i in range(len(tokens)):
        if len(tokens[i]) == 1:
            tokens[i] = re.sub(r"[!\"#$%&\'()*\+, -.\/:;<=>?@\[\\\]^_`{|}~]", '', tokens[i])
    while '' in tokens:
        tokens.remove('')
    # tokens = ' '.join(tokens)
    return tokens

In [34]:
def lcs(X, Y):
    X_, Y_ = [], []
    X_ = nltk_token_string(X)
    Y_ = nltk_token_string(Y)

    m = len(X_)
    n = len(Y_)
 
    # declaring the array for storing the dp values
    L = [[None]*(n + 1) for i in range(m + 1)]
 
    """Following steps build L[m + 1][n + 1] in bottom up fashion
    Note: L[i][j] contains length of LCS of X[0..i-1]
    and Y[0..j-1]"""
    for i in range(m + 1):
        for j in range(n + 1):
            if i == 0 or j == 0 :
                L[i][j] = 0
            elif X_[i-1] == Y_[j-1]:
                L[i][j] = L[i-1][j-1]+1
            else:
                L[i][j] = max(L[i-1][j], L[i][j-1])
 
    # L[m][n] contains the length of LCS of X[0..n-1] & Y[0..m-1]
    return L[m][n], m, n


def acc_(full, sub):
    common, m, n = lcs(full, sub)
    union = m + n - common
    if union == 0:
        return 1
    accuracy = float(common/union)

    return accuracy

In [35]:
def get_acc(q_true, r_true, q_sub, r_sub):
    q_acc_sum = 0
    r_acc_sum = 0
    test_len = len(q_true)
    for i in range(test_len):
        q_accuracy = acc_(q_true[i], q_sub[i])
        r_accuracy = acc_(r_true[i], r_sub[i])

        q_acc_sum += q_accuracy
        r_acc_sum += r_accuracy

    print("q accuracy: ", q_acc_sum/test_len)
    print("r accuracy: ", r_acc_sum/test_len)
    return q_acc_sum/test_len, r_acc_sum/test_len

### Train model

In [37]:
def evaluate(valid_loader, valid_r, valid_q):
    model.eval()
    running_loss = 0.0
    total_loss = 0.0
    predict_pos, q_sub_output, r_sub_output = [], [], []
    q_true_output, r_true_output = [], []

    with torch.no_grad():
        loop = tqdm(valid_loader, leave=True, ncols=75)
        for batch_id, batch in enumerate(loop):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            q_start = batch['q_start'].to(device)
            r_start = batch['r_start'].to(device)
            q_end = batch['q_end'].to(device)
            r_end = batch['r_end'].to(device)

            # model output
            # outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

            q_start_logits, r_start_logits, q_end_logits, r_end_logits = torch.split(outputs, 1, 2)

            q_start_logits = q_start_logits.squeeze(-1).contiguous()
            r_start_logits = r_start_logits.squeeze(-1).contiguous()
            q_end_logits = q_end_logits.squeeze(-1).contiguous()
            r_end_logits = r_end_logits.squeeze(-1).contiguous()

            q_start_loss = loss_fct(q_start_logits, q_start)
            r_start_loss = loss_fct(r_start_logits, r_start)
            q_end_loss = loss_fct(q_end_logits, q_end)
            r_end_loss = loss_fct(r_end_logits, r_end)

            loss = q_start_loss + r_start_loss + q_end_loss + r_end_loss

            running_loss += loss.item()
            total_loss += loss.item()
            if batch_id % 250 == 0 and batch_id != 0:
                print('Validation Epoch {} Batch {} Loss {:.4f}'.format(
                    batch_id + 1, batch_id, running_loss / 250))
                running_loss = 0.0

            q_start_prdict = torch.argmax(q_start_logits, 1).cpu().numpy()
            r_start_prdict = torch.argmax(r_start_logits, 1).cpu().numpy()
            q_end_prdict = torch.argmax(q_end_logits, 1).cpu().numpy()
            r_end_prdict = torch.argmax(r_end_logits, 1).cpu().numpy()
            # print(q_start_prdict, r_start_prdict, q_end_prdict, r_end_prdict)
            for i in range(len(input_ids)):
                predict_pos.append((q_start_prdict[i].item(), r_start_prdict[i].item(), q_end_prdict[i].item(), r_end_prdict[i].item()))

                q_true_s = val_mappping[batch_size * batch_id + i][q_start[i]][0]
                q_true_e = val_mappping[batch_size * batch_id + i][q_end[i]][-1]
                r_true_s = val_mappping[batch_size * batch_id + i][r_start[i]][0]
                r_true_e = val_mappping[batch_size * batch_id + i][r_end[i]][-1]
                q_true = valid_q[batch_size * batch_id + i][q_true_s:q_true_e]
                r_true = valid_r[batch_size * batch_id + i][r_true_s:r_true_e]

                q_s = val_mappping[batch_size * batch_id + i][predict_pos[-1][0]][0]
                q_e = val_mappping[batch_size * batch_id + i][predict_pos[-1][2]][-1]
                r_s = val_mappping[batch_size * batch_id + i][predict_pos[-1][1]][0]
                r_e = val_mappping[batch_size * batch_id + i][predict_pos[-1][3]][-1]
                q_sub = valid_q[batch_size * batch_id + i][q_s:q_e]
                r_sub = valid_r[batch_size * batch_id + i][r_s:r_e]

                q_sub_output.append(q_sub)
                r_sub_output.append(r_sub)
                q_true_output.append(q_true)
                r_true_output.append(r_true)

        print("evaluate loss: ", total_loss / len(valid_loader))
        # q_sub, r_sub = get_output_post_fn(valid, q_sub_output, r_sub_output)
    return q_sub_output, r_sub_output, q_true_output, r_true_output

In [49]:
best_acc = 0.0
for epoch in range(training_epoch):
    model.train()
    running_loss = 0.0
    total_loss = 0.0

    loop = tqdm(train_loader, leave=True, ncols=75)

    for batch_id, batch in enumerate(loop):
        # reset
        optim.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        q_start = batch['q_start'].to(device)
        r_start = batch['r_start'].to(device)
        q_end = batch['q_end'].to(device)
        r_end = batch['r_end'].to(device)

        # model output
        # outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        q_start_logits, r_start_logits, q_end_logits, r_end_logits = torch.split(outputs, 1, 2)

        q_start_logits = q_start_logits.squeeze(-1).contiguous()
        r_start_logits = r_start_logits.squeeze(-1).contiguous()
        q_end_logits = q_end_logits.squeeze(-1).contiguous()
        r_end_logits = r_end_logits.squeeze(-1).contiguous()

        q_start_loss = loss_fct(q_start_logits, q_start)
        r_start_loss = loss_fct(r_start_logits, r_start)
        q_end_loss = loss_fct(q_end_logits, q_end)
        r_end_loss = loss_fct(r_end_logits, r_end)

        loss = q_start_loss + r_start_loss + q_end_loss + r_end_loss

        # calculate loss
        loss.backward()
        running_loss += loss.item()
        total_loss += loss.item()
        # update parameters
        clip_grad_norm_(model.parameters(), 1.0)
        optim.step()
        scheduler.step()
        
        if batch_id % 500 == 0 and batch_id != 0 or batch_id == len(train_loader) - 1:
            print('Step {} Batch {} Loss {:.4f}'.format(
                batch_id + 1, batch_id, (running_loss / 500) if batch_id != len(train_loader) - 1 or(len(train_loader) % 500) ==0 else running_loss / (len(train_loader) % 500)))
            running_loss = 0.0

        loop.set_description('Epoch {}'.format(epoch + 1))
        loop.set_postfix(loss=total_loss/(batch_id+1))
    # evaluate(valid_loader)
    q_sub_output, r_sub_output, q_true_output, r_true_output = evaluate(valid_loader, valid_data_r, valid_data_q)
    # q_sub, r_sub = get_output_post_fn(valid, q_sub_output, r_sub_output)
    acc_q, acc_r = get_acc(q_true_output, r_true_output, q_sub_output, r_sub_output)
    acc = (acc_q + acc_r) / 2
    print("acc:", acc)
    if acc > best_acc:
        best_acc = acc
        best_model_name = str(best_acc)
        torch.save(model.state_dict(), best_model_name)
        print("save model----acc: ", best_acc)

Epoch 1:  45%|██████▋        | 501/1125 [03:50<04:36,  2.26it/s, loss=11.8]

Step 501 Batch 500 Loss 11.8277


Epoch 1:  89%|████████████▍ | 1001/1125 [07:42<00:57,  2.17it/s, loss=10.4]

Step 1001 Batch 1000 Loss 9.0669


Epoch 1: 100%|██████████████| 1125/1125 [08:39<00:00,  2.17it/s, loss=10.3]


Step 1125 Batch 1124 Loss 8.8425


100%|████████████████████████████████████| 125/125 [00:19<00:00,  6.30it/s]


evaluate loss:  9.143227420806884
q accuracy:  0.5091153116873126
r accuracy:  0.5033085593666576
acc: 0.5062119355269852
save model----acc:  0.5062119355269852


Epoch 2:  45%|██████▋        | 501/1125 [03:52<04:55,  2.11it/s, loss=8.51]

Step 501 Batch 500 Loss 8.5234


Epoch 2:  89%|████████████▍ | 1001/1125 [07:45<00:56,  2.18it/s, loss=8.54]

Step 1001 Batch 1000 Loss 8.5673


Epoch 2: 100%|██████████████| 1125/1125 [08:42<00:00,  2.15it/s, loss=8.55]


Step 1125 Batch 1124 Loss 8.5834


100%|████████████████████████████████████| 125/125 [00:19<00:00,  6.32it/s]


evaluate loss:  9.029457206726073
q accuracy:  0.5056663274236893
r accuracy:  0.510552382399677
acc: 0.5081093549116832
save model----acc:  0.5081093549116832


Epoch 3:  45%|██████▋        | 501/1125 [03:52<04:44,  2.19it/s, loss=8.13]

Step 501 Batch 500 Loss 8.1424


Epoch 3:  89%|████████████▍ | 1001/1125 [07:44<00:56,  2.18it/s, loss=8.16]

Step 1001 Batch 1000 Loss 8.2001


Epoch 3: 100%|██████████████| 1125/1125 [08:42<00:00,  2.15it/s, loss=8.19]


Step 1125 Batch 1124 Loss 8.3000


100%|████████████████████████████████████| 125/125 [00:19<00:00,  6.36it/s]


evaluate loss:  9.143028686523438
q accuracy:  0.5061502005370291
r accuracy:  0.5102139529881575
acc: 0.5081820767625933
save model----acc:  0.5081820767625933


In [6]:
# model = myModel().to(device)
model.load_state_dict(torch.load(best_model_name))

<All keys matched successfully>

## Predict

In [38]:
def predict(test_loader):
    predict_pos = []

    model.eval()

    q_sub_output, r_sub_output = [],[]

    loop = tqdm(test_loader, leave=True)
    for batch_id, batch in enumerate(loop):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)

        # model output
        # outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        
        q_start_logits, r_start_logits, q_end_logits, r_end_logits = torch.split(outputs, 1, 2)

        q_start_logits = q_start_logits.squeeze(-1).contiguous()
        r_start_logits = r_start_logits.squeeze(-1).contiguous()
        q_end_logits = q_end_logits.squeeze(-1).contiguous()
        r_end_logits = r_end_logits.squeeze(-1).contiguous()

        q_start_prdict = torch.argmax(q_start_logits, 1).cpu().numpy()
        r_start_prdict = torch.argmax(r_start_logits, 1).cpu().numpy()
        q_end_prdict = torch.argmax(q_end_logits, 1).cpu().numpy()
        r_end_prdict = torch.argmax(r_end_logits, 1).cpu().numpy()

        for i in range(len(input_ids)):
            predict_pos.append((q_start_prdict[i].item(), r_start_prdict[i].item(), q_end_prdict[i].item(), r_end_prdict[i].item()))
            k = 1
            while tokenizer.decode(input_ids[i][predict_pos[-1][0]])[0:2] == "##":
                predict_pos[-1] = (q_start_prdict[i].item()-k, r_start_prdict[i].item(), q_end_prdict[i].item(), r_end_prdict[i].item())
                i += 1
            k = 1
            while tokenizer.decode(input_ids[i][predict_pos[-1][1]])[0:2] == "##":
                predict_pos[-1] = (q_start_prdict[i].item(), r_start_prdict[i].item(), q_end_prdict[i].item()-k, r_end_prdict[i].item())
                i += 1
            k = 1

            if token_type_ids[i][predict_pos[-1][2]] != 0:
                predict_pos[-1] = (q_start_prdict[i].item(), r_start_prdict[i].item(), list(token_type_ids[i]).index(1) - 2, r_end_prdict[i].item())
            if token_type_ids[i][predict_pos[-1][1]] != 1:
                predict_pos[-1] = (q_start_prdict[i].item(), list(token_type_ids[i]).index(1), q_end_prdict[i].item(), r_end_prdict[i].item())
            # q_sub_output.append(q_sub)
            # r_sub_output.append(r_sub)
    
    return predict_pos  #q_sub_output, r_sub_output, predict_pos

In [39]:
test = pd.read_csv("../data/Batch_answers - test_data(no_label).csv")
test.tail()
test[['q','r']] = test[['q','r']].apply(lambda x: x.str.strip('\"'))
test.tail()
def split_sen(data_):    
    for i,(j,z) in enumerate(zip(data_["q"], data_["r"])):
        # print(i, print(data_["q"][i]))
        if len(j.split(" ")) > 200:
            n = math.ceil(len(j.split(" "))/200)
            tmp = j.split(" . ")
            n = math.ceil(len(tmp)/n)
            data_["q"][i] = [(" . ").join(tmp[idx : idx + n]) for idx in range(0, len(tmp), n)]
        else:   data_["q"][i] = [j]
        if len(z.split(" ")) > 200:
            n = math.ceil(len(z.split(" "))/200)
            tmp = z.split(" . ")
            n = math.ceil(len(tmp)/n)
            data_["r"][i] = [(" . ").join(tmp[idx : idx + n]) for idx in range(0, len(tmp), n)]
        else:   data_["r"][i] = [z]
    return data_

def re_idx(array):
    idx_list = np.array([len(x) for x in array])+3
    idx_list_ = np.cumsum(idx_list)
    s_list = idx_list_ - idx_list
    idx_list_ -= 4 
    return [(x, y ,i) for i,(x,y) in enumerate(zip(s_list, idx_list_))]

def re_pair(q, q_redix):
    return [[a,b] for (a,b) in zip(q, q_redix)]

test = split_sen(test)
test["q_reidx"] = test.apply(lambda x : re_idx(x["q"]), axis=1)
test["r_reidx"] = test.apply(lambda x : re_idx(x["r"]), axis=1)
test["q"] = test.apply(lambda x : re_pair(x["q"], x["q_reidx"]), axis=1)
test["r"] = test.apply(lambda x : re_pair(x["r"], x["r_reidx"]), axis=1)
test = test.explode('q').reset_index(drop=True)
test = test.explode('r').reset_index(drop=True)
test["q_reidx"] = test["q"].apply(lambda x : (x[1][0], x[1][1]))
test["q_sub_idx"] = test["q"].apply(lambda x : x[1][-1])
test["q"] = test["q"].apply(lambda x : x[0])
test["r_reidx"] = test["r"].apply(lambda x : (x[1][0], x[1][1]))
test["r_sub_idx"] = test["r"].apply(lambda x : x[1][-1])
test["r"] = test["r"].apply(lambda x : x[0])
test["s+r"] = test["s"] +": " + test["r"]
test.tail(10)   #2387

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  del sys.path[0]
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  if sys.path[0] == "":
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy


Unnamed: 0,id,q,r,s,q_reidx,r_reidx,q_sub_idx,r_sub_idx,s+r
2378,5227,Why so many what ?,Different sciences . There are different spins...,AGREE,"(0, 17)","(0, 106)",0,0,AGREE: Different sciences . There are differen...
2379,6136,Why would someone make that up and pass it off...,"Well , they did .",AGREE,"(0, 65)","(0, 16)",0,0,"AGREE: Well , they did ."
2380,2271,You once said that you had done a detailed stu...,So doing a detailed study of something require...,DISAGREE,"(0, 318)","(0, 210)",0,0,DISAGREE: So doing a detailed study of somethi...
2381,4420,"Woodward was a fraud , and I recall that was a...",do you have proof of such a statement ? And wh...,DISAGREE,"(0, 77)","(0, 166)",0,0,DISAGREE: do you have proof of such a statemen...
2382,4071,Would you accept civil unions even though they...,No ... and you do n't have to either .,DISAGREE,"(0, 71)","(0, 37)",0,0,DISAGREE: No ... and you do n't have to either .
2383,9499,You are betraying your belief system .,Yep . ( I 'm assuming that by `` belief system...,AGREE,"(0, 37)","(0, 270)",0,0,AGREE: Yep . ( I 'm assuming that by `` belief...
2384,4611,"You are in a loud minority , railing against t...",Being in the minority or in the majority is ir...,DISAGREE,"(0, 77)","(0, 90)",0,0,DISAGREE: Being in the minority or in the majo...
2385,9328,You bet your XXX that 'd make me happy .,"Well , first , I probably would n't bet my XXX...",DISAGREE,"(0, 39)","(0, 237)",0,0,"DISAGREE: Well , first , I probably would n't ..."
2386,5225,you say `` f * * * the Constitution. ``,and gun nuts say f * * * the children when we ...,DISAGREE,"(0, 38)","(0, 99)",0,0,DISAGREE: and gun nuts say f * * * the childre...
2387,68,Your answers were without content or meaning ....,k,DISAGREE,"(0, 68)","(0, 0)",0,0,DISAGREE: k


In [40]:
test_data_q = test['q'].tolist()
test_data_r = test['s+r'].tolist()
test_q_reidx = test['q_reidx'].tolist()
test_r_reidx = test['r_reidx'].tolist()
test_encodings = tokenizer(test_data_q, test_data_r, truncation=True, padding=True, max_length=512, return_offsets_mapping=True)
test_offset_mapping = test_encodings["offset_mapping"]
test_encodings.pop("offset_mapping")
test_encodings.keys()
test_dataset = qrDataset(test_encodings)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

In [42]:
# q_sub_output, r_sub_output, predict_pos = predict(test_loader)
predict_pos = predict(test_loader)

100%|██████████| 299/299 [00:42<00:00,  7.04it/s]


In [48]:
q_sub, r_sub = [], []
for i in range(len(predict_pos)):
    if i == 2382:
        print(predict_pos[i][0], predict_pos[i][1], predict_pos[i][2], predict_pos[i][3])
    q_s = test_offset_mapping[i][predict_pos[i][0]][0]
    q_e = test_offset_mapping[i][predict_pos[i][2]][-1]
    r_s = test_offset_mapping[i][predict_pos[i][1]][0]
    r_e = test_offset_mapping[i][predict_pos[i][3]][-1]
    if i == 2382:
        print(test_data_q[i][q_s:q_e], test_data_r[i][r_s:r_e])
    q_pre_sen = test_data_q[i][q_s:q_e]
    r_pre_sen = test_data_r[i][r_s:r_e]
    q_sub.append(q_pre_sen)
    r_sub.append(r_pre_sen)

1 21 13 34
Would you accept civil unions even though they were less then marriage ? No ... and you do n't have to either .


In [49]:
test['q_sub'] = q_sub
test['r_sub'] = r_sub

In [51]:
test2 = test.copy()

In [53]:
ans_id, ans_q, ans_r = [], [], []
for id in set(test2["id"]):
    if id == 3890:
        print(id)
    frame = test2[test2["id"] == id]
    q_set =set(frame["q_sub_idx"])
    r_set =set(frame["r_sub_idx"])
    q_sub, r_sub = "", ""
    if len(q_set) == 1:
        q_sub = frame["q_sub"].iloc[0]
        if q_sub == "":
            q_sub = frame["q"].iloc[0]
    else:
        for idx in q_set:
            # find max len by q_set to find in frame
            q_frame = frame[frame["q_sub_idx"] == idx]
            max_idx = max(len(q) for q in q_frame["q_sub"])
            for q in q_frame["q_sub"]:
                if len(q) == max_idx:
                    q_sub += q
                    break
    if len(q_sub) == 0:
        if len(frame) == 1:
            q_sub = frame["q_sub"].iloc[0]
        else:
            q_sub = frame["q"][frame["q"].index[0]]
            for idx, q in enumerate(frame["q"][1:]):
                if frame["q_sub_idx"][frame["q"].index[0]+idx+1] != frame["q_sub_idx"][frame["q"].index[0]+idx]:
                    q_sub += q

    if len(r_set) == 1:
        r_sub = frame["r_sub"].iloc[0]
        if r_sub == "":
            r_sub = frame["r"].iloc[0]
    else:
        for idx in r_set:
            # find max len by q_set to find in frame
            r_frame = frame[frame["r_sub_idx"] == idx]
            max_idx = max(len(r) for r in r_frame["r_sub"])
            for r in r_frame["r_sub"]:
                if len(r) == max_idx:
                    r_sub += r
                    break

    if len(r_sub) == 0:
        if len(frame) == 1:
            r_sub = frame["r_sub"].iloc[0]
        else:
            r_sub = frame["r"][frame["r"].index[0]]
            for idx, r in enumerate(frame["r"][1:]):
                if frame["r_sub_idx"][frame["r"].index[0]+idx+1] != frame["r_sub_idx"][frame["r"].index[0]+idx]:
                    r_sub += r

    ans_id.append(id)
    ans_q.append('"'+q_sub+'"')
    ans_r.append('"'+r_sub+'"')

len(ans_id), len(ans_q), len(ans_r)

3890


(2016, 2016, 2016)

In [54]:
for q in ans_q:
    if q == '""':
        print(q)

for r in ans_r:
    if r == '""':
        print(r)

In [55]:
ans = pd.DataFrame({"id": ans_id, "q": ans_q, "r": ans_r})
# pd.set_option('display.max_colwidth', -1)
ans

Unnamed: 0,id,q,r
0,1,"""I got a good idea . however , they do tend to...","""By your own admission you havenÂ ’ t 'hung ou..."
1,2,"""Be sure to give your guns a big fat kiss toni...","""Actually , they did n't . The whole tragedy w..."
2,3,"""One of the biggest arguments against gun cont...","""Not quite . To be more correct regarding gove..."
3,4,"""compare the ' B ' specimen in your fossil lin...","""Comparison I could 've just circled the whole..."
4,5,"""There are some incedents that are beyond your...","""Well yes ."""
...,...,...,...
2011,8186,"""It seems that you would be willing to grant t...","""Sorry to hear you lost an hour 's worth of wo..."
2012,8187,"""Waiting until they are born likely gives them...","""I think they have more of a chance becasue , ..."
2013,8188,"""The government was right to tighten up the la...","""So those who will allow more gun control , sh..."
2014,8189,"""would you have a problem calling a guy marrag...","""Yes , I would . The term is not simply just w..."


In [57]:
ans.to_csv("submission_bert_"+best_model_name+".csv", index=False, encoding="utf-8")