In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as ss
import scipy.special as ssp
import sys
import scipy.io as sio
sys.path.append("../code/")
from vae import *
from util import *
from keras.callbacks import ModelCheckpoint

## import plot packages
from matplotlib.ticker import FormatStrFormatter
from matplotlib import ticker

%load_ext autoreload
%autoreload 2
%matplotlib inline

## load macaque data

In [None]:
## load trial information
## starttime, endtime, number, tgtontime, gocuetime, tgtdir, tgtid
trial_dat = sio.loadmat("../data/chewie_data/Chewie_20161006_trials_array.mat")

In [None]:
## load spike data
dat_ = sio.loadmat('../data/chewie_data/Chewie_20161006_seq.mat');

dat_all = [[] for _ in range(8)];
tar_dir = np.unique(trial_dat['trials_array'][:,5])[:8];
trial_dat_id = np.unique(trial_dat['trials_array'][:,5],return_inverse=True)[1];

for trial_id in range(251):
    if dat_['seq'][0][trial_id]['T'][0,0] != 0:
        dat_all[trial_dat_id[trial_id]].append(dat_['seq'][0][trial_id]['y'].T[:,63:]);

dat_all = np.array([np.array(dat_all[ii]) for ii in range(8)]);

In [None]:
## randomly split into batches
np.random.seed(666);
trial_ls = [np.random.permutation(np.array_split(np.random.permutation(np.arange(dat_all[ii].shape[0])),24)) for ii in range(8)];

x_all = [];
u_all = [];
cu_all = [];
for ii in range(24): # 24 batches
    x_tr = [];
    u_tr = [];
    cu_tr= [];
    for jj in range(8): # 8 different directions
        x_tmp = np.concatenate(dat_all[jj][trial_ls[jj][ii]])#[:,:-1];
        cu_tmp = np.ones((x_tmp.shape[0],1))*jj;
        u_tmp = np.ones((x_tmp.shape[0],1))*0.0;
        x_tr.append(x_tmp);
        cu_tr.append(cu_tmp);
        u_tr.append(u_tmp);
    x_all.append(np.concatenate(x_tr));
    u_all.append(np.concatenate(u_tr));
    cu_all.append(np.concatenate(cu_tr));

x_all = np.array(x_all);
u_all = np.array(u_all);
cu_all = np.array(cu_all);

In [None]:
x_train = x_all[:20];
u_train = u_all[:20];

x_valid = x_all[20:22];
u_valid = u_all[20:22];

x_test = x_all[22:];
u_test = u_all[22:];

## fit vae

In [None]:
np.random.seed(666);
vae = vae_mdl(dim_x=x_all[0].shape[-1], 
                   dim_z=4,
                   gen_nodes=60, n_blk=2, mdl='poisson', learning_rate=5e-4)

In [None]:
model_chk_path = '../results/macaque_4d_999_vae.h5' ##999, 777
mcp = ModelCheckpoint(model_chk_path, monitor="val_loss", save_best_only=True, save_weights_only=True)
s_n = vae.fit_generator(custom_data_generator(x_train, u_train),
              steps_per_epoch=len(x_train), epochs=1000,
              verbose=1,
              validation_data = custom_data_generator(x_valid, u_valid),
              validation_steps = len(x_valid), callbacks=[mcp]);

In [None]:
plt.plot(s_n.history['val_loss'][:])

In [None]:
model_chk_path = '../results/macaque_4d_999_vae.h5'
vae.load_weights(model_chk_path);

In [None]:
outputs = vae.predict_generator(custom_data_generator(x_all, u_all),
                                                steps = len(x_all));
# post_mean, post_log_var, z_sample,fire_rate, lam_mean, lam_log_var, z_mean, z_log_var
print(outputs[0].var(axis=0))  ## variance of each latent dimension

In [None]:
z_pred_all = [];
for ii in range(dat_all.shape[0]):
    z_pred_tmp = [];
    for jj in range(dat_all[ii].shape[0]):
        z_pred_tmp.append(vae.predict([dat_all[ii][jj], np.ones((dat_all[ii][jj].shape[0],1))*(0.0)])[0]);
    z_pred_all.append(z_pred_tmp);
z_pred_all = np.array(z_pred_all);

## make plots

In [None]:
## posterior mean
c_vec = np.array(['red','orange','green','blue','indigo','pink','brown','gray'])
c_all = np.array(np.concatenate(cu_all).reshape(-1), dtype='int');

#link = {0:0,1:1,2:3};
#link = {0:3,1:1,2:2};
fsz = 14;

fig = plt.figure(figsize=(8,4))

ax1 = plt.subplot(1,2,1);
ax1.set_xlabel('Latent 1',fontsize=fsz,fontweight='normal');
ax1.set_ylabel('Latent 2',fontsize=fsz,fontweight='normal');
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
plt.scatter(outputs[0][:,3], outputs[0][:,1], s=1, c=c_vec[c_all%8], alpha=0.5);
plt.setp(ax1.get_xticklabels(), fontsize=fsz);
plt.setp(ax1.get_yticklabels(), fontsize=fsz);

ax2 = plt.subplot(1,2,2);
ax2.set_xlabel('Latent 3',fontsize=fsz,fontweight='normal');
ax2.set_ylabel('Latent 4',fontsize=fsz,fontweight='normal');
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
plt.scatter(outputs[0][:,2], outputs[0][:,0], s=1, c=c_vec[c_all%8], alpha=0.5);
plt.setp(ax2.get_xticklabels(), fontsize=fsz);
plt.setp(ax2.get_yticklabels(), fontsize=fsz);

plt.tight_layout()

In [None]:
## posterior mean average accross trials/repeats
c_vec = np.array(['red','orange','green','blue','indigo','pink','brown','gray'])
c_all = np.array(np.concatenate(u_all).reshape(-1), dtype='int');

#ndim = 2;
#ndir = 0;
#select = (np.concatenate(u_all).reshape(-1) == ndir);
fig = plt.figure(figsize=(8,4))
ax1 = plt.subplot(1,2,1);
ax1.set_xlabel('Latent 1',fontsize=fsz,fontweight='normal');
ax1.set_ylabel('Latent 2',fontsize=fsz,fontweight='normal');
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)

for ndir in range(8):
    mean1 = mean2 = 0;
    n_tr = len(z_pred_all[ndir]);
    counter = 0;
    for ii in range(n_tr):
        if len(z_pred_all[ndir][ii]) >= 20:
            counter += 1;
            mean1 += z_pred_all[ndir][ii][:20,3];
            mean2 += z_pred_all[ndir][ii][:20,1];
    #print(counter);
    ax1.plot(mean1/counter, mean2/counter, '-x', c=c_vec[ndir]);
    
plt.setp(ax1.get_xticklabels(), fontsize=fsz);
plt.setp(ax1.get_yticklabels(), fontsize=fsz);

ax2 = plt.subplot(1,2,2);
ax2.set_xlabel('Latent 3',fontsize=fsz,fontweight='normal');
ax2.set_ylabel('Latent 4',fontsize=fsz,fontweight='normal');
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)


for ndir in range(8):
    mean1 = mean2 = 0;
    n_tr = len(z_pred_all[ndir]);
    counter = 0;
    for ii in range(n_tr):
        if len(z_pred_all[ndir][ii]) >= 20:
            counter += 1;
            mean1 += z_pred_all[ndir][ii][:20,2];
            mean2 += z_pred_all[ndir][ii][:20,0];
    #print(counter);
    ax2.plot(mean1/counter, mean2/counter, '-x', c=c_vec[ndir]);
plt.setp(ax2.get_xticklabels(), fontsize=fsz);
plt.setp(ax2.get_yticklabels(), fontsize=fsz);

plt.tight_layout()

## compute log likelihood

In [None]:
## sample u
u_fake = np.array([[np.ones((x_test[ii].shape[0],1))*jj for ii in range(len(x_test))] 
                   for jj in range(1)])

## compute loglik
np.random.seed(666);
lik_all = compute_marginal_lik_poisson(vae, x_test, u_fake, 500, log_opt = True);
lik_use = np.concatenate([lik_all[jj].mean(axis=0)-ssp.loggamma(x_test[jj]+1).sum(axis=-1) for jj in range(len(lik_all))]);

## save as np.save("../results/lik_vae_chewie.npy", lik_use) 

## compute firing rate

In [None]:
## compute firing rate for vae
z_pred_all = [];
for ii in range(dat_all.shape[0]):
    z_pred_tmp = [];
    for jj in range(dat_all[ii].shape[0]):
        z_pred_tmp.append(vae.predict([dat_all[ii][jj], np.ones((dat_all[ii][jj].shape[0],1))*(0.0)])[3]);
    z_pred_all.append(z_pred_tmp);
z_pred_all = np.array(z_pred_all);

## save as np.save("../results/fire_rate_vae_chewie.npy",z_pred_all_vae)