In [1]:
import random
import os
import numpy as np
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoModel
from tqdm import tqdm
torch.cuda.set_device(1)

In [2]:
BATCH_SIZE=8

In [3]:
def split_data(features):
    pos_features = []
    neg_features = []
    for feature in features:
        if feature['labels'][0] == 1:
            neg_features.append(feature)
        else:
            pos_features.append(feature)

    return pos_features, neg_features

In [4]:
def combine_shuffle(pos_data, neg_data):
    outputs = []
    for i in range(len(neg_data)):
        outputs.append(torch.cat([pos_data[i], neg_data[i]], 0))
    # shuffle
    rand_index = np.arange(outputs[0].shape[0])
    np.random.shuffle(rand_index)
    for i in range(len(outputs)):
        outputs[i] = outputs[i][rand_index]

    return outputs

In [5]:
def data_gen(data):
    all_input_ids = torch.tensor([f['input_ids'] for f in data], dtype=torch.long)
    all_input_mask = torch.tensor([f['input_mask'] for f in data], dtype=torch.long)
    all_segment_ids = torch.tensor([f['segment_ids'] for f in data], dtype=torch.long)
    all_labels = torch.tensor([f['labels'] for f in data], dtype=torch.float)

    return all_input_ids, all_input_mask, all_segment_ids, all_labels

In [6]:
def get_result(y_pred, y_true,
               acc_num, precision_num, recall_num,
               acc_num_ign, precision_num_ign, recall_num_ign, th=0.5):
    intrain = y_true['intrain']
    y_true = y_true['labels']
    y_pred = y_pred[1:]
    y_true = y_true[1:]
    y_pred[y_pred > th] = 1
    y_pred[y_pred < th] = 0

    y_add = y_pred + y_true
    y_add[y_add != 2] = 0
    y_add[y_add == 2] = 1

    recall_num += np.sum(y_true)
    precision_num += np.sum(y_pred)
    acc_num += np.sum(y_add)

    if not intrain:
        recall_num_ign += np.sum(y_true)
        precision_num_ign += np.sum(y_pred)
        acc_num_ign += np.sum(y_add)

    return acc_num, precision_num, recall_num, acc_num_ign, precision_num_ign, recall_num_ign

In [7]:
def evaluate(model, dev_features, device=torch.device("cuda")):
    eval_dataloader = DataLoader(dev_features, batch_size=BATCH_SIZE, collate_fn=data_gen,
                                 shuffle=False)

    model.eval()
    eval_index = 0
    acc_num = acc_num_ign = 0
    precision_num = precision_num_ign = 0
    recall_num = recall_num_ign = 0

    print("Start evaluating")
    with tqdm(total=len(eval_dataloader), desc='Evaluating') as pbar:
        for batch_data in eval_dataloader:
            input_ids, input_mask, segment_ids, _ = batch_data
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)

            with torch.no_grad():
                preds = model(input_ids, input_mask)

            preds = preds.cpu().numpy()
            for i in range(preds.shape[0]):
                acc_num, precision_num, recall_num, \
                acc_num_ign, precision_num_ign, recall_num_ign = get_result(preds[i], dev_features[eval_index],
                                                                            acc_num, precision_num, recall_num,
                                                                            acc_num_ign, precision_num_ign,
                                                                            recall_num_ign)
                eval_index += 1

            recall = acc_num / (recall_num + 1e-5)
            precision = acc_num / (precision_num + 1e-5)
            f1 = 2 * (recall * precision) / (recall + precision + 1e-5)

            recall_ign = acc_num_ign / (recall_num_ign + 1e-5)
            precision_ign = acc_num_ign / (precision_num_ign + 1e-5)
            f1_ign = 2 * (recall_ign * precision_ign) / (recall_ign + precision_ign + 1e-5)

            recall *= 100
            precision *= 100
            f1 *= 100
            recall_ign *= 100
            precision_ign *= 100
            f1_ign *= 100

            pbar.set_postfix({'Precision':'{:.5f}'.format(precision),'F1': '{:.5f}'.format(f1)})
            pbar.update(1)

    print('Precision:{:.3f}, Recall:{:.3f}, F1-score:{:.3f}'.format(precision, recall, f1))
    print('Precision_ignore:{:.3f}, Recall_ignore:{:.3f}, F1-score_ignore:{:.3f}'.format(precision_ign, recall_ign,
                                                                                         f1_ign))

In [8]:
class DLREModel(nn.Module):
    def __init__(self):
        super(DLREModel, self).__init__()
        self.bert = AutoModel.from_pretrained("bert-base-uncased")
        self.fc = nn.Linear(768, 97)
        
        
    def forward(self, input_ids, attention_mask):            
        input_ids=input_ids.cuda()
        attention_mask =attention_mask.cuda()
        bert_output = self.bert(input_ids, attention_mask=attention_mask)
        bert_cls_hidden_state = bert_output[0][:,0,:]       #提取[CLS]对应的隐藏状态
        output = self.fc(bert_cls_hidden_state)
        output = F.sigmoid(output)
        return output

In [9]:
# load data
train_feat_dir="data/train_annotated_cls_data.txt"
dev_feat_dir="data/dev_cls_data.txt"
print('loading data...')

train_features = []
with open(train_feat_dir, 'r') as f:
    for line in tqdm(f):
        train_features.append(json.loads(line))
dev_features = []
with open(dev_feat_dir, 'r') as f:
    for line in tqdm(f):
        dev_features.append(json.loads(line))

546it [00:00, 5455.66it/s]

loading data...


1189412it [05:18, 3733.74it/s]
392062it [01:53, 3469.22it/s]


In [10]:
pos_features, neg_features = split_data(train_features)

In [12]:
train_pos_dataloader = DataLoader(pos_features,
                                      batch_size=BATCH_SIZE//2,
                                      collate_fn=data_gen, shuffle=True)
neg_features = neg_features[:len(pos_features) * (len(neg_features) // len(pos_features))]  #使得neg是pos的整数倍
train_neg_dataloader = DataLoader(neg_features,
                                  batch_size=int(
                                      BATCH_SIZE//2 * (len(neg_features) // len(pos_features))),
                                  collate_fn=data_gen, shuffle=True)

In [13]:
print('start training...')
model = DLREModel()
optimizer = optim.Adam(model.parameters(), lr=1e-05)
becloss = nn.BCELoss()
model.train()
model.cuda()
becloss.cuda()

start training...


BCELoss()

In [14]:
gradient_accumulation_steps=8
for epoch in range(10):
    total_loss = 0
    for step, (pos_batch, neg_batch) in enumerate(zip(train_pos_dataloader, train_neg_dataloader)):
        pos_batch = tuple(t for t in pos_batch)
        neg_batch = tuple(t[0:BATCH_SIZE//2] for t in neg_batch)
        
        batch = combine_shuffle(pos_batch, neg_batch)

        input_ids, input_mask, segment_ids, relation_multi_label = batch
        input_ids = input_ids.cuda()
        input_mask = input_mask.cuda()
        relation_multi_label=relation_multi_label.cuda()
        
        outputs = model(input_ids, input_mask)
        
        if relation_multi_label.dtype != outputs.dtype:
            relation_multi_label = relation_multi_label.half()
        
        loss = becloss(outputs, relation_multi_label)
        
        total_loss += torch.mean(loss).item()
        
        if(gradient_accumulation_steps>0):
            loss = loss/gradient_accumulation_steps
        loss.backward()
        
        if((step+1)%gradient_accumulation_steps==0):
            optimizer.step()
            optimizer.zero_grad()
        
        
        
        if((step+1)%80==0):
            print('Epoch :{}[{}/{}({:.0f}%)]\t AVG-Loss:{:.6f}\t'.format(epoch,(step+1) * BATCH_SIZE,len(train_pos_dataloader.dataset)*2,100.0*step / len(train_pos_dataloader),total_loss/(step + 1)))
            total_loss = 0
    evaluate(model, dev_features)
    torch.save(model.state_dict(), "models/110checkpoint-"+str(epoch+1)+".pth")





Evaluating:   0%|          | 2/1430 [00:00<02:20, 10.14it/s, Precision=0.00000, F1=0.00000]

Start evaluating


Evaluating: 100%|██████████| 1430/1430 [02:21<00:00, 10.08it/s, Precision=83.34746, F1=26.95232]


Precision:83.347, Recall:16.076, F1-score:26.952
Precision_ignore:83.280, Recall_ignore:14.672, F1-score_ignore:24.948


Evaluating:   0%|          | 2/1430 [00:00<02:20, 10.19it/s, Precision=99.99983, F1=54.54501]

Start evaluating


Evaluating: 100%|██████████| 1430/1430 [02:21<00:00, 10.08it/s, Precision=96.58332, F1=57.91088]


Precision:96.583, Recall:41.353, F1-score:57.911
Precision_ignore:96.538, Recall_ignore:39.377, F1-score_ignore:55.937


Evaluating:   0%|          | 2/1430 [00:00<02:20, 10.15it/s, Precision=99.99988, F1=66.66617]

Start evaluating


Evaluating: 100%|██████████| 1430/1430 [02:21<00:00, 10.08it/s, Precision=96.84889, F1=69.34211]


Precision:96.849, Recall:54.005, F1-score:69.342
Precision_ignore:96.818, Recall_ignore:52.067, F1-score_ignore:67.716


Evaluating:   0%|          | 2/1430 [00:00<02:20, 10.18it/s, Precision=99.99991, F1=81.48094]

Start evaluating


Evaluating: 100%|██████████| 1430/1430 [02:22<00:00, 10.06it/s, Precision=96.68055, F1=73.12043]


Precision:96.681, Recall:58.794, F1-score:73.120
Precision_ignore:96.462, Recall_ignore:57.369, F1-score_ignore:71.947


Evaluating:   0%|          | 2/1430 [00:00<02:20, 10.17it/s, Precision=99.99992, F1=89.65462]

Start evaluating


Evaluating: 100%|██████████| 1430/1430 [02:21<00:00, 10.09it/s, Precision=96.49416, F1=78.47881]


Precision:96.494, Recall:66.133, F1-score:78.479
Precision_ignore:96.313, Recall_ignore:64.494, F1-score_ignore:77.255


Evaluating:   0%|          | 2/1430 [00:00<02:20, 10.17it/s, Precision=99.99993, F1=93.33277]

Start evaluating


Evaluating: 100%|██████████| 1430/1430 [02:21<00:00, 10.08it/s, Precision=95.38229, F1=78.83912]


Precision:95.382, Recall:67.187, F1-score:78.839
Precision_ignore:95.241, Recall_ignore:65.504, F1-score_ignore:77.622


Evaluating:   0%|          | 2/1430 [00:00<02:20, 10.20it/s, Precision=99.99993, F1=93.33277]

Start evaluating


Evaluating: 100%|██████████| 1430/1430 [02:21<00:00, 10.08it/s, Precision=95.66018, F1=79.43363]


Precision:95.660, Recall:67.914, F1-score:79.434
Precision_ignore:95.360, Recall_ignore:66.692, F1-score_ignore:78.490


Evaluating:   0%|          | 2/1430 [00:00<02:20, 10.18it/s, Precision=99.99993, F1=96.77363]

Start evaluating


Evaluating: 100%|██████████| 1430/1430 [02:21<00:00, 10.08it/s, Precision=95.23197, F1=78.38166]


Precision:95.232, Recall:66.599, F1-score:78.382
Precision_ignore:95.145, Recall_ignore:64.868, F1-score_ignore:77.142


Evaluating:   0%|          | 2/1430 [00:00<02:20, 10.17it/s, Precision=99.99994, F1=99.99944]

Start evaluating


Evaluating: 100%|██████████| 1430/1430 [02:21<00:00, 10.09it/s, Precision=94.78281, F1=79.38962]


Precision:94.783, Recall:68.298, F1-score:79.390
Precision_ignore:94.470, Recall_ignore:66.776, F1-score_ignore:78.244


Evaluating:   0%|          | 2/1430 [00:00<02:19, 10.27it/s, Precision=99.99994, F1=99.99944]

Start evaluating


Evaluating: 100%|██████████| 1430/1430 [02:21<00:00, 10.08it/s, Precision=94.11505, F1=80.14968]


Precision:94.115, Recall:69.794, F1-score:80.150
Precision_ignore:93.916, Recall_ignore:67.842, F1-score_ignore:78.777
