## Import packages

In [None]:
import numpy as np
import numpy.random as npr
from matplotlib import pyplot as plt
%matplotlib inline
from matplotlib import cm
import matplotlib
from scipy import io
import seaborn as sns
import time


from multiprocessing import Pool

# from ssa_functions import fit_ssa, get_sample_weights, weighted_pca, weighted_rrr
import sys
sys.path.append('/Users/andrew/Documents/Projects/Churchland/Sparsity/code/ssa')

from ssa import fit_ssa, weighted_pca
from ssa.util import get_sample_weights

## Load Data

In [None]:
# which monkey are we working with?
monkName = 'Balboa'

In [None]:
# folder with rates
load_folder='/Users/sherryan/sca_data/'

# load data
data=io.loadmat(load_folder + monkName + '_outAndBack_redYellowConds_rawRates.mat')

# pull out the psths
# data is a C x N x T tensor 
data_array=data['data']

## Preprocess data


In [None]:
#Downsample data (using a factor of 10 here)
data_downsamp=data_array[:,:,np.arange(0,data_array.shape[2],10)]

# pull out some useful numbers
numConds,numN,trlDur = np.shape(data_downsamp)

#Concatenate all the conditions (so the matrix is size N x TC instead of C x N x T)
data_concat=data_downsamp.swapaxes(0,1).reshape([data_downsamp.shape[1],data_downsamp.shape[0]*data_downsamp.shape[2]])

#fr range
fr_range=np.ptp(data_concat,axis=1)[:,None]

# make a time mask
timeMask = np.tile(np.arange(trlDur),(1,numConds)).T.flatten()

# define the times we want to use for ssa/pca
# target on: 20
# move on:   77
# return:    200
trainTimes = np.arange(20,230)

# define a 'training mask' for convenience 
trainMask = np.in1d(timeMask,trainTimes)

In [None]:
#Subtract cross-condition mean
data_scm=data_downsamp-np.mean(data_downsamp,axis=0)[None,:,:]

#Concatenate all the conditions (so the matrix is size N x TC instead of C x N x T)
data_scm_concat=data_scm.swapaxes(0,1).reshape([data_downsamp.shape[1],data_downsamp.shape[0]*data_downsamp.shape[2]])

#Soft normalize (divide each neuron by its fr range + 5)
data_scm_norm=data_scm_concat/(fr_range+5)

# mean-center the data
dMean = np.tile(np.mean(data_scm_norm,axis=1)[:,np.newaxis],(1,data_scm_norm.shape[1]))
data_mc = data_scm_norm - dMean

## choose magnitude for sparcity penalty 
- this is a single free parameter that tells ssa how much to prioritize sparsity (relative to reconstruction error)

#### Choose data and hyperparameters

In [None]:
# rename the data for convenienc
#Note that model requires (T x N) input rather than (N x T), which is why there are transposes below
fit_data=np.copy(data_mc.T) 

# how much to weight each timestep (used by)
sample_weights=get_sample_weights(fit_data)

# number of dimensions to find
R_est=8

# range of lambda values to test
lambdaRange = np.logspace(-5,1,num = 10)

# number of tested lambdas
numLambda = lambdaRange.shape[0]

#Number of epochs of model fitting
numEpochs=3000

#Learning rate of model fitting
learningRate=.001

#Whether to print the model error while fitting
vBose=True

# Whether or not to impose hard orthogonality constraint 
hardOrthFlag = False

# Soft orthogonality penalty
# note that requiring a hard orthogonality constraint greatly slows down SSA, especially on large datasets. Having a soft penalty speeds everything up by about a factor of 10 and finds orthogonal dimensions. 
lam_orth = 4




In [None]:
# cycle through sparcity lambdas
X_pred = np.zeros((fit_data.shape[0],fit_data.shape[1],numLambda))

# cycle through lambdas
for L in np.arange(numLambda):

    # fit model
    model,latent, x_pred,losses=fit_ssa(X=fit_data,sample_weight = sample_weights,
                                   R=R_est,lam_sparse = lambdaRange[L],lr=learningRate,n_epochs=numEpochs,
                                   orth=hardOrthFlag, lam_orthog=lam_orth)

    # pull out the latents and the reconstructed data 
    x_pred = x_pred.detach().numpy()

    # save the reconstructions
    X_pred[:,:,L] = x_pred

In [None]:
# Save the reconstructions from above 
saveDir = '/Users/andrew/Documents/Projects/Churchland/Sparsity/data/sparsityLambda_tests/'
np.savez(saveDir+ monkName + '_' + 'centerOutReaching_sparsityLambdaTest'
           , predData = X_pred, data = fit_data)

In [None]:
# run weighted PCA for comparison reconstruction error

# run pca
U_est,V_est = weighted_pca(fit_data[trainMask,:],R_est,sample_weights[trainMask])

# get latents
pca_latent = fit_data@U_est

# reconstruct neurons
reconData_pca = np.reshape(pca_latent @ V_est,(-1,1),order = 'F')

# calculate reconstruction error 
oRates = fit_data.reshape(-1,1,order = 'F')

reconError_pca = np.sum( (oRates - reconData_pca)**2, axis = 0)

# calculate reconstruction error from U.T, rather than V
# reconstruct neurons
reconData_pca = np.reshape(pca_latent @ U_est.T,(-1,1),order = 'F')

# calculate reconstruction error 
oRates = fit_data.reshape(-1,1,order = 'F')

reconError_pca_U = np.sum( (oRates - reconData_pca)**2, axis = 0)




In [None]:

# change the default label color to white 
plt.rcParams['text.color'] = 'k'
plt.rcParams['xtick.color'] = 'k'
plt.rcParams['ytick.color'] = 'k'
plt.rcParams['axes.labelcolor'] = 'k'


# plot the reconstruction errors calculated above as a function of sparsity lambda 

# reshape reconstructed population from size CT x N x K to CTN x K
reconData = X_pred[:,:,:-1].reshape(-1,numLambda-1,order = 'F')

# reshape original data to CTN x 1 vectors
oRates = fit_data.reshape(-1,1, order = 'F')

# calculate total error between reconstruction and original rates
reconError = np.sum( (oRates - reconData)**2, axis = 0)

# plot error
plt.figure(figsize = (5,5));
plt.plot(lambdaRange[0:-1], reconError, color = 'r',label = 'reconstruction error');

# change x scale to a logscale
plt.xscale('log')

# add labels
plt.xlabel('sparsity penalty')
plt.ylabel('reconstruction error')

    
# mark the default lambda penalty for this dataset
plt.vlines(x = 0.015,ymin = 2000,ymax = 5000,color = 'k',label = 'default lam_sparse for this dataset');

# mark pca reconstruction error
plt.hlines(y = reconError_pca,xmin = 10e-5,xmax = 1,color = 'b',linestyle = '--',label = 'pca reconstruction error');


# add a legend
legend = plt.legend();
# change the font color so we can see the legend
for text in legend.get_texts():
    text.set_color('black');
    
# save figure
# save directory
figDir = '/Users/andrew/Documents/Projects/Churchland/Sparsity/figures/centerOutReaching/'

# save
plt.savefig(figDir + monkName + '_reconstructionError.pdf',dpi = 'figure')


#### how does ability to reconstruct activity change with R_est? ####

In [None]:
# rename the data for convenienc
#Note that model requires (T x N) input rather than (N x T), which is why there are transposes below
fit_data=np.copy(data_scm_norm.T) 

# how much to weight each timestep (used by)
sample_weights=get_sample_weights(fit_data)

# number of dimensions to find
R_est=np.arange(2,25)
numR = R_est.shape[0]

# sparcity penalty
sLambda= 0.01

#Number of epochs of model fitting
numEpochs=3000

#Learning rate of model fitting
learningRate=.001

#Whether to print the model error while fitting
vBose=True

# Whether or not to impose hard orthogonality constraint 
hardOrthFlag = False

# Soft orthogonality penalty
# note that requiring a hard orthogonality constraint greatly slows down SSA, especially on large datasets. Having a soft penalty speeds everything up by about a factor of 10 and finds orthogonal dimensions. 
lam_orth = 10

# generate a condition mask
condMask = np.tile(np.arange(numConds),(trlDur,1)).reshape(-1,1,order = 'F').flatten()


In [None]:
# cycle through sparcity lambdas
X_pred = np.zeros((fit_data.shape[0],fit_data.shape[1],numR))
minMax_var = np.zeros((numR,1))

# cycle through lambdas
for R in np.arange(numR):

    # fit model
    model,latent, x_pred=fit_ssa(X=fit_data,sample_weight = sample_weights,
                                   R=R_est[R],lam = sLambda,lr=learningRate,n_epochs=numEpochs,
                                   orth=hardOrthFlag, lam2=lam_orth,verbose=vBose)

    # pull out the latents and the reconstructed data 
    x_pred = x_pred.detach().numpy()

    # grab the latents
    latent=latent.detach().numpy()

    # reshape to be size T x C x K
    latent = np.reshape(latent,[-1,8,R_est[R]],order = 'F');

    # calculate the minimum (across dimensions) maximum (across time) cross-condition variance 
    minMax_var[R] = np.min(np.max(np.var(latent,axis = 1)))

    # save the reconstructions
    X_pred[:,:,R] = x_pred

In [None]:
# Save the reconstructions from above 
saveDir = '/Users/andrew/Documents/Projects/Churchland/Sparsity/data/reaching/'
np.savez(saveDir+ monkName + '_' + 'centerOutReaching_numDimsTest'
           , predData = X_pred, data = fit_data, minMax_ccVar = minMax_var)

In [None]:
# change the default label color to white 
plt.rcParams['text.color'] = 'w'
plt.rcParams['xtick.color'] = 'w'
plt.rcParams['ytick.color'] = 'w'
plt.rcParams['axes.labelcolor'] = 'w'


# plot the reconstruction errors calculated above as a function of sparsity lambda 

# reshape reconstructed population from size CT x N x K to CTN x K
reconData = X_pred.reshape(-1,numR,order = 'F')

# reshape original data to CTN x 1 vectors
oRates = fit_data.reshape(-1,1, order = 'F')

# calculate total error between reconstruction and original rates
reconError = np.sum( (oRates - reconData)**2, axis = 0)

# plot error
plt.figure(figsize = (5,5));
plt.plot(R_est, reconError, color = 'r',label = 'reconstruction error');

# add labels
plt.xlabel('number of dimensions')
plt.ylabel('reconstruction error')

# add a legend
legend = plt.legend();
# change the font color so we can see the legend
for text in legend.get_texts():
    text.set_color('black');

## Fit ssa with the chosen sparsity penalty

In [None]:
# fit model
sparsityLambda = 0.01
R_est = 8
model,ssa_latent, x_pred,losses=fit_ssa(X=fit_data[trainMask,:],sample_weight = sample_weights[trainMask],
                                R=R_est,lam_sparse = None,lr=learningRate,n_epochs=numEpochs*2,
                                orth=hardOrthFlag, lam_orthog=lam_orth)

# grab weights
ssaW = model.fc1.weight.detach().numpy()

# project all of the data into the ssa dimensions
ssa_latent = fit_data @ ssaW.T

# calculate reconstruction R2
X_hat = x_pred.detach().numpy()
X = fit_data[trainMask,:]





In [None]:
# calculate reconstruction R2
X_hat = x_pred.detach().numpy().reshape(-1,1,order = 'F')
X = fit_data[trainMask,:].reshape(-1,1,order='F')

SS_tot = np.sum( (X - np.mean(X))**2)
SS_res = np.sum( (X - X_hat)**2)

R2_ssa = 1 - (SS_res/SS_tot)

# display results
print('SSA R2: ' + str(R2_ssa))

In [None]:
# plot the dot product of the learned U weights (which are not constrained to be orthogonal)
plt.imshow(model.fc1.weight.detach().numpy()@model.fc1.weight.detach().numpy().T,clim=[-0.1,0.1],cmap='RdBu');
plt.colorbar();

In [None]:
# plot the dot product of the learned V weights (which may or may not be constrained to be orthogonal, depending on the 'orthFlg') 
plt.imshow(model.fc2.weight.detach().numpy().T@model.fc2.weight.detach().numpy(),clim=[-0.1,0.1],cmap='RdBu');
plt.colorbar();

### Fit PCA for comparison

In [None]:
U_est,V_est = weighted_pca(fit_data[trainMask,:],R_est,sample_weights[trainMask])
pca_latent = fit_data@U_est

# calculate reconstruction R2
X = fit_data[trainMask,:]
X_hat = X @U_est @ V_est
X_hat = X_hat.reshape(-1,1,order = 'F')
X = X.reshape(-1,1,order = 'F')

SS_tot = np.sum( (X - np.mean(X))**2)
SS_res = np.sum( (X - X_hat)**2)

R2_pca = 1 - (SS_res/SS_tot)

# display results
print('PCA R2: ' + str(R2_pca))

def get_sses_pred(y_test,y_test_pred):
    sse=np.sum((y_test_pred-y_test)**2,axis=0)
    return sse

def get_sses_mean(y_test):
    y_mean=np.mean(y_test,axis=0)
    sse_mean=np.sum((y_test-y_mean)**2,axis=0)
    return sse_mean
sses=get_sses_pred(X,X_hat)
sses_mean=get_sses_mean(X)
print(1-np.sum(sses)/np.sum(sses_mean))

In [None]:
X.shape

In [None]:
# for ssa and pca latents, calculate mean (and std) pairwise correlation 

# ssa first
ssa_rho = np.abs(np.corrcoef(ssa_latent.T))

# only keep lower triangle
ssa_rho = ssa_rho[np.triu_indices(R_est,k=1)]

# pca
pca_rho = np.abs(np.corrcoef(pca_latent.T))
pca_rho = pca_rho[np.triu_indices(R_est,k=1)]

# calculate mean (and std) of pairwise correlations
mPCA = np.mean(pca_rho)
stdPCA = np.std(pca_rho)

mSSA = np.mean(ssa_rho)
stdSSA = np.std(ssa_rho)

print('mPCA (std) ' + str(mPCA) + '(' + str(stdPCA) + ')' )
print('mSSA (std) ' + str(mSSA) + '(' + str(stdSSA) + ')' )



In [None]:
# calculate variance accounted for by pca and ssa latents

# total variance
totVar = np.sum(np.var(fit_data,axis=0))

# variance of pca latents
pcaVar = np.var(pca_latent,axis=0)

# variance of ssa latents (ordered by variance explained)
ssaVar = np.sort(np.var(ssa_latent, axis = 0))[::-1]

# cumulative sum of fraction of variance explained
cumSum_pca = np.cumsum(pcaVar) / totVar
cumSum_ssa = np.cumsum(ssaVar) / totVar

# plot
plt.plot(cumSum_pca,linewidth = 2,color = 'r')
plt.plot(cumSum_ssa,linewidth = 2,color = 'k')

# ratio of variance explained
print('ratio of variance explained: ' + str(np.sum(cumSum_ssa)/ np.sum(cumSum_pca)))

### define some plotting colors
 - ssa is purple
 - pca is washed out purple

In [None]:
# fraction of colormap to use
colorIdx = np.arange(0.55,1,0.4/8)

# define ssa colors
ssa_cMap = sns.cubehelix_palette(start = 0.1,rot = 0.6,dark = 0.15, light = 0.8,as_cmap = True,gamma = 1.1)
ssa_cMap = ssa_cMap(colorIdx)

# define pca colors
pca_cMap = sns.cubehelix_palette(start = 0.1,rot = 0.6,dark = 0.15, light = 0.8,as_cmap = True,gamma = 0.7,hue = 0.3)
pca_cMap = pca_cMap(colorIdx)


In [None]:
for i in np.arange(8):
    plt.plot(np.arange(0,10)*i,color = ssa_cMap[i,:],linewidth = 8);

In [None]:
for i in np.arange(8):
    plt.plot(np.arange(0,10)*i,color = pca_cMap[i,:],linewidth = 8);

## Plot latents

### Plot when reordering by time of maximal influence

In [None]:
# calculate across-condition variance of each projection as a function of time

# reshape both latents to be size T x C x K 
rs_ssa_latent = np.reshape(ssa_latent,(-1,8,R_est),order = 'F')

# calcualte across condition variance
ssa_var = np.var(rs_ssa_latent,axis = 1)

# find peak occupancy of each dimension
pkIdx = np.argmax(ssa_var,axis = 0)

# define plotting order
ssa_order = np.argsort(pkIdx)

# do the same for pca
rs_pca_latent = np.reshape(pca_latent,(-1,8,R_est),order = 'F')
pca_var = np.var(rs_pca_latent,axis = 1)
pkIdx = np.argmax(pca_var,axis = 0)

# define plotting order
pca_order = np.argsort(pkIdx)

#### PCA

In [None]:
#Get indices of each trial
T=data_downsamp.shape[2] #Length of time per condition
trs=np.arange(0,8)
t_idxs=[np.arange(T*tr,T*(tr+1)) for tr in trs]

# define some useful time points
tgt_idx=20
move_idx=77
ret_idx=200

# change color of axis labels so we can see them in the pdf
plt.rcParams['text.color'] = 'w'
plt.rcParams['xtick.color'] = 'w'
plt.rcParams['ytick.color'] = 'w'
plt.rcParams['axes.labelcolor'] = 'w'


fig,ax=plt.subplots(R_est,1,figsize=(15,20))
for i in range(R_est):
    for j in range(len(trs)):

        ax[i].plot(pca_latent[:,pca_order[i]][t_idxs[j]] - np.mean(pca_latent[:,pca_order[i]]),linewidth=2.25,color=pca_cMap[j,:])
        
        ax[i].plot([tgt_idx,tgt_idx],[-1.7,1.7],'gray',linewidth=.5)
        ax[i].plot([move_idx,move_idx],[-2,2],'k',linewidth=.5)
        ax[i].plot([ret_idx,ret_idx],[-2,2],'k',linewidth=.5)

        ax[i].set_xlim([0,T+1])
        ax[i].set_ylim([-2.5, 2.5])

        if i<R_est-1:
            ax[i].set_xticks([])
        else:
            ax[i].set_xlabel('Time (10ms bins)')
            
        ax[i].set_yticks([])
        ax[i].set_ylabel('Dim. '+str(i+1))

    ax[0].set_title('Weighted PCA')

# save figure

# save directory
figDir = '/Users/andrew/Documents/Projects/Churchland/Sparsity/figures/centerOutReaching/'

# save
plt.savefig(figDir + monkName + '_noMean_sub_pcaProj.pdf',dpi = 'figure')

#### SSA

In [None]:
#Get indices of each trial
T=data_downsamp.shape[2] #Length of time per condition
trs=np.arange(0,8)
t_idxs=[np.arange(T*tr,T*(tr+1)) for tr in trs]


fig,ax=plt.subplots(R_est,1,figsize=(15,20))
for i in range(R_est):
    for j in range(len(trs)):

        ax[i].plot(ssa_latent[:,ssa_order[i]][t_idxs[j]] - np.mean(ssa_latent[:,ssa_order[i]]),linewidth=2.25,color=ssa_cMap[j,:])
        
        ax[i].plot([tgt_idx,tgt_idx],[-1.7,1.7],'gray',linewidth=.5)
        ax[i].plot([move_idx,move_idx],[-2,2],'k',linewidth=.5)
        ax[i].plot([ret_idx,ret_idx],[-2,2],'k',linewidth=.5)
        ax[i].plot([0,T],[0,0],'k--')

        ax[i].set_xlim([0,T+1])
        ax[i].set_ylim([-2.5, 2.5])

        if i<R_est-1:
            ax[i].set_xticks([])
        else:
            ax[i].set_xlabel('Time (10ms bins)')
            
        ax[i].set_yticks([])
        ax[i].set_ylabel('Dim. '+str(i+1))

    ax[0].set_title('SSA')

# save
plt.savefig(figDir + monkName + '_noMeanSub_ssaProj.pdf',dpi = 'figure')

### calculate the across-condition variance at three time points: preparation, hold, movement
- target on: 20
- outward movement: 77
- return movement: 200

In [None]:
# define the time points we care about
prepWindow = np.arange(22,52)
moveWindow = np.arange(195,225)
holdWindow = np.arange(130,160)

In [None]:
# calculate the across-condition variance within each window, for each ssa and pca dimension

# initialize K x 3 array to hold results
ssaVar = np.zeros((R_est,3))
pcaVar = np.zeros((R_est,3))

# number of conditions
numConds = trs.shape[0]

# cycle through dimensions
for d in np.arange(R_est):

    # grab and reshape our projections to be size t x c
    ssaProj = ssa_latent[:,d].reshape([-1,numConds],order = 'F')
    pcaProj = pca_latent[:,d].reshape([-1,numConds],order = 'F')


    # calculate across-condition variance 
    projVar_ssa = np.var(ssaProj,axis = 1)
    projVar_pca = np.var(pcaProj,axis = 1)

    # normalize variance by total cross-condtion variance in this dimension
    totVar_ssa= np.sum(projVar_ssa)
    totVar_pca= np.sum(projVar_pca)

    # prep variance
    ssaVar[d,0] = np.sum(projVar_ssa[prepWindow]) / totVar_ssa
    pcaVar[d,0] = np.sum(projVar_pca[prepWindow]) / totVar_pca

    # move variance
    ssaVar[d,1] = np.sum(projVar_ssa[moveWindow]) / totVar_ssa
    pcaVar[d,1] = np.sum(projVar_pca[moveWindow]) / totVar_pca

    # hold variance
    ssaVar[d,2] = np.sum(projVar_ssa[holdWindow]) / totVar_ssa
    pcaVar[d,2] = np.sum(projVar_pca[holdWindow]) / totVar_pca


#### Plot variances 
- prep vs. move
- move vs. hold
- hold vs. prep 

In [None]:
# define plotting colors
pcaColor = 'lightslategray'
ssaColor = 'purple'

# set up figure
fig, axes = plt.subplots(nrows = 1,ncols = 3,figsize = (8,5))

# prep vs. move
plt.subplot(1,3,1)
plt.plot(ssaVar[:,1],ssaVar[:,0],'o',color = ssaColor,mfc = ssaColor,ms = 10 )
plt.plot(pcaVar[:,1],pcaVar[:,0],'o',color = pcaColor,mfc = pcaColor,ms = 10 )
plt.xlim([-0.05,0.5]);plt.ylim([-0.05,0.5]);
plt.xlabel('move');plt.ylabel('prep');




# move vs. hold
plt.subplot(1,3,2)
plt.plot(ssaVar[:,2],ssaVar[:,1],'o',color = ssaColor,mfc = ssaColor,ms = 10 )
plt.plot(pcaVar[:,2],pcaVar[:,1],'o',color = pcaColor,mfc = pcaColor,ms = 10 )
plt.xlim([-0.05,0.5]);plt.ylim([-0.05,0.5]);
plt.xlabel('hold');plt.ylabel('move');


# hold vs. prep
plt.subplot(1,3,3)
plt.plot(ssaVar[:,0],ssaVar[:,2],'o',color = ssaColor,mfc = ssaColor,ms = 10 )
plt.plot(pcaVar[:,0],pcaVar[:,2],'o',color = pcaColor,mfc = pcaColor,ms = 10 )
plt.xlim([-0.05,0.5]);plt.ylim([-0.05,0.5]);
plt.xlabel('prep');plt.ylabel('hold');


fig.tight_layout();

### bootstrap neurons, run weighted pca, generate a bunch of loadings

##### define a function to feed to the parallel pool

In [None]:
def runPCA_reaching(counter):
    
    # X is a matrix of size CT x N 
    dataL,numN = fit_data.shape
    
    # draw 'numN' neurons with replacement
    nIdx = np.random.choice(numN,numN)
    X_samp = fit_data[:,nIdx]
    
    # calculate our sample weights
    sWeights = get_sample_weights(X_samp)
     
    # run weighted pca
    uEst,vEst = weighted_pca(X_samp,R_est,sWeights)
    
    # pca latents
    lat = X_samp @ uEst
    return lat

In [None]:
# set up parallel pool
# 'Pool' defaults to the number of available CPUs
pool = Pool()

# fit pca 
out = pool.imap_unordered(runPCA_reaching,np.arange(1000))

In [None]:
# pull out all of our loadings
pca_bs_L = [*out]

#### load projections into prep, move, and posture space 
- calculated via gamal's orthogonal dim red. method

In [None]:
# folder with rates
load_folder='/Users/andrew/Documents/Projects/Churchland/Sparsity/data/reaching/'

# load data
data=io.loadmat(load_folder + monkName + '_gamalLoadings.mat')

# pull out prep, move, and posture projections
# (N x k)
gPrepProj=data['prepProj'][:,:,None]
gMoveProj=data['moveProj'][:,:,None]
gPostProj=data['postProj'][:,:,None]

# concatenate all projections
gProj = np.concatenate((gPrepProj,gMoveProj,gPostProj),axis = 2)

In [None]:
gPrepProj.shape

#### regress pca and ssa latents against prep, move, or posture dimensions
- ask how well the prep, move, or posture projections (ground truth) can reconstruct ssa/pca projection in a single dimension

In [None]:
# number of bootstraps
numBoots = len(pca_bs_L)

# initialize vector to hold R2
# each column corresponds to the reconstruction from the prep, move, or posture projections
allPCA_R2 = np.zeros((numBoots,R_est,3))

# cycle through bootstrap repetitions
for ii in np.arange(numBoots):

    # pull out bootstrap latents
    Y = pca_bs_L[ii]

    # get B (regression weights) from prep, move, and posture projections
    for e in np.arange(3):

        # make life a bit easier and just pull out the gamal projections we want
        X = gProj[:,:,e]

        # regress
        B = np.linalg.inv(X.T@X) @ X.T @ Y

        # reconstruct Y 
        Y_hat = X @ B
    
        # calculate R2
        ss_tot = np.sum((Y - np.mean(Y,axis = 0))**2, axis = 0)
        ss_res = np.sum((Y - Y_hat)**2,axis = 0)
        allPCA_R2[ii,:,e] = 1 - (ss_res/ss_tot)

# take max R2 across each epoch 
maxR2_pca = np.max(allPCA_R2,axis = 2)

# redo regression analysis for ssa projections
ssa_R2 = np.zeros((1,R_est,3))

# rename some variables for consistency
Y = ssa_latent

for e in np.arange(3):

    # make life a bit easier and just pull out the gamal projections we want
    X = gProj[:,:,e]

    # regress
    B = np.linalg.inv(X.T@X) @ X.T @ Y

    # reconstruct Y 
    Y_hat = X @ B
    
    # calculate R2
    ss_tot = np.sum((Y - np.mean(Y,axis = 0))**2, axis = 0)
    ss_res = np.sum((Y - Y_hat)**2,axis = 0)
    ssa_R2[0,:,e] = 1 - (ss_res/ss_tot)

# take max across epochs
maxR2_ssa = np.max(ssa_R2,axis = 2).squeeze()



#### Plot 

In [None]:
# change color of axis labels so we can see them in the pdf
plt.rcParams['text.color'] = 'k'
plt.rcParams['xtick.color'] = 'k'
plt.rcParams['ytick.color'] = 'k'
plt.rcParams['axes.labelcolor'] = 'k'

# reshape pca reconstructions
pcaR2_recon = np.reshape(maxR2_pca,[-1,1],order = 'F')


# plot pca results with a small amount of jitter around 1
pcaLoc = np.random.normal(loc = 1,scale = 0.03,size = (numBoots*R_est,1))

# plot means
plt.plot(pcaLoc,pcaR2_recon,'o',color = (0.4,0.4,0.4),ms = 4,alpha = 0.01);

# plot the mean and std of pca performance
plt.errorbar(1,np.mean(pcaR2_recon),np.std(pcaR2_recon),color = 'k',lw = 3,zorder = 3)
plt.plot(1,np.mean(pcaR2_recon),'o',color = 'k',ms = 8,zorder = 3);


# get jittered position for ssa reconstructions
ssaLoc = np.random.normal(loc = 2,scale = 0.01,size = (R_est,1))
plt.plot(ssaLoc,maxR2_ssa,'o',color = 'purple',alpha = 0.5);

# plot the mean and std of ssa performance
plt.errorbar(2,np.mean(maxR2_ssa),np.std(maxR2_ssa) / np.sqrt(maxR2_ssa.shape[0]) ,color = 'purple',lw = 3)
plt.plot(2,np.mean(maxR2_ssa),'o',color = 'purple',ms = 8);



# clean up
plt.xlim((0.5, 2.5));plt.ylim((0,1));
plt.xticks(np.array([1,2]),('PCA','SSA'));
plt.yticks(np.array([0,0.5, 1]));
plt.title(monkName);
plt.ylabel('R2');


# # save directory
figDir = '/Users/andrew/Documents/Projects/Churchland/Sparsity/figures/centerOutReaching/'

# save
plt.savefig(figDir + monkName + '_ssaVsPCA_reconError.pdf',dpi = 'figure')




In [None]:
#Get indices of each trial
T=data_downsamp.shape[2] #Length of time per condition
trs=np.arange(0,8)
t_idxs=[np.arange(T*tr,T*(tr+1)) for tr in trs]


fig,ax=plt.subplots(R_est,1,figsize=(10,20))
for i in range(R_est):
    for j in range(len(trs)):

        ax[i].plot(ssa_latent[:,i][t_idxs[j]],linewidth=2.25,color=ssa_cMap[j,:])
        
        ax[i].plot([tgt_idx,tgt_idx],[-1.7,1.7],'gray',linewidth=.5)
        ax[i].plot([move_idx,move_idx],[-2,2],'k',linewidth=.5)
        ax[i].plot([ret_idx,ret_idx],[-2,2],'k',linewidth=.5)
        ax[i].plot([0,T],[0,0],'k--')

        ax[i].set_xlim([0,T+1])
        ax[i].set_ylim([-2, 2])

        if i<R_est-1:
            ax[i].set_xticks([])
        else:
            ax[i].set_xlabel('Time (10ms bins)')
            
        ax[i].set_yticks([])
        ax[i].set_ylabel('Dim. '+str(i+1))

    ax[0].set_title('SSA')


#### Plot kinematics for one example condition 

In [None]:
# load kinematic data 
# correct folder
load_folder='/Users/andrew/Documents/Projects/Churchland/Sparsity/data/rawRates/'

# load data
data=io.loadmat(load_folder + 'Balboa' + '_outAndBack_redYellowConds_kinematics.mat')

# pull out the psths
# position data is 2 (X and Y) x T x N (single neuron recordings)
# the condition ID is saved in the same file 
position= data['allPos']
speed =   data['allSpeed']

# take average across neurons
mPos = np.mean(position,axis = 2)
mSpeed = np.mean(speed,axis = 1)

# downsample both to match resolution of neural data
mPos = mPos[:,np.arange(0,mPos.shape[1],10)]
mSpeed = mSpeed[np.arange(0,mSpeed.shape[0],10)]

# calculate distance from start as a function of time 
startPos = np.mean(mPos[:,np.arange(5)],axis = 1)
posDiff  = startPos[:,None] -mPos
dist     = np.linalg.norm(posDiff,axis = 0)

# plot 
plt.figure(figsize = (8,5));
plt.subplot(211);
plt.plot(dist,color = 'k',lw = 2,label = 'distance from center');plt.ylim([-5,250]);
legend = plt.legend();
plt.setp(legend.get_texts(),color = 'k')
plt.subplot(212);plt.ylim([-5,125]);
plt.plot(mSpeed,color = 'slategrey',lw=2,label = 'speed');
legend = plt.legend();
plt.setp(legend.get_texts(),color = 'k')


# define some useful time points
tgt_idx=20
move_idx=77
ret_idx=200

plt.plot([tgt_idx,tgt_idx],[0,120],'k',linewidth=.5);
plt.plot([move_idx,move_idx],[0,120],'k',linewidth=.5);
plt.plot([ret_idx,ret_idx],[0,120],'k',linewidth=.5);

# save
figDir = '/Users/andrew/Documents/Projects/Churchland/Sparsity/figures/centerOutReaching/'

# save
plt.savefig(figDir + 'Balboa' + '_kinematics.pdf',dpi = 'figure')






### Plot gamal projections 

In [None]:
### define some new colormaps
# fraction of colormap to use
colorIdx = np.arange(0.55,1,0.4/8)

# define prep colors
prep_cMap = sns.cubehelix_palette(start = 2,rot = 0,dark = 0.15, light = 0.8,as_cmap = True,gamma = 1.1)
prep_cMap = prep_cMap(colorIdx)

# move colors
move_cMap = sns.cubehelix_palette(start = 2.35,rot = 0,dark = 0.15, light = 0.8,as_cmap = True,gamma = 0.7)
move_cMap = move_cMap(colorIdx)

# posture colors
post_cMap = sns.cubehelix_palette(start = 2.8,rot = 0,dark = 0.15, light = 0.8,as_cmap = True,gamma = 0.7)
post_cMap = post_cMap(colorIdx)

# plot colors
plt.figure(figsize = (5,5));
plt.subplot(311);
for i in np.arange(8):
    plt.plot(np.arange(0,10)*i,color = prep_cMap[i,:],linewidth = 8);

plt.subplot(312);
for i in np.arange(8):
    plt.plot(np.arange(0,10)*i,color = move_cMap[i,:],linewidth = 8);

plt.subplot(313);
for i in np.arange(8):
    plt.plot(np.arange(0,10)*i,color = post_cMap[i,:],linewidth = 8);

# concatenate colormaps
prep_cMap = prep_cMap[:,:,None]
move_cMap = move_cMap[:,:,None]
post_cMap = post_cMap[:,:,None]

g_cMap = np.concatenate((prep_cMap, move_cMap, post_cMap),axis = 2)

In [None]:
# plot 
#Get indices of each trial
T=data_downsamp.shape[2] #Length of time per condition
trs=np.arange(0,8)
t_idxs=[np.arange(T*tr,T*(tr+1)) for tr in trs]

# define some useful time points
tgt_idx=20
move_idx=77
ret_idx=200

# change color of axis labels so we can see them in the pdf
plt.rcParams['text.color'] = 'k'
plt.rcParams['xtick.color'] = 'k'
plt.rcParams['ytick.color'] = 'k'
plt.rcParams['axes.labelcolor'] = 'k'


fig,ax=plt.subplots(R_est,3,figsize=(15,20))
for E in range(3):

    for i in range(R_est):
        for j in range(len(trs)):

            
            ax[i,E].plot(gProj[:,i,E][t_idxs[j]],linewidth=2.25,color=g_cMap[j,:,E])
        
            ax[i,E].plot([tgt_idx,tgt_idx],[-100,100],'gray',linewidth=.5)
            ax[i,E].plot([move_idx,move_idx],[-100,100],'k',linewidth=.5)
            ax[i,E].plot([ret_idx,ret_idx],[-100,100],'k',linewidth=.5)

            ax[i,E].set_xlim([0,T+1])
            ax[i,E].set_ylim([-120, 120])

            if i<R_est-1:
                ax[i,E].set_xticks([])
            else:
                ax[i,E].set_xlabel('Time (10ms bins)')
            
            ax[i,E].set_yticks([])
            ax[i,E].set_ylabel('Dim. '+str(i+1))

        ax[0,E].set_title('Gamal Projections')

# save figure

# save directory
figDir = '/Users/andrew/Documents/Projects/Churchland/Sparsity/figures/centerOutReaching/'

# save
plt.savefig(figDir + monkName + '_gamalProj.pdf',dpi = 'figure')