In [50]:
import jax.numpy as jnp
from jax import grad, vmap, jit, random, value_and_grad,vjp
import jax.nn as nn
import tqdm

from scipy.ndimage.interpolation import shift


In [2]:
a=jnp.ones((10))



In [3]:

def f(x):
    return x**3+x**2+2

In [4]:
grad_fn=grad(f)

In [5]:
grad_fn(1.0)

DeviceArray(5., dtype=float32)

In [6]:
v_grad=vmap(grad_fn)

In [7]:
v_grad(a)

DeviceArray([5., 5., 5., 5., 5., 5., 5., 5., 5., 5.], dtype=float32)

RNN using Jax?


In [8]:
import numpy as np
import haiku as hk
import jax
import optax

In [54]:
def mso(samples,horizon=12):
    t=jnp.arange(0,samples)
    X=jnp.sin(0.2*t)+jnp.sin(0.311*t)+jnp.sin(0.42*t)+jnp.sin(0.51*t)
    y=shift(X,horizon)
    return X,y

In [55]:
X,y=mso(25000)

In [14]:
def rnn_forward(x_t,h_prev,hidden_size=32):
    f=hk.VanillaRNN(hidden_size=hidden_size)
    h_t,h_t=f(x_t,h_prev)
    g=hk.Linear(1)
    y_t=nn.relu(g(h_t))
    return y_t,h_t

In [15]:
rnn_forward_pure=hk.without_apply_rng(hk.transform(rnn_forward))

In [16]:
key=random.PRNGKey(0)
params=rnn_forward_pure.init(key,jnp.array([1.0]),jnp.array([1.0]*32))

In [21]:
def naive_loss_fn(params,X,y,t):
    """
    X: input timeseries
    y: output timeseries
    t: timestep to calculate the loss
    """
    h_t=jnp.array([0.0]*32)
    for i in range(t):
        x_t=X[np.array([t])]
        y_t_pred,h_t=rnn_forward_pure.apply(params,x_t,h_t)
    y_t=y[np.array([t-12])]
    mse_loss=((y_t_pred-y_t)**2).sum()
    return mse_loss


In [22]:
optimizer = optax.adam(0.01)
opt_state = optimizer.init(params)

In [25]:
#Lets design our training loop
h_t=jnp.array([0.0]*32) #The initial state of the RNN
for t in tqdm.tqdm(range(X.shape[0])):
    x_t=X[np.array([t])]
    if t>=12: 
        loss,dy_dparam=value_and_grad(naive_loss_fn)(params,X,y,t)
        updates, opt_state = optimizer.update(dy_dparam, opt_state)
        params=optax.apply_updates(params,updates)
        print(loss.block_until_ready())

  0%|                                       | 14/25000 [00:00<03:44, 111.07it/s]

0.98111427
0.43839055
0.056675002
0.019133534
0.124444045
0.119675964
0.015401036
0.061219186
0.44251958
1.0347172
1.4704335


  0%|                                        | 26/25000 [00:00<17:07, 24.30it/s]

1.451438
1.0023844
0.43446332
0.07185803
0.005695258
0.08627631


  0%|                                        | 32/25000 [00:01<23:30, 17.70it/s]

0.120873034
0.060311727
0.0015008628
0.03781329
0.13768934


  0%|                                        | 36/25000 [00:01<28:12, 14.75it/s]

0.18202233
0.11048316
0.010379525
0.04591453


  0%|                                        | 39/25000 [00:02<32:19, 12.87it/s]

0.29595083
0.66427755


  0%|                                        | 41/25000 [00:02<35:41, 11.66it/s]

0.95169675
1.0171465


  0%|                                        | 43/25000 [00:02<40:47, 10.20it/s]

0.87482846
0.6519342


  0%|                                        | 45/25000 [00:03<44:36,  9.32it/s]

0.47451505
0.4018551


  0%|                                        | 48/25000 [00:03<50:39,  8.21it/s]

0.44508317
0.59490895


  0%|                                        | 50/25000 [00:03<55:41,  7.47it/s]

0.7882939
0.8658472


  0%|                                      | 52/25000 [00:04<1:01:02,  6.81it/s]

0.6544822
0.2090999


  0%|                                      | 54/25000 [00:04<1:05:50,  6.31it/s]

0.02005328
0.8854042


  0%|                                      | 56/25000 [00:04<1:09:59,  5.94it/s]

3.3535821
7.063162


  0%|                                      | 58/25000 [00:05<1:12:43,  5.72it/s]

10.566997
12.001057


  0%|                                      | 59/25000 [00:05<1:14:28,  5.58it/s]

10.31275


  0%|                                      | 61/25000 [00:05<1:20:43,  5.15it/s]

6.195523
1.9214138


  0%|                                      | 63/25000 [00:06<1:21:31,  5.10it/s]

0.0049081906
1.5474458


  0%|                                      | 64/25000 [00:06<1:22:41,  5.03it/s]

5.4562364


  0%|                                      | 65/25000 [00:06<1:23:47,  4.96it/s]

9.157268


  0%|                                      | 66/25000 [00:06<1:24:49,  4.90it/s]

10.324558


  0%|                                      | 67/25000 [00:07<1:26:13,  4.82it/s]

8.3790455


  0%|                                      | 68/25000 [00:07<1:27:52,  4.73it/s]

4.6943827


  0%|                                      | 69/25000 [00:07<1:28:42,  4.68it/s]

1.4446418


  0%|                                      | 70/25000 [00:07<1:29:34,  4.64it/s]

0.044687808


  0%|                                        | 71/25000 [00:08<47:43,  8.71it/s]

0.36271983





KeyboardInterrupt: 

### Problems with the naive implementation?
- We need to unroll the entirity of input until the time step 0 to calculate loss at time t
- Gradient calculation time increases as t increases
- Batch implementation is tricky as we cant parallize this array 

Properties we want to have our P_RNN to have:
- forward function that takes params, hidden_state, input -> hidden_state
- sensitivity function that takes params, hidden_state, p - (input,hidden_state) trajectory

In [91]:
def rnn_forward(i_t,h_tminus1,hidden_size=32):
    f=hk.VanillaRNN(hidden_size=hidden_size)
    h_t,h_t=f(i_t,h_tminus1)
    return h_t

In [114]:
rnn_forward_pure=hk.without_apply_rng(hk.transform(rnn_forward))
key=random.PRNGKey(0)
rnn_params=rnn_forward_pure.init(key,jnp.array([1.0]),jnp.array([1.0]*32))

In [104]:
def sensitivity(rnn_params,hidden_state,rnn_trajectory):
    rnn_jac_theta=jax.jacfwd(rnn_forward_pure.apply)
    rnn_jac_hidden=jax.jacfwd(rnn_forward_pure.apply,argnums=2)
    del_h_tminus1_theta=rnn_jac_theta(rnn_params,rnn_trajectory[0][0],rnn_trajectory[0][1])
    del_h_t_theta=del_h_tminus1_theta
    for t in range(1,len(rnn_trajectory)):
        i_t,h_tminus1=rnn_trajectory[t]
        del_f_theta=rnn_jac_theta(rnn_params,i_t,h_tminus1)
        del_f_h_tminus1=rnn_jac_hidden(rnn_params,i_t,h_tminus1)
        # del_h_t_theta=del_f_h_tminus1*del_h_tminus1_theta
        del_h_t_theta= jax.tree_map(lambda x: jnp.tensordot(del_f_h_tminus1,x,axes=1), del_h_tminus1_theta) 
        # del_h_t_theta+=del_f_theta
        del_h_t_theta=jax.tree_multimap(lambda x, y: x+y, del_h_t_theta, del_f_theta)
        del_h_tminus1_theta=del_h_t_theta
    return del_h_t_theta

In [105]:
truncation=10

In [106]:
def g(h_t):
    g_fun=hk.Linear(1)
    return g_fun(h_t)

g=hk.without_apply_rng(hk.transform(g))
g_params=g.init(key,jnp.array([1.0]*32))

def loss_fn(g_params,h_t,y_t):
    y_t_pred=g.apply(g_params,h_t)
    return ((y_t_pred-y_t)**2).sum()

In [115]:
optimizer_rnn = optax.adam(0.001)
opt_state_rnn = optimizer_rnn.init(rnn_params)
optimizer_g = optax.adam(0.001)
opt_state_g = optimizer_g.init(g_params)

In [117]:
h_tminus1=jnp.array([0.0]*32) #The initial state of the RNN
trajectories=[]
losses=[]
for t in tqdm.tqdm(range(X.shape[0])):
    x_t,y_t=X[np.array([t])],y[np.array([t])]
    trajectories.append((x_t,h_tminus1))
    trajectories=trajectories[-truncation:]
    h_t=rnn_forward_pure.apply(rnn_params,x_t,h_tminus1)
    #Predict the output and get the loss
    loss,grad_g_params=value_and_grad(loss_fn)(g_params,h_t,y_t)
    #Calculate the gradient wrt to rnn_params using the sensitivities of the rnn and g layers
    rnn_sensitivity_t=sensitivity(rnn_params,h_t,trajectories)
    g_sensitivity_t=jax.jacfwd(loss_fn,argnums=1)(g_params,h_t,y_t)
    grad_rnn_params=jax.tree_multimap(lambda x: jnp.tensordot(g_sensitivity_t,x,axes=1), rnn_sensitivity_t) 
    #Update the gradients
    updates, opt_state_rnn = optimizer_rnn.update(grad_rnn_params, opt_state_rnn)
    rnn_params=optax.apply_updates(rnn_params,updates)
    updates, opt_state_g = optimizer_g.update(grad_g_params, opt_state_g)
    g_params=optax.apply_updates(g_params,updates)
    #Calculate average loss and print if it is print frequency
    losses.append(loss)
    if len(losses)==100:
        print("Loss: ",np.mean(losses))
        losses=[]
    h_tminus1=h_t #Previous state becomes the current state before the next iteration

  0%|▏                                      | 101/25000 [00:07<32:51, 12.63it/s]

Loss:  1.9928762


  1%|▎                                      | 201/25000 [00:15<32:39, 12.66it/s]

Loss:  1.6692373


  1%|▍                                      | 301/25000 [00:23<32:25, 12.70it/s]

Loss:  1.7385086


  2%|▋                                      | 401/25000 [00:32<36:44, 11.16it/s]

Loss:  1.4898416


  2%|▊                                      | 501/25000 [00:41<35:00, 11.67it/s]

Loss:  1.6864971


  2%|▉                                      | 601/25000 [00:50<33:24, 12.17it/s]

Loss:  1.6806474


  3%|█                                      | 701/25000 [00:58<32:08, 12.60it/s]

Loss:  1.2669876


  3%|█▏                                     | 801/25000 [01:05<31:55, 12.63it/s]

Loss:  1.8941128


  4%|█▍                                     | 901/25000 [01:13<32:26, 12.38it/s]

Loss:  1.5674942


  4%|█▌                                    | 1001/25000 [01:22<33:32, 11.92it/s]

Loss:  1.5030847


  4%|█▋                                    | 1102/25000 [01:30<31:56, 12.47it/s]

Loss:  1.6027762


  5%|█▊                                    | 1202/25000 [01:38<31:33, 12.57it/s]

Loss:  1.3700947


  5%|█▉                                    | 1302/25000 [01:46<31:34, 12.51it/s]

Loss:  1.3286232


  6%|██▏                                   | 1402/25000 [01:55<34:49, 11.29it/s]

Loss:  2.1271584


  6%|██▎                                   | 1483/25000 [02:02<32:30, 12.06it/s]


KeyboardInterrupt: 