Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 124 additions & 50 deletions tensorlayer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2702,6 +2702,7 @@ def __init__(
self.all_layers.extend( [self.outputs] )
self.all_params.extend( rnn_variables )


# Bidirectional Dynamic RNN
class BiDynamicRNNLayer(Layer):
"""
Expand All @@ -2722,16 +2723,30 @@ class BiDynamicRNNLayer(Layer):
The arguments for the cell initializer.
n_hidden : a int
The number of hidden units in the layer.
n_steps : a int
The sequence length.
initializer : initializer
The initializer for initializing the parameters.
sequence_length : a tensor, array or None
The sequence length of each row of input data, see ``Advanced Ops for Dynamic RNN``.
- If None, it uses ``retrieve_seq_length_op`` to compute the sequence_length, i.e. when the features of padding (on right hand side) are all zeros.
- If using word embedding, you may need to compute the sequence_length from the ID array (the integer features before word embedding) by using ``retrieve_seq_length_op2`` or ``retrieve_seq_length_op``.
- You can also input an numpy array.
- More details about TensorFlow dynamic_rnn in `Wild-ML Blog <http://www.wildml.com/2016/08/rnns-in-tensorflow-a-practical-guide-and-undocumented-features/>`_.
fw_initial_state : None or forward RNN State
If None, initial_state is zero_state.
bw_initial_state : None or backward RNN State
If None, initial_state is zero_state.
dropout : `tuple` of `float`: (input_keep_prob, output_keep_prob).
The input and output keep probability.
n_layer : a int, default is 1.
The number of RNN layers.
return_last : boolean
If True, return the last output, "Sequence input and single output"\n
If False, return all outputs, "Synced sequence input and output"\n
In other word, if you want to apply one or more RNN(s) on this layer, set to False.
return_seq_2d : boolean
When return_last = False\n
if True, return 2D Tensor [n_example, n_hidden], for stacking DenseLayer after it.
if False, return 3D Tensor [n_example/n_steps, n_steps, n_hidden], for stacking multiple RNN after it.
- When return_last = False
- If True, return 2D Tensor [n_example, 2 * n_hidden], for stacking DenseLayer or computing cost after it.
- If False, return 3D Tensor [n_example/n_steps(max), n_steps(max), 2 * n_hidden], for stacking multiple RNN after it.
name : a string or None
An optional name to attach to this layer.

Expand All @@ -2740,20 +2755,23 @@ class BiDynamicRNNLayer(Layer):
outputs : a tensor
The output of this RNN.
return_last = False, outputs = all cell_output, which is the hidden state.
cell_output.get_shape() = (?, n_hidden)
cell_output.get_shape() = (?, 2 * n_hidden)

final_state : a tensor or StateTuple
fw(bw)_final_state : a tensor or StateTuple
When state_is_tuple = False,
it is the final hidden and cell states, states.get_shape() = [?, 2 * n_hidden].\n
When state_is_tuple = True, it stores two elements: (c, h), in that order.
You can get the final state after each iteration during training, then
feed it to the initial state of next iteration.

initial_state : a tensor or StateTuple
fw(bw)_initial_state : a tensor or StateTuple
It is the initial state of this RNN layer, you can use it to initialize
your state at the begining of each epoch or iteration according to your
training procedure.

sequence_length : a tensor or array, shape = [batch_size]
The sequence lengths computed by Advanced Opt or the given sequence lengths.

Notes
-----
Input dimension should be rank 3 : [batch_size, n_steps(max), n_features], if no, please see :class:`ReshapeLayer`.
Expand All @@ -2768,59 +2786,118 @@ def __init__(
self,
layer = None,
cell_fn = tf.nn.rnn_cell.LSTMCell,
cell_init_args = {'state_is_tuple' : True},
n_hidden = 64,
cell_init_args = {},
n_hidden = 100,
initializer = tf.random_uniform_initializer(-0.1, 0.1),
# n_steps = 5,
sequence_length = None,
fw_initial_state = None,
bw_initial_state = None,
dropout = None,
n_layer = 1,
return_last = False,
# is_reshape = True,
return_seq_2d = False,
name = 'birnn_layer',
name = 'bi_dyrnn_layer',
):
Layer.__init__(self, name=name)
self.inputs = layer.outputs

print(" tensorlayer:Instantiate BiDynamicRNNLayer %s: n_hidden:%d, n_steps:%d, in_dim:%d %s, cell_fn:%s " % (self.name, n_hidden,
n_steps, self.inputs.get_shape().ndims, self.inputs.get_shape(), cell_fn.__name__))
print(" Untested !!!")
print(" tensorlayer:Instantiate BiDynamicRNNLayer %s: n_hidden:%d, in_dim:%d %s, cell_fn:%s, dropout:%s, n_layer:%d" %
(self.name, n_hidden, self.inputs.get_shape().ndims, self.inputs.get_shape(), cell_fn.__name__, dropout, n_layer))

self.cell = cell = cell_fn(num_units=n_hidden, **cell_init_args)
# self.initial_state = cell.zero_state(batch_size, dtype=tf.float32)
# state = self.initial_state
# Input dimension should be rank 3 [batch_size, n_steps(max), n_features]
try:
self.inputs.get_shape().with_rank(3)
except:
raise Exception("RNN : Input dimension should be rank 3 : [batch_size, n_steps(max), n_features]")

# Get the batch_size
fixed_batch_size = self.inputs.get_shape().with_rank_at_least(1)[0]
if fixed_batch_size.value:
batch_size = fixed_batch_size.value
print(" batch_size (concurrent processes): %d" % batch_size)
else:
from tensorflow.python.ops import array_ops
batch_size = array_ops.shape(self.inputs)[0]
print(" non specified batch_size, uses a tensor instead.")
self.batch_size = batch_size

with tf.variable_scope(name, initializer=initializer) as vs:
outputs, states = tf.nn.bidirectional_dynamic_rnn(
cell_fw=cell,
cell_bw=cell,
dtype=tf.float64,
sequence_length=X_lengths,
inputs=X)

output_fw, output_bw = outputs
states_fw, states_bw = states

result = tf.contrib.learn.run_n(
{"output_fw": output_fw, "output_bw": output_bw, "states_fw": states_fw, "states_bw": states_bw},
n=1,
feed_dict=None)
rnn_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
# Creats the cell function
self.fw_cell = cell_fn(num_units=n_hidden, **cell_init_args)
self.bw_cell = cell_fn(num_units=n_hidden, **cell_init_args)

print(" n_params : %d" % (len(rnn_variables)))
# Apply dropout
if dropout:
if type(dropout) in [tuple, list]:
in_keep_prob = dropout[0]
out_keep_prob = dropout[1]
elif isinstance(dropout, float):
in_keep_prob, out_keep_prob = dropout, dropout
else:
raise Exception("Invalid dropout type (must be a 2-D tuple of "
"float)")
self.fw_cell = tf.nn.rnn_cell.DropoutWrapper(
self.fw_cell,
input_keep_prob=in_keep_prob,
output_keep_prob=out_keep_prob)
self.bw_cell = tf.nn.rnn_cell.DropoutWrapper(
self.bw_cell,
input_keep_prob=in_keep_prob,
output_keep_prob=out_keep_prob)
# Apply multiple layers
if n_layer > 1:
print(" n_layer: %d" % n_layer)
self.fw_cell = tf.nn.rnn_cell.MultiRNNCell([self.fw_cell] * n_layer)
self.bw_cell = tf.nn.rnn_cell.MultiRNNCell([self.bw_cell] * n_layer)
# Initial state of RNN
if fw_initial_state is None:
self.fw_initial_state = self.fw_cell.zero_state(self.batch_size, dtype=tf.float32)
else:
self.fw_initial_state = fw_initial_state
if bw_initial_state is None:
self.bw_initial_state = self.bw_cell.zero_state(self.batch_size, dtype=tf.float32)
else:
self.bw_initial_state = bw_initial_state
# Computes sequence_length
if sequence_length is None:
sequence_length = retrieve_seq_length_op(
self.inputs if isinstance(self.inputs, tf.Tensor) else tf.pack(self.inputs))

if return_last:
# 2D Tensor [batch_size, n_hidden]
self.outputs = output_fw
else:
if return_seq_2d:
# PTB tutorial:
# 2D Tensor [n_example, n_hidden]
self.outputs = tf.reshape(tf.concat(1, output_fw), [-1, n_hidden])
outputs, (states_fw, states_bw) = tf.nn.bidirectional_dynamic_rnn(
cell_fw=self.fw_cell,
cell_bw=self.bw_cell,
inputs=self.inputs,
sequence_length=sequence_length,
initial_state_fw=self.fw_initial_state,
initial_state_bw=self.bw_initial_state,
)
rnn_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)

print(" n_params : %d" % (len(rnn_variables)))
# Manage the outputs
outputs = tf.concat(-1, outputs)
if return_last:
# [batch_size, 2 * n_hidden]
self.outputs = advanced_indexing_op(outputs, sequence_length)
else:
# <akara>:
# 3D Tensor [n_example/n_steps, n_steps, n_hidden]
self.outputs = tf.reshape(tf.concat(1, output_fw), [-1, n_steps, n_hidden])
# [batch_size, n_step(max), 2 * n_hidden]
if return_seq_2d:
# PTB tutorial:
# 2D Tensor [n_example, 2 * n_hidden]
self.outputs = tf.reshape(tf.concat(1, outputs), [-1, 2 * n_hidden])
else:
# <akara>:
# 3D Tensor [batch_size, n_steps(max), 2 * n_hidden]
max_length = tf.shape(outputs)[1]
batch_size = tf.shape(outputs)[0]
self.outputs = tf.reshape(tf.concat(1, outputs), [batch_size, max_length, 2 * n_hidden])
# self.outputs = tf.reshape(tf.concat(1, outputs), [-1, max_length, 2 * n_hidden])

self.final_state = state
# Final state
self.fw_final_states = states_fw
self.bw_final_states = states_bw

self.sequence_length = sequence_length

self.all_layers = list(layer.all_layers)
self.all_params = list(layer.all_params)
Expand All @@ -2830,9 +2907,6 @@ def __init__(
self.all_params.extend( rnn_variables )





## Shape layer
class FlattenLayer(Layer):
"""
Expand Down