In [1]:
import numpy as onp
import jax.numpy as np
from jax import grad, jit, vmap, value_and_grad
from jax.experimental import optimizers
import jax.nn as nn
from functools import partial
from jax import lax
from tqdm import tnrange
from sklearn import metrics
import matplotlib.pyplot as plt
from sklearn.metrics import explained_variance_score
import scipy
import time

In [2]:
def RSLDS(N, K, H, C_syn, in_no):
    def init_fun(N, K, H, C_syn, in_no, batch_size):
        C_syn = C_syn
        
        W_zu = onp.random.randn(N, K, K)*0.01
        W_zx = onp.random.randn(N, K, K, H)*0.01
        W_xx = onp.random.randn(N, K, H, H)*0.01
        W_xu = onp.random.randn(N, K, H)*0.01
        W_yx = onp.random.randn(N*H)*0.01
        
        b_z = onp.random.randn(N, K, K)*0.01
        b_x = onp.random.randn(N, K, H)*0.01
        b_y = onp.random.randn(1)*0.01
        
        U_scale = onp.random.randn(in_no)*0.01+1
        Z_init = onp.random.randn(batch_size, N, K)*0.01
        X_init = onp.random.randn(batch_size, N, H)*0.01
        
        return (W_zx, W_zu, b_z,
                W_xx, W_xu, b_x,
                W_yx, b_y, U_scale, Z_init, X_init)
    
    def apply_fun(params, inputs, temp, batch_size):
        Z_init = nn.softmax(params[-2] / temp, -1) #(B,N,K)
        X_init = params[-1] #(B,N,H)
        
        def apply_fun_scan(params, Z, X, U_raw, temp):
            W_zx, W_zu, b_z,
            W_xx, W_xu, b_x,
            W_yx, b_y, U_scale, _, __ = params
            
            U_scaled = U_raw * U_scale.reshape(1,-1)
            U = np.matmul(U_scaled, C_syn.T) # (B, N)
            batch = U.shape[0]
            
            W_zu_part = np.sum(np.expand_dims(W_zu, 0) * np.expand_dims(Z, -1), -2) #(B,N,K)
            W_zx_part = np.sum(np.expand_dims(W_zx, 0) * np.expand_dims(Z,[-1,-2]), -3) #(B,N,K,H)
            b_z_part = np.sum(np.expand_dims(b_z, 0) * np.expand(Z, -1), -2) #(B,N,K)
            
            Z_new_raw = np.matmul(W_zx_part, np.expand_dims(X, -1)).squeeze(-1) \
                + W_zu_part * np.expand_dims(U, -1) + b_z_part
            Z_new = nn.softmax(Z_new_raw / temp, -1) #(B,N,K)
            
            W_xu_part = np.sum(np.expand_dims(W_xu, 0) * np.expand_dims(Z_new, -1), -2) #(B,N,H)
            W_xx_part = np.sum(np.expand_dims(W_xx, 0) * np.expand_dims(Z_new,[-1,-2]), -3) #(B,N,H,H)
            b_x_part = np.sum(np.expand_dims(b_x, 0) * np.expand(Z_new, -1), -2) #(B,N,H)
            
            X_new = np.matmul(W_xx_part, np.expand_dims(X, -1)).squeeze(-1) \
                + W_xu_part * np.expand_dims(U, -1) + b_x_part #(B,N,H)
            
            Y = np.sum(X_new.reshape(batch,-1) * W_yx.reshape(1,-1), -1) + b_y
            
            return Z_new, X_new, Y, Z_new, X_new
            
        # Change input from (B,T,I) to (T,B,I)
        inputs = np.moveaxis(inputs, 1, 0)
        f = partial(apply_fun_scan, params)
        _, __, Y, Z_new, X_new = lax.scan(f, Z_init, X_init, inputs, temp)
        return Y, Z_new, X_new
    return init_fun, apply_fun

In [3]:
base_dir = "/media/hdd01/sklee/"
experiment = "clust4-60"
cell_type = "CA1"
E_neural_file = "Espikes_neural.npz"
V_file = "V_diff.npy"
eloc_file = "Elocs_T10_Ne2000_gA0.6_tauA1_gN0.8_Ni200_gG0.1_gB0.1_Er0.5_Ir7.4_random_NR_rep1000_stimseed1.npy"

E_neural = scipy.sparse.load_npz(base_dir+cell_type+"_"+experiment+"/data/"+E_neural_file)
V = onp.load(base_dir+cell_type+"_"+experiment+"/data/"+V_file)
eloc = onp.load(base_dir+cell_type+"_"+experiment+"/data/"+eloc_file)

den_idx = onp.unique(eloc[880:1120,0])
e_idx = onp.where(np.isin(eloc[:,0], den_idx) == True)[0]

In [4]:
T_train = 999 * 1000 * 50
T_test = 1 * 1000 * 50
H = 3 # H
N = 4 # N
K = 3 # K
in_no = 299

batch_length = 50000
batch_size = 9
iter_no = 9990
epoch_no = iter_no*batch_length*batch_size//T_train

In [5]:
C_syn = onp.zeros((N, in_no))
for i in range(in_no):
    idx = e_idx[i]
    if eloc[idx,0] == den_idx[0]:
        C_syn[0,i] = 1.
    elif eloc[idx,0] == den_idx[1]:
        C_syn[1,i] = 1.
    elif eloc[idx,0] == den_idx[2]:
        C_syn[2,i] = 1.
    elif eloc[idx,0] == den_idx[3]:
        C_syn[3,i] = 1.

In [6]:
V_train = V[:T_train]
V_test = V[T_train:T_train + T_test]
test_E_neural = E_neural[T_train:T_train+T_test].toarray()
train_E_neural = E_neural[:T_train]

train_idx = onp.empty((epoch_no, T_train//batch_length//batch_size))
for i in range(epoch_no):
    part_idx = onp.arange(0, T_train, batch_length*batch_size)
    onp.random.shuffle(part_idx)
    train_idx[i] = part_idx
train_idx = train_idx.flatten()

In [7]:
init_fun, rnn = RSLDS(N, K, H, C_syn, in_no)
params = init_fun(N, K, H, C_syn, in_no, batch_size)

def mse_loss(params, inputs, targets, temp, batch):
    preds = rnn(params, inputs, temp, batch)
    return np.mean((preds - targets)**2)

@jit
def update(params, x, y, opt_state, temp, batch):
    loss, grads = value_and_grad(mse_loss)(params, x, y, temp, batch)
    opt_state = opt_update(0, grads, opt_state)
    return get_params(opt_state), opt_state, loss


In [8]:

opt_init, opt_update, get_params = optimizers.adam(0.0025)
opt_state = opt_init(params)

In [9]:
temp_list = np.logspace(0,-3,20)
temp_count = 0

for i in tnrange(iter_no):
    if (i%20 == 19) & (temp_count < 19):
        temp_count += 1
    temp = temp_list[temp_count]
    
    batch_idx = int(train_idx[i])
    batch_E_neural = train_E_neural[batch_idx : batch_idx+batch_length*batch_size].toarray().reshape(batch_size, batch_length, -1)
    batch_E_neural = batch_E_neural[:,:,e_idx].astype(float)
    batch_V = V_train[batch_idx : batch_idx+batch_length*batch_size].reshape(batch_size, -1)
    
    params, opt_state, loss = update(params, batch_E_neural, batch_V, opt_state, temp, batch_size)
    print(loss)
    

  for i in tnrange(iter_no):


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=9990.0), HTML(value='')))




ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(float32[50000,9,299])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `int` function. If trying to convert the data type of a value, try using `x.astype(int)` or `jnp.array(x, int)` instead.
While tracing the function update at <ipython-input-7-1b9c45d4ebd6>:8, this concrete value was not available in Python because it depends on the value of the arguments to update at <ipython-input-7-1b9c45d4ebd6>:8 at flattened positions [11], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).
 (https://jax.readthedocs.io/en/latest/errors.html#jax._src.errors.ConcretizationTypeError)