In [42]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import pandas as pd
import warnings
import os
import sys
import re

sys.path.append("../")
warnings.filterwarnings("ignore")

In [43]:
import os


root_dir = os.path.abspath('../bailian_nlp')

data_path = os.path.join(root_dir, 'datadrive/bailian/pos')
if not os.path.exists(data_path):
    os.makedirs(data_path, exist_ok=True)

train_path = os.path.join(data_path, 'train_small.csv')
valid_path = os.path.join(data_path, 'valid.csv')

model_dir = os.path.join(root_dir, 'datadrive/models/chinese_L-12_H-768_A-12/')
init_checkpoint_pt = os.path.join(model_dir, 'bert_model.bin')
bert_config_file = os.path.join(model_dir, 'bert_config.json')
vocab_file = os.path.join(model_dir, 'vocab.txt')
model_pt = os.path.join(model_dir, 'pos.bin')
config_file = os.path.join(model_dir, 'pos.json')



In [None]:
def build_data():
    
    import re
    import json
    
    p = re.compile(r'(.+?)/(?:([a-z]{1,2})(?:$| ))')
    seg_file = os.path.join(data_path, 'final_baidu-23w.txt')
    
    delimiter='△△△'
    
    replace_chars = [
        '\x97',
        '\uf076',
        "\ue405",
        "\ue105",
        "\ue415",
        '\x07',
        '\x7f',
        '\u3000',
        '\xa0',
        ' '
    ]
    with open(seg_file) as fin, open(train_path, 'w') as train_f, open(valid_path, 'w') as valid_f:
        train_f.write(f'0{delimiter}1\n')
        valid_f.write(f'0{delimiter}1\n')
        for line in fin:
            line = line.strip()
            if not line:
                continue
                
            import random
            score = random.random()
            
            fout = train_f if score > 0.3 else valid_f
            words = []
            flags = []
            for word, flag in p.findall(line):
                char_list = ['unk' if c in replace_chars or c.isspace() else c for c in list(word)]
                
                tag_list = [f'B_{flag}'] + [f'I_{flag}']  * (len(char_list) - 1)
                
                
                
                if len(char_list) != len(tag_list):
                    print(line)
                    print(word, flag)
                    print(char_list, tag_list)
                    
                words.extend(char_list)
                flags.extend(tag_list)
                
            assert len(words) == len(flags)
                
            fout.write(delimiter.join([
                ' '.join(flags),
                ' '.join(words)
            ]))
            fout.write('\n')
            
            
            
build_data()
    

In [14]:
# 正常训练

from bailian_nlp.modules import BertNerData as NerData

data = NerData.create(
    train_path,
    valid_path, 
    vocab_file,
    data_type="bert_uncased",
    is_cls=False,
    max_seq_len=64,
    batch_size=32
    
)

import torch
import torch.nn as nn
from importlib import reload
from bailian_nlp.modules.models import bert_models
reload(bert_models)

model = bert_models.BertBiLSTMAttnCRF.create(
    len(data.label2idx),
    bert_config_file, 
    init_checkpoint_pt,
    enc_hidden_dim=256
)
model.get_n_trainable_params()


from bailian_nlp.modules.train import train
reload(train)
num_epochs = 1
learner = train.NerLearner(model, data,
                     best_model_path=model_pt,
                     lr=0.001, clip=1.0, sup_labels=data.id2label,
                     t_total=num_epochs * len(data.train_dl))

learner.fit(num_epochs, target_metric='f1')


HBox(children=(IntProgress(value=0, description='bert data', max=999, style=ProgressStyle(description_width='i…

2019-04-13 11:35:04,723 DEBUG: get_data cost 0.557345s


HBox(children=(IntProgress(value=0, description='bert data', max=1000, style=ProgressStyle(description_width='…

2019-04-13 11:35:05,418 DEBUG: get_data cost 0.694552s
2019-04-13 11:35:06,874 INFO: Resuming train... Current epoch 0.


HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

2019-04-13 11:37:17,937 INFO: 
epoch 1, average train epoch loss=136.53



HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

2019-04-13 11:39:07,543 INFO: on epoch 0 by max_f1: 0.271
2019-04-13 11:39:07,544 INFO: Saving new best model...


             precision    recall  f1-score   support

      <pad>      0.000     0.000     0.000         0
      [CLS]      1.000     0.872     0.932      1000
        B_w      0.282     0.358     0.316      4221
       B_nt      0.038     0.002     0.003      1984
       I_nt      0.503     0.849     0.631      7531
       B_ti      0.015     0.002     0.003      1580
       I_ti      0.664     0.835     0.739      3928
       B_nr      0.452     0.032     0.060      1305
       I_nr      0.420     0.815     0.555      2200
        I_w      0.000     0.000     0.000       305
          X      0.906     0.059     0.111       492
        B_v      0.101     0.171     0.127      2849
        I_v      0.082     0.125     0.099      2746
        B_r      0.000     0.000     0.000       487
        I_r      0.000     0.000     0.000       321
        B_n      0.111     0.093     0.101      3271
        I_n      0.218     0.531     0.309      5049
       B_vn      0.000     0.000     0.000   

In [16]:
len(data.label2idx), len(data.id2label)

(55, 55)

In [39]:

for d in data.train_dl:
    batch = d

labels_mask = batch[-2]
labels = batch[-1]
inputs = batch[0]

labels_mask.shape, labels.shape, inputs.shape


(torch.Size([7, 63]), torch.Size([7, 63]), torch.Size([7, 63]))

In [45]:
data.id2label

['<pad>',
 '[CLS]',
 'B_w',
 'B_nt',
 'I_nt',
 'B_ti',
 'I_ti',
 'B_nr',
 'I_nr',
 'I_w',
 'X',
 'B_v',
 'I_v',
 'B_r',
 'I_r',
 'B_n',
 'I_n',
 'B_vn',
 'I_vn',
 'B_c',
 'B_u',
 'B_ns',
 'I_ns',
 'B_d',
 'I_d',
 'B_nz',
 'I_nz',
 'B_p',
 'B_a',
 'B_t',
 'I_t',
 'B_nw',
 'I_nw',
 'B_ad',
 'I_ad',
 'B_m',
 'I_m',
 'B_vd',
 'I_vd',
 'I_p',
 'I_a',
 'B_f',
 'B_an',
 'I_an',
 'I_f',
 'B_xx',
 'I_xx',
 'B_s',
 'I_s',
 'I_c',
 'B_xc',
 'I_u',
 'B_q',
 'I_q',
 'I_xc']

In [40]:

lens = labels_mask.sum(-1)
lens

tensor([63, 63, 63, 63, 25, 18, 16])

In [18]:
from bailian_nlp.modules.train import train
reload(train)
num_epochs = 1
learner = train.NerLearner(model, data,
                     best_model_path=model_pt,
                     lr=0.001, clip=1.0, sup_labels=data.id2label,
                     t_total=num_epochs * len(data.train_dl))

learner.fit(num_epochs, target_metric='f1')


2019-04-13 11:43:42,098 INFO: Resuming train... Current epoch 0.


HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

sss tensor([[ 1, 27, 15,  ...,  5,  6,  7],
        [ 1,  3,  4,  ...,  6,  7,  8],
        [ 1,  3,  4,  ..., 35, 36, 28],
        ...,
        [ 1, 35,  2,  ...,  0,  0,  0],
        [ 1, 35,  2,  ...,  0,  0,  0],
        [ 1,  3,  4,  ...,  0,  0,  0]]) tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])
torch.Size([32]) tensor([-164.7185, -142.0759, -192.6982, -234.6617, -194.9586, -203.6380,
        -125.6304, -198.8265, -163.9693, -171.6601, -237.0861, -188.7927,
        -194.9128,  -95.5633, -190.2001, -107.1882, -120.0904, -140.0920,
        -119.8892,  -29.8800,  -21.6516,  -17.7465,  -15.1057,  -16.8644,
         -15.3313,  -17.4735,  -14.7751,  -18.4119,  -15.0327,  -28.7552,
         -27.6639,  -19.1078], grad_fn=<ThSubBackward>)
sss tensor([[ 1,  3,  4,  ...,  2,  3,  4],
        [ 1,  3,  4,  ...,  5,  6, 25

In [None]:
# 恢复训练

from blnlp import pos
from importlib import reload
reload(pos)
tagger = pos.PosTagger()
tagger.init_env(for_train=True)

learner = tagger.learner
num_epochs = 5
learner.load_model()
learner.fit(num_epochs, target_metric='f1')


In [None]:
type(learner.data.train_dl)

In [None]:
learner.save_model()

In [None]:
from modules.data import bert_data
reload(bert_data)
dl = bert_data.get_bert_data_loader_for_predict(valid_path, learner)

In [None]:
# learner.load_model()
preds = learner.predict(dl)

In [None]:
from modules.utils.plot_metrics import get_bert_span_report
clf_report = get_bert_span_report(dl, preds)
print(clf_report)

In [None]:
from blnlp import pos
from importlib import reload
reload(pos)
tagger = pos.PosTagger()


In [None]:
import time

st = time.time()
text = '近日，编程猫（深圳点猫科技有限公司）正式对外宣布完成B轮1.2亿元融资。本轮融资由高瓴资本领投，清流资本、清晗基金跟投，天使轮投资者猎豹移动继续跟投。'
# text = '未来编程教育产业将蓬勃发展，编程猫作为提供工具与内容的企业，有望长期处于行业领跑者地位。'
text = '美年大健康产业（集团）有限公司美年大健康产业（集团）有限公司美年大健康产业（集团）有限公司始创于2004年,是中国健康体检和医疗服务集团,总部位于上海,深耕布局北京、深圳、沈阳、广州、成都、武汉、...'
res = tagger.cut(text)
ed = time.time()
print(ed - st)
res

In [None]:
tagger.learner.data.is_meta