In [23]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import numpy as np
import pandas as pd
import nltk
import pickle
import os
from transformers import BertTokenizerFast
from transformers import BertModel
from transformers import AdamW
from tqdm import tqdm
from transformers import set_seed
set_seed(123)

nltk.download('punkt')

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

[nltk_data] Downloading package punkt to /home/julia/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [24]:
def add_label_token(encoding, q_label, r_label, q_reidx, r_reidx):
    labels = []
    labels_mask = []
    for idx, (q_ls, q_r, r_ls, r_r) in enumerate(zip(q_label, q_reidx, r_label, r_reidx)):
        word_idx = encoding.word_ids(batch_index= idx)
        label = np.array([0 for x in word_idx])
        label_mask = np.array([0 if x == None else 1 for x in word_idx])
        if q_ls == [] or r_ls == []:
            labels.append(label)
            labels_mask.append(label_mask)
            continue

        for q_l in q_ls:
            q_s, q_e =  q_l[0], q_l[1]

            q_s = encoding.char_to_token(idx, q_s-q_r[0], 0)
            q_e = encoding.char_to_token(idx, q_e-q_r[0], 0)
            
            if q_s == None or q_e == None:
                continue

            for i in range(q_s, q_e+1):
                label[i] = 1
                if i == 0 or word_idx[i] != word_idx[i-1]:
                    pass
                else: 
                    label_mask[i] = 0

        for r_l in r_ls:
            r_s, r_e =  r_l[0], r_l[1]

            r_s = encoding.char_to_token(idx, r_s-r_r[0], 1)
            r_e = encoding.char_to_token(idx, r_e-r_r[0], 1)
            
            if r_s == None or r_e == None:
                continue
            
            for i in range(r_s, r_e+1):
                label[i] = 1
                if i == 0 or word_idx[i] != word_idx[i-1]:
                    pass
                else:
                    label_mask[i] = 0

        labels.append(label)
        labels_mask.append(label_mask)
    return labels, labels_mask

In [25]:
def get_input_pos(train_encoding):
    input_token_pos_list = []
    input_token_pos = []
    for i in range(len(train_encoding["input_ids"])):
        for id in train_encoding.word_ids(i):
            if id == None:
                input_token_pos.append(0)
            else:
                input_token_pos.append(1)
        input_token_pos_list.append(input_token_pos)
        input_token_pos = []
    return input_token_pos_list

In [26]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

In [27]:
from transformers.models.bert.modeling_bert import *
from transformers import BertModel
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torchcrf import CRF

class Mymodel(BertPreTrainedModel):
    def __init__(self, config, lstm_hidden_size, lstm_dropout_prob, num_labels):
        super().__init__(config)
        self.num_labels = num_labels

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(768, num_labels)
        # lstm
        # self.bilstm = nn.LSTM(config.hidden_size, lstm_hidden_size, dropout=lstm_dropout_prob, batch_first=True, bidirectional=True)
        # self.classifier = nn.Linear(lstm_hidden_size*2, num_labels)
        self.crf = CRF(num_labels, batch_first=True)

        self.init_weights()  

    def forward(self, input_ids, input_token_pos, token_type_ids, attention_mask, labels=None, labels_mask=None):
        outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]

        origin_sequence_output = [layer[starts.nonzero().squeeze(1)] for layer, starts in zip(sequence_output, input_token_pos)]
        padded_sequence_output = pad_sequence(origin_sequence_output, batch_first=True)
        padded_sequence_output = self.dropout(padded_sequence_output)
        # lstm
        # lstm_output, _ = self.bilstm(padded_sequence_output)

        output = self.classifier(padded_sequence_output)
        # lstm
        # output = self.classifier(lstm_output)
        self.predict = self.crf.decode(output)
        if labels is not None:
            origin_label = [layer[starts.nonzero().squeeze(1)] for layer, starts in zip(labels, input_token_pos)]
            padded_label = pad_sequence(origin_label, batch_first=True)
            origin_loss_mask = [layer[starts.nonzero().squeeze(1)] for layer, starts in zip(labels_mask, input_token_pos)]
            padded_loss_mask = pad_sequence(origin_loss_mask, batch_first=True)
            loss_mask = padded_loss_mask.gt(0)
            loss = self.crf(output, padded_label, loss_mask) * (-1)

            return loss

In [28]:
def run(batch):
    global model
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    token_type_ids = batch['token_type_ids'].to(device)
    input_token_pos = batch['input_token_pos'].to(device)
    labels = batch['labels'].to(device)
    labels_mask = batch['labels_mask'].to(device)
    
    loss = model.forward(input_ids, input_token_pos, token_type_ids, attention_mask, labels, labels_mask)
    
    return loss

In [29]:
def train(train_loader):
    global model, optim
    epoch = 4

    for e in range(epoch):
        model.train()
        running_loss = 0.0
        model_file = f"bertNER_lstm_epoch{e+1}"

        loop = tqdm(train_loader, leave=True)
        for batch_id, batch in enumerate(loop):
            optim.zero_grad()

            loss = run(batch)
            loss.backward()
            optim.step()

            running_loss += loss.item()
            if batch_id % 500 == 0 and batch_id != 0:
                torch.save(model.state_dict(), model_file)
                print('Epoch {} Batch {} Loss {:.4f}'.format(
                    e+1, batch_id, running_loss / 500))
                running_loss = 0.0

            loop.set_description(f'Epoch {e+1}')
            loop.set_postfix(loss=loss.item())
        torch.save(model.state_dict(), model_file)

In [30]:
def pred(loader):
    global tokenizer
    predict_q = []
    predict_r = []
    predict_q_id = []
    predict_r_id = []

    model.eval()

    loop = tqdm(loader, leave=True)
    for batch_id, batch in enumerate(loop):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        input_token_pos = batch['input_token_pos'].to(device)

        model.forward(input_ids=input_ids,  input_token_pos=input_token_pos, attention_mask=attention_mask, token_type_ids=token_type_ids)      
        prdict_label = model.predict
        origin_ids = [layer[starts.nonzero().squeeze(1)] for layer, starts in zip(input_ids, input_token_pos)]
        origin_type = [layer[starts.nonzero().squeeze(1)] for layer, starts in zip(token_type_ids, input_token_pos)]
        for i in range(len(origin_ids)):
            for l_id in range(len(origin_ids[i])):
                if prdict_label[i][l_id] == 1:
                    if origin_type[i][l_id] == 0 and origin_ids[i][l_id] != 100:
                        predict_q_id.append(origin_ids[i][l_id])
                    elif origin_ids[i][l_id] != 100:
                        predict_r_id.append(origin_ids[i][l_id])
            p_q = tokenizer.decode(predict_q_id)
            p_r = tokenizer.decode(predict_r_id)
            predict_q.append(p_q)
            predict_r.append(p_r)
            predict_q_id = []
            predict_r_id = []
    
    return predict_q, predict_r

In [None]:
if __name__ == "__main__":
    with open("data_relabel.pkl", 'rb') as file:
        df_train = pickle.load(file)
        
    train_q = df_train["q"].tolist()
    train_r = df_train["r"].tolist()
    train_q_label_idx = df_train["q_label"].tolist()
    train_r_label_idx = df_train["r_label"].tolist()
    train_q_reidx = df_train["q_reidx"].tolist()
    train_r_reidx = df_train["r_reidx"].tolist()
    
    tokenizer = BertTokenizerFast.from_pretrained('dslim/bert-base-NER')
    train_encoding = tokenizer(train_q, train_r, truncation=True, padding=True)
    train_encoding["labels"],  train_encoding["labels_mask"]= add_label_token(train_encoding, train_q_label_idx, train_r_label_idx, train_q_reidx, train_r_reidx)
    train_encoding["input_token_pos"] = get_input_pos(train_encoding)
    
    batch_size = 4
    lstm_hidden_size = 128
    lstm_dropout_prob = 0
    num_labels = 2

    train_dataset = MyDataset(train_encoding)
    train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
    
    model = Mymodel.from_pretrained('dslim/bert-base-NER', lstm_hidden_size, lstm_dropout_prob, num_labels, ignore_mismatched_sizes=True).to(device)
    optim = AdamW(model.parameters(), lr=1e-4)
    
    train(train_loader)

In [32]:
def data_postprocess(df_test, df_answer):
    assert len(df_test) == len(df_answer), 'length not match'
    
    df_answer['q'] = df_answer['q'].apply(lambda x: x.replace(' ##', ''))
    df_answer['r'] = df_answer['r'].apply(lambda x: x.replace(' ##', ''))
    df_answer['q'] = df_answer['q'].apply(lambda x: x.replace('##', ''))
    df_answer['r'] = df_answer['r'].apply(lambda x: x.replace('##', ''))

    for idx, row in df_answer.iterrows():
        if len(row['q']) == 0:
            df_answer.loc[idx, 'q'] = df_test.loc[idx, 'q']
        if len(row['r']) == 0:
            df_answer.loc[idx, 'r'] = df_test.loc[idx, 'r']
    
    df_answer[['q', 'r']] = df_answer[['q', 'r']].apply(lambda x: x.str.strip('\"'))
    df_answer[['q', 'r']] = df_answer[['q', 'r']].apply(lambda x: '"' + x + '"')
    return df_answer

In [33]:
def submit():
    df_test = pd.read_csv("Batch_answers - test_data(no_label).csv", encoding = "utf-8")
    df_test[['q','r']] = df_test[['q','r']].apply(lambda x: x.str.strip('\"'))
    
    test_q = df_test["q"].tolist()
    test_r = df_test["r"].tolist()

    test_encoding = tokenizer(test_q, test_r, truncation=True, padding=True)
    test_encoding["input_token_pos"] = get_input_pos(test_encoding)

    test_dataset = MyDataset(test_encoding)
    test_loader = DataLoader(test_dataset, batch_size, shuffle=False)
    
    predict_q, predict_r = pred(test_loader)
    df_answer = pd.DataFrame()
    df_answer['id'] = df_test['id']
    df_answer['q'] = predict_q
    df_answer['r'] = predict_r

    df_answer = data_postprocess(df_test, df_answer)
    df_answer.to_csv('submission_ner.csv', index=False, encoding='utf-8')