# Run SSA and weighted PCA on center-out (red and yellow, long delay) reaching data. Plot projections in pca and sca dimensions
- first generate Balboa figures, then Alex.
- Use default orthogonality and sparsity penalties for SCA

## Import packages

In [None]:
import numpy as np
import numpy.random as npr
from scipy import io
import seaborn as sns
import time

import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt

from sca.models import fit_sca, weighted_pca
from sca.util import get_sample_weights
from plotly.subplots import make_subplots

# add appropriate directories to search path
import sys
sys.path.append('/Users/andrew/Documents/Projects/Churchland/Sparsity/code/andrewPython/')

from centerOutReaching_utils import calculateEpochOccupancy
from centerOutReaching_utils import calculateOccDispersion
from centerOutReaching_utils import calculateAI
from centerOutReaching_utils import calculateChanceAI
from parallelFunctions import bootstrapNeurons_SCA_PCA

import time


## Load Data

In [None]:
# which monkey are we working with?
monkName = 'Balboa'

In [None]:
# folder with rates
load_folder='/Users/andrew/Documents/Projects/Churchland/Sparsity/data/rawRates/'

# load data
data=io.loadmat(load_folder + monkName + '_outAndBack_redYellowConds_rawRates.mat')

# pull out the psths
# data is a C x N x T tensor 
data_array=data['data']

## Preprocess data


In [None]:
#Downsample data (using a factor of 10 here)
data_downsamp=data_array[:,:,np.arange(0,data_array.shape[2],10)]

# pull out some useful numbers
numConds,numN,trlDur = np.shape(data_downsamp)

#Concatenate all the conditions (so the matrix is size N x TC instead of C x N x T)
data_concat=data_downsamp.swapaxes(0,1).reshape([data_downsamp.shape[1],data_downsamp.shape[0]*data_downsamp.shape[2]])

#fr range
fr_range=np.ptp(data_concat,axis=1)[:,None]

# make a time mask
timeMask = np.tile(np.arange(trlDur),(1,numConds)).T.flatten()

# define the times we want to use for ssa/pca
# target on: 20
# move on:   77
# return:    200
trainTimes = np.arange(20,230)

# define a 'training mask' for convenience 
trainMask = np.in1d(timeMask,trainTimes)

#Subtract cross-condition mean
data_scm=data_downsamp-np.mean(data_downsamp,axis=0)[None,:,:]

#Concatenate all the conditions (so the matrix is size N x TC instead of C x N x T)
data_scm_concat=data_scm.swapaxes(0,1).reshape([data_scm.shape[1],data_scm.shape[0]*data_scm.shape[2]])

#Soft normalize (divide each neuron by its fr range + 5)
data_scm_norm=data_scm_concat/(fr_range+5)

# mean-center the data
dMean = np.tile(np.mean(data_scm_norm,axis=1)[:,np.newaxis],(1,data_scm_norm.shape[1]))
data_mc = data_scm_norm - dMean

# rename the data for convenience
#Note that model requires (T x N) input rather than (N x T), which is why there are transposes below
fit_data=np.copy(data_mc.T)

# how much to weight each timestep (used by)
sample_weights=get_sample_weights(fit_data)

# number of dimensions to find
R_est=8

# Whether or not to impose hard orthogonality constraint
hardOrthFlag = True



## Fit SCA

In [None]:
bias = model.fc1.bias.detach().numpy()
plt.plot(bias);

In [None]:
# fit model
start = time.time()
model,ssa_latent, x_pred,losses=fit_sca(X=fit_data[trainMask,:],sample_weight = sample_weights[trainMask],
                                R=R_est,orth = hardOrthFlag)
end = time.time()
print(end-start)

# grab weights
ssaW = model.fc1.weight.detach().numpy()

# project all of the data into the ssa dimensions
ssa_latent = fit_data @ ssaW.T

# plot loss
plt.plot(losses);



In [None]:
# calculate reconstruction R2
X_hat = x_pred.detach().numpy().reshape(-1,1,order = 'F')
X = fit_data[trainMask,:].reshape(-1,1,order='F')

SS_tot = np.sum( (X - np.mean(X))**2)
SS_res = np.sum( (X - X_hat)**2)

R2_ssa = 1 - (SS_res/SS_tot)

# display results
print('SSA R2: ' + str(R2_ssa))

In [None]:
# plot the dot product of the learned U weights (which are not constrained to be orthogonal)
U_dProd = model.fc1.weight.detach().numpy()@model.fc1.weight.detach().numpy().T

# plot
fig = px.imshow(U_dProd)
fig.update_layout(height = 300, width = 300,title = 'dot product of U')

In [None]:
# plot the dot product of the learned V weights (which may or may not be constrained to be orthogonal, depending on the 'orthFlg') 
V_dProd = model.fc2.weight.detach().numpy()@model.fc2.weight.detach().numpy().T

# plot
fig = px.imshow(abs(V_dProd),range_color = [0,0.5])
fig.update_layout(height = 300, width = 300,title = 'dot product of V')

### Fit PCA for comparison

In [None]:
U_est,V_est = weighted_pca(fit_data[trainMask,:],R_est,sample_weights[trainMask])
pca_latent = fit_data@U_est

# calculate reconstruction R2
X = fit_data[trainMask,:]
X_hat = X @U_est @ V_est
X_hat = X_hat.reshape(-1,1,order = 'F')
X = X.reshape(-1,1,order = 'F')

SS_tot = np.sum( (X - np.mean(X))**2)
SS_res = np.sum( (X - X_hat)**2)

R2_pca = 1 - (SS_res/SS_tot)

# display results
print('PCA R2: ' + str(R2_pca))

### define some plotting colors
 - ssa is purple -> light purple
 - I used https://davidjohnstone.net/lch-lab-colour-gradient-picker

In [None]:
# define ssa colors
sca_cMap  = ['#5e0044', '#6f144e', '#812858', '#933c62', '#a5506d', '#b76477', '#c97881', '#db8c8c']

## Plot latents

### Plot when reordering by time of maximal influence

In [None]:
# calculate across-condition variance of each projection as a function of time

# reshape both latents to be size T x C x K 
rs_sca_latent = np.reshape(ssa_latent,(-1,8,R_est),order = 'F')

# calculate across condition variance
sca_var = np.var(rs_sca_latent,axis = 1)

# find peak occupancy of each dimension
pkIdx = np.argmax(sca_var,axis = 0)

# define plotting order
sca_order = np.argsort(pkIdx)

# resort ssa_latents by time of maximum occupancy
rs_sca_latent = rs_sca_latent[:,:,sca_order]

# do the same for the pca projections
rs_pca_latent = np.reshape(pca_latent,(-1,8,R_est),order = 'F')

pca_var = np.var(rs_pca_latent,axis = 1)
pkIdx = np.argmax(pca_var,axis = 0)
pca_order = np.argsort(pkIdx)
rs_pca_latent = rs_pca_latent[:,:,pca_order]

# define some useful time points
tgt_idx=20
move_idx=77
ret_idx=200


In [None]:
# # alternatively, order latents by the time that occupancy crosses some threshold
#
# # reshape both latents to be size T x C x K
# rs_sca_latent = np.reshape(ssa_latent,(-1,8,R_est),order = 'F')
#
# # calculate across condition variance
# sca_var = np.var(rs_sca_latent,axis = 1)
#
# # for each occupancy trace, calculate when it crosses some (low) threshold
# threshTime_sca = np.zeros(R_est)
# threshold = 0.01
# for ii in range(R_est):
#     threshTime_sca[ii] = int(np.argwhere(sca_var[:,ii] > threshold)[0])
#
# # sort this list
# sca_order = np.argsort(threshTime_sca)
#
# # re-arrange latents
# rs_sca_latent = rs_sca_latent[:,:,sca_order]
#
# # now do the same for pca
# rs_pca_latent = np.reshape(pca_latent,(-1,8,R_est),order = 'F')
#
# pca_var = np.var(rs_pca_latent,axis = 1)
#
# threshTime_pca = np.zeros(R_est)
# threshold = 0.01
# for ii in range(R_est):
#     threshTime_pca[ii] = int(np.argwhere(pca_var[:,ii] > threshold)[0])
#
# # sort this list
# pca_order = np.argsort(threshTime_pca)
# rs_pca_latent = rs_pca_latent[:,:,pca_order]

In [None]:
# Define the directory for our new figures
figDir = '/Users/andrew/Documents/Projects/Churchland/Sparsity/figures/centerOutReaching/defaultSCAParameters/'

#### SCA

In [None]:


# range for y axis
yRange = [-1.8,1.8]

fig = make_subplots(rows=R_est,cols = 1,shared_xaxes = True,vertical_spacing = 0)

for ii in range(R_est):

    for jj in range(numConds):
        latTrace = go.Scatter(y = rs_sca_latent[:,jj,ii], line = go.scatter.Line(color = sca_cMap[jj],width = 2.5),showlegend = False)
        fig.add_trace(latTrace,row = ii+1,col=1)
        if ii == (R_est-1):
            fig.add_vline(x = tgt_idx,row = ii+1,col = 1, line_color = 'black')
            fig.add_vline(x = move_idx,row = ii+1,col = 1, line_color = 'black')
            fig.add_vline(x = ret_idx,row = ii+1,col = 1, line_color = 'black')

            #
    # add a vertical line for scale
    scaleLine = go.Scatter(x = [0,0],y = [-1,1],showlegend = False,mode = 'lines',
                            line = go.scatter.Line(color = 'black',width = 5))
    fig.add_trace(scaleLine,row = ii+1,col = 1)


fig.update_layout(height = 2000,width =600,title = 'SCA ' + monkName,title_font_color = 'black',
                  paper_bgcolor = 'white',
                  plot_bgcolor = 'white')
fig.update_yaxes(showgrid = False,zeroline = False,visible = False,range = yRange)
fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,visible = False)
fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,
                 ticks = 'outside',tickvals = [0,50],ticktext = ['0','500'],visible = True,row = R_est,col = 1)


# save
#fig.write_image(figDir + monkName + 'SCA_' + str(R_est) + 'dims.pdf')
#fig.show()


### Calculate the fractional occupancy for each epoch for each dimension
    - for reference:
    tgt_idx=20
    move_idx=77
    ret_idx=200

In [None]:
# define a vector of times, where 0 is the start of the trial (not target onset)
time = list(range(rs_sca_latent.shape[0]))

# define a prep. time
# outward reach
prepTime = list(range(tgt_idx,move_idx))
# return reach
prepTime.extend(list(range(ret_idx - 30, ret_idx)))

# execution time
# outward reach
moveTime = list(range(move_idx,move_idx + 30))
# return reach
moveTime.extend(list(range(ret_idx, ret_idx + 30 )))

# posture time
postTime = list(range(move_idx + 30,ret_idx - 30))

# plot all of our chosen times to make sure they make sense
fig = go.Figure(go.Scatter(y = np.isin(time,prepTime), name = 'prep'))
fig.add_trace(go.Scatter(y = np.isin(time,moveTime), name = 'move'))
fig.add_trace(go.Scatter(y = np.isin(time,postTime),name = 'post'))

fig.add_vline(x = tgt_idx, line_color = 'white',annotation_text = 'target on',annotation_position = 'top')
fig.add_vline(x = move_idx, line_color = 'white',annotation_text = 'move out',annotation_position = 'top')
fig.add_vline(x = ret_idx, line_color = 'white',annotation_text = 'return move',annotation_position = 'top')

fig.update_layout(height = 400,width = 700)

In [None]:
# calculate fractional occupancy
fractOcc_sca = calculateEpochOccupancy(rs_sca_latent,prepTimes = prepTime, moveTimes = moveTime, postTimes = postTime, projTimes = time)
fractOcc_pca = calculateEpochOccupancy(rs_pca_latent,prepTimes = prepTime, moveTimes = moveTime, postTimes = postTime, projTimes = time)

In [None]:
# calculate dispersion for each dimension (sum of absolute difference between each fractional occupancy)
disp_sca = calculateOccDispersion(fractOcc_sca)
disp_pca = calculateOccDispersion(fractOcc_pca)

In [None]:
# plot results
zMin = 0
zMax = 0.7

# initialize a figure
fig = make_subplots(rows=1,cols = 2,shared_xaxes = False,shared_yaxes = False,horizontal_spacing = 0.1, subplot_titles = ['SCA','PCA'])

# sca first
htPlt = go.Heatmap(go.Heatmap(z = fractOcc_sca,zmin = zMin, zmax = zMax, colorscale = 'Brwnyl'))
fig.add_trace(htPlt, row=1, col=1)

# pca
htPlt = go.Heatmap(go.Heatmap(z = fractOcc_pca,zmin = zMin, zmax = zMax, colorscale = 'Brwnyl'))
fig.add_trace(htPlt, row=1, col=2)

fig.update_yaxes(autorange = 'reversed',title = 'dimension')
fig.update_xaxes(title = 'epoch',tickmode = 'array',
                 tickvals = [0,1,2],
                 ticktext = ['prep.','exec.','post.'])

fig.update_layout(width = 800,height = 500,title = 'Epoch Sparsity ' + monkName)

# save figure
#fig.write_image(figDir + monkName + '_SCAPCA_' + str(R_est) + 'dims_epochSparsityMap.pdf')
#fig.show()

In [None]:
# display the max value of each row
print(np.max(fractOcc_sca,axis = 1))

### generate PCA and SCA latents from bootstrapped neuron populations
    - using the same population for each

In [None]:
# make copies to use as arguments for function we're going to parallelize
# copy R_est, fit_data, and trainMask

# number of bootstraps
numBoots = 500

inputList = []
for ii in range(numBoots):
    inputList.append([R_est,fit_data,trainMask])


In [None]:
# run bootstrapping function
from multiprocessing.pool import Pool

# set up parallel pool to use all available workers
pool = Pool()

# bootstrap
output = pool.starmap(bootstrapNeurons_SCA_PCA,inputList)

In [None]:
# close the pool
pool.close()

In [None]:
# parse outputs and save
numReps = len(output)
sca_latents_all = np.zeros((fit_data.shape[0],R_est,numReps))
pca_latents_all = np.zeros((fit_data.shape[0],R_est,numReps))


for ii in range(numReps):
    sca_latents_all[:,:,ii] = output[ii][0]
    pca_latents_all[:,:,ii] = output[ii][1]


# save everything
saveDir = '/Users/andrew/Documents/Projects/Churchland/Sparsity/data/reaching/sca_pca_defaultParams_bootstrap/'
np.save(saveDir + monkName + '_sca_pca_bootstrappedLatents_' + str(R_est) + 'Dims_v2.npy', {'sca_latents': sca_latents_all,'pca_latents':pca_latents_all})

In [None]:
# load the latents if we need to
saveDir = '/Users/andrew/Documents/Projects/Churchland/Sparsity/data/reaching/sca_pca_defaultParams_bootstrap/'
data = np.load(saveDir + monkName + '_sca_pca_bootstrappedLatents_8Dims.npy',allow_pickle=True)
data = data.item()
sca_latents = data['sca_latents']
pca_latents = data['pca_latents']


In [None]:
# for each set of sca (and pca) latents, calculate epoch sparsity matrix, then median sparsity value (need to come up with a better name) across dimensions

# initialize two vectors to hold results
numBoots = sca_latents.shape[2]
sparseIdx_sca = np.zeros(numBoots)
sparseIdx_pca = np.zeros(numBoots)

# cycle through bootstraps
for ii in range(numBoots):

    # reshape pca and sca projections to be T x C x K
    sca_rs = sca_latents[:,:,ii].reshape(-1,8,R_est,order = 'F')
    pca_rs = pca_latents[:,:,ii].reshape(-1,8,R_est,order = 'F')


    # calculate fractional occupancy
    fractOcc_sca = calculateEpochOccupancy(sca_rs,prepTimes = prepTime, moveTimes = moveTime, postTimes = postTime, projTimes = time)
    fractOcc_pca = calculateEpochOccupancy(pca_rs,prepTimes = prepTime, moveTimes = moveTime, postTimes = postTime, projTimes = time)

    # calculate dispersion for each dimension (sum of absolute difference between each fractional occupancy row)
    disp_sca = calculateOccDispersion(fractOcc_sca)
    disp_pca = calculateOccDispersion(fractOcc_pca)

    # save mean across dimensions
    sparseIdx_sca[ii] = np.copy(np.mean(disp_sca))
    sparseIdx_pca[ii] = np.copy(np.mean(disp_pca))




In [None]:
# define sca and pca colors
eBarColors = ['#5e0044','#a79ba4']

# combine data into lists to make life easy
sparseIdx = [np.copy(sparseIdx_sca), np.copy(sparseIdx_pca)]

# define some plotting parameters
eBarThickness = 5
eBarWidth = 10
markerSize = 15
markerLineWidth = 2

# sca
fig = go.Figure(data = go.Scatter(
    x = [0],y = [np.mean(sparseIdx[0])],
    error_y = dict(
        type = 'data',
        array = [np.std(sparseIdx[0])],visible = True,thickness = eBarThickness,width = eBarWidth),
    marker = dict(
        color = eBarColors[0],
        size = markerSize,
        line = dict(
            color = 'black',
            width = markerLineWidth
        )
    ),showlegend = False
    )
)

# pca
pcaTrace = go.Scatter(
    x = [1],y = [np.mean(sparseIdx[1])],
    error_y = dict(
        type = 'data',
        array = [np.std(sparseIdx[1])],visible = True,thickness = eBarThickness,width = eBarWidth),
    marker = dict(
        color = eBarColors[1],
        size = markerSize,
        line = dict(
            color = 'black',
            width = markerLineWidth
        )
    ),showlegend = False
)
fig.add_trace(pcaTrace)

# clean up figure
fig.update_layout(height =500,width =350,title = 'SCA vs. PCA epoch sparsity ' + monkName,title_font_color = 'black',
                  paper_bgcolor = 'white',
                  plot_bgcolor = 'white')
fig.update_yaxes(showgrid = False,zeroline = False,visible = False)
fig.update_yaxes(color = 'black',ticks = 'outside',visible = True,showline = True,linewidth = 1.5,tickwidth = 1.5)
fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,visible = False)
fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,
                 ticks = 'outside',tickvals = [0,1],ticktext = ['SCA','PCA'],visible = True,tickwidth = 1.5)

# save figure
# fig.write_image(figDir + monkName + '_SCA_vs_PCA_' + str(R_est) + 'dims_bootstrap_epochSparsitySummary.pdf')
# fig.show()

In [None]:
# calculate pVal for sparseIdx_sca < sparseIdx_pca
scaBigger = [1 if x < y else 0 for x,y in zip(sparseIdx[0],sparseIdx[1])]
pVal = sum(scaBigger) / len(scaBigger)
print('bootstrap pVal: ' + str(pVal) + ' (' + str(numBoots) +' bootstraps)')

### Calculate the (pairwise) alignment indices between prep_out, move_out, posture, prep_return, and move_return
    for reference:
    tgt_idx=20
    move_idx=77
    ret_idx=200

In [None]:
# define the time periods for each epoch

# prep
prepOut = list(np.arange(tgt_idx,move_idx - 10))
prepRtn = list(np.arange(ret_idx - 30,ret_idx - 10))


# move
moveOut = list(np.arange(move_idx,move_idx + 20))
moveRtn = list(np.arange(ret_idx,ret_idx + 20))

# posture
post = list(np.arange(move_idx + 30,ret_idx - 50))

# place everything in a list
epochTimes = [prepOut, moveOut, post, prepRtn, moveRtn]
numEpochs = len(epochTimes)

# number of dimensions to use for alignemnt index calculation
numDims = 10

# make a 'numEpochs' x 'numEpochs' matrix to hold results
allAI = np.zeros((numEpochs,numEpochs))

# cycle through all epochs
for ii in range(numEpochs):
    for jj in range(numEpochs):

        # grab data from epoch ii
        X1 = fit_data[np.isin(timeMask,epochTimes[ii]),:]

        # data from epoch jj
        X2 = fit_data[np.isin(timeMask,epochTimes[jj]),:]

        # calculate alignment index
        allAI[ii,jj] = calculateAI(X1,X2,numDims)

# calculate change AI
chanceAI = calculateChanceAI(fit_data,numDims)

In [None]:
# calulate median chance AI
medChanceAI = np.median(chanceAI)
print('chance AI: ' + str(medChanceAI))

# plot results
fig = go.Figure(go.Heatmap(z = allAI,zmin = 0, zmax = 1,colorscale = 'Electric'))

# add labels
fig.update_layout(yaxis = dict(
    tickmode = 'array',
    tickvals = [0,1,2,3,4],
    ticktext = ['prep out','move out','posture','prep rtn','move rtn']),
                xaxis = dict(
    tickmode = 'array',
    tickvals = [0,1,2,3,4],
    ticktext = ['prep out','move out','posture','prep rtn','move rtn'],
    tickcolor = 'black'),
                width = 500,height = 500,title = monkName + ' alignment indices')


fig.show()

# save figure (as html).
fig.write_html(figDir + monkName + '_' + str(R_est) + 'dims_aligmentIndex.html')


In [None]:
# initialize a figure
fig = make_subplots(rows=1,cols = numDimCount,shared_xaxes = True,shared_yaxes = True,horizontal_spacing = 0.01, subplot_titles = [str(x) + ' dims' for x in numDimsInRun] )

# cycle through runs with different number of dimensions
for dd in range(numDimCount):

    # reshape
    tempDisp = disp_all_rs[:,dd]
    tempDisp = tempDisp.reshape(numSparse,numOrthog,order = 'F')

    # plot
    tempImage = go.Heatmap(z = tempDisp, zmin = 0,zmax = 2,x = orthogLambdas,y = sparsityLambdas,colorscale = 'Magma')
    fig.add_trace(tempImage, row = 1, col=dd+1)



# fix some odd bug with the y axis tick labels
fig.update_layout(yaxis = dict(
    tickmode = 'array',
    tickvals = [1e-4, 1e-2, 1, 100]))

fig.update_xaxes(type = 'log')
fig.update_yaxes(type = 'log')
fig.update_layout(height = 600,width = 1000,xaxis_title = 'orthog lambda',yaxis_title = 'sparse lambdas',coloraxis_colorbar = dict(title = 'mean dispersion'),title = 'Epoch Dispersion')

fig.write_image(figDir + monkName + 'orth_sparse_lambdaSweeps_epochDispersion.pdf')
fig.show()

### Regress sca latents against gamal latents

In [None]:
# folder with rates
load_folder='/Users/andrew/Documents/Projects/Churchland/Sparsity/data/reaching/'

# load data
data=io.loadmat(load_folder + monkName + '_gamalLoadings.mat')

# pull out prep, move, and posture projections
# (N x k)
gPrepProj=data['prepProj'][:,:,None]
gMoveProj=data['moveProj'][:,:,None]
gPostProj=data['postProj'][:,:,None]

# concatenate all projections
gProj = np.concatenate((gPrepProj,gMoveProj,gPostProj),axis = 2)

In [None]:
# number of bootstraps
numBoots = pca_latents.shape[2]

# number of SCA dimensions
R_est = sca_latents.shape[1]

# initialize vector to hold R2
# each column corresponds to the reconstruction from the prep, move, or posture projections
allPCA_R2 = np.zeros((numBoots,R_est,3))

# cycle through bootstrap repetitions
for ii in np.arange(numBoots):

    # cycle through each pca dimension
    for jj in range(pca_latents.shape[1]):

        # pull out bootstrap latents
        Y = pca_latents[:,jj,ii]

        # get B (regression weights) from prep, move, and posture projections
        for e in np.arange(3):

            # make life a bit easier and just pull out the gamal projections we want
            X = gProj[:,:,e]

            # regress
            B = np.linalg.inv(X.T@X) @ X.T @ Y

            # reconstruct Y
            Y_hat = X @ B

            # calculate R2
            ss_tot = np.sum((Y - np.mean(Y,axis = 0))**2, axis = 0)
            ss_res = np.sum((Y - Y_hat)**2,axis = 0)
            allPCA_R2[ii,jj,e] = 1 - (ss_res/ss_tot)

# take max R2 across each epoch
maxR2_pca = np.max(allPCA_R2,axis = 2)

# redo regression analysis for ssa projections
# initialize vector to hold R2
allSCA_R2 = np.zeros((numBoots,R_est,3))

# cycle through bootstrap repetitions
for ii in np.arange(numBoots):

    # cycle through each sca dimension
    for jj in range(sca_latents.shape[1]):

        # pull out bootstrap latents
        Y = sca_latents[:,jj,ii]

        # get B (regression weights) from prep, move, and posture projections
        for e in np.arange(3):

            # make life a bit easier and just pull out the gamal projections we want
            X = gProj[:,:,e]

            # regress
            B = np.linalg.inv(X.T@X) @ X.T @ Y

            # reconstruct Y
            Y_hat = X @ B

            # calculate R2
            ss_tot = np.sum((Y - np.mean(Y,axis = 0))**2, axis = 0)
            ss_res = np.sum((Y - Y_hat)**2,axis = 0)
            allSCA_R2[ii,jj,e] = 1 - (ss_res/ss_tot)

# take max R2 across each epoch
maxR2_sca = np.max(allSCA_R2,axis = 2)


In [None]:
y = x1 > x2
print(np.sum(y)/x1.shape[0])

In [None]:
x1 = maxR2_pca.reshape(-1,1,order='F')
x2 = maxR2_sca.reshape(-1,1,order='F')
m1 = np.mean(x1)
s1 = np.std(x1)
m2 = np.mean(x2)
s2 = np.std(x2)
plt.errorbar(0,m1,s1,color='b');
plt.errorbar(1,m2,s2,color='r');
plt.ylim([0,1]);
plt.xlim([-0.5,1.5]);

# # save directory
figDir = '/Users/andrew/Documents/Projects/Churchland/Sparsity/figures/centerOutReaching/'

# save
plt.savefig(figDir + monkName + '_ssaVsPCA_reconError_test.pdf',dpi = 'figure')

In [None]:
# change color of axis labels so we can see them in the pdf
plt.rcParams['text.color'] = 'k'
plt.rcParams['xtick.color'] = 'k'
plt.rcParams['ytick.color'] = 'k'
plt.rcParams['axes.labelcolor'] = 'k'

# reshape pca reconstructions
pcaR2_recon = np.reshape(maxR2_pca,[-1,1],order = 'F')


# plot pca results with a small amount of jitter around 1
pcaLoc = np.random.normal(loc = 1,scale = 0.03,size = (numBoots*R_est,1))

# plot means
plt.plot(pcaLoc,pcaR2_recon,'o',color = (0.4,0.4,0.4),ms = 4,alpha = 0.01);

# plot the mean and std of pca performance
plt.errorbar(1,np.mean(pcaR2_recon),np.std(pcaR2_recon),color = 'k',lw = 3,zorder = 3)
plt.plot(1,np.mean(pcaR2_recon),'o',color = 'k',ms = 8,zorder = 3);


# reshape pca reconstructions
scaR2_recon = np.reshape(maxR2_sca,[-1,1],order = 'F')

# get jittered position for ssa reconstructions
ssaLoc = np.random.normal(loc = 2,scale = 0.01,size = (numBoots*R_est,1))
plt.plot(ssaLoc,scaR2_recon,'o',color = 'purple',alpha = 0.01);

# plot the mean and std of ssa performance
plt.errorbar(2,np.mean(scaR2_recon),np.std(scaR2_recon) ,color = 'black',lw = 3)
plt.plot(2,np.mean(scaR2_recon),'o',color = 'black',ms = 8);



# clean up
plt.xlim((0.5, 2.5));plt.ylim((0,1));
plt.xticks(np.array([1,2]),('PCA','SSA'));
plt.yticks(np.array([0,0.5, 1]));
plt.title(monkName);
plt.ylabel('R2');


# # save directory
figDir = '/Users/andrew/Documents/Projects/Churchland/Sparsity/figures/centerOutReaching/'

# save
#plt.savefig(figDir + monkName + '_ssaVsPCA_reconError.pdf',dpi = 'figure')


In [None]:
np.mean(pcaR2_recon)

### Cycle through number of dimensions and calculate R2 of reconstructed neurons

In [None]:
# number of dimensions to test
numDims = np.arange(5,55,5)

# initialize vectors to hold results
R2_sca = np.zeros(numDims.shape[0]) + np.nan
R2_pca = np.zeros(numDims.shape[0]) + np.nan

# cycle through number of dimensions
for ii,jj in enumerate(numDims):

    # fit SCA
    model,ssa_latent, x_pred,losses=fit_sca(X=fit_data[trainMask,:],sample_weight = sample_weights[trainMask],
                                    R=jj,orth = hardOrthFlag)
    # calculate reconstruction R2
    X_hat = x_pred.detach().numpy().reshape(-1,1,order = 'F')
    X = fit_data[trainMask,:].reshape(-1,1,order='F')

    SS_tot = np.sum( (X - np.mean(X))**2)
    SS_res = np.sum( (X - X_hat)**2)

    R2_sca[ii] = 1 - (SS_res/SS_tot)

    # fit PCA
    U_est,V_est = weighted_pca(fit_data[trainMask,:],jj,sample_weights[trainMask])
    pca_latent = fit_data@U_est

    # calculate reconstruction R2
    X = fit_data[trainMask,:]
    X_hat = X @U_est @ V_est
    X_hat = X_hat.reshape(-1,1,order = 'F')
    X = X.reshape(-1,1,order = 'F')

    SS_tot = np.sum( (X - np.mean(X))**2)
    SS_res = np.sum( (X - X_hat)**2)

    R2_pca[ii] = 1 - (SS_res/SS_tot)



In [None]:
# plot results

# change color of axis labels so we can see them in the pdf
plt.rcParams['text.color'] = 'k'
plt.rcParams['xtick.color'] = 'k'
plt.rcParams['ytick.color'] = 'k'
plt.rcParams['axes.labelcolor'] = 'k'

# plot sca results
plt.plot(numDims, R2_sca,'k',label='sca')

# and pca results
plt.plot(numDims,R2_pca,color=[0.5,0.5,0.5],label='pca')

# clean up
plt.xlim((0, 50));plt.ylim((0,1));
plt.xticks(np.array([0,25,50]));
plt.yticks(np.array([0,0.5, 1]));
plt.title(monkName);
plt.ylabel('R2');


# # save directory
figDir = '/Users/andrew/Documents/Projects/Churchland/Sparsity/figures/centerOutReaching/'

# save
plt.savefig(figDir + monkName + '_scaVsPCA_reconError.pdf',dpi = 'figure')