In [16]:
import jax.numpy as jnp
from jax import grad, vmap, jit, random, value_and_grad,vjp
import jax.nn as nn
import tqdm


from scipy.ndimage.interpolation import shift


### RNN using Jax?


In [17]:
import numpy as np
import haiku as hk
import jax
import optax

In [None]:
def mso(samples,horizon=12):
    t=jnp.arange(0,samples)
    X=jnp.sin(0.2*t)+jnp.sin(0.311*t)+jnp.sin(0.42*t)+jnp.sin(0.51*t)
    y=shift(X,horizon)
    return X,y

In [None]:
X,y=mso(25000)

In [None]:
def rnn_forward(x_t,h_prev,hidden_size=32):
    f=hk.VanillaRNN(hidden_size=hidden_size)
    h_t,h_t=f(x_t,h_prev)
    g=hk.Linear(1)
    y_t=nn.relu(g(h_t))
    return y_t,h_t

In [None]:
rnn_forward_pure=hk.without_apply_rng(hk.transform(rnn_forward))

In [None]:
key=random.PRNGKey(0)
params=rnn_forward_pure.init(key,jnp.array([1.0]),jnp.array([1.0]*32))

In [None]:
def naive_loss_fn(params,X,y,t):
    """
    X: input timeseries
    y: output timeseries
    t: timestep to calculate the loss
    """
    h_t=jnp.array([0.0]*32)
    for i in range(t):
        x_t=X[np.array([t])]
        y_t_pred,h_t=rnn_forward_pure.apply(params,x_t,h_t)
    y_t=y[np.array([t-12])]
    mse_loss=((y_t_pred-y_t)**2).sum()
    return mse_loss


In [None]:
optimizer = optax.adam(0.01)
opt_state = optimizer.init(params)

In [None]:
#Lets design our training loop
h_t=jnp.array([0.0]*32) #The initial state of the RNN
for t in tqdm.tqdm(range(X.shape[0])):
    x_t=X[np.array([t])]
    if t>=12: 
        loss,dy_dparam=value_and_grad(naive_loss_fn)(params,X,y,t)
        updates, opt_state = optimizer.update(dy_dparam, opt_state)
        params=optax.apply_updates(params,updates)
        print(loss.block_until_ready())

### Problems with the naive implementation?
- We need to unroll the entirity of input until the time step 0 to calculate loss at time t
- Gradient calculation time increases as t increases
- Batch implementation is tricky as we cant parallize this array 

Properties we want to have our P_RNN to have:
- forward function that takes params, hidden_state, input -> hidden_state
- sensitivity function that takes params, hidden_state, p - (input,hidden_state) trajectory

In [None]:
def rnn_forward(i_t,h_tminus1,hidden_size=32):
    f=hk.VanillaRNN(hidden_size=hidden_size)
    h_t,h_t=f(i_t,h_tminus1)
    return h_t

In [None]:
rnn_forward_pure=hk.without_apply_rng(hk.transform(rnn_forward))
key=random.PRNGKey(0)
rnn_params=rnn_forward_pure.init(key,jnp.array([1.0]),jnp.array([1.0]*32))
rnn_forward_pure_jit=rnn_forward_pure.apply

In [None]:
def sensitivity(rnn_params,hidden_state,rnn_trajectory):
    inputs,hidden_states=rnn_trajectory[0],rnn_trajectory[1]
    rnn_jac_theta=jax.jacrev(rnn_forward_pure.apply)
    rnn_jac_hidden=jax.jacrev(rnn_forward_pure.apply,argnums=2)
    del_h_tminus1_theta=rnn_jac_theta(rnn_params,inputs[0],hidden_states[0])
    del_h_t_theta=del_h_tminus1_theta
    
    def sensitivity_calc(del_h_tminus1_theta,trajectory):
        i_t,h_tminus1=trajectory
        del_f_theta=rnn_jac_theta(rnn_params,i_t,h_tminus1)
        del_f_h_tminus1=rnn_jac_hidden(rnn_params,i_t,h_tminus1)
        # del_h_t_theta=del_f_h_tminus1*del_h_tminus1_theta
        del_h_t_theta= jax.tree_map(lambda x: jnp.tensordot(del_f_h_tminus1,x,axes=1), del_h_tminus1_theta) 
        # del_h_t_theta+=del_f_theta
        del_h_t_theta=jax.tree_multimap(lambda x, y: x+y, del_h_t_theta, del_f_theta)
        return del_h_t_theta,None
    
    if len(rnn_trajectory)>1:
        del_h_t_theta,_=jax.lax.scan(sensitivity_calc,del_h_tminus1_theta,(inputs[1:],hidden_states[1:]))
    return del_h_t_theta
#sensitivity=jit(sensitivity)

In [None]:
truncation=100

In [None]:
def g(h_t):
    out=nn.relu(hk.Linear(32)(h_t))
    out=hk.Linear(1)(out)
    return out

g=hk.without_apply_rng(hk.transform(g))
g_params=g.init(key,jnp.array([1.0]*32))

def loss_fn(g_params,h_t,y_t):
    y_t_pred=g.apply(g_params,h_t)
    return ((y_t_pred-y_t)**2).sum()

loss_fn=jit(loss_fn)

In [None]:
optimizer_rnn = optax.adam(0.0001)
opt_state_rnn = optimizer_rnn.init(rnn_params)
optimizer_g = optax.adam(0.001)
opt_state_g = optimizer_g.init(g_params)

In [None]:
h_tminus1=jnp.array([0.0]*32) #The initial state of the RNN
trajectories=[jnp.zeros((truncation,1)),jnp.zeros((truncation,32))]

def update_trajectories(trajectories,i,h):
    trajectories[0]=jnp.concatenate((trajectories[0][1:],i.reshape(1,-1)),axis=0)
    trajectories[1]=jnp.concatenate((trajectories[1][1:],h.reshape(1,-1)),axis=0)
    return trajectories
    
losses=[]
for epoch in range(10):
    for t in tqdm.tqdm(range(X.shape[0])):
        x_t,y_t=X[np.array([t])],y[np.array([t])]
        #Update trajectories
        trajectories=update_trajectories(trajectories,x_t,h_tminus1)
        h_t=rnn_forward_pure_jit(rnn_params,x_t,h_tminus1)
        #Predict the output and get the loss
        loss,grad_g_params=value_and_grad(loss_fn)(g_params,h_t,y_t)
        #Calculate the gradient wrt to rnn_params using the sensitivities of the rnn and g layers
        rnn_sensitivity_t=sensitivity(rnn_params,h_t,trajectories)
        g_sensitivity_t=jax.jacfwd(loss_fn,argnums=1)(g_params,h_t,y_t)
        grad_rnn_params=jax.tree_multimap(lambda x: jnp.tensordot(g_sensitivity_t,x,axes=1), rnn_sensitivity_t) 
        #Update the gradients
        updates, opt_state_rnn = optimizer_rnn.update(grad_rnn_params, opt_state_rnn)
        rnn_params=optax.apply_updates(rnn_params,updates)
        updates, opt_state_g = optimizer_g.update(grad_g_params, opt_state_g)
        g_params=optax.apply_updates(g_params,updates)
        #Calculate average loss and print if it is print frequency
        losses.append(loss)
        if len(losses)==100:
            print("Loss: ",np.mean(losses))
            losses=[]
        h_tminus1=h_t #Previous state becomes the current state before the next iteration

Awesome! It works, and its super fast too! Let's do some sanity check for the sensitivity

In [None]:
sample_x=X[0:20]

h_t=jnp.array([0.0]*32)
def test_sensitivity(rnn_params,sample_x,h_t):
    for i in range(20):
        h_t=rnn_forward_pure_jit(rnn_params,sample_x[np.array([i])],h_t)
    return h_t

true_sensitivity=jax.jacrev(test_sensitivity)(rnn_params,sample_x,h_tminus1)

hidden_states=[]
for i in range(20):
    hidden_states.append(h_t)
    h_t=rnn_forward_pure_jit(rnn_params,sample_x[np.array([i])],h_t)
print("Hidden state calculated correctly: ",jnp.allclose(h_t,test_sensitivity(rnn_params,sample_x,h_tminus1)))
hidden_states=jnp.stack(hidden_states)
our_sensitivity=sensitivity(rnn_params,None,[sample_x.reshape(-1,1),hidden_states])
print("Sensitivity calculated correctly: ",jnp.allclose(true_sensitivity['vanilla_rnn/linear']['w'],our_sensitivity['vanilla_rnn/linear']['w'],atol=0.0001))

Next Step? create a haiku model that can be transformed, it automatically manages state and is differentiable with grad

Notes:
1. Maybe we should use haiku states
2. trajectories is updated in forward pass as a state variable
3. .apply is differentiable, and should use the trajectories to calculate sensitivities internally

In [99]:
import sys
sys.path.append('../')
from jax import lax

In [19]:
from src.models.rnn import MultiplicativeRNN,rnn_transform

In [59]:
def rnn_forward(inputs,last_state):
    rnn=MultiplicativeRNN(10,4,32)
    out,state=rnn(inputs,last_state)
    return out,state

In [60]:
rnn_forward_pure=hk.without_apply_rng(hk.transform(rnn_forward))


In [52]:
key=random.PRNGKey(0)
sample_o=jax.random.normal(key,[20,10])
sample_a=jax.random.normal(key,[20,4])

h_t=MultiplicativeRNN.initial_state(10,4,32,50)
rnn_params=rnn_forward_pure.init(key,sample_o[0],sample_a[0],h_t)
def test_sensitivity(rnn_params,sample_o,sample_a,h_t):
    for i in range(10):
        out,h_t=(rnn_forward_pure.apply)(rnn_params,sample_o[i],sample_a[i],h_t)
    return out

true_sensitivity=jax.jacrev(test_sensitivity)(rnn_params,sample_o,sample_a,h_t)
out_true=test_sensitivity(rnn_params,sample_o,sample_a,h_t).block_until_ready()
h_t=MultiplicativeRNN.initial_state(10,4,32,50)
hidden_states=[]
for i in range(10):
    out,h_t=rnn_forward_pure.apply(rnn_params,sample_o[i],sample_a[i],h_t)

print("Hidden state calculated correctly: ",jnp.allclose(out,out_true,atol=0.01))
our_sensitivity=MultiplicativeRNN.sensitivity(rnn_forward_pure.apply,rnn_params,h_t)
print("Sensitivity calculated correctly: ",jnp.allclose(true_sensitivity['multiplicative_rnn']['w_h'],
                                                        our_sensitivity['multiplicative_rnn']['w_h'],atol=0.0001))

TypeError: rnn_forward() takes 2 positional arguments but 3 were given

In [71]:
transformed_f=custom_jvp(rnn_forward_pure.apply)

In [72]:
forward_fn=hk.without_apply_rng(hk.transform(rnn_forward)).apply


In [258]:
from jax import custom_jvp,jvp
import jax.numpy as jnp

# f :: a -> b
def jvp_through_time(rnn_forward,rnn_params,rnn_state,tangents):
        """Returns the jacobian of the current hidden state with respect to the RNN params 
            until the length of the trajectory. 

            In future this will be replaced with with JAX grad primitives
        Args:
            rnn_forward_pure ([type]): Pure haiku RNN transformed function 
            rnn_params ([type]): RNN parameters for the transformed function
            rnn_trajectory (tuple(jnp.array,jnp.array)): Tuple of (inputs,last hidden_state)

        Returns:
            jax.numpy.array: Tensor containing the sensitivities as a Jacobian matrix
        """
        hidden_state,trajectory=rnn_state
        last_hidden_states=trajectory.last_hidden_states
        observations=trajectory.observations
        last_actions=trajectory.last_actions
        rnn_params_t,inputs_t,last_state_t=tangents
        rnn_forward_out_tangents=(rnn_params_t,inputs_t,last_state_t[0])
        def rnn_forward_out(params,inputs,last_hidden_state):
            out,_=rnn_forward(params,inputs,(last_hidden_state,None))
            return out


        rnn_jac_hidden=jax.jacrev(rnn_forward_out,argnums=2)
        _,jvp_h_tminus1_theta=jvp(rnn_forward_out,(rnn_params,(observations[0],last_actions[0]),last_hidden_states[0]),
                                          rnn_forward_out_tangents) #Ignore the output for now
        jvp_h_t_theta=jvp_h_tminus1_theta
        def sensitivity_calc(carry_state,trajectory):
            jvp_h_tminus1_theta,rnn_forward_out_tangents=carry_state
            o_t,a_tminus1,h_tminus1=trajectory
            _,jvp_f_theta=jvp(rnn_forward_out,(rnn_params,(o_t,a_tminus1),h_tminus1),
                                          rnn_forward_out_tangents)
            del_f_h_tminus1=rnn_jac_hidden(rnn_params,(o_t,a_tminus1),h_tminus1)
            jvp_h_t_theta=jnp.tensordot(del_f_h_tminus1,jvp_h_tminus1_theta,axes=1)+jvp_f_theta
            return (jvp_h_t_theta,rnn_forward_out_tangents),None

        def scan_all_prev():
            scan_all_prev,_=jax.lax.scan(sensitivity_calc,(jvp_h_tminus1_theta,rnn_forward_out_tangents),(observations[1:],last_actions[1:],
                                                                                last_hidden_states[1:]))
            return scan_all_prev[0]
        return scan_all_prev()

def f_jvp(primals, tangents):
        rnn_params,inputs,last_state=primals
        rnn_params_t,inputs_t,last_state_t=tangents
        primal_out=transformed_f(rnn_params,inputs,last_state)
        tangent_out_1=jvp_through_time(forward_fn,rnn_params,primal_out[1],tangents)
        return primal_out,(tangent_out_1,primal_out[1]) #JVP for last_state is NotImplemented hence returns primal output

transformed_f.defjvp(f_jvp)

In [259]:
def rnn_forward_out(params,inputs,last_hidden_state):
            out,_=forward_fn(params,inputs,(last_hidden_state,None))
            return out

In [261]:
jit_fn=jax.jacrev(transformed_f)

In [279]:
key=random.PRNGKey(0)
sample_o=jax.random.normal(key,[20,10])
sample_a=jax.random.normal(key,[20,4])
h_t=MultiplicativeRNN.initial_state(10,4,32,4)
rnn_params=rnn_forward_pure.init(key,(sample_o[0],sample_a[0]),h_t)
def test_sensitivity(rnn_params,sample_o,sample_a,h_t):
    for i in range(10):
        out,h_t=(rnn_forward_pure.apply)(rnn_params,(sample_o[i],sample_a[i]),h_t)
    return out

true_sensitivity=jax.jacfwd(test_sensitivity)(rnn_params,sample_o,sample_a,h_t)
out_true=test_sensitivity(rnn_params,sample_o,sample_a,h_t).block_until_ready()

h_t=MultiplicativeRNN.initial_state(10,4,32,4)
rnn_params=rnn_forward_pure.init(key,(sample_o[0],sample_a[0]),h_t)
for i in range(10):
    h_tminus1=h_t
    out,h_t=(transformed_f)(rnn_params,(sample_o[i],sample_a[i]),h_t)
our_sensitivity=jit_fn(rnn_params,(sample_o[9],sample_a[9]),h_tminus1)[0]
out=test_sensitivity(rnn_params,sample_o,sample_a,h_t).block_until_ready()

print("Hidden state calculated correctly: ",jnp.allclose(out,out_true,atol=0.01))
print("Sensitivity calculated correctly: ",jnp.allclose(true_sensitivity['multiplicative_rnn']['w_h'],
                                                        our_sensitivity['multiplicative_rnn']['w_h'],atol=0.0001))

Hidden state calculated correctly:  True
Sensitivity calculated correctly:  True


In [280]:
our_sensitivity['multiplicative_rnn']['w_h']

DeviceArray([[[[ 5.71306013e-02, -3.69638503e-02, -1.84737682e-01,
                -5.97500168e-02],
               [ 1.13539644e-01, -6.99471012e-02, -3.35696071e-01,
                -1.11328594e-01],
               [ 7.35889599e-02, -4.62183021e-02, -2.25740716e-01,
                -7.40753785e-02],
               ...,
               [ 7.82930329e-02, -4.95377854e-02, -2.43257105e-01,
                -7.95406923e-02],
               [ 6.72445372e-02, -4.37715091e-02, -2.19151676e-01,
                -7.07369745e-02],
               [ 5.51667400e-02, -3.62026580e-02, -1.82693779e-01,
                -5.87135665e-02]],

              [[-1.18411472e-03,  4.35174501e-04,  1.59353600e-03,
                 7.07001716e-04],
               [-7.17509480e-04,  2.81363784e-04,  7.05261133e-04,
                 3.74341704e-04],
               [-9.01502208e-04,  3.07681039e-04,  1.08134688e-03,
                 5.04102500e-04],
               ...,
               [-1.23758125e-03,  4.17796255e-04,

In [189]:
true_sensitivity['multiplicative_rnn']['w_h']

DeviceArray([[[[ 5.71416616e-02, -3.69503610e-02, -1.84733570e-01,
                -5.97531162e-02],
               [ 1.13548018e-01, -6.99368417e-02, -3.35692942e-01,
                -1.11330934e-01],
               [ 7.35982060e-02, -4.62069884e-02, -2.25737289e-01,
                -7.40779638e-02],
               ...,
               [ 7.83026516e-02, -4.95260283e-02, -2.43253514e-01,
                -7.95433968e-02],
               [ 6.72517940e-02, -4.37625907e-02, -2.19149023e-01,
                -7.07390010e-02],
               [ 5.51725850e-02, -3.61954533e-02, -1.82691619e-01,
                -5.87151907e-02]],

              [[-1.18910312e-03,  4.29016683e-04,  1.59164739e-03,
                 7.07926345e-04],
               [-7.21266435e-04,  2.76689127e-04,  7.03833590e-04,
                 3.74950789e-04],
               [-9.05676920e-04,  3.02516506e-04,  1.07976596e-03,
                 5.04875497e-04],
               ...,
               [-1.24187779e-03,  4.12450405e-04,

In [None]:
jit_fn=jit(jax.jacrev(trf.apply))

In [222]:
import time

In [223]:
start=time.time()
our_sensitivity=MultiplicativeRNN.sensitivity(rnn_forward_pure.apply,rnn_params,h_t)
print(time.time()-start)
start=time.time()
our_sensitivity=jit_fn(rnn_params,(sample_o,sample_a),h_t)
print(time.time()-start)

0.18496108055114746


TypeError: dot_general requires contracting dimensions to have the same shape, got [4] and [20].

In [8]:
trf=rnn_transform(MultiplicativeRNN,10,4,32)

In [9]:
key=random.PRNGKey(0)
sample_o=jax.random.normal(key,[20,10])
sample_a=jax.random.normal(key,[20,4])
h_t=MultiplicativeRNN.initial_state(10,4,32,50)
rnn_params=rnn_forward_pure.init(key,sample_o[0],sample_a[0],h_t)
def test_sensitivity(rnn_params,sample_o,sample_a,h_t):
    for i in range(10):
        out,h_t=(rnn_forward_pure.apply)(rnn_params,sample_o[i],sample_a[i],h_t)
    return out

true_sensitivity=jax.jacrev(test_sensitivity)(rnn_params,sample_o,sample_a,h_t)
out_true=test_sensitivity(rnn_params,sample_o,sample_a,h_t).block_until_ready()

h_t=MultiplicativeRNN.initial_state(10,4,32,50)
rnn_params=rnn_forward_pure.init(key,sample_o[0],sample_a[0],h_t)
h_tminus1=None
for i in range(10):
    h_tminus1=h_t
    out,h_t=(trf.apply)(rnn_params,(sample_o[i],sample_a[i]),h_t)
our_sensitivity=jax.jacrev(trf.apply)(rnn_params,(sample_o[9],sample_a[9]),h_tminus1)[0]
out=test_sensitivity(rnn_params,sample_o,sample_a,h_t).block_until_ready()

print("Hidden state calculated correctly: ",jnp.allclose(out,out_true,atol=0.01))
print("Sensitivity calculated correctly: ",jnp.allclose(true_sensitivity['multiplicative_rnn']['w_h'],
                                                        our_sensitivity['multiplicative_rnn']['w_h'],atol=0.0001))



Hidden state calculated correctly:  True
Sensitivity calculated correctly:  True


In [None]:
our_sensitivity['multiplicative_rnn']['w_h']

GOOD JOB!