In [None]:
## load packages
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import scipy.stats as ss
import sys
sys.path.append('../code/');

import encoding as enc
import decoding as dec
import util_plot as uplot

seed = 888;
np.random.seed(seed) # fix randomness

%load_ext autoreload
%autoreload 2
%matplotlib inline

# deconvolution in matlab (skip)

In [None]:
load('../data/Hippocampus/results (12).mat');
c = C_raw';
shape = size(c);
K = shape(2);
T = shape(1);
cd = zeros(K,T);
s = zeros(K,T);
bs = zeros(K,1);
pars = zeros(K,1);
sn = zeros(K,1);
for n = 1:K
    [cd(n,:), s(n,:), options] = deconvolveCa(c(:,n), 'foopsi', 'ar1', 'smin', 0, 'optimize_pars', false, 'pars', 0.95, 'optimize_b', false);
    bs(n,1) = options.b;
    pars(n,1) = options.pars;
    sn(n,1) = options.sn;
end
save("../data/Hippocampus/12_s_0_deconv_new.mat", 'cd', 's', 'bs', 'pars', 'sn');

# load deconvolved data

In [None]:
## load deconvolved data (do deconvolution in matlab using hard-threshold method, set smin=0)
import h5py
with h5py.File('../data/Hippocampus/SessInfo (15).mat', 'r') as f:
    print(list(f.keys()))
    print(list(f['SessInfo']))
    print(list(f['SessInfo']['Behavior']))
    info = (f['SessInfo']['Behavior']['treadPos'])[()];

deconv = sio.loadmat("../data/Hippocampus/12_c_0_deconv_new.mat");
len_use = 53600;
hd = info[:len_use,0].copy();
Y_real = (deconv['s']/deconv['sn']).T[:len_use,:]; # normalize data by noise level
hd[hd==0] = 1;
hd = hd*360 - 180;

# downsample data
Y_real = Y_real.reshape(2, int(Y_real.shape[0]/2), Y_real.shape[1], order="F").sum(axis=0);
hd = hd.reshape(2, int(hd.shape[0]/2), order="F").mean(axis=0);

# only use running data
k = 6;
(hd[k:] - hd[:-k] == 0).sum()/hd.shape
tempp = np.where((hd[k:] - hd[:-k] != 0))[0]
tempq = np.where((hd[k:] - hd[:-k] == 0))[0]
hd_enc = hd[tempp];
Y_real_enc = Y_real[tempp,:];

# set up params

In [None]:
## result saving path
path = './hippo_rlt/';

## encoding params
n_epochs = 1800; # number of epochs 
learning_rate = 1e-2; # step size 
gen_nodes = 15; # number of nodes used in hidden layers

thresh = 0; # threshold for bernoulli model
gam_shift = 1e-4; # shift for gamma model, in general take it as a small number
factor = np.ones(Y_real.shape[1])*thresh; # location parameter for SNG model

## decoding params
bin_len = 121; # number of bins used for bayesian decoding analysis

In [None]:
## ramdomly split data
np.random.seed(seed);
perm = np.random.permutation(56);
cperm = np.random.permutation(Y_real_enc.shape[1]);

Y_split = np.array_split(Y_real_enc, 56, axis=0);
hd_split = np.array_split(hd_enc, 56, axis=0);

Y_train = [Y_split[index] for index in perm[:38]];
Y_train = np.concatenate(Y_train, axis=0)
hd_train = [hd_split[index] for index in perm[:38]]
hd_train = np.concatenate(hd_train, axis=0)

Y_valid = [Y_split[index] for index in perm[38:44]]
Y_valid = np.concatenate(Y_valid, axis=0)
hd_valid = [hd_split[index] for index in perm[38:44]]
hd_valid = np.concatenate(hd_valid, axis=0)

Y_test = [Y_split[index] for index in perm[44:]]
Y_test = np.concatenate(Y_test, axis=0)
hd_test = [hd_split[index] for index in perm[44:]]
hd_test = np.concatenate(hd_test, axis=0)

nTrain = Y_train.shape[0];
nValid = Y_valid.shape[0];
nTest = Y_real_enc.shape[0]-nTrain-nValid;

Y_const = np.concatenate([Y_split[index] for index in perm], axis=0);
hd_const = np.concatenate([hd_split[index] for index in perm], axis=0);

hd_bins_use = np.linspace(-180,180,bin_len)[1:];
hd_temp = uplot.getBins(hd_const, hd_bins_use);

# train encoding models

In [None]:
## poisson
spl_values_poi,yfit_poi,cost_poi,test_cost_poi=enc.poissonRunner(Y_train,Y_valid,hd_train,hd_valid,hd_test,n_epochs=n_epochs,learning_rate=learning_rate,gen_nodes=gen_nodes,bin_len=bin_len,path=path);

## bernoulli
spl_values_ber, yfit_ber, cost_ber, test_cost_ber = enc.bernoulliRunner(Y_train,Y_valid,hd_train,hd_valid,hd_test, thresh=thresh,n_epochs=n_epochs,learning_rate=learning_rate,gen_nodes=gen_nodes,bin_len=bin_len,path=path);

## gamma
spl_values_gam,spl_k_values_gam,spl_theta_values_gam,yfit_gam,gam_k,gam_theta,cost_gam,test_cost_gam=enc.gammaRunner(Y_train,Y_valid,hd_train,hd_valid,hd_test,factor=gam_shift,n_epochs=n_epochs,learning_rate=learning_rate,gen_nodes=gen_nodes,bin_len=bin_len,path=path);

## sng
spl_values_sngr, spl_p_values_sngr, spl_theta_values_sngr, sngr_k, sngr_loc, yfit_sngr, sngr_p, sngr_theta, cost_sngr, test_cost_sngr = enc.sngRlxRunner(Y_train, Y_valid, hd_train,hd_valid, hd_test, factor, n_epochs=n_epochs,learning_rate=5e-3, gen_nodes=gen_nodes,bin_len=bin_len,path=path);

# bayesian decoding analysis

In [None]:
## poisson
poi_mean_decode, poi_decode, poi_lik_mat = dec.poissonDecoding(Y_const, spl_values_poi, bin_len=bin_len);

## bernoulli
ber_mean_decode, ber_decode, ber_lik_mat = dec.bernoulliDecoding(Y_const, thresh, spl_values_ber, bin_len=bin_len);

## gamma
gam_mean_decode,gam_decode,gam_lik_mat=dec.gammaDecoding(Y_const,spl_k_values_gam,spl_theta_values_gam,factor=gam_shift,bin_len=bin_len);

## SNG
sngr_mean_decode, sngr_decode, sngr_lik_mat = dec.sngRlxDecoding(Y_const, spl_theta_values_sngr, 
                                                           spl_p_values_sngr, sngr_k, sngr_loc, bin_len=bin_len);

# check encoding results

In [None]:
## sample data
y_poisson, p_poi = enc.samplePoisson(yfit_poi, seed=seed);
y_bernoulli = enc.sampleBernoulli(yfit_ber, seed=seed);
y_sngr = enc.sampleSngRlx(sngr_p, sngr_k, sngr_theta, sngr_loc, seed=seed);
y_gamma = enc.sampleGamma(gam_k, gam_theta, seed=seed);

## diagnostic plots

In [None]:
## plot fitted vs observed mean
fig_rate = uplot.plotSampleRatem(Y_const, hd_const, y_poisson,
                                 y_bernoulli, y_gamma, y_sngr,ss=0,bins=10);

In [None]:
## plot fitted vs observed variance
fig_var = uplot.plotSampleVarm(Y_const,y_poisson, y_bernoulli, y_gamma,
                               y_sngr,hd_const,ss=0,bins=10);

In [None]:
## plot fitted vs empirical cdf
nid = 10;
fig_cdfslab = uplot.plotCdfSlabm(Y_const, hd_temp, spl_values_poi,spl_values_ber,
                   spl_theta_values_sngr, sngr_k, sngr_loc,spl_p_values_sngr,
                   spl_theta_values_gam, spl_k_values_gam,hd_bins_use,
           bin_list=[0,20,40], size=[1,3], nid=nid);

In [None]:
## plot fitted tunning curve
num=1; 
nid = 10;
tuning_curve = enc.get_tc(Y_train, hd_train, bin_len);
tuning_curve2 = enc.get_tc(Y_test, hd_test,bin_len);
fig_tc = uplot.plotTC(num, spl_values_poi[:,nid:], 
                      spl_values_ber[:,nid:], spl_values_gam[:,nid:], spl_values_sngr[:,nid:],
             tuning_curve[:,nid:],tuning_curve2[:,nid:],bin_len);

In [None]:
## plot fitted params
fig_params=uplot.plotParams([spl_values_poi, spl_theta_values_sngr, spl_p_values_sngr],
                 ['Mean $\mathbf{\lambda}$','Scale $\mathbf{a}$', 'Non-zero prob $\mathbf{p}$'], bin_len);

# check decoding results

In [None]:
## compute decoding mean absolute error

_,poi_error = dec.error(poi_mean_decode[:], hd_const[:]);
_,ber_error = dec.error(ber_mean_decode[:], hd_const[:]);
_,sngr_error = dec.error(sngr_mean_decode[:], hd_const[:]);
_,gam_error = dec.error(gam_mean_decode[:], hd_const[:]);

# training error
print('training error: poisson:', np.round(np.abs(poi_error[:nTrain]).sum()/nTrain,2),
      'bernoulli:', np.round(np.abs(ber_error[:nTrain]).sum()/nTrain,2),
      'gamma:', np.round(np.abs(gam_error[:nTrain]).sum()/nTrain,2),
      'SNG:', np.round(np.abs(sngr_error[:nTrain]).sum()/nTrain,2))

# test error
print('test error: poisson:', np.round(np.abs(poi_error[(nTrain+nValid):]).sum()/nTest,2),
      'bernoulli:', np.round(np.abs(ber_error[(nTrain+nValid):]).sum()/nTest,2),
      'gamma:', np.round(np.abs(gam_error[(nTrain+nValid):]).sum()/nTest,2),
      'SNG:', np.round(np.abs(sngr_error[(nTrain+nValid):]).sum()/nTest,2))


In [None]:
## compute CI coverage rate when varying multiple confidence levels (this can be slow)
width_med_ci, width_mean_ci, conf_rate_ci = dec.getVaryCI(Y_const, hd_const, hd_bins_use, 
                                                          sngr_lik_mat, poi_lik_mat,
                                                          ber_lik_mat, gam_lik_mat,
                                                          poi_mean_decode,ber_mean_decode, 
                                                          sngr_mean_decode, gam_mean_decode,
                                                          nTrain, nValid,nTest);
# results are in order SNG, poisson, bernoulli, gamma
# if only want to compute CI for SNG model, please refer to the getVaryCI function in decoding.py

In [None]:
## plot log posterior likelihood trace
# resemble the data in time order
xx = np.asarray(np.linspace(0, Y_real_enc.shape[0]-1, Y_real_enc.shape[0]), dtype='int')
xx_split = np.array_split(xx, perm.shape[0]);

rlt = [];
trial_rlt = [];
for ii in range(perm.shape[0]):
    rlt = rlt + list(xx_split[perm[ii]]);
    trial_rlt = trial_rlt + list(np.tile(ii,xx_split[perm[ii]].shape[0]))
    
order = np.argsort(rlt);
trial_order = np.zeros(Y_real.shape[0])-1;
trial_order[tempp] = np.array(trial_rlt)[order];

poi_decode_all = np.zeros(Y_real.shape[0]);
ber_decode_all = np.zeros(Y_real.shape[0]);
sngr_decode_all = np.zeros(Y_real.shape[0]);
gam_decode_all = np.zeros(Y_real.shape[0]);

poi_decode_all[tempp] = poi_mean_decode[order];
poi_decode_all[tempq] = poi_mean_decode2;

sngr_decode_all[tempp] = sngr_mean_decode[order];
sngr_decode_all[tempq] = sngr_mean_decode2;

ber_decode_all[tempp] = ber_mean_decode[order];
ber_decode_all[tempq] = ber_mean_decode2;

gam_decode_all[tempp] = gam_mean_decode[order];
gam_decode_all[tempq] = ber_mean_decode2;

poi_lik_all = np.zeros((Y_real.shape[0], hd_bins_use.shape[0]));
ber_lik_all = np.zeros((Y_real.shape[0], hd_bins_use.shape[0]));
sngr_lik_all = np.zeros((Y_real.shape[0], hd_bins_use.shape[0]));
gam_lik_all = np.zeros((Y_real.shape[0], hd_bins_use.shape[0]));

poi_lik_all[tempp] = poi_lik_mat[order];
poi_lik_all[tempq] = poi_lik_mat2;

sngr_lik_all[tempp] = sngr_lik_mat[order];
sngr_lik_all[tempq] = sngr_lik_mat2;

ber_lik_all[tempp] = ber_lik_mat[order];
ber_lik_all[tempq] = ber_lik_mat2;

gam_lik_all[tempp] = gam_lik_mat[order];
gam_lik_all[tempq] = gam_lik_mat2;

hd_bin_num = uplot.getBinNum(hd, hd_bins_use);
poi_bin_num = uplot.getBinNum(poi_decode_all, hd_bins_use);
ber_bin_num = uplot.getBinNum(ber_decode_all, hd_bins_use);
gam_bin_num = uplot.getBinNum(gam_decode_all, hd_bins_use);
sngr_bin_num = uplot.getBinNum(sngr_decode_all, hd_bins_use);

s_te = 4200;
e_te = s_te+3000;

poi_temp = np.clip(np.log(poi_lik_all/poi_lik_all.max(axis=1,keepdims=True)),-10,0);
ber_temp = np.clip(np.log(ber_lik_all/ber_lik_all.max(axis=1,keepdims=True)),-10,0);
gam_temp = np.clip(np.log(gam_lik_all/gam_lik_all.max(axis=1,keepdims=True)),-10,0);
sngr_temp = np.clip(np.log(sngr_lik_all/sngr_lik_all.max(axis=1,keepdims=True)),-10,0);

fig_conf = uplot.plotPosSuper2(hd_bin_num[s_te:e_te],sngr_temp[s_te:e_te].T);

In [None]:
## plot CI converage rate vs confidence levels and CI converage rate vs CI width
conf_level_list=np.array(list(np.linspace(0.9,0.999,10))[:-1] + [1-1e-2, 1-5e-3, 1-1e-3]);
print(conf_level_list);

fig_cimr = uplot.plotCIcov_mean(conf_rate_ci, width_mean_ci,drop=True);
fig_cilr = uplot.plotCIcov_lev(-np.log10(1-conf_rate_ci), -np.log10(1-conf_level_list),real=True,drop=True);

# save results

In [None]:
## save results
encoding_results = {'poi':[spl_values_poi,yfit_poi,cost_poi,test_cost_poi],'ber':[spl_values_ber, yfit_ber, cost_ber, test_cost_ber],
                    'sngr':[spl_values_sngr, spl_p_values_sngr, spl_theta_values_sngr, sngr_k, sngr_loc, yfit_sngr, sngr_p, sngr_theta, cost_sngr, test_cost_sngr],
                    'gam':[spl_values_gam,spl_k_values_gam,spl_theta_values_gam,yfit_gam,gam_k,gam_theta,cost_gam,test_cost_gam]};

decoding_results = {'poi':[poi_mean_decode,poi_decode,poi_lik_mat],'ber':[ber_mean_decode,ber_decode,ber_lik_mat],
                    'sngr':[sngr_mean_decode,sngr_decode,sngr_lik_mat],'gam':[gam_mean_decode,gam_decode,gam_lik_mat],
                    'stats':[width_med_ci,width_mean_ci,conf_rate_ci]};

np.savez(path+'rlts.npz',encoding_results,decoding_results);