From 7a28a2b08f91e7b2e0bc5df34b41d209a5fa5317 Mon Sep 17 00:00:00 2001 From: Jingqing Zhang Date: Thu, 1 Aug 2019 10:05:45 +0100 Subject: [PATCH 1/8] remove return_last_state warning for RNN --- tensorlayer/layers/recurrent.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorlayer/layers/recurrent.py b/tensorlayer/layers/recurrent.py index 66ffe2211..4635f2f69 100644 --- a/tensorlayer/layers/recurrent.py +++ b/tensorlayer/layers/recurrent.py @@ -247,11 +247,10 @@ def forward(self, inputs, sequence_length=None, initial_state=None, **kwargs): sequence_length = [i - 1 for i in sequence_length] # set warning - if (not self.return_last_state or not self.return_last_output) and sequence_length is not None: + if (not self.return_last_output) and sequence_length is not None: warnings.warn( 'return_last_output is set as %s ' % self.return_last_output + - 'and return_last_state is set as %s. ' % self.return_last_state + - 'When sequence_length is provided, both are recommended to set as True. ' + + 'When sequence_length is provided, it is recommended to set as True. ' + 'Otherwise, padding will be considered while RNN is forwarding.' ) From 2e3458396d7f6e4e912061bf0577a002d7b27898 Mon Sep 17 00:00:00 2001 From: Jingqing Zhang Date: Thu, 1 Aug 2019 10:36:46 +0100 Subject: [PATCH 2/8] comment warning, fix if seq_len=0 --- tensorlayer/layers/recurrent.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tensorlayer/layers/recurrent.py b/tensorlayer/layers/recurrent.py index 4635f2f69..398fdbe28 100644 --- a/tensorlayer/layers/recurrent.py +++ b/tensorlayer/layers/recurrent.py @@ -244,15 +244,15 @@ def forward(self, inputs, sequence_length=None, initial_state=None, **kwargs): "but got an actual length of a sequence %d" % i ) - sequence_length = [i - 1 for i in sequence_length] + sequence_length = [i - 1 if i >= 1 else 0 for i in sequence_length] # set warning - if (not self.return_last_output) and sequence_length is not None: - warnings.warn( - 'return_last_output is set as %s ' % self.return_last_output + - 'When sequence_length is provided, it is recommended to set as True. ' + - 'Otherwise, padding will be considered while RNN is forwarding.' - ) + # if (not self.return_last_output) and sequence_length is not None: + # warnings.warn( + # 'return_last_output is set as %s ' % self.return_last_output + + # 'When sequence_length is provided, it is recommended to set as True. ' + + # 'Otherwise, padding will be considered while RNN is forwarding.' + # ) # return the last output, iterating each seq including padding ones. No need to store output during each # time step. @@ -273,6 +273,7 @@ def forward(self, inputs, sequence_length=None, initial_state=None, **kwargs): self.cell.reset_recurrent_dropout_mask() # recurrent computation + # FIXME: if sequence_length is provided (dynamic rnn), only iterate max(sequence_length) times. for time_step in range(total_steps): cell_output, states = self.cell.call(inputs[:, time_step, :], states, training=self.is_train) From aeb0fb06c19634597b939a502ab1d0331d088bcd Mon Sep 17 00:00:00 2001 From: Jingqing Zhang Date: Thu, 1 Aug 2019 10:39:06 +0100 Subject: [PATCH 3/8] comment outdated classes --- tensorlayer/layers/recurrent.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorlayer/layers/recurrent.py b/tensorlayer/layers/recurrent.py index 398fdbe28..5d461c34d 100644 --- a/tensorlayer/layers/recurrent.py +++ b/tensorlayer/layers/recurrent.py @@ -758,6 +758,7 @@ def forward(self, inputs, fw_initial_state=None, bw_initial_state=None, **kwargs return outputs +''' class ConvRNNCell(object): """Abstract object representing an Convolutional RNN Cell.""" @@ -1071,6 +1072,7 @@ def __init__( self._add_layers(self.outputs) self._add_params(rnn_variables) +''' # @tf.function def retrieve_seq_length_op(data): From e932fe898fccb65539e138560b4297dd03cbc5e3 Mon Sep 17 00:00:00 2001 From: Jingqing Zhang Date: Thu, 1 Aug 2019 10:43:01 +0100 Subject: [PATCH 4/8] test dynamic rnn with fake data --- tests/layers/test_layers_recurrent.py | 50 ++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/tests/layers/test_layers_recurrent.py b/tests/layers/test_layers_recurrent.py index b974b5b8b..517586700 100644 --- a/tests/layers/test_layers_recurrent.py +++ b/tests/layers/test_layers_recurrent.py @@ -18,7 +18,7 @@ class Layer_RNN_Test(CustomTestCase): @classmethod def setUpClass(cls): - cls.batch_size = 2 + cls.batch_size = 10 cls.vocab_size = 20 cls.embedding_size = 4 @@ -26,7 +26,13 @@ def setUpClass(cls): cls.hidden_size = 8 cls.num_steps = 6 + cls.data_n_steps = np.random.randint(low=cls.num_steps // 2, high=cls.num_steps + 1, size=cls.batch_size) cls.data_x = np.random.random([cls.batch_size, cls.num_steps, cls.embedding_size]).astype(np.float32) + + for i in range(cls.batch_size): + for j in range(cls.data_n_steps[i], cls.num_steps): + cls.data_x[i][j][:] = 0 + cls.data_y = np.zeros([cls.batch_size, 1]).astype(np.float32) cls.data_y2 = np.zeros([cls.batch_size, cls.num_steps]).astype(np.float32) @@ -865,6 +871,48 @@ def forward(self, x): print(output.shape) print(state) + def test_dynamic_rnn_with_fake_data(self): + + class CustomisedModel(tl.models.Model): + + def __init__(self): + super(CustomisedModel, self).__init__() + self.rnnlayer = tl.layers.LSTMRNN( + units=8, dropout=0.1, in_channels=4, + return_last_output=True, + return_last_state=False + ) + self.dense = tl.layers.Dense(in_channels=8, n_units=1) + + def forward(self, x): + z = self.rnnlayer(x, sequence_length=tl.layers.retrieve_seq_length_op3(x)) + z = self.dense(z[:, :]) + return z + + rnn_model = CustomisedModel() + print(rnn_model) + optimizer = tf.optimizers.Adam(learning_rate=0.01) + rnn_model.train() + + for epoch in range(50): + with tf.GradientTape() as tape: + pred_y = rnn_model(self.data_x) + loss = tl.cost.mean_squared_error(pred_y, self.data_y) + + gradients = tape.gradient(loss, rnn_model.trainable_weights) + optimizer.apply_gradients(zip(gradients, rnn_model.trainable_weights)) + + if (epoch + 1) % 10 == 0: + print("epoch %d, loss %f" % (epoch, loss)) + + # Testing saving and restoring of RNN weights + rnn_model2 = CustomisedModel() + rnn_model2.eval() + pred_y = rnn_model2(self.data_x) + loss = tl.cost.mean_squared_error(pred_y, self.data_y) + print("MODEL INIT loss %f" % (loss)) + + if __name__ == '__main__': From 6173d1a73e178da6be953cf7743ddd8643be6aa3 Mon Sep 17 00:00:00 2001 From: Jingqing Zhang Date: Thu, 1 Aug 2019 10:45:04 +0100 Subject: [PATCH 5/8] test dynamic rnn with fake data --- tests/layers/test_layers_recurrent.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/layers/test_layers_recurrent.py b/tests/layers/test_layers_recurrent.py index 517586700..9ae9324d9 100644 --- a/tests/layers/test_layers_recurrent.py +++ b/tests/layers/test_layers_recurrent.py @@ -905,6 +905,9 @@ def forward(self, x): if (epoch + 1) % 10 == 0: print("epoch %d, loss %f" % (epoch, loss)) + filename = "dynamic_rnn.h5" + rnn_model.save_weights(filename) + # Testing saving and restoring of RNN weights rnn_model2 = CustomisedModel() rnn_model2.eval() @@ -912,6 +915,11 @@ def forward(self, x): loss = tl.cost.mean_squared_error(pred_y, self.data_y) print("MODEL INIT loss %f" % (loss)) + rnn_model2.load_weights(filename) + pred_y = rnn_model2(self.data_x) + loss = tl.cost.mean_squared_error(pred_y, self.data_y) + print("MODEL RESTORE W loss %f" % (loss)) + if __name__ == '__main__': From ccf27f6455ad79ab3592c99ab2de7f3a25688403 Mon Sep 17 00:00:00 2001 From: Jingqing Zhang Date: Thu, 1 Aug 2019 10:50:46 +0100 Subject: [PATCH 6/8] test dynamic rnn with fake data and test saving and restoring of dynamic rnn --- tests/layers/test_layers_recurrent.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/layers/test_layers_recurrent.py b/tests/layers/test_layers_recurrent.py index 9ae9324d9..500445a76 100644 --- a/tests/layers/test_layers_recurrent.py +++ b/tests/layers/test_layers_recurrent.py @@ -920,6 +920,9 @@ def forward(self, x): loss = tl.cost.mean_squared_error(pred_y, self.data_y) print("MODEL RESTORE W loss %f" % (loss)) + import os + os.remove(filename) + if __name__ == '__main__': From bf8ff1034a18dc2317cd2a175f8fa4a4896b7ecb Mon Sep 17 00:00:00 2001 From: Jingqing Zhang Date: Thu, 1 Aug 2019 11:06:23 +0100 Subject: [PATCH 7/8] yapf format and solve travis-ci problem --- tensorlayer/layers/recurrent.py | 6 +++++- tests/layers/test_layers_recurrent.py | 7 ++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorlayer/layers/recurrent.py b/tensorlayer/layers/recurrent.py index 5d461c34d..1814dd727 100644 --- a/tensorlayer/layers/recurrent.py +++ b/tensorlayer/layers/recurrent.py @@ -105,7 +105,10 @@ class RNN(Layer): Similar to the DynamicRNN in TL 1.x. If the `sequence_length` is provided in RNN's forwarding and both `return_last_output` and `return_last_state` - are set as `True`, the forward function will automatically ignore the paddings. + are set as `True`, the forward function will automatically ignore the paddings. Note that if `return_last_output` + is set as `False`, the synced sequence outputs will still include outputs which correspond with paddings, + but users are free to select which slice of outputs to be used in following procedure. + The `sequence_length` should be a list of integers which indicates the length of each sequence. It is recommended to `tl.layers.retrieve_seq_length_op3 `__ @@ -1074,6 +1077,7 @@ def __init__( ''' + # @tf.function def retrieve_seq_length_op(data): """An op to compute the length of a sequence from input shape of [batch_size, n_step(max), n_features], diff --git a/tests/layers/test_layers_recurrent.py b/tests/layers/test_layers_recurrent.py index 500445a76..4309eae02 100644 --- a/tests/layers/test_layers_recurrent.py +++ b/tests/layers/test_layers_recurrent.py @@ -18,7 +18,7 @@ class Layer_RNN_Test(CustomTestCase): @classmethod def setUpClass(cls): - cls.batch_size = 10 + cls.batch_size = 2 cls.vocab_size = 20 cls.embedding_size = 4 @@ -878,9 +878,7 @@ class CustomisedModel(tl.models.Model): def __init__(self): super(CustomisedModel, self).__init__() self.rnnlayer = tl.layers.LSTMRNN( - units=8, dropout=0.1, in_channels=4, - return_last_output=True, - return_last_state=False + units=8, dropout=0.1, in_channels=4, return_last_output=True, return_last_state=False ) self.dense = tl.layers.Dense(in_channels=8, n_units=1) @@ -924,7 +922,6 @@ def forward(self, x): os.remove(filename) - if __name__ == '__main__': unittest.main() From 22ca25f914661f3408a0d0034af1c599a1e56e68 Mon Sep 17 00:00:00 2001 From: Jingqing Zhang Date: Fri, 2 Aug 2019 14:14:03 +0100 Subject: [PATCH 8/8] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b12ea687c..5b22341f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -79,6 +79,7 @@ To release a new version, please update the changelog as followed: ### Deprecated ### Fixed +- RNN updates: remove warnings, fix if seq_len=0, unitest (#PR 1033) ### Removed