Skip to content

Commit bdcc3e4

Browse files
alter some code
1 parent 989844b commit bdcc3e4

File tree

1 file changed

+31
-22
lines changed

1 file changed

+31
-22
lines changed

seg_lstm.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,35 @@ class SegLSTM(SegBase):
1313
def __init__(self):
1414
SegBase.__init__(self)
1515
self.dtype = tf.float32
16+
# 参数初始化
1617
self.skip_window_left = constant.LSTM_SKIP_WINDOW_LEFT
1718
self.skip_window_right = constant.LSTM_SKIP_WINDOW_RIGHT
1819
self.window_size = self.skip_window_left + self.skip_window_right + 1
19-
self.embed_size = 50
20-
self.hidden_units = 100
20+
self.embed_size = 100
21+
self.hidden_units = 150
2122
self.tag_count = 4
2223
self.concat_embed_size = self.window_size * self.embed_size
24+
self.vocab_size = constant.VOCAB_SIZE
25+
self.alpha = 0.05
26+
self.lam = 0.0005
27+
self.eta = 0.02
28+
self.dropout_rate = 0.2
29+
# 数据初始化
2330
trans = TransformDataLSTM()
2431
self.words_batch = trans.words_batch
2532
self.tags_batch = trans.labels_batch
2633
self.dictionary = trans.dictionary
27-
self.vocab_size = constant.VOCAB_SIZE
28-
self.alpha = 0.02
29-
self.lam = 0.001
34+
# 模型定义和初始化
3035
self.sess = tf.Session()
3136
self.optimizer = tf.train.GradientDescentOptimizer(self.alpha)
3237
self.x = tf.placeholder(self.dtype, shape=[self.concat_embed_size, None])
3338
self.x_plus = tf.placeholder(self.dtype, shape=[1, None, self.concat_embed_size])
3439
self.embeddings = tf.Variable(
35-
tf.random_uniform([self.vocab_size, self.embed_size], -1.0 / math.sqrt(self.embed_size),
36-
1.0 / math.sqrt(self.embed_size),
40+
tf.random_uniform([self.vocab_size, self.embed_size], -8.0 / math.sqrt(self.embed_size),
41+
8.0 / math.sqrt(self.embed_size),
3742
dtype=self.dtype), dtype=self.dtype, name='embeddings')
3843
self.w = tf.Variable(
39-
tf.truncated_normal([self.tags_count, self.hidden_units], stddev=1.0 / math.sqrt(self.concat_embed_size)),
44+
tf.truncated_normal([self.tags_count, self.hidden_units], stddev=5.0 / math.sqrt(self.concat_embed_size)),
4045
dtype=self.dtype, name='w')
4146
self.b = tf.Variable(tf.zeros([self.tag_count, 1]), dtype=self.dtype, name='b')
4247
self.A = tf.Variable(tf.random_uniform([self.tag_count, self.tag_count], -1, 1), dtype=self.dtype, name='A')
@@ -55,7 +60,7 @@ def __init__(self):
5560
self.lstm = tf.contrib.rnn.LSTMCell(self.hidden_units)
5661
self.lstm_output, self.lstm_out_state = tf.nn.dynamic_rnn(self.lstm, self.x_plus, dtype=self.dtype)
5762
tf.global_variables_initializer().run(session=self.sess)
58-
self.word_scores = tf.matmul(self.w, tf.transpose(self.lstm_output[0])) + self.b
63+
self.word_scores = tf.nn.relu(tf.matmul(self.w, tf.transpose(self.lstm_output[0])) + self.b)
5964
self.loss = tf.reduce_sum(tf.multiply(self.map_matrix, self.word_scores))
6065
self.lstm_variable = [v for v in tf.global_variables() if v.name.startswith('rnn')]
6166
self.params = [self.w, self.b] + self.lstm_variable
@@ -66,6 +71,8 @@ def __init__(self):
6671
self.update_embed_op = tf.scatter_update(self.embeddings, self.embed_index, self.embedp)
6772
self.grad_embed = tf.gradients(tf.multiply(self.map_matrix, self.word_scores), self.x_plus)
6873
self.saver = tf.train.Saver(self.params + [self.embeddings, self.A, self.init_A], max_to_keep=100)
74+
self.regu = tf.contrib.layers.apply_regularization(tf.contrib.layers.l2_regularizer(self.lam),
75+
self.params + [self.A, self.init_A])
6976

7077
def model(self, embeds):
7178
scores = self.sess.run(self.word_scores, feed_dict={self.x_plus: np.expand_dims(embeds.T, 0)})
@@ -75,14 +82,15 @@ def model(self, embeds):
7582
def train_exe(self):
7683
self.sess.graph.finalize()
7784
last_time = time.time()
78-
for sentence_index, (sentence, tags) in enumerate(zip(self.words_batch, self.tags_batch)):
79-
self.train_sentence(sentence, tags, len(tags))
80-
if sentence_index % 500 == 0:
81-
print(sentence_index)
82-
print(time.time() - last_time)
83-
last_time = time.time()
84-
print(self.sess.run(self.init_A))
85-
self.saver.save(self.sess, 'tmp/lstm-model%d.ckpt' % 0)
85+
for i in range(3):
86+
for sentence_index, (sentence, tags) in enumerate(zip(self.words_batch, self.tags_batch)):
87+
self.train_sentence(sentence, tags, len(tags))
88+
if sentence_index % 500 == 0:
89+
print(sentence_index)
90+
print(time.time() - last_time)
91+
last_time = time.time()
92+
print(self.sess.run(self.init_A))
93+
self.saver.save(self.sess, 'tmp/lstm-model%d.ckpt' % i)
8694

8795
def train_sentence(self, sentence, tags, length):
8896
sentence_embeds = self.sess.run(self.lookup_op, feed_dict={self.sentence_holder: sentence}).reshape(
@@ -108,6 +116,7 @@ def train_sentence(self, sentence, tags, length):
108116
feed_dict={self.indices: sparse_indices, self.shape: output_shape,
109117
self.values: sparse_values})
110118
# 更新参数
119+
# self.sess.run(self.regu)
111120
self.sess.run(self.train,
112121
feed_dict={self.x_plus: np.expand_dims(update_embed.T, 0), self.map_matrix: sentence_matrix})
113122
self.sess.run(self.regularization)
@@ -119,7 +128,7 @@ def train_sentence(self, sentence, tags, length):
119128
grad = self.sess.run(self.grad_embed, feed_dict={self.x_plus: embed,
120129
self.map_matrix: np.expand_dims(sentence_matrix[:, i], 1)})[0]
121130

122-
sentence_update_embed = (embed + self.alpha * grad) * (1 - self.lam)
131+
sentence_update_embed = (embed - self.alpha * grad) * (1 - self.lam)
123132
self.embeddings = self.sess.run(self.update_embed_op,
124133
feed_dict={
125134
self.embedp: sentence_update_embed.reshape([self.window_size, self.embed_size]),
@@ -153,22 +162,22 @@ def gen_update_A(correct_tags, current_tags):
153162

154163
return A_update, init_A_update, update_init
155164

156-
def seg(self, sentence, model_path='tmp/lstm-model0.ckpt'):
157-
# tf.reset_default_graph()
165+
def seg(self, sentence, model_path='tmp/lstm-model1.ckpt'):
158166
self.saver.restore(self.sess, model_path)
159167
seq = self.index2seq(self.sentence2index(sentence))
160168
sentence_embeds = tf.nn.embedding_lookup(self.embeddings, seq).eval(session=self.sess).reshape(
161169
[len(sentence), self.concat_embed_size])
162170
sentence_scores = self.sess.run(self.word_scores, feed_dict={self.x_plus: np.expand_dims(sentence_embeds, 0)})
163171
init_A_val = self.init_A.eval(session=self.sess)
164172
A_val = self.A.eval(session=self.sess)
165-
print(self.sess.run(self.w))
173+
print(A_val)
166174
current_tags = self.viterbi(sentence_scores, A_val, init_A_val)
167175
return self.tags2words(sentence, current_tags), current_tags
168176

169177

170178
if __name__ == '__main__':
171179
seg = SegLSTM()
172180
# seg.train_exe()
173-
res = seg.seg('我爱北京天安门')
181+
print(seg.seg('我爱北京天安门'))
174182
print(seg.seg('小明来自南京师范大学'))
183+
print(seg.seg('小明是上海理工大学的学生'))

0 commit comments

Comments
 (0)