In [1]:
## tool
import pickle
import tqdm
import numpy as np
import pandas as pd
import os
import time
import re

import textwrap

## plot
import matplotlib.pyplot as plt
import seaborn as sns

from seqeval.metrics import f1_score, accuracy_score

## Bert
import transformers
from transformers import BertForTokenClassification, AdamW, get_linear_schedule_with_warmup
from transformers import BertTokenizer, BertConfig

## module
from module.ner_trainer import NRE_Trainer

## torch
import torch
from torch.utils.data import DataLoader, random_split, Dataset

In [2]:
torch.cuda.get_device_name(0)

'GeForce RTX 2080 Ti'

## Load Model

In [3]:
model = BertForTokenClassification.from_pretrained("bert-base-chinese",
    num_labels = 3,
    output_attentions = False,
    output_hidden_states = False
)

NRE_trainer = NRE_Trainer(model, None, None)
NRE_trainer.model.load_state_dict(torch.load('params/best-model.pth'))

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-c

device:cuda


<All keys matched successfully>

## 預測名字

### Helper Function

In [4]:
def preprocess(text):
    tokenized_sentence = tokenizer_chinese.encode(text)
    input_ids = tokenized_sentence
    return input_ids

In [5]:
def split_sentence(text):
    sentences = []
    if len(text) < 500:
        sentences.append(text)
        
    elif len(text) > 500 and len(text) <= 1000:
        sentences.append(text[:500])
        sentences.append(text[-500:])
        
    elif len(text) > 1000 and len(text) <= 1500:
        mid = int(len(text)/2)
        sentences.append(text[:500])
        sentences.append(text[mid-250:mid+250])
        sentences.append(text[-500:])
        
    elif len(text) > 1500 and len(text) <= 2000:
        point_1 = int(len(text)*0.25)
        point_2 = int(len(text)*0.75)
        sentences.append(text[:500])
        sentences.append(text[point_1-250:point_1+250])
        sentences.append(text[point_2-250:point_2+250])
        sentences.append(text[-500:])
        
    else:
        mid = int(len(text)/2)
        point_1 = int(len(text)*0.25)
        point_2 = int(len(text)*0.75)
        sentences.append(text[:500])
        sentences.append(text[point_1-250:point_1+250])
        sentences.append(text[mid-250:mid+250])
        sentences.append(text[point_2-250:point_2+250])
        sentences.append(text[-500:])
        
    return sentences

In [6]:
# def split_sentence(text):
#     sentences = re.findall(u'[^!?。\.\!\?]+[!?。\.\!\?]?', text, flags=re.U)
#     sentences_under512 = []
    
#     for sentence in sentences:
        
#         l = len(sentence)
        
#         if l > 512:
#             num = int(l/512)+1
#             sentences_under512 += textwrap.wrap(sentence, int(l/num))
#         else:
#             sentences_under512.append(sentence)
            
#     return sentences_under512

In [7]:
def pad_to_len(seqs, to_len, padding=0):
    paddeds = []
    for seq in seqs:
        paddeds.append(
            seq[:to_len] + [padding] * max(0, to_len - len(seq))
        )

    return paddeds

In [8]:
class Bert_dataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        sample = self.data[index]
        instance = {
            'id': sample['ID'],
            'token': sample['token_id'],
        }
        return instance
    
    def collate_fn(self, samples):
        batch = {}
        
        for key in ['id']:
            if any(key not in sample for sample in samples):
                continue
            batch[key] = [sample[key] for sample in samples]
            
        for key in ['token']:
            if any(key not in sample for sample in samples):
                continue
                
            padded = pad_to_len(
                [sample[key] for sample in samples], 512, 0
            )
            batch[key] = torch.tensor(padded)

        return batch

In [9]:
def str2list(s):    
    if s == '[]':
        return []
    return [str(i.replace(' ', '')[1:-1]) for i in s[1:-1].split(',')]

In [10]:
def find_all_indexes(input_str, search_str):
    l1 = []
    length = len(input_str)
    index = 0
    while index < length:
        i = input_str.find(search_str, index)
        if i == -1:
            return l1
        l1.append(i)
        index = i + 1
    return l1

In [11]:
def get_index(name_index, sentence_index):
    
    if name_index == []:
        return []
    
    arr = []
    j = 0
    l_name = len(name_index)
    l_sentence = len(sentence_index)
    
    while j < l_sentence:
        if name_index[0] == sentence_index[j]:
            
            flag = 1
            record = j
            
            for i in range(l_name):
                if name_index[i] != sentence_index[j]:
                    flag = 0
                    break
                j+=1
                
            if flag:
                j-=1
                arr+=[i for i in range(record, record+l_name)]

        j+=1
    
    return arr

### Load Data

In [12]:
train_csv = pd.read_csv('data/train.csv')
train_df  = train_csv.drop(['hyperlink', 'content', 'domain', 'name'], axis=1)

# load pretrained bert model 
tokenizer_chinese = BertTokenizer.from_pretrained("bert-base-chinese", do_lower_case=False)

### Preprocess

In [13]:
%%time
train_df['length']    = train_csv['article'].apply(lambda x: len(x))
train_df['sentences'] = train_csv['article'].apply(lambda x: split_sentence(x))
train_df['token_ids'] = train_df['sentences'].apply(lambda x: [preprocess(sentence) for sentence in x])
train_df['blacklist'] = train_csv['name'].apply(lambda x: str2list(x))

CPU times: user 14.9 s, sys: 10.6 ms, total: 14.9 s
Wall time: 14.9 s


In [14]:
data_collection = []

for data in tqdm.tqdm(train_df.values):
    news_ID, article, length, sentences, token_ids, blacklist = data
    
    for token_id, sentence in zip(token_ids, sentences):
        sample = {}
        sample['ID'] = news_ID
        sample['original_article'] = article
        sample['length'] = length
        sample['sentence'] = sentence
        sample['token_id'] = token_id
        
        data_collection.append(sample)
    
#     break

100%|██████████| 5023/5023 [00:00<00:00, 298172.71it/s]


In [15]:
bert_dataset = Bert_dataset(data_collection)

train_loader = DataLoader(
    dataset = bert_dataset,
    batch_size = 32,
    collate_fn = lambda x: Bert_dataset.collate_fn(bert_dataset, x)
)

In [16]:
print(len(bert_dataset), len(train_loader))

11137 349


## Predict

In [18]:
answer = {}

with torch.no_grad():
    for step, batch in enumerate(tqdm.tqdm(train_loader), 0):
        
        ## input
        b_input_token = batch['token'].to(NRE_trainer.device)
        
        ## model
        output = model(b_input_token)
        
        ## output indices
        label_indices = np.argmax(output[0].to('cpu').numpy(), axis=2)
        
        ## post precessing
        for i, input_token in enumerate(b_input_token.to('cpu').numpy()):
            
            words  = tokenizer_chinese.convert_ids_to_tokens(input_token)
            labels = label_indices[i]
            ID     = batch['id'][i]
                                
            names = []
            name_string = ''
            
            for word, label in zip(words, labels):
                if label == 1:
                    name_string += word
                else:
                    if name_string != '':
                        names.append(name_string)
                        name_string = ''
                        
            if ID in answer:
                arr = answer[ID]
                arr+=names
            else:
                answer[ID] = names
                
#         if step > 0:
#             break
                
for key in answer:
    answer[key] = list(set(answer[key]))

100%|██████████| 349/349 [02:06<00:00,  2.75it/s]


In [19]:
train_df['name'] = train_df.news_ID.apply(lambda x: answer[x])

In [20]:
train_df

Unnamed: 0,news_ID,article,length,sentences,token_ids,blacklist,name
0,1,理財基金量化交易追求絕對報酬 有效對抗牛熊市鉅亨網記者 鄭心芸2019/07/05 22:3...,1723,[理財基金量化交易追求絕對報酬 有效對抗牛熊市鉅亨網記者 鄭心芸2019/07/05 22:...,"[[101, 4415, 6512, 1825, 7032, 7030, 1265, 769...",[],"[詹姆斯·西蒙斯, 張堯勇, 鄭心芸]"
1,2,10月13日晚間發生Uber Eats黃姓外送人員職災死亡案件，北市府勞動局認定業者未依職業...,726,[10月13日晚間發生Uber Eats黃姓外送人員職災死亡案件，北市府勞動局認定業者未依職...,"[[101, 8108, 3299, 8124, 3189, 3241, 7279, 463...",[],"[賴香伶, 康水順, 黃]"
2,3,社會2019.10.08 09:53【法拍有詭4】飯店遭管委會斷水斷電員工怒吼：生計何去何從...,1154,[社會2019.10.08 09:53【法拍有詭4】飯店遭管委會斷水斷電員工怒吼：生計何去何...,"[[101, 4852, 3298, 9160, 119, 8108, 119, 8142,...",[],"[張慶輝, 林, 李育材, 李日順]"
3,4,文章已被刪除 404 or 例外,16,[文章已被刪除 404 or 例外],"[[101, 3152, 4995, 2347, 6158, 1165, 7370, 105...",[],[]
4,5,例稿名稱：臺灣屏東地方法院公示催告公告發文日期：中華民國108年9月20日發文字號：屏院進家...,671,[例稿名稱：臺灣屏東地方法院公示催告公告發文日期：中華民國108年9月20日發文字號：屏院進...,"[[101, 891, 4943, 1399, 4935, 8038, 5637, 4124...",[],"[沈君融, 陳世恒]"
...,...,...,...,...,...,...,...
5018,5019,香港特首林鄭月娥4日宣布撤回逃犯條例修訂，示威者斥為「太遲太少」，「一碗水救不了森林大火」，...,529,[香港特首林鄭月娥4日宣布撤回逃犯條例修訂，示威者斥為「太遲太少」，「一碗水救不了森林大火」...,"[[101, 7676, 3949, 4294, 7674, 3360, 6972, 329...",[],[林鄭月娥]
5019,5020,台股台股盤勢【華冠投顧】OTC轉強 小樹走高華冠投顧※來源：華冠投顧2019/07/15 1...,489,[台股台股盤勢【華冠投顧】OTC轉強 小樹走高華冠投顧※來源：華冠投顧2019/07/15 ...,"[[101, 1378, 5500, 1378, 5500, 4676, 1248, 523...",[],"[謝宗霖, 川普]"
5020,5021,近日教育部在媒體上宣布駁回世新大學社發所的停招申請案，但卻沒同時宣布其他大學類似的申請案，其...,2030,[近日教育部在媒體上宣布駁回世新大學社發所的停招申請案，但卻沒同時宣布其他大學類似的申請案，...,"[[101, 6818, 3189, 3136, 5509, 6956, 1762, 205...",[],"[江漢聲, 王英洲, 潘文忠, 朱俊彰]"
5021,5022,史上金額最大開發案「台北雙子星」最優申請人「南海團隊」香港商南海發展有限公司、馬來西亞商馬頓...,932,[史上金額最大開發案「台北雙子星」最優申請人「南海團隊」香港商南海發展有限公司、馬來西亞商馬...,"[[101, 1380, 677, 7032, 7540, 3297, 1920, 7274...",[],[]
