@@ -13,30 +13,35 @@ class SegLSTM(SegBase):
1313 def __init__ (self ):
1414 SegBase .__init__ (self )
1515 self .dtype = tf .float32
16+ # 参数初始化
1617 self .skip_window_left = constant .LSTM_SKIP_WINDOW_LEFT
1718 self .skip_window_right = constant .LSTM_SKIP_WINDOW_RIGHT
1819 self .window_size = self .skip_window_left + self .skip_window_right + 1
19- self .embed_size = 50
20- self .hidden_units = 100
20+ self .embed_size = 100
21+ self .hidden_units = 150
2122 self .tag_count = 4
2223 self .concat_embed_size = self .window_size * self .embed_size
24+ self .vocab_size = constant .VOCAB_SIZE
25+ self .alpha = 0.05
26+ self .lam = 0.0005
27+ self .eta = 0.02
28+ self .dropout_rate = 0.2
29+ # 数据初始化
2330 trans = TransformDataLSTM ()
2431 self .words_batch = trans .words_batch
2532 self .tags_batch = trans .labels_batch
2633 self .dictionary = trans .dictionary
27- self .vocab_size = constant .VOCAB_SIZE
28- self .alpha = 0.02
29- self .lam = 0.001
34+ # 模型定义和初始化
3035 self .sess = tf .Session ()
3136 self .optimizer = tf .train .GradientDescentOptimizer (self .alpha )
3237 self .x = tf .placeholder (self .dtype , shape = [self .concat_embed_size , None ])
3338 self .x_plus = tf .placeholder (self .dtype , shape = [1 , None , self .concat_embed_size ])
3439 self .embeddings = tf .Variable (
35- tf .random_uniform ([self .vocab_size , self .embed_size ], - 1 .0 / math .sqrt (self .embed_size ),
36- 1 .0 / math .sqrt (self .embed_size ),
40+ tf .random_uniform ([self .vocab_size , self .embed_size ], - 8 .0 / math .sqrt (self .embed_size ),
41+ 8 .0 / math .sqrt (self .embed_size ),
3742 dtype = self .dtype ), dtype = self .dtype , name = 'embeddings' )
3843 self .w = tf .Variable (
39- tf .truncated_normal ([self .tags_count , self .hidden_units ], stddev = 1 .0 / math .sqrt (self .concat_embed_size )),
44+ tf .truncated_normal ([self .tags_count , self .hidden_units ], stddev = 5 .0 / math .sqrt (self .concat_embed_size )),
4045 dtype = self .dtype , name = 'w' )
4146 self .b = tf .Variable (tf .zeros ([self .tag_count , 1 ]), dtype = self .dtype , name = 'b' )
4247 self .A = tf .Variable (tf .random_uniform ([self .tag_count , self .tag_count ], - 1 , 1 ), dtype = self .dtype , name = 'A' )
@@ -55,7 +60,7 @@ def __init__(self):
5560 self .lstm = tf .contrib .rnn .LSTMCell (self .hidden_units )
5661 self .lstm_output , self .lstm_out_state = tf .nn .dynamic_rnn (self .lstm , self .x_plus , dtype = self .dtype )
5762 tf .global_variables_initializer ().run (session = self .sess )
58- self .word_scores = tf .matmul (self .w , tf .transpose (self .lstm_output [0 ])) + self .b
63+ self .word_scores = tf .nn . relu ( tf . matmul (self .w , tf .transpose (self .lstm_output [0 ])) + self .b )
5964 self .loss = tf .reduce_sum (tf .multiply (self .map_matrix , self .word_scores ))
6065 self .lstm_variable = [v for v in tf .global_variables () if v .name .startswith ('rnn' )]
6166 self .params = [self .w , self .b ] + self .lstm_variable
@@ -66,6 +71,8 @@ def __init__(self):
6671 self .update_embed_op = tf .scatter_update (self .embeddings , self .embed_index , self .embedp )
6772 self .grad_embed = tf .gradients (tf .multiply (self .map_matrix , self .word_scores ), self .x_plus )
6873 self .saver = tf .train .Saver (self .params + [self .embeddings , self .A , self .init_A ], max_to_keep = 100 )
74+ self .regu = tf .contrib .layers .apply_regularization (tf .contrib .layers .l2_regularizer (self .lam ),
75+ self .params + [self .A , self .init_A ])
6976
7077 def model (self , embeds ):
7178 scores = self .sess .run (self .word_scores , feed_dict = {self .x_plus : np .expand_dims (embeds .T , 0 )})
@@ -75,14 +82,15 @@ def model(self, embeds):
7582 def train_exe (self ):
7683 self .sess .graph .finalize ()
7784 last_time = time .time ()
78- for sentence_index , (sentence , tags ) in enumerate (zip (self .words_batch , self .tags_batch )):
79- self .train_sentence (sentence , tags , len (tags ))
80- if sentence_index % 500 == 0 :
81- print (sentence_index )
82- print (time .time () - last_time )
83- last_time = time .time ()
84- print (self .sess .run (self .init_A ))
85- self .saver .save (self .sess , 'tmp/lstm-model%d.ckpt' % 0 )
85+ for i in range (3 ):
86+ for sentence_index , (sentence , tags ) in enumerate (zip (self .words_batch , self .tags_batch )):
87+ self .train_sentence (sentence , tags , len (tags ))
88+ if sentence_index % 500 == 0 :
89+ print (sentence_index )
90+ print (time .time () - last_time )
91+ last_time = time .time ()
92+ print (self .sess .run (self .init_A ))
93+ self .saver .save (self .sess , 'tmp/lstm-model%d.ckpt' % i )
8694
8795 def train_sentence (self , sentence , tags , length ):
8896 sentence_embeds = self .sess .run (self .lookup_op , feed_dict = {self .sentence_holder : sentence }).reshape (
@@ -108,6 +116,7 @@ def train_sentence(self, sentence, tags, length):
108116 feed_dict = {self .indices : sparse_indices , self .shape : output_shape ,
109117 self .values : sparse_values })
110118 # 更新参数
119+ # self.sess.run(self.regu)
111120 self .sess .run (self .train ,
112121 feed_dict = {self .x_plus : np .expand_dims (update_embed .T , 0 ), self .map_matrix : sentence_matrix })
113122 self .sess .run (self .regularization )
@@ -119,7 +128,7 @@ def train_sentence(self, sentence, tags, length):
119128 grad = self .sess .run (self .grad_embed , feed_dict = {self .x_plus : embed ,
120129 self .map_matrix : np .expand_dims (sentence_matrix [:, i ], 1 )})[0 ]
121130
122- sentence_update_embed = (embed + self .alpha * grad ) * (1 - self .lam )
131+ sentence_update_embed = (embed - self .alpha * grad ) * (1 - self .lam )
123132 self .embeddings = self .sess .run (self .update_embed_op ,
124133 feed_dict = {
125134 self .embedp : sentence_update_embed .reshape ([self .window_size , self .embed_size ]),
@@ -153,22 +162,22 @@ def gen_update_A(correct_tags, current_tags):
153162
154163 return A_update , init_A_update , update_init
155164
156- def seg (self , sentence , model_path = 'tmp/lstm-model0.ckpt' ):
157- # tf.reset_default_graph()
165+ def seg (self , sentence , model_path = 'tmp/lstm-model1.ckpt' ):
158166 self .saver .restore (self .sess , model_path )
159167 seq = self .index2seq (self .sentence2index (sentence ))
160168 sentence_embeds = tf .nn .embedding_lookup (self .embeddings , seq ).eval (session = self .sess ).reshape (
161169 [len (sentence ), self .concat_embed_size ])
162170 sentence_scores = self .sess .run (self .word_scores , feed_dict = {self .x_plus : np .expand_dims (sentence_embeds , 0 )})
163171 init_A_val = self .init_A .eval (session = self .sess )
164172 A_val = self .A .eval (session = self .sess )
165- print (self . sess . run ( self . w ) )
173+ print (A_val )
166174 current_tags = self .viterbi (sentence_scores , A_val , init_A_val )
167175 return self .tags2words (sentence , current_tags ), current_tags
168176
169177
170178if __name__ == '__main__' :
171179 seg = SegLSTM ()
172180 # seg.train_exe()
173- res = seg .seg ('我爱北京天安门' )
181+ print ( seg .seg ('我爱北京天安门' ) )
174182 print (seg .seg ('小明来自南京师范大学' ))
183+ print (seg .seg ('小明是上海理工大学的学生' ))
0 commit comments