From dcef47168e1d630fe02634f6a542df0c3571228d Mon Sep 17 00:00:00 2001 From: CTTC Date: Tue, 16 May 2017 11:30:06 +0800 Subject: [PATCH] fix minor errors when adjusting codes for RNN --- tensorlayer/layers.py | 81 +++++++++++++++++++++---------------------- 1 file changed, 40 insertions(+), 41 deletions(-) diff --git a/tensorlayer/layers.py b/tensorlayer/layers.py index a29d0ed1a..a89296273 100755 --- a/tensorlayer/layers.py +++ b/tensorlayer/layers.py @@ -15,7 +15,7 @@ from six.moves import xrange import random, warnings import copy - +import inspect # __all__ = [ # "Layer", # "DenseLayer", @@ -3397,7 +3397,10 @@ def __init__( # for input_ in tf.split(1, num_steps, inputs)] # outputs, state = rnn.rnn(cell, inputs, initial_state=self._initial_state) outputs = [] - self.cell = cell = cell_fn(num_units=n_hidden, **cell_init_args) + if 'reuse' in inspect.getargspec(cell_fn.__init__).args: + self.cell = cell = cell_fn(num_units=n_hidden, reuse=tf.get_variable_scope().reuse, **cell_init_args) + else: + self.cell = cell = cell_fn(num_units=n_hidden, **cell_init_args) if initial_state is None: self.initial_state = cell.zero_state(batch_size, dtype=tf.float32) # 1.2.3 state = self.initial_state @@ -3560,8 +3563,7 @@ def __init__( raise Exception("RNN : Input dimension should be rank 3 : [batch_size, n_steps, n_features]") with tf.variable_scope(name, initializer=initializer) as vs: - self.fw_cell = cell_fn(num_units=n_hidden, **cell_init_args) - self.bw_cell = cell_fn(num_units=n_hidden, **cell_init_args) + rnn_creator = lambda: cell_fn(num_units=n_hidden, **cell_init_args) # Apply dropout if dropout: if type(dropout) in [tuple, list]: @@ -3576,14 +3578,14 @@ def __init__( DropoutWrapper_fn = tf.contrib.rnn.DropoutWrapper except: DropoutWrapper_fn = tf.nn.rnn_cell.DropoutWrapper - self.fw_cell = DropoutWrapper_fn( - self.fw_cell, - input_keep_prob=in_keep_prob, - output_keep_prob=out_keep_prob) - self.bw_cell = DropoutWrapper_fn( - self.bw_cell, - input_keep_prob=in_keep_prob, - output_keep_prob=out_keep_prob) + cell_creator = lambda: DropoutWrapper_fn(rnn_creator(), + input_keep_prob=in_keep_prob, + output_keep_prob=1.0) # out_keep_prob) + else: + cell_creator = rnn_creator + self.fw_cell = cell_creator() + self.bw_cell = cell_creator() + # Apply multiple layers if n_layer > 1: try: # TF1.0 @@ -3592,13 +3594,11 @@ def __init__( MultiRNNCell_fn = tf.nn.rnn_cell.MultiRNNCell try: - self.fw_cell = MultiRNNCell_fn([self.fw_cell] * n_layer, - state_is_tuple=True) - self.bw_cell = MultiRNNCell_fn([self.bw_cell] * n_layer, - state_is_tuple=True) + self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True) + self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True) except: - self.fw_cell = MultiRNNCell_fn([self.fw_cell] * n_layer) - self.bw_cell = MultiRNNCell_fn([self.bw_cell] * n_layer) + self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)]) + self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)]) # Initial state of RNN if fw_initial_state is None: @@ -3938,7 +3938,7 @@ def __init__( # Creats the cell function # cell_instance_fn=lambda: cell_fn(num_units=n_hidden, **cell_init_args) # HanSheng - self.cell = cell_fn(num_units=n_hidden, **cell_init_args) + rnn_creator = lambda: cell_fn(num_units=n_hidden, **cell_init_args) # Apply dropout if dropout: @@ -3960,9 +3960,11 @@ def __init__( # cell_instance_fn1(), # input_keep_prob=in_keep_prob, # output_keep_prob=out_keep_prob) - self.cell = DropoutWrapper_fn(self.cell, + cell_creator = lambda: DropoutWrapper_fn(rnn_creator(), input_keep_prob=in_keep_prob, output_keep_prob=1.0)#out_keep_prob) - + else: + cell_creator = rnn_creator + self.cell = cell_creator() # Apply multiple layers if n_layer > 1: try: @@ -3973,10 +3975,10 @@ def __init__( # cell_instance_fn2=cell_instance_fn # HanSheng try: # cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)], state_is_tuple=True) # HanSheng - self.cell = MultiRNNCell_fn([self.cell] * n_layer, state_is_tuple=True) + self.cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True) except: # when GRU # cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)]) # HanSheng - self.cell = MultiRNNCell_fn([self.cell] * n_layer) + self.cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)]) if dropout: self.cell = DropoutWrapper_fn(self.cell, @@ -4179,8 +4181,7 @@ def __init__( with tf.variable_scope(name, initializer=initializer) as vs: # Creats the cell function # cell_instance_fn=lambda: cell_fn(num_units=n_hidden, **cell_init_args) # HanSheng - self.fw_cell = cell_fn(num_units=n_hidden, **cell_init_args) - self.bw_cell = cell_fn(num_units=n_hidden, **cell_init_args) + rnn_creator = lambda: cell_fn(num_units=n_hidden, **cell_init_args) # Apply dropout if dropout: @@ -4202,15 +4203,13 @@ def __init__( # cell_instance_fn1(), # input_keep_prob=in_keep_prob, # output_keep_prob=out_keep_prob) - - self.fw_cell = DropoutWrapper_fn( - self.fw_cell, - input_keep_prob=in_keep_prob, - output_keep_prob=out_keep_prob) - self.bw_cell = DropoutWrapper_fn( - self.bw_cell, - input_keep_prob=in_keep_prob, - output_keep_prob=out_keep_prob) + cell_creator = lambda: DropoutWrapper_fn(rnn_creator(), + input_keep_prob=in_keep_prob, + output_keep_prob=1.0) # out_keep_prob) + else: + cell_creator = rnn_creator + self.fw_cell = cell_creator() + self.bw_cell = cell_creator() # Apply multiple layers if n_layer > 1: try: @@ -4220,8 +4219,8 @@ def __init__( # cell_instance_fn2=cell_instance_fn # HanSheng # cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)]) - self.fw_cell = MultiRNNCell_fn([self.fw_cell] * n_layer) - self.bw_cell = MultiRNNCell_fn([self.bw_cell] * n_layer) + self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)]) + self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)]) # self.fw_cell=cell_instance_fn() # self.bw_cell=cell_instance_fn() # Initial state of RNN @@ -5256,17 +5255,17 @@ def sampled_loss(inputs, labels): # ============ Seq Encode Layer ============= # Create the internal multi-layer cell for our RNN. try: # TF1.0 - single_cell = tf.contrib.rnn.GRUCell(size) + cell_creator = lambda: tf.contrib.rnn.GRUCell(size) except: - single_cell = tf.nn.rnn_cell.GRUCell(size) + cell_creator = lambda: tf.nn.rnn_cell.GRUCell(size) if use_lstm: try: # TF1.0 - single_cell = tf.contrib.rnn.BasicLSTMCell(size) + cell_creator = lambda: tf.contrib.rnn.BasicLSTMCell(size) except: - single_cell = tf.nn.rnn_cell.BasicLSTMCell(size) + cell_creator = lambda: tf.nn.rnn_cell.BasicLSTMCell(size) - cell = single_cell + cell = cell_creator() if num_layers > 1: try: # TF1.0 cell = tf.contrib.rnn.MultiRNNCell([single_cell] * num_layers)