# Importing Packages

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

import numpy as np
import pandas as pd
import scipy as sc
import scipy.io as io
from functools import partial

import jax.numpy as jnp
import jax.scipy as jsc
from jax import grad, jit, vmap, pmap, random, lax, value_and_grad, tree_map, devices
import jax.example_libraries.optimizers as jeo

from tqdm import trange
import pickle

In [None]:
from jax.config import config
config.update("jax_enable_x64", True)
config.update("jax_debug_nans", True)

In [None]:
num_mach = 10
attemptno = 1

# Prior Functions

In [None]:
@jit
def AlpEnvelope(Xarr, TRise, TauDiff, Lag):
  TRise = TRise ** 2.0
  TDecay = TRise + TauDiff ** 2 + 1e-8
  TMax = (jnp.log(TRise / TDecay) * TDecay * TRise) / (TRise - TDecay)
  Max =  jnp.exp(- TMax / TDecay) - jnp.exp(- TMax / TRise)
  Delayed = Xarr - Lag
  NewTime = jnp.where(Delayed < 0.0, 1000, Delayed)
  return (jnp.exp(- NewTime / TDecay) - jnp.exp(- NewTime / TRise)) / Max

In [None]:
@jit
def Squared_exp(I, J, Sigma_f, Ell):
  return Sigma_f**2.0*jnp.exp(-(I-J)**2/(2.0*Ell**2))

# Obtaining Training Data

## Importing Spike Trains and Finger Movement

In [None]:
alldata = pd.read_hdf('./Data/Allfinger_veldata.h5') # Import DataFrame

In [None]:
alldata.xs(0)

In [None]:
trainind = (0, 1, 2)

In [None]:
alldata.loc[0].to_numpy().shape

In [None]:
data = [alldata.loc[i] for i in trainind]

In [None]:
data[2]

In [None]:
datlen = [len(i.index) for i in data]
print(min(datlen))

In [None]:
# Set number of time bins (k) and number of filters to use, num_tbin cannot exceed the minimum length in the training set!
num_tbin = 159913
n = num_tbin - 1
num_filt = 130
batch_size = num_filt // num_mach
num_filt = num_mach * batch_size
print(num_filt)

In [None]:
# All to correspond to dims (nummach, numsam, numfilt, k + 1, 1)

In [None]:
spikedat = [data[i].spikes.to_numpy()[:num_tbin, :num_filt].T[:, :, None] for i in trainind]

In [None]:
print(spikedat[2].shape)

In [None]:
ytime = [(data[i].index / np.timedelta64(1, 's')).to_numpy()[:num_tbin] for i in trainind] # Get spikes/output time array

In [None]:
# Get x velocities
xraw = [data[i].finger_vel.x.to_numpy()[:num_tbin].reshape(n + 1, 1) for i in trainind]

In [None]:
xraw[0].shape

## Standardising Data

In [None]:
# Set variance to 1.0
xvel = [xraw[i] / np.std(xraw[i]) for i in trainind]

In [None]:
del data # Clear data from memory
del alldata

In [None]:
for i in trainind:
  plt.plot(ytime[i], xvel[i])

## Setting up Filter Sizes

In [None]:
# Time bin size
time_bin = ytime[0][1] - ytime[0][0]
print(time_bin)

f_maxt = 1.0 # Filter horizon

k = np.floor(f_maxt/time_bin).astype(np.int16) # Maximum index of filter data

ftime = np.linspace(0.0, f_maxt, k + 1).reshape((k + 1, 1)) # Filter corresponding time array

## Getting FFT of Spike Data

In [None]:
# FFT of spike train
spikepad = [np.hstack((spikedat[i], np.zeros((num_filt, k, 1)))) for i in trainind]
spikefft = [np.fft.rfft(spikepad[i], axis=1) for i in trainind]
fftlen = [np.shape(spikefft[i])[1] for i in trainind]
spikefft = [spikefft[i].reshape(num_mach, batch_size, fftlen[i], 1) for i in trainind]
spikefft = [[spikefft[j][i] for j in range(len(trainind))] for i in range(num_mach)]
spikefft = jnp.asarray(spikefft)
xvel = jnp.asarray(xvel)

In [None]:
print(spikefft.shape)
print(xvel.shape)

# ELBO 

In [None]:
Diag = vmap(jnp.diag)

In [None]:
@jit
def Solver(Kmm, Diff):
  return jsc.linalg.solve(Kmm, Diff, sym_pos=True, check_finite=True)

In [None]:
V_Solver = vmap(Solver, in_axes = [None, 0])

In [None]:
@jit
def Likelihoods(Predictions, Velocity, Sigma_n, N):
  return - 0.5 * ((N + 1) * jnp.log(2 * jnp.pi * Sigma_n ** 2) + \
                    jnp.sum((Velocity - Predictions)**2, axis = -2)/(Sigma_n ** 2))

In [None]:
Likely = vmap(Likelihoods, in_axes=(0, 0, None, None,))

In [None]:
@partial(jit, static_argnums = (2, 3,))
def Irfft(Fft, Array, K, N):
  return jnp.fft.irfft(Fft * Array, N + K + 1, axis = -2)[:, :, : N + 1].sum(axis = 1)

In [None]:
VIrfft = vmap(Irfft, in_axes=(None, 0, None, None,))

In [None]:
@partial(jit, static_argnums = range(10, 17))
def Neg_ELBO(Sigma_f, Ell, Sigma_n, Z_Fractions, V_Vector, L_Diag, L_ODiag, TRise,
              TauDiff, Lag, Number_F, Num_Base, M, Num_Filt, Batch_Size, K, N,
              Indices, Spike_Fft, Velocity, F_Time, Subkeys):

  # Creating lag and z-vector         
  Lag = Lag ** 2.0
  
  Z_Vector = (0.5 * jnp.sin( jnp.pi * (Z_Fractions - 0.5)) + 0.5) * (F_Time[-1, 0] - Lag) + Lag

  # Creating the L-matrix
  D = Diag(jnp.exp(L_Diag))

  L_Matrix = D.at[:, Indices[0], Indices[1]].set(L_ODiag)

  # KL term
  KL = 0.5 * (- jnp.sum(jnp.log(jnp.diagonal(L_Matrix, axis1 = 1, axis2 = 2) ** 2)) + \
                jnp.sum(L_Matrix ** 2) + jnp.sum(V_Vector ** 2) - Batch_Size * M)

  # Expectation term
  Thetas = random.normal(Subkeys[0], (Number_F, Batch_Size, 1, Num_Base)) * (1.0 / Ell)

  Taus = random.uniform(Subkeys[1], (Number_F, Batch_Size, 1, Num_Base)) * 2.0 * jnp.pi

  Omegas = random.normal(Subkeys[2], (Number_F, Batch_Size, Num_Base, 1))

  Constant = (Sigma_f * jnp.sqrt(2.0 / Num_Base))

  ZT = Z_Vector.transpose(0, 2, 1)
 
  Phi1 = Constant * jnp.cos(F_Time * Thetas + Taus)
  Phi2 = Constant * jnp.cos(Z_Vector * Thetas + Taus)

  Kmm = Squared_exp(Z_Vector, ZT, Sigma_f, Ell)
  Knm = Squared_exp(F_Time, ZT, Sigma_f, Ell)

  C = jnp.linalg.cholesky(Kmm + jnp.eye(M) * 1e-6)

  V_u = C @ L_Matrix @ L_Matrix.transpose(0, 2, 1) @ C.transpose(0, 2, 1)

  Mu_u = C @ V_Vector

  V_uChol = jnp.linalg.cholesky(V_u + 1e-6 * jnp.eye(M))

  U_Samples = Mu_u + V_uChol @ random.normal(Subkeys[3], (Number_F, Batch_Size, M, 1))

  Vu = V_Solver(Kmm + 1e-6 * jnp.eye(M), U_Samples - Phi2 @ Omegas)

  F_Samples = (Phi1 @ Omegas + Knm @ Vu) * AlpEnvelope(F_Time, TRise, TauDiff, Lag)

  F_Fft = jnp.fft.rfft(F_Samples, n = N + K + 1, axis = -2)

  Filter_Out = VIrfft(F_Fft, Spike_Fft, K, N)

  Filter = lax.psum(Filter_Out, axis_name="machs")

  Likelihood = Likely(Filter, Velocity, Sigma_n, N).sum(axis=0)

  KL = lax.psum(KL, axis_name="machs")
  Exp = jnp.mean(Likelihood)
                  
  return (KL-Exp)/(Num_Filt * N + 1)

In [None]:
PNeg = pmap(Neg_ELBO, axis_name = "machs", in_axes=(0, 0, None, 0, 0, 0, 0, 0, 0, 0, None, None, None, None, None, None, None, None, 0, None, None, 0,),
            static_broadcasted_argnums = range(10, 17))

In [None]:
Grad_Bound = value_and_grad(Neg_ELBO, argnums = range(10))

# Training the Model

In [None]:
@jit
def MCalc(Grad, M, B1 = 0.9):
  return (1 - B1) * Grad + B1 * M

In [None]:
@jit
def MBias(M, Step, B1 = 0.9):
  return M / (1 - B1 ** (Step + 1))

In [None]:
@jit
def VCalc(Grad, V, B2 = 0.99):
  return (1 - B2) * jnp.square(Grad) + B2 * V

In [None]:
@jit
def VBias(V, Step, B2 = 0.99):
  return V / (1 - B2 ** (Step + 1))

In [None]:
@jit
def CFinState(X, Mhat, Vhat, Step_Size = 1e-2, Eps = 1e-8):
  return X - Step_Size * Mhat / (jnp.sqrt(Vhat) + Eps)

In [None]:
@jit
def MFinState(X, Mhat, Vhat, Step_Size = 1e-3, Eps = 1e-8):
  return X - Step_Size * Mhat / (jnp.sqrt(Vhat) + Eps)

In [None]:
@jit
def FFinState(X, Mhat, Vhat, Step_Size = 1e-4, Eps = 1e-8):
  return X - Step_Size * Mhat / (jnp.sqrt(Vhat) + Eps)

In [None]:
@jit
def CAdam(Step, X, Grad, M, V):

  M = tree_map(MCalc, Grad, M) # First  moment estimate.
  V = tree_map(VCalc, Grad, V)  # Second moment estimate.
  Step = tuple(Step * jnp.ones(10))
  Mhat = tree_map(MBias, M, Step) # Bias correction.
  Vhat = tree_map(VBias, V, Step) # Bias correction.

  X = tree_map(CFinState, X, Mhat, Vhat)

  return X, M, V

In [None]:
@jit
def MAdam(Step, X, Grad, M, V):

  M = tree_map(MCalc, Grad, M) # First  moment estimate.
  V = tree_map(VCalc, Grad, V)  # Second moment estimate.
  Step = tuple(Step * jnp.ones(10))
  Mhat = tree_map(MBias, M, Step) # Bias correction.
  Vhat = tree_map(VBias, V, Step) # Bias correction.

  X = tree_map(MFinState, X, Mhat, Vhat)

  return X, M, V

In [None]:
@jit
def FAdam(Step, X, Grad, M, V):

  M = tree_map(MCalc, Grad, M) # First  moment estimate.
  V = tree_map(VCalc, Grad, V)  # Second moment estimate.
  Step = tuple(Step * jnp.ones(10))
  Mhat = tree_map(MBias, M, Step) # Bias correction.
  Vhat = tree_map(VBias, V, Step) # Bias correction.

  X = tree_map(FFinState, X, Mhat, Vhat)

  return X, M, V

In [None]:
@partial(pmap, axis_name = "machs", in_axes=(None, 0, 0, None, 0, 0, 0, 0, 0, 0, 0, None, 
                                    None, None, None, None, None, None, None,  
                                    0, None, None, 0, 0, 0, 0,),
                static_broadcasted_argnums = range(11, 18))

def _CUpdate(Iter, Sigma_f, Ell, Sigma_n, Z_Fractions, V_Vector, L_Diag, L_ODiag, TRise,
              TauDiff, Lag, Num_Coarse_Fs, Num_Base, M, Num_Filt, Batch_Size, K, N,
              Indices, Spike_Fft, Velocity, F_Time, Subkeys, Key, Mad, Vad):
  
  Value, Grads = Grad_Bound(Sigma_f, Ell, Sigma_n, Z_Fractions, V_Vector, L_Diag, L_ODiag, TRise,
              TauDiff, Lag, Num_Coarse_Fs, Num_Base, M, Num_Filt, Batch_Size, K, N,
                  Indices, Spike_Fft, Velocity, F_Time, Subkeys)
  
  X, Mad, Vad = CAdam(Iter, (Sigma_f, Ell, Sigma_n, Z_Fractions, V_Vector, L_Diag, L_ODiag, TRise,
              TauDiff, Lag), Grads, Mad, Vad)

  Key, *Subkeys = random.split(Key, 5)
  Subkeys = jnp.asarray(Subkeys).astype(jnp.uint32).reshape(4, 2)

  return *X, Subkeys, Key, Mad, Vad, Value

In [None]:
@partial(pmap, axis_name = "machs", in_axes=(None, 0, 0, None, 0, 0, 0, 0, 0, 0, 0, None, 
                                    None, None, None, None, None, None, None, 
                                    0, None, None, 0, 0, 0, 0,),
                static_broadcasted_argnums = range(11, 18))

def _MUpdate(Iter, Sigma_f, Ell, Sigma_n, Z_Fractions, V_Vector, L_Diag, L_ODiag, TRise,
              TauDiff, Lag, Num_Mid_Fs, Num_Base, M, Num_Filt, Batch_Size, K, N,
              Indices, Spike_Fft, Velocity, F_Time, Subkeys, Key, Mad, Vad):
  
  Value, Grads = Grad_Bound(Sigma_f, Ell, Sigma_n, Z_Fractions, V_Vector, L_Diag, L_ODiag, TRise,
              TauDiff, Lag, Num_Mid_Fs, Num_Base, M, Num_Filt, Batch_Size, K, N,
                  Indices, Spike_Fft, Velocity, F_Time, Subkeys)
  
  X, Mad, Vad = MAdam(Iter, (Sigma_f, Ell, Sigma_n, Z_Fractions, V_Vector, L_Diag, L_ODiag, TRise,
              TauDiff, Lag), Grads, Mad, Vad)

  Key, *Subkeys = random.split(Key, 5)
  Subkeys = jnp.asarray(Subkeys).astype(jnp.uint32).reshape(4, 2)

  return *X, Subkeys, Key, Mad, Vad, Value

In [None]:
@partial(pmap, axis_name = "machs", in_axes=(None, 0, 0, None, 0, 0, 0, 0, 0, 0, 0, None, 
                                    None, None, None, None, None, None, None, 
                                    0, None, None, 0, 0, 0, 0,),
                static_broadcasted_argnums = range(11, 18))

def _FUpdate(Iter, Sigma_f, Ell, Sigma_n, Z_Fractions, V_Vector, L_Diag, L_ODiag, TRise,
              TauDiff, Lag, Num_Fine_Fs, Num_Base, M, Num_Filt, Batch_Size, K, N,
              Indices, Spike_Fft, Velocity, F_Time, Subkeys, Key, Mad, Vad):
  
  Value, Grads = Grad_Bound(Sigma_f, Ell, Sigma_n, Z_Fractions, V_Vector, L_Diag, L_ODiag, TRise,
              TauDiff, Lag, Num_Fine_Fs, Num_Base, M, Num_Filt, Batch_Size, K, N,
                  Indices, Spike_Fft, Velocity, F_Time, Subkeys)
  
  X, Mad, Vad = FAdam(Iter, (Sigma_f, Ell, Sigma_n, Z_Fractions, V_Vector, L_Diag, L_ODiag, TRise,
              TauDiff, Lag), Grads, Mad, Vad)

  Key, *Subkeys = random.split(Key, 5)
  Subkeys = jnp.asarray(Subkeys).astype(jnp.uint32).reshape(4, 2)

  return *X, Subkeys, Key, Mad, Vad, Value

In [None]:
@partial(pmap, in_axes = (0, 0, None, 0, 0, 0, 0, 0, 0, 0,))
def Init_Adam(Sigma_f, Ell, Sigma_n, Z_Fractions, V_Vector, L_Diag, L_ODiag, TRise,
              TauDiff, Lag):
  X = (Sigma_f, Ell, Sigma_n, Z_Fractions, V_Vector, L_Diag, L_ODiag, TRise,
              TauDiff, Lag)

  Mad = tree_map(jnp.zeros_like, X)

  return Mad, Mad
  

In [None]:
@jit
def Set(Matrix, Indices, NewVals):
  return Matrix.at[:, Indices[0], Indices[1]].set(NewVals)

In [None]:
PDiag = pmap(Diag)

In [None]:
PSet = pmap(Set, in_axes=(0, None, 0,))

In [None]:
spikeffta = [np.asarray([spikefft[i][j] for i in range(num_mach)]) for j in range(3)]

In [None]:
spikefftf = [spikeffta[i].reshape(num_filt, fftlen[i], 1) for i in range(3)]

In [None]:
@jit
def LowerSolv(MatA, MatB):
  return jsc.linalg.solve_triangular(MatA, MatB, lower = True)

In [None]:
VLS = vmap(LowerSolv)

In [None]:
@partial(jit, static_argnums = (1, 2, 3,))
def OUTconvolve(Filters, Num_Filt, K, N, Spike_FFT):
  Pad = jnp.hstack((Filters, jnp.zeros((Num_Filt, N, 1))))
  Pred_fft = jnp.fft.rfft(Pad, axis = -2)
  Pred_sum = jnp.sum(Pred_fft * Spike_FFT, axis = 0)
  Pred_fitfft = jnp.fft.irfft(Pred_sum, K + N + 1, axis = -2)[ : N + 1]
  return Pred_fitfft

# X Velocity

In [None]:
key = random.PRNGKey(4)

# Initialising the parameters

# Generative parameters
isigma_f = 1.0 * jnp.ones((num_mach, batch_size, 1, 1))
iell = 0.002 * jnp.ones ((num_mach, batch_size, 1, 1))
isigma_n = 0.000000005
itrise = np.sqrt(0.05) * jnp.ones((num_mach, batch_size, 1, 1))
itaudiff = np.sqrt(0.01) * jnp.ones((num_mach, batch_size, 1, 1))
ilag = np.sqrt(0.01) * jnp.ones((num_mach, batch_size, 1, 1))

xsigma_f = isigma_f
xell = iell
xsigma_n = isigma_n
xtrise = itrise
xtaudiff = itaudiff
xlag = ilag

# Variational parameters
num_f = 15
num_b = 100
num_ind = 20
iz_final = jnp.tile(jnp.linspace(0.0, 1.0, num_ind).reshape((num_ind, 1)), (num_mach, batch_size, 1, 1))
iz_fracs = (jnp.arcsin(2.0 * (iz_final - 0.5)) / jnp.pi) + 0.5
key, *subkeys = random.split(key, 4)
iv_vector = 0.01 * random.normal(subkeys[-3], (num_mach, batch_size, num_ind, 1))
il_diag = 0.01 * random.normal(subkeys[-2], (num_mach, batch_size, num_ind))
il_odiag = 0.1 * random.normal(subkeys[-1], (num_mach, batch_size, int((num_ind ** 2 - num_ind) / 2)))
indices = jnp.asarray(jnp.tril_indices(num_ind, -1))

xz_fracs = iz_fracs
xv_vector = iv_vector
xl_diag = il_diag
xl_odiag = il_odiag

# Random number generator
iopt_key = np.asarray([random.PRNGKey(i) for i in range(num_mach)]).astype(np.uint32)
itest = vmap(random.split, in_axes=(0, None))(iopt_key, 1 + 4 * num_mach)
iopt_key = np.asarray([itest[i][0] for i in range(num_mach)]).astype(np.uint32)
iopt_subkey = np.asarray([itest[i][1:] for i in range(num_mach)]).astype(np.uint32)

opt_key = iopt_key
opt_subkey = iopt_subkey
print(opt_subkey.shape)

In [None]:
mad, vad = Init_Adam(isigma_f, iell, isigma_n, iz_fracs, iv_vector, il_diag,
              il_odiag, itrise, itaudiff, ilag)

In [None]:
step = 0
xelbo_history = np.zeros(20000)

In [None]:
coarse_steps = 1500
mid_steps = 16500
fine_steps = 2000

In [None]:
extra_coarse_steps = 0
extra_mid_steps = 0
extra_fine_steps = 0

In [None]:
xelbo_history = np.hstack((xelbo_history, np.zeros(sum([extra_coarse_steps, extra_mid_steps, extra_fine_steps]))))

In [None]:
step = 0
for i in trange(coarse_steps):  
  xsigma_f, xell, xsigma_n, xz_fracs, xv_vector, xl_diag, xl_odiag, xtrise, xtaudiff, xlag, st, kt, mad, vad, xvalue = _CUpdate(i, xsigma_f, xell, xsigma_n, xz_fracs, xv_vector, xl_diag, xl_odiag, xtrise, xtaudiff, xlag, num_f, num_b, num_ind, num_filt, batch_size, k, n, indices, spikefft, xvel, ftime, opt_subkey, opt_key, mad, vad)
  xsigma_n = xsigma_n[0]
  xelbo_history[step] = -xvalue[0]
  step += 1

In [None]:
for i in trange(mid_steps): 
  xsigma_f, xell, xsigma_n, xz_fracs, xv_vector, xl_diag, xl_odiag, xtrise, xtaudiff, xlag, st, kt, mad, vad, xvalue = _MUpdate(i, xsigma_f, xell, xsigma_n, xz_fracs, xv_vector, xl_diag, xl_odiag, xtrise, xtaudiff, xlag, num_f, num_b, num_ind, num_filt, batch_size, k, n, indices, spikefft, xvel, ftime, opt_subkey, opt_key, mad, vad)
  xsigma_n = xsigma_n[0]
  xelbo_history[step] = -xvalue[0]
  step += 1

In [None]:
for i in trange(fine_steps):  
  xsigma_f, xell, xsigma_n, xz_fracs, xv_vector, xl_diag, xl_odiag, xtrise, xtaudiff, xlag, st, kt, mad, vad, xvalue = _FUpdate(i, xsigma_f, xell, xsigma_n, xz_fracs, xv_vector, xl_diag, xl_odiag, xtrise, xtaudiff, xlag, num_f, num_b, num_ind, num_filt, batch_size, k, n, indices, spikefft, xvel, ftime, opt_subkey, opt_key, mad, vad)
  xsigma_n = xsigma_n[0]
  xelbo_history[step] = -xvalue[0]
  step += 1

In [None]:
plt.plot(xelbo_history[10000:])

In [None]:
print(PNeg(isigma_f, iell, isigma_n, iz_fracs, iv_vector, il_diag,
                il_odiag, itrise, itaudiff, ilag, num_f, num_b,
                num_ind, num_filt, batch_size, k, n, indices, spikefft, xvel, 
                ftime, iopt_subkey))

In [None]:
print(PNeg(xsigma_f, xell, xsigma_n, xz_fracs, xv_vector, xl_diag,
                xl_odiag, xtrise, xtaudiff, xlag, num_f, num_b,
                num_ind, num_filt, batch_size, k, n, indices, spikefft, xvel, 
                ftime, iopt_subkey))

In [None]:
xfsigma_f, xfell, xfsigma_n, xfz_fracs, xfv_vector, xfl_diag, xfl_odiag, xftrise, xftaudiff, xflag = xsigma_f, xell, xsigma_n, xz_fracs, xv_vector, xl_diag, xl_odiag, xtrise, xtaudiff, xlag

In [None]:
xoutdict = {"Parameters": (xfsigma_f, xfell, xfsigma_n, xfz_fracs, xfv_vector, xfl_diag, xfl_odiag, xftrise, xftaudiff, xflag),
          "ELBO History": xelbo_history,
          "Training Sets": trainind,
          "Steps":(coarse_steps, mid_steps, fine_steps, extra_coarse_steps, extra_mid_steps, extra_fine_steps),
          "Initial Parameters": (isigma_f, iell, isigma_n, iz_fracs, iv_vector, il_diag, il_odiag, itrise, itaudiff, ilag),
          "Training Parameters": (num_f, num_b, num_ind, num_filt),
          "mad": mad,
          "vad": vad,
          "Velocity Trained": "x"}

In [None]:
with open('./Data/Hyperparameters/' + str(n) + '_' + str(attemptno) + 'xvel' + '.pkl', 'wb') as f:
    pickle.dump(xoutdict, f)

In [None]:
ilag = ilag ** 2

iz_vector = (0.5 * np.sin( np.pi * (iz_fracs - 0.5)) + 0.5) * (ftime[-1, 0] - ilag) + ilag

idz = PDiag(np.exp(il_diag))
print(il_diag.shape)
print(idz.shape)
il_matrix = PSet(idz, indices, il_odiag)


xflag = xflag ** 2

xfz_vector = (0.5 * np.sin( np.pi * (xfz_fracs - 0.5)) + 0.5) * (ftime[-1] - xflag) + xflag

xfd = PDiag(np.exp(xfl_diag))
xfl_matrix = PSet(xfd, indices, xfl_odiag)
print(xfd.shape)
print(xfl_matrix.shape)

In [None]:
isigma_ff = isigma_f.reshape(num_filt, 1)
iellf = iell.reshape(num_filt, 1)
iz_vectorf = iz_vector.reshape(num_filt, num_ind, 1)
iv_vectorf = iv_vector.reshape(num_filt, num_ind, 1)
il_matrixf = il_matrix.reshape(num_filt, num_ind, num_ind)
itrisef = itrise.reshape(num_filt, 1)
itaudifff = itaudiff.reshape(num_filt, 1)
ilagf = ilag.reshape(num_filt, 1)

xfsigma_ff = xfsigma_f.reshape(num_filt, 1)
xfellf = xfell.reshape(num_filt, 1)
xfz_vectorf = xfz_vector.reshape(num_filt, num_ind, 1)
xfv_vectorf = xfv_vector.reshape(num_filt, num_ind, 1)
xfl_matrixf = xfl_matrix.reshape(num_filt, num_ind, num_ind)
xftrisef = xftrise.reshape(num_filt, 1)
xftaudifff = xftaudiff.reshape(num_filt, 1)
xflagf = xflag.reshape(num_filt, 1)

In [None]:
print(xfsigma_ff)

In [None]:
print(xfellf)

In [None]:
print(xfsigma_n)

In [None]:
print(xftrisef)

In [None]:
print(xftaudifff)

In [None]:
print(xflagf)

## Variational Predictions

In [None]:
fpred_points = ftime.copy()

In [None]:
scisigmaf = isigma_ff[:, None]
sciellf = iellf[:, None]
scitrisef = itrisef[:, None]
scitaudifff = itaudifff[:, None]
scilagf = ilagf[:, None]

ikmm = Squared_exp(iz_vectorf, iz_vectorf.transpose(0, 2, 1), scisigmaf, sciellf)
icpred = np.linalg.cholesky(ikmm + 1e-6 * np.eye(num_ind))
del ikmm
ikzast = Squared_exp(iz_vectorf, fpred_points.T, scisigmaf, sciellf)
# ibzast = np.zeros((num_filt, num_ind, k + 1))
# for i in range(num_filt):
#   ibzast[i] = sc.linalg.solve_triangular(icpred[i], ikzast[i], lower = True)
ibzast = VLS(icpred, ikzast)
del ikzast
del icpred
ibzastT = ibzast.transpose(0, 2, 1)
ipredenv = AlpEnvelope(fpred_points, scitrisef, scitaudifff, scilagf)
ipredenvT = ipredenv.transpose(0, 2, 1)
imeanpred = ipredenv * (ibzastT @ iv_vectorf)
ikastast = Squared_exp(fpred_points, fpred_points.T, scisigmaf, sciellf)
icovpred = ipredenv * (ikastast + ibzastT @ (il_matrixf @ il_matrixf.transpose(0, 2, 1) - np.eye(num_ind)) @ ibzast) * ipredenvT

In [None]:
xscfsigmaf = xfsigma_ff[:, None]
xscfellf = xfellf[:, None]
xscftrisef = xftrisef[:, None]
xscftaudifff = xftaudifff[:, None]
xscflagf = xflagf[:, None]

xfkmm = Squared_exp(xfz_vectorf, xfz_vectorf.transpose(0, 2, 1), xscfsigmaf, xscfellf)
xfcpred = np.linalg.cholesky(xfkmm + 1e-6 * np.eye(num_ind))
xfkzast = Squared_exp(xfz_vectorf, fpred_points.T, xscfsigmaf, xscfellf)
xfbzast = np.zeros((num_filt, num_ind, k + 1))
# for i in range(num_filt):
#   xfbzast[i] = sc.linalg.solve_triangular(xfcpred[i], xfkzast[i], lower = True)
xfbzast = VLS(xfcpred, xfkzast)
xfbzastT = xfbzast.transpose(0, 2, 1)
xfpredenv = AlpEnvelope(fpred_points, xscftrisef, xscftaudifff, xscflagf)
xfpredenvT = xfpredenv.transpose(0, 2, 1)
xfmeanpred = xfpredenv * (xfbzastT @ xfv_vectorf)

xfkastast = Squared_exp(fpred_points, fpred_points.T, xscfsigmaf, xscfellf)
xfcovpred = xfpredenv * (xfkastast + xfbzastT @ (xfl_matrixf @ xfl_matrixf.transpose(0, 2, 1) - np.eye(num_ind)) @ xfbzast) * xfpredenvT

In [None]:
print(imeanpred.shape)

In [None]:
ipredindenv = AlpEnvelope(iz_vectorf, scitrisef, scitaudifff, scilagf)
index = 2
plt.plot(fpred_points, imeanpred[index])
is2 = np.sqrt(np.diag(icovpred[index]))
ifill = (imeanpred[index].flatten() - is2, imeanpred[index].flatten() + is2)
plt.fill_between(fpred_points.flatten(), *ifill, alpha = 0.1)
plt.scatter(iz_vectorf[index], (ipredindenv * (icpred @ iv_vectorf))[index], s = 100, marker = '+', c='k', alpha = 0.6)
#plt.xlim([ftime[0], ftime[-1]])

In [None]:
xpredindenv = AlpEnvelope(xfz_vectorf, xscftrisef, xscftaudifff, xscflagf)

plt.plot(fpred_points, xfmeanpred[index])
xs2 = np.sqrt(np.diag(xfcovpred[index]))
fill = (xfmeanpred[index].flatten() - xs2, xfmeanpred[index].flatten() + xs2)
plt.fill_between(fpred_points.flatten(), *fill, alpha = 0.1)
plt.scatter(xfz_vectorf[index], (xpredindenv * (xfcpred @ xfv_vectorf))[index], s = 100, marker = '+', c='k', alpha = 0.6)
#plt.xlim([ftime[0], ftime[-1]])

In [None]:
xenvelopes = AlpEnvelope(ftime, xscftrisef, xscftaudifff, xscflagf)
for i in xenvelopes:
  plt.plot(ftime, i)

In [None]:
for index in range(num_filt):
  plt.plot(fpred_points, xfmeanpred[index])
  xs2 = np.sqrt(np.diag(xfcovpred[index]))
  xfill = (xfmeanpred[index].flatten() - xs2, xfmeanpred[index].flatten() + xs2)
  plt.fill_between(fpred_points.flatten(), *xfill, alpha = 0.1)
  plt.scatter(xfz_vectorf[index], (xpredindenv * (xfcpred @ xfv_vectorf))[index], s = 100, marker = '+', c='k', alpha = 0.6)
  #plt.xlim([ftime[0], ftime[-1]])

In [None]:
@partial(jit, static_argnums = (1, 2,))
def OUTconvolve(Filters, K, N, Spike_FFT):
  Pad = np.hstack((Filters, np.zeros((num_filt, N, 1))))
  Pred_fft = np.fft.rfft(Pad, axis = -2)
  Pred_sum = np.sum(Pred_fft * Spike_FFT, axis = 0)
  Pred_fitfft = np.fft.irfft(Pred_sum, K + N + 1, axis = -2)[ : N + 1]
  return Pred_fitfft

In [None]:
is2 = np.sqrt(np.diagonal(icovpred, axis1=1, axis2=2)).reshape((num_filt, k + 1, 1))

imeanpredict = OUTconvolve(imeanpred, k, n, spikefftf[0])
imeanpredictns2 = OUTconvolve(imeanpred - 2 * is2, k, n, spikefftf[0])
imeanpredictps2 = OUTconvolve(imeanpred + 2 * is2, k, n, spikefftf[0])

In [None]:
xfs2 = np.sqrt(np.diagonal(xfcovpred, axis1=1, axis2=2)).reshape((num_filt, k + 1, 1))

xfmeanpredict = OUTconvolve(xfmeanpred, k, n, spikefftf[2])
print(xfmeanpredict.shape)
xfmeanpredictns2 = OUTconvolve(xfmeanpred - 2 * xfs2, k, n, spikefftf[2])
xfmeanpredictps2 = OUTconvolve(xfmeanpred + 2 * xfs2, k, n, spikefftf[2])

In [None]:
plt.plot(ytime[0], imeanpredict)
plt.plot(ytime[0], xvel[0])
plt.fill_between(ytime[0].flatten(), imeanpredictns2.flatten() - 2 * isigma_n,  
                  imeanpredictps2.flatten() + 2 * isigma_n, alpha = 0.5)

In [None]:
plt.plot(ytime[2], xfmeanpredict)
plt.plot(ytime[2], xvel[2])
plt.fill_between(ytime[2].flatten(), xfmeanpredictns2.flatten() - 2 * xfsigma_n,  
                  xfmeanpredictps2.flatten() + 2 * xfsigma_n, alpha = 0.5)

In [None]:
xytesttru = xvel[:, :3000]
xytimetru = ytime[0][:3000]
xfmeanpredicttru = xfmeanpredict[:3000]
xfmeanpredictns2tru = xfmeanpredictns2[:3000]
xfmeanpredictps2tru = xfmeanpredictps2[:3000]
plt.plot(xytimetru, xfmeanpredicttru)
plt.plot(xytimetru, xytesttru[0])
plt.fill_between(xytimetru.flatten(), xfmeanpredictns2tru.flatten() - 2 * xfsigma_n,  
                  xfmeanpredictps2tru.flatten() + 2 * xfsigma_n, alpha = 0.5)

# On Test Data

In [None]:
alldata = pd.read_hdf('./Data/Allfinger_veldata.h5') # Import DataFrame

In [None]:
testind = (3,)

In [None]:
testdat = [alldata.loc[i] for i in testind]

In [None]:
testdat[0]

In [None]:
num_tes = 161393
tes_filt = 130
ntest = num_tes - 1

In [None]:
testspike = [testdat[i].spikes.to_numpy()[:num_tes, :tes_filt].T[:, :, None] for i in range(len(testind))]

In [None]:
testpad = [jnp.hstack((i, jnp.zeros((num_filt, k, 1)))) for i in testspike]
print(testpad[0].shape)
testspikefft = jnp.asarray([np.fft.rfft(i, axis = -2) for i in testpad])
print(testspikefft[0].shape)

In [None]:
testtime = [(testdat[i].index / np.timedelta64(1, 's')).to_numpy()[:num_tes] for i in range(len(testind))] # Get spikes/output time array

In [None]:
# Get x velocities
xrawtes = [testdat[i].finger_vel.x.to_numpy()[:num_tes].reshape(ntest + 1, 1) for i in range(len(testind))]

## Standardising Data

In [None]:
# Set variance to 1.0
xtest = np.asarray([xrawtes[i] / np.std(xrawtes[i]) for i in range(len(testind))])

In [None]:
del testdat # Clear data from memory
del alldata

In [None]:
for i in range(len(testind)):
  plt.plot(testtime[i], xtest[i])

## Plotting vs Test

In [None]:
xtmeanpredict = OUTconvolve(xfmeanpred, tes_filt, k, ntest, testspikefft[0])
xtmeanpredictns2 = OUTconvolve(xfmeanpred - 2 * xfs2, tes_filt, k, ntest, testspikefft[0])
xtmeanpredictps2 = OUTconvolve(xfmeanpred + 2 * xfs2, tes_filt, k, ntest, testspikefft[0])

In [None]:
plt.plot(testtime[0], xtmeanpredict)
plt.plot(testtime[0], xtest[0])
plt.fill_between(testtime[0].flatten(), xtmeanpredictns2.flatten() - 2 * xfsigma_n,  
                  xtmeanpredictps2.flatten() + 2 * xfsigma_n, alpha = 0.5)

## R^2

## X Velocity

### Test Data

In [None]:
alltestpredict = np.asarray([OUTconvolve(xfmeanpred, num_filt, k, ntest, i) for i in testspikefft])

In [None]:
plt.scatter(xtest.flatten(), alltestpredict.flatten(), s=.005)

In [None]:
tess = np.sum(np.asarray([np.square(xtest[i] - alltestpredict[i]) for i in range(len(testind))]).flatten())
ttss = np.sum(np.asarray([np.square(xtest[i] - np.mean(xtest[i])) for i in range(len(testind))]).flatten())
tr2 = 1 - tess/ttss
print(tr2)

### Training Data

In [None]:
allpredict = np.asarray([OUTconvolve(xfmeanpred, num_filt, k, n, i) for i in spikefftf])

In [None]:
plt.scatter(xvel.flatten(), allpredict.flatten(), s=.005)

In [None]:
ess = np.sum(np.asarray([np.square(xvel[i] - allpredict[i]) for i in range(len(trainind))]).flatten())
tss = np.sum(np.asarray([np.square(xvel[i] - np.mean(xvel[i])) for i in range(len(trainind))]).flatten())
r2 = 1 - ess/tss
print(r2)