![Py4Eng](../logo.png)

# Recurrent Neural Networks
## Yoav Ram

In this session we will understand:
- what recurrent neural network and how they work, and
- how memory and state can be implemented in neural networks



In [1]:
%matplotlib inline
import matplotlib.pyplot as plt

import jax 
import jax.numpy as np
print('jax', jax.__version__, jax.default_backend())
import optax # pip install optax

from collections import Counter
from random import uniform

jax 0.4.35 cpu


# Data

In developing this RNN we will follow [Andrej Karpathy](http://cs.stanford.edu/people/karpathy/)'s [blogpost about RNNs](http://karpathy.github.io/2015/05/21/rnn-effectiveness) ([original code gist](https://gist.github.com/karpathy/d4dee566867f8291f086) with BSD License).

The data is just text data, in this case Shakespear's writing.

In [2]:
filename = '../data/shakespear.txt'
with open(filename, 'rt') as f:
    text = f.read()

print("Number of characters: {}".format(len(text)))
print("Number of unique characters: {}".format(len(set(text))))
print("Number of lines: {}".format(text.count('\n')))
print("Number of words: {}".format(text.count(' ')))
print()
print("Excerpt:")
print("*" * len("Excerpt:"))
print(text[:500])

Number of characters: 99993
Number of unique characters: 62
Number of lines: 3298
Number of words: 15893

Excerpt:
********
That, poor contempt, or claim'd thou slept so faithful,
I may contrive our father; and, in their defeated queen,
Her flesh broke me and puttance of expedition house,
And in that same that ever I lament this stomach,
And he, nor Butly and my fury, knowing everything
Grew daily ever, his great strength and thought
The bright buds of mine own.

BIONDELLO:
Marry, that it may not pray their patience.'

KING LEAR:
The instant common maid, as we may less be
a brave gentleman and joiner: he that finds u


We start by creating 
- a list `chars` of the unique characters
- `data_size` the number of total characters
- `vocab_size` the number of unique characters
- `int_to_char` a dictionary from index to char
- `char_to_int` a dictionary from char to index
and then we convert `data` from a string to a NumPy array of integers representing the chars.

In [3]:
chars = list(set(text))
data_size, vocab_size = len(text), len(chars)

# char to int and vice versa
int_to_char = dict(enumerate(chars)) #  == { i: ch for i,ch in enumerate(chars) }
char_to_int = dict(zip(int_to_char.values(), int_to_char.keys())) # { ch: i for i,ch in enumerate(chars) }

def onehot_encode(text):
    ints = [char_to_int[c] for c in text]
    ints = np.array(ints, dtype=int)
    return jax.nn.one_hot(ints, vocab_size)

def onehot_decode(data):
    ints = data.argmax(axis=1).tolist()
    chars = (int_to_char[k] for k in ints)
    return str.join('', chars)

X = onehot_encode(text)

# RNN model

- $x_t$ is the $t$ character, one-hot encoded and a 1D array of length `vocab_size`.
- $h_t$ is the state of the hidden memory layer after seeing $t$ characters, encoded as a 1D array of numbers (neurons...)
- $\widehat y_t$ is the prediction of the network after seeing $t$ characters, encoded as a 1D array of probabilities of length `vocab_size`

The model is then written as:

The model is then written as:
$$
h_t = \tanh{\left(W_x^h x_t + W_h^h h_{t-1} + b_h\right)}
$$
$$
\hat y_t = \mathrm{softmax}\left(W_h^y h_t + b_y\right)
$$
$$
x_{t+1} \sim \mathrm{Cat}(\hat{y}_t)
$$


and we set $h_0 = (0, \ldots, 0)$.

This operation will be performed by our `step` function.

The `feed_forward` function will loop over a sequence of $x=(x_1, x_2, \ldots, x_k)$ of some arbitray size - similar to batches in the FFN and CNN frameworks.

In [5]:
def step(params, x, h):
    Wxh, Whh, Why, bh, by = params
    h = np.tanh(Wxh @ x + Whh @ h + bh)        
    yhat = jax.nn.softmax(Why @ h + by)
    return yhat, h

def feed_forward(params, x, h):
    yhat = np.zeros_like(x)
    
    for t in range(len(x)):
        yhat_t, h = step(params, x[t], h)        
        yhat = yhat.at[t, :].set(yhat_t) # equivalent to NumPy's yhat[t, :] = yhat_t

    return yhat, h

We initialize the network parameters so we can test `feed_forward`.

In [6]:
h_size = 100 # number of units in hidden layer

def init_params(key):
    subkeys = jax.random.split(key, 3)
    Wxh = jax.random.normal(subkeys[0], (h_size, vocab_size)) * 0.01 
    Whh = jax.random.normal(subkeys[1], (h_size, h_size)) * 0.01
    Why = jax.random.normal(subkeys[2], (vocab_size, h_size)) * 0.01 
    bh = np.zeros(h_size,) # hidden layer bias
    by = np.zeros(vocab_size) # readout layer bias
    params = Wxh, Whh, Why, bh, by
    return params

In [8]:
key = jax.random.key(420) # generate new key based on the seed "42"
params = init_params(key)

x, y = X[:25], X[1:26]
h = np.zeros(h_size)
%timeit yhat, _ = feed_forward(params, x, h)
yhat, _ = feed_forward(params, x, h)
print(onehot_decode(yhat))
print(onehot_decode(y))

42.6 ms ± 2.43 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
odtvudbXQGo.X;vR,bvudQGo.
hat, poor contempt, or cl


### Back propagation

Back propagation works, as before, using the chain rule. 
It is similar to the [FFN example](FFN.ipynb), except that the $h$ layer adds a bit of complexity, but not much.

The details of the gradient calculation can be found in Stanford's ["Convolutional Neural Networks for Visual Recognition" course](http://cs231n.github.io/neural-networks-case-study/#grad).

What's important to discuss is that instead of back propagating a single step of the network $t$, we back propagate over a sequence of steps, that is over $x=(x_1, \ldots, x_k)$ for some arbitrary $k$.

How? By "unrolling" the network.

![](https://www.researchgate.net/profile/Hamid-Rabiee/publication/341956650/figure/fig1/AS:11431281078694336@1660200861643/An-RNN-unrolled-through-the-time-The-same-structure-is-repeated-at-adjacent-time-steps.ppm)

For example, for sequence of length 3, the input is $x=[x_1, x_2, x_3]$, and we can write

$$
h_1 = \tanh{\big(x_1 \cdot W^{xh} + h_0 \cdot W^{hh} + b_h\big)} $$$$
\hat{y}_1 = \text{softmax}\big(h_1 \cdot W^{hy}\big) $$$$
h_2 = \tanh{\big(x_2 \cdot W^{xh} + h_1 \cdot W^{hh} + b_h\big)} $$$$
\hat{y}_2 = \text{softmax}\big(h_2 \cdot W_h^y\big) $$$$
h_3 = \tanh{\big(x_3 \cdot W^{xh} + h_2 \cdot W^{hh} + b_h\big)} $$$$
\hat{y}_3 = \text{softmax}\big(h_3 \cdot W^{hy}\big)
$$

The NLL or loss (negative log likelihood, cross entropy) for a single step is
$$
NLL = \sum_{k=0}^{61}{x_{t+1,k} \log{\hat{y}_{t,k}}}
$$
where $k$ runs over all characters from 0 to 61 (for 62 characters).
After unrolling the network, the NLL is computed by summing over all $\hat{y}_t$ together, 
$$
NLL = \sum_{t=1}^{3}{\sum_{k=0}^{61}{x_{t+1,k} \log{\hat{y}_{t,k}}}}
$$
The gradients are computed for this NLL with respect to the $W$ and $b$ parameters.

In [13]:
def NLL(params, x, y, h):
    yhat, h = feed_forward(params, x, h)    
    loss = -(y * np.log(yhat)).sum()
    return loss, h

loss, h = NLL(params, x, y, h)
print(loss)

103.16748


## Automatic differentiation with JAX
Now instead of manually deriving the gradient and implementing it as a Python program, we use JAX's automatic differentiation. [`jax.grad`](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html#jax-first-transformation-grad) takes a function `f(a, b, c)` and returns a function `dfda(a, b, c)` that returns the gradient of `f` with respect to `a` at the values of `a`, `b`, and `c`. It does so by automating the procedure we did manually using the chain rule.

In our case, `f` is `NLL`, `a` is `params`, and `b` and `c` are `x` and `y`, that is, we use `grad` on `NLL(params, x, y)` to get `backprop(params, x, y)`.

The function [`jax.value_and_grad`](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html#value-and-grad) is used to return both `f(a,b,c)` (the "value") and the `dfda` (the "grad"). 
Finally, `has_aux` means that `f` return two values - the value that needs to be differentiated, and an auxillary value. In our case, the value to differentiate is `loss` and the auxillary is `h`. This is important because we need to keep track of `h` and `loss`.

In [14]:
backprop = jax.value_and_grad(NLL, has_aux=True)

(loss, h), grads = backprop(params, x, y, h)
for p, g in zip(params, grads):
    assert p.shape == g.shape
    assert not (g == 0).all()

# Adam optimizer with Optax

We can use a JAX implementation of the Adam optimizer from the [Optax](https://optax.readthedocs.io/) library.
We first create the optimizer and initialize its state.

In [15]:
optimizer = optax.adam(learning_rate=0.001) # 0.001 is the default from Kingma et al 2014
opt_state = optimizer.init(params)

We then use the optimizer to compute the updates, and apply them.

In [16]:
(loss, h), grads = backprop(params, x, y, h)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates) 

# JITing the training step

We write a function that does all this, and pass it to `jax.jit`, which [just-in-time compiles the function](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) so it can be executed efficiently in XLA.

In [11]:
@jax.jit # decreases runtime from 380 ms to <1 ms!
def update_params(params, opt_state, x, y, h):
    (loss, h), grads = backprop(params, x, y, h)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, h, opt_state, loss

In [16]:
%timeit update_params(params, opt_state, x, y, h)

569 µs ± 113 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [17]:
params, h, opt_state, loss = update_params(params, opt_state, x, y, h)
print(loss)
params, h, opt_state, loss = update_params(params, opt_state, x, y, h)
print(loss)

103.08142
102.99289


# Sampling from the network

Finally, instead of a `predict` function, we have a `sample` function, which, given the parameters and the number of samples we want, produces a sample of text from the network.

It does so by drawing a random seed for $x_0$ and drawing $x_t$ for $t>0$ from the distribution given by $\widehat y_t$.

![](https://www.researchgate.net/profile/Aven-Zhou/publication/337006979/figure/fig3/AS:821430174380045@1572855623911/An-Illustration-of-the-Generating-Sequence-in-an-RNN.png)

In [18]:
def sample(params, num_samples, key):
    h = np.zeros(h_size)
    
    x = np.zeros((num_samples, vocab_size), dtype=float)
    key, subkey = jax.random.split(key)
    seed_char = jax.random.choice(subkey, vocab_size)
    x = x.at[0, seed_char].set(1)
    
    for t in range(1, num_samples):
        yhat, h = step(params, x[t-1], h)
        # draw from output distribution
        key, subkey = jax.random.split(key)
        i = jax.random.choice(subkey, vocab_size, p=yhat)
        x = x.at[t, i].set(1)
    return onehot_decode(x)

print(sample(params, 100, jax.random.key(1)))

'QRiQjqr
HhvSvEJk!fq !-UEKIrGLbadVoO;:HWlZFfyh!cGtPRz
WCJDdC :EBoiYHyM;Dvak FDG Li!zlaeBYkclFbVOm'PN


# Training the network

We setup the training - the sequence length to unroll the network, the number of batches, parameter initialization, Adam optimizer.

In [22]:
seq_length = 25
max_batches = 1000000
h = np.zeros(h_size)
pos = 0
batch = 0 
losses = []
key = jax.random.key(8)

key, subkey = jax.random.split(key)
params = init_params(subkey)

optimizer = optax.adam(learning_rate=0.001) # you can try with 0.01
opt_state = optimizer.init(params)

Now we can train the RNN!

In [16]:
%%time
while batch <= max_batches:
    if pos + seq_length + 1 >= data_size:
        # reset data position and hidden state
        pos, h = 0, np.zeros(h_size) 
        
    x = X[pos : pos + seq_length]
    y = X[pos + 1 : pos + seq_length + 1]
    pos += seq_length
        
    params, h, opt_state, loss = update_params(params, opt_state, x, y, h)
    losses.append(loss)
    
    if batch % (max_batches // 10) == 0:
        print('batch {:d}, loss {:.6f}, pos {}'.format(batch, loss, pos))
        print()
        
        key, subkey = jax.random.split(key)        
        sample_text = sample(params, 200, subkey)
        print(sample_text)
        print('-'*80)
    batch += 1

batch 0, loss 103.177963, pos 25

 l;HJCdR
rr.BzkNsLdtu.xcABpQP'yOtBTB:nknaC Irzl
'IWX yJM-ZTVRQHtjBLkeyu;kbcObr-.P?ztjKpED gxjHiA
;tXUVFb:fbzcu

Y,ysuPDYTNdq ?m!
WGbFHzQLl,MacrAqngQkFV-TDsd
CREJjtHcQYytIiOaZRHUYvOPR,QnufWwAmVxsE ibdi
--------------------------------------------------------------------------------
batch 100000, loss 31.441458, pos 650

JoX:
Straight has, what, my in stact purfeand.

HOLON:
I why cwerigh's thilks Endees I did upfice undoyos intain,
With majuce and not she menter in'tinnse come your voice
To see.
sine streaths expire.
--------------------------------------------------------------------------------
batch 200000, loss 38.215614, pos 1275

Thinst bureand you speak, but hother Midgnt death; thy spirit---chart fespity: and will day whilus lord her proforitle not, that so muct had ingespet, must brought:
Yad this me dote you are renoullawe
--------------------------------------------------------------------------------
batch 300000, loss 36.385902, pos 1900

O

# References

- Andrej Karpathy's [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) blogpost
- Vinyals et al. (2014) [Show and Tell: A Neural Image Caption Generator](https://arxiv.org/abs/1411.4555).
- [Obama-RNN](https://medium.com/@samim/obama-rnn-machine-generated-political-speeches-c8abd18a2ea0) by samim.
- [Making a Predictive Keyboard using Recurrent Neural Networks](http://curiousily.com/data-science/2017/05/23/tensorflow-for-hackers-part-5.html) by Venelin Valkov

# Colophon
This notebook was written by [Yoav Ram](http://python.yoavram.com).

This work is licensed under a [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/) International License.

![Python logo](https://www.python.org/static/community_logos/python-logo.png)