/
keras_bert_classify_text_cnn.py
317 lines (271 loc) · 13.6 KB
/
keras_bert_classify_text_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
# -*- coding: UTF-8 -*-
# !/usr/bin/python
# @time :2019/5/18 23:51
# @author :Mo
# @function :classify text of bert and (text-cnn、r-cnn or avt-cnn)
from __future__ import division, absolute_import
from keras.objectives import sparse_categorical_crossentropy, categorical_crossentropy
from conf.path_config import path_webank_train, path_webank_dev, path_webank_test
from keras.layers import Conv1D, GlobalAveragePooling1D, GlobalMaxPooling1D
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.layers import SpatialDropout1D, Dropout
from sklearn.metrics import classification_report
from keras.layers import CuDNNGRU, CuDNNLSTM
from keras.layers import Bidirectional
from keras.layers import RepeatVector
from keras.layers import Concatenate
from keras.layers import GRU, LSTM
from keras.layers import Multiply
from keras.layers import Permute
from keras.layers import Lambda
from keras.layers import Dense
from keras.models import Model
from keras import regularizers
import numpy as np
import codecs
import keras.backend as k_keras
import logging as logger
from keras_bert import Tokenizer
from ClassificationText.bert.keras_bert_layer import AttentionWeightedAverage
from ClassificationText.bert.keras_bert_embedding import KerasBertEmbedding
from ClassificationText.bert import args
from conf.feature_config import config_name, ckpt_name, vocab_file, max_seq_len, layer_indexes, gpu_memory_fraction
def attention(inputs, single_attention_vector=False):
# attention机制
time_steps = k_keras.int_shape(inputs)[1]
input_dim = k_keras.int_shape(inputs)[2]
x = Permute((2, 1))(inputs)
x = Dense(time_steps, activation='softmax')(x)
if single_attention_vector:
x = Lambda(lambda x: k_keras.mean(x, axis=1))(x)
x = RepeatVector(input_dim)(x)
a_probs = Permute((2, 1))(x)
output_attention_mul = Multiply()([inputs, a_probs])
return output_attention_mul
class BertTextCnnModel():
def __init__(self):
# logger.info("BertBiLstmModel init start!")
print("BertBiLstmModel init start!")
self.config_path, self.checkpoint_path, self.dict_path = config_name, ckpt_name, vocab_file
self.max_seq_len, self.filters, self.embedding_dim, self.keep_prob = args.max_seq_len, args.filters, args.embedding_dim, args.keep_prob
self.activation, self.label = args.activation, args.label
# reader tokenizer
self.token_dict = {}
with codecs.open(self.dict_path, 'r', 'utf8') as reader:
for line in reader:
token = line.strip()
self.token_dict[token] = len(self.token_dict)
self.tokenizer = Tokenizer(self.token_dict)
# 这里模型可以选text-rnn、r-cnn或者是avt-cnn
# self.build_model_text_cnn()
# self.build_model_r_cnn()
self.build_model_avt_cnn()
# logger.info("BertBiLstmModel init end!")
print("BertBiLstmModel init end!")
def build_model_text_cnn(self):
######### text-cnn #########
# bert embedding
bert_inputs, bert_output = KerasBertEmbedding().bert_encode()
# text cnn
bert_output_emmbed = SpatialDropout1D(rate=self.keep_prob)(bert_output)
concat_out = []
for index, filter_size in enumerate(self.filters):
x = Conv1D(name='TextCNN_Conv1D_{}'.format(index), filters=int(self.embedding_dim/2), kernel_size=self.filters[index], padding='valid', kernel_initializer='normal', activation='relu')(bert_output_emmbed)
x = GlobalMaxPooling1D(name='TextCNN_MaxPool1D_{}'.format(index))(x)
concat_out.append(x)
x = Concatenate(axis=1)(concat_out)
x = Dropout(self.keep_prob)(x)
# 最后就是softmax
dense_layer = Dense(self.label, activation=self.activation)(x)
output_layers = [dense_layer]
self.model = Model(bert_inputs, output_layers)
def build_model_r_cnn(self):
######### RCNN #########
# bert embedding
bert_inputs, bert_output = KerasBertEmbedding().bert_encode()
# rcnn
bert_output_emmbed = SpatialDropout1D(rate=self.keep_prob)(bert_output)
if args.use_lstm:
if args.use_cudnn_cell:
layer_cell = CuDNNLSTM
else:
layer_cell = LSTM
else:
if args.use_cudnn_cell:
layer_cell = CuDNNGRU
else:
layer_cell = GRU
x = Bidirectional(layer_cell(units=args.units, return_sequences=args.return_sequences,
kernel_regularizer=regularizers.l2(args.l2 * 0.1),
recurrent_regularizer=regularizers.l2(args.l2)
))(bert_output_emmbed)
x = Dropout(args.keep_prob)(x)
x = Conv1D(filters=int(self.embedding_dim / 2), kernel_size=2, padding='valid', kernel_initializer='normal', activation='relu')(x)
x = GlobalMaxPooling1D()(x)
x = Dropout(args.keep_prob)(x)
# 最后就是softmax
dense_layer = Dense(self.label, activation=self.activation)(x)
output_layers = [dense_layer]
self.model = Model(bert_inputs, output_layers)
def build_model_avt_cnn(self):
#########text-cnn#########
# bert embedding
bert_inputs, bert_output = KerasBertEmbedding().bert_encode()
# text cnn
bert_output_emmbed = SpatialDropout1D(rate=self.keep_prob)(bert_output)
concat_x = []
concat_y = []
concat_z = []
for index, filter_size in enumerate(self.filters):
conv = Conv1D(name='TextCNN_Conv1D_{}'.format(index), filters=int(self.embedding_dim/2), kernel_size=self.filters[index], padding='valid', kernel_initializer='normal', activation='relu')(bert_output_emmbed)
x = GlobalMaxPooling1D(name='TextCNN_MaxPooling1D_{}'.format(index))(conv)
y = GlobalAveragePooling1D(name='TextCNN_AveragePooling1D_{}'.format(index))(conv)
z = AttentionWeightedAverage(name='TextCNN_Annention_{}'.format(index))(conv)
concat_x.append(x)
concat_y.append(y)
concat_z.append(z)
merge_x = Concatenate(axis=1)(concat_x)
merge_y = Concatenate(axis=1)(concat_y)
merge_z = Concatenate(axis=1)(concat_z)
merge_xyz = Concatenate(axis=1)([merge_x, merge_y, merge_z])
x = Dropout(self.keep_prob)(merge_xyz)
# 最后就是softmax
dense_layer = Dense(self.label, activation=self.activation)(x)
output_layers = [dense_layer]
self.model = Model(bert_inputs, output_layers)
def compile_model(self):
self.model.compile(optimizer=args.optimizers,
loss=categorical_crossentropy,
metrics=args.metrics)
def callback(self):
c_b = [ModelCheckpoint(args.path_save_model, monitor='val_loss', verbose=1, save_best_only=True,
save_weights_only=False, mode='min'),
EarlyStopping(min_delta=1e-9, patience=4, mode='min')
]
return c_b
def fit(self, x_train, y_train, x_dev, y_dev):
self.model.fit(x_train, y_train, batch_size=args.batch_size,
epochs=args.epochs, validation_data=(x_dev, y_dev),
shuffle=True,
callbacks=self.callback())
self.model.save(args.path_save_model)
def load_model(self):
print("BertBiLstmModel load_model start!")
# logger.info("BertBiLstmModel load_model start!")
self.model.load_weights(args.path_save_model)
# logger.info("BertBiLstmModel load_model end+!")
print("BertBiLstmModel load_model end+!")
def process_pair(self, textss):
# 文本预处理,传入一个list,返回的是ids\mask\type-ids
input_ids = []
input_masks = []
input_type_ids = []
for texts in textss:
tokens_text = self.tokenizer.tokenize(texts[0])
logger.info('Tokens1:', tokens_text)
tokens_text2 = self.tokenizer.tokenize(texts[1])
logger.info('Tokens2:', tokens_text2)
input_id, input_type_id = self.tokenizer.encode(first=texts[0], second=texts[1], max_len=self.max_seq_len)
input_mask = [0 if ids == 0 else 1 for ids in input_id]
input_ids.append(input_id)
input_type_ids.append(input_type_id)
input_masks.append(input_mask)
# numpy处理list
input_ids = np.array(input_ids)
input_masks = np.array(input_masks)
input_type_ids = np.array(input_type_ids)
logger.info("process ok!")
return input_ids, input_masks, input_type_ids
def predict(self, sen_1, sen_2):
input_ids, input_masks, input_type_ids = self.process_pair([[sen_1, sen_2]])
return self.model.predict([input_ids, input_masks], batch_size=1)
def predict_list(self, questions):
label_preds = []
for questions_pair in questions:
input_ids, input_masks, input_type_ids = self.process_pair([questions_pair])
label_pred = self.model.predict([input_ids, input_masks], batch_size=1)
label_preds.append(label_pred[0])
return label_preds
def classify_pair_corpus_webank(bert_model, path_webank):
# 数据预处理
from utils.text_tools import text_preprocess, txtRead, txtWrite
import random
webank_q_2_l = txtRead(path_webank, encodeType='utf-8')
questions = []
labels = []
for ques_label in webank_q_2_l[1:]:
q_2_l = ques_label.split(',')
q_1 = q_2_l[0]
q_2 = "".join(q_2_l[1:-1])
label = q_2_l[-1]
questions.append([text_preprocess(q_1), text_preprocess(q_2)])
label_int = int(label)
labels.append([0, 1] if label_int == 1 else [1, 0])
questions = np.array(questions)
labels = np.array(labels)
input_ids, input_masks, input_type_ids = bert_model.process_pair(questions)
return questions, labels, input_ids, input_masks, input_type_ids
def train():
# 1. trian
bert_model = BertTextCnnModel()
bert_model.compile_model()
_, labels_train, input_ids_train, input_masks_train, _ = classify_pair_corpus_webank(bert_model, path_webank_train)
_, labels_dev, input_ids_dev, input_masks_dev, _ = classify_pair_corpus_webank(bert_model, path_webank_dev)
# questions_test, labels_test, input_ids_test, input_masks_test, _ = classify_pair_corpus_webank(bert_model, path_webank_test)
print("process corpus ok!")
bert_model.fit([input_ids_train, input_masks_train], labels_train, [input_ids_dev, input_masks_dev], labels_dev)
print("bert_model fit ok!")
def tet():
# 2.test
bert_model = BertTextCnnModel()
bert_model.load_model()
questions_test, labels_test, input_ids_test, input_masks_test, _ = classify_pair_corpus_webank(bert_model,
path_webank_test)
print('predict_list start! you will wait for a few minutes')
labels_pred = bert_model.predict_list(questions_test)
print('predict_list end!')
labels_pred_np = np.array(labels_pred)
labels_pred_np_arg = np.argmax(labels_pred_np, axis=1)
labels_test_np = np.array(labels_test)
labels_test_np_arg = np.argmax(labels_test_np, axis=1)
target_names = ['不相似', '相似']
report_predict = classification_report(labels_test_np_arg, labels_pred_np_arg,
target_names=target_names, digits=9)
print(report_predict)
def predict():
# 3. predict
bert_model = BertTextCnnModel()
bert_model.load_model()
pred = bert_model.predict(sen_1='jy', sen_2='myz')
print(pred[0][1])
while True:
print("sen_1: ")
sen_1 = input()
print("sen_2: ")
sen_2 = input()
pred = bert_model.predict(sen_1=sen_1, sen_2=sen_2)
print(pred[0][1])
if __name__ == "__main__":
train()
# tet()
# predict()
# text cnn, real stop
# 100000/100000 [==============================] - 1546s 15ms/step - loss: 0.4168 - acc: 0.8108 - val_loss: 0.4379 - val_acc: 0.8008
# Epoch 00024: val_loss improved from 0.43926 to 0.43788, saving model to model_webank_tdt/bert_avt_cnn.h5
# precision recall f1-score support
# 不相似 0.800245600 0.782000000 0.791017601 5000
# 相似 0.786859601 0.804800000 0.795728693 5000
# avg / total 0.793552600 0.793400000 0.793373147 10000
# text-rcnn, real stop
# 100000/100000 [==============================] - 1671s 17ms/step - loss: 0.4627 - acc: 0.7971 - val_loss: 0.4810 - val_acc: 0.8018
# precision recall f1-score support
# 不相似 0.777479378 0.810600000 0.793694311 5000
# 相似 0.802172551 0.768000000 0.784714417 5000
# avg / total 0.789825965 0.789300000 0.789204364 10000
# avt-cnn, real stop
# 100000/100000 [==============================] - 1562s 16ms/step - loss: 0.4204 - acc: 0.8091 - val_loss: 0.4391 - val_acc: 0.7925
# Epoch 00015: val_loss improved from 0.44410 to 0.43914, saving model to model_webank_tdt/bert_avt_cnn.h5
# precision recall f1-score support
# 不相似 0.789808917 0.768800000 0.779162866 5000
# 相似 0.774790571 0.795400000 0.784960032 5000
# avg / total 0.782299744 0.782100000 0.782061449 10000