Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model.reset_states() does not work for bidirectional-RNNs in tf.keras. #34055

Closed
keithchugg opened this issue Nov 7, 2019 · 3 comments
Closed
Assignees
Labels
comp:keras Keras related issues TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug

Comments

@keithchugg
Copy link

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): YES.
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): macOS 10.14.6
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 2.0.0
  • Python version: 3.7.4
  • GPU model and memory: none (MacBook Pro, Core i5, Iris Graphics 6100, 1.5 GB)

Describe the current behavior
State handling in RNNs with a Bidirectional wrapper has changed in tf.keras from keras with TF 1.x. In the old keras with TF 1.x, using stateful=True in a bidi-RNN had no effect -- i.e., all bidi-RNN models behaved as if stateful=False. Therefore model.reset_states() did not do anything.

In the new tf.keras, stateful=True in a bidi-RNN does have an effect -- the fwd-RNN is stateful and the bwd-RNN is stateful. This is a good change IMO -- even though stateful bidi-RNNs are unusual, this is the best way to implement. However, in tf.keras, the model.reset_states() does not do anything for bidi-RNN models (SimpleRNN, GRU, LSTM).

Describe the expected behavior

For the minimal example script provided below, here is the output:

FWD::
non_stateful: [ 1.   -0.5   0.25]
stateful: [ 1.   -0.5   0.25]
delta: [0. 0. 0.]
BWD::
non_stateful: [1. 0. 0.]
stateful: [1. 0. 0.]
delta: [0. 0. 0.]
FWD::
non_stateful: [ 1.   -0.5   0.25]
stateful: [ 0.875   -0.4375   0.21875]
delta: [-0.125    0.0625  -0.03125]
BWD::
non_stateful: [1. 0. 0.]
stateful: [ 0.875  0.25  -0.5  ]
delta: [-0.125  0.25  -0.5  ]

** RESETING STATES in STATEFUL MODEL **

FWD::
non_stateful: [ 1.   -0.5   0.25]
stateful: [ 0.890625   -0.4453125   0.22265625]
delta: [-0.109375    0.0546875  -0.02734375]
BWD::
non_stateful: [1. 0. 0.]
stateful: [ 0.890625  0.21875  -0.4375  ]
delta: [-0.109375  0.21875  -0.4375  ]

The results after the STATE RESET should be the same as the first set of results -- i.e., the last (third) set of results should produce the same result for the stateful and non-stateful models (same as the first set of results).

Code to reproduce the issue

import numpy as np
TF2 = True
if TF2:
	### currently, there is a bug in tf.keras: model.reset_states() does not work
	from tensorflow.keras.layers import Input, Dense, SimpleRNN, GRU, LSTM, Bidirectional
	from tensorflow.keras.models import Model
else:
	### in the old keras, bidi-RNNs with stateful=True behave smae as stateful=False
	from keras.layers import Input, Dense, SimpleRNN, GRU, LSTM, Bidirectional
	from keras.models import Model

sequence_length = 3
feature_dim = 1
features_in = Input(batch_shape=(1, sequence_length, feature_dim)) 

rnn_out = Bidirectional( SimpleRNN(1, activation=None, use_bias=False, return_sequences=True, return_state=False, stateful=False))(features_in)
stateless_model = Model(inputs=[features_in], outputs=[rnn_out])

stateful_rnn_out = Bidirectional( SimpleRNN(1, activation=None, use_bias=False, return_sequences=True, return_state=False, stateful=True))(features_in)
stateful_model = Model(inputs=features_in, outputs=stateful_rnn_out)

toy_weights = [ np.asarray([[1.0]], dtype=np.float32), np.asarray([[-0.5]], dtype=np.float32), np.asarray([[1.0]], dtype=np.float32), np.asarray([[-0.5]], dtype=np.float32)]

stateless_model.set_weights(toy_weights)
stateful_model.set_weights(toy_weights)

x_in = np.zeros(sequence_length)
x_in[0] = 1
x_in = x_in.reshape( (1, sequence_length, feature_dim) )

def print_bidi_out(non_stateful_out, stateful_out):
	fb = ['FWD::', 'BWD::']

	for i in range(2):
		print(fb[i])
		print(f'non_stateful: {non_stateful_out.T[i]}')
		print(f'stateful: {stateful_out.T[i]}')
		print(f'delta: {stateful_out.T[i]-non_stateful_out.T[i]}')


non_stateful_out = stateless_model.predict(x_in).reshape((sequence_length,2))
stateful_out = stateful_model.predict(x_in).reshape((sequence_length,2))
print_bidi_out(non_stateful_out, stateful_out)

non_stateful_out = stateless_model.predict(x_in).reshape((sequence_length,2))
stateful_out = stateful_model.predict(x_in).reshape((sequence_length,2))
print_bidi_out(non_stateful_out, stateful_out)

print('\n** RESETING STATES in STATEFUL MODEL **\n')
stateful_model.reset_states()
non_stateful_out = stateless_model.predict(x_in).reshape((sequence_length,2))
stateful_out = stateful_model.predict(x_in).reshape((sequence_length,2))
print_bidi_out(non_stateful_out, stateful_out)
@keithchugg keithchugg changed the title model.reset_states() does not work for bidirectional-RNNs. model.reset_states() does not work for bidirectional-RNNs in tf.keras. Nov 7, 2019
@keithchugg
Copy link
Author

Note: this is an issue with tf.keras vs. keras (not TF 1.x vs TF 2.0)

@gadagashwini-zz gadagashwini-zz self-assigned this Nov 8, 2019
@gadagashwini-zz gadagashwini-zz added TF 2.0 Issues relating to TensorFlow 2.0 comp:keras Keras related issues type:bug Bug labels Nov 8, 2019
@gadagashwini-zz
Copy link
Contributor

I could replicate issue on colab with Tf 2.0.
Please take a look at gist. Thanks!

@ymodak ymodak assigned qlzh727 and unassigned ymodak Nov 15, 2019
tensorflow-copybara pushed a commit that referenced this issue Nov 19, 2019
The self.stateful value was override in the base_layer.__init__().

See #34055 for more details.

PiperOrigin-RevId: 281184004
Change-Id: I74c47a555cae8b045ee78b5c9a0144c4f9569978
@qlzh727 qlzh727 closed this as completed Nov 19, 2019
@tensorflow-bot
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug
Projects
Development

No branches or pull requests

4 participants