Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions tensorlayer/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,10 @@ def __init__(
DropoutWrapper_fn = tf.contrib.rnn.DropoutWrapper
except Exception:
DropoutWrapper_fn = tf.nn.rnn_cell.DropoutWrapper
cell_creator = lambda: DropoutWrapper_fn(rnn_creator(), input_keep_prob=in_keep_prob, output_keep_prob=1.0) # out_keep_prob)
cell_creator = lambda is_last=True: \
DropoutWrapper_fn(rnn_creator(),
input_keep_prob=in_keep_prob,
output_keep_prob=out_keep_prob if is_last else 1.0)
else:
cell_creator = rnn_creator
self.fw_cell = cell_creator()
Expand All @@ -392,11 +395,11 @@ def __init__(
MultiRNNCell_fn = tf.nn.rnn_cell.MultiRNNCell

try:
self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True)
self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True)
self.fw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True)
self.bw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True)
except Exception:
self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
self.fw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)])
self.bw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)])

# Initial state of RNN
if fw_initial_state is None:
Expand Down Expand Up @@ -1076,7 +1079,10 @@ def __init__(
# cell_instance_fn1(),
# input_keep_prob=in_keep_prob,
# output_keep_prob=out_keep_prob)
cell_creator = lambda: DropoutWrapper_fn(rnn_creator(), input_keep_prob=in_keep_prob, output_keep_prob=1.0)
cell_creator = lambda is_last=True: \
DropoutWrapper_fn(rnn_creator(),
input_keep_prob=in_keep_prob,
output_keep_prob=out_keep_prob if is_last else 1.0)
else:
cell_creator = rnn_creator
self.cell = cell_creator()
Expand All @@ -1090,10 +1096,10 @@ def __init__(
# cell_instance_fn2=cell_instance_fn # HanSheng
try:
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)], state_is_tuple=True) # HanSheng
self.cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True)
self.cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True)
except Exception: # when GRU
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)]) # HanSheng
self.cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
self.cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)])

# self.cell=cell_instance_fn() # HanSheng

Expand Down