In [20]:
import jax.numpy as jnp
from jax import grad, vmap, jit, random, value_and_grad,vjp
import jax.nn as nn
import tqdm


from scipy.ndimage.interpolation import shift


### RNN using Jax?


In [21]:
import numpy as np
import haiku as hk
import jax
import optax

In [22]:
def mso(samples,horizon=12):
    t=jnp.arange(0,samples)
    X=jnp.sin(0.2*t)+jnp.sin(0.311*t)+jnp.sin(0.42*t)+jnp.sin(0.51*t)
    y=shift(X,horizon)
    return X,y

In [23]:
X,y=mso(25000)

In [24]:
def rnn_forward(x_t,h_prev,hidden_size=32):
    f=hk.VanillaRNN(hidden_size=hidden_size)
    h_t,h_t=f(x_t,h_prev)
    g=hk.Linear(1)
    y_t=nn.relu(g(h_t))
    return y_t,h_t

In [25]:
rnn_forward_pure=hk.without_apply_rng(hk.transform(rnn_forward))

In [26]:
key=random.PRNGKey(0)
params=rnn_forward_pure.init(key,jnp.array([1.0]),jnp.array([1.0]*32))

In [27]:
def naive_loss_fn(params,X,y,t):
    """
    X: input timeseries
    y: output timeseries
    t: timestep to calculate the loss
    """
    h_t=jnp.array([0.0]*32)
    for i in range(t):
        x_t=X[np.array([t])]
        y_t_pred,h_t=rnn_forward_pure.apply(params,x_t,h_t)
    y_t=y[np.array([t-12])]
    mse_loss=((y_t_pred-y_t)**2).sum()
    return mse_loss


In [28]:
optimizer = optax.adam(0.01)
opt_state = optimizer.init(params)

In [29]:
#Lets design our training loop
h_t=jnp.array([0.0]*32) #The initial state of the RNN
for t in tqdm.tqdm(range(X.shape[0])):
    x_t=X[np.array([t])]
    if t>=12: 
        loss,dy_dparam=value_and_grad(naive_loss_fn)(params,X,y,t)
        updates, opt_state = optimizer.update(dy_dparam, opt_state)
        params=optax.apply_updates(params,updates)
        print(loss.block_until_ready())

  0%|                                        | 16/25000 [00:00<16:43, 24.89it/s]

0.123220906
0.020707674
0.0
0.0


  0%|                                        | 19/25000 [00:00<20:48, 20.00it/s]

0.0
0.0
0.0


  0%|                                        | 21/25000 [00:01<22:56, 18.15it/s]

0.0
0.0
0.0


  0%|                                        | 25/25000 [00:01<27:39, 15.05it/s]

0.0
0.0
1.0624077e-31


  0%|                                        | 27/25000 [00:01<30:19, 13.73it/s]

1.9617304
6.5988936
11.017821


  0%|                                        | 29/25000 [00:01<33:11, 12.54it/s]

12.605366
10.642206


  0%|                                        | 31/25000 [00:01<36:08, 11.51it/s]

6.515843
2.5347004


  0%|                                        | 33/25000 [00:02<39:05, 10.64it/s]

0.3414339
0.0776509


  0%|                                        | 36/25000 [00:02<43:43,  9.51it/s]

0.7336017
1.1804924


  0%|                                        | 38/25000 [00:02<47:36,  8.74it/s]

0.98111427
0.43839055


  0%|                                        | 40/25000 [00:03<51:23,  8.09it/s]

0.056675002
0.019133534


  0%|                                        | 42/25000 [00:03<55:29,  7.50it/s]

0.124444045
0.119675964


  0%|                                        | 44/25000 [00:03<58:46,  7.08it/s]

0.015401036
0.061219186


  0%|                                      | 46/25000 [00:03<1:01:39,  6.74it/s]

0.44251958
1.0347172


  0%|                                      | 48/25000 [00:04<1:05:13,  6.38it/s]

1.4704335
1.451438


  0%|                                      | 49/25000 [00:04<1:06:14,  6.28it/s]

1.0023844


  0%|                                      | 51/25000 [00:04<1:11:59,  5.78it/s]

0.43446332
0.07185803


  0%|                                      | 53/25000 [00:05<1:13:07,  5.69it/s]

0.005695258
0.08627631


  0%|                                      | 55/25000 [00:05<1:15:23,  5.51it/s]

0.120873034
0.060311727


  0%|                                      | 57/25000 [00:05<1:17:58,  5.33it/s]

0.0015008628
0.03781329


  0%|                                      | 59/25000 [00:06<1:20:32,  5.16it/s]

0.13768934
0.18202233


  0%|                                        | 60/25000 [00:06<45:35,  9.12it/s]

0.11048316





KeyboardInterrupt: 

### Problems with the naive implementation?
- We need to unroll the entirity of input until the time step 0 to calculate loss at time t
- Gradient calculation time increases as t increases
- Batch implementation is tricky as we cant parallize this array 

Properties we want to have our P_RNN to have:
- forward function that takes params, hidden_state, input -> hidden_state
- sensitivity function that takes params, hidden_state, p - (input,hidden_state) trajectory

In [44]:
def rnn_forward(i_t,h_tminus1,hidden_size=32):
    f=hk.VanillaRNN(hidden_size=hidden_size)
    h_t,h_t=f(i_t,h_tminus1)
    return h_t

In [45]:
rnn_forward_pure=hk.without_apply_rng(hk.transform(rnn_forward))
key=random.PRNGKey(0)
rnn_params=rnn_forward_pure.init(key,jnp.array([1.0]),jnp.array([1.0]*32))
rnn_forward_pure_jit=rnn_forward_pure.apply

In [55]:
def sensitivity(rnn_params,hidden_state,rnn_trajectory):
    inputs,hidden_states=rnn_trajectory[0],rnn_trajectory[1]
    rnn_jac_theta=jax.jacrev(rnn_forward_pure.apply)
    rnn_jac_hidden=jax.jacrev(rnn_forward_pure.apply,argnums=2)
    del_h_tminus1_theta=rnn_jac_theta(rnn_params,inputs[0],hidden_states[0])
    del_h_t_theta=del_h_tminus1_theta
    
    def sensitivity_calc(del_h_tminus1_theta,trajectory):
        i_t,h_tminus1=trajectory
        del_f_theta=rnn_jac_theta(rnn_params,i_t,h_tminus1)
        del_f_h_tminus1=rnn_jac_hidden(rnn_params,i_t,h_tminus1)
        # del_h_t_theta=del_f_h_tminus1*del_h_tminus1_theta
        del_h_t_theta= jax.tree_map(lambda x: jnp.tensordot(del_f_h_tminus1,x,axes=1), del_h_tminus1_theta) 
        # del_h_t_theta+=del_f_theta
        del_h_t_theta=jax.tree_multimap(lambda x, y: x+y, del_h_t_theta, del_f_theta)
        return del_h_t_theta,None
    
    if len(rnn_trajectory)>1:
        del_h_t_theta,_=jax.lax.scan(sensitivity_calc,del_h_tminus1_theta,(inputs[1:],hidden_states[1:]))
    return del_h_t_theta
#sensitivity=jit(sensitivity)

In [56]:
truncation=100

In [57]:
def g(h_t):
    out=nn.relu(hk.Linear(32)(h_t))
    out=hk.Linear(1)(out)
    return out

g=hk.without_apply_rng(hk.transform(g))
g_params=g.init(key,jnp.array([1.0]*32))

def loss_fn(g_params,h_t,y_t):
    y_t_pred=g.apply(g_params,h_t)
    return ((y_t_pred-y_t)**2).sum()

loss_fn=jit(loss_fn)

In [58]:
optimizer_rnn = optax.adam(0.0001)
opt_state_rnn = optimizer_rnn.init(rnn_params)
optimizer_g = optax.adam(0.001)
opt_state_g = optimizer_g.init(g_params)

In [59]:
h_tminus1=jnp.array([0.0]*32) #The initial state of the RNN
trajectories=[jnp.zeros((truncation,1)),jnp.zeros((truncation,32))]

def update_trajectories(trajectories,i,h):
    trajectories[0]=jnp.concatenate((trajectories[0][1:],i.reshape(1,-1)),axis=0)
    trajectories[1]=jnp.concatenate((trajectories[1][1:],h.reshape(1,-1)),axis=0)
    return trajectories
    
losses=[]
for epoch in range(10):
    for t in tqdm.tqdm(range(X.shape[0])):
        x_t,y_t=X[np.array([t])],y[np.array([t])]
        #Update trajectories
        trajectories=update_trajectories(trajectories,x_t,h_tminus1)
        h_t=rnn_forward_pure_jit(rnn_params,x_t,h_tminus1)
        #Predict the output and get the loss
        loss,grad_g_params=value_and_grad(loss_fn)(g_params,h_t,y_t)
        #Calculate the gradient wrt to rnn_params using the sensitivities of the rnn and g layers
        rnn_sensitivity_t=sensitivity(rnn_params,h_t,trajectories)
        g_sensitivity_t=jax.jacfwd(loss_fn,argnums=1)(g_params,h_t,y_t)
        grad_rnn_params=jax.tree_multimap(lambda x: jnp.tensordot(g_sensitivity_t,x,axes=1), rnn_sensitivity_t) 
        #Update the gradients
        updates, opt_state_rnn = optimizer_rnn.update(grad_rnn_params, opt_state_rnn)
        rnn_params=optax.apply_updates(rnn_params,updates)
        updates, opt_state_g = optimizer_g.update(grad_g_params, opt_state_g)
        g_params=optax.apply_updates(g_params,updates)
        #Calculate average loss and print if it is print frequency
        losses.append(loss)
        if len(losses)==100:
            print("Loss: ",np.mean(losses))
            losses=[]
        h_tminus1=h_t #Previous state becomes the current state before the next iteration

  0%|                                                                       | 29/25000 [00:03<54:12,  7.68it/s]


KeyboardInterrupt: 

Awesome! It works, and its super fast too! Let's do some sanity check for the sensitivity

In [None]:
sample_x=X[0:20]

h_t=jnp.array([0.0]*32)
def test_sensitivity(rnn_params,sample_x,h_t):
    for i in range(20):
        h_t=rnn_forward_pure_jit(rnn_params,sample_x[np.array([i])],h_t)
    return h_t

true_sensitivity=jax.jacrev(test_sensitivity)(rnn_params,sample_x,h_tminus1)

hidden_states=[]
for i in range(20):
    hidden_states.append(h_t)
    h_t=rnn_forward_pure_jit(rnn_params,sample_x[np.array([i])],h_t)
print("Hidden state calculated correctly: ",jnp.allclose(h_t,test_sensitivity(rnn_params,sample_x,h_tminus1)))
hidden_states=jnp.stack(hidden_states)
our_sensitivity=sensitivity(rnn_params,None,[sample_x.reshape(-1,1),hidden_states])
print("Sensitivity calculated correctly: ",jnp.allclose(true_sensitivity['vanilla_rnn/linear']['w'],our_sensitivity['vanilla_rnn/linear']['w'],atol=0.0001))

Next Step? create a haiku model that can be transformed, it automatically manages state and is differentiable with grad

Notes:
1. Maybe we should use haiku states
2. trajectories is updated in forward pass as a state variable
3. .apply is differentiable, and should use the trajectories to calculate sensitivities internally

In [3]:
import sys
sys.path.append('../')


In [4]:
from src.models import MultiplicativeRNN

In [5]:
def rnn_forward(obs,act,last_state):
    rnn=MultiplicativeRNN(10,4,32)
    out,state=rnn(obs,act,last_state)
    return out,state

In [6]:
rnn_forward_pure=hk.without_apply_rng(hk.transform(rnn_forward))


In [8]:
key=random.PRNGKey(0)
sample_o=jax.random.normal(key,[20,10])
sample_a=jax.random.normal(key,[20,4])

h_t=MultiplicativeRNN.initial_state(10,4,32,50)
rnn_params=rnn_forward_pure.init(key,sample_o[0],sample_a[0],h_t)
def test_sensitivity(rnn_params,sample_o,sample_a,h_t):
    for i in range(20):
        out,h_t=jit(rnn_forward_pure.apply)(rnn_params,sample_o[i],sample_a[i],h_t)
    return out

true_sensitivity=jax.jacrev(test_sensitivity)(rnn_params,sample_o,sample_a,h_t)

hidden_states=[]
for i in range(20):
    out,h_t=rnn_forward_pure.apply(rnn_params,sample_o[i],sample_a[i],h_t)
print("Hidden state calculated correctly: ",jnp.allclose(out,test_sensitivity(rnn_params,sample_o,sample_a,h_t)))
our_sensitivity=MultiplicativeRNN.sensitivity(rnn_forward_pure.apply,rnn_params,h_t)
print("Sensitivity calculated correctly: ",jnp.allclose(true_sensitivity['multiplicative_rnn']['w_h'],
                                                        our_sensitivity['multiplicative_rnn']['w_h'],atol=0.0001))

Hidden state calculated correctly:  True
Sensitivity calculated correctly:  True


In [14]:
def x(a):
    return (a**2).sum(),2

In [18]:
a=value_and_grad(x)

In [19]:
a(jnp.array([12.0,12.0]))

TypeError: Gradient only defined for scalar-output functions. Output was (DeviceArray(288., dtype=float32), 2).

In [63]:
from __future__ import print_function
from functools import partial
from jax import jit

@partial(jit, static_argnums=(0,))
def app(f, x):
  return f(x)
jit_a=jit(app)
print(jit_a(lambda x: 2 * x, 3))

GOOD JOB!