![Py4Eng](../logo.png)

# Gated  Recurrent Unit
## Yoav Ram

In this session we will expand over RNN with GRU.

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt

import jax 
import jax.numpy as np
print('jax', jax.__version__, jax.default_backend())
import numpy as onp
import optax # pip install optax

from collections import Counter
from random import uniform
import pickle

jax 0.4.30 gpu


# Data

The data is just text data, in this case Shakespear's writing.

In [2]:
filename = '../data/shakespear.txt'
with open(filename, 'rt') as f:
    text = f.read()

print("Number of characters: {}".format(len(text)))
print("Number of unique characters: {}".format(len(set(text))))
print("Number of lines: {}".format(text.count('\n')))
print("Number of words: {}".format(text.count(' ')))
print()
print("Excerpt:")
print("*" * len("Excerpt:"))
print(text[:500])

Number of characters: 99993
Number of unique characters: 62
Number of lines: 3298
Number of words: 15893

Excerpt:
********
That, poor contempt, or claim'd thou slept so faithful,
I may contrive our father; and, in their defeated queen,
Her flesh broke me and puttance of expedition house,
And in that same that ever I lament this stomach,
And he, nor Butly and my fury, knowing everything
Grew daily ever, his great strength and thought
The bright buds of mine own.

BIONDELLO:
Marry, that it may not pray their patience.'

KING LEAR:
The instant common maid, as we may less be
a brave gentleman and joiner: he that finds u


We start by creating 
- a list `chars` of the unique characters
- `data_size` the number of total characters
- `vocab_size` the number of unique characters
- `int_to_char` a dictionary from index to char
- `char_to_int` a dictionary from char to index
and then we convert `data` from a string to a NumPy array of integers representing the chars.

In [3]:
chars = list(set(text))
data_size, vocab_size = len(text), len(chars)

# char to int and vice versa
int_to_char = dict(enumerate(chars)) #  == { i: ch for i,ch in enumerate(chars) }
char_to_int = dict(zip(int_to_char.values(), int_to_char.keys())) # { ch: i for i,ch in enumerate(chars) }

def onehot_encode(text):
    ints = [char_to_int[c] for c in text]
    ints = np.array(ints, dtype=int)
    return jax.nn.one_hot(ints, vocab_size)

def onehot_decode(data):
    ints = data.argmax(axis=1).tolist()
    chars = (int_to_char[k] for k in ints)
    return str.join('', chars)

X = onehot_encode(text)

# GRU model

The GRU extends RNN. It avoids the vanishing gradient problem for the vanilla RNN, and is more efficient than LSTM (long-short time memory).
To compute the update to the hidden memory layer $h_t$, it first computes a _reset gate_ $r_t$ and an update gate $z_t$ that are used to interpoltate between the candidate memory $\tilde h_t$ and the next $h_t$.

- $x_t$ is the $t$ character, one-hot encoded and a 1D array of length `vocab_size`.
- $h_t$ is the state of the hidden memory layer after seeing $t$ characters, encoded as a 1D array of numbers (neurons...)
- $r_t$ is the _reset gate_
- $z_t$ is the _update gate_
- $\tilde h_t$ is the candidate hidden memory
- $\widehat y_t$ is the prediction of the network after seeing $t$ characters, encoded as a 1D array of probabilities of length `vocab_size`
- $\sigma(x)$ is the sigmoid/logistic function
- $\circ$ is the Hadamard or element-wise product, $x \circ y = (x_1 y_1, \ldots x_n y_n)$.

The model is then written as:
$$
z_t = \sigma{\left(W_x^z x_t + W_h^z h_{t-1} + b_z\right)}
$$
$$
r_t = \sigma{\left(W_x^r x_t + W_h^r h_{t-1} + b_r\right)}
$$
$$
\tilde h_t = \tanh{\left(W_x^h x_t + W_h^h (r_t \circ h_{t-1}) + b_h\right)}
$$
$$
h_t = (1-z_t) \circ h_{t-1} + z_t \circ \tilde h_t
$$
$$
\hat y_t = \mathrm{softmax}\left(W_h^y h_t + b_y\right)
$$
$$
x_{t+1} \sim \mathrm{Cat}(\hat{y}_t)
$$

and we set $h_0 = (0, \ldots, 0)$.

This operation will be performed by our `step` function.

## Multi-layer GRU
We are going to layer multiple GRUs, so that the output of the first is the input of the second etc. This is done by the `layered_step` function.

The `feed_forward` function will loop over a sequence of $x=(x_1, x_2, \ldots, x_k)$ of some arbitray size - similar to batches in the FFN and CNN frameworks.

In [4]:
def step(params, x, h):
    Wxz, Whz, Wxr, Whr, Wxh, Whh, Why, bz, br, bh, by = params
    z = jax.nn.sigmoid(Wxz @ x + Whz @ h + bz)
    r = jax.nn.sigmoid(Wxr @ x + Whr @ h + br)
    tildeh = jax.nn.tanh(Wxh @ x + Whh @ (r * h) + bh)
    h = (1 - z) * h + z * tildeh
    yhat = jax.nn.softmax(Why @ h + by)
    return yhat, h

def layered_step(params, x, h):
    for i in range(len(params)):
        x, h[i] = step(params[i], x, h[i])
    return x, h

In [5]:
def feed_forward(params, x, h):
    yhat = np.zeros_like(x)
    
    for t in range(len(x)):
        yhat_t, h = layered_step(params, x[t], h)        
        yhat = yhat.at[t, :].set(yhat_t) # equivalent to NumPy's yhat[t, :] = yhat_t

    return yhat, h

In [6]:
def NLL(params, x, y, h):
    yhat, h = feed_forward(params, x, h)    
    loss = -(y * np.log(yhat)).sum()
    return loss, h

We initialize the network parameters so we can test `feed_forward`.

In [7]:
def init_params(key):
    subkeys = jax.random.split(key, 7)
    Wxr = jax.random.normal(subkeys[0], (h_size, vocab_size)) * 0.01 
    Whr = jax.random.normal(subkeys[1], (h_size, h_size)) * 0.01
    Wxz = jax.random.normal(subkeys[2], (h_size, vocab_size)) * 0.01 
    Whz = jax.random.normal(subkeys[3], (h_size, h_size)) * 0.01    
    Wxh = jax.random.normal(subkeys[4], (h_size, vocab_size)) * 0.01 
    Whh = jax.random.normal(subkeys[5], (h_size, h_size)) * 0.01
    Why = jax.random.normal(subkeys[6], (vocab_size, h_size)) * 0.01 
    bz = np.zeros(h_size,) 
    br = np.zeros(h_size,) 
    bh = np.zeros(h_size,) 
    by = np.zeros(vocab_size) 
    params = Wxz, Whz, Wxr, Whr, Wxh, Whh, Why, bz, br, bh, by
    return params

In [8]:
h_size = 100 # number of units in hidden layer
nlayers = 3
key = jax.random.key(412) # generate new key based on the seed "42"

init_keys = jax.random.split(key, nlayers) # one key per layer
params = [init_params(k) for k in init_keys] # init params per layer
h = [np.zeros(h_size) for _ in range(len(params))] # init hidden vector per layer

x, y = X[:25], X[1:26]
%timeit yhat, _ = feed_forward(params, x, h)
yhat, _ = feed_forward(params, x, h)
print(onehot_decode(yhat))
print(onehot_decode(y))
loss, h = NLL(params, x, y, h)
print(loss)

127 ms ± 1.38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
ooooooooooooooooooooooooo
hat, poor contempt, or cl
103.17666


## Back propagation by automatic differentiation

This works in the same way as it did with RNN.

In [9]:
backprop = jax.value_and_grad(NLL, has_aux=True)

(loss, h), grads = backprop(params, x, y, h)
for params_i, grads_i in zip(params, grads): # loop over layers
    for p, g in zip(params_i, grads_i): # loop over params of layer
        assert p.shape == g.shape
        assert not (g == 0).all()

# Adam optimizer with Optax

We can use a JAX implementation of the Adam optimizer from the [Optax](https://optax.readthedocs.io/) library.
We first create the optimizer and initialize its state.

In [10]:
optimizer = optax.adam(learning_rate=0.001) # 0.001 is the default from Kingma et al 2014
opt_state = optimizer.init(params)

We then use the optimizer to compute the updates, and apply them.

In [11]:
(loss, h), grads = backprop(params, x, y, h)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates) 

# JITing the training step

We write a function that does all this, and pass it to `jax.jit`, which [just-in-time compiles the function](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) so it can be executed efficiently in XLA.

In [12]:
@jax.jit 
def update_params(params, opt_state, x, y, h):
    (loss, h), grads = backprop(params, x, y, h)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, h, opt_state, loss

In [13]:
%timeit update_params(params, opt_state, x, y, h)

695 µs ± 145 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
params, h, opt_state, loss = update_params(params, opt_state, x, y, h)
print(loss)
params, h, opt_state, loss = update_params(params, opt_state, x, y, h)
print(loss)

103.119
103.05258


# Sampling from the network

Finally, instead of a `predict` function, we have a `sample` function, which, given the parameters and the number of samples we want, produces a sample of text from the network.

It does so by drawing a random seed for $x_0$ and drawing $x_t$ for $t>0$ from the distribution given by $\widehat y_t$.

![](https://www.researchgate.net/profile/Aven-Zhou/publication/337006979/figure/fig3/AS:821430174380045@1572855623911/An-Illustration-of-the-Generating-Sequence-in-an-RNN.png)

In [15]:
def sample(params, num_samples, key):
    h = [np.zeros(h_size) for _ in range(len(params))]
    
    x = np.zeros((num_samples, vocab_size), dtype=float)
    key, subkey = jax.random.split(key)
    seed_char = jax.random.choice(subkey, vocab_size)
    x = x.at[0, seed_char].set(1)
    
    for t in range(1, num_samples):
        yhat, h = layered_step(params, x[t-1], h)
        # draw from output distribution
        key, subkey = jax.random.split(key)
        i = jax.random.choice(subkey, vocab_size, p=yhat)
        x = x.at[t, i].set(1)
    return onehot_decode(x)

print(sample(params, 100, jax.random.key(1)))

MjnrjaYnG,bAPAUqLR Y.RIFUsTnSfkcz
lV;C,hHJ: obR-SyEQpGheDZze.CUmlrK,ow;ZAcL.:kS.frRpHcdmKL-H:k
VWME'


# Training the network

We setup the training - the sequence length to unroll the network, the number of batches, parameter initialization, Adam optimizer.

In [53]:
seq_length = 25
max_batches = 10000000
h = np.zeros(h_size)
pos = 0
batch = 0 
losses = []
key = jax.random.key(28)
key, *init_keys = jax.random.split(key)
params = [init_params(k) for k in init_keys]
h = [np.zeros(h_size) for _ in range(len(params))]

optimizer = optax.adam(learning_rate=0.001) # you can try with 0.01
opt_state = optimizer.init(params)

Now we can train the GRU!

In [54]:
%%time
while batch <= max_batches:
    if pos + seq_length + 1 >= data_size:
        # reset data position and hidden state
        pos, h = 0, [np.zeros(h_size) for _ in range(len(params))]
        
    x = X[pos : pos + seq_length]
    y = X[pos + 1 : pos + seq_length + 1]
    pos += seq_length
        
    params, h, opt_state, loss = update_params(params, opt_state, x, y, h)
    losses.append(loss)
    
    if batch % (max_batches // 10) == 0:
        print('batch {:d}, loss {:.6f}, pos {}'.format(batch, loss, pos))
        print()
        
        with open("../data/gru3-jax-params-{}.pkl".format(batch), 'wb') as file:
            pickle.dump(params, file)
        
        key, subkey = jax.random.split(key)        
        sample_text = sample(params, 200, subkey)
        print(sample_text)
        print('-'*80)
    batch += 1

batch 0, loss 103.179733, pos 25

hApMNxL:jTeJmQ!oXYxlJIfZeD.Lcleqz
D.QP'fAEEmBjEgmJaCtimu:GfgsFIvLyHQjmDWM'ozUGkd? ;k!H
bQQ:cyRVvECO,dzIls
JWUobZpCuIykaL:jdsn-ky?lkem'KUMKPW'MrkUILiMmrlVRG zOy-BftpZ'mrK'WCaoGUe.jTdu:Lw,T?KdneD.gg.af:
--------------------------------------------------------------------------------
batch 1000000, loss 26.493454, pos 6275

publet:
I'll see them, True, I am boy.

SfiLd:
I holl our lood but a worth to their marriegh lenscury Matcius. Not his love.

STEUS:
More done be dot longer without 'em if from it with kous?

AlTORAR:
--------------------------------------------------------------------------------
batch 2000000, loss 33.737244, pos 12525

R-Our all Philose nobles;
And be a reclain the so but go the prayer: and says Buthoth,
When becontly with her wormand.

OCTAVIUS Cages, but creems. Pildes of report had do you and the very fields asug
--------------------------------------------------------------------------------
batch 3000000, loss 15.674427, pos 18

# Load parameters from specific batch

In [56]:
batch = 6000000
with open("../data/gru3-jax-params-{}.pkl".format(batch), 'rb') as file:
    params = pickle.load(file)

In [57]:
print(sample(params, 500, jax.random.key(12)))

man's
Montles, and I frower, if you wouldsple the mind of all let him to-mnom.

NORFON:
Then be degided the bruble every issemboos o'
rag,
Not came of our age.

Hostest withse.

BERTRY:
As erame without my lives of sister betwent your face
To-night this your foot a furge of your mourer sol
Did amen my delised we do.

ANGELO:
Sir any fair suind;
For: the sliem. This thren charge in grace I saw the of most heels.

BOTTOF,
Rement things in know?

IAGO:
A pleasure and even me nobly right too and lea


# References

- Andrej Karpathy's [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) blogpost
- Cho et al. 2014. [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation](https://arxiv.org/abs/1406.1078). arXiv:1406.1078

# Colophon
This notebook was written by [Yoav Ram](http://python.yoavram.com).

This work is licensed under a [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/) International License.

![Python logo](https://www.python.org/static/community_logos/python-logo.png)