In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import pandas as pd
import warnings
import os
import sys
import re

sys.path.append("../")
warnings.filterwarnings("ignore")

In [2]:
import os


root_dir = os.path.abspath('../')

data_path = os.path.join(root_dir, 'datadrive/bailian')
if not os.path.exists(data_path):
    os.makedirs(data_path, exist_ok=True)

train_path = os.path.join(data_path, 'train.csv')
valid_path = os.path.join(data_path, 'valid_small.csv')

model_dir = os.path.join(root_dir, 'datadrive/models/chinese_L-12_H-768_A-12/')
init_checkpoint_pt = os.path.join(model_dir, 'bert_model.bin')
bert_config_file = os.path.join(model_dir, 'bert_config.json')
vocab_file = os.path.join(model_dir, 'vocab.txt')
model_pt = os.path.join(model_dir, 'pytorch_model.bin')



In [None]:
def build_data():
    
    import re
    import json
    
    p = re.compile(r'(.+?)/(?:([a-z]{1,2})(?:$| ))')
    seg_file = os.path.join(data_path, 'final_baidu-23w.txt')
    
    delimiter='△△△'
    with open(seg_file) as fin, open(train_path, 'w') as train_f, open(valid_path, 'w') as valid_f:
        train_f.write(f'0{delimiter}1\n')
        valid_f.write(f'0{delimiter}1\n')
        for line in fin:
            line = line.strip()
            if not line:
                continue
                
            import random
            score = random.random()
            
            fout = train_f if score > 0.3 else valid_f
            words = []
            flags = []
            for word, flag in p.findall(line):
                char_list = [c if c not in [' ', '\x7f']  else 'unk' for c in list(word)]
                
                tag_list = [f'B_{flag}'] + [f'I_{flag}']  * (len(char_list) - 1)
                
                if len(char_list) != len(tag_list):
                    print(line)
                    print(word, flag)
                    print(char_list, tag_list)
                    
                words.extend(char_list)
                flags.extend(tag_list)
                
            assert len(words) == len(flags)
                
            fout.write(delimiter.join([
                ' '.join(flags),
                ' '.join(words)
            ]))
            fout.write('\n')
            
            
            
build_data()
    

In [5]:
from sklearn.utils import shuffle
shuffle([[1,2], [3, 4]])

[[3, 4], [1, 2]]

In [69]:
# 正常训练

from modules import BertNerData as NerData

data = NerData.create(
    train_path,
    valid_path, 
    vocab_file,
    data_type="bert_uncased",
    is_cls=False
)

import torch
from importlib import reload
from modules.models import bert_models
reload(bert_models)

model = bert_models.BertBiLSTMAttnCRF.create(
    len(data.label2idx),
    bert_config_file, 
    init_checkpoint_pt,
    enc_hidden_dim=256
)
model.get_n_trainable_params()

from modules.train import train
reload(train)
num_epochs = 20
learner = train.NerLearner(model, data,
                     best_model_path=model_pt,
                     lr=0.001, clip=1.0, sup_labels=data.id2label,
                     t_total=num_epochs * len(data.train_dl))

learner.fit(num_epochs, target_metric='f1')



HBox(children=(IntProgress(value=0, description='bert data', max=155954, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='bert data', max=499, style=ProgressStyle(description_width='i…

2019-03-26 12:38:13,079 INFO: Resuming train... Current epoch 0.


HBox(children=(IntProgress(value=0, max=9748), HTML(value='')))

In [4]:
len(data.train_dl.dataset), len(data.valid_dl.dataset)

(4999, 499)

In [13]:
# 恢复训练

import torch
from importlib import reload
from modules.models import bert_models
from modules.models import released_models
from modules.data import bert_data
reload(bert_data)
reload(bert_models)
reload(released_models)

config_file = os.path.join(model_dir, 'pytorch_model.json')
learner = released_models.recover_from_config(config_file)
num_epochs = 20
learner.load_model()
learner.fit(num_epochs, target_metric='f1')


HBox(children=(IntProgress(value=0, description='bert data', max=155954, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='bert data', max=499, style=ProgressStyle(description_width='i…

2019-03-27 03:21:59,756 INFO: Resuming train... Current epoch 0.


HBox(children=(IntProgress(value=0, max=9748), HTML(value='')))

2019-03-27 07:03:02,948 INFO: 
epoch 1, average train epoch loss=13.276



HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

2019-03-27 07:03:41,551 INFO: on epoch 0 by max_f1: 0.918
2019-03-27 07:03:41,552 INFO: Saving new best model...


             precision    recall  f1-score   support

      <pad>      0.000     0.000     0.000         0
      [CLS]      1.000     1.000     1.000       499
        B_m      0.957     0.950     0.953       776
        B_w      0.996     0.995     0.996      6365
       B_nt      0.896     0.944     0.919      2248
       I_nt      0.901     0.950     0.925     11396
       B_nr      0.939     0.981     0.960      1692
       I_nr      0.920     0.979     0.949      2890
        I_w      0.990     1.000     0.995       578
          X      1.000     1.000     1.000       680
       B_nz      0.762     0.582     0.660       766
       I_nz      0.723     0.591     0.650      2368
        B_n      0.914     0.887     0.901      5111
        I_n      0.911     0.884     0.897      8455
        B_v      0.930     0.905     0.917      3741
        I_v      0.915     0.892     0.903      3807
        B_p      0.966     0.950     0.958       923
        I_m      0.974     0.978     0.976   

HBox(children=(IntProgress(value=0, max=9748), HTML(value='')))

In [12]:
type(learner.data.train_dl)

modules.data.bert_data.DataLoaderForTrain

In [15]:
from modules.data import bert_data
reload(bert_data)
dl = bert_data.get_bert_data_loader_for_predict(valid_path, learner)

HBox(children=(IntProgress(value=0, description='bert data', max=499, style=ProgressStyle(description_width='i…

In [17]:
# learner.load_model()
preds = learner.predict(dl)

HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

In [19]:
from modules.utils.plot_metrics import get_bert_span_report
clf_report = get_bert_span_report(dl, preds)
print(clf_report)

             precision    recall  f1-score   support

          c      0.420     0.427     0.424       625
         ti      0.540     0.572     0.556      2099
          a      0.439     0.430     0.435      1064
          p      0.499     0.494     0.497       923
          m      0.457     0.455     0.456       776
         nw      0.500     0.394     0.441        33
         xc      0.394     0.413     0.403        63
         an      0.231     0.246     0.238        61
          u      0.431     0.432     0.431      1765
         nz      0.385     0.299     0.337       766
          n      0.488     0.474     0.481      5111
          q      0.387     0.300     0.338        40
         nt      0.524     0.552     0.538      2248
         vn      0.409     0.423     0.416      1653
         vd      0.180     0.208     0.193        53
         nr      0.545     0.569     0.557      1692
         xx      0.000     0.000     0.000        31
          f      0.462     0.469     0.465   

In [10]:
from blnlp import pos
from importlib import reload
reload(pos)
tagger = pos.PosTagger()


In [13]:
import time

st = time.time()
text = '近日，编程猫（深圳点猫科技有限公司）正式对外宣布完成B轮1.2亿元融资。本轮融资由高瓴资本领投，清流资本、清晗基金跟投，天使轮投资者猎豹移动继续跟投。'
# text = '未来编程教育产业将蓬勃发展，编程猫作为提供工具与内容的企业，有望长期处于行业领跑者地位。'
res = tagger.cut(text)
ed = time.time()
print(ed - st)
res

HBox(children=(IntProgress(value=0, description='bert data', max=1, style=ProgressStyle(description_width='ini…




HBox(children=(IntProgress(value=0, max=1), HTML(value='')))


0.28635597229003906


[('近日', 't'),
 ('，', 'w'),
 ('编程猫', 'nt'),
 ('（', 'w'),
 ('深圳点猫科技有限公司', 'nt'),
 ('）', 'w'),
 ('正式', 'ad'),
 ('对外', 'd'),
 ('宣布完成', 'v'),
 ('B', 'xc'),
 ('轮', 'q'),
 ('1.2亿元', 'm'),
 ('融资', 'vn'),
 ('。', 'w'),
 ('本轮', 'r'),
 ('融资', 'vn'),
 ('由', 'p'),
 ('高瓴资本', 'nt'),
 ('领', 'v'),
 ('投', 'vn'),
 ('，', 'w'),
 ('清流资本', 'nt'),
 ('、', 'w'),
 ('清晗基金', 'nt'),
 ('跟', 'vd'),
 ('投', 'v'),
 ('，', 'w'),
 ('天使轮', 'nz'),
 ('投资者', 'n'),
 ('猎豹移动', 'nt'),
 ('继续跟', 'vd'),
 ('投', 'v'),
 ('。', 'w')]

In [64]:
tagger.learner.data.is_meta

False