Skip to content

Commit d4efab6

Browse files
fix dnn and evaluation bug
1 parent c957bc6 commit d4efab6

File tree

2 files changed

+19
-23
lines changed

2 files changed

+19
-23
lines changed

python/dnlp/core/dnn_crf.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
3030
self.seq_length = tf.placeholder(tf.int32, [self.batch_size])
3131
else:
3232
self.input = tf.placeholder(tf.int32, [None, self.windows_size])
33+
3334
# 查找表层
3435
self.embedding_layer = self.get_embedding_layer()
3536
# 隐藏层
@@ -46,6 +47,9 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
4647

4748
if mode == 'predict':
4849
self.output = tf.squeeze(tf.transpose(self.output), axis=2)
50+
self.sess = tf.Session()
51+
self.sess.run(tf.global_variables_initializer())
52+
tf.train.Saver().restore(save_path=self.model_path, sess=self.sess)
4953
elif train == 'll':
5054
self.ll_loss, _ = tf.contrib.crf.crf_log_likelihood(self.output, self.real_indices, self.seq_length,
5155
self.transition)
@@ -180,17 +184,15 @@ def generate_transition_update_index(self, correct_labels, current_labels):
180184
def predict(self, sentence: str, return_labels=False):
181185
if self.mode != 'predict':
182186
raise Exception('mode is not allowed to predict')
183-
with tf.Session() as sess:
184-
tf.global_variables_initializer().run()
185-
tf.train.Saver().restore(save_path=self.model_path, sess=sess)
186-
input = self.indices2input(self.sentence2indices(sentence))
187-
runner = [self.output, self.transition, self.transition_init]
188-
output, trans, trans_init = sess.run(runner, feed_dict={self.input: input})
189-
labels = self.viterbi(output, trans, trans_init)
190-
if not return_labels:
191-
return self.tags2words(sentence, labels)
192-
else:
193-
return self.tags2words(sentence, labels), labels
187+
188+
input = self.indices2input(self.sentence2indices(sentence))
189+
runner = [self.output, self.transition, self.transition_init]
190+
output, trans, trans_init = self.sess.run(runner, feed_dict={self.input: input})
191+
labels = self.viterbi(output, trans, trans_init)
192+
if not return_labels:
193+
return self.tags2words(sentence, labels)
194+
else:
195+
return self.tags2words(sentence, labels), self.tag2sequences(labels)
194196

195197
def get_embedding_layer(self) -> tf.Tensor:
196198
embeddings = self.__get_variable([self.dict_size, self.embed_size], 'embeddings')

python/dnlp/utils/evaluation.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: UTF-8 -*-
22
import pickle
3-
from dnlp.utils.constant import TAG_BEGIN, TAG_INSIDE, TAG_OTHER, TAG_END, TAG_SINGLE
3+
from dnlp.utils.constant import TAG_BEGIN, TAG_INSIDE, TAG_END, TAG_SINGLE
44

55

66
def get_cws_statistics(correct_labels, predict_labels) -> (int, int, int):
@@ -33,7 +33,7 @@ def get_cws_statistics(correct_labels, predict_labels) -> (int, int, int):
3333
predicts[predict_start] = i
3434

3535
for predict in predicts:
36-
if corrects.get(predict) is not None and corrects[predict] == predicts[predict]:
36+
if predict in corrects and corrects[predict] == predicts[predict]:
3737
true_positive_count += 1
3838

3939
return true_positive_count, len(predicts), len(corrects)
@@ -72,22 +72,16 @@ def get_ner_statistics(correct_labels, predict_labels) -> (int, int, int):
7272
def evaluate_cws(model, data_path: str):
7373
with open(data_path, 'rb') as f:
7474
data = pickle.load(f)
75-
dictionary = data['dictionary']
76-
tags = data['tags']
77-
reversed_map = dict(zip(tags.values(), tags.keys()))
7875
characters = data['characters']
7976
labels_true = data['labels']
8077
c_count = 0
8178
p_count = 0
82-
r_count = -0
79+
r_count = 0
8380
for sentence, label in enumerate(characters, labels_true):
8481
words, labels_predict = model.predict(sentence, return_labels=True)
85-
seq = []
86-
for l in zip(labels_predict):
87-
seq.append(reversed_map[l])
88-
c, p, r = get_cws_statistics(label, seq)
82+
c, p, r = get_cws_statistics(label, labels_predict)
8983
c_count += c
9084
p_count += p
9185
r_count += r
92-
print(c / p)
93-
print(c / r)
86+
print(c_count / p_count)
87+
print(c_count / r_count)

0 commit comments

Comments
 (0)