Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ To release a new version, please update the changelog as followed:
### Deprecated

### Fixed
- RNN updates: remove warnings, fix if seq_len=0, unitest (#PR 1033)

### Removed

Expand Down
24 changes: 15 additions & 9 deletions tensorlayer/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ class RNN(Layer):
Similar to the DynamicRNN in TL 1.x.

If the `sequence_length` is provided in RNN's forwarding and both `return_last_output` and `return_last_state`
are set as `True`, the forward function will automatically ignore the paddings.
are set as `True`, the forward function will automatically ignore the paddings. Note that if `return_last_output`
is set as `False`, the synced sequence outputs will still include outputs which correspond with paddings,
but users are free to select which slice of outputs to be used in following procedure.

The `sequence_length` should be a list of integers which indicates the length of each sequence.
It is recommended to
`tl.layers.retrieve_seq_length_op3 <https://tensorlayer.readthedocs.io/en/latest/modules/layers.html#compute-sequence-length-3>`__
Expand Down Expand Up @@ -244,16 +247,15 @@ def forward(self, inputs, sequence_length=None, initial_state=None, **kwargs):
"but got an actual length of a sequence %d" % i
)

sequence_length = [i - 1 for i in sequence_length]
sequence_length = [i - 1 if i >= 1 else 0 for i in sequence_length]

# set warning
if (not self.return_last_state or not self.return_last_output) and sequence_length is not None:
warnings.warn(
'return_last_output is set as %s ' % self.return_last_output +
'and return_last_state is set as %s. ' % self.return_last_state +
'When sequence_length is provided, both are recommended to set as True. ' +
'Otherwise, padding will be considered while RNN is forwarding.'
)
# if (not self.return_last_output) and sequence_length is not None:
# warnings.warn(
# 'return_last_output is set as %s ' % self.return_last_output +
# 'When sequence_length is provided, it is recommended to set as True. ' +
# 'Otherwise, padding will be considered while RNN is forwarding.'
# )

# return the last output, iterating each seq including padding ones. No need to store output during each
# time step.
Expand All @@ -274,6 +276,7 @@ def forward(self, inputs, sequence_length=None, initial_state=None, **kwargs):
self.cell.reset_recurrent_dropout_mask()

# recurrent computation
# FIXME: if sequence_length is provided (dynamic rnn), only iterate max(sequence_length) times.
for time_step in range(total_steps):

cell_output, states = self.cell.call(inputs[:, time_step, :], states, training=self.is_train)
Expand Down Expand Up @@ -758,6 +761,7 @@ def forward(self, inputs, fw_initial_state=None, bw_initial_state=None, **kwargs
return outputs


'''
class ConvRNNCell(object):
"""Abstract object representing an Convolutional RNN Cell."""

Expand Down Expand Up @@ -1071,6 +1075,8 @@ def __init__(
self._add_layers(self.outputs)
self._add_params(rnn_variables)

'''


# @tf.function
def retrieve_seq_length_op(data):
Expand Down
56 changes: 56 additions & 0 deletions tests/layers/test_layers_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@ def setUpClass(cls):
cls.hidden_size = 8
cls.num_steps = 6

cls.data_n_steps = np.random.randint(low=cls.num_steps // 2, high=cls.num_steps + 1, size=cls.batch_size)
cls.data_x = np.random.random([cls.batch_size, cls.num_steps, cls.embedding_size]).astype(np.float32)

for i in range(cls.batch_size):
for j in range(cls.data_n_steps[i], cls.num_steps):
cls.data_x[i][j][:] = 0

cls.data_y = np.zeros([cls.batch_size, 1]).astype(np.float32)
cls.data_y2 = np.zeros([cls.batch_size, cls.num_steps]).astype(np.float32)

Expand Down Expand Up @@ -865,6 +871,56 @@ def forward(self, x):
print(output.shape)
print(state)

def test_dynamic_rnn_with_fake_data(self):

class CustomisedModel(tl.models.Model):

def __init__(self):
super(CustomisedModel, self).__init__()
self.rnnlayer = tl.layers.LSTMRNN(
units=8, dropout=0.1, in_channels=4, return_last_output=True, return_last_state=False
)
self.dense = tl.layers.Dense(in_channels=8, n_units=1)

def forward(self, x):
z = self.rnnlayer(x, sequence_length=tl.layers.retrieve_seq_length_op3(x))
z = self.dense(z[:, :])
return z

rnn_model = CustomisedModel()
print(rnn_model)
optimizer = tf.optimizers.Adam(learning_rate=0.01)
rnn_model.train()

for epoch in range(50):
with tf.GradientTape() as tape:
pred_y = rnn_model(self.data_x)
loss = tl.cost.mean_squared_error(pred_y, self.data_y)

gradients = tape.gradient(loss, rnn_model.trainable_weights)
optimizer.apply_gradients(zip(gradients, rnn_model.trainable_weights))

if (epoch + 1) % 10 == 0:
print("epoch %d, loss %f" % (epoch, loss))

filename = "dynamic_rnn.h5"
rnn_model.save_weights(filename)

# Testing saving and restoring of RNN weights
rnn_model2 = CustomisedModel()
rnn_model2.eval()
pred_y = rnn_model2(self.data_x)
loss = tl.cost.mean_squared_error(pred_y, self.data_y)
print("MODEL INIT loss %f" % (loss))

rnn_model2.load_weights(filename)
pred_y = rnn_model2(self.data_x)
loss = tl.cost.mean_squared_error(pred_y, self.data_y)
print("MODEL RESTORE W loss %f" % (loss))

import os
os.remove(filename)


if __name__ == '__main__':

Expand Down