Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 40 additions & 41 deletions tensorlayer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from six.moves import xrange
import random, warnings
import copy

import inspect
# __all__ = [
# "Layer",
# "DenseLayer",
Expand Down Expand Up @@ -3397,7 +3397,10 @@ def __init__(
# for input_ in tf.split(1, num_steps, inputs)]
# outputs, state = rnn.rnn(cell, inputs, initial_state=self._initial_state)
outputs = []
self.cell = cell = cell_fn(num_units=n_hidden, **cell_init_args)
if 'reuse' in inspect.getargspec(cell_fn.__init__).args:
self.cell = cell = cell_fn(num_units=n_hidden, reuse=tf.get_variable_scope().reuse, **cell_init_args)
else:
self.cell = cell = cell_fn(num_units=n_hidden, **cell_init_args)
if initial_state is None:
self.initial_state = cell.zero_state(batch_size, dtype=tf.float32) # 1.2.3
state = self.initial_state
Expand Down Expand Up @@ -3560,8 +3563,7 @@ def __init__(
raise Exception("RNN : Input dimension should be rank 3 : [batch_size, n_steps, n_features]")

with tf.variable_scope(name, initializer=initializer) as vs:
self.fw_cell = cell_fn(num_units=n_hidden, **cell_init_args)
self.bw_cell = cell_fn(num_units=n_hidden, **cell_init_args)
rnn_creator = lambda: cell_fn(num_units=n_hidden, **cell_init_args)
# Apply dropout
if dropout:
if type(dropout) in [tuple, list]:
Expand All @@ -3576,14 +3578,14 @@ def __init__(
DropoutWrapper_fn = tf.contrib.rnn.DropoutWrapper
except:
DropoutWrapper_fn = tf.nn.rnn_cell.DropoutWrapper
self.fw_cell = DropoutWrapper_fn(
self.fw_cell,
input_keep_prob=in_keep_prob,
output_keep_prob=out_keep_prob)
self.bw_cell = DropoutWrapper_fn(
self.bw_cell,
input_keep_prob=in_keep_prob,
output_keep_prob=out_keep_prob)
cell_creator = lambda: DropoutWrapper_fn(rnn_creator(),
input_keep_prob=in_keep_prob,
output_keep_prob=1.0) # out_keep_prob)
else:
cell_creator = rnn_creator
self.fw_cell = cell_creator()
self.bw_cell = cell_creator()

# Apply multiple layers
if n_layer > 1:
try: # TF1.0
Expand All @@ -3592,13 +3594,11 @@ def __init__(
MultiRNNCell_fn = tf.nn.rnn_cell.MultiRNNCell

try:
self.fw_cell = MultiRNNCell_fn([self.fw_cell] * n_layer,
state_is_tuple=True)
self.bw_cell = MultiRNNCell_fn([self.bw_cell] * n_layer,
state_is_tuple=True)
self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True)
self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True)
except:
self.fw_cell = MultiRNNCell_fn([self.fw_cell] * n_layer)
self.bw_cell = MultiRNNCell_fn([self.bw_cell] * n_layer)
self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])

# Initial state of RNN
if fw_initial_state is None:
Expand Down Expand Up @@ -3938,7 +3938,7 @@ def __init__(

# Creats the cell function
# cell_instance_fn=lambda: cell_fn(num_units=n_hidden, **cell_init_args) # HanSheng
self.cell = cell_fn(num_units=n_hidden, **cell_init_args)
rnn_creator = lambda: cell_fn(num_units=n_hidden, **cell_init_args)

# Apply dropout
if dropout:
Expand All @@ -3960,9 +3960,11 @@ def __init__(
# cell_instance_fn1(),
# input_keep_prob=in_keep_prob,
# output_keep_prob=out_keep_prob)
self.cell = DropoutWrapper_fn(self.cell,
cell_creator = lambda: DropoutWrapper_fn(rnn_creator(),
input_keep_prob=in_keep_prob, output_keep_prob=1.0)#out_keep_prob)

else:
cell_creator = rnn_creator
self.cell = cell_creator()
# Apply multiple layers
if n_layer > 1:
try:
Expand All @@ -3973,10 +3975,10 @@ def __init__(
# cell_instance_fn2=cell_instance_fn # HanSheng
try:
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)], state_is_tuple=True) # HanSheng
self.cell = MultiRNNCell_fn([self.cell] * n_layer, state_is_tuple=True)
self.cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True)
except: # when GRU
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)]) # HanSheng
self.cell = MultiRNNCell_fn([self.cell] * n_layer)
self.cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])

if dropout:
self.cell = DropoutWrapper_fn(self.cell,
Expand Down Expand Up @@ -4179,8 +4181,7 @@ def __init__(
with tf.variable_scope(name, initializer=initializer) as vs:
# Creats the cell function
# cell_instance_fn=lambda: cell_fn(num_units=n_hidden, **cell_init_args) # HanSheng
self.fw_cell = cell_fn(num_units=n_hidden, **cell_init_args)
self.bw_cell = cell_fn(num_units=n_hidden, **cell_init_args)
rnn_creator = lambda: cell_fn(num_units=n_hidden, **cell_init_args)

# Apply dropout
if dropout:
Expand All @@ -4202,15 +4203,13 @@ def __init__(
# cell_instance_fn1(),
# input_keep_prob=in_keep_prob,
# output_keep_prob=out_keep_prob)

self.fw_cell = DropoutWrapper_fn(
self.fw_cell,
input_keep_prob=in_keep_prob,
output_keep_prob=out_keep_prob)
self.bw_cell = DropoutWrapper_fn(
self.bw_cell,
input_keep_prob=in_keep_prob,
output_keep_prob=out_keep_prob)
cell_creator = lambda: DropoutWrapper_fn(rnn_creator(),
input_keep_prob=in_keep_prob,
output_keep_prob=1.0) # out_keep_prob)
else:
cell_creator = rnn_creator
self.fw_cell = cell_creator()
self.bw_cell = cell_creator()
# Apply multiple layers
if n_layer > 1:
try:
Expand All @@ -4220,8 +4219,8 @@ def __init__(

# cell_instance_fn2=cell_instance_fn # HanSheng
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)])
self.fw_cell = MultiRNNCell_fn([self.fw_cell] * n_layer)
self.bw_cell = MultiRNNCell_fn([self.bw_cell] * n_layer)
self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
# self.fw_cell=cell_instance_fn()
# self.bw_cell=cell_instance_fn()
# Initial state of RNN
Expand Down Expand Up @@ -5256,17 +5255,17 @@ def sampled_loss(inputs, labels):
# ============ Seq Encode Layer =============
# Create the internal multi-layer cell for our RNN.
try: # TF1.0
single_cell = tf.contrib.rnn.GRUCell(size)
cell_creator = lambda: tf.contrib.rnn.GRUCell(size)
except:
single_cell = tf.nn.rnn_cell.GRUCell(size)
cell_creator = lambda: tf.nn.rnn_cell.GRUCell(size)

if use_lstm:
try: # TF1.0
single_cell = tf.contrib.rnn.BasicLSTMCell(size)
cell_creator = lambda: tf.contrib.rnn.BasicLSTMCell(size)
except:
single_cell = tf.nn.rnn_cell.BasicLSTMCell(size)
cell_creator = lambda: tf.nn.rnn_cell.BasicLSTMCell(size)

cell = single_cell
cell = cell_creator()
if num_layers > 1:
try: # TF1.0
cell = tf.contrib.rnn.MultiRNNCell([single_cell] * num_layers)
Expand Down