You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Have I written custom code (as opposed to using a stock example script provided in TensorFlow): YES.
OS Platform and Distribution (e.g., Linux Ubuntu 16.04): macOS 10.14.6
TensorFlow installed from (source or binary): binary
TensorFlow version (use command below): 2.0.0
Python version: 3.7.4
GPU model and memory: none (MacBook Pro, Core i5, Iris Graphics 6100, 1.5 GB)
Describe the current behavior
State handling in RNNs with a Bidirectional wrapper has changed in tf.keras from keras with TF 1.x. In the old keras with TF 1.x, using stateful=True in a bidi-RNN had no effect -- i.e., all bidi-RNN models behaved as if stateful=False. Therefore model.reset_states() did not do anything.
In the new tf.keras, stateful=True in a bidi-RNN does have an effect -- the fwd-RNN is stateful and the bwd-RNN is stateful. This is a good change IMO -- even though stateful bidi-RNNs are unusual, this is the best way to implement. However, in tf.keras, the model.reset_states() does not do anything for bidi-RNN models (SimpleRNN, GRU, LSTM).
Describe the expected behavior
For the minimal example script provided below, here is the output:
The results after the STATE RESET should be the same as the first set of results -- i.e., the last (third) set of results should produce the same result for the stateful and non-stateful models (same as the first set of results).
Code to reproduce the issue
importnumpyasnpTF2=TrueifTF2:
### currently, there is a bug in tf.keras: model.reset_states() does not workfromtensorflow.keras.layersimportInput, Dense, SimpleRNN, GRU, LSTM, Bidirectionalfromtensorflow.keras.modelsimportModelelse:
### in the old keras, bidi-RNNs with stateful=True behave smae as stateful=Falsefromkeras.layersimportInput, Dense, SimpleRNN, GRU, LSTM, Bidirectionalfromkeras.modelsimportModelsequence_length=3feature_dim=1features_in=Input(batch_shape=(1, sequence_length, feature_dim))
rnn_out=Bidirectional( SimpleRNN(1, activation=None, use_bias=False, return_sequences=True, return_state=False, stateful=False))(features_in)
stateless_model=Model(inputs=[features_in], outputs=[rnn_out])
stateful_rnn_out=Bidirectional( SimpleRNN(1, activation=None, use_bias=False, return_sequences=True, return_state=False, stateful=True))(features_in)
stateful_model=Model(inputs=features_in, outputs=stateful_rnn_out)
toy_weights= [ np.asarray([[1.0]], dtype=np.float32), np.asarray([[-0.5]], dtype=np.float32), np.asarray([[1.0]], dtype=np.float32), np.asarray([[-0.5]], dtype=np.float32)]
stateless_model.set_weights(toy_weights)
stateful_model.set_weights(toy_weights)
x_in=np.zeros(sequence_length)
x_in[0] =1x_in=x_in.reshape( (1, sequence_length, feature_dim) )
defprint_bidi_out(non_stateful_out, stateful_out):
fb= ['FWD::', 'BWD::']
foriinrange(2):
print(fb[i])
print(f'non_stateful: {non_stateful_out.T[i]}')
print(f'stateful: {stateful_out.T[i]}')
print(f'delta: {stateful_out.T[i]-non_stateful_out.T[i]}')
non_stateful_out=stateless_model.predict(x_in).reshape((sequence_length,2))
stateful_out=stateful_model.predict(x_in).reshape((sequence_length,2))
print_bidi_out(non_stateful_out, stateful_out)
non_stateful_out=stateless_model.predict(x_in).reshape((sequence_length,2))
stateful_out=stateful_model.predict(x_in).reshape((sequence_length,2))
print_bidi_out(non_stateful_out, stateful_out)
print('\n** RESETING STATES in STATEFUL MODEL **\n')
stateful_model.reset_states()
non_stateful_out=stateless_model.predict(x_in).reshape((sequence_length,2))
stateful_out=stateful_model.predict(x_in).reshape((sequence_length,2))
print_bidi_out(non_stateful_out, stateful_out)
The text was updated successfully, but these errors were encountered:
keithchugg
changed the title
model.reset_states() does not work for bidirectional-RNNs.
model.reset_states() does not work for bidirectional-RNNs in tf.keras.
Nov 7, 2019
The self.stateful value was override in the base_layer.__init__().
See #34055 for more details.
PiperOrigin-RevId: 281184004
Change-Id: I74c47a555cae8b045ee78b5c9a0144c4f9569978
System information
Describe the current behavior
State handling in RNNs with a Bidirectional wrapper has changed in tf.keras from keras with TF 1.x. In the old keras with TF 1.x, using
stateful=True
in a bidi-RNN had no effect -- i.e., all bidi-RNN models behaved as ifstateful=False
. Thereforemodel.reset_states()
did not do anything.In the new tf.keras,
stateful=True
in a bidi-RNN does have an effect -- the fwd-RNN is stateful and the bwd-RNN is stateful. This is a good change IMO -- even though stateful bidi-RNNs are unusual, this is the best way to implement. However, in tf.keras, themodel.reset_states()
does not do anything for bidi-RNN models (SimpleRNN, GRU, LSTM).Describe the expected behavior
For the minimal example script provided below, here is the output:
The results after the STATE RESET should be the same as the first set of results -- i.e., the last (third) set of results should produce the same result for the stateful and non-stateful models (same as the first set of results).
Code to reproduce the issue
The text was updated successfully, but these errors were encountered: