-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Closed
Description
When state_is_tuple, we need to feed the cell and hidden state seperately, the code will look like below. Is there any simple way to simplify the code?
# reset all states at the begining of every epoch
state1 = tl.layers.initialize_rnn_state(lstm1.initial_state)
state2 = tl.layers.initialize_rnn_state(lstm2.initial_state)
for step, (x, y) in enumerate(tl.iterate.ptb_iterator(train_data,
batch_size, num_steps)):
feed_dict = {input_data: x, targets: y,
lstm1.initial_state.c: state1[0],
lstm1.initial_state.h: state1[1],
lstm2.initial_state.c: state2[0],
lstm2.initial_state.h: state2[1],
}
# For training, enable dropout
feed_dict.update( network.all_drop )
_cost, state1_c, state1_h, state2_c, state2_h, _ = \
sess.run([cost,
lstm1.final_state.c,
lstm1.final_state.h,
lstm2.final_state.c,
lstm2.final_state.h,
train_op],
feed_dict=feed_dict
)
state1 = (state1_c, state1_h)
state2 = (state2_c, state2_h)
Metadata
Metadata
Assignees
Labels
No labels