From 36ff5f95fcdba5cad69cb2504a48d51f806354eb Mon Sep 17 00:00:00 2001 From: Matthew Zeng Date: Thu, 1 Feb 2018 15:52:59 +0800 Subject: [PATCH] use stack_bidirectional_dynamic_rnn for multi-layers BiDynamicRNN --- tensorlayer/layers.py | 71 +++++++++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 33 deletions(-) diff --git a/tensorlayer/layers.py b/tensorlayer/layers.py index e30f7cbaf..04a3722f4 100644 --- a/tensorlayer/layers.py +++ b/tensorlayer/layers.py @@ -5623,38 +5623,24 @@ def __init__( # cell_instance_fn1(), # input_keep_prob=in_keep_prob, # output_keep_prob=out_keep_prob) - cell_creator = lambda: DropoutWrapper_fn(rnn_creator(), input_keep_prob=in_keep_prob, output_keep_prob=1.0) # out_keep_prob) + cell_creator = lambda is_last=True: \ + DropoutWrapper_fn(rnn_creator(), + input_keep_prob=in_keep_prob, + output_keep_prob=out_keep_prob if is_last else 1.0) # out_keep_prob) else: - cell_creator = rnn_creator - self.fw_cell = cell_creator() - self.bw_cell = cell_creator() - # Apply multiple layers - if n_layer > 1: - try: - MultiRNNCell_fn = tf.contrib.rnn.MultiRNNCell - except: - MultiRNNCell_fn = tf.nn.rnn_cell.MultiRNNCell + cell_creator = lambda : rnn_creator() - # cell_instance_fn2=cell_instance_fn # HanSheng - # cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)]) - self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)]) - self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)]) - if dropout: - self.fw_cell = DropoutWrapper_fn(self.fw_cell, input_keep_prob=1.0, output_keep_prob=out_keep_prob) - self.bw_cell = DropoutWrapper_fn(self.bw_cell, input_keep_prob=1.0, output_keep_prob=out_keep_prob) + # if dropout: + # self.fw_cell = DropoutWrapper_fn(self.fw_cell, input_keep_prob=1.0, output_keep_prob=out_keep_prob) + # self.bw_cell = DropoutWrapper_fn(self.bw_cell, input_keep_prob=1.0, output_keep_prob=out_keep_prob) # self.fw_cell=cell_instance_fn() # self.bw_cell=cell_instance_fn() # Initial state of RNN - if fw_initial_state is None: - self.fw_initial_state = self.fw_cell.zero_state(self.batch_size, dtype=D_TYPE) # dtype=tf.float32) - else: - self.fw_initial_state = fw_initial_state - if bw_initial_state is None: - self.bw_initial_state = self.bw_cell.zero_state(self.batch_size, dtype=D_TYPE) # dtype=tf.float32) - else: - self.bw_initial_state = bw_initial_state + + self.fw_initial_state = fw_initial_state + self.bw_initial_state = bw_initial_state # Computes sequence_length if sequence_length is None: try: ## TF1.0 @@ -5662,14 +5648,33 @@ def __init__( except: ## TF0.12 sequence_length = retrieve_seq_length_op(self.inputs if isinstance(self.inputs, tf.Tensor) else tf.pack(self.inputs)) - outputs, (states_fw, states_bw) = tf.nn.bidirectional_dynamic_rnn( - cell_fw=self.fw_cell, - cell_bw=self.bw_cell, - inputs=self.inputs, - sequence_length=sequence_length, - initial_state_fw=self.fw_initial_state, - initial_state_bw=self.bw_initial_state, - **dynamic_rnn_init_args) + if n_layer > 1: + self.fw_cell = [cell_creator(is_last= i == n_layer - 1) for i in range(n_layer)] + self.bw_cell = [cell_creator(is_last= i == n_layer - 1) for i in range(n_layer)] + from tensorflow.contrib.rnn import stack_bidirectional_dynamic_rnn + outputs, states_fw, states_bw = stack_bidirectional_dynamic_rnn( + cells_fw=self.fw_cell, + cells_bw=self.bw_cell, + inputs=self.inputs, + sequence_length=sequence_length, + initial_states_fw=self.fw_initial_state, + initial_states_bw=self.bw_initial_state, + dtype=D_TYPE, + **dynamic_rnn_init_args) + + else: + self.fw_cell = cell_creator() + self.bw_cell = cell_creator() + outputs, (states_fw, states_bw) = tf.nn.bidirectional_dynamic_rnn( + cell_fw=self.fw_cell, + cell_bw=self.bw_cell, + inputs=self.inputs, + sequence_length=sequence_length, + initial_state_fw=self.fw_initial_state, + initial_state_bw=self.bw_initial_state, + dtype=D_TYPE, + **dynamic_rnn_init_args) + rnn_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name) print(" n_params : %d" % (len(rnn_variables)))