In [1]:
import os
import re
import codecs
try:
    import cPickle as pickle
except ImportError:
    import pickle
import gc
import datetime,time

import numpy as np
import tensorflow as tf
import keras

from keras_contrib.layers import CRF
import keras.backend.tensorflow_backend as KTF

from tqdm import tqdm

from gensim.models import KeyedVectors
import multiprocessing
from gensim.models import Word2Vec
from gensim.models.word2vec import LineSentence

from conlleval import return_report

#进行配置，使用30%的GPU
configure = tf.ConfigProto()
configure.gpu_options.per_process_gpu_memory_fraction = 0.3
session = tf.Session(config=configure)

# 设置session
KTF.set_session(session )

Using TensorFlow backend.


In [2]:
base_dir = "./data_sets/"
train_file_path = os.path.join(base_dir, "train.txt")
dev_file_path = os.path.join(base_dir, "dev.txt")
test_file_path = os.path.join(base_dir, "test.txt")

embedding_path = os.path.join(base_dir, "word2vec.txt")

char_dico_path = os.path.join(base_dir, "maps.pkl")

model_path = './model/bilstm/bilstm_model.h5'
dev_result_path = "./temp/bilstm.txt"
test_result_path = "./result/bilstm/bilstm.txt"


tag_to_id ={
    'O': 0,   
    'B-t': 1,  'I-t': 2,  'E-t': 3,  'S-t': 4, 
    'B-j': 5,  'I-j': 6,  'E-j': 7,  'S-j': 8,
    'B-b': 9,  'I-b': 10, 'E-b': 11, 'S-b': 12,
    'B-z': 13, 'I-z': 14, 'E-z': 15, 'S-z': 16,
    'B-s': 17, 'I-s': 18, 'E-s': 19, 'S-s': 20}

max_len = 200

In [3]:
def zero_digits(s):
    """
    将数字都变为0
    :param s 输入句子
    """
    return re.sub('\d', '0', s)

def load_sentences(path, zeros=False):
    """
    载入句子
    """
    sentences = []
    sentence = []
    num = 0
    for line in codecs.open(path, 'r', 'utf8'):
        num += 1
        line = zero_digits(line.rstrip()) if zeros else line.rstrip()
        if not line:
            if len(sentence) > 0:
                if 'DOCSTART' not in sentence[0][0]:
                    sentences.append(sentence)
                sentence = []
        else:
            if line[0] == " ":
                line = "$" + line[1:]
                word = line.split()
            else:
                word= line.split()
            assert len(word) >= 2, print([word[0]])
            sentence.append(word)
    if len(sentence) > 0:
        if 'DOCSTART' not in sentence[0][0]:
            sentences.append(sentence)
    return sentences

def char_mapping(dico_path):
    with open(dico_path, 'rb') as f:
        char_to_id = pickle.load(f)
    id_to_char = {value: key for key, value in char_to_id.items()}
    return char_to_id, id_to_char
        
def prepare_dataset(sentences, char_to_id, tag_to_id, lower=False):
    def f(x):
        return x.lower() if lower else x
    data = []
    for s in sentences:
        string = [w[0] for w in s]
        chars = [char_to_id[f(w) if f(w) in char_to_id else '<UNK>'] for w in string]
        tags = [tag_to_id[w[-1]] for w in s]
        data.append([string, chars, tags])
    return data

def padding_data(data, max_len=200):
    """
    填充数据，不足的补0，超过部分截断
    """
    strings = []
    chars = []
    targets = []
    for line in tqdm(data):
        string, char, target = line
        if len(string) >= max_len:
            strings.append(string[0:max_len])
            chars.append(char[0:max_len])
            targets.append(target[0:max_len])
        else:
            padding = [0] * (max_len - len(string))
            strings.append(string + padding)
            chars.append(char + padding)
            targets.append(target + padding)
    return [strings, chars, targets]

def to_one_hot(targets, num_classes):
    """
    将标签变为one-hot表示
    """
    all_labels = []
    for target in targets:
        all_labels.append(tf.keras.utils.to_categorical(target, num_classes))
    return np.array(all_labels)

def get_train_dev_test(path, char_to_id, tag_to_id, num_classes, lower=False, max_len=200):
    """
    载入数据
    """
    sentences = load_sentences(path)
    data = prepare_dataset(sentences, char_to_id, tag_to_id, lower=lower)
    data = padding_data(data, max_len=max_len)
    X = np.array(data[1])
    y = to_one_hot(data[2], num_classes)
    return X, y
    

def load_word2vec(embedding_path, word_index, embed_dim=128): 
    """
    载入预训练好的word2vec(Google-News)
    """
    # 载入word2vec词向量
    word2vec_dict = KeyedVectors.load_word2vec_format(embedding_path, binary=False)
    
    embedding_index = dict()
    for word in word2vec_dict.wv.vocab:
        embedding_index[word] = word2vec_dict.word_vec(word)
    print('Load %s word vectors.' % len(embedding_index))
    
    all_embs = np.stack(list(embedding_index.values()))
    # emb_mean, emb_std = all_embs.mean(), all_embs.std()
    emb_mean = np.mean(all_embs, axis=0)
    vocab_size = len(word_index)
    # 初始化权重
    embedding_matrix = np.zeros((vocab_size+1, embed_dim))
    gc.collect()
    # 对权重矩阵进行赋值，未找到时用词向量平均值填充
    for word, i in word_index.items():
        if i < vocab_size + 1:
            embedding_vector = embedding_index.get(word)
            if embedding_vector is not None:
                embedding_matrix[i] = embedding_vector
            else:
                embedding_matrix[i] = emb_mean
    zero_vector = np.zeros((1, embed_dim))
    embedding_matrix[0] = zero_vector
    del embedding_index
    return embedding_matrix

def one_hot_to_label(one_hot, id_to_tag):
    """
    将one-hot变为标签
    """
    labels = []
    for tags in one_hot:
        temp = []
        for tag in tags:
            tag_id = tag.argmax()
            temp.append(id_to_tag[tag_id])
        labels.append(temp)
    return labels
            
def write_result(fpath, labels, result_path, char_to_id, tag_to_id, id_to_tag, lower=False, max_len=200):
    sentences = load_sentences(fpath)
    data = prepare_dataset(sentences, char_to_id, tag_to_id, lower=lower)
    temps = []
    for item in data:
        temp = []
        if len(item[0]) <= max_len:
            for i in range(len(item[0])):
                temp.append([item[0][i], id_to_tag[item[-1][i]]])
        temps.append(temp)
    all_data = []
    for i in range(len(temps)):
        temp = []
        if len(temps[i]) <= max_len:
            for j in range(len(temps[i])):
                m = temps[i][j]
                z = labels[i][j]
                m.append(z)
                temp.append(m)
            all_data.append(temp)
        else:
            labels[i] += ['O'] * (len(temps[i])-max_len)
            for j in range(len(temps[i])):
                m = temps[i][j]
                z = labels[i][j]
                m.append(z)
                temp.append(m)
            all_data.append(temp)
    with open(result_path, 'w', encoding="utf-8") as fw:
        for sentence in all_data:
            for item in sentence:
                fw.write(item[0] + ' ' + item[1] + ' ' + item[2] + '\n')
            fw.write('\n')
            
        
def evaluate(result_path):
    report = return_report(result_path)
    for item in report:
        print(item)

# 准备数据

In [4]:
char_to_id, id_to_char = char_mapping(char_dico_path)
X_train, y_train = get_train_dev_test(train_file_path, char_to_id, tag_to_id, len(tag_to_id))
X_dev, y_dev = get_train_dev_test(dev_file_path, char_to_id, tag_to_id, len(tag_to_id))
X_test, y_test = get_train_dev_test(test_file_path, char_to_id, tag_to_id, len(tag_to_id))
embedding_matrix = load_word2vec(embedding_path, char_to_id)
id_to_tag = {value: key for key, value in tag_to_id.items()}

100%|██████████| 3543/3543 [00:00<00:00, 86983.41it/s]
100%|██████████| 1517/1517 [00:00<00:00, 104461.65it/s]
100%|██████████| 2123/2123 [00:00<00:00, 20660.83it/s]


Load 2122 word vectors.




In [5]:
print("训练集信息：句子：{}, 标签：{}".format(X_train.shape, y_train.shape))
print("验证集信息：句子：{}, 标签：{}".format(X_dev.shape, y_dev.shape))
print("测试集信息：句子：{}, 标签：{}".format(X_test.shape, y_test.shape))
print("Embedding信息：{}".format(embedding_matrix.shape))

训练集信息：句子：(3543, 200), 标签：(3543, 200, 21)
验证集信息：句子：(1517, 200), 标签：(1517, 200, 21)
测试集信息：句子：(2123, 200), 标签：(2123, 200, 21)
Embedding信息：(2124, 128)


# 超参数配置

In [6]:
class Config:
    max_len = 200
    lstm_dim = 100
    vocab_size = len(char_to_id)
    class_nums = len(tag_to_id)
    embedding_dim = 128
    dropout_rate = 0.5
    epochs = 120
    batch_size = 128
    
config = Config()

# 搭建模型

In [7]:
# Input Layer
char_input = keras.layers.Input(shape=(config.max_len,), name="char_input_layer")
# Embedding Layer
embedding_layer = keras.layers.Embedding(config.vocab_size+1,
                                         config.embedding_dim,
                                         weights=[embedding_matrix],
                                         input_length=config.max_len,
                                         mask_zero=True,
                                         trainable=False)(char_input)
# BiLSTM Layer
bilstm = keras.layers.Bidirectional(keras.layers.LSTM(config.lstm_dim, unroll=False, return_sequences=True))(embedding_layer)
# Dropout Layer
dropout_layer = keras.layers.Dropout(config.dropout_rate)(bilstm)
# CRF Layer
crf = CRF(config.class_nums, sparse_target=False)
crf_layer = crf(dropout_layer)

model = keras.models.Model(inputs=char_input, outputs=crf_layer)
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
char_input_layer (InputLayer (None, 200)               0         
_________________________________________________________________
embedding_1 (Embedding)      (None, 200, 128)          271872    
_________________________________________________________________
bidirectional_1 (Bidirection (None, 200, 200)          183200    
_________________________________________________________________
dropout_1 (Dropout)          (None, 200, 200)          0         
_________________________________________________________________
crf_1 (CRF)                  (None, 200, 21)           4704      
Total params: 459,776
Trainable params: 187,904
Non-trainable params: 271,872
_________________________________________________________________


# 模型训练

In [8]:
from keras_contrib.losses import crf_loss
from keras_contrib.metrics import crf_accuracy

adadelta = keras.optimizers.Adadelta(lr=1.0, rho=0.95, epsilon=1e-06)
# model.compile(loss=crf.loss_function, optimizer=adadelta, metrics=[crf.accuracy])
model.compile(loss=crf_loss, optimizer=adadelta, metrics=[crf_accuracy])

best_fscore = 0
for i in range(config.epochs):
    print('Epoch '+str(i+1)+"/"+str(config.epochs))
    model.fit(X_train, y_train, nb_epoch=1, batch_size=config.batch_size)
    y_pred = model.predict(X_dev, batch_size=config.batch_size)
    labels = one_hot_to_label(y_pred, id_to_tag)
    write_result(dev_file_path, labels, dev_result_path, char_to_id, tag_to_id, id_to_tag)
    report, fscore = return_report(dev_result_path)
    for item in report:
        print(item)
    if fscore > best_fscore:
        model.save(model_path)
        best_fscore = fscore

Epoch 1/120


  # This is added back by InteractiveShellApp.init_path()


Epoch 1/1
processed 59449 tokens with 6198 phrases; found: 1604 phrases; correct: 422.

accuracy:  72.65%; precision:  26.31%; recall:   6.81%; FB1:  10.82

                b: precision:   0.00%; recall:   0.00%; FB1:   0.00  0

                j: precision:  18.74%; recall:   3.91%; FB1:   6.47  459

                s: precision:  14.47%; recall:   5.58%; FB1:   8.06  767

                t: precision:  60.65%; recall:  15.19%; FB1:  24.30  371

                z: precision:   0.00%; recall:   0.00%; FB1:   0.00  7

Epoch 2/120
Epoch 1/1
processed 59449 tokens with 6198 phrases; found: 3888 phrases; correct: 1941.

accuracy:  79.43%; precision:  49.92%; recall:  31.32%; FB1:  38.49

                b: precision:   0.00%; recall:   0.00%; FB1:   0.00  0

                j: precision:  46.72%; recall:  26.23%; FB1:  33.60  1235

                s: precision:  37.10%; recall:  27.70%; FB1:  31.72  1485

                t: precision:  73.42%; recall:  54.83%; FB1:  62.78  1106

          

processed 59449 tokens with 6198 phrases; found: 5575 phrases; correct: 4666.

accuracy:  91.73%; precision:  83.70%; recall:  75.28%; FB1:  79.27

                b: precision:  64.15%; recall:  52.04%; FB1:  57.46  159

                j: precision:  88.96%; recall:  85.73%; FB1:  87.31  2120

                s: precision:  77.05%; recall:  66.16%; FB1:  71.19  1708

                t: precision:  93.03%; recall:  86.50%; FB1:  89.64  1377

                z: precision:  38.39%; recall:  24.40%; FB1:  29.83  211

Epoch 15/120
Epoch 1/1
processed 59449 tokens with 6198 phrases; found: 5825 phrases; correct: 4840.

accuracy:  92.30%; precision:  83.09%; recall:  78.09%; FB1:  80.51

                b: precision:  73.88%; recall:  50.51%; FB1:  60.00  134

                j: precision:  89.50%; recall:  87.55%; FB1:  88.51  2152

                s: precision:  75.19%; recall:  69.48%; FB1:  72.22  1838

                t: precision:  90.11%; recall:  89.80%; FB1:  89.96  1476

         

processed 59449 tokens with 6198 phrases; found: 5943 phrases; correct: 5137.

accuracy:  93.85%; precision:  86.44%; recall:  82.88%; FB1:  84.62

                b: precision:  71.93%; recall:  62.76%; FB1:  67.03  171

                j: precision:  92.02%; recall:  92.73%; FB1:  92.37  2217

                s: precision:  78.82%; recall:  72.60%; FB1:  75.58  1832

                t: precision:  93.10%; recall:  91.15%; FB1:  92.12  1450

                z: precision:  65.93%; recall:  54.22%; FB1:  59.50  273

Epoch 28/120
Epoch 1/1
processed 59449 tokens with 6198 phrases; found: 5846 phrases; correct: 5109.

accuracy:  93.78%; precision:  87.39%; recall:  82.43%; FB1:  84.84

                b: precision:  75.78%; recall:  62.24%; FB1:  68.35  161

                j: precision:  92.15%; recall:  92.32%; FB1:  92.23  2204

                s: precision:  81.29%; recall:  72.75%; FB1:  76.78  1780

                t: precision:  93.73%; recall:  90.88%; FB1:  92.29  1436

         

processed 59449 tokens with 6198 phrases; found: 6086 phrases; correct: 5306.

accuracy:  94.43%; precision:  87.18%; recall:  85.61%; FB1:  86.39

                b: precision:  70.62%; recall:  69.90%; FB1:  70.26  194

                j: precision:  92.36%; recall:  94.00%; FB1:  93.17  2239

                s: precision:  80.21%; recall:  76.22%; FB1:  78.16  1890

                t: precision:  93.52%; recall:  92.57%; FB1:  93.04  1466

                z: precision:  72.05%; recall:  64.46%; FB1:  68.04  297

Epoch 41/120
Epoch 1/1
processed 59449 tokens with 6198 phrases; found: 6142 phrases; correct: 5302.

accuracy:  94.13%; precision:  86.32%; recall:  85.54%; FB1:  85.93

                b: precision:  59.91%; recall:  70.92%; FB1:  64.95  232

                j: precision:  91.77%; recall:  94.23%; FB1:  92.98  2259

                s: precision:  81.47%; recall:  76.02%; FB1:  78.65  1856

                t: precision:  92.56%; recall:  93.25%; FB1:  92.90  1492

         

processed 59449 tokens with 6198 phrases; found: 6043 phrases; correct: 5329.

accuracy:  94.54%; precision:  88.18%; recall:  85.98%; FB1:  87.07

                b: precision:  74.33%; recall:  70.92%; FB1:  72.58  187

                j: precision:  91.32%; recall:  94.18%; FB1:  92.73  2269

                s: precision:  83.01%; recall:  76.17%; FB1:  79.44  1825

                t: precision:  94.39%; recall:  93.18%; FB1:  93.78  1462

                z: precision:  74.33%; recall:  67.17%; FB1:  70.57  300

Epoch 54/120
Epoch 1/1
processed 59449 tokens with 6198 phrases; found: 6067 phrases; correct: 5374.

accuracy:  94.76%; precision:  88.58%; recall:  86.71%; FB1:  87.63

                b: precision:  78.09%; recall:  70.92%; FB1:  74.33  178

                j: precision:  93.63%; recall:  94.23%; FB1:  93.93  2214

                s: precision:  81.44%; recall:  77.88%; FB1:  79.62  1902

                t: precision:  94.03%; recall:  93.59%; FB1:  93.81  1474

         

processed 59449 tokens with 6198 phrases; found: 5941 phrases; correct: 5317.

accuracy:  94.59%; precision:  89.50%; recall:  85.79%; FB1:  87.60

                b: precision:  84.47%; recall:  69.39%; FB1:  76.19  161

                j: precision:  94.42%; recall:  93.05%; FB1:  93.73  2168

                s: precision:  81.49%; recall:  77.48%; FB1:  79.43  1891

                t: precision:  94.59%; recall:  92.10%; FB1:  93.33  1442

                z: precision:  82.08%; recall:  68.98%; FB1:  74.96  279

Epoch 67/120
Epoch 1/1
processed 59449 tokens with 6198 phrases; found: 6057 phrases; correct: 5406.

accuracy:  94.86%; precision:  89.25%; recall:  87.22%; FB1:  88.23

                b: precision:  81.07%; recall:  69.90%; FB1:  75.07  169

                j: precision:  93.25%; recall:  94.77%; FB1:  94.00  2236

                s: precision:  82.51%; recall:  77.78%; FB1:  80.07  1875

                t: precision:  94.13%; recall:  94.19%; FB1:  94.16  1482

         

processed 59449 tokens with 6198 phrases; found: 6009 phrases; correct: 5410.

accuracy:  95.04%; precision:  90.03%; recall:  87.29%; FB1:  88.64

                b: precision:  85.28%; recall:  70.92%; FB1:  77.44  163

                j: precision:  94.36%; recall:  94.32%; FB1:  94.34  2199

                s: precision:  82.55%; recall:  78.73%; FB1:  80.60  1897

                t: precision:  94.98%; recall:  93.25%; FB1:  94.11  1454

                z: precision:  84.12%; recall:  75.00%; FB1:  79.30  296

Epoch 80/120
Epoch 1/1
processed 59449 tokens with 6198 phrases; found: 6086 phrases; correct: 5417.

accuracy:  94.94%; precision:  89.01%; recall:  87.40%; FB1:  88.20

                b: precision:  75.40%; recall:  71.94%; FB1:  73.63  187

                j: precision:  92.59%; recall:  94.32%; FB1:  93.45  2241

                s: precision:  83.25%; recall:  78.98%; FB1:  81.06  1887

                t: precision:  94.60%; recall:  93.52%; FB1:  94.06  1464

         

processed 59449 tokens with 6198 phrases; found: 5961 phrases; correct: 5374.

accuracy:  94.96%; precision:  90.15%; recall:  86.71%; FB1:  88.40

                b: precision:  82.35%; recall:  71.43%; FB1:  76.50  170

                j: precision:  94.10%; recall:  94.27%; FB1:  94.19  2204

                s: precision:  83.29%; recall:  78.18%; FB1:  80.65  1867

                t: precision:  95.98%; recall:  91.83%; FB1:  93.86  1417

                z: precision:  80.86%; recall:  73.80%; FB1:  77.17  303

Epoch 93/120
Epoch 1/1
processed 59449 tokens with 6198 phrases; found: 6025 phrases; correct: 5425.

accuracy:  95.09%; precision:  90.04%; recall:  87.53%; FB1:  88.77

                b: precision:  82.49%; recall:  74.49%; FB1:  78.28  177

                j: precision:  94.31%; recall:  94.14%; FB1:  94.22  2196

                s: precision:  83.15%; recall:  79.64%; FB1:  81.36  1905

                t: precision:  95.50%; recall:  93.11%; FB1:  94.29  1444

         

processed 59449 tokens with 6198 phrases; found: 5973 phrases; correct: 5393.

accuracy:  94.97%; precision:  90.29%; recall:  87.01%; FB1:  88.62

                b: precision:  80.00%; recall:  73.47%; FB1:  76.60  180

                j: precision:  93.92%; recall:  93.32%; FB1:  93.62  2186

                s: precision:  84.01%; recall:  79.79%; FB1:  81.85  1889

                t: precision:  96.32%; recall:  91.90%; FB1:  94.06  1413

                z: precision:  81.31%; recall:  74.70%; FB1:  77.86  305

Epoch 106/120
Epoch 1/1
processed 59449 tokens with 6198 phrases; found: 6292 phrases; correct: 5513.

accuracy:  94.95%; precision:  87.62%; recall:  88.95%; FB1:  88.28

                b: precision:  73.74%; recall:  74.49%; FB1:  74.11  198

                j: precision:  92.26%; recall:  94.82%; FB1:  93.52  2261

                s: precision:  81.57%; recall:  81.00%; FB1:  81.28  1975

                t: precision:  91.42%; recall:  95.75%; FB1:  93.54  1551

        

processed 59449 tokens with 6198 phrases; found: 6170 phrases; correct: 5492.

accuracy:  95.13%; precision:  89.01%; recall:  88.61%; FB1:  88.81

                b: precision:  81.01%; recall:  73.98%; FB1:  77.33  179

                j: precision:  94.12%; recall:  94.55%; FB1:  94.33  2210

                s: precision:  81.91%; recall:  81.05%; FB1:  81.48  1968

                t: precision:  93.14%; recall:  95.41%; FB1:  94.26  1517

                z: precision:  81.76%; recall:  72.89%; FB1:  77.07  296

Epoch 119/120
Epoch 1/1
processed 59449 tokens with 6198 phrases; found: 6139 phrases; correct: 5477.

accuracy:  95.12%; precision:  89.22%; recall:  88.37%; FB1:  88.79

                b: precision:  77.08%; recall:  75.51%; FB1:  76.29  192

                j: precision:  93.34%; recall:  94.86%; FB1:  94.09  2236

                s: precision:  83.26%; recall:  80.49%; FB1:  81.85  1923

                t: precision:  94.71%; recall:  94.26%; FB1:  94.48  1474

        

In [9]:
from keras_contrib.layers import CRF
from keras.models import load_model

custom_objects = {"CRF": CRF,"crf_loss": crf_loss, "crf_accuracy": crf_accuracy}
model = load_model(model_path, custom_objects=custom_objects)

y_test_pred = model.predict(X_test, batch_size=config.batch_size)
labels = one_hot_to_label(y_test_pred, id_to_tag)
write_result(test_file_path, labels, test_result_path, char_to_id, tag_to_id, id_to_tag)
report, fscore = return_report(test_result_path)
for item in report:
    print(item)

processed 85933 tokens with 9098 phrases; found: 9021 phrases; correct: 7746.

accuracy:  93.04%; precision:  85.87%; recall:  85.14%; FB1:  85.50

                b: precision:  74.85%; recall:  70.35%; FB1:  72.53  485

                j: precision:  89.85%; recall:  92.56%; FB1:  91.19  3103

                s: precision:  83.36%; recall:  78.93%; FB1:  81.08  2710

                t: precision:  89.92%; recall:  92.07%; FB1:  90.98  2311

                z: precision:  62.62%; recall:  57.21%; FB1:  59.79  412

