![Py4Eng](img/logo.png)

# Recurrent Neural Networks
## Yoav Ram

In this session we will understand:
- what recurrent neural network and how they work, and
- how memory and state can be implemented in neural networks
- how JAX can be used for high-performance numerical computing as a NumPy replacement

**Please use the correct kernel**: in the notebook menu bar, click `Kernel`, then `Change kernel...` then choose `conda_tensorflow2_p37`.

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt

import jax # http://jax.readthedocs.io
import jax.numpy as np
import optax # https://optax.readthedocs.io

from collections import Counter
from random import uniform

# JAX

[JAX](https://jax.readthedocs.io/en/latest/index.html) combines automatic differentiation and a machine-learning specific compiler ([XLA](https://www.tensorflow.org/xla)) for high-performance numerical computing.
- JAX provides a familiar NumPy-style API for ease of adoption,
- JAX includes composable function transformations for compilation, batching (i.e. vectorization), automatic differentiation, and parallelization,
- The same code executes on multiple backends, including CPU, GPU, and TPU (Google's GPU)

When using JAX we can mostly use the NumPy API, with some important difference:
- JAX arrays are immutable so we cannot use item assignment
- random number generations requires us to provide a random key at every call (the random number generator is stateless)

JAX allows us to just-in-time compile functions and importantly to compute gradients automatically. We will see these features as we proceed.

# Data

In developing this RNN we will follow [Andrej Karpathy](http://cs.stanford.edu/people/karpathy/)'s [blogpost about RNNs](http://karpathy.github.io/2015/05/21/rnn-effectiveness) ([original code gist](https://gist.github.com/karpathy/d4dee566867f8291f086) with BSD License).

The data is just text data, in this case Shakespear's writing.

In [2]:
filename = '../data/shakespear.txt'
with open(filename, 'rt') as f:
    text = f.read()

print("Number of characters: {}".format(len(text)))
print("Number of unique characters: {}".format(len(set(text))))
print("Number of lines: {}".format(text.count('\n')))
print("Number of words: {}".format(text.count(' ')))
print()
print("Excerpt:")
print("*" * len("Excerpt:"))
print(text[:500])

Number of characters: 99993
Number of unique characters: 62
Number of lines: 3298
Number of words: 15893

Excerpt:
********
That, poor contempt, or claim'd thou slept so faithful,
I may contrive our father; and, in their defeated queen,
Her flesh broke me and puttance of expedition house,
And in that same that ever I lament this stomach,
And he, nor Butly and my fury, knowing everything
Grew daily ever, his great strength and thought
The bright buds of mine own.

BIONDELLO:
Marry, that it may not pray their patience.'

KING LEAR:
The instant common maid, as we may less be
a brave gentleman and joiner: he that finds u


We start by creating 
- a list `chars` of the unique characters
- `data_size` the number of total characters
- `vocab_size` the number of unique characters
- `int_to_char` a dictionary from index to char
- `char_to_int` a dictionary from char to index
and then we convert `data` from a string to a NumPy array of integers representing the chars.

In [3]:
chars = list(set(text))
data_size, vocab_size = len(text), len(chars)

# char to int and vice versa
int_to_char = dict(enumerate(chars)) #  == { i: ch for i,ch in enumerate(chars) }
char_to_int = dict(zip(int_to_char.values(), int_to_char.keys())) # { ch: i for i,ch in enumerate(chars) }

def onehot_encode(text):
    ints = [char_to_int[c] for c in text]
    ints = np.array(ints, dtype=int)
    return jax.nn.one_hot(ints, vocab_size)

def onehot_decode(data):
    ints = data.argmax(axis=1).tolist()
    chars = (int_to_char[k] for k in ints)
    return str.join('', chars)

X = onehot_encode(text)

# RNN model

- $x_t$ is the $t$ character, one-hot encoded and a 1D array of length `vocab_size`.
- $h_t$ is the state of the hidden memory layer after seeing $t$ characters, encoded as a 1D array of numbers (neurons...)
- $\widehat y_t$ is the prediction of the network after seeing $t$ characters, encoded as a 1D array of probabilities of length `vocab_size`

The model is then written as:

$$
h_t = \tanh{\big(x_t \cdot W_x^h + h_{t-1} \cdot W_h^h + b_h\big)}
$$
$$
\widehat y_t = softmax\big(h_t \cdot W_h^y + b_y\big)
$$

and we set $h_0 = (0, \ldots, 0)$.

This operation will be performed by our `step` function.

The `feed_forward` function will loop over a sequence of $x=(x_1, x_2, \ldots, x_k)$ of some arbitray size - similar to batches in the FFN and CNN frameworks.
The loss function `cross_entropy` is computed from the parameters and 

In [4]:
def step(params, x, h):
    Wxh, Whh, Why, bh, by = params
    h = np.tanh(Wxh @ x + Whh @ h + bh)        
    yhat = jax.nn.softmax(Why @ h + by) # softmax function implemented in JAX
    return yhat, h
    
def feed_forward(params, x, h):
    yhat = np.zeros_like(x)
    
    for t in range(len(x)):
        yhat_t, h = step(params, x[t], h)
        # this is the JAX syntax that replaces NumPy item assignment
        yhat = yhat.at[t, :].set(yhat_t) # equivalent to yhat[t, :] = yhat_t

    return yhat, h

We initialize the network parameters so we can test `feed_forward`.

Because [JAX pseudo-random number generation](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html) works differently than NumPy -- the generator is stateless, so we need to provide a "random key" (i.e. seed) at every function call. Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated. To generate different and independent samples, we must `split()` the key.

In [5]:
h_size = 100 # number of units in hidden layer

def init_params(key):
    key, subkey = jax.random.split(key) # split the key - one for consumption and one for the rest of the program
    Wxh = jax.random.normal(subkey, (h_size, vocab_size)) * 0.01 
    key, subkey = jax.random.split(key) # split the key...
    Whh = jax.random.normal(subkey, (h_size, h_size)) * 0.01
    key, subkey = jax.random.split(key) # split the key
    Why = jax.random.normal(subkey, (vocab_size, h_size)) * 0.01 
    bh = np.zeros(h_size,) # hidden layer bias
    by = np.zeros(vocab_size) # readout layer bias
    params = Wxh, Whh, Why, bh, by
    return params

In [6]:
key = jax.random.PRNGKey(42) # generate new key based on the seed "42"
params = init_params(key)

x, y = X[:25], X[1:26]
h = np.zeros(h_size)
yhat, h = feed_forward(params, x, h)
print(onehot_decode(yhat))

y'JPTT:vJ'TUvQs:R:sTTv'TU


# Back propagation

Back propagation works, as before, using the chain rule. 
It is similar to the [FFN example](FFN.ipynb), except that the $h$ layer adds a bit of complexity, but not much.

The details of the gradient calculation can be found in Stanford's ["Convolutional Neural Networks for Visual Recognition" course](http://cs231n.github.io/neural-networks-case-study/#grad).

What's important to discuss is that instead of back propagating a single step of the network $t$, we back propagate over a sequence of steps, that is over $x=(x_1, \ldots, x_k)$ for some arbitrary $k$.

![rolled RNN](img/rolled_rnn.png)

How? By "unrolling" the network.

![Unrolled RNN](img/unrolled_rnn.png)

For example, for $k=3$, the input is $x=[x(1), x(2), x(3)]$, and we can write

$$
h(1) = \tanh{\big(x(1) \cdot W_x^h + h(0) \cdot W_h^h + b_h\big)} $$$$
\widehat y(1) = softmax\big(h(1) \cdot W_h^y\big) $$$$
h(2) = \tanh{\big(x(2) \cdot W_x^h + h(1) \cdot W_h^h + b_h\big)} $$$$
\widehat y(2) = softmax\big(h(2) \cdot W_h^y\big) $$$$
h(3) = \tanh{\big(x(3) \cdot W_x^h + h(2) \cdot W_h^h + b_h\big)} $$$$
\widehat y(3) = softmax\big(h(3) \cdot W_h^y\big)
$$

The cross entropy is computed by summing over all $\widehat y(t)$ together, and then the gradient is computed for this cross entropy with respect to the various $W$ and $b$ parameters.

In [7]:
def cross_entropy(params, x, y, h):
    yhat, h = feed_forward(params, x, h)    
    loss = -(y * np.log(yhat)).sum()
    return loss, h

loss, h = cross_entropy(params, x, y, h)
print(loss)

103.17685


## Automatic differentiation with JAX
Now instead of manually deriving the gradient and implementing it as a Python program, we use JAX's automatic differentiation. [`jax.grad`](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html#jax-first-transformation-grad) takes a function `f(a, b, c)` and returns a function `dfda(a, b, c)` that returns the gradient of `f` with respect to `a` at the values of `a`, `b`, and `c`. It does so by automating the procedure we did manually using the chain rule.

In our case, `f` is `cross_entropy`, `a` is `params`, and `b` and `c` are `x` and `y`.

The function [`jax.value_and_grad`](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html#value-and-grad) is used to return both `f(a,b,c)` (the "value") and the `dfda` (the "grad"). 
Finally, `has_aux` means that `f` return two values - the value that needs to be differentiated, and an auxillary value. In our case, the value to differentiate is `loss` and the auxillary is `h`.

In [8]:
back_propagation = jax.value_and_grad(cross_entropy, has_aux=True)

(loss, h), grads = back_propagation(params, x, y, h)
for p, g in zip(params, grads):
    assert p.shape == g.shape
    assert not (g == 0).all()

# Adam optimizer with Optax

We can use a JAX implementation of the Adam optimizer from the [Optax](https://optax.readthedocs.io/) library.
We first create the optimizer and initialize its state.

In [9]:
optimizer = optax.adam(learning_rate=0.001) # 0.001 is the default from Kingma et al 2014
opt_state = optimizer.init(params)

We then use the optimizer to compute the updates, and apply them.

In [10]:
(loss, h), grads = back_propagation(params, x, y, h)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)

# JITing the training step

We write a function that does all this, and pass it to `jax.jit`, which [just-in-time compiles the function](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) so it can be executed efficiently in XLA.

In [11]:
@jax.jit # decreases runtime from 380 ms to <1 ms!
def update_params(params, opt_state, x, y, h):
    (loss, h), grads = back_propagation(params, x, y, h)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, h, opt_state, loss

In [12]:
%timeit update_params(params, opt_state, x, y, h)

625 µs ± 192 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
params, h, opt_state, loss = update_params(params, opt_state, x, y, h)
print(loss)
params, h, opt_state, loss = update_params(params, opt_state, x, y, h)
print(loss)

103.09128
103.0037


# Sampling from the network

Finally, instead of a `predict` function, we have a `sample` function, which, given the parameters $W$s and $b$s and number of chars, produces a sample of text from the network.

It does so by using drawing a random seed for $x_0$ and drawing $x_t$ for $t>0$ from the distribution given by $\widehat y_t$.

![](img/sampling_rnn.png)

In [16]:
def sample(params, num_samples, key):
    h = np.zeros(h_size)
    
    x = np.zeros((num_samples, vocab_size), dtype=float)
    key, subkey = jax.random.split(key)
    seed_char = jax.random.choice(subkey, vocab_size)
    x = x.at[0, seed_char].set(1)
    
    for t in range(1, num_samples):
        yhat, h = step(params, x[t-1], h)
        # draw from output distribution
        key, subkey = jax.random.split(key)
        i = jax.random.choice(subkey, vocab_size, p=yhat)
        x = x.at[t, i].set(1)
    return onehot_decode(x)

print(sample(params, 100, jax.random.PRNGKey(1)))

nBNUBF?,qHuQvQ;EhC:?eCsz;VT,PJgML!
ap'HoKSm:AuCiPIyNDqolZxLle';W
UbHAdpxQMhemxPeJUCDKMGWbhiKmg!a-nyw


# Training the network

In [19]:
seq_length = 25
max_batches = 100000
h = np.zeros(h_size)
pos = 0
batch = 0 
key = jax.random.PRNGKey(8)

key, subkey = jax.random.split(key)
params = init_params(subkey)

optimizer = optax.adam(learning_rate=0.001) # you can try with 0.01
opt_state = optimizer.init(params)

In [20]:
while batch <= max_batches:
    if pos + seq_length + 1 >= data_size:
        # reset data position and hidden state
        pos, h = 0, np.zeros(h_size) 
        
    x = X[pos : pos + seq_length]
    y = X[pos + 1 : pos + seq_length + 1]
    pos += seq_length
        
    params, h, opt_state, loss = update_params(params, opt_state, x, y, h)
    
    if batch % (max_batches // 10) == 0:
        print('batch {:d}, loss {:.6f}, pos {}'.format(batch, loss, pos))
        print()
        
        key, subkey = jax.random.split(key)        
        sample_text = sample(params, 200, subkey)
        print(sample_text)
        print('-'*80)
    batch += 1

batch 0, loss 103.167320, pos 25

:EsndaVZoNN
zBRKbrVCu
?PczTf-ASMCzQz;yRyla:.NBEoA.Ik:SdxHGQDZfnCizrRwSusRUPMUNH
-jBCiYTL :O?inhcosCkWDpU;FUBPuooveSbu- vQKVq:j,'oIJUpnBfrEexlPNcqyOfRpDHQ bVoaZLdiCnPfvSC.hMlGZnWvtM-ZefyuFI!c,D?bL:hUVh
--------------------------------------------------------------------------------
batch 10000, loss 55.802208, pos 50075

did bearur to my yot be ant hear.

ENTRAGL:
Sar,
I heve far axlobsupreve seven't gion arr.

PUSIR:
Ch, dave, this tey,
Thiuet, I walls. Tould wall. 
FyOSDETAL:
So your laddind
Twem dith thes pently an
--------------------------------------------------------------------------------
batch 20000, loss 46.399208, pos 150

QUUCES IFHDLIMAS:
Heat net you there whele as the als hes are, why have, Heirenveng: he hould he dateage fore deen of our stay dot,
Than se dey alang for to the suntly corske'ns of huse insuld leak:
B
--------------------------------------------------------------------------------
batch 30000, loss 39.962032, pos 50200

My

# References

- Andrej Karpathy's [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) blogpost
- [Obama-RNN](https://medium.com/@samim/obama-rnn-machine-generated-political-speeches-c8abd18a2ea0) by samim.
- [Making a Predictive Keyboard using Recurrent Neural Networks](http://curiousily.com/data-science/2017/05/23/tensorflow-for-hackers-part-5.html) by Venelin Valkov
- [JAX tutorial](https://colinraffel.com/blog/you-don-t-know-jax.html) by Colin Raffel

# Colophon
This notebook was written by [Yoav Ram](http://python.yoavram.com) and is part of the [_Data Science with Python_](https://github.com/yoavram/DataSciPy) workshop.

This work is licensed under a [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/) International License.

![Python logo](https://www.python.org/static/community_logos/python-logo.png)