diff --git a/tensorlayer/layers/recurrent.py b/tensorlayer/layers/recurrent.py index 1bb79720c..d7936b8bf 100644 --- a/tensorlayer/layers/recurrent.py +++ b/tensorlayer/layers/recurrent.py @@ -378,7 +378,10 @@ def __init__( DropoutWrapper_fn = tf.contrib.rnn.DropoutWrapper except Exception: DropoutWrapper_fn = tf.nn.rnn_cell.DropoutWrapper - cell_creator = lambda: DropoutWrapper_fn(rnn_creator(), input_keep_prob=in_keep_prob, output_keep_prob=1.0) # out_keep_prob) + cell_creator = lambda is_last=True: \ + DropoutWrapper_fn(rnn_creator(), + input_keep_prob=in_keep_prob, + output_keep_prob=out_keep_prob if is_last else 1.0) else: cell_creator = rnn_creator self.fw_cell = cell_creator() @@ -392,11 +395,11 @@ def __init__( MultiRNNCell_fn = tf.nn.rnn_cell.MultiRNNCell try: - self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True) - self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True) + self.fw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True) + self.bw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True) except Exception: - self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)]) - self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)]) + self.fw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)]) + self.bw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)]) # Initial state of RNN if fw_initial_state is None: @@ -1076,7 +1079,10 @@ def __init__( # cell_instance_fn1(), # input_keep_prob=in_keep_prob, # output_keep_prob=out_keep_prob) - cell_creator = lambda: DropoutWrapper_fn(rnn_creator(), input_keep_prob=in_keep_prob, output_keep_prob=1.0) + cell_creator = lambda is_last=True: \ + DropoutWrapper_fn(rnn_creator(), + input_keep_prob=in_keep_prob, + output_keep_prob=out_keep_prob if is_last else 1.0) else: cell_creator = rnn_creator self.cell = cell_creator() @@ -1090,10 +1096,10 @@ def __init__( # cell_instance_fn2=cell_instance_fn # HanSheng try: # cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)], state_is_tuple=True) # HanSheng - self.cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True) + self.cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True) except Exception: # when GRU # cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)]) # HanSheng - self.cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)]) + self.cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)]) # self.cell=cell_instance_fn() # HanSheng