In [1]:
## tool
import pickle
import tqdm
import numpy as np
import pandas as pd
import os
import time
import re
import textwrap
import matplotlib.pyplot as plt

## plot
import matplotlib.pyplot as plt
import seaborn as sns

from seqeval.metrics import f1_score, accuracy_score

## Bert
import transformers
from transformers import BertForTokenClassification, AdamW, get_linear_schedule_with_warmup
from transformers import BertTokenizer, BertConfig

## module
from module.ner_trainer import NRE_Trainer

## torch
import torch
from torch.utils.data import DataLoader, random_split, Dataset

In [2]:
torch.cuda.get_device_name(0)

'GeForce RTX 2080 Ti'

## Preprocess

### Helper Function

In [3]:
def preprocess(text):
    tokenized_sentence = tokenizer_chinese.encode(text)
    return tokenized_sentence

In [4]:
def split_sentence(text):
    sentences = []
    if len(text) < 500:
        sentences.append(text)
        
    elif len(text) > 500 and len(text) <= 1000:
        sentences.append(text[:500])
        sentences.append(text[-500:])
        
    elif len(text) > 1000 and len(text) <= 1500:
        mid = int(len(text)/2)
        sentences.append(text[:500])
        sentences.append(text[mid-250:mid+250])
        sentences.append(text[-500:])
        
    elif len(text) > 1500 and len(text) <= 2000:
        point_1 = int(len(text)*0.25)
        point_2 = int(len(text)*0.75)
        sentences.append(text[:500])
        sentences.append(text[point_1-250:point_1+250])
        sentences.append(text[point_2-250:point_2+250])
        sentences.append(text[-500:])
        
    else:
        mid = int(len(text)/2)
        point_1 = int(len(text)*0.25)
        point_2 = int(len(text)*0.75)
        sentences.append(text[:500])
        sentences.append(text[point_1-250:point_1+250])
        sentences.append(text[mid-250:mid+250])
        sentences.append(text[point_2-250:point_2+250])
        sentences.append(text[-500:])
        
    return sentences

In [5]:
def pad_to_len(seqs, to_len, padding=0):
    paddeds = []
    for seq in seqs:
        paddeds.append(
            seq[:to_len] + [padding] * max(0, to_len - len(seq))
        )
    return paddeds

In [6]:
class Bert_dataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        sample = self.data[index]
        instance = {
            'id': sample['ID'],
            'words': sample['token_id'],
            'tag' : sample['label'],
            'segment': [0]*len(sample['token_id']),
#             'mask': [1]*(len(sample['token_id']))
            'mask': sample['mask']
        }
        return instance
    
    def collate_fn(self, samples):
        batch = {}
        for key in ['id']:
            if any(key not in sample for sample in samples):
                continue
            batch[key] = [sample[key] for sample in samples]
            
        for key in ['words', 'tag', 'segment', 'mask']:
            if any(key not in sample for sample in samples):
                continue
            to_len = max([len(sample[key]) for sample in samples])
            if key =='tag':
                pad_logit = 2
            else:
                pad_logit = 0
            padded = pad_to_len(
                [sample[key] for sample in samples], 512, pad_logit
            )
            batch[key] = torch.tensor(padded).long()

        return batch

In [7]:
def str2list(s):    
    if s == '[]':
        return []
    return [str(i.replace(' ', '')[1:-1]) for i in s[1:-1].split(',')]

In [8]:
def get_index(name_index, sentence_index):
    
    if name_index == []:
        return []
    
    arr = []
    j = 0
    l_name = len(name_index)
    l_sentence = len(sentence_index)
    
    while j < l_sentence:
        if name_index[0] == sentence_index[j]:
            
            flag = 1
            record = j
            
            for i in range(l_name):
                if name_index[i] != sentence_index[j]:
                    flag = 0
                    break
                j+=1
                
            if flag:
                j-=1
                arr+=[i for i in range(record, record+l_name)]
        j+=1
    
    return arr

#### Load Data

In [9]:
train_csv = pd.read_csv('data/train.csv')
train_df  = train_csv.drop(['hyperlink', 'content', 'domain', 'name'], axis=1)

# load pretrained bert model 
tokenizer_chinese = BertTokenizer.from_pretrained("bert-base-chinese", do_lower_case=False)

### Preprocess

In [10]:
%%time
train_df['length']    = train_csv['article'].apply(lambda x: len(x))
train_df['sentences'] = train_csv['article'].apply(lambda x: split_sentence(x))
train_df['token_ids'] = train_df['sentences'].apply(lambda x: [preprocess(sentence) for sentence in x])
train_df['blacklist'] = train_csv['name'].apply(lambda x: str2list(x))
train_df['names']     = pd.read_csv('data/name.csv')['0'].apply(lambda x: str2list(x))

CPU times: user 14.7 s, sys: 46.4 ms, total: 14.8 s
Wall time: 14.8 s


In [11]:
## delete some useless data
train_df = train_df[train_df.article != '文章已被刪除 404 or 例外']
train_df = train_df[(train_df['length']) <= 2551]
train_df = train_df[(train_df['length']) >= 75]

In [12]:
train_df = train_df.reset_index(drop=True)
len(train_df)

4590

In [13]:
answer_json_type = {}
data_collection  = []

for data in tqdm.tqdm(train_df.values):
    news_ID, article, length, sentences, token_ids, blacklist, names = data
    
    for token_id, sentence in zip(token_ids, sentences):
        sample = {}
        sample['ID'] = news_ID
        sample['original_article'] = article
        sample['length'] = length
        sample['sentence'] = sentence
        sample['token_id'] = token_id
        sample['blacklist'] = blacklist
        sample['names'] = names 
        
        ## get label
        one_hot = torch.zeros(len(token_id)).tolist()
        one_hot2 = torch.zeros(len(token_id)).tolist()
        
        position, position2 = [], []
        
        for black_name in blacklist:
            black_name_ids = tokenizer_chinese.encode(black_name)[1:-1]
            position+=get_index(black_name_ids, token_id)
            
        for i in position:
            one_hot[i] = 1
            
        for name in names:
            name_ids = tokenizer_chinese.encode(name)[1:-1]
            position2+=get_index(name_ids, token_id)
            
        for i in position2:
            one_hot2[i] = 1
        
        sample['label'] = one_hot
        sample['mask']  = one_hot2
        
        data_collection.append(sample)
        
    if news_ID in answer_json_type:
        pass
    else:
        answer_json_type[news_ID] = {'names': names, 'blacklist': blacklist}

100%|██████████| 4590/4590 [00:03<00:00, 1165.29it/s]


In [14]:
torch.manual_seed(0)

bert_dataset = Bert_dataset(data_collection)

train_dataset, test_dataset = random_split(bert_dataset,
                                           [int(len(bert_dataset)*0.8), 
                                            len(bert_dataset)-int(len(bert_dataset)*0.8)])

valid_dataset, test_dataset = random_split(test_dataset, 
                                           [int(len(test_dataset)*0.5), 
                                            len(test_dataset)-int(len(test_dataset)*0.5)])

## Hyperparameter
n_train = len(train_dataset)
n_valid = len(valid_dataset)
n_test  = len(test_dataset)
BATCH_SIZE = 8

In [15]:
print(len([1 for i in train_dataset if 1 in i['tag']]), n_train)

242 3672


In [16]:
print(len([1 for i in valid_dataset if 1 in i['tag']]), n_valid)

27 459


In [17]:
print(len([1 for i in test_dataset if 1 in i['tag']]), n_test)

24 459


In [18]:
train_loader = DataLoader(
    dataset = train_dataset,
    batch_size = BATCH_SIZE,
    shuffle = True,
    collate_fn = lambda x: Bert_dataset.collate_fn(train_dataset, x)
)

valid_loader = DataLoader(
    dataset = valid_dataset,
    batch_size = BATCH_SIZE,
    collate_fn = lambda x: Bert_dataset.collate_fn(valid_dataset, x)
)

test_loader = DataLoader(
    dataset = test_dataset,
    batch_size = BATCH_SIZE,
    collate_fn = lambda x: Bert_dataset.collate_fn(test_dataset, x)
)

## Training

In [19]:
model = BertForTokenClassification.from_pretrained("bert-base-chinese",
    num_labels = 3,
    output_attentions = False,
    output_hidden_states = False
)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-c

In [20]:
trainer = NRE_Trainer(model, train_loader, valid_loader)
trainer.tag2idx = { 'O': 0, 'blacklist': 1, 'PAD': 2}
trainer.tag_values = ['O', 'blacklist', 'PAD']

device:cuda


In [21]:
trainer.model.load_state_dict(torch.load('params/best-model.pth'))

<All keys matched successfully>

In [22]:
acc, total_loss = trainer.evaluation(test=False)
print(f"device: {trainer.device} classification acc: {acc: .4f} validation loss: {total_loss:.4f}")

100%|██████████| 58/58 [00:05<00:00, 10.28it/s]

device: cuda classification acc:  0.9603 validation loss: 0.5652





In [None]:
trainer.training_process(early_stopping = True, 
                         n_iter_no_change = 5, 
                         max_epoch = 50, 
                         save_params = True, 
                         verbose = True, 
                         learning_rate = 3e-5, 
                         save_paths='model-best-fine-tune.pth')

100%|██████████| 459/459 [02:12<00:00,  3.47it/s]
100%|██████████| 58/58 [00:06<00:00,  9.45it/s]


 time: 2m 20s | epoch:0 train loss: 0.3837 | validation loss: 0.3709 acc: 0.9974

100%|██████████| 459/459 [02:22<00:00,  3.23it/s]
100%|██████████| 58/58 [00:05<00:00, 10.62it/s]


 time: 4m 49s | epoch:1 train loss: 0.3444 | validation loss: 0.3560 acc: 0.9974

100%|██████████| 459/459 [02:24<00:00,  3.19it/s]
100%|██████████| 58/58 [00:05<00:00,  9.99it/s]


 time: 7m 20s | epoch:2 train loss: 0.3479 | validation loss: 0.3541 acc: 0.9974

100%|██████████| 459/459 [02:16<00:00,  3.36it/s]
100%|██████████| 58/58 [00:05<00:00, 10.65it/s]


 time: 9m 44s | epoch:3 train loss: 0.3420 | validation loss: 0.3850 acc: 0.9974 -- 1

100%|██████████| 459/459 [02:09<00:00,  3.53it/s]
100%|██████████| 58/58 [00:05<00:00, 10.67it/s]


 time: 12m 0s | epoch:4 train loss: 0.3655 | validation loss: 0.3300 acc: 0.9974

100%|██████████| 459/459 [02:09<00:00,  3.53it/s]
100%|██████████| 58/58 [00:05<00:00, 10.67it/s]


 time: 14m 17s | epoch:5 train loss: 0.3475 | validation loss: 0.3748 acc: 0.9974 -- 1

100%|██████████| 459/459 [02:09<00:00,  3.53it/s]
100%|██████████| 58/58 [00:05<00:00, 10.66it/s]


 time: 16m 34s | epoch:6 train loss: 0.3536 | validation loss: 0.3171 acc: 0.9974

100%|██████████| 459/459 [02:09<00:00,  3.53it/s]
100%|██████████| 58/58 [00:05<00:00, 10.66it/s]


 time: 18m 51s | epoch:7 train loss: 0.3545 | validation loss: 0.3201 acc: 0.9974 -- 1

100%|██████████| 459/459 [02:14<00:00,  3.40it/s]
100%|██████████| 58/58 [00:06<00:00,  9.47it/s]


 time: 21m 13s | epoch:8 train loss: 0.3528 | validation loss: 0.3480 acc: 0.9974 -- 2

100%|██████████| 459/459 [02:26<00:00,  3.14it/s]
100%|██████████| 58/58 [00:06<00:00,  9.47it/s]


 time: 23m 46s | epoch:9 train loss: 0.3615 | validation loss: 0.3254 acc: 0.9974 -- 3

100%|██████████| 459/459 [02:26<00:00,  3.14it/s]
100%|██████████| 58/58 [00:06<00:00,  9.47it/s]


 time: 26m 19s | epoch:10 train loss: 0.3557 | validation loss: 0.3282 acc: 0.9974 -- 4

 81%|████████▏ | 374/459 [01:59<00:27,  3.14it/s]

In [None]:
# trainer.model.load_state_dict(torch.load('model-best-fine-tune.pth'))

## Predict

In [25]:
raw_answer = {}
answer = {}

with torch.no_grad():
    for _, batch in enumerate(tqdm.tqdm(test_loader)):
        b_input_token  = batch['words'].to(trainer.device)
        output = model(b_input_token)
        label_indices = np.argmax(output[0].to('cpu').numpy(), axis=2)
        
        for i, input_token in enumerate(b_input_token.to('cpu').numpy()):
            
            words  = tokenizer_chinese.convert_ids_to_tokens(input_token)
            labels = label_indices[i]
            ID     = batch['id'][i]
                                
            names = []
            name_string = ''
            
            for word, label in zip(words, labels):
                if label == 1:
                    name_string += word
                else:
                    if name_string != '':
                        names.append(name_string)
                        name_string = ''
            
            if ID in answer:
                arr = answer[ID]
                arr+=names
            else:
                answer[ID] = names


for key in answer:
    answer[key] = list(set(answer[key]))

100%|██████████| 58/58 [00:05<00:00, 10.36it/s]


In [26]:
def score(truth, predict):
    if truth == [] and predict != []:
        return 0
    
    if truth != [] and predict == []:
        return 0
    
    if truth == [] and predict == []:
        return 1
    
    recall = len([i for i in truth if i in predict])/len(truth)
    precision = len([i for i in predict if i in truth])/len(predict)
    
    try:
        return 2/((1/recall)+(1/precision))
    except:
        return 0

In [28]:
total_score = 0

for ID in answer:
    blacklist = answer_json_type[ID]['blacklist']
    
    total_score+=score(answer[ID], blacklist)
    
    print(ID, 'Answer:',answer[ID], 
          '黑名單:', answer_json_type[ID]['blacklist'],
#           'Names:', answer_json_type[ID]['names'],
          'Score:', score(answer[ID], blacklist)
         )

print(total_score)
print(total_score/len(answer))

4858 Answer: [] 黑名單: [] Score: 1
1177 Answer: [] 黑名單: [] Score: 1
153 Answer: [] 黑名單: [] Score: 1
1442 Answer: [] 黑名單: [] Score: 1
2901 Answer: [] 黑名單: [] Score: 1
3565 Answer: [] 黑名單: ['穆曉光'] Score: 0
1134 Answer: [] 黑名單: [] Score: 1
3964 Answer: [] 黑名單: [] Score: 1
2671 Answer: [] 黑名單: [] Score: 1
670 Answer: [] 黑名單: [] Score: 1
1818 Answer: [] 黑名單: [] Score: 1
4426 Answer: [] 黑名單: [] Score: 1
4864 Answer: [] 黑名單: [] Score: 1
4063 Answer: [] 黑名單: [] Score: 1
4139 Answer: [] 黑名單: [] Score: 1
1684 Answer: [] 黑名單: [] Score: 1
1910 Answer: [] 黑名單: [] Score: 1
4081 Answer: [] 黑名單: [] Score: 1
1130 Answer: [] 黑名單: [] Score: 1
2061 Answer: [] 黑名單: [] Score: 1
1973 Answer: [] 黑名單: [] Score: 1
4831 Answer: [] 黑名單: [] Score: 1
2488 Answer: [] 黑名單: [] Score: 1
1639 Answer: [] 黑名單: [] Score: 1
4028 Answer: [] 黑名單: [] Score: 1
1444 Answer: [] 黑名單: ['黃世陽', '黃顯雄'] Score: 0
3744 Answer: [] 黑名單: [] Score: 1
3499 Answer: [] 黑名單: [] Score: 1
4150 Answer: [] 黑名單: [] Score: 1
2651 Answer: [] 黑名單: ['李保承',