# Stateful RNNs

In this reading notebook you will learn how to retain the state of an RNN when processing long sequences.

In [1]:
import tensorflow as tf
tf.__version__

print('GPU name: {}'.format(tf.test.gpu_device_name()))

gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
tf.config.experimental.set_virtual_device_configuration(
    gpus[0],
    [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)]
)
  

GPU name: /device:GPU:0


So far you have trained RNNs on entire sequences, possibly of varying length. In some applications, such as financial time series modeling or real-time speech processing, the input sequence can be very long. 

One way to process such sequences is to simply chop up the sequences into separate batches. However, the internal state of the RNN would then normally be reset in between the batches. Persisting an RNN cell's state between batches is useful in such contexts.

## Stateful and non-stateful RNN models
We will begin by creating two versions of the same RNN model. The first is a regular RNN that does not retain its state.

In [55]:
# Create a regular (non-stateful) RNN

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import GRU, Bidirectional, LSTM

gru = Sequential([
    GRU(5, input_shape=(None, 1), name='rnn')
])

gru.summary()
# states = gru.layers[0].states
# print(len(states), states[0].shape)
# states[0]

Model: "sequential_12"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
rnn (GRU)                    (None, 5)                 120       
Total params: 120
Trainable params: 120
Non-trainable params: 0
_________________________________________________________________


To persist RNN cell states between batches, you can use the `stateful` argument when you initialize an RNN layer. The default value of this argument is `False`. This argument is available for all RNN layer types.

In [16]:
# Create a stateful RNN

stateful_gru = Sequential([
    GRU(5, stateful=True, batch_input_shape=(2, None, 1), name='stateful_rnn')
])

stateful_gru.summary()
states = stateful_gru.layers[0].states
print(len(states), states[0].shape)
states[0]

Model: "sequential_10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
stateful_rnn (GRU)           (2, 5)                    120       
Total params: 120
Trainable params: 120
Non-trainable params: 0
_________________________________________________________________
1 (2, 5)


<tf.Variable 'stateful_rnn/Variable:0' shape=(2, 5) dtype=float32, numpy=
array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)>

Note that as well as setting `stateful=True`, we have also specified the `batch_input_shape`. This fixes the number of elements in a batch, as well as providing the sequence length and number of features. So the above model will always require a batch of 2 sequences.

When using stateful RNNs, it is necessary to supply this argument to the first layer of a `Sequential` model. This is because the model will always assume that each element of every subsequent batch it receives will be a continuation of the sequence from the corresponding element in the previous batch.

Another detail is that when defining a model with a stateful RNN using the functional API, you will need to specify the `batch_shape` argument as follows:

In [42]:
# Redefine the same stateful RNN using the functional API

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input

inputs = Input(batch_shape=(2, None, 1))
outputs = GRU(5, stateful=True, name='stateful_rnn')(inputs)

stateful_gru = Model(inputs=inputs, outputs=outputs)

states = stateful_gru.layers[1].states
print(len(states), states[0].shape)
states[0]

1 (2, 5)


<tf.Variable 'stateful_rnn/Variable:0' shape=(2, 5) dtype=float32, numpy=
array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)>

In [64]:
inputs = Input(batch_shape=(2, None, 1))
outputs = Bidirectional(layer=LSTM(5, stateful=True, name='stateful_rnn'))(inputs)
# outputs = Bidirectional(layer=LSTM(5, stateful=True, name='stateful_rnn'),
#                         backward_layer = GRU(5, stateful=True, name='backward_stateful_rnn', go_backwards=True)
#                        )(inputs)

stateful_gru = Model(inputs=inputs, outputs=outputs)
stateful_gru.summary()

# Bidirectional里面只有backward_layery有States，原因何在
states = stateful_gru.layers[1].layer.states
print(stateful_gru.layers[1].layer.name)
print(len(states))
for state in states:
    if state is not None:
        print(state.shape , state)

print('-'*50)

states = stateful_gru.layers[1].backward_layer.states
print(stateful_gru.layers[1].backward_layer.name)
print(len(states))
for state in states:
    if state is not None:
        print(state.shape , state)



Model: "functional_75"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_43 (InputLayer)        [(2, None, 1)]            0         
_________________________________________________________________
bidirectional_35 (Bidirectio (2, 10)                   280       
Total params: 280
Trainable params: 280
Non-trainable params: 0
_________________________________________________________________
stateful_rnn
2
--------------------------------------------------
backward_stateful_rnn
2
(2, 5) <tf.Variable 'bidirectional_35/backward_stateful_rnn/Variable:0' shape=(2, 5) dtype=float32, numpy=
array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)>
(2, 5) <tf.Variable 'bidirectional_35/backward_stateful_rnn/Variable:0' shape=(2, 5) dtype=float32, numpy=
array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)>


### Inspect the RNN states
We can inspect the RNN layer states by retrieving the recurrent layer from each model, and looking at the `states` property.

In [25]:
# Retrieve the RNN layer and inspect the internal state

gru.get_layer('rnn').states

[None]

In [26]:
# Retrieve the RNN layer and inspect the internal state

stateful_gru.get_layer('stateful_rnn').states

[<tf.Variable 'stateful_rnn/Variable:0' shape=(2, 5) dtype=float32, numpy=
 array([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]], dtype=float32)>]

Note that the internal state of the stateful RNN has a state stored for each element in a batch, which is why the shape of the state Variable is `(2, 5)`.

### Create a simple sequence dataset
We will demonstrate the effect of statefulness on a simple sequence dataset consisting of two sequences.

In [27]:
# Create the sequence dataset

sequence_data = tf.constant([
    [[-4.], [-3.], [-2.], [-1.], [0.], [1.], [2.], [3.], [4.]],
    [[-40.], [-30.], [-20.], [-10.], [0.], [10.], [20.], [30.], [40.]]
], dtype=tf.float32)
sequence_data.shape

TensorShape([2, 9, 1])

### Process the sequence batch with both models

Now see what happens when you pass the batch of sequences through either model:

In [28]:
# Process the batch with both models

_1 = gru(sequence_data)
_2 = stateful_gru(sequence_data)

In [29]:
# Retrieve the RNN layer and inspect the internal state

gru.get_layer('rnn').states

[None]

In [30]:
# Retrieve the RNN layer and inspect the internal state

stateful_gru.get_layer('stateful_rnn').states

[<tf.Variable 'stateful_rnn/Variable:0' shape=(2, 5) dtype=float32, numpy=
 array([[ 0.56634927, -0.5143654 ,  0.11323077, -0.3098507 ,  0.38285917],
        [-0.09284443,  0.45850012, -0.21524496,  0.42768195, -0.628173  ]],
       dtype=float32)>]

The stateful RNN model has updated and retained its state after having processed the input sequence batch. This internal state could then be used as the initial state for processing a continuation of both sequences in the next batch.

### Resetting the internal state
If you need a stateful RNN to forget (or re-initialise) its state, then you can call an RNN layer's `reset_states()` method.

In [31]:
# Reset the internal state of the stateful RNN model

stateful_gru.get_layer('stateful_rnn').reset_states()

In [32]:
# Retrieve the RNN layer and inspect the internal state

stateful_gru.get_layer('stateful_rnn').states

[<tf.Variable 'stateful_rnn/Variable:0' shape=(2, 5) dtype=float32, numpy=
 array([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]], dtype=float32)>]

Note that `reset_states()` resets the state to `0.`, which is the default initial state for the RNN layers in TensorFlow.

### Retaining internal state across batches
Passing a sequence to a stateful layer as several subsequences produces the same final output as passing the whole sequence at once.

In [33]:
# Reset the internal state of the stateful RNN model and process the full sequences

stateful_gru.get_layer('stateful_rnn').reset_states()
_ = stateful_gru(sequence_data)
stateful_gru.get_layer('stateful_rnn').states

[<tf.Variable 'stateful_rnn/Variable:0' shape=(2, 5) dtype=float32, numpy=
 array([[ 0.56634927, -0.5143654 ,  0.11323077, -0.3098507 ,  0.38285917],
        [-0.09284443,  0.45850012, -0.21524496,  0.42768195, -0.628173  ]],
       dtype=float32)>]

In [19]:
# Break the sequences into batches

sequence_batch1 = sequence_data[:, :3, :]
sequence_batch2 = sequence_data[:, 3:6, :]
sequence_batch3 = sequence_data[:, 6:, :]

print("First batch:", sequence_batch1)
print("\nSecond batch:", sequence_batch2)
print("\nThird batch:", sequence_batch3)

First batch: tf.Tensor(
[[[ -4.]
  [ -3.]
  [ -2.]]

 [[-40.]
  [-30.]
  [-20.]]], shape=(2, 3, 1), dtype=float32)

Second batch: tf.Tensor(
[[[ -1.]
  [  0.]
  [  1.]]

 [[-10.]
  [  0.]
  [ 10.]]], shape=(2, 3, 1), dtype=float32)

Third batch: tf.Tensor(
[[[ 2.]
  [ 3.]
  [ 4.]]

 [[20.]
  [30.]
  [40.]]], shape=(2, 3, 1), dtype=float32)


Note that the first element in every batch is part of the same sequence, and the second element in every batch is part of the same sequence.

In [34]:
# Reset the internal state of the stateful RNN model and process the batches in order

stateful_gru.get_layer('stateful_rnn').reset_states()
_ = stateful_gru(sequence_batch1)
_ = stateful_gru(sequence_batch2)
_ = stateful_gru(sequence_batch3)
stateful_gru.get_layer('stateful_rnn').states

[<tf.Variable 'stateful_rnn/Variable:0' shape=(2, 5) dtype=float32, numpy=
 array([[ 0.56634927, -0.5143654 ,  0.11323077, -0.3098507 ,  0.38285917],
        [-0.09284443,  0.45850012, -0.21524496,  0.42768195, -0.628173  ]],
       dtype=float32)>]

Notice that the internal state of the stateful RNN after processing each batch is the same as it was earlier when we processed the entire sequence at once.

This property can be used when training stateful RNNs, if we ensure that each example in a batch is a continuation of the same sequence as the corresponding example in the previous batch.

## Further reading and resources
* https://www.tensorflow.org/guide/keras/rnn#cross-batch_statefulness

In [8]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
import numpy as np
import math

data_dim = 16
timesteps = 8
num_classes = 10
batch_size = 32
num_epochs = 50

# 期望输入数据尺寸: (batch_size, timesteps, data_dim)
# 请注意，我们必须提供完整的 batch_input_shape，因为网络是有状态的。
# 第 k 批数据的第 i 个样本是第 k-1 批数据的第 i 个样本的后续。
model = Sequential()
model.add(LSTM(32, stateful=True,
               batch_input_shape=(batch_size, timesteps, data_dim)))
# model.add(LSTM(32, return_sequences=True, stateful=True))
# model.add(LSTM(32, stateful=True))
model.add(Dense(10, activation='softmax'))
model.summary()

model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

# 生成虚拟训练数据
x_train = np.random.random((batch_size * 10, timesteps, data_dim))      # (320, 8, 16)
y_train = np.random.random((batch_size * 10, num_classes))              # (320, 10)

# 生成虚拟验证数据
x_val = np.random.random((batch_size * 3, timesteps, data_dim))     # (96, 8, 16)
y_val = np.random.random((batch_size * 3, num_classes))

for i in range(num_epochs):
    print("Epoch {:d}/{:d}".format(i+1, num_epochs))
    model.fit(x_train, y_train, batch_size=batch_size, epochs=1, validation_data=(x_val, y_val), shuffle=False)
    model.reset_states()

score, _ = model.evaluate(x_val, y_val, batch_size=batch_size)      # 返回误差值和度量值
rmse = math.sqrt(score)
print("\nMSE: {:.3f}, RMSE: {:.3f}".format(score, rmse))

pre = model.predict(x_val, batch_size=batch_size)

Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm_9 (LSTM)                (32, 32)                  6272      
_________________________________________________________________
dense_3 (Dense)              (32, 10)                  330       
Total params: 6,602
Trainable params: 6,602
Non-trainable params: 0
_________________________________________________________________
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 4

In [68]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense

model = Sequential()
model.add(LSTM(32, batch_input_shape=(4, 10, 16)))
# model.add(LSTM(32, return_sequences=True, stateful=True))
# model.add(LSTM(32, stateful=True))
model.add(Dense(10, activation='softmax'))
model.summary()

Model: "sequential_16"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm (LSTM)                  (4, 32)                   6272      
_________________________________________________________________
dense (Dense)                (4, 10)                   330       
Total params: 6,602
Trainable params: 6,602
Non-trainable params: 0
_________________________________________________________________
