In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import pandas as pd
import warnings
import os
import sys
import re

sys.path.append("../")
warnings.filterwarnings("ignore")

In [2]:
import os


root_dir = os.path.abspath('../')

data_path = os.path.join(root_dir, 'datadrive/bailian')
if not os.path.exists(data_path):
    os.makedirs(data_path, exist_ok=True)

train_path = os.path.join(data_path, 'train.csv')
valid_path = os.path.join(data_path, 'valid.csv')

model_dir = os.path.join(root_dir, 'datadrive/bert/chinese_L-12_H-768_A-12/')
init_checkpoint_pt = os.path.join(root_dir, 'datadrive/models/chinese_L-12_H-768_A-12/pytorch_model.bin')
bert_config_file = os.path.join(model_dir, 'bert_config.json')
vocab_file = os.path.join(model_dir, 'vocab.txt')



In [63]:
def build_data():
    
    import re
    import json
    
    p = re.compile(r'(.+?)/(?:([a-z]{1,2})(?:$| ))')
    seg_file = os.path.join(data_path, 'final_baidu-23w.txt')
    
    delimiter='△△△'
    with open(seg_file) as fin, open(train_path, 'w') as train_f, open(valid_path, 'w') as valid_f:
        train_f.write(f'0{delimiter}1\n')
        valid_f.write(f'0{delimiter}1\n')
        for line in fin:
            line = line.strip()
            if not line:
                continue
                
            import random
            score = random.random()
            
            fout = train_f if score > 0.3 else valid_f
            words = []
            flags = []
            for word, flag in p.findall(line):
                char_list = [c if c not in [' ', '\x7f']  else 'unk' for c in list(word)]
                
                tag_list = [f'B_{flag}'] + [f'I_{flag}']  * (len(char_list) - 1)
                
                if len(char_list) != len(tag_list):
                    print(line)
                    print(word, flag)
                    print(char_list, tag_list)
                    
                words.extend(char_list)
                flags.extend(tag_list)
                
            assert len(words) == len(flags)
                
            fout.write(delimiter.join([
                ' '.join(flags),
                ' '.join(words)
            ]))
            fout.write('\n')
            
            
            
build_data()
    

In [78]:
from modules import BertNerData as NerData

data = NerData.create(train_path, valid_path, vocab_file, data_type="bert_uncased", cuda=False,is_cls=False)



In [79]:
len(data.train_dl.dataset), len(data.valid_dl.dataset)

(155954, 67275)

In [81]:
data.id2label

['<pad>',
 '[CLS]',
 'B_m',
 'B_w',
 'B_nt',
 'I_nt',
 'B_nr',
 'I_nr',
 'I_w',
 'X',
 'B_nz',
 'I_nz',
 'B_n',
 'I_n',
 'B_v',
 'I_v',
 'B_p',
 'I_m',
 'I_p',
 'B_r',
 'I_r',
 'B_vn',
 'I_vn',
 'B_vd',
 'I_vd',
 'B_a',
 'I_a',
 'B_t',
 'I_t',
 'B_ti',
 'I_ti',
 'B_u',
 'B_c',
 'B_d',
 'I_d',
 'B_ns',
 'I_ns',
 'I_c',
 'B_f',
 'B_ad',
 'I_ad',
 'I_f',
 'B_xc',
 'B_s',
 'I_s',
 'B_an',
 'I_an',
 'B_nw',
 'I_nw',
 'I_u',
 'B_xx',
 'I_xx',
 'I_xc',
 'B_q',
 'I_q']

In [80]:
from modules.models.bert_models import BertBiLSTMAttnCRF
model = BertBiLSTMAttnCRF.create(len(data.label2idx), bert_config_file, init_checkpoint_pt, enc_hidden_dim=256, use_cuda=False)
model.get_n_trainable_params()

1160181

In [82]:
from modules import NerLearner
num_epochs = 10
learner = NerLearner(model, data,
                     best_model_path=init_checkpoint_pt,
                     lr=0.001, clip=1.0, sup_labels=data.id2label,
                     t_total=num_epochs * len(data.train_dl))

In [83]:
learner.fit(num_epochs, target_metric='f1')

2019-03-23 17:51:53,880 INFO: Resuming train... Current epoch 0.
train loss: 752.8927001953125:   0%|          | 6/9748 [02:37<71:24:21, 26.39s/it]