In [2]:
## tool
import pickle
import tqdm
import numpy as np
import pandas as pd
import os
import time
import re
import textwrap

## plot
import matplotlib.pyplot as plt
import seaborn as sns

from seqeval.metrics import f1_score, accuracy_score

## Bert
import transformers
from transformers import BertForTokenClassification, AdamW, get_linear_schedule_with_warmup
from transformers import BertTokenizer, BertConfig

## module
from module.ner_trainer import NRE_Trainer

## torch
import torch
from torch.utils.data import DataLoader, random_split, Dataset

In [3]:
torch.cuda.get_device_name(0)

'GeForce RTX 2080 Ti'

## Load Model

In [4]:
model = BertForTokenClassification.from_pretrained("bert-base-chinese",
    num_labels = 3,
    output_attentions = False,
    output_hidden_states = False
)

NRE_trainer = NRE_Trainer(model, None, None)
NRE_trainer.model.load_state_dict(torch.load('params/best-model-test.pth'))

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-c

device:cuda


<All keys matched successfully>

## 預測名字

#### Helper Function

In [5]:
def preprocess(text):
    tokenized_sentence = tokenizer_chinese.encode(text)
    input_ids = tokenized_sentence
    return input_ids

In [6]:
def split_sentence(text):
    sentences = []
    if len(text) < 500:
        sentences.append(text)
        
    elif len(text) > 500 and len(text) <= 1000:
        sentences.append(text[:500])
        sentences.append(text[-500:])
        
    elif len(text) > 1000 and len(text) <= 1500:
        mid = int(len(text)/2)
        sentences.append(text[:500])
        sentences.append(text[mid-250:mid+250])
        sentences.append(text[-500:])
        
    elif len(text) > 1500 and len(text) <= 2000:
        point_1 = int(len(text)*0.25)
        point_2 = int(len(text)*0.75)
        sentences.append(text[:500])
        sentences.append(text[point_1-250:point_1+250])
        sentences.append(text[point_2-250:point_2+250])
        sentences.append(text[-500:])
        
    else:
        mid = int(len(text)/2)
        point_1 = int(len(text)*0.25)
        point_2 = int(len(text)*0.75)
        sentences.append(text[:500])
        sentences.append(text[point_1-250:point_1+250])
        sentences.append(text[mid-250:mid+250])
        sentences.append(text[point_2-250:point_2+250])
        sentences.append(text[-500:])
        
    return sentences

In [7]:
# def split_sentence(text):
#     sentences = re.findall(u'[^!?。\.\!\?]+[!?。\.\!\?]?', text, flags=re.U)
#     sentences_under512 = []
    
#     for sentence in sentences:
        
#         l = len(sentence)
        
#         if l > 512:
#             num = int(l/512)+1
#             sentences_under512 += textwrap.wrap(sentence, int(l/num))
#         else:
#             sentences_under512.append(sentence)
            
#     return sentences_under512

In [8]:
def pad_to_len(seqs, to_len, padding=0):
    paddeds = []
    for seq in seqs:
        paddeds.append(
            seq[:to_len] + [padding] * max(0, to_len - len(seq))
        )

    return paddeds

In [9]:
class Bert_dataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        sample = self.data[index]
        instance = {
            'id': sample['ID'],
            'token': sample['token_id'],
        }
        return instance
    
    def collate_fn(self, samples):
        batch = {}
        
        for key in ['id']:
            if any(key not in sample for sample in samples):
                continue
            batch[key] = [sample[key] for sample in samples]
            
        for key in ['token']:
            if any(key not in sample for sample in samples):
                continue
                
            padded = pad_to_len(
                [sample[key] for sample in samples], 512, 0
            )
            batch[key] = torch.tensor(padded)

        return batch

In [10]:
def str2list(s):    
    if s == '[]':
        return []
    return [str(i.replace(' ', '')[1:-1]) for i in s[1:-1].split(',')]

In [11]:
def find_all_indexes(input_str, search_str):
    l1 = []
    length = len(input_str)
    index = 0
    while index < length:
        i = input_str.find(search_str, index)
        if i == -1:
            return l1
        l1.append(i)
        index = i + 1
    return l1

In [12]:
def get_index(name_index, sentence_index):
    
    if name_index == []:
        return []
    
    arr = []
    j = 0
    l_name = len(name_index)
    l_sentence = len(sentence_index)
    
    while j < l_sentence:
        if name_index[0] == sentence_index[j]:
            
            flag = 1
            record = j
            
            for i in range(l_name):
                if name_index[i] != sentence_index[j]:
                    flag = 0
                    break
                j+=1
                
            if flag:
                j-=1
                arr+=[i for i in range(record, record+l_name)]

        j+=1
    
    return arr

#### Load Data

In [13]:
train_csv = pd.read_csv('data/train.csv')
train_df  = train_csv.drop(['hyperlink', 'content', 'domain', 'name'], axis=1)

In [14]:
# load pretrained bert model 
tokenizer_chinese = BertTokenizer.from_pretrained("bert-base-chinese", do_lower_case=False)

#### Preprocess

In [15]:
%%time
train_df['length']    = train_csv['article'].apply(lambda x: len(x))
train_df['sentences'] = train_csv['article'].apply(lambda x: split_sentence(x))
train_df['token_ids'] = train_df['sentences'].apply(lambda x: [preprocess(sentence) for sentence in x])
train_df['blacklist'] = train_csv['name'].apply(lambda x: str2list(x))

CPU times: user 14.8 s, sys: 23.8 ms, total: 14.9 s
Wall time: 14.9 s


In [16]:
raw_data = []

for data in tqdm.tqdm(train_df.values):
    news_ID, article, length, sentences, token_ids, blacklist = data
    
    sample = {}
    
    for token_id, sentence in zip(token_ids, sentences):
        sample['ID'] = news_ID
        sample['original_article'] = article
        sample['length'] = length
        sample['sentence'] = sentence
        sample['token_id'] = token_id
        
        raw_data.append(sample)

100%|██████████| 5023/5023 [00:00<00:00, 326084.43it/s]


In [17]:
bert_dataset = Bert_dataset(raw_data)

train_loader = DataLoader(
    dataset = bert_dataset,
    batch_size = 32,
    collate_fn = lambda x: Bert_dataset.collate_fn(bert_dataset, x)
)

In [18]:
print(len(bert_dataset), len(train_loader))

11137 349


#### Predict Name

In [None]:
answer = {}

with torch.no_grad():
    for step, batch in enumerate(tqdm.tqdm(train_loader), 0):
        b_input_token = batch['token'].to(NRE_trainer.device)
        output = model(b_input_token)
        label_indices = np.argmax(output[0].to('cpu').numpy(), axis=2)
        
        for i, input_token in enumerate(b_input_token.to('cpu').numpy()):
            
            words  = tokenizer_chinese.convert_ids_to_tokens(input_token)
            labels = label_indices[i]
            ID     = batch['id'][i]
                                
            names = []
            name_string = ''
            
            for word, label in zip(words, labels):
                if label == 1:
                    name_string += word
                else:
                    if name_string != '':
                        names.append(name_string)
                        name_string = ''
                        
            if ID in answer:
                arr = answer[ID]
                arr+=names
            else:
                answer[ID] = names
                
for key in answer:
    answer[key] = list(set(answer[key]))

 81%|████████▏ | 284/349 [01:43<00:23,  2.73it/s]

In [None]:
train_df['name'] = train_df.news_ID.apply(lambda x: answer[x-1])

In [None]:
train_df.head(5)