Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 38 additions & 33 deletions tensorlayer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5623,53 +5623,58 @@ def __init__(
# cell_instance_fn1(),
# input_keep_prob=in_keep_prob,
# output_keep_prob=out_keep_prob)
cell_creator = lambda: DropoutWrapper_fn(rnn_creator(), input_keep_prob=in_keep_prob, output_keep_prob=1.0) # out_keep_prob)
cell_creator = lambda is_last=True: \
DropoutWrapper_fn(rnn_creator(),
input_keep_prob=in_keep_prob,
output_keep_prob=out_keep_prob if is_last else 1.0) # out_keep_prob)
else:
cell_creator = rnn_creator
self.fw_cell = cell_creator()
self.bw_cell = cell_creator()
# Apply multiple layers
if n_layer > 1:
try:
MultiRNNCell_fn = tf.contrib.rnn.MultiRNNCell
except:
MultiRNNCell_fn = tf.nn.rnn_cell.MultiRNNCell
cell_creator = lambda : rnn_creator()

# cell_instance_fn2=cell_instance_fn # HanSheng
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)])
self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])

if dropout:
self.fw_cell = DropoutWrapper_fn(self.fw_cell, input_keep_prob=1.0, output_keep_prob=out_keep_prob)
self.bw_cell = DropoutWrapper_fn(self.bw_cell, input_keep_prob=1.0, output_keep_prob=out_keep_prob)
# if dropout:
# self.fw_cell = DropoutWrapper_fn(self.fw_cell, input_keep_prob=1.0, output_keep_prob=out_keep_prob)
# self.bw_cell = DropoutWrapper_fn(self.bw_cell, input_keep_prob=1.0, output_keep_prob=out_keep_prob)

# self.fw_cell=cell_instance_fn()
# self.bw_cell=cell_instance_fn()
# Initial state of RNN
if fw_initial_state is None:
self.fw_initial_state = self.fw_cell.zero_state(self.batch_size, dtype=D_TYPE) # dtype=tf.float32)
else:
self.fw_initial_state = fw_initial_state
if bw_initial_state is None:
self.bw_initial_state = self.bw_cell.zero_state(self.batch_size, dtype=D_TYPE) # dtype=tf.float32)
else:
self.bw_initial_state = bw_initial_state

self.fw_initial_state = fw_initial_state
self.bw_initial_state = bw_initial_state
# Computes sequence_length
if sequence_length is None:
try: ## TF1.0
sequence_length = retrieve_seq_length_op(self.inputs if isinstance(self.inputs, tf.Tensor) else tf.stack(self.inputs))
except: ## TF0.12
sequence_length = retrieve_seq_length_op(self.inputs if isinstance(self.inputs, tf.Tensor) else tf.pack(self.inputs))

outputs, (states_fw, states_bw) = tf.nn.bidirectional_dynamic_rnn(
cell_fw=self.fw_cell,
cell_bw=self.bw_cell,
inputs=self.inputs,
sequence_length=sequence_length,
initial_state_fw=self.fw_initial_state,
initial_state_bw=self.bw_initial_state,
**dynamic_rnn_init_args)
if n_layer > 1:
self.fw_cell = [cell_creator(is_last= i == n_layer - 1) for i in range(n_layer)]
self.bw_cell = [cell_creator(is_last= i == n_layer - 1) for i in range(n_layer)]
from tensorflow.contrib.rnn import stack_bidirectional_dynamic_rnn
outputs, states_fw, states_bw = stack_bidirectional_dynamic_rnn(
cells_fw=self.fw_cell,
cells_bw=self.bw_cell,
inputs=self.inputs,
sequence_length=sequence_length,
initial_states_fw=self.fw_initial_state,
initial_states_bw=self.bw_initial_state,
dtype=D_TYPE,
**dynamic_rnn_init_args)

else:
self.fw_cell = cell_creator()
self.bw_cell = cell_creator()
outputs, (states_fw, states_bw) = tf.nn.bidirectional_dynamic_rnn(
cell_fw=self.fw_cell,
cell_bw=self.bw_cell,
inputs=self.inputs,
sequence_length=sequence_length,
initial_state_fw=self.fw_initial_state,
initial_state_bw=self.bw_initial_state,
dtype=D_TYPE,
**dynamic_rnn_init_args)

rnn_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)

print(" n_params : %d" % (len(rnn_variables)))
Expand Down