### Stateful RNN

Until now, we have only used **stateless** RNNs: at each training iteration the
model starts with a hidden state full of zeros, then it updates this state at each
time step, and after the last time step, it throws it away as it is not needed
anymore. What if we instructed the RNN to preserve this final state after
processing a training batch and use it as the initial state for the next training
batch? This way the model could learn long-term patterns despite only
backpropagating through short sequences. This is called a **stateful** RNN.

* First, note that a stateful RNN only makes sense if each input sequence in a
batch starts exactly where the corresponding sequence in the previous batch
left off. So the first thing we need to do to build a stateful RNN is to use
sequential and **nonoverlapping** input sequences (rather than the shuffled and
overlapping sequences we used to train stateless RNNs). When creating the
tf.data.Dataset, we must therefore use `shift=length` (instead of shift=1) when
calling the window() method. Moreover, we must **not** call the `shuffle()`
method.

* Unfortunately, batching is much harder when preparing a dataset for a
stateful RNN than it is for a stateless RNN. Indeed, if we were to call
batch(32), then 32 *consecutive* windows would be put in the *same* batch, and
the following batch would not continue each of these windows where it left
off. The first batch would contain windows 1 to 32 and the second batch
would contain windows 33 to 64, so if you consider, say, the first window of
each batch (i.e., windows 1 and 33), you can see that they are not
consecutive. The simplest solution to this problem is to just use a **batch size
of 1.**

In [1]:
import pathlib
import tensorflow as tf

filepath = pathlib.Path("datasets") / "shakespeare.txt"
with open(filepath, "r") as f_:
    shakespear_txt = f_.read()
print("".join(shakespear_txt[:80]))
print()
print("Distinct characters:", "".join(sorted(set(shakespear_txt.lower()))))

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

Distinct characters: 
 !$&',-.3:;?abcdefghijklmnopqrstuvwxyz


In [None]:
text_vec_layer = tf.keras.layers.TextVectorization(
    split="character", standardize="lower"
)
text_vec_layer.adapt([shakespear_txt])
encoded = text_vec_layer([shakespear_txt])[0]
encoded

<tf.Tensor: shape=(1115394,), dtype=int64, numpy=array([21,  7, 10, ..., 22, 28, 12], dtype=int64)>

In [3]:
encoded -= 2
n_tokens = text_vec_layer.vocabulary_size() - 2
dataset_size = len(encoded)
print("\rDistinct w/o <PAD> and <UNK>: ", n_tokens)
print(f"Total dataset size: {dataset_size:_}")

Distinct w/o <PAD> and <UNK>:  39
Total dataset size: 1_115_394


In [None]:
def to_dataset_for_stateful_rnn(sequence, length):
    ds = tf.data.Dataset.from_tensor_slices(sequence)
    ds = ds.window(length + 1, shift=length, drop_remainder=True)  # shift = length
    ds = ds.flat_map(lambda window: window.batch(length + 1)).batch(1)  # batch size 1
    return ds.map(lambda window: (window[:, :-1], window[:, 1:])).prefetch(1)

In [None]:
length = 100
stateful_train_set = to_dataset_for_stateful_rnn(encoded[:1_000_000], length)
stateful_valid_set = to_dataset_for_stateful_rnn(encoded[1_000_000:1_060_000], length)
stateful_test_set = to_dataset_for_stateful_rnn(encoded[1_060_000:], length)

In [6]:
# simple example using to_dataset_for_stateful_rnn:
list(to_dataset_for_stateful_rnn(tf.range(10), 3))

[(<tf.Tensor: shape=(1, 3), dtype=int32, numpy=array([[0, 1, 2]])>,
  <tf.Tensor: shape=(1, 3), dtype=int32, numpy=array([[1, 2, 3]])>),
 (<tf.Tensor: shape=(1, 3), dtype=int32, numpy=array([[3, 4, 5]])>,
  <tf.Tensor: shape=(1, 3), dtype=int32, numpy=array([[4, 5, 6]])>),
 (<tf.Tensor: shape=(1, 3), dtype=int32, numpy=array([[6, 7, 8]])>,
  <tf.Tensor: shape=(1, 3), dtype=int32, numpy=array([[7, 8, 9]])>)]

If you'd like to have more than one window per batch, you can use the `to_batched_dataset_for_stateful_rnn()` function instead of `to_dataset_for_stateful_rnn()`:

Batching is harder, but it is not impossible. For example, we could chop
Shakespeare’s text into 32 texts of equal length, create one dataset of
consecutive input sequences for each of them, and finally use
`tf.data.Dataset.zip(datasets).map(lambda *windows: tf.stack(windows))` to
create proper consecutive batches, where the $n^{th}$ input sequence in a batch
starts off exactly where the $n^{th}$ input sequence ended in the previous batch

In [None]:
import numpy as np


def to_non_overlapping_windows(sequence, length):
    ds = tf.data.Dataset.from_tensor_slices(sequence)
    ds = ds.window(length + 1, shift=length, drop_remainder=True)
    return ds.flat_map(lambda window: window.batch(length + 1))


def to_batched_dataset_for_stateful_rnn(sequence, length, batch_size=32):
    parts = np.array_split(sequence, batch_size)
    datasets = tuple(to_non_overlapping_windows(part, length) for part in parts)
    ds = tf.data.Dataset.zip(datasets).map(lambda *windows: tf.stack(windows))
    return ds.map(lambda window: (window[:, :-1], window[:, 1:])).prefetch(1)

In [8]:
list(to_batched_dataset_for_stateful_rnn(tf.range(20), length=3, batch_size=2))

[(<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
  array([[ 0,  1,  2],
         [10, 11, 12]])>,
  <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
  array([[ 1,  2,  3],
         [11, 12, 13]])>),
 (<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
  array([[ 3,  4,  5],
         [13, 14, 15]])>,
  <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
  array([[ 4,  5,  6],
         [14, 15, 16]])>),
 (<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
  array([[ 6,  7,  8],
         [16, 17, 18]])>,
  <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
  array([[ 7,  8,  9],
         [17, 18, 19]])>)]

<hr>
Code breakdown:

<br>

Now, let’s create the stateful RNN. We need to set the `stateful` argument to
True when creating each recurrent layer, and because the stateful RNN needs
to know the batch size (since it will preserve a state for each input sequence
in the batch). Therefore we must set the `batch_input_shape` argument in the
first layer. Note that we can leave the second dimension unspecified, since
the input sequences could have any length:

In [None]:
tf.random.set_seed(42)
tf.keras.backend.clear_session()

model = tf.keras.Sequential(
    [
        tf.keras.layers.Embedding(
            input_dim=n_tokens, output_dim=16, batch_input_shape=[1, None]
        ),  # This is with batch_size = 1
        tf.keras.layers.GRU(128, return_sequences=True, stateful=True),
        tf.keras.layers.Dense(n_tokens, activation="softmax"),
    ]
)
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding (Embedding)       (1, None, 16)             624       
                                                                 
 gru (GRU)                   (1, None, 128)            56064     
                                                                 
 dense (Dense)               (1, None, 39)             5031      
                                                                 
Total params: 61719 (241.09 KB)
Trainable params: 61719 (241.09 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


<hr>
Recall:

* The inputs of the Embedding layer will be 2D with shape: [*batch_size*, *window_length*]
* The output of the Embedding layer will be 3D with shape: [*batch_size*, *window_length*, *embedding_size*]
* The Dense should have *n_tokens* units (same as input_dim of Embedding)
* We want to output a probability for each character and should sum up to 1, so we use softmax.
* Output of the RNN should be: [1, *window_length*, *n_tokens*]
<hr>

At the end of each epoch, we need to *reset the states* before we go back to the
beginning of the text.

In [9]:
class ResetStatesCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        self.model.reset_states()

Compile and train using the ResetStatesCallback

In [None]:
from IPython.display import display, Markdown

code = """
model.compile(loss="sparse_categorical_crossentropy", optimizer="nadam", metrics=["accuracy"])
history = model.fit(stateful_train_set, validation_data=stateful_valid_set,
                epochs=10, callbacks=[ResetStatesCallback()])
"""
display(Markdown("```python\n{}\n```".format(code)))

```python

model.compile(loss="sparse_categorical_crossentropy", optimizer="nadam", metrics=["accuracy"])
history = model.fit(stateful_train_set, validation_data=stateful_valid_set,
                epochs=10, callbacks=[ResetStatesCallback()])

```

***Extra***: converting the stateful RNN to a stateless RNN and using it.

To use the model with different batch sizes, we need to create a stateless copy:

In [None]:
stateless_model = tf.keras.Sequential(
    [
        tf.keras.layers.Embedding(input_dim=n_tokens, output_dim=16),
        tf.keras.layers.GRU(128, return_sequences=True),
        tf.keras.layers.Dense(n_tokens, activation="softmax"),
    ]
)

To set the weights, we first need to build the model (so the weights get created):

In [12]:
stateless_model.build(tf.TensorShape([None, None]))

In [13]:
stateless_model.set_weights(model.get_weights())

In [None]:
shakespeare_model = tf.keras.Sequential(
    [text_vec_layer, tf.keras.layers.Lambda(lambda X: X - 2), stateless_model]
)