From b27f55f5b22cf93d2706feb585e3aa0cc88ecb74 Mon Sep 17 00:00:00 2001 From: Akaitsuki Brunestud Date: Fri, 30 Dec 2016 16:48:38 +0800 Subject: [PATCH] [layers] update BiDynamicRNNLayer to be avaliable, compatible with API of BiRNNLayer and DynamicRNNLayer --- tensorlayer/layers.py | 174 ++++++++++++++++++++++++++++++------------ 1 file changed, 124 insertions(+), 50 deletions(-) diff --git a/tensorlayer/layers.py b/tensorlayer/layers.py index c0151a208..f070fcff0 100755 --- a/tensorlayer/layers.py +++ b/tensorlayer/layers.py @@ -2702,6 +2702,7 @@ def __init__( self.all_layers.extend( [self.outputs] ) self.all_params.extend( rnn_variables ) + # Bidirectional Dynamic RNN class BiDynamicRNNLayer(Layer): """ @@ -2722,16 +2723,30 @@ class BiDynamicRNNLayer(Layer): The arguments for the cell initializer. n_hidden : a int The number of hidden units in the layer. - n_steps : a int - The sequence length. + initializer : initializer + The initializer for initializing the parameters. + sequence_length : a tensor, array or None + The sequence length of each row of input data, see ``Advanced Ops for Dynamic RNN``. + - If None, it uses ``retrieve_seq_length_op`` to compute the sequence_length, i.e. when the features of padding (on right hand side) are all zeros. + - If using word embedding, you may need to compute the sequence_length from the ID array (the integer features before word embedding) by using ``retrieve_seq_length_op2`` or ``retrieve_seq_length_op``. + - You can also input an numpy array. + - More details about TensorFlow dynamic_rnn in `Wild-ML Blog `_. + fw_initial_state : None or forward RNN State + If None, initial_state is zero_state. + bw_initial_state : None or backward RNN State + If None, initial_state is zero_state. + dropout : `tuple` of `float`: (input_keep_prob, output_keep_prob). + The input and output keep probability. + n_layer : a int, default is 1. + The number of RNN layers. return_last : boolean If True, return the last output, "Sequence input and single output"\n If False, return all outputs, "Synced sequence input and output"\n In other word, if you want to apply one or more RNN(s) on this layer, set to False. return_seq_2d : boolean - When return_last = False\n - if True, return 2D Tensor [n_example, n_hidden], for stacking DenseLayer after it. - if False, return 3D Tensor [n_example/n_steps, n_steps, n_hidden], for stacking multiple RNN after it. + - When return_last = False + - If True, return 2D Tensor [n_example, 2 * n_hidden], for stacking DenseLayer or computing cost after it. + - If False, return 3D Tensor [n_example/n_steps(max), n_steps(max), 2 * n_hidden], for stacking multiple RNN after it. name : a string or None An optional name to attach to this layer. @@ -2740,20 +2755,23 @@ class BiDynamicRNNLayer(Layer): outputs : a tensor The output of this RNN. return_last = False, outputs = all cell_output, which is the hidden state. - cell_output.get_shape() = (?, n_hidden) + cell_output.get_shape() = (?, 2 * n_hidden) - final_state : a tensor or StateTuple + fw(bw)_final_state : a tensor or StateTuple When state_is_tuple = False, it is the final hidden and cell states, states.get_shape() = [?, 2 * n_hidden].\n When state_is_tuple = True, it stores two elements: (c, h), in that order. You can get the final state after each iteration during training, then feed it to the initial state of next iteration. - initial_state : a tensor or StateTuple + fw(bw)_initial_state : a tensor or StateTuple It is the initial state of this RNN layer, you can use it to initialize your state at the begining of each epoch or iteration according to your training procedure. + sequence_length : a tensor or array, shape = [batch_size] + The sequence lengths computed by Advanced Opt or the given sequence lengths. + Notes ----- Input dimension should be rank 3 : [batch_size, n_steps(max), n_features], if no, please see :class:`ReshapeLayer`. @@ -2768,59 +2786,118 @@ def __init__( self, layer = None, cell_fn = tf.nn.rnn_cell.LSTMCell, - cell_init_args = {'state_is_tuple' : True}, - n_hidden = 64, + cell_init_args = {}, + n_hidden = 100, initializer = tf.random_uniform_initializer(-0.1, 0.1), - # n_steps = 5, + sequence_length = None, + fw_initial_state = None, + bw_initial_state = None, + dropout = None, + n_layer = 1, return_last = False, - # is_reshape = True, return_seq_2d = False, - name = 'birnn_layer', + name = 'bi_dyrnn_layer', ): Layer.__init__(self, name=name) self.inputs = layer.outputs - print(" tensorlayer:Instantiate BiDynamicRNNLayer %s: n_hidden:%d, n_steps:%d, in_dim:%d %s, cell_fn:%s " % (self.name, n_hidden, - n_steps, self.inputs.get_shape().ndims, self.inputs.get_shape(), cell_fn.__name__)) - print(" Untested !!!") + print(" tensorlayer:Instantiate BiDynamicRNNLayer %s: n_hidden:%d, in_dim:%d %s, cell_fn:%s, dropout:%s, n_layer:%d" % + (self.name, n_hidden, self.inputs.get_shape().ndims, self.inputs.get_shape(), cell_fn.__name__, dropout, n_layer)) - self.cell = cell = cell_fn(num_units=n_hidden, **cell_init_args) - # self.initial_state = cell.zero_state(batch_size, dtype=tf.float32) - # state = self.initial_state + # Input dimension should be rank 3 [batch_size, n_steps(max), n_features] + try: + self.inputs.get_shape().with_rank(3) + except: + raise Exception("RNN : Input dimension should be rank 3 : [batch_size, n_steps(max), n_features]") + + # Get the batch_size + fixed_batch_size = self.inputs.get_shape().with_rank_at_least(1)[0] + if fixed_batch_size.value: + batch_size = fixed_batch_size.value + print(" batch_size (concurrent processes): %d" % batch_size) + else: + from tensorflow.python.ops import array_ops + batch_size = array_ops.shape(self.inputs)[0] + print(" non specified batch_size, uses a tensor instead.") + self.batch_size = batch_size with tf.variable_scope(name, initializer=initializer) as vs: - outputs, states = tf.nn.bidirectional_dynamic_rnn( - cell_fw=cell, - cell_bw=cell, - dtype=tf.float64, - sequence_length=X_lengths, - inputs=X) - - output_fw, output_bw = outputs - states_fw, states_bw = states - - result = tf.contrib.learn.run_n( - {"output_fw": output_fw, "output_bw": output_bw, "states_fw": states_fw, "states_bw": states_bw}, - n=1, - feed_dict=None) - rnn_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name) + # Creats the cell function + self.fw_cell = cell_fn(num_units=n_hidden, **cell_init_args) + self.bw_cell = cell_fn(num_units=n_hidden, **cell_init_args) - print(" n_params : %d" % (len(rnn_variables))) + # Apply dropout + if dropout: + if type(dropout) in [tuple, list]: + in_keep_prob = dropout[0] + out_keep_prob = dropout[1] + elif isinstance(dropout, float): + in_keep_prob, out_keep_prob = dropout, dropout + else: + raise Exception("Invalid dropout type (must be a 2-D tuple of " + "float)") + self.fw_cell = tf.nn.rnn_cell.DropoutWrapper( + self.fw_cell, + input_keep_prob=in_keep_prob, + output_keep_prob=out_keep_prob) + self.bw_cell = tf.nn.rnn_cell.DropoutWrapper( + self.bw_cell, + input_keep_prob=in_keep_prob, + output_keep_prob=out_keep_prob) + # Apply multiple layers + if n_layer > 1: + print(" n_layer: %d" % n_layer) + self.fw_cell = tf.nn.rnn_cell.MultiRNNCell([self.fw_cell] * n_layer) + self.bw_cell = tf.nn.rnn_cell.MultiRNNCell([self.bw_cell] * n_layer) + # Initial state of RNN + if fw_initial_state is None: + self.fw_initial_state = self.fw_cell.zero_state(self.batch_size, dtype=tf.float32) + else: + self.fw_initial_state = fw_initial_state + if bw_initial_state is None: + self.bw_initial_state = self.bw_cell.zero_state(self.batch_size, dtype=tf.float32) + else: + self.bw_initial_state = bw_initial_state + # Computes sequence_length + if sequence_length is None: + sequence_length = retrieve_seq_length_op( + self.inputs if isinstance(self.inputs, tf.Tensor) else tf.pack(self.inputs)) - if return_last: - # 2D Tensor [batch_size, n_hidden] - self.outputs = output_fw - else: - if return_seq_2d: - # PTB tutorial: - # 2D Tensor [n_example, n_hidden] - self.outputs = tf.reshape(tf.concat(1, output_fw), [-1, n_hidden]) + outputs, (states_fw, states_bw) = tf.nn.bidirectional_dynamic_rnn( + cell_fw=self.fw_cell, + cell_bw=self.bw_cell, + inputs=self.inputs, + sequence_length=sequence_length, + initial_state_fw=self.fw_initial_state, + initial_state_bw=self.bw_initial_state, + ) + rnn_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name) + + print(" n_params : %d" % (len(rnn_variables))) + # Manage the outputs + outputs = tf.concat(-1, outputs) + if return_last: + # [batch_size, 2 * n_hidden] + self.outputs = advanced_indexing_op(outputs, sequence_length) else: - # : - # 3D Tensor [n_example/n_steps, n_steps, n_hidden] - self.outputs = tf.reshape(tf.concat(1, output_fw), [-1, n_steps, n_hidden]) + # [batch_size, n_step(max), 2 * n_hidden] + if return_seq_2d: + # PTB tutorial: + # 2D Tensor [n_example, 2 * n_hidden] + self.outputs = tf.reshape(tf.concat(1, outputs), [-1, 2 * n_hidden]) + else: + # : + # 3D Tensor [batch_size, n_steps(max), 2 * n_hidden] + max_length = tf.shape(outputs)[1] + batch_size = tf.shape(outputs)[0] + self.outputs = tf.reshape(tf.concat(1, outputs), [batch_size, max_length, 2 * n_hidden]) + # self.outputs = tf.reshape(tf.concat(1, outputs), [-1, max_length, 2 * n_hidden]) - self.final_state = state + # Final state + self.fw_final_states = states_fw + self.bw_final_states = states_bw + + self.sequence_length = sequence_length self.all_layers = list(layer.all_layers) self.all_params = list(layer.all_params) @@ -2830,9 +2907,6 @@ def __init__( self.all_params.extend( rnn_variables ) - - - ## Shape layer class FlattenLayer(Layer): """