@@ -30,6 +30,7 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
30
30
self .seq_length = tf .placeholder (tf .int32 , [self .batch_size ])
31
31
else :
32
32
self .input = tf .placeholder (tf .int32 , [None , self .windows_size ])
33
+
33
34
# 查找表层
34
35
self .embedding_layer = self .get_embedding_layer ()
35
36
# 隐藏层
@@ -46,6 +47,9 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
46
47
47
48
if mode == 'predict' :
48
49
self .output = tf .squeeze (tf .transpose (self .output ), axis = 2 )
50
+ self .sess = tf .Session ()
51
+ self .sess .run (tf .global_variables_initializer ())
52
+ tf .train .Saver ().restore (save_path = self .model_path , sess = self .sess )
49
53
elif train == 'll' :
50
54
self .ll_loss , _ = tf .contrib .crf .crf_log_likelihood (self .output , self .real_indices , self .seq_length ,
51
55
self .transition )
@@ -180,17 +184,15 @@ def generate_transition_update_index(self, correct_labels, current_labels):
180
184
def predict (self , sentence : str , return_labels = False ):
181
185
if self .mode != 'predict' :
182
186
raise Exception ('mode is not allowed to predict' )
183
- with tf .Session () as sess :
184
- tf .global_variables_initializer ().run ()
185
- tf .train .Saver ().restore (save_path = self .model_path , sess = sess )
186
- input = self .indices2input (self .sentence2indices (sentence ))
187
- runner = [self .output , self .transition , self .transition_init ]
188
- output , trans , trans_init = sess .run (runner , feed_dict = {self .input : input })
189
- labels = self .viterbi (output , trans , trans_init )
190
- if not return_labels :
191
- return self .tags2words (sentence , labels )
192
- else :
193
- return self .tags2words (sentence , labels ), labels
187
+
188
+ input = self .indices2input (self .sentence2indices (sentence ))
189
+ runner = [self .output , self .transition , self .transition_init ]
190
+ output , trans , trans_init = self .sess .run (runner , feed_dict = {self .input : input })
191
+ labels = self .viterbi (output , trans , trans_init )
192
+ if not return_labels :
193
+ return self .tags2words (sentence , labels )
194
+ else :
195
+ return self .tags2words (sentence , labels ), self .tag2sequences (labels )
194
196
195
197
def get_embedding_layer (self ) -> tf .Tensor :
196
198
embeddings = self .__get_variable ([self .dict_size , self .embed_size ], 'embeddings' )
0 commit comments