Skip to content

Commit

Permalink
Dynamic RNN without padding (#1023)
Browse files Browse the repository at this point in the history
* without_padding

* return output

* check actual_length type and values

* set warning and unittest

* unittest of dyanmic rnn

* yapf

* check the length of intitial_length

* check the length of intitial_length

* minor update to unittest and yapf

* update change log

* fix state return bug and unittest passed

* add documentation

* yapf formt
  • Loading branch information
ArnoldLIULJ authored and zsdonghao committed Jul 19, 2019
1 parent 6957a25 commit 95c4ca2
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 7 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ To release a new version, please update the changelog as followed:
### Added
- Support nested layer customization (#PR 1015)
- Support string dtype in InputLayer (#PR 1017)
- Support Dynamic RNN in RNN (#PR 1023)

### Changed

Expand All @@ -97,6 +98,8 @@ To release a new version, please update the changelog as followed:
- @zsdonghao
- @ChrisWu1997: #1010 #1015
- @warshallrho: #1017 #1021
- @ArnoldLIULJ: #1023
- @JingqingZ: #1023

## [2.1.0]

Expand Down
120 changes: 113 additions & 7 deletions tensorlayer/layers/recurrent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#! /usr/bin/python
# -*- coding: utf-8 -*-

import numpy as np
import tensorflow as tf
import tensorlayer as tl
from tensorlayer import logging
from tensorlayer.decorators import deprecated_alias
from tensorlayer.layers.core import Layer
import warnings

# TODO: uncomment
__all__ = [
Expand Down Expand Up @@ -99,6 +101,33 @@ class RNN(Layer):
>>> outputs = tl.layers.Dense(n_units=1)(rnn_out2)
>>> rnn_model = tl.models.Model(inputs=inputs, outputs=outputs)
An example if the sequences have different length and contain padding.
Similar to the DynamicRNN in TL 1.x.
If the `sequence_length` is provided in RNN's forwarding and both `return_last_output` and `return_last_state`
are set as `True`, the forward function will automatically ignore the paddings.
The `sequence_length` should be a list of integers which indicates the length of each sequence.
It is recommended to
`tl.layers.retrieve_seq_length_op3 <https://tensorlayer.readthedocs.io/en/latest/modules/layers.html#compute-sequence-length-3>`__
to calculate the `sequence_length`.
>>> data = [[[1], [2], [0], [0], [0]], [[1], [2], [3], [0], [0]], [[1], [2], [6], [1], [1]]]
>>> data = tf.convert_to_tensor(data, dtype=tf.float32)
>>> class DynamicRNNExample(tl.models.Model):
>>> def __init__(self):
>>> super(DynamicRNNExample, self).__init__()
>>> self.rnnlayer = tl.layers.RNN(
>>> cell=tf.keras.layers.SimpleRNNCell(units=6, dropout=0.1), in_channels=1, return_last_output=True,
>>> return_last_state=True
>>> )
>>> def forward(self, x):
>>> z, s = self.rnnlayer(x, sequence_length=tl.layers.retrieve_seq_length_op3(x))
>>> return z, s
>>> model = DynamicRNNExample()
>>> model.eval()
>>> output, state = model(data)
Notes
-----
Input dimension should be rank 3 : [batch_size, n_steps, n_features], if no, please see layer :class:`Reshape`.
Expand Down Expand Up @@ -157,49 +186,116 @@ def build(self, inputs_shape):
self._trainable_weights.append(var)

# @tf.function
def forward(self, inputs, initial_state=None, **kwargs):
def forward(self, inputs, sequence_length=None, initial_state=None, **kwargs):
"""
Parameters
----------
inputs : input tensor
The input of a network
sequence_length: None or list of integers
The actual length of each sequence in batch without padding.
If provided, when `return_last_output` and `return_last_state` are `True`,
the RNN will perform in the manner of a dynamic RNN, i.e.
the RNN will return the actual last output / state without padding.
initial_state : None or list of Tensor (RNN State)
If None, `initial_state` is zero state.
**kwargs: dict
Some attributes can be updated during forwarding
such as `return_last_output`, `return_seq_2d`, `return_last_state`.
"""

if kwargs:
for attr in kwargs:
if attr in self.__dict__:
setattr(self, attr, kwargs[attr])

if self.return_last_output:
batch_size = inputs.get_shape().as_list()[0]
total_steps = inputs.get_shape().as_list()[1]

# checking the type and values of sequence_length
if sequence_length is not None:
if isinstance(sequence_length, list):
pass
elif isinstance(sequence_length, tf.Tensor):
pass
elif isinstance(sequence_length, np.ndarray):
sequence_length = sequence_length.tolist()
else:
raise TypeError(
"The argument sequence_length should be either None or a list of integers. "
"Type got %s" % type(sequence_length)
)
if (len(sequence_length) != batch_size):
raise ValueError(
"The argument sequence_length should contain %d " % batch_size +
"elements indicating the initial length of each sequence, but got only %d. " % len(sequence_length)
)
for i in sequence_length:
if not (type(i) is int or (isinstance(i, tf.Tensor) and i.dtype.is_integer)):
raise TypeError(
"The argument sequence_length should be either None or a list of integers. "
"One element of sequence_length has the type %s" % type(i)
)
if i > total_steps:
raise ValueError(
"The actual length of a sequence should not be longer than "
"that of the longest sequence (total steps) in this mini-batch. "
"Total steps of this mini-batch %d, " % total_steps +
"but got an actual length of a sequence %d" % i
)

sequence_length = [i - 1 for i in sequence_length]

# set warning
if (not self.return_last_state or not self.return_last_output) and sequence_length is not None:
warnings.warn(
'return_last_output is set as %s ' % self.return_last_output +
'and return_last_state is set as %s. ' % self.return_last_state +
'When sequence_length is provided, both are recommended to set as True. ' +
'Otherwise, padding will be considered while RNN is forwarding.'
)

# return the last output, iterating each seq including padding ones. No need to store output during each
# time step.
if self.return_last_output and sequence_length is None:
outputs = [-1]
else:
outputs = list()

# initialize the states if provided
states = initial_state if initial_state is not None else self.cell.get_initial_state(inputs)
if not isinstance(states, list):
states = [states]

total_steps = inputs.get_shape().as_list()[1]
stored_states = list()

# initialize the cell
self.cell.reset_dropout_mask()
self.cell.reset_recurrent_dropout_mask()

# recurrent computation
for time_step in range(total_steps):

cell_output, states = self.cell.call(inputs[:, time_step, :], states, training=self.is_train)
stored_states.append(states)

if self.return_last_output:
if self.return_last_output and sequence_length is None:
outputs[-1] = cell_output
else:
outputs.append(cell_output)

if self.return_last_output:
# prepare to return results
if self.return_last_output and sequence_length is None:
outputs = outputs[-1]

elif self.return_last_output and sequence_length is not None:
outputs = tf.convert_to_tensor(outputs)
outputs = tf.gather(outputs, sequence_length, axis=0)

outputs_without_padding = []
for i in range(batch_size):
outputs_without_padding.append(outputs[i][i][:])
outputs = tf.convert_to_tensor(outputs_without_padding)
else:
if self.return_seq_2d:
# PTB tutorial: stack dense layer after that, or compute the cost from the output
Expand All @@ -210,7 +306,17 @@ def forward(self, inputs, initial_state=None, **kwargs):
# 3D Tensor [batch_size, n_steps, n_hidden]
outputs = tf.reshape(tf.concat(outputs, 1), [-1, total_steps, self.cell.units])

if self.return_last_state:
if self.return_last_state and sequence_length is None:
return outputs, states
elif self.return_last_state and sequence_length is not None:

stored_states = tf.convert_to_tensor(stored_states)
stored_states = tf.gather(stored_states, sequence_length, axis=0)

states = []
for i in range(stored_states.shape[1]):
states.append(tf.convert_to_tensor([stored_states[b, i, b, :] for b in range(batch_size)]))

return outputs, states
else:
return outputs
Expand Down
96 changes: 96 additions & 0 deletions tests/layers/test_layers_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,102 @@ def test_target_mask_op(self):
if fail_flag:
self.fail("Wrong data shape not detected.")

def test_dynamic_rnn(self):
batch_size = 3
num_steps = 5
embedding_size = 6

hidden_size = 4
inputs = tl.layers.Input([batch_size, num_steps, embedding_size])

rnn_layer = tl.layers.RNN(
cell=tf.keras.layers.LSTMCell(units=hidden_size, dropout=0.1), in_channels=embedding_size,
return_last_output=True, return_last_state=True
)

rnn_layer.is_train = False

print(tl.layers.retrieve_seq_length_op3(inputs))
_ = rnn_layer(inputs, sequence_length=tl.layers.retrieve_seq_length_op3(inputs))
_ = rnn_layer(inputs, sequence_length=np.array([5, 5, 5]))

# test exceptions
except_flag = False
try:
_ = rnn_layer(inputs, sequence_length=1)
except_flag = True
except TypeError as e:
print(e)

try:
_ = rnn_layer(inputs, sequence_length=["str", 1, 2])
except_flag = True
except TypeError as e:
print(e)

try:
_ = rnn_layer(inputs, sequence_length=[10, 2, 2])
except_flag = True
except ValueError as e:
print(e)

try:
_ = rnn_layer(inputs, sequence_length=[1])
except_flag = True
except ValueError as e:
print(e)

if except_flag:
self.fail("Exception not detected.")

# test warning
for _ in range(5):
_ = rnn_layer(inputs, sequence_length=[5, 5, 5], return_last_output=False, return_last_state=True)
_ = rnn_layer(inputs, sequence_length=[5, 5, 5], return_last_output=True, return_last_state=False)

x = rnn_layer(inputs, sequence_length=None, return_last_output=True, return_last_state=True)
y = rnn_layer(inputs, sequence_length=[5, 5, 5], return_last_output=True, return_last_state=True)

assert len(x) == 2
assert len(y) == 2

for i, j in zip(x, y):
self.assertTrue(np.allclose(i, j))

def test_dynamic_rnn_with_seq_len_op2(self):
data = [[[1], [2], [0], [0], [0]], [[1], [2], [3], [0], [0]], [[1], [2], [6], [1], [1]]]
data = tf.convert_to_tensor(data, dtype=tf.float32)

class DynamicRNNExample(tl.models.Model):

def __init__(self):
super(DynamicRNNExample, self).__init__()

self.rnnlayer = tl.layers.RNN(
cell=tf.keras.layers.SimpleRNNCell(units=6, dropout=0.1), in_channels=1, return_last_output=True,
return_last_state=True
)

def forward(self, x):
z0, s0 = self.rnnlayer(x, sequence_length=None)
z1, s1 = self.rnnlayer(x, sequence_length=tl.layers.retrieve_seq_length_op3(x))
z2, s2 = self.rnnlayer(x, sequence_length=tl.layers.retrieve_seq_length_op3(x), initial_state=s1)
print(z0)
print(z1)
print(z2)
print("===")
print(s0)
print(s1)
print(s2)
return z2, s2

model = DynamicRNNExample()
model.eval()

output, state = model(data)
print(output.shape)
print(state)


if __name__ == '__main__':

Expand Down

0 comments on commit 95c4ca2

Please sign in to comment.