From fbc593eb02cc939a7d4a14cd4ede550f38235c3b Mon Sep 17 00:00:00 2001 From: Patrick Date: Wed, 22 Feb 2017 00:28:48 +0800 Subject: [PATCH] use tf.contrib.rnn.* for TF1.0 --- tensorlayer/layers.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/tensorlayer/layers.py b/tensorlayer/layers.py index 3e29ac41a..164e6bb78 100644 --- a/tensorlayer/layers.py +++ b/tensorlayer/layers.py @@ -4678,12 +4678,23 @@ def sampled_loss(inputs, labels): # ============ Seq Encode Layer ============= # Create the internal multi-layer cell for our RNN. - single_cell = tf.nn.rnn_cell.GRUCell(size) + try: # TF1.0 + single_cell = tf.contrib.rnn.GRUCell(size) + except: + single_cell = tf.nn.rnn_cell.GRUCell(size) + if use_lstm: - single_cell = tf.nn.rnn_cell.BasicLSTMCell(size) + try: # TF1.0 + single_cell = tf.contrib.rnn.BasicLSTMCell(size) + except: + single_cell = tf.nn.rnn_cell.BasicLSTMCell(size) + cell = single_cell if num_layers > 1: - cell = tf.nn.rnn_cell.MultiRNNCell([single_cell] * num_layers) + try: # TF1.0 + cell = tf.contrib.rnn.MultiRNNCell([single_cell] * num_layers) + except: + cell = tf.nn.rnn_cell.MultiRNNCell([single_cell] * num_layers) # ============== Seq Decode Layer ============ # The seq2seq function: we use embedding for the input and attention.